diff --git a/.asf.yaml b/.asf.yaml index 5fe94dc04af59..d71e7def36ad1 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -50,6 +50,79 @@ github: main: required_pull_request_reviews: required_approving_review_count: 1 + # needs to be updated as part of the release process + # .asf.yaml doesn't support wildcard branch protection rules, only exact branch names + # https://github.com/apache/infrastructure-asfyaml?tab=readme-ov-file#branch-protection + # Keeping set of protected branches for future releases + # Meanwhile creating a prerelease script that will update the branch protection names + # automatically. Keep track on it https://github.com/apache/datafusion/issues/17134 + branch-50: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-51: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-52: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-53: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-54: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-55: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-56: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-57: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-58: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-59: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-60: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-61: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-62: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-63: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-64: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-65: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-66: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-67: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-68: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-69: + required_pull_request_reviews: + required_approving_review_count: 1 + branch-70: + required_pull_request_reviews: + required_approving_review_count: 1 + pull_requests: + # enable updating head branches of pull requests + allow_update_branch: true + allow_auto_merge: true # publishes the content of the `asf-site` branch to # https://datafusion.apache.org/ diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 9dd627b01abed..49aacd118e19b 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -4,10 +4,12 @@ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ # Remove imagemagick due to https://security-tracker.debian.org/tracker/CVE-2019-10131 && apt-get purge -y imagemagick imagemagick-6-common -# Add protoc -# https://datafusion.apache.org/contributor-guide/getting_started.html#protoc-installation -RUN curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v25.1/protoc-25.1-linux-x86_64.zip \ - && unzip protoc-25.1-linux-x86_64.zip -d $HOME/.local \ - && rm protoc-25.1-linux-x86_64.zip +# setup the containers WORKDIR so npm install works +# https://stackoverflow.com/questions/57534295/npm-err-tracker-idealtree-already-exists-while-creating-the-docker-image-for +WORKDIR /root -ENV PATH="$PATH:$HOME/.local/bin" \ No newline at end of file +# Add protoc, npm, prettier +# https://datafusion.apache.org/contributor-guide/development_environment.html#protoc-installation +RUN apt-get update \ + && apt-get install -y --no-install-recommends protobuf-compiler libprotobuf-dev npm nodejs\ + && rm -rf /var/lib/apt/lists/* diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index a886cbd74c23a..ac5f082113117 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,5 +1,6 @@ name: Bug report description: Create a report to help us improve +type: Bug labels: bug body: - type: textarea diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 2542b28dcae8a..955e59d74d08b 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,5 +1,6 @@ name: Feature request description: Suggest an idea for this project +type: Feature labels: enhancement body: - type: textarea diff --git a/.github/actions/setup-macos-aarch64-builder/action.yaml b/.github/actions/setup-macos-aarch64-builder/action.yaml index 288799a284b01..b62370447adea 100644 --- a/.github/actions/setup-macos-aarch64-builder/action.yaml +++ b/.github/actions/setup-macos-aarch64-builder/action.yaml @@ -44,6 +44,8 @@ runs: rustup default stable rustup component add rustfmt - name: Setup rust cache - uses: Swatinem/rust-cache@v2 + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + save-if: ${{ github.ref_name == 'main' }} - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime diff --git a/.github/actions/setup-macos-builder/action.yaml b/.github/actions/setup-macos-builder/action.yaml deleted file mode 100644 index fffdab160b043..0000000000000 --- a/.github/actions/setup-macos-builder/action.yaml +++ /dev/null @@ -1,47 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Prepare Rust Builder for MacOS -description: 'Prepare Rust Build Environment for MacOS' -inputs: - rust-version: - description: 'version of rust to install (e.g. stable)' - required: true - default: 'stable' -runs: - using: "composite" - steps: - - name: Install protobuf compiler - shell: bash - run: | - mkdir -p $HOME/d/protoc - cd $HOME/d/protoc - export PROTO_ZIP="protoc-29.1-osx-x86_64.zip" - curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v29.1/$PROTO_ZIP - unzip $PROTO_ZIP - echo "$HOME/d/protoc/bin" >> $GITHUB_PATH - export PATH=$PATH:$HOME/d/protoc/bin - protoc --version - - name: Setup Rust toolchain - shell: bash - run: | - rustup update stable - rustup toolchain install stable - rustup default stable - rustup component add rustfmt - - name: Configure rust runtime env - uses: ./.github/actions/setup-rust-runtime diff --git a/.github/actions/setup-rust-runtime/action.yaml b/.github/actions/setup-rust-runtime/action.yaml index cd18be9890315..e0341de93b83d 100644 --- a/.github/actions/setup-rust-runtime/action.yaml +++ b/.github/actions/setup-rust-runtime/action.yaml @@ -20,8 +20,6 @@ description: 'Setup Rust Runtime Environment' runs: using: "composite" steps: - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.4 - name: Configure runtime env shell: bash # do not produce debug symbols to keep memory usage down @@ -31,8 +29,5 @@ runs: # Set debuginfo=line-tables-only as debuginfo=0 causes immensely slow build # See for more details: https://github.com/rust-lang/rust/issues/119560 run: | - echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV - echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV echo "RUST_BACKTRACE=1" >> $GITHUB_ENV echo "RUSTFLAGS=-C debuginfo=line-tables-only -C incremental=false" >> $GITHUB_ENV - diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 7c2b7e3a5458c..9d1d77d44c378 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -20,7 +20,7 @@ updates: - package-ecosystem: cargo directory: "/" schedule: - interval: daily + interval: weekly target-branch: main labels: [auto-dependencies] ignore: @@ -50,3 +50,8 @@ updates: interval: "daily" open-pull-requests-limit: 10 labels: [auto-dependencies] + - package-ecosystem: "pip" + directory: "/docs" + schedule: + interval: "weekly" + labels: [auto-dependencies] diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index 0d65b1aa809ff..d2b6e350a5e68 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -23,6 +23,8 @@ concurrency: on: push: + branches: + - main paths: - "**/Cargo.toml" - "**/Cargo.lock" @@ -31,13 +33,17 @@ on: paths: - "**/Cargo.toml" - "**/Cargo.lock" + + merge_group: jobs: security_audit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install cargo-audit - run: cargo install cargo-audit + uses: taiki-e/install-action@e7ef886cf8f69c25ecef6bbc2858a42e273496ec # v2.62.28 + with: + tool: cargo-audit - name: Run audit check run: cargo audit diff --git a/.github/workflows/dependencies.yml b/.github/workflows/dependencies.yml index a577725fed4b9..7e736e1a7afbf 100644 --- a/.github/workflows/dependencies.yml +++ b/.github/workflows/dependencies.yml @@ -23,6 +23,8 @@ concurrency: on: push: + branches-ignore: + - 'gh-readonly-queue/**' paths: - "**/Cargo.toml" - "**/Cargo.lock" @@ -30,6 +32,7 @@ on: paths: - "**/Cargo.toml" - "**/Cargo.lock" + merge_group: # manual trigger # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: @@ -41,7 +44,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -53,3 +56,14 @@ jobs: run: | cd dev/depcheck cargo run + + detect-unused-dependencies: + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - name: Install cargo-machete + run: cargo install cargo-machete --version ^0.9 --locked + - name: Detect unused dependencies + run: cargo machete --with-metadata \ No newline at end of file diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index aa4bd862e09e4..cc879f66cc936 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -16,7 +16,12 @@ # under the License. name: Dev -on: [push, pull_request] +on: + push: + branches-ignore: + - 'gh-readonly-queue/**' + pull_request: + merge_group: concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} @@ -27,15 +32,18 @@ jobs: runs-on: ubuntu-latest name: Check License Header steps: - - uses: actions/checkout@v4 - - uses: korandoru/hawkeye@v6 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - name: Install HawkEye + run: cargo install hawkeye --version 6.2.0 --locked --profile dev + - name: Run license header check + run: ci/scripts/license_header.sh prettier: name: Use prettier to check formatting of documents runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-node@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0 with: node-version: "20" - name: Prettier check diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 5f1b2c1395982..588bf46aaca70 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -32,16 +32,16 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout docs sources - uses: actions/checkout@v4 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Checkout asf-site branch - uses: actions/checkout@v4 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: ref: asf-site path: asf-site - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.12" diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index d3c901c5b71b6..c182f2ef85d23 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -34,37 +34,18 @@ on: workflow_dispatch: jobs: - # Run doc tests - linux-test-doc: - name: cargo doctest (amd64) - runs-on: ubuntu-latest - container: - image: amd64/rust - steps: - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 1 - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Run doctests (embedded rust examples) - run: cargo test --doc --features avro,json - - name: Verify Working Directory Clean - run: git diff --exit-code - + # Test doc build linux-test-doc-build: name: Test doc build runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.12" - name: Install doc dependencies diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index a5d68ff079b56..9343997e05682 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -32,6 +32,10 @@ on: push: branches: - main + # support extended test suite for release candidate branches, + # it is not expected to have many changes in these branches, + # so running extended tests is not a burden + - 'branch-*' workflow_dispatch: inputs: pr_number: @@ -47,7 +51,7 @@ on: permissions: contents: read checks: write - + jobs: # Check crate compiles and base cargo check passes @@ -56,8 +60,9 @@ jobs: runs-on: ubuntu-latest # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: + ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true fetch-depth: 1 - name: Install Rust @@ -79,12 +84,13 @@ jobs: runs-on: ubuntu-latest # note: do not use amd/rust container to preserve disk space steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: + ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true fetch-depth: 1 - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be + uses: jlumbroso/free-disk-space@54081f138730dfa15788a46383842cd2f914a1be # v1.3.1 - name: Install Rust run: | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y @@ -99,7 +105,17 @@ jobs: - name: Run tests (excluding doctests) env: RUST_BACKTRACE: 1 - run: cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace,extended_tests,recursive_protection + run: | + cargo test \ + --profile ci \ + --exclude datafusion-examples \ + --exclude datafusion-benchmarks \ + --exclude datafusion-cli \ + --workspace \ + --lib \ + --tests \ + --bins \ + --features avro,json,backtrace,extended_tests,recursive_protection - name: Verify Working Directory Clean run: git diff --exit-code - name: Cleanup @@ -112,8 +128,9 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: + ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true fetch-depth: 1 - name: Setup Rust toolchain @@ -123,7 +140,7 @@ jobs: - name: Run tests run: | cd datafusion - cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --exclude datafusion-sqllogictest --workspace --lib --tests --features=force_hash_collisions,avro + cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --exclude datafusion-sqllogictest --exclude datafusion-cli --workspace --lib --tests --features=force_hash_collisions,avro cargo clean sqllogictest-sqlite: @@ -132,8 +149,9 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: + ref: ${{ github.event.inputs.pr_head_sha }} # will be empty if triggered by push submodules: true fetch-depth: 1 - name: Setup Rust toolchain @@ -145,44 +163,6 @@ jobs: cargo test --features backtrace --profile release-nonlto --test sqllogictests -- --include-sqlite cargo clean - # If the workflow was triggered by the PR comment (through pr_comment_commands.yml action) we need to manually update check status to display in UI - update-check-status: - needs: [linux-build-lib, linux-test-extended, hash-collisions, sqllogictest-sqlite] - runs-on: ubuntu-latest - if: ${{ always() && github.event_name == 'workflow_dispatch' }} - steps: - - name: Determine workflow status - id: status - run: | - if [[ "${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') }}" == "true" ]]; then - echo "workflow_status=failure" >> $GITHUB_OUTPUT - echo "conclusion=failure" >> $GITHUB_OUTPUT - else - echo "workflow_status=completed" >> $GITHUB_OUTPUT - echo "conclusion=success" >> $GITHUB_OUTPUT - fi - - - name: Update check run - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - const workflowRunUrl = `https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}`; - - await github.rest.checks.update({ - owner: context.repo.owner, - repo: context.repo.repo, - check_run_id: ${{ github.event.inputs.check_run_id }}, - status: 'completed', - conclusion: '${{ steps.status.outputs.conclusion }}', - output: { - title: '${{ steps.status.outputs.conclusion == 'success' && 'Extended Tests Passed' || 'Extended Tests Failed' }}', - summary: `Extended tests have completed with status: ${{ steps.status.outputs.conclusion }}.\n\n[View workflow run](${workflowRunUrl})` - }, - details_url: workflowRunUrl - }); - - diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/labeler.yml similarity index 90% rename from .github/workflows/dev_pr.yml rename to .github/workflows/labeler.yml index 11c14c5c2feee..0abf535b9741f 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/labeler.yml @@ -39,17 +39,17 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v5.0.0 + uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - configuration-path: .github/workflows/dev_pr/labeler.yml + configuration-path: .github/workflows/labeler/labeler-config.yml sync-labels: true # TODO: Enable this when eps1lon/actions-label-merge-conflict is available. diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/labeler/labeler-config.yml similarity index 95% rename from .github/workflows/dev_pr/labeler.yml rename to .github/workflows/labeler/labeler-config.yml index da93e65418551..e408130725215 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/labeler/labeler-config.yml @@ -41,7 +41,7 @@ physical-expr: physical-plan: - changed-files: - - any-glob-to-any-file: [datafusion/physical-plan/**/*'] + - any-glob-to-any-file: ['datafusion/physical-plan/**/*'] catalog: @@ -77,6 +77,10 @@ proto: - changed-files: - any-glob-to-any-file: ['datafusion/proto/**/*', 'datafusion/proto-common/**/*'] +spark: +- changed-files: + - any-glob-to-any-file: ['datafusion/spark/**/*'] + substrait: - changed-files: - any-glob-to-any-file: ['datafusion/substrait/**/*'] diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index aa96d55a0d851..9cbfd6030a7f6 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -23,12 +23,13 @@ concurrency: on: pull_request: + merge_group: jobs: check-files: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: fetch-depth: 0 - name: Check size of new Git objects @@ -38,7 +39,16 @@ jobs: MAX_FILE_SIZE_BYTES: 1048576 shell: bash run: | - git rev-list --objects ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} \ + if [ "${{ github.event_name }}" = "merge_group" ]; then + # For merge queue, compare against the base branch + base_sha="${{ github.event.merge_group.base_sha }}" + head_sha="${{ github.event.merge_group.head_sha }}" + else + # For pull requests + base_sha="${{ github.event.pull_request.base.sha }}" + head_sha="${{ github.event.pull_request.head.sha }}" + fi + git rev-list --objects ${base_sha}..${head_sha} \ > pull-request-objects.txt exit_code=0 while read -r id path; do diff --git a/.github/workflows/pr_comment_commands.yml b/.github/workflows/pr_comment_commands.yml deleted file mode 100644 index a20a5b15965dd..0000000000000 --- a/.github/workflows/pr_comment_commands.yml +++ /dev/null @@ -1,89 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: PR commands - -on: - issue_comment: - types: [created] - -permissions: - contents: read - pull-requests: write - actions: write - checks: write - -jobs: - # Starts the extended_tests on a PR branch when someone leaves a `Run extended tests` comment - run_extended_tests: - runs-on: ubuntu-latest - if: ${{ github.event_name == 'issue_comment' && github.event.issue.pull_request && contains(github.event.comment.body, 'Run extended tests') }} - steps: - - name: Dispatch extended tests for a PR branch with comment - uses: actions/github-script@v7 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - // Get PR details to fetch the branch name - const { data: pullRequest } = await github.rest.pulls.get({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: context.payload.issue.number - }); - - // Extract the branch name - const branchName = pullRequest.head.ref; - const headSha = pullRequest.head.sha; - const workflowRunsUrl = `https://github.com/${context.repo.owner}/${context.repo.repo}/actions?query=workflow%3A%22Datafusion+extended+tests%22+branch%3A${branchName}`; - - // Create a check run that links to the Actions tab so the run will be visible in GitHub UI - const check = await github.rest.checks.create({ - owner: context.repo.owner, - repo: context.repo.repo, - name: 'Extended Tests', - head_sha: headSha, - status: 'in_progress', - output: { - title: 'Extended Tests Running', - summary: `Extended tests have been triggered for this PR.\n\n[View workflow runs](${workflowRunsUrl})` - }, - details_url: workflowRunsUrl - }); - - // Dispatch the workflow with the PR branch name - await github.rest.actions.createWorkflowDispatch({ - owner: context.repo.owner, - repo: context.repo.repo, - workflow_id: 'extended.yml', - ref: branchName, - inputs: { - pr_number: context.payload.issue.number.toString(), - check_run_id: check.data.id.toString(), - pr_head_sha: headSha - } - }); - - - name: Add reaction to comment - uses: actions/github-script@v7 - with: - script: | - await github.rest.reactions.createForIssueComment({ - owner: context.repo.owner, - repo: context.repo.repo, - comment_id: context.payload.comment.id, - content: 'rocket' - }); \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 1e6cd97acea33..e3e4c881da8c1 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -23,6 +23,8 @@ concurrency: on: push: + branches-ignore: + - 'gh-readonly-queue/**' paths-ignore: - "docs/**" - "**.md" @@ -34,19 +36,12 @@ on: - "**.md" - ".github/ISSUE_TEMPLATE/**" - ".github/pull_request_template.md" + merge_group: # manual trigger # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: jobs: - # Check license header - license-header-check: - runs-on: ubuntu-latest - name: Check License Header - steps: - - uses: actions/checkout@v4 - - uses: korandoru/hawkeye@v6 - # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test @@ -54,11 +49,16 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + shared-key: "amd-ci-check" # this job uses it's own cache becase check has a separate cache and we need it to be fast as it blocks other jobs + save-if: ${{ github.ref_name == 'main' }} - name: Prepare cargo build run: | # Adding `--locked` here to assert that the `Cargo.lock` file is up to @@ -77,7 +77,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -102,11 +102,16 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + save-if: false # set in linux-test + shared-key: "amd-ci" - name: Check datafusion-substrait (default features) run: cargo check --profile ci --all-targets -p datafusion-substrait # @@ -134,7 +139,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -165,11 +170,16 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + save-if: false # set in linux-test + shared-key: "amd-ci" - name: Check datafusion (default features) run: cargo check --profile ci --all-targets -p datafusion # @@ -207,10 +217,14 @@ jobs: run: cargo check --profile ci --no-default-features -p datafusion --features=recursive_protection - name: Check datafusion (serde) run: cargo check --profile ci --no-default-features -p datafusion --features=serde + - name: Check datafusion (sql) + run: cargo check --profile ci --no-default-features -p datafusion --features=sql - name: Check datafusion (string_expressions) run: cargo check --profile ci --no-default-features -p datafusion --features=string_expressions - name: Check datafusion (unicode_expressions) run: cargo check --profile ci --no-default-features -p datafusion --features=unicode_expressions + - name: Check parquet encryption (parquet_encryption) + run: cargo check --profile ci --no-default-features -p datafusion --features=parquet_encryption # Check datafusion-functions crate features # @@ -223,7 +237,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -255,15 +269,22 @@ jobs: name: cargo test (amd64) needs: linux-build-lib runs-on: ubuntu-latest + container: + image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 - name: Setup Rust toolchain - run: rustup toolchain install stable - - name: Install Protobuf Compiler - run: sudo apt-get install -y protobuf-compiler + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + save-if: ${{ github.ref_name == 'main' }} + shared-key: "amd-ci" - name: Run tests (excluding doctests and datafusion-cli) env: RUST_BACKTRACE: 1 @@ -278,7 +299,7 @@ jobs: --lib \ --tests \ --bins \ - --features serde,avro,json,backtrace,integration-tests + --features serde,avro,json,backtrace,integration-tests,parquet_encryption - name: Verify Working Directory Clean run: git diff --exit-code @@ -288,24 +309,17 @@ jobs: needs: linux-build-lib runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 - name: Setup Rust toolchain run: rustup toolchain install stable - - name: Setup Minio - S3-compatible storage - run: | - docker run -d --name minio-container \ - -p 9000:9000 \ - -e MINIO_ROOT_USER=TEST-DataFusionLogin -e MINIO_ROOT_PASSWORD=TEST-DataFusionPassword \ - -v $(pwd)/datafusion/core/tests/data:/source quay.io/minio/minio \ - server /data - docker exec minio-container /bin/sh -c "\ - mc ready local - mc alias set localminio http://localhost:9000 TEST-DataFusionLogin TEST-DataFusionPassword && \ - mc mb localminio/data && \ - mc cp -r /source/* localminio/data" + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + save-if: false # set in linux-test + shared-key: "amd-ci" - name: Run tests (excluding doctests) env: RUST_BACKTRACE: 1 @@ -314,12 +328,9 @@ jobs: AWS_SECRET_ACCESS_KEY: TEST-DataFusionPassword TEST_STORAGE_INTEGRATION: 1 AWS_ALLOW_HTTP: true - run: cargo test --profile ci -p datafusion-cli --lib --tests --bins + run: cargo test --features backtrace --profile ci -p datafusion-cli --lib --tests --bins - name: Verify Working Directory Clean run: git diff --exit-code - - name: Minio Output - if: ${{ !cancelled() }} - run: docker logs minio-container linux-test-example: @@ -329,7 +340,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -337,6 +348,11 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + save-if: ${{ github.ref_name == 'main' }} + shared-key: "amd-ci-linux-test-example" - name: Run examples run: | # test datafusion-sql examples @@ -354,7 +370,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -375,7 +391,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -384,25 +400,28 @@ jobs: run: ci/scripts/rust_docs.sh linux-wasm-pack: - name: build with wasm-pack - runs-on: ubuntu-latest - container: - image: amd64/rust + name: build and run with wasm-pack + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v4 - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + - name: Setup for wasm32 + run: | + rustup target add wasm32-unknown-unknown - name: Install dependencies run: | - apt-get update -qq - apt-get install -y -qq clang - - name: Install wasm-pack - run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh - - name: Build with wasm-pack + sudo apt-get update -qq + sudo apt-get install -y -qq clang + - name: Setup wasm-pack + uses: taiki-e/install-action@e7ef886cf8f69c25ecef6bbc2858a42e273496ec # v2.62.28 + with: + tool: wasm-pack + - name: Run tests with headless mode working-directory: ./datafusion/wasmtest - run: wasm-pack build --dev + run: | + # debuginfo=none because CI tests weren't completing successfully after this upstream PR: + # https://github.com/wasm-bindgen/wasm-bindgen/pull/4635 + RUSTFLAGS='--cfg getrandom_backend="wasm_js" -C debuginfo=none' wasm-pack test --headless --firefox + RUSTFLAGS='--cfg getrandom_backend="wasm_js" -C debuginfo=none' wasm-pack test --headless --chrome --chromedriver $CHROMEWEBDRIVER/chromedriver # verify that the benchmark queries return the correct results verify-benchmark-results: @@ -412,7 +431,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -459,7 +478,7 @@ jobs: --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -476,6 +495,28 @@ jobs: POSTGRES_HOST: postgres POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }} + sqllogictest-substrait: + name: "Run sqllogictest in Substrait round-trip mode" + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + submodules: true + fetch-depth: 1 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Run sqllogictest + # TODO: Right now several tests are failing in Substrait round-trip mode, so this + # command cannot be run for all the .slt files. Run it for just one that works (limit.slt) + # until most of the tickets in https://github.com/apache/datafusion/issues/16248 are addressed + # and this command can be run without filters. + run: cargo test --test sqllogictests -- --substrait-round-trip limit.slt + # Temporarily commenting out the Windows flow, the reason is enormously slow running build # Waiting for new Windows 2025 github runner # Details: https://github.com/apache/datafusion/issues/13726 @@ -495,27 +536,11 @@ jobs: # export PATH=$PATH:$HOME/d/protoc/bin # cargo test --lib --tests --bins --features avro,json,backtrace - # Commenting out intel mac build as so few users would ever use it - # Details: https://github.com/apache/datafusion/issues/13846 - # macos: - # name: cargo test (macos) - # runs-on: macos-latest - # steps: - # - uses: actions/checkout@v4 - # with: - # submodules: true - # fetch-depth: 1 - # - name: Setup Rust toolchain - # uses: ./.github/actions/setup-macos-builder - # - name: Run tests (excluding doctests) - # shell: bash - # run: cargo test run --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace - macos-aarch64: name: cargo test (macos-aarch64) runs-on: macos-14 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -532,7 +557,7 @@ jobs: container: image: amd64/rust:bullseye # Use the bullseye tag image which comes with python3.9 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -547,7 +572,7 @@ jobs: with: rust-version: stable - name: Run datafusion-common tests - run: cargo test --profile ci -p datafusion-common --features=pyarrow + run: cargo test --profile ci -p datafusion-common --features=pyarrow,sql vendor: name: Verify Vendored Code @@ -555,7 +580,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -572,7 +597,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder with: @@ -631,7 +656,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -641,6 +666,11 @@ jobs: rust-version: stable - name: Install Clippy run: rustup component add clippy + - name: Rust Dependency Cache + uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1 + with: + save-if: ${{ github.ref_name == 'main' }} + shared-key: "amd-ci-clippy" - name: Run clippy run: ci/scripts/rust_clippy.sh @@ -651,7 +681,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -672,7 +702,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: submodules: true fetch-depth: 1 @@ -680,7 +710,7 @@ jobs: uses: ./.github/actions/setup-builder with: rust-version: stable - - uses: actions/setup-node@v4 + - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0 with: node-version: "20" - name: Check if configs.md has been modified @@ -705,11 +735,14 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Install cargo-msrv - run: cargo install cargo-msrv + uses: taiki-e/install-action@e7ef886cf8f69c25ecef6bbc2858a42e273496ec # v2.62.28 + with: + tool: cargo-msrv + - name: Check datafusion working-directory: datafusion/core run: | @@ -719,10 +752,15 @@ jobs: # `rust-version` key of `Cargo.toml`. # # To reproduce: - # 1. Install the version of Rust that is failing. Example: - # rustup install 1.80.1 - # 2. Run the command that failed with that version. Example: - # cargo +1.80.1 check -p datafusion + # 1. Install the version of Rust that is failing. + # 2. Run the command that failed with that version. + # + # Example: + # # MSRV looks like "1.80.0" and is specified in Cargo.toml. We can read the value with the following command: + # msrv="$(cargo metadata --format-version=1 | jq '.packages[] | select( .name == "datafusion" ) | .rust_version' -r)" + # echo "MSRV: ${msrv}" + # rustup install "${msrv}" + # cargo "+${msrv}" check # # To resolve, either: # 1. Change your code to use older Rust features, @@ -739,3 +777,11 @@ jobs: - name: Check datafusion-proto working-directory: datafusion/proto run: cargo msrv --output-format json --log-target stdout verify + typos: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + persist-credentials: false + - uses: crate-ci/typos@80c8a4945eec0f6d464eaf9e65ed98ef085283d1 # v1.38.1 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 2312526824a91..d5fc9287aa6a5 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -27,7 +27,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v9 + - uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: stale-pr-message: "Thank you for your contribution. Unfortunately, this pull request is stale because it has been open 60 days with no activity. Please remove the stale label or comment or this will be closed in 7 days." days-before-pr-stale: 60 diff --git a/.gitignore b/.gitignore index 4ae32925d908e..8466a72adaec8 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ docker_cache *.orig .*.swp .*.swo +*.pending-snap venv/* diff --git a/Cargo.lock b/Cargo.lock index c708c516ab360..00bd64f21eb11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "abi_stable" @@ -61,15 +61,9 @@ dependencies = [ [[package]] name = "adler2" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" - -[[package]] -name = "adler32" -version = "1.2.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "ahash" @@ -77,23 +71,23 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "version_check", ] [[package]] name = "ahash" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "const-random", - "getrandom 0.2.15", + "getrandom 0.3.3", "once_cell", "version_check", - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -126,12 +120,6 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - [[package]] name = "android_system_properties" version = "0.1.5" @@ -149,9 +137,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.18" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +checksum = "3ae563653d1938f79b1ab1b5e668c87c76a9930414574a6583a7b7e11a8e6192" dependencies = [ "anstyle", "anstyle-parse", @@ -164,69 +152,69 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" [[package]] name = "anstyle-parse" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" -version = "1.1.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "anstyle-wincon" -version = "3.0.7" +version = "3.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a" dependencies = [ "anstyle", - "once_cell", - "windows-sys 0.59.0", + "once_cell_polyfill", + "windows-sys 0.60.2", ] [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "apache-avro" -version = "0.17.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aef82843a0ec9f8b19567445ad2421ceeb1d711514384bdd3d49fe37102ee13" +checksum = "3a033b4ced7c585199fb78ef50fca7fe2f444369ec48080c5fd072efa1a03cc7" dependencies = [ "bigdecimal", - "bzip2 0.4.4", + "bon", + "bzip2 0.6.0", "crc32fast", "digest", - "libflate", "log", + "miniz_oxide", "num-bigint", "quad-rand", - "rand 0.8.5", + "rand 0.9.2", "regex-lite", "serde", "serde_bytes", "serde_json", "snap", - "strum 0.26.3", - "strum_macros 0.26.4", - "thiserror 1.0.69", - "typed-builder", + "strum 0.27.2", + "strum_macros 0.27.2", + "thiserror", "uuid", "xz2", "zstd", @@ -246,9 +234,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc208515aa0151028e464cc94a692156e945ce5126abd3537bb7fd6ba2143ed1" +checksum = "6e833808ff2d94ed40d9379848a950d995043c7fb3e81a30b383f4c6033821cc" dependencies = [ "arrow-arith", "arrow-array", @@ -259,20 +247,20 @@ dependencies = [ "arrow-ipc", "arrow-json", "arrow-ord", + "arrow-pyarrow", "arrow-row", "arrow-schema", "arrow-select", "arrow-string", "half", - "pyo3", - "rand 0.8.5", + "rand 0.9.2", ] [[package]] name = "arrow-arith" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e07e726e2b3f7816a85c6a45b6ec118eeeabf0b2a8c208122ad949437181f49a" +checksum = "ad08897b81588f60ba983e3ca39bda2b179bdd84dced378e7df81a5313802ef8" dependencies = [ "arrow-array", "arrow-buffer", @@ -284,26 +272,26 @@ dependencies = [ [[package]] name = "arrow-array" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2262eba4f16c78496adfd559a29fe4b24df6088efc9985a873d58e92be022d5" +checksum = "8548ca7c070d8db9ce7aa43f37393e4bfcf3f2d3681df278490772fd1673d08d" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", "chrono-tz", "half", - "hashbrown 0.15.2", + "hashbrown 0.16.0", "num", ] [[package]] name = "arrow-buffer" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e899dade2c3b7f5642eb8366cfd898958bcca099cde6dfea543c7e8d3ad88d4" +checksum = "e003216336f70446457e280807a73899dd822feaf02087d31febca1363e2fccc" dependencies = [ "bytes", "half", @@ -312,9 +300,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4103d88c5b441525ed4ac23153be7458494c2b0c9a11115848fdb9b81f6f886a" +checksum = "919418a0681298d3a77d1a315f625916cb5678ad0d74b9c60108eb15fd083023" dependencies = [ "arrow-array", "arrow-buffer", @@ -333,9 +321,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d3cb0914486a3cae19a5cad2598e44e225d53157926d0ada03c20521191a65" +checksum = "bfa9bf02705b5cf762b6f764c65f04ae9082c7cfc4e96e0c33548ee3f67012eb" dependencies = [ "arrow-array", "arrow-cast", @@ -343,15 +331,14 @@ dependencies = [ "chrono", "csv", "csv-core", - "lazy_static", "regex", ] [[package]] name = "arrow-data" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a329fb064477c9ec5f0870d2f5130966f91055c7c5bce2b3a084f116bc28c3b" +checksum = "a5c64fff1d142f833d78897a772f2e5b55b36cb3e6320376f0961ab0db7bd6d0" dependencies = [ "arrow-buffer", "arrow-schema", @@ -361,9 +348,9 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7408f2bf3b978eddda272c7699f439760ebc4ac70feca25fefa82c5b8ce808d" +checksum = "8c8b0ba0784d56bc6266b79f5de7a24b47024e7b3a0045d2ad4df3d9b686099f" dependencies = [ "arrow-arith", "arrow-array", @@ -381,30 +368,32 @@ dependencies = [ "futures", "once_cell", "paste", - "prost", - "prost-types", + "prost 0.13.5", + "prost-types 0.13.5", "tonic", ] [[package]] name = "arrow-ipc" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddecdeab02491b1ce88885986e25002a3da34dd349f682c7cfe67bab7cc17b86" +checksum = "1d3594dcddccc7f20fd069bc8e9828ce37220372680ff638c5e00dea427d88f5" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", + "arrow-select", "flatbuffers", "lz4_flex", + "zstd", ] [[package]] name = "arrow-json" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d03b9340013413eb84868682ace00a1098c81a5ebc96d279f7ebf9a4cac3c0fd" +checksum = "88cf36502b64a127dc659e3b305f1d993a544eab0d48cce704424e62074dc04b" dependencies = [ "arrow-array", "arrow-buffer", @@ -413,18 +402,20 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.8.0", + "indexmap 2.11.4", "lexical-core", + "memchr", "num", "serde", "serde_json", + "simdutf8", ] [[package]] name = "arrow-ord" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f841bfcc1997ef6ac48ee0305c4dfceb1f7c786fe31e67c1186edf775e1f1160" +checksum = "3c8f82583eb4f8d84d4ee55fd1cb306720cddead7596edce95b50ee418edf66f" dependencies = [ "arrow-array", "arrow-buffer", @@ -433,11 +424,23 @@ dependencies = [ "arrow-select", ] +[[package]] +name = "arrow-pyarrow" +version = "56.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d924b32e96f8bb74d94cd82bd97b313c432fcb0ea331689ef9e7c6b8be4b258" +dependencies = [ + "arrow-array", + "arrow-data", + "arrow-schema", + "pyo3", +] + [[package]] name = "arrow-row" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1eeb55b0a0a83851aa01f2ca5ee5648f607e8506ba6802577afdda9d75cdedcd" +checksum = "9d07ba24522229d9085031df6b94605e0f4b26e099fb7cdeec37abd941a73753" dependencies = [ "arrow-array", "arrow-buffer", @@ -448,21 +451,22 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85934a9d0261e0fa5d4e2a5295107d743b543a6e0484a835d4b8db2da15306f9" +checksum = "b3aa9e59c611ebc291c28582077ef25c97f1975383f1479b12f3b9ffee2ffabe" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.4", "serde", + "serde_json", ] [[package]] name = "arrow-select" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e2932aece2d0c869dd2125feb9bd1709ef5c445daa3838ac4112dcfa0fda52c" +checksum = "8c41dbbd1e97bfcaee4fcb30e29105fb2c75e4d82ae4de70b792a5d3f66b2e7a" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-array", "arrow-buffer", "arrow-data", @@ -472,9 +476,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "912e38bd6a7a7714c1d9b61df80315685553b7455e8a6045c27531d8ecd5b458" +checksum = "53f5183c150fbc619eede22b861ea7c0eebed8eaac0333eaa7f6da5205fd504d" dependencies = [ "arrow-array", "arrow-buffer", @@ -499,22 +503,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "assert_cmd" -version = "2.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1835b7f27878de8525dc71410b5a31cdcc5f230aed5ba5df968e09c201b23d" -dependencies = [ - "anstyle", - "bstr", - "doc-comment", - "libc", - "predicates", - "predicates-core", - "predicates-tree", - "wait-timeout", -] - [[package]] name = "async-compression" version = "0.4.19" @@ -549,40 +537,18 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", -] - -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "async-trait" -version = "0.1.88" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -602,15 +568,15 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-config" -version = "1.6.1" +version = "1.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c39646d1a6b51240a1a23bb57ea4eebede7e16fbc237fdc876980233dcecb4f" +checksum = "04b37ddf8d2e9744a0b9c19ce0b78efe4795339a90b66b7bae77987092cd2e69" dependencies = [ "aws-credential-types", "aws-runtime", @@ -627,7 +593,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.2.0", + "http 1.3.1", "ring", "time", "tokio", @@ -638,9 +604,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.2" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4471bef4c22a06d2c7a1b6492493d3fdf24a805323109d6874f9c94d5906ac14" +checksum = "799a1290207254984cb7c05245111bc77958b92a3c9bb449598044b36341cce6" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -650,9 +616,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.12.6" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dabb68eb3a7aa08b46fddfd59a3d55c978243557a90ab804769f7e20e67d2b01" +checksum = "94b8ff6c09cd57b16da53641caa860168b88c172a5ee163b0288d3d6eea12786" dependencies = [ "aws-lc-sys", "zeroize", @@ -660,9 +626,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.27.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bbe221bbf523b625a4dd8585c7f38166e31167ec2ca98051dbcb4c3b6e825d2" +checksum = "0e44d16778acaf6a9ec9899b92cebd65580b83f685446bf2e1f5d3d732f99dcd" dependencies = [ "bindgen", "cc", @@ -673,9 +639,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.5.6" +version = "1.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0aff45ffe35196e593ea3b9dd65b320e51e2dda95aff4390bc459e461d09c6ad" +checksum = "2e1ed337dabcf765ad5f2fb426f13af22d576328aaf09eac8f70953530798ec0" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -689,7 +655,6 @@ dependencies = [ "fastrand", "http 0.2.12", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -698,9 +663,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.63.0" +version = "1.85.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1cb45b83b53b5cd55ee33fd9fd8a70750255a3f286e4dca20e882052f2b256f" +checksum = "2f2c741e2e439f07b5d1b33155e246742353d82167c785a2ff547275b7e32483" dependencies = [ "aws-credential-types", "aws-runtime", @@ -714,16 +679,15 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.64.0" +version = "1.87.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d4d9bc075ea6238778ed3951b65d3cde8c3864282d64fdcd19f2a90c0609f1" +checksum = "6428ae5686b18c0ee99f6f3c39d94ae3f8b42894cdc35c35d8fb2470e9db2d4c" dependencies = [ "aws-credential-types", "aws-runtime", @@ -737,16 +701,15 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.64.0" +version = "1.87.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819ccba087f403890fee4825eeab460e64c59345667d2b83a12cf544b581e3a7" +checksum = "5871bec9a79a3e8d928c7788d654f135dde0e71d2dd98089388bab36b37ef607" dependencies = [ "aws-credential-types", "aws-runtime", @@ -761,16 +724,15 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.3.0" +version = "1.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d03c3c05ff80d54ff860fe38c726f6f494c639ae975203a101335f223386db" +checksum = "084c34162187d39e3740cb635acd73c4e3a551a36146ad6fe8883c929c9f876c" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -781,8 +743,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.2.0", - "once_cell", + "http 1.3.1", "percent-encoding", "sha2", "time", @@ -802,9 +763,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.0" +version = "0.62.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5949124d11e538ca21142d1fba61ab0a2a2c1bc3ed323cdb3e4b878bfb83166" +checksum = "7c4dacf2d38996cf729f55e7a762b30918229917eca115de45dfa8dfb97796c9" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", @@ -812,9 +773,8 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "pin-utils", @@ -823,15 +783,15 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.0.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0497ef5d53065b7cd6a35e9c1654bd1fefeae5c52900d91d1b188b0af0f29324" +checksum = "147e8eea63a40315d704b97bf9bc9b8c1402ae94f89d5ad6f7550d963309da1b" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.2.0", + "http 1.3.1", "hyper", "hyper-rustls", "hyper-util", @@ -840,27 +800,27 @@ dependencies = [ "rustls-native-certs", "rustls-pki-types", "tokio", - "tower 0.5.2", + "tokio-rustls", + "tower", "tracing", ] [[package]] name = "aws-smithy-json" -version = "0.61.3" +version = "0.61.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92144e45819cae7dc62af23eac5a038a58aa544432d2102609654376a900bd07" +checksum = "eaa31b350998e703e9826b2104dd6f63be0508666e1aba88137af060e8944047" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-observability" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445d065e76bc1ef54963db400319f1dd3ebb3e0a74af20f7f7630625b0cc7cc0" +checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" dependencies = [ "aws-smithy-runtime-api", - "once_cell", ] [[package]] @@ -875,9 +835,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.8.1" +version = "1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0152749e17ce4d1b47c7747bdfec09dac1ccafdcbc741ebf9daa2a373356730f" +checksum = "4fa63ad37685ceb7762fa4d73d06f1d5493feb88e3f27259b9ed277f4c01b185" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -888,10 +848,9 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", - "once_cell", "pin-project-lite", "pin-utils", "tokio", @@ -900,15 +859,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.4" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da37cf5d57011cb1753456518ec76e31691f1f474b73934a284eb2a1c76510f" +checksum = "07f5e0fc8a6b3f2303f331b94504bbf754d85488f402d6f1dd7a6080f99afe56" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -917,15 +876,15 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.3.0" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836155caafba616c0ff9b07944324785de2ab016141c3550bd1c07882f8cee8f" +checksum = "d498595448e43de7f4296b7b7a18a8a02c61ec9349128c80a368f7c3b4ab11a8" dependencies = [ "base64-simd", "bytes", "bytes-utils", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -940,18 +899,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.9" +version = "0.60.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" +checksum = "3db87b96cb1b16c024980f133968d52882ca0daaee3a086c6decc500f6c99728" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.6" +version = "1.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3873f8deed8927ce8d04487630dc9ff73193bab64742a61d050e57a68dec4125" +checksum = "b069d19bf01e46298eaedd7c6f283fe565a59263e53eebec945f3e6398f42390" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -963,15 +922,14 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.9" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ - "async-trait", "axum-core", "bytes", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "itoa", @@ -983,21 +941,20 @@ dependencies = [ "rustversion", "serde", "sync_wrapper", - "tower 0.5.2", + "tower", "tower-layer", "tower-service", ] [[package]] name = "axum-core" -version = "0.4.5" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" dependencies = [ - "async-trait", "bytes", - "futures-util", - "http 1.2.0", + "futures-core", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "mime", @@ -1010,9 +967,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" dependencies = [ "addr2line", "cfg-if", @@ -1047,9 +1004,9 @@ dependencies = [ [[package]] name = "bigdecimal" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f31f3af01c5c65a07985c804d3366560e6fa7883d640a122819b14ec327482c" +checksum = "1a22f228ab7a1b23027ccc6c350b72868017af7ea8356fbdf19f8d991c690013" dependencies = [ "autocfg", "libm", @@ -1061,25 +1018,22 @@ dependencies = [ [[package]] name = "bindgen" -version = "0.69.5" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.4", "cexpr", "clang-sys", - "itertools 0.10.5", - "lazy_static", - "lazycell", + "itertools 0.13.0", "log", "prettyplease", "proc-macro2", "quote", "regex", - "rustc-hash 1.1.0", + "rustc-hash", "shlex", - "syn 2.0.100", - "which", + "syn 2.0.106", ] [[package]] @@ -1090,9 +1044,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" [[package]] name = "bitvec" @@ -1117,9 +1071,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.0" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34a796731680be7931955498a16a10b2270c7762963d5d570fdbfe02dcbf314f" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" dependencies = [ "arrayref", "arrayvec", @@ -1150,7 +1104,7 @@ dependencies = [ "futures-util", "hex", "home", - "http 1.2.0", + "http 1.3.1", "http-body-util", "hyper", "hyper-named-pipe", @@ -1168,7 +1122,7 @@ dependencies = [ "serde_json", "serde_repr", "serde_urlencoded", - "thiserror 2.0.12", + "thiserror", "tokio", "tokio-util", "tower-service", @@ -1187,11 +1141,36 @@ dependencies = [ "serde_with", ] +[[package]] +name = "bon" +version = "3.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2529c31017402be841eb45892278a6c21a000c0a17643af326c73a73f83f0fb" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82020dadcb845a345591863adb65d74fa8dc5c18a0b6d408470e13b7adc7005" +dependencies = [ + "darling", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.106", +] + [[package]] name = "borsh" -version = "1.5.5" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5430e3be710b68d984d1391c854eb431a9d548640711faa54eecb1df93db91cc" +checksum = "ad8646f98db542e39fc66e68a20b2144f6a732636df7c2354e74645faaa433ce" dependencies = [ "borsh-derive", "cfg_aliases", @@ -1199,22 +1178,22 @@ dependencies = [ [[package]] name = "borsh-derive" -version = "1.5.5" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8b668d39970baad5356d7c83a86fee3a539e6f93bf6764c97368243e17a0487" +checksum = "fdd1d3c0c2f5833f22386f252fe8ed005c7f59fdcddeef025c01b4c3b9fd9ac3" dependencies = [ "once_cell", "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "brotli" -version = "7.0.0" +version = "8.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" +checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1223,9 +1202,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "4.0.2" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1233,20 +1212,19 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", - "regex-automata", "serde", ] [[package]] name = "bumpalo" -version = "3.17.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytecheck" @@ -1294,21 +1272,20 @@ dependencies = [ [[package]] name = "bzip2" -version = "0.4.4" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +checksum = "49ecfb22d906f800d4fe833b6282cf4dc1c298f5057ca0b5445e5c209735ca47" dependencies = [ "bzip2-sys", - "libc", ] [[package]] name = "bzip2" -version = "0.5.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ecfb22d906f800d4fe833b6282cf4dc1c298f5057ca0b5445e5c209735ca47" +checksum = "bea8dcd42434048e4f7a304411d9273a411f647446c1234a65ce0554923f4cff" dependencies = [ - "bzip2-sys", + "libbz2-rs-sys", ] [[package]] @@ -1329,10 +1306,11 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "80f41ae168f955c12fb8960b057d70d0ca153fb83182b57d86380443527be7e9" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -1349,9 +1327,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "cfg_aliases" @@ -1361,38 +1339,26 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.39" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ - "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "serde", "wasm-bindgen", - "windows-targets 0.52.6", + "windows-link 0.2.0", ] [[package]] name = "chrono-tz" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efdce149c370f133a071ca8ef6ea340b7b88748ab0810097a9e2976eaa34b4f3" +checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" dependencies = [ "chrono", - "chrono-tz-build", - "phf", -] - -[[package]] -name = "chrono-tz-build" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" -dependencies = [ - "parse-zoneinfo", - "phf_codegen", + "phf 0.12.1", ] [[package]] @@ -1430,7 +1396,7 @@ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" dependencies = [ "glob", "libc", - "libloading 0.8.6", + "libloading 0.8.9", ] [[package]] @@ -1446,9 +1412,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.34" +version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e958897981290da2a852763fe9cdb89cd36977a5d729023127095fa94d95e2ff" +checksum = "e2134bb3ea021b78629caa971416385309e0131b351b25e01dc16fb54e1b5fae" dependencies = [ "clap_builder", "clap_derive", @@ -1456,9 +1422,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.34" +version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83b0f35019843db2160b5bb19ae09b4e6411ac33fc6a712003c33e03090e2489" +checksum = "c2ba64afa3c0a6df7fa517765e31314e983f51dda798ffba27b988194fb65dc9" dependencies = [ "anstream", "anstyle", @@ -1468,27 +1434,27 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.32" +version = "4.5.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" +checksum = "bbfd7eae0b0f1a6e63d4b13c9c478de77c2eb546fba158ad50b4203dc24b9f9c" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "clap_lex" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" [[package]] name = "clipboard-win" -version = "5.4.0" +version = "5.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15efe7a882b08f34e38556b14f2fb3daa98769d06c7f0c1b076dfd0d983bc892" +checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4" dependencies = [ "error-code", ] @@ -1504,33 +1470,46 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" [[package]] name = "comfy-table" -version = "7.1.4" +version = "7.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a65ebfec4fb190b6f90e944a817d60499ee0744e582530e2c9900a22e591d9a" +checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56" dependencies = [ - "unicode-segmentation", - "unicode-width 0.2.0", + "strum 0.26.3", + "strum_macros 0.26.4", + "unicode-width 0.2.1", ] [[package]] name = "console" -version = "0.15.10" +version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" dependencies = [ "encode_unicode", "libc", "once_cell", - "unicode-width 0.2.0", "windows-sys 0.59.0", ] +[[package]] +name = "console" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b430743a6eb14e9764d4260d4c0d8123087d504eeb9c48f2b2a5e810dd369df4" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width 0.2.1", + "windows-sys 0.61.0", +] + [[package]] name = "console_error_panic_hook" version = "0.1.7" @@ -1556,16 +1535,19 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "tiny-keccak", ] [[package]] name = "const_panic" -version = "0.2.12" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2459fc9262a1aa204eb4b5764ad4f189caec88aea9634389c0a25f8be7f6265e" +checksum = "e262cdaac42494e3ae34c43969f9cdeb7da178bdb4b66fa6a1ea2edb4c8ae652" +dependencies = [ + "typewit", +] [[package]] name = "constant_time_eq" @@ -1575,9 +1557,9 @@ checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" [[package]] name = "core-foundation" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" dependencies = [ "core-foundation-sys", "libc", @@ -1589,29 +1571,20 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" -[[package]] -name = "core2" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" -dependencies = [ - "memchr", -] - [[package]] name = "core_extensions" -version = "1.5.3" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92c71dc07c9721607e7a16108336048ee978c3a8b129294534272e8bac96c0ee" +checksum = "42bb5e5d0269fd4f739ea6cedaf29c16d81c27a7ce7582008e90eb50dcd57003" dependencies = [ "core_extensions_proc_macros", ] [[package]] name = "core_extensions_proc_macros" -version = "1.5.3" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69f3b219d28b6e3b4ac87bc1fc522e0803ab22e055da177bff0068c4150c61a6" +checksum = "533d38ecd2709b7608fb8e18e4504deb99e9a72879e6aa66373a76d8dc4259ea" [[package]] name = "cpufeatures" @@ -1624,9 +1597,9 @@ dependencies = [ [[package]] name = "crc32fast" -version = "1.4.2" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" dependencies = [ "cfg-if", ] @@ -1640,7 +1613,7 @@ dependencies = [ "anes", "cast", "ciborium", - "clap 4.5.34", + "clap 4.5.48", "criterion-plot", "futures", "is-terminal", @@ -1671,9 +1644,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.14" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" dependencies = [ "crossbeam-utils", ] @@ -1705,9 +1678,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-common" @@ -1742,19 +1715,31 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.9" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" +checksum = "ec09e802f5081de6157da9a75701d6c713d8dc3ba52571fd4bd25f412644e8a6" dependencies = [ - "quote", - "syn 2.0.100", + "ctor-proc-macro", + "dtor", ] +[[package]] +name = "ctor-proc-macro" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2931af7e13dc045d8e9d26afccc6fa115d64e115c9c84b1166288b46f6782c2" + +[[package]] +name = "cty" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b365fabc795046672053e29c954733ec3b05e4be654ab130fe8f1f94d7051f35" + [[package]] name = "darling" -version = "0.20.10" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" dependencies = [ "darling_core", "darling_macro", @@ -1762,35 +1747,29 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.10" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "darling_macro" -version = "0.20.10" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ "darling_core", "quote", - "syn 2.0.100", + "syn 2.0.106", ] -[[package]] -name = "dary_heap" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" - [[package]] name = "dashmap" version = "6.1.0" @@ -1807,14 +1786,14 @@ dependencies = [ [[package]] name = "datafusion" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "arrow-ipc", "arrow-schema", "async-trait", "bytes", - "bzip2 0.5.2", + "bzip2 0.6.0", "chrono", "criterion", "ctor", @@ -1841,6 +1820,7 @@ dependencies = [ "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", + "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", "datafusion-physical-optimizer", "datafusion-physical-plan", @@ -1850,6 +1830,7 @@ dependencies = [ "env_logger", "flate2", "futures", + "glob", "insta", "itertools 0.14.0", "log", @@ -1858,7 +1839,7 @@ dependencies = [ "parking_lot", "parquet", "paste", - "rand 0.8.5", + "rand 0.9.2", "rand_distr", "regex", "rstest", @@ -1877,7 +1858,7 @@ dependencies = [ [[package]] name = "datafusion-benchmarks" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "datafusion", @@ -1885,23 +1866,24 @@ dependencies = [ "datafusion-proto", "env_logger", "futures", + "libmimalloc-sys", "log", "mimalloc", "object_store", "parquet", - "rand 0.8.5", + "rand 0.9.2", + "regex", "serde", "serde_json", "snmalloc-rs", "structopt", - "test-utils", "tokio", "tokio-util", ] [[package]] name = "datafusion-catalog" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-trait", @@ -1914,7 +1896,6 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-plan", "datafusion-session", - "datafusion-sql", "futures", "itertools 0.14.0", "log", @@ -1925,7 +1906,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-trait", @@ -1937,24 +1918,22 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", - "datafusion-session", "futures", "log", "object_store", - "tempfile", "tokio", ] [[package]] name = "datafusion-cli" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", - "assert_cmd", "async-trait", "aws-config", "aws-credential-types", - "clap 4.5.34", + "chrono", + "clap 4.5.48", "ctor", "datafusion", "dirs", @@ -1962,31 +1941,33 @@ dependencies = [ "futures", "insta", "insta-cmd", + "log", "mimalloc", "object_store", "parking_lot", "parquet", - "predicates", "regex", "rstest", "rustyline", + "testcontainers", + "testcontainers-modules", "tokio", "url", ] [[package]] name = "datafusion-common" -version = "46.0.1" +version = "50.2.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "apache-avro", "arrow", "arrow-ipc", - "base64 0.22.1", "chrono", "half", "hashbrown 0.14.5", - "indexmap 2.8.0", + "hex", + "indexmap 2.11.4", "insta", "libc", "log", @@ -1994,7 +1975,7 @@ dependencies = [ "parquet", "paste", "pyo3", - "rand 0.8.5", + "rand 0.9.2", "recursive", "sqlparser", "tokio", @@ -2003,7 +1984,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "46.0.1" +version = "50.2.0" dependencies = [ "futures", "log", @@ -2012,19 +1993,21 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-compression", "async-trait", "bytes", - "bzip2 0.5.2", + "bzip2 0.6.0", "chrono", + "criterion", "datafusion-common", "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", + "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", @@ -2034,8 +2017,7 @@ dependencies = [ "itertools 0.14.0", "log", "object_store", - "parquet", - "rand 0.8.5", + "rand 0.9.2", "tempfile", "tokio", "tokio-util", @@ -2046,43 +2028,35 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "46.0.1" +version = "50.2.0" dependencies = [ "apache-avro", "arrow", "async-trait", "bytes", - "chrono", - "datafusion-catalog", "datafusion-common", "datafusion-datasource", - "datafusion-execution", - "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", "futures", "num-traits", "object_store", - "rstest", "serde_json", - "tokio", ] [[package]] name = "datafusion-datasource-csv" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-trait", "bytes", - "datafusion-catalog", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", - "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", @@ -2094,46 +2068,43 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-trait", "bytes", - "datafusion-catalog", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", - "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-session", "futures", "object_store", - "serde_json", "tokio", ] [[package]] name = "datafusion-datasource-parquet" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-trait", "bytes", "chrono", - "datafusion-catalog", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-physical-expr", + "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", - "datafusion-physical-optimizer", "datafusion-physical-plan", + "datafusion-pruning", "datafusion-session", "futures", "itertools 0.14.0", @@ -2141,25 +2112,27 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand 0.8.5", "tokio", ] [[package]] name = "datafusion-doc" -version = "46.0.1" +version = "50.2.0" [[package]] name = "datafusion-examples" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "arrow-flight", "arrow-schema", "async-trait", + "base64 0.22.1", "bytes", "dashmap", "datafusion", + "datafusion-ffi", + "datafusion-physical-expr-adapter", "datafusion-proto", "env_logger", "futures", @@ -2167,7 +2140,9 @@ dependencies = [ "mimalloc", "nix", "object_store", - "prost", + "prost 0.13.5", + "rand 0.9.2", + "serde_json", "tempfile", "test-utils", "tokio", @@ -2180,27 +2155,31 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", + "async-trait", "chrono", "dashmap", "datafusion-common", "datafusion-expr", "futures", + "insta", "log", "object_store", "parking_lot", - "rand 0.8.5", + "parquet", + "rand 0.9.2", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", + "async-trait", "chrono", "ctor", "datafusion-common", @@ -2210,7 +2189,9 @@ dependencies = [ "datafusion-functions-window-common", "datafusion-physical-expr-common", "env_logger", - "indexmap 2.8.0", + "indexmap 2.11.4", + "insta", + "itertools 0.14.0", "paste", "recursive", "serde_json", @@ -2219,36 +2200,40 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "datafusion-common", - "indexmap 2.8.0", + "indexmap 2.11.4", "itertools 0.14.0", "paste", ] [[package]] name = "datafusion-ffi" -version = "46.0.1" +version = "50.2.0" dependencies = [ "abi_stable", "arrow", + "arrow-schema", "async-ffi", "async-trait", "datafusion", + "datafusion-common", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "doc-comment", "futures", "log", - "prost", + "prost 0.13.5", "semver", "tokio", ] [[package]] name = "datafusion-functions" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "arrow-buffer", @@ -2257,17 +2242,19 @@ dependencies = [ "blake3", "chrono", "criterion", + "ctor", "datafusion-common", "datafusion-doc", "datafusion-execution", "datafusion-expr", "datafusion-expr-common", "datafusion-macros", + "env_logger", "hex", "itertools 0.14.0", "log", "md-5", - "rand 0.8.5", + "rand 0.9.2", "regex", "sha2", "tokio", @@ -2277,9 +2264,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "46.0.1" +version = "50.2.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2293,25 +2280,25 @@ dependencies = [ "half", "log", "paste", - "rand 0.8.5", + "rand 0.9.2", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "46.0.1" +version = "50.2.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", "datafusion-expr-common", "datafusion-physical-expr-common", - "rand 0.8.5", + "rand 0.9.2", ] [[package]] name = "datafusion-functions-nested" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "arrow-ord", @@ -2320,19 +2307,21 @@ dependencies = [ "datafusion-doc", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-macros", "datafusion-physical-expr-common", "itertools 0.14.0", "log", "paste", - "rand 0.8.5", + "rand 0.9.2", ] [[package]] name = "datafusion-functions-table" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-trait", @@ -2346,7 +2335,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "datafusion-common", @@ -2362,7 +2351,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "46.0.1" +version = "50.2.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2370,30 +2359,32 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "46.0.1" +version = "50.2.0" dependencies = [ - "datafusion-expr", + "datafusion-doc", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "datafusion-optimizer" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-trait", "chrono", + "criterion", "ctor", "datafusion-common", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions-aggregate", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-sql", "env_logger", - "indexmap 2.8.0", + "indexmap 2.11.4", "insta", "itertools 0.14.0", "log", @@ -2404,9 +2395,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "46.0.1" +version = "50.2.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2417,20 +2408,34 @@ dependencies = [ "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", - "indexmap 2.8.0", + "indexmap 2.11.4", + "insta", "itertools 0.14.0", - "log", + "parking_lot", "paste", - "petgraph", - "rand 0.8.5", + "petgraph 0.8.3", + "rand 0.9.2", "rstest", ] +[[package]] +name = "datafusion-physical-expr-adapter" +version = "50.2.0" +dependencies = [ + "arrow", + "datafusion-common", + "datafusion-expr", + "datafusion-functions", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "itertools 0.14.0", +] + [[package]] name = "datafusion-physical-expr-common" -version = "46.0.1" +version = "50.2.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "datafusion-common", "datafusion-expr-common", @@ -2440,28 +2445,28 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "datafusion-common", "datafusion-execution", "datafusion-expr", "datafusion-expr-common", - "datafusion-functions-nested", "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", + "datafusion-pruning", "insta", "itertools 0.14.0", - "log", "recursive", + "tokio", ] [[package]] name = "datafusion-physical-plan" -version = "46.0.1" +version = "50.2.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "arrow-ord", "arrow-schema", @@ -2473,6 +2478,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", @@ -2480,13 +2486,13 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.8.0", + "indexmap 2.11.4", "insta", "itertools 0.14.0", "log", "parking_lot", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.2", "rstest", "rstest_reuse", "tokio", @@ -2494,7 +2500,7 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "chrono", @@ -2508,54 +2514,82 @@ dependencies = [ "doc-comment", "object_store", "pbjson", - "prost", + "pretty_assertions", + "prost 0.13.5", "serde", "serde_json", - "strum 0.27.1", "tokio", ] [[package]] name = "datafusion-proto-common" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "datafusion-common", "doc-comment", "pbjson", - "prost", + "prost 0.13.5", "serde", - "serde_json", ] [[package]] -name = "datafusion-session" -version = "46.0.1" +name = "datafusion-pruning" +version = "50.2.0" dependencies = [ "arrow", - "async-trait", - "dashmap", "datafusion-common", - "datafusion-common-runtime", - "datafusion-execution", + "datafusion-datasource", "datafusion-expr", + "datafusion-expr-common", + "datafusion-functions-nested", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", - "datafusion-sql", - "futures", + "insta", "itertools 0.14.0", "log", - "object_store", +] + +[[package]] +name = "datafusion-session" +version = "50.2.0" +dependencies = [ + "async-trait", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-plan", "parking_lot", - "tokio", +] + +[[package]] +name = "datafusion-spark" +version = "50.2.0" +dependencies = [ + "arrow", + "bigdecimal", + "chrono", + "crc32fast", + "criterion", + "datafusion-catalog", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "log", + "rand 0.9.2", + "sha1", + "url", ] [[package]] name = "datafusion-sql" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "bigdecimal", + "chrono", "ctor", "datafusion-common", "datafusion-expr", @@ -2564,8 +2598,9 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-window", "env_logger", - "indexmap 2.8.0", + "indexmap 2.11.4", "insta", + "itertools 0.14.0", "log", "paste", "recursive", @@ -2576,15 +2611,17 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "46.0.1" +version = "50.2.0" dependencies = [ "arrow", "async-trait", "bigdecimal", "bytes", "chrono", - "clap 4.5.34", + "clap 4.5.48", "datafusion", + "datafusion-spark", + "datafusion-substrait", "env_logger", "futures", "half", @@ -2594,20 +2631,21 @@ dependencies = [ "object_store", "postgres-protocol", "postgres-types", + "regex", "rust_decimal", "sqllogictest", "sqlparser", "tempfile", "testcontainers", "testcontainers-modules", - "thiserror 2.0.12", + "thiserror", "tokio", "tokio-postgres", ] [[package]] name = "datafusion-substrait" -version = "46.0.1" +version = "50.2.0" dependencies = [ "async-recursion", "async-trait", @@ -2618,16 +2656,17 @@ dependencies = [ "itertools 0.14.0", "object_store", "pbjson-types", - "prost", + "prost 0.13.5", "serde_json", "substrait", "tokio", "url", + "uuid", ] [[package]] name = "datafusion-wasmtest" -version = "46.0.1" +version = "50.2.0" dependencies = [ "chrono", "console_error_panic_hook", @@ -2638,8 +2677,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-plan", "datafusion-sql", - "getrandom 0.2.15", - "insta", + "getrandom 0.3.3", "object_store", "tokio", "url", @@ -2649,19 +2687,19 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.11" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" dependencies = [ "powerfmt", "serde", ] [[package]] -name = "difflib" -version = "0.4.0" +name = "diff" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" [[package]] name = "digest" @@ -2692,7 +2730,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.0", ] [[package]] @@ -2703,7 +2741,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -2714,15 +2752,30 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "docker_credential" -version = "1.3.1" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31951f49556e34d90ed28342e1df7e1cb7a229c4cab0aecc627b5d91edd41d07" +checksum = "1d89dfcba45b4afad7450a99b39e751590463e45c04728cf555d36bb66940de8" dependencies = [ "base64 0.21.7", "serde", "serde_json", ] +[[package]] +name = "dtor" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97cbdf2ad6846025e8e25df05171abfb30e3ababa12ee0a0e44b9bbe570633a8" +dependencies = [ + "dtor-proc-macro", +] + +[[package]] +name = "dtor-proc-macro" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7454e41ff9012c00d53cf7f475c5e3afa3b91b7c90568495495e8d9bf47a1055" + [[package]] name = "dunce" version = "1.0.5" @@ -2731,9 +2784,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "dyn-clone" -version = "1.0.18" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "educe" @@ -2744,14 +2797,14 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "either" -version = "1.13.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "encode_unicode" @@ -2782,7 +2835,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -2797,9 +2850,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.7" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3716d7a920fb4fac5d84e9d4bce8ceb321e9414b4409da61b07b75c1e3d0697" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", @@ -2816,19 +2869,19 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.0", ] [[package]] name = "error-code" -version = "3.3.1" +version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" +checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" [[package]] name = "escape8259" @@ -2838,13 +2891,13 @@ checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" [[package]] name = "etcetera" -version = "0.8.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +checksum = "26c7b13d0780cb82722fd59f6f57f925e143427e4a75313a6c77243bf5326ae6" dependencies = [ "cfg-if", "home", - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -2861,13 +2914,13 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fd-lock" -version = "4.0.2" +version = "4.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e5768da2206272c81ef0b5e951a41862938a6070da63bcea197899942d3b947" +checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", - "rustix 0.38.44", - "windows-sys 0.52.0", + "rustix", + "windows-sys 0.59.0", ] [[package]] @@ -2902,16 +2955,22 @@ dependencies = [ [[package]] name = "filetime" -version = "0.2.25" +version = "0.2.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" +checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" dependencies = [ "cfg-if", "libc", "libredox", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ced73b1dacfc750a6db6c0a0c3a3853c8b41997e2e2c563dc90804ae6867959" + [[package]] name = "fixedbitset" version = "0.5.7" @@ -2920,33 +2979,25 @@ checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "flatbuffers" -version = "24.12.23" +version = "25.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f1baf0dbf96932ec9a3038d57900329c015b0bfb7b63d904f3bc27e2b02a096" +checksum = "1045398c1bfd89168b5fd3f1fc11f6e70b34f6f66300c87d44d3de849463abf1" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.9.4", "rustc_version", ] [[package]] name = "flate2" -version = "1.1.0" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" dependencies = [ "crc32fast", + "libz-rs-sys", "miniz_oxide", ] -[[package]] -name = "float-cmp" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" -dependencies = [ - "num-traits", -] - [[package]] name = "fnv" version = "1.0.7" @@ -2955,24 +3006,24 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "form_urlencoded" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ "percent-encoding", ] [[package]] name = "fs-err" -version = "3.1.0" +version = "3.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f89bda4c2a21204059a977ed3bfe746677dfd137b83c339e702b0ac91d482aa" +checksum = "44f150ffc8782f35521cec2b23727707cb4045706ba3c854e86bef66b3a8cdbd" dependencies = [ "autocfg", ] @@ -3045,7 +3096,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -3088,16 +3139,16 @@ dependencies = [ name = "gen" version = "0.1.0" dependencies = [ - "pbjson-build", - "prost-build", + "pbjson-build 0.8.0", + "prost-build 0.14.1", ] [[package]] name = "gen-common" version = "0.1.0" dependencies = [ - "pbjson-build", - "prost-build", + "pbjson-build 0.8.0", + "prost-build 0.14.1", ] [[package]] @@ -3121,27 +3172,29 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets 0.52.6", + "r-efi", + "wasi 0.14.7+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -3152,9 +3205,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "globset" @@ -3171,17 +3224,17 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.8" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.2.0", - "indexmap 2.8.0", + "http 1.3.1", + "indexmap 2.11.4", "slab", "tokio", "tokio-util", @@ -3190,13 +3243,14 @@ dependencies = [ [[package]] name = "half" -version = "2.5.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" +checksum = "e54c115d4f30f52c67202f079c5f9d8b49db4691f460fdb0b4c2e838261b2ba5" dependencies = [ "cfg-if", "crunchy", "num-traits", + "zerocopy", ] [[package]] @@ -3214,21 +3268,27 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "allocator-api2", ] [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "allocator-api2", "equivalent", "foldhash", ] +[[package]] +name = "hashbrown" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" + [[package]] name = "heck" version = "0.3.3" @@ -3246,9 +3306,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.4.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" [[package]] name = "hex" @@ -3287,9 +3347,9 @@ dependencies = [ [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -3314,27 +3374,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http 1.3.1", ] [[package]] name = "http-body-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", - "futures-util", - "http 1.2.0", + "futures-core", + "http 1.3.1", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -3344,26 +3404,28 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "humantime" -version = "2.1.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hyper" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" dependencies = [ + "atomic-waker", "bytes", "futures-channel", - "futures-util", + "futures-core", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "httparse", "httpdate", "itoa", "pin-project-lite", + "pin-utils", "smallvec", "tokio", "want", @@ -3386,12 +3448,11 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.5" +version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "futures-util", - "http 1.2.0", + "http 1.3.1", "hyper", "hyper-util", "rustls", @@ -3417,18 +3478,23 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ + "base64 0.22.1", "bytes", "futures-channel", + "futures-core", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "hyper", + "ipnet", + "libc", + "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.0", "tokio", "tower-service", "tracing", @@ -3451,16 +3517,17 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.61" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", - "windows-core 0.52.0", + "windows-core 0.62.0", ] [[package]] @@ -3474,21 +3541,22 @@ dependencies = [ [[package]] name = "icu_collections" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" dependencies = [ "displaydoc", + "potential_utf", "yoke", "zerofrom", "zerovec", ] [[package]] -name = "icu_locid" -version = "1.5.0" +name = "icu_locale_core" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" dependencies = [ "displaydoc", "litemap", @@ -3497,31 +3565,11 @@ dependencies = [ "zerovec", ] -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" dependencies = [ "displaydoc", "icu_collections", @@ -3529,67 +3577,54 @@ dependencies = [ "icu_properties", "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" dependencies = [ "displaydoc", "icu_collections", - "icu_locid_transform", + "icu_locale_core", "icu_properties_data", "icu_provider", - "tinystr", + "potential_utf", + "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" [[package]] name = "icu_provider" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" dependencies = [ "displaydoc", - "icu_locid", - "icu_provider_macros", + "icu_locale_core", "stable_deref_trait", "tinystr", "writeable", "yoke", "zerofrom", + "zerotrie", "zerovec", ] -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", -] - [[package]] name = "ident_case" version = "1.0.1" @@ -3598,9 +3633,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", "smallvec", @@ -3609,9 +3644,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -3630,45 +3665,44 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.8.0" +version = "2.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058" +checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.16.0", "serde", + "serde_core", ] [[package]] name = "indicatif" -version = "0.17.11" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +checksum = "70a646d946d06bedbbc4cac4c218acf4bbf2d87757a784857025f4d447e4e1cd" dependencies = [ - "console", - "number_prefix", + "console 0.16.1", "portable-atomic", - "unicode-width 0.2.0", + "unicode-width 0.2.1", + "unit-prefix", "web-time", ] [[package]] name = "indoc" -version = "2.0.5" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "insta" -version = "1.42.2" +version = "1.43.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50259abbaa67d11d2bcafc7ba1d094ed7a0c70e3ce893f0d0997f73558cb3084" +checksum = "46fdb647ebde000f43b5b53f773c30cf9b0cb4300453208713fa38b2c70935a0" dependencies = [ - "console", + "console 0.15.11", "globset", - "linked-hash-map", "once_cell", - "pin-project", "regex", "serde", "similar", @@ -3692,17 +3726,38 @@ version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" +[[package]] +name = "io-uring" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" +dependencies = [ + "bitflags 2.9.4", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is-terminal" -version = "0.4.15" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ "hermit-abi", "libc", @@ -3744,15 +3799,15 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jiff" -version = "0.2.4" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d699bc6dfc879fb1bf9bdff0d4c56f0884fc6f0d0eb0fba397a6d00cd9a6b85e" +checksum = "be1f93b8b1eb69c77f24bbb0afdf66f54b632ee39af40ca21c4365a1d7347e49" dependencies = [ "jiff-static", "log", @@ -3763,29 +3818,30 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.4" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d16e75759ee0aa64c57a56acbf43916987b20c77373cb7e808979e02b93c9f9" +checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ + "getrandom 0.3.3", "libc", ] [[package]] name = "js-sys" -version = "0.3.77" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" dependencies = [ "once_cell", "wasm-bindgen", @@ -3797,17 +3853,11 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "lexical-core" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" +checksum = "7d8d125a277f807e55a77304455eb7b1cb52f2b18c143b60e766c120bd64a594" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -3818,84 +3868,59 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" +checksum = "52a9f232fbd6f550bc0137dcb5f99ab674071ac2d690ac69704593cb4abbea56" dependencies = [ "lexical-parse-integer", "lexical-util", - "static_assertions", ] [[package]] name = "lexical-parse-integer" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" +checksum = "9a7a039f8fb9c19c996cd7b2fcce303c1b2874fe1aca544edc85c4a5f8489b34" dependencies = [ "lexical-util", - "static_assertions", ] [[package]] name = "lexical-util" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" -dependencies = [ - "static_assertions", -] +checksum = "2604dd126bb14f13fb5d1bd6a66155079cb9fa655b37f875b3a742c705dbed17" [[package]] name = "lexical-write-float" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" +checksum = "50c438c87c013188d415fbabbb1dceb44249ab81664efbd31b14ae55dabb6361" dependencies = [ "lexical-util", "lexical-write-integer", - "static_assertions", ] [[package]] name = "lexical-write-integer" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" +checksum = "409851a618475d2d5796377cad353802345cba92c867d9fbcde9cf4eac4e14df" dependencies = [ "lexical-util", - "static_assertions", ] [[package]] -name = "libc" -version = "0.2.171" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" - -[[package]] -name = "libflate" -version = "2.1.0" +name = "libbz2-rs-sys" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45d9dfdc14ea4ef0900c1cddbc8dcd553fbaacd8a4a282cf4018ae9dd04fb21e" -dependencies = [ - "adler32", - "core2", - "crc32fast", - "dary_heap", - "libflate_lz77", -] +checksum = "2c4a545a15244c7d945065b5d392b2d2d7f21526fba56ce51467b06ed445e8f7" [[package]] -name = "libflate_lz77" -version = "2.1.0" +name = "libc" +version = "0.2.176" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" -dependencies = [ - "core2", - "hashbrown 0.14.5", - "rle-decode-fast", -] +checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" [[package]] name = "libloading" @@ -3909,39 +3934,40 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.6" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-link 0.2.0", ] [[package]] name = "libm" -version = "0.2.11" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libmimalloc-sys" -version = "0.1.40" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07d0e07885d6a754b9c7993f2625187ad694ee985d60f23355ff0e7077261502" +checksum = "667f4fec20f29dfc6bc7357c582d91796c169ad7e2fce709468aefeb2c099870" dependencies = [ "cc", + "cty", "libc", ] [[package]] name = "libredox" -version = "0.1.3" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.4", "libc", - "redox_syscall 0.5.8", + "redox_syscall 0.5.17", ] [[package]] @@ -3952,39 +3978,36 @@ checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" dependencies = [ "anstream", "anstyle", - "clap 4.5.34", + "clap 4.5.48", "escape8259", ] [[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - -[[package]] -name = "linux-raw-sys" -version = "0.4.15" +name = "libz-rs-sys" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd" +dependencies = [ + "zlib-rs", +] [[package]] name = "linux-raw-sys" -version = "0.9.2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db9c683daf087dc577b7506e9695b3d556a9f3849903fa28186283afd6809e9" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" [[package]] name = "litemap" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -3992,15 +4015,21 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" + +[[package]] +name = "lru-slab" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "lz4_flex" -version = "0.11.3" +version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" +checksum = "08ab2867e3eeeca90e844d1940eab391c9dc5228783db2ed999acbc0a9ed375a" dependencies = [ "twox-hash", ] @@ -4018,9 +4047,9 @@ dependencies = [ [[package]] name = "matchit" -version = "0.7.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "md-5" @@ -4034,9 +4063,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.4" +version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" [[package]] name = "memoffset" @@ -4049,9 +4078,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.44" +version = "0.1.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99585191385958383e13f6b822e6b6d8d9cf928e7d286ceb092da92b43c87bc1" +checksum = "e1ee66a4b64c74f4ef288bcbb9192ad9c3feaad75193129ac8509af543894fd8" dependencies = [ "libmimalloc-sys", ] @@ -4080,29 +4109,30 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.4" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] name = "mio" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.52.0", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.59.0", ] [[package]] name = "multimap" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" [[package]] name = "nibble_vec" @@ -4115,11 +4145,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.29.0" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.4", "cfg-if", "cfg_aliases", "libc", @@ -4135,12 +4165,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "normalize-line-endings" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" - [[package]] name = "ntapi" version = "0.4.1" @@ -4152,12 +4176,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -4242,10 +4265,23 @@ dependencies = [ ] [[package]] -name = "number_prefix" -version = "0.4.0" +name = "objc2-core-foundation" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" +dependencies = [ + "bitflags 2.9.4", +] + +[[package]] +name = "objc2-io-kit" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +checksum = "71c1c64d6120e51cd86033f67176b1cb66780c2efe34dec55176f77befd93c0a" +dependencies = [ + "libc", + "objc2-core-foundation", +] [[package]] name = "object" @@ -4258,46 +4294,58 @@ dependencies = [ [[package]] name = "object_store" -version = "0.11.2" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cfccb68961a56facde1163f9319e0d15743352344e7808a11795fb99698dcaf" +checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740" dependencies = [ "async-trait", "base64 0.22.1", "bytes", "chrono", + "form_urlencoded", "futures", + "http 1.3.1", + "http-body-util", "humantime", "hyper", - "itertools 0.13.0", + "itertools 0.14.0", "md-5", "parking_lot", "percent-encoding", "quick-xml", - "rand 0.8.5", + "rand 0.9.2", "reqwest", "ring", "rustls-pemfile", "serde", "serde_json", - "snafu", + "serde_urlencoded", + "thiserror", "tokio", "tracing", "url", "walkdir", + "wasm-bindgen-futures", + "web-time", ] [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl-probe" @@ -4326,23 +4374,17 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "owo-colors" -version = "4.1.0" +version = "4.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb37767f6569cd834a413442455e0f066d0d522de8630436e2a1761d9726ba56" +checksum = "48dd4f4a2c8405440fd0462561f0e5806bd0f77e86f51c761481bdd4018b545e" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -4350,24 +4392,24 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.8", + "redox_syscall 0.5.17", "smallvec", "windows-targets 0.52.6", ] [[package]] name = "parquet" -version = "54.2.1" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f88838dca3b84d41444a0341b19f347e8098a3898b0f21536654b8b799e11abd" +checksum = "f0dbd48ad52d7dccf8ea1b90a3ddbfaea4f69878dd7683e51c507d4bc52b5b27" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-array", "arrow-buffer", "arrow-cast", @@ -4382,12 +4424,13 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.15.2", + "hashbrown 0.16.0", "lz4_flex", "num", "num-bigint", "object_store", "paste", + "ring", "seq-macro", "simdutf8", "snap", @@ -4395,7 +4438,6 @@ dependencies = [ "tokio", "twox-hash", "zstd", - "zstd-sys", ] [[package]] @@ -4420,16 +4462,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.100", -] - -[[package]] -name = "parse-zoneinfo" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" -dependencies = [ - "regex", + "syn 2.0.106", ] [[package]] @@ -4456,8 +4489,20 @@ checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9" dependencies = [ "heck 0.5.0", "itertools 0.13.0", - "prost", - "prost-types", + "prost 0.13.5", + "prost-types 0.13.5", +] + +[[package]] +name = "pbjson-build" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af22d08a625a2213a78dbb0ffa253318c5c79ce3133d32d296655a7bdfb02095" +dependencies = [ + "heck 0.5.0", + "itertools 0.14.0", + "prost 0.14.1", + "prost-types 0.14.1", ] [[package]] @@ -4469,17 +4514,17 @@ dependencies = [ "bytes", "chrono", "pbjson", - "pbjson-build", - "prost", - "prost-build", + "pbjson-build 0.7.0", + "prost 0.13.5", + "prost-build 0.13.5", "serde", ] [[package]] name = "percent-encoding" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "petgraph" @@ -4488,65 +4533,76 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.8.0", + "indexmap 2.11.4", +] + +[[package]] +name = "petgraph" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.5", + "indexmap 2.11.4", + "serde", ] [[package]] name = "phf" -version = "0.11.3" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" dependencies = [ - "phf_shared", + "phf_shared 0.12.1", ] [[package]] -name = "phf_codegen" -version = "0.11.3" +name = "phf" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" +checksum = "c1562dc717473dbaa4c1f85a36410e03c047b2e7df7f45ee938fbef64ae7fadf" dependencies = [ - "phf_generator", - "phf_shared", + "phf_shared 0.13.1", + "serde", ] [[package]] -name = "phf_generator" -version = "0.11.3" +name = "phf_shared" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" dependencies = [ - "phf_shared", - "rand 0.8.5", + "siphasher", ] [[package]] name = "phf_shared" -version = "0.11.3" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +checksum = "e57fef6bc5981e38c2ce2d63bfa546861309f875b8a75f092d1d54ae2d64f266" dependencies = [ "siphasher", ] [[package]] name = "pin-project" -version = "1.1.9" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfe2e71e1471fe07709406bf725f710b02927c9c54b2b5b2ec0e8087d97c327d" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.9" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6e859e6e5bd50440ab63c47e3ebabc90f26251f7c73c3d3e837b74a1cc3fa67" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -4563,9 +4619,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "plotters" @@ -4597,9 +4653,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "portable-atomic-util" @@ -4612,21 +4668,21 @@ dependencies = [ [[package]] name = "postgres-derive" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69700ea4603c5ef32d447708e6a19cd3e8ac197a000842e97f527daea5e4175f" +checksum = "56df96f5394370d1b20e49de146f9e6c25aa9ae750f449c9d665eafecb3ccae6" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "postgres-protocol" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76ff0abab4a9b844b93ef7b81f1efc0a366062aaef2cd702c76256b5dc075c54" +checksum = "fbef655056b916eb868048276cfd5d6a7dea4f81560dfd047f97c8c6fe3fcfd4" dependencies = [ "base64 0.22.1", "byteorder", @@ -4635,16 +4691,16 @@ dependencies = [ "hmac", "md-5", "memchr", - "rand 0.9.0", + "rand 0.9.2", "sha2", "stringprep", ] [[package]] name = "postgres-types" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" +checksum = "77a120daaabfcb0e324d5bf6e411e9222994cb3795c79943a0ef28ed27ea76e4" dependencies = [ "bytes", "chrono", @@ -4653,6 +4709,15 @@ dependencies = [ "postgres-protocol", ] +[[package]] +name = "potential_utf" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -4661,58 +4726,38 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" -dependencies = [ - "zerocopy 0.7.35", -] - -[[package]] -name = "predicates" -version = "3.1.3" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "anstyle", - "difflib", - "float-cmp", - "normalize-line-endings", - "predicates-core", - "regex", + "zerocopy", ] [[package]] -name = "predicates-core" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" - -[[package]] -name = "predicates-tree" -version = "1.0.12" +name = "pretty_assertions" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" dependencies = [ - "predicates-core", - "termtree", + "diff", + "yansi", ] [[package]] name = "prettyplease" -version = "0.2.31" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "proc-macro-crate" -version = "3.2.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" dependencies = [ "toml_edit", ] @@ -4743,9 +4788,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.93" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -4757,7 +4802,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.13.5", +] + +[[package]] +name = "prost" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" +dependencies = [ + "bytes", + "prost-derive 0.14.1", ] [[package]] @@ -4771,12 +4826,32 @@ dependencies = [ "log", "multimap", "once_cell", - "petgraph", + "petgraph 0.7.1", + "prettyplease", + "prost 0.13.5", + "prost-types 0.13.5", + "regex", + "syn 2.0.106", + "tempfile", +] + +[[package]] +name = "prost-build" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" +dependencies = [ + "heck 0.5.0", + "itertools 0.14.0", + "log", + "multimap", + "once_cell", + "petgraph 0.7.1", "prettyplease", - "prost", - "prost-types", + "prost 0.14.1", + "prost-types 0.14.1", "regex", - "syn 2.0.100", + "syn 2.0.106", "tempfile", ] @@ -4790,7 +4865,20 @@ dependencies = [ "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", +] + +[[package]] +name = "prost-derive" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -4799,7 +4887,16 @@ version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" dependencies = [ - "prost", + "prost 0.13.5", +] + +[[package]] +name = "prost-types" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" +dependencies = [ + "prost 0.14.1", ] [[package]] @@ -4813,9 +4910,9 @@ dependencies = [ [[package]] name = "psm" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f58e5423e24c18cc840e1c98370b3993c6649cd1678b4d24318bcf0a083cbe88" +checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" dependencies = [ "cc", ] @@ -4842,11 +4939,10 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.23.5" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" +checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" dependencies = [ - "cfg-if", "indoc", "libc", "memoffset", @@ -4860,9 +4956,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.23.5" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" +checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" dependencies = [ "once_cell", "target-lexicon", @@ -4870,9 +4966,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.23.5" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" +checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" dependencies = [ "libc", "pyo3-build-config", @@ -4880,27 +4976,27 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.23.5" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" +checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "pyo3-macros-backend" -version = "0.23.5" +version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" +checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" dependencies = [ "heck 0.5.0", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -4911,9 +5007,9 @@ checksum = "5a651516ddc9168ebd67b24afd085a718be02f8858fe406591b013d101ce2f40" [[package]] name = "quick-xml" -version = "0.37.2" +version = "0.38.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "165859e9e55f79d67b96c5d96f4e88b6f2695a1972849c15a6a3f5c59fc2c003" +checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" dependencies = [ "memchr", "serde", @@ -4921,37 +5017,40 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.6" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" dependencies = [ "bytes", + "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", - "socket2", - "thiserror 2.0.12", + "socket2 0.6.0", + "thiserror", "tokio", "tracing", + "web-time", ] [[package]] name = "quinn-proto" -version = "0.11.9" +version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ "bytes", - "getrandom 0.2.15", - "rand 0.8.5", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.2", "ring", - "rustc-hash 2.1.1", + "rustc-hash", "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.12", + "thiserror", "tinyvec", "tracing", "web-time", @@ -4959,27 +5058,33 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.10" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.6.0", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "quote" -version = "1.0.40" +version = "1.0.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "radium" version = "0.7.0" @@ -5009,13 +5114,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.1", - "zerocopy 0.8.18", + "rand_core 0.9.3", ] [[package]] @@ -5035,7 +5139,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.1", + "rand_core 0.9.3", ] [[package]] @@ -5044,34 +5148,33 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] name = "rand_core" -version = "0.9.1" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a88e0da7a2c97baa202165137c158d0a2e824ac465d13d81046727b34cb247d3" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.1", - "zerocopy 0.8.18", + "getrandom 0.3.3", ] [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.8.5", + "rand 0.9.2", ] [[package]] name = "rayon" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" dependencies = [ "either", "rayon-core", @@ -5079,9 +5182,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -5104,7 +5207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5118,29 +5221,49 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.8" +version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.4", ] [[package]] name = "redox_users" -version = "0.5.0" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "libredox", - "thiserror 2.0.12", + "thiserror", +] + +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] name = "regex" -version = "1.11.1" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "8b5288124840bee7b386bc413c487869b360b2b4ec421ea56425128692f2a82c" dependencies = [ "aho-corasick", "memchr", @@ -5150,9 +5273,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "833eb9ce86d40ef33cb1306d8accf7bc8ec2bfea4355cbdebb3df68b40925cad" dependencies = [ "aho-corasick", "memchr", @@ -5161,23 +5284,23 @@ dependencies = [ [[package]] name = "regex-lite" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" +checksum = "943f41321c63ef1c92fd763bfe054d2668f7f225a5c29f0105903dc2fc04ba30" [[package]] name = "regex-syntax" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" [[package]] name = "regress" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ef7fa9ed0256d64a688a3747d0fef7a88851c18a5e1d57f115f38ec2e09366" +checksum = "145bb27393fe455dd64d6cbc8d059adfa392590a45eadf079c01b11857e7b010" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.5", "memchr", ] @@ -5207,32 +5330,28 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.12" +version = "0.12.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" +checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper", "hyper-rustls", "hyper-util", - "ipnet", "js-sys", "log", - "mime", - "once_cell", "percent-encoding", "pin-project-lite", "quinn", "rustls", "rustls-native-certs", - "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", @@ -5241,25 +5360,25 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-util", - "tower 0.5.2", + "tower", + "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "wasm-streams", "web-sys", - "windows-registry", ] [[package]] name = "ring" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -5294,17 +5413,11 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "rle-decode-fast" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" - [[package]] name = "rstest" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03e905296805ab93e13c1ec3a03f4b6c4f35e9498a3d5fa96dc626d22c03cd89" +checksum = "6fc39292f8613e913f7df8fa892b8944ceb47c247b78e1b1ae2f09e019be789d" dependencies = [ "futures-timer", "futures-util", @@ -5314,9 +5427,9 @@ dependencies = [ [[package]] name = "rstest_macros" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef0053bbffce09062bee4bcc499b0fbe7a57b879f1efe088d6d8d4c7adcdef9b" +checksum = "1f168d99749d307be9de54d23fd226628d99768225ef08f6ffb52e0182a27746" dependencies = [ "cfg-if", "glob", @@ -5326,7 +5439,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.100", + "syn 2.0.106", "unicode-ident", ] @@ -5338,14 +5451,14 @@ checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", "rand 0.8.5", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "rust_decimal" -version = "1.37.1" +version = "1.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faa7de2ba56ac291bd90c6b9bece784a52ae1411f9506544b3eae36dd2356d50" +checksum = "c8975fc98059f365204d635119cf9c5a60ae67b841ed49b5422a9a7e56cdfac0" dependencies = [ "arrayvec", "borsh", @@ -5360,15 +5473,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" - -[[package]] -name = "rustc-hash" -version = "1.1.0" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" [[package]] name = "rustc-hash" @@ -5387,35 +5494,22 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags 2.8.0", - "errno", - "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - -[[package]] -name = "rustix" -version = "1.0.2" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7178faa4b75a30e269c71e61c353ce2748cf3d76f0c44c393f4e60abf49b825" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.4", "errno", "libc", - "linux-raw-sys 0.9.2", - "windows-sys 0.59.0", + "linux-raw-sys", + "windows-sys 0.61.0", ] [[package]] name = "rustls" -version = "0.23.23" +version = "0.23.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +checksum = "cd3c25631629d034ce7cd9940adc9d45762d46de2b0f57193c4443b92c6d4d40" dependencies = [ "aws-lc-rs", "once_cell", @@ -5449,18 +5543,19 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ "web-time", + "zeroize", ] [[package]] name = "rustls-webpki" -version = "0.102.8" +version = "0.103.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" dependencies = [ "aws-lc-rs", "ring", @@ -5470,17 +5565,17 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "rustyline" -version = "15.0.0" +version = "17.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee1e066dc922e513bda599c6ccb5f3bb2b0ea5870a579448f2622993f0a9a2f" +checksum = "e902948a25149d50edc1a8e0141aad50f54e22ba83ff988cf8f7c9ef07f50564" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.4", "cfg-if", "clipboard-win", "fd-lock", @@ -5491,16 +5586,16 @@ dependencies = [ "nix", "radix_trie", "unicode-segmentation", - "unicode-width 0.2.0", + "unicode-width 0.2.1", "utf8parse", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -5513,11 +5608,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.0", ] [[package]] @@ -5532,6 +5627,30 @@ dependencies = [ "serde_json", ] +[[package]] +name = "schemars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + +[[package]] +name = "schemars" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "schemars_derive" version = "0.8.22" @@ -5541,7 +5660,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5558,11 +5677,11 @@ checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" [[package]] name = "security-framework" -version = "3.2.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +checksum = "cc198e42d9b7510827939c9a15f5062a0c913f3371d765977e586d2fe6c16f4a" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.4", "core-foundation", "core-foundation-sys", "libc", @@ -5571,9 +5690,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.14.0" +version = "2.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" dependencies = [ "core-foundation-sys", "libc", @@ -5581,46 +5700,58 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.26" +version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" dependencies = [ "serde", + "serde_core", ] [[package]] name = "seq-macro" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] [[package]] name = "serde_bytes" -version = "0.11.15" +version = "0.11.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a" +checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" dependencies = [ "serde", + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5631,30 +5762,31 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ "itoa", "memchr", "ryu", "serde", + "serde_core", ] [[package]] name = "serde_repr" -version = "0.1.19" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5666,7 +5798,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5683,15 +5815,17 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.12.0" +version = "3.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa" +checksum = "c522100790450cf78eeac1507263d0a350d4d5b30df0c8e1fe051a10c22b376e" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.8.0", + "indexmap 2.11.4", + "schemars 0.9.0", + "schemars 1.0.4", "serde", "serde_derive", "serde_json", @@ -5701,14 +5835,14 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.12.0" +version = "3.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e" +checksum = "327ada00f7d64abaac1e55a6911e90cf665aa051b9a561c7006c157f4633135e" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5717,18 +5851,29 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.8.0", + "indexmap 2.11.4", "itoa", "ryu", "serde", "unsafe-libyaml", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" -version = "0.10.8" +version = "0.10.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures", @@ -5752,13 +5897,19 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b" dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "simdutf8" version = "0.1.5" @@ -5779,39 +5930,15 @@ checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" [[package]] name = "slab" -version = "0.4.9" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" [[package]] name = "smallvec" -version = "1.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" - -[[package]] -name = "snafu" -version = "0.8.5" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" -dependencies = [ - "snafu-derive", -] - -[[package]] -name = "snafu-derive" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "quote", - "syn 2.0.100", -] +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "snap" @@ -5839,19 +5966,29 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "sqllogictest" -version = "0.28.0" +version = "0.28.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b2f0b80fc250ed3fdd82fc88c0ada5ad62ee1ed5314ac5474acfa52082f518" +checksum = "3566426f72a13e393aa34ca3d542c5b0eb86da4c0db137ee9b5cfccc6179e52d" dependencies = [ "async-trait", "educe", @@ -5868,15 +6005,15 @@ dependencies = [ "similar", "subst", "tempfile", - "thiserror 2.0.12", + "thiserror", "tracing", ] [[package]] name = "sqlparser" -version = "0.55.0" +version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4521174166bac1ff04fe16ef4524c70144cd29682a45978978ca3d7f4e0be11" +checksum = "4591acadbcf52f0af60eafbb2c003232b2b4cd8de5f0e9437cb8b1b59046cc0f" dependencies = [ "log", "recursive", @@ -5891,7 +6028,7 @@ checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5902,9 +6039,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stacker" -version = "0.1.18" +version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d08feb8f695b465baed819b03c128dc23f57a694510ab1f06c77f763975685e" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" dependencies = [ "cc", "cfg-if", @@ -5913,12 +6050,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "stringprep" version = "0.1.5" @@ -5945,7 +6076,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5956,7 +6087,7 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -5991,12 +6122,9 @@ checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" [[package]] name = "strum" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" -dependencies = [ - "strum_macros 0.27.1", -] +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" [[package]] name = "strum_macros" @@ -6008,27 +6136,26 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "strum_macros" -version = "0.27.1" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "rustversion", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "subst" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e7942675ea19db01ef8cf15a1e6443007208e6c74568bd64162da26d40160d" +checksum = "0a9a86e5144f63c2d18334698269a8bfae6eece345c70b64821ea5b35054ec99" dependencies = [ "memchr", "unicode-width 0.1.14", @@ -6036,26 +6163,26 @@ dependencies = [ [[package]] name = "substrait" -version = "0.55.0" +version = "0.58.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3a359aeb711c1e1944c0c4178bbb2d679d39237ac5bfe28f7e0506e522e5ce6" +checksum = "de6d24c270c6c672a86c183c3a8439ba46c1936f93cf7296aa692de3b0ff0228" dependencies = [ "heck 0.5.0", "pbjson", - "pbjson-build", + "pbjson-build 0.7.0", "pbjson-types", "prettyplease", - "prost", - "prost-build", - "prost-types", + "prost 0.13.5", + "prost-build 0.13.5", + "prost-types 0.13.5", "protobuf-src", "regress", - "schemars", + "schemars 0.8.22", "semver", "serde", "serde_json", "serde_yaml", - "syn 2.0.100", + "syn 2.0.106", "typify", "walkdir", ] @@ -6079,9 +6206,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.100" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -6099,26 +6226,26 @@ dependencies = [ [[package]] name = "synstructure" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "sysinfo" -version = "0.33.1" +version = "0.37.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fc858248ea01b66f19d8e8a6d55f41deaf91e9d495246fd01368d99935c6c01" +checksum = "16607d5caffd1c07ce073528f9ed972d88db15dd44023fa57142963be3feb11f" dependencies = [ - "core-foundation-sys", "libc", "memchr", "ntapi", - "rayon", + "objc2-core-foundation", + "objc2-io-kit", "windows", ] @@ -6130,29 +6257,23 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "target-lexicon" -version = "0.12.16" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" [[package]] name = "tempfile" -version = "3.19.1" +version = "3.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" dependencies = [ "fastrand", - "getrandom 0.3.1", + "getrandom 0.3.3", "once_cell", - "rustix 1.0.2", - "windows-sys 0.59.0", + "rustix", + "windows-sys 0.61.0", ] -[[package]] -name = "termtree" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" - [[package]] name = "test-utils" version = "0.1.0" @@ -6161,14 +6282,14 @@ dependencies = [ "chrono-tz", "datafusion-common", "env_logger", - "rand 0.8.5", + "rand 0.9.2", ] [[package]] name = "testcontainers" -version = "0.23.3" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a4f01f39bb10fc2a5ab23eb0d888b1e2bb168c157f61a1b98e6c501c639c74" +checksum = "23bb7577dca13ad86a78e8271ef5d322f37229ec83b8d98da6d996c588a1ddb1" dependencies = [ "async-trait", "bollard", @@ -6185,7 +6306,7 @@ dependencies = [ "serde", "serde_json", "serde_with", - "thiserror 2.0.12", + "thiserror", "tokio", "tokio-stream", "tokio-tar", @@ -6195,9 +6316,9 @@ dependencies = [ [[package]] name = "testcontainers-modules" -version = "0.11.6" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d43ed4e8f58424c3a2c6c56dbea6643c3c23e8666a34df13c54f0a184e6c707" +checksum = "eac95cde96549fc19c6bf19ef34cc42bd56e264c1cb97e700e21555be0ecf9e2" dependencies = [ "testcontainers", ] @@ -6213,52 +6334,31 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.69" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl 1.0.69", -] - -[[package]] -name = "thiserror" -version = "2.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" -dependencies = [ - "thiserror-impl 2.0.12", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", + "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "thread_local" -version = "1.1.8" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" dependencies = [ "cfg-if", - "once_cell", ] [[package]] @@ -6274,9 +6374,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.37" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "itoa", @@ -6289,15 +6389,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.2" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.19" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", @@ -6314,9 +6414,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.7.6" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" dependencies = [ "displaydoc", "zerovec", @@ -6334,9 +6434,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" dependencies = [ "tinyvec_macros", ] @@ -6349,20 +6449,22 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.1" +version = "1.47.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a" +checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "slab", + "socket2 0.6.0", "tokio-macros", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -6373,14 +6475,14 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "tokio-postgres" -version = "0.7.13" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0" +checksum = "a156efe7fff213168257853e1dfde202eed5f487522cbbbf7d219941d753d853" dependencies = [ "async-trait", "byteorder", @@ -6391,12 +6493,12 @@ dependencies = [ "log", "parking_lot", "percent-encoding", - "phf", + "phf 0.13.1", "pin-project-lite", "postgres-protocol", "postgres-types", - "rand 0.9.0", - "socket2", + "rand 0.9.2", + "socket2 0.6.0", "tokio", "tokio-util", "whoami", @@ -6404,9 +6506,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.1" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" +checksum = "05f63835928ca123f1bef57abbcd23bb2ba0ac9ae1235f1e65bda0d06e7786bd" dependencies = [ "rustls", "tokio", @@ -6440,9 +6542,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.14" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" dependencies = [ "bytes", "futures-core", @@ -6453,34 +6555,46 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "32f1085dec27c2b6632b04c80b3bb1b4300d6495d1e129693bdda7d91e72eec1" +dependencies = [ + "serde_core", +] [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.23.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "f3effe7c0e86fdff4f69cdd2ccc1b96f933e24811c5441d44904e8683e27184b" dependencies = [ - "indexmap 2.8.0", + "indexmap 2.11.4", "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cf893c33be71572e0e9aa6dd15e6677937abd686b066eac3f8cd3531688a627" +dependencies = [ "winnow", ] [[package]] name = "tonic" -version = "0.12.3" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" dependencies = [ - "async-stream", "async-trait", "axum", "base64 0.22.1", "bytes", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -6488,11 +6602,11 @@ dependencies = [ "hyper-util", "percent-encoding", "pin-project", - "prost", - "socket2", + "prost 0.13.5", + "socket2 0.5.10", "tokio", "tokio-stream", - "tower 0.4.13", + "tower", "tower-layer", "tower-service", "tracing", @@ -6500,17 +6614,16 @@ dependencies = [ [[package]] name = "tower" -version = "0.4.13" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", - "indexmap 1.9.3", - "pin-project", + "indexmap 2.11.4", "pin-project-lite", - "rand 0.8.5", "slab", + "sync_wrapper", "tokio", "tokio-util", "tower-layer", @@ -6519,16 +6632,19 @@ dependencies = [ ] [[package]] -name = "tower" -version = "0.5.2" +name = "tower-http" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" dependencies = [ - "futures-core", + "bitflags 2.9.4", + "bytes", "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "iri-string", "pin-project-lite", - "sync_wrapper", - "tokio", + "tower", "tower-layer", "tower-service", ] @@ -6558,20 +6674,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.28" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "tracing-core" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ "once_cell", "valuable", @@ -6590,9 +6706,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "nu-ansi-term", "sharded-slab", @@ -6625,13 +6741,9 @@ checksum = "e78122066b0cb818b8afd08f7ed22f7fdbc3e90815035726f0840d0d26c0747a" [[package]] name = "twox-hash" -version = "1.6.3" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" -dependencies = [ - "cfg-if", - "static_assertions", -] +checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" [[package]] name = "typed-arena" @@ -6639,37 +6751,23 @@ version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" -[[package]] -name = "typed-builder" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06fbd5b8de54c5f7c91f6fe4cebb949be2125d7758e630bb58b1d831dbce600" -dependencies = [ - "typed-builder-macro", -] - -[[package]] -name = "typed-builder-macro" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", -] - [[package]] name = "typenum" version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "typewit" +version = "1.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" + [[package]] name = "typify" -version = "0.3.0" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e03ba3643450cfd95a1aca2e1938fef63c1c1994489337998aff4ad771f21ef8" +checksum = "7144144e97e987c94758a3017c920a027feac0799df325d6df4fc8f08d02068e" dependencies = [ "typify-impl", "typify-macro", @@ -6677,38 +6775,38 @@ dependencies = [ [[package]] name = "typify-impl" -version = "0.3.0" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bce48219a2f3154aaa2c56cbf027728b24a3c8fe0a47ed6399781de2b3f3eeaf" +checksum = "062879d46aa4c9dfe0d33b035bbaf512da192131645d05deacb7033ec8581a09" dependencies = [ "heck 0.5.0", "log", "proc-macro2", "quote", "regress", - "schemars", + "schemars 0.8.22", "semver", "serde", "serde_json", - "syn 2.0.100", - "thiserror 2.0.12", + "syn 2.0.106", + "thiserror", "unicode-ident", ] [[package]] name = "typify-macro" -version = "0.3.0" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b5780d745920ed73c5b7447496a9b5c42ed2681a9b70859377aec423ecf02b" +checksum = "9708a3ceb6660ba3f8d2b8f0567e7d4b8b198e2b94d093b8a6077a751425de9e" dependencies = [ "proc-macro2", "quote", - "schemars", + "schemars 0.8.22", "semver", "serde", "serde_json", "serde_tokenstream", - "syn 2.0.100", + "syn 2.0.106", "typify-impl", ] @@ -6720,9 +6818,9 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.16" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" [[package]] name = "unicode-normalization" @@ -6753,15 +6851,21 @@ checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode-width" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" [[package]] name = "unindent" -version = "0.2.3" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "unit-prefix" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "323402cff2dd658f39ca17c789b502021b3f18707c91cdf22e3838e1b4023817" [[package]] name = "unsafe-libyaml" @@ -6777,9 +6881,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.4" +version = "2.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" dependencies = [ "form_urlencoded", "idna", @@ -6793,12 +6897,6 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -6813,11 +6911,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.16.0" +version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ - "getrandom 0.3.1", + "getrandom 0.3.3", "js-sys", "serde", "wasm-bindgen", @@ -6841,15 +6939,6 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" -[[package]] -name = "wait-timeout" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" -dependencies = [ - "libc", -] - [[package]] name = "walkdir" version = "2.5.0" @@ -6871,17 +6960,26 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] @@ -6892,35 +6990,36 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" dependencies = [ "cfg-if", "once_cell", "rustversion", "wasm-bindgen-macro", + "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" dependencies = [ "bumpalo", "log", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.50" +version = "0.4.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" +checksum = "7e038d41e478cc73bae0ff9b36c60cff1c98b8f38f8d7e8061e79ee63608ac5c" dependencies = [ "cfg-if", "js-sys", @@ -6931,9 +7030,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6941,31 +7040,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" dependencies = [ "unicode-ident", ] [[package]] name = "wasm-bindgen-test" -version = "0.3.50" +version = "0.3.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66c8d5e33ca3b6d9fa3b4676d774c5778031d27a578c2b007f905acf816152c3" +checksum = "4e381134e148c1062f965a42ed1f5ee933eef2927c3f70d1812158f711d39865" dependencies = [ "js-sys", "minicov", @@ -6976,13 +7075,13 @@ dependencies = [ [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.50" +version = "0.3.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17d5042cc5fa009658f9a7333ef24291b1291a25b6382dd68862a7f3b969f69b" +checksum = "b673bca3298fe582aeef8352330ecbad91849f85090805582400850f8270a2e8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] @@ -7000,9 +7099,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.77" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +checksum = "9367c417a924a74cae129e6a2ae3b47fabb1f8995595ab474029da749a8be120" dependencies = [ "js-sys", "wasm-bindgen", @@ -7018,25 +7117,13 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix 0.38.44", -] - [[package]] name = "whoami" -version = "1.5.2" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" dependencies = [ - "redox_syscall 0.5.8", + "libredox", "wasite", "web-sys", ] @@ -7059,11 +7146,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.9" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.0", ] [[package]] @@ -7074,103 +7161,141 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.57.0" +version = "0.61.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" dependencies = [ - "windows-core 0.57.0", - "windows-targets 0.52.6", + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link 0.1.3", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.2", ] [[package]] name = "windows-core" -version = "0.52.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ - "windows-targets 0.52.6", + "windows-implement", + "windows-interface", + "windows-link 0.1.3", + "windows-result 0.3.4", + "windows-strings 0.4.2", ] [[package]] name = "windows-core" -version = "0.57.0" +version = "0.62.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" +checksum = "57fe7168f7de578d2d8a05b07fd61870d2e73b4020e9f49aa00da8471723497c" dependencies = [ "windows-implement", "windows-interface", - "windows-result 0.1.2", - "windows-targets 0.52.6", + "windows-link 0.2.0", + "windows-result 0.4.0", + "windows-strings 0.5.0", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link 0.1.3", + "windows-threading", ] [[package]] name = "windows-implement" -version = "0.57.0" +version = "0.60.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "windows-interface" -version = "0.57.0" +version = "0.59.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] -name = "windows-registry" +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-link" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" + +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" dependencies = [ - "windows-result 0.2.0", - "windows-strings", - "windows-targets 0.52.6", + "windows-core 0.61.2", + "windows-link 0.1.3", ] [[package]] name = "windows-result" -version = "0.1.2" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-targets 0.52.6", + "windows-link 0.1.3", ] [[package]] name = "windows-result" -version = "0.2.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +checksum = "7084dcc306f89883455a206237404d3eaf961e5bd7e0f312f7c91f57eb44167f" dependencies = [ - "windows-targets 0.52.6", + "windows-link 0.2.0", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-result 0.2.0", - "windows-targets 0.52.6", + "windows-link 0.1.3", ] [[package]] -name = "windows-sys" -version = "0.48.0" +name = "windows-strings" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "7218c655a553b0bed4426cf54b20d7ba363ef543b52d515b3e48d7fd55318dda" dependencies = [ - "windows-targets 0.48.5", + "windows-link 0.2.0", ] [[package]] @@ -7192,18 +7317,21 @@ dependencies = [ ] [[package]] -name = "windows-targets" -version = "0.48.5" +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.3", +] + +[[package]] +name = "windows-sys" +version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +checksum = "e201184e40b2ede64bc2ea34968b28e33622acdbbf37104f0e4a33f7abe657aa" dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", + "windows-link 0.2.0", ] [[package]] @@ -7215,7 +7343,7 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", @@ -7223,10 +7351,30 @@ dependencies = [ ] [[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" +name = "windows-targets" +version = "0.53.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +dependencies = [ + "windows-link 0.1.3", + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link 0.1.3", +] [[package]] name = "windows_aarch64_gnullvm" @@ -7235,10 +7383,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" +name = "windows_aarch64_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" [[package]] name = "windows_aarch64_msvc" @@ -7247,10 +7395,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] -name = "windows_i686_gnu" -version = "0.48.5" +name = "windows_aarch64_msvc" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" [[package]] name = "windows_i686_gnu" @@ -7258,6 +7406,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" @@ -7265,10 +7419,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] -name = "windows_i686_msvc" -version = "0.48.5" +name = "windows_i686_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" [[package]] name = "windows_i686_msvc" @@ -7277,10 +7431,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" +name = "windows_i686_msvc" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" [[package]] name = "windows_x86_64_gnu" @@ -7289,10 +7443,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" +name = "windows_x86_64_gnu" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" [[package]] name = "windows_x86_64_gnullvm" @@ -7301,10 +7455,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" +name = "windows_x86_64_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" [[package]] name = "windows_x86_64_msvc" @@ -7313,34 +7467,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] -name = "winnow" -version = "0.7.2" +name = "windows_x86_64_msvc" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603" -dependencies = [ - "memchr", -] +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] -name = "wit-bindgen-rt" -version = "0.33.0" +name = "winnow" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" dependencies = [ - "bitflags 2.8.0", + "memchr", ] [[package]] -name = "write16" -version = "1.0.0" +name = "wit-bindgen" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" [[package]] name = "writeable" -version = "0.5.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "wyz" @@ -7353,13 +7504,12 @@ dependencies = [ [[package]] name = "xattr" -version = "1.4.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" dependencies = [ "libc", - "linux-raw-sys 0.4.15", - "rustix 0.38.44", + "rustix", ] [[package]] @@ -7377,11 +7527,17 @@ dependencies = [ "lzma-sys", ] +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yoke" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ "serde", "stable_deref_trait", @@ -7391,75 +7547,54 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", "synstructure", ] [[package]] name = "zerocopy" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" -dependencies = [ - "byteorder", - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy" -version = "0.8.18" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79386d31a42a4996e3336b0919ddb90f81112af416270cff95b5f5af22b839c2" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" dependencies = [ - "zerocopy-derive 0.8.18", + "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.35" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", -] - -[[package]] -name = "zerocopy-derive" -version = "0.8.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76331675d372f91bf8d17e13afbd5fe639200b73d01f0fc748bb059f9cca2db7" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", + "syn 2.0.106", ] [[package]] name = "zerofrom" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", "synstructure", ] @@ -7469,11 +7604,22 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + [[package]] name = "zerovec" -version = "0.10.4" +version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" dependencies = [ "yoke", "zerofrom", @@ -7482,15 +7628,21 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.10.3" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.106", ] +[[package]] +name = "zlib-rs" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2" + [[package]] name = "zstd" version = "0.13.3" @@ -7502,18 +7654,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.2.1" +version = "7.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.13+zstd.1.5.6" +version = "2.0.16+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index d26446c111675..dd0b20de528af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,14 +40,17 @@ members = [ "datafusion/functions-window-common", "datafusion/optimizer", "datafusion/physical-expr", + "datafusion/physical-expr-adapter", "datafusion/physical-expr-common", "datafusion/physical-optimizer", + "datafusion/pruning", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/proto-common", "datafusion/proto-common/gen", "datafusion/session", + "datafusion/spark", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", @@ -73,9 +76,9 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) -rust-version = "1.82.0" +rust-version = "1.87.0" # Define DataFusion version -version = "46.0.1" +version = "50.2.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -86,70 +89,75 @@ version = "46.0.1" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -apache-avro = { version = "0.17", default-features = false } -arrow = { version = "54.2.1", features = [ +apache-avro = { version = "0.20", default-features = false } +arrow = { version = "56.2.0", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "54.1.0", default-features = false } -arrow-flight = { version = "54.2.1", features = [ +arrow-buffer = { version = "56.2.0", default-features = false } +arrow-flight = { version = "56.2.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "54.2.0", default-features = false, features = [ +arrow-ipc = { version = "56.2.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "54.1.0", default-features = false } -arrow-schema = { version = "54.1.0", default-features = false } -async-trait = "0.1.88" -bigdecimal = "0.4.7" +arrow-ord = { version = "56.2.0", default-features = false } +arrow-schema = { version = "56.2.0", default-features = false } +async-trait = "0.1.89" +bigdecimal = "0.4.8" bytes = "1.10" -chrono = { version = "0.4.38", default-features = false } +chrono = { version = "0.4.42", default-features = false } criterion = "0.5.1" -ctor = "0.2.9" +ctor = "0.4.3" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "46.0.1", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "46.0.1" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "46.0.1" } -datafusion-common = { path = "datafusion/common", version = "46.0.1", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "46.0.1" } -datafusion-datasource = { path = "datafusion/datasource", version = "46.0.1", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "46.0.1", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "46.0.1", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "46.0.1", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "46.0.1", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "46.0.1" } -datafusion-execution = { path = "datafusion/execution", version = "46.0.1" } -datafusion-expr = { path = "datafusion/expr", version = "46.0.1" } -datafusion-expr-common = { path = "datafusion/expr-common", version = "46.0.1" } -datafusion-ffi = { path = "datafusion/ffi", version = "46.0.1" } -datafusion-functions = { path = "datafusion/functions", version = "46.0.1" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "46.0.1" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "46.0.1" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "46.0.1" } -datafusion-functions-table = { path = "datafusion/functions-table", version = "46.0.1" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "46.0.1" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "46.0.1" } -datafusion-macros = { path = "datafusion/macros", version = "46.0.1" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "46.0.1", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "46.0.1", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "46.0.1", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "46.0.1" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "46.0.1" } -datafusion-proto = { path = "datafusion/proto", version = "46.0.1" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "46.0.1" } -datafusion-session = { path = "datafusion/session", version = "46.0.1" } -datafusion-sql = { path = "datafusion/sql", version = "46.0.1" } +datafusion = { path = "datafusion/core", version = "50.2.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "50.2.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "50.2.0" } +datafusion-common = { path = "datafusion/common", version = "50.2.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "50.2.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "50.2.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "50.2.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "50.2.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "50.2.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "50.2.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "50.2.0" } +datafusion-execution = { path = "datafusion/execution", version = "50.2.0", default-features = false } +datafusion-expr = { path = "datafusion/expr", version = "50.2.0", default-features = false } +datafusion-expr-common = { path = "datafusion/expr-common", version = "50.2.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "50.2.0" } +datafusion-functions = { path = "datafusion/functions", version = "50.2.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "50.2.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "50.2.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "50.2.0", default-features = false } +datafusion-functions-table = { path = "datafusion/functions-table", version = "50.2.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "50.2.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "50.2.0" } +datafusion-macros = { path = "datafusion/macros", version = "50.2.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "50.2.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "50.2.0", default-features = false } +datafusion-physical-expr-adapter = { path = "datafusion/physical-expr-adapter", version = "50.2.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "50.2.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "50.2.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "50.2.0" } +datafusion-proto = { path = "datafusion/proto", version = "50.2.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "50.2.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "50.2.0" } +datafusion-session = { path = "datafusion/session", version = "50.2.0" } +datafusion-spark = { path = "datafusion/spark", version = "50.2.0" } +datafusion-sql = { path = "datafusion/sql", version = "50.2.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "50.2.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" -half = { version = "2.5.0", default-features = false } +half = { version = "2.7.0", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } -indexmap = "2.8.0" +hex = { version = "0.4.3" } +indexmap = "2.11.4" itertools = "0.14" log = "^0.4" -object_store = { version = "0.11.0", default-features = false } +object_store = { version = "0.12.4", default-features = false } parking_lot = "0.12" -parquet = { version = "54.2.1", default-features = false, features = [ +parquet = { version = "56.2.0", default-features = false, features = [ "arrow", "async", "object_store", @@ -157,25 +165,70 @@ parquet = { version = "54.2.1", default-features = false, features = [ pbjson = { version = "0.7.0" } pbjson-types = "0.7" # Should match arrow-flight's version of prost. -insta = { version = "1.41.1", features = ["glob", "filters"] } +insta = { version = "1.43.2", features = ["glob", "filters"] } prost = "0.13.1" -rand = "0.8.5" +rand = "0.9" recursive = "0.1.1" -regex = "1.8" -rstest = "0.24.0" +regex = "1.11" +rstest = "0.25.0" serde_json = "1" -sqlparser = { version = "0.55.0", features = ["visitor"] } +sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] } tempfile = "3" -tokio = { version = "1.44", features = ["macros", "rt", "sync"] } -url = "2.5.4" +testcontainers = { version = "0.24", features = ["default"] } +testcontainers-modules = { version = "0.12" } +tokio = { version = "1.47", features = ["macros", "rt", "sync"] } +url = "2.5.7" +[workspace.lints.clippy] +# Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) +large_futures = "warn" +used_underscore_binding = "warn" +or_fun_call = "warn" +unnecessary_lazy_evaluations = "warn" +uninlined_format_args = "warn" +inefficient_to_string = "warn" + +[workspace.lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = [ + 'cfg(datafusion_coop, values("tokio", "tokio_fallback", "per_stream"))', + "cfg(tarpaulin)", + "cfg(tarpaulin_include)", +] } +unused_qualifications = "deny" + +# -------------------- +# Compilation Profiles +# -------------------- +# A Cargo profile is a preset for the compiler/linker knobs that trade off: +# - Build time: how quickly code compiles and links +# - Runtime performance: how fast the resulting binaries execute +# - Binary size: how large the executables end up +# - Debuggability: how much debug information is preserved for debugging and profiling +# +# Profiles available: +# - dev: default debug build; fastest to compile, slowest to run, full debug info +# for everyday development. +# Run: cargo run +# - release: optimized build; slowest to compile, fastest to run, smallest +# binaries for public releases. +# Run: cargo run --release +# - release-nonlto: skips LTO, so it builds quicker while staying close to +# release performance. It is useful when developing performance optimizations. +# Run: cargo run --profile release-nonlto +# - profiling: inherits release optimizations but retains debug info to support +# profiling tools and flamegraphs. +# Run: cargo run --profile profiling +# - ci: derived from `dev` but disables incremental builds and strips dependency +# symbols to keep CI artifacts small and reproducible. +# Run: cargo run --profile ci +# +# If you want to optimize compilation, the `compile_profile` benchmark can be useful. +# See `benchmarks/README.md` for more details. [profile.release] codegen-units = 1 lto = true strip = true # Eliminate debug information to minimize binary size -# the release profile takes a long time to build so we can use this profile during development to save time -# cargo build --profile release-nonlto [profile.release-nonlto] codegen-units = 16 debug-assertions = false @@ -188,28 +241,20 @@ rpath = false strip = false # Retain debug info for flamegraphs [profile.ci] +debug = false inherits = "dev" incremental = false -# ci turns off debug info, etc. for dependencies to allow for smaller binaries making caching more effective +# This rule applies to every package except workspace members (dependencies +# such as `arrow` and `tokio`). It disables debug info and related features on +# dependencies so their binaries stay smaller, improving cache reuse. [profile.ci.package."*"] debug = false debug-assertions = false strip = "debuginfo" incremental = false -# release inherited profile keeping debug information and symbols -# for mem/cpu profiling [profile.profiling] inherits = "release" debug = true strip = false - -[workspace.lints.clippy] -# Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) -large_futures = "warn" -used_underscore_binding = "warn" - -[workspace.lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } -unused_qualifications = "deny" diff --git a/NOTICE.txt b/NOTICE.txt index 21be1a20d554f..7f3c80d606c07 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -1,5 +1,5 @@ Apache DataFusion -Copyright 2019-2024 The Apache Software Foundation +Copyright 2019-2025 The Apache Software Foundation This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). \ No newline at end of file +The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md index 158033d40599f..4c4b955176b2b 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ [![Open Issues][open-issues-badge]][open-issues-url] [![Discord chat][discord-badge]][discord-url] [![Linkedin][linkedin-badge]][linkedin-url] +![Crates.io MSRV][msrv-badge] [crates-badge]: https://img.shields.io/crates/v/datafusion.svg [crates-url]: https://crates.io/crates/datafusion @@ -40,6 +41,7 @@ [open-issues-url]: https://github.com/apache/datafusion/issues [linkedin-badge]: https://img.shields.io/badge/Follow-Linkedin-blue [linkedin-url]: https://www.linkedin.com/company/apache-datafusion/ +[msrv-badge]: https://img.shields.io/crates/msrv/datafusion?label=Min%20Rust%20Version [Website](https://datafusion.apache.org/) | [API Docs](https://docs.rs/datafusion/latest/datafusion/) | @@ -58,8 +60,6 @@ See [use cases] for examples. The following related subprojects target end users - [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame queries. -- [DataFusion Ray](https://github.com/apache/datafusion-ray/) provides a distributed version of DataFusion that scales - out on Ray clusters. - [DataFusion Comet](https://github.com/apache/datafusion-comet/) is an accelerator for Apache Spark based on DataFusion. @@ -118,6 +118,7 @@ Default features: - `datetime_expressions`: date and time functions such as `to_timestamp` - `encoding_expressions`: `encode` and `decode` functions - `parquet`: support for reading the [Apache Parquet] format +- `sql`: Support for sql parsing / planning - `regex_expressions`: regular expression functions, such as `regexp_match` - `unicode_expressions`: Include unicode aware functions such as `character_length` - `unparser`: enables support to reverse LogicalPlans back into SQL @@ -127,25 +128,13 @@ Optional features: - `avro`: support for reading the [Apache Avro] format - `backtrace`: include backtrace information in error messages +- `parquet_encryption`: support for using [Parquet Modular Encryption] - `pyarrow`: conversions between PyArrow and DataFusion types - `serde`: enable arrow-schema's `serde` feature [apache avro]: https://avro.apache.org/ [apache parquet]: https://parquet.apache.org/ - -## Rust Version Compatibility Policy - -The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow -[semantic versioning](https://semver.org/). A Rust toolchain release can be identified -by a version string like `1.80.0`, or more generally `major.minor.patch`. - -DataFusion's supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. - -For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. - -Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. - -DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) +[parquet modular encryption]: https://parquet.apache.org/docs/file-format/data-pages/encryption/ ## DataFusion API Evolution and Deprecation Guidelines @@ -153,7 +142,7 @@ Public methods in Apache DataFusion evolve over time: while we try to maintain a stable API, we also improve the API over time. As a result, we typically deprecate methods before removing them, according to the [deprecation guidelines]. -[deprecation guidelines]: https://datafusion.apache.org/library-user-guide/api-health.html +[deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html ## Dependencies and `Cargo.lock` diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 063f4dac22d8d..b3fd520814dbc 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -33,6 +33,7 @@ workspace = true ci = [] default = ["mimalloc"] snmalloc = ["snmalloc-rs"] +mimalloc_extended = ["libmimalloc-sys/extended"] [dependencies] arrow = { workspace = true } @@ -40,18 +41,19 @@ datafusion = { workspace = true, default-features = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } futures = { workspace = true } +libmimalloc-sys = { version = "0.1", optional = true } log = { workspace = true } mimalloc = { version = "0.1", optional = true, default-features = false } object_store = { workspace = true } parquet = { workspace = true, default-features = true } rand = { workspace = true } -serde = { version = "1.0.219", features = ["derive"] } +regex.workspace = true +serde = { version = "1.0.228", features = ["derive"] } serde_json = { workspace = true } snmalloc-rs = { version = "0.3", optional = true } structopt = { version = "0.3", default-features = false } -test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tokio-util = { version = "0.7.14" } +tokio-util = { version = "0.7.16" } [dev-dependencies] datafusion-proto = { workspace = true } diff --git a/benchmarks/README.md b/benchmarks/README.md index 8acaa298bd3ad..8fed85fa02b80 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -23,7 +23,6 @@ This crate contains benchmarks based on popular public data sets and open source benchmark suites, to help with performance and scalability testing of DataFusion. - ## Other engines The benchmarks measure changes to DataFusion itself, rather than @@ -31,11 +30,11 @@ its performance against other engines. For competitive benchmarking, DataFusion is included in the benchmark setups for several popular benchmarks that compare performance with other engines. For example: -* [ClickBench] scripts are in the [ClickBench repo](https://github.com/ClickHouse/ClickBench/tree/main/datafusion) -* [H2o.ai `db-benchmark`] scripts are in [db-benchmark](https://github.com/apache/datafusion/tree/main/benchmarks/src/h2o.rs) +- [ClickBench] scripts are in the [ClickBench repo](https://github.com/ClickHouse/ClickBench/tree/main/datafusion) +- [H2o.ai `db-benchmark`] scripts are in [db-benchmark](https://github.com/apache/datafusion/tree/main/benchmarks/src/h2o.rs) -[ClickBench]: https://github.com/ClickHouse/ClickBench/tree/main -[H2o.ai `db-benchmark`]: https://github.com/h2oai/db-benchmark +[clickbench]: https://github.com/ClickHouse/ClickBench/tree/main +[h2o.ai `db-benchmark`]: https://github.com/h2oai/db-benchmark # Running the benchmarks @@ -65,31 +64,87 @@ Create / download a specific dataset (TPCH) ```shell ./bench.sh data tpch ``` + Data is placed in the `data` subdirectory. ## Running benchmarks Run benchmark for TPC-H dataset + ```shell ./bench.sh run tpch ``` + or for TPC-H dataset scale 10 + ```shell ./bench.sh run tpch10 ``` To run for specific query, for example Q21 + ```shell ./bench.sh run tpch10 21 ``` -## Select join algorithm +## Compile profile benchmark + +Generate the data required for the compile profile helper (TPC-H SF=1): + +```shell +./bench.sh data compile_profile +``` + +Run the benchmark across all default Cargo profiles (`dev`, `release`, `ci`, `release-nonlto`): + +```shell +./bench.sh run compile_profile +``` + +Limit the run to a single profile: + +```shell +./bench.sh run compile_profile dev +``` + +Or specify a subset of profiles: + +```shell +./bench.sh run compile_profile dev release +``` + +You can also invoke the helper directly if you need to customise arguments further: + +```shell +./benchmarks/compile_profile.py --profiles dev release --data /path/to/tpch_sf1 +``` + + +## Benchmark with modified configurations + +### Select join algorithm + The benchmark runs with `prefer_hash_join == true` by default, which enforces HASH join algorithm. To run TPCH benchmarks with join other than HASH: + ```shell PREFER_HASH_JOIN=false ./bench.sh run tpch ``` +### Configure with environment variables + +Any [datafusion options](https://datafusion.apache.org/user-guide/configs.html) that are provided environment variables are +also considered by the benchmarks. +The following configuration runs the TPCH benchmark with datafusion configured to _not_ repartition join keys. + +```shell +DATAFUSION_OPTIMIZER_REPARTITION_JOINS=false ./bench.sh run tpch +``` + +You might want to adjust the results location to avoid overwriting previous results. +Environment configuration that was picked up by datafusion is logged at `info` level. +To verify that datafusion picked up your configuration, run the benchmarks with `RUST_LOG=info` or higher. + ## Comparing performance of main and a branch ```shell @@ -200,6 +255,16 @@ cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-parq Or if you want to verify and run all the queries in the benchmark, you can just run `cargo test`. +#### Sorted Conversion + +The TPCH tables generated by the dbgen utility are sorted by their first column (their primary key for most tables, the `l_orderkey` column for the `lineitem` table.) + +To preserve this sorted order information during conversion (useful for benchmarking execution on pre-sorted data) include the `--sort` flag: + +```bash +cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-sorted-parquet --format parquet --sort +``` + ### Comparing results between runs Any `dfbench` execution with `-o ` argument will produce a @@ -251,6 +316,7 @@ This will produce output like: └──────────────┴──────────────┴──────────────┴───────────────┘ ``` + # Benchmark Runner The `dfbench` program contains subcommands to run the various @@ -289,6 +355,66 @@ FLAGS: ... ``` +# Profiling Memory Stats for each benchmark query +The `mem_profile` program wraps benchmark execution to measure memory usage statistics, such as peak RSS. It runs each benchmark query in a separate subprocess, capturing the child process’s stdout to print structured output. + +Subcommands supported by mem_profile are the subset of those in `dfbench`. +Currently supported benchmarks include: Clickbench, H2o, Imdb, SortTpch, Tpch + +Before running benchmarks, `mem_profile` automatically compiles the benchmark binary (`dfbench`) using `cargo build`. Note that the build profile used for `dfbench` is not tied to the profile used for running `mem_profile` itself. We can explicitly specify the desired build profile using the `--bench-profile` option (e.g. release-nonlto). By prebuilding the binary and running each query in a separate process, we can ensure accurate memory statistics. + +Currently, `mem_profile` only supports `mimalloc` as the memory allocator, since it relies on `mimalloc`'s API to collect memory statistics. + +Because it runs the compiled binary directly from the target directory, make sure your working directory is the top-level datafusion/ directory, where the target/ is also located. + +The benchmark subcommand (e.g., `tpch`) and all following arguments are passed directly to `dfbench`. Be sure to specify `--bench-profile` before the benchmark subcommand. + +Example: +```shell +datafusion$ cargo run --profile release-nonlto --bin mem_profile -- --bench-profile release-nonlto tpch --path benchmarks/data/tpch_sf1 --partitions 4 --format parquet +``` +Example Output: +``` +Query Time (ms) Peak RSS Peak Commit Major Page Faults +---------------------------------------------------------------- +1 503.42 283.4 MB 3.0 GB 0 +2 431.09 240.7 MB 3.0 GB 0 +3 594.28 350.1 MB 3.0 GB 0 +4 468.90 462.4 MB 3.0 GB 0 +5 653.58 385.4 MB 3.0 GB 0 +6 296.79 247.3 MB 2.0 GB 0 +7 662.32 652.4 MB 3.0 GB 0 +8 702.48 396.0 MB 3.0 GB 0 +9 774.21 611.5 MB 3.0 GB 0 +10 733.62 397.9 MB 3.0 GB 0 +11 271.71 209.6 MB 3.0 GB 0 +12 512.60 212.5 MB 2.0 GB 0 +13 507.83 381.5 MB 2.0 GB 0 +14 420.89 313.5 MB 3.0 GB 0 +15 539.97 288.0 MB 2.0 GB 0 +16 370.91 229.8 MB 3.0 GB 0 +17 758.33 467.0 MB 2.0 GB 0 +18 1112.32 638.9 MB 3.0 GB 0 +19 712.72 280.9 MB 2.0 GB 0 +20 620.64 402.9 MB 2.9 GB 0 +21 971.63 388.9 MB 2.9 GB 0 +22 404.50 164.8 MB 2.0 GB 0 +``` + +## Reported Metrics +When running benchmarks, `mem_profile` collects several memory-related statistics using the mimalloc API: + +- Peak RSS (Resident Set Size): +The maximum amount of physical memory used by the process. +This is a process-level metric collected via OS-specific mechanisms and is not mimalloc-specific. + +- Peak Commit: +The peak amount of memory committed by the allocator (i.e., total virtual memory reserved). +This is mimalloc-specific. It gives a more allocator-aware view of memory usage than RSS. + +- Major Page Faults: +The number of major page faults triggered during execution. +This metric is obtained from the operating system and is not mimalloc-specific. # Writing a new benchmark ## Creating or downloading data outside of the benchmark @@ -347,37 +473,6 @@ Your benchmark should create and use an instance of `BenchmarkRun` defined in `b The output of `dfbench` help includes a description of each benchmark, which is reproduced here for convenience. -## Cancellation - -Test performance of cancelling queries. - -Queries in DataFusion should stop executing "quickly" after they are -cancelled (the output stream is dropped). - -The queries are executed on a synthetic dataset generated during -the benchmark execution that is an anonymized version of a -real-world data set. - -The query is an anonymized version of a real-world query, and the -test starts the query then cancels it and reports how long it takes -for the runtime to fully exit. - -Example output: - -``` -Using 7 files found on disk -Starting to load data into in-memory object store -Done loading data into in-memory object store -in main, sleeping -Starting spawned -Creating logical plan... -Creating physical plan... -Executing physical plan... -Getting results... -cancelling thread -done dropping runtime in 83.531417ms -``` - ## ClickBench The ClickBench[1] benchmarks are widely cited in the industry and @@ -397,7 +492,7 @@ logs. Example -dfbench parquet-filter --path ./data --scale-factor 1.0 +dfbench parquet-filter --path ./data --scale-factor 1.0 generates the synthetic dataset at `./data/logs.parquet`. The size of the dataset can be controlled through the `size_factor` @@ -429,6 +524,7 @@ Iteration 2 returned 1781686 rows in 1947 ms ``` ## Sort + Test performance of sorting large datasets This test sorts a a synthetic dataset generated during the @@ -445,24 +541,46 @@ Test performance of end-to-end sort SQL queries. (While the `Sort` benchmark foc Sort integration benchmark runs whole table sort queries on TPCH `lineitem` table, with different characteristics. For example, different number of sort keys, different sort key cardinality, different number of payload columns, etc. +If the TPCH tables have been converted as sorted on their first column (see [Sorted Conversion](#sorted-conversion)), you can use the `--sorted` flag to indicate that the input data is pre-sorted, allowing DataFusion to leverage that order during query execution. + +Additionally, an optional `--limit` flag is available for the sort benchmark. When specified, this flag appends a `LIMIT n` clause to the SQL query, effectively converting the query into a TopK query. Combining the `--sorted` and `--limit` options enables benchmarking of TopK queries on pre-sorted inputs. + See [`sort_tpch.rs`](src/sort_tpch.rs) for more details. ### Sort TPCH Benchmark Example Runs + 1. Run all queries with default setting: + ```bash - cargo run --release --bin dfbench -- sort-tpch -p '....../datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' + cargo run --release --bin dfbench -- sort-tpch -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' ``` 2. Run a specific query: + ```bash - cargo run --release --bin dfbench -- sort-tpch -p '....../datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' --query 2 + cargo run --release --bin dfbench -- sort-tpch -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' --query 2 ``` -3. Run all queries with `bench.sh` script: +3. Run all queries as TopK queries on presorted data: + +```bash + cargo run --release --bin dfbench -- sort-tpch --sorted --limit 10 -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' +``` + +4. Run all queries with `bench.sh` script: + ```bash ./bench.sh run sort_tpch ``` +### TopK TPCH + +In addition, topk_tpch is available from the bench.sh script: + +```bash +./bench.sh run topk_tpch +``` + ## IMDB Run Join Order Benchmark (JOB) on IMDB dataset. @@ -496,59 +614,78 @@ External aggregation benchmarks run several aggregation queries with different m This benchmark is inspired by [DuckDB's external aggregation paper](https://hannes.muehleisen.org/publications/icde2024-out-of-core-kuiper-boncz-muehleisen.pdf), specifically Section VI. ### External Aggregation Example Runs + 1. Run all queries with predefined memory limits: + ```bash # Under 'benchmarks/' directory cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' ``` 2. Run a query with specific memory limit: + ```bash cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' --query 1 --memory-limit 30M ``` 3. Run all queries with `bench.sh` script: + ```bash ./bench.sh data external_aggr ./bench.sh run external_aggr ``` +## h2o.ai benchmarks + +The h2o.ai benchmarks are a set of performance tests for groupby and join operations. Beyond the standard h2o benchmark, there is also an extended benchmark for window functions. These benchmarks use synthetic data with configurable sizes (small: 1e7 rows, medium: 1e8 rows, big: 1e9 rows) to evaluate DataFusion's performance across different data scales. + +Reference: -## h2o benchmarks for groupby +- [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) +- [Extended window benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) + +### h2o benchmarks for groupby + +#### Generate data for h2o benchmarks -### Generate data for h2o benchmarks There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. The data is generated in the `data` directory. 1. Generate small data (1e7 rows) + ```bash ./bench.sh data h2o_small ``` - 2. Generate medium data (1e8 rows) + ```bash ./bench.sh data h2o_medium ``` - 3. Generate large data (1e9 rows) + ```bash ./bench.sh data h2o_big ``` -### Run h2o benchmarks +#### Run h2o benchmarks + There are three options for running h2o benchmarks: `small`, `medium`, and `big`. + 1. Run small data benchmark + ```bash ./bench.sh run h2o_small ``` 2. Run medium data benchmark + ```bash ./bench.sh run h2o_medium ``` 3. Run large data benchmark + ```bash ./bench.sh run h2o_big ``` @@ -556,53 +693,114 @@ There are three options for running h2o benchmarks: `small`, `medium`, and `big` 4. Run a specific query with a specific data path For example, to run query 1 with the small data generated above: + ```bash cargo run --release --bin dfbench -- h2o --path ./benchmarks/data/h2o/G1_1e7_1e7_100_0.csv --query 1 ``` -## h2o benchmarks for join +### h2o benchmarks for join -### Generate data for h2o benchmarks There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. The data is generated in the `data` directory. -1. Generate small data (4 table files, the largest is 1e7 rows) +Here is a example to generate `small` dataset and run the benchmark. To run other +dataset size configuration, change the command similar to the previous example. + ```bash +# Generate small data (4 table files, the largest is 1e7 rows) ./bench.sh data h2o_small_join + +# Run the benchmark +./bench.sh run h2o_small_join ``` +To run a specific query with a specific join data paths, the data paths are including 4 table files. + +For example, to run query 1 with the small data generated above: -2. Generate medium data (4 table files, the largest is 1e8 rows) ```bash -./bench.sh data h2o_medium_join +cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/join.sql --query 1 ``` -3. Generate large data (4 table files, the largest is 1e9 rows) +### Extended h2o benchmarks for window + +This benchmark extends the h2o benchmark suite to evaluate window function performance. H2o window benchmark uses the same dataset as the h2o join benchmark. There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. + +Here is a example to generate `small` dataset and run the benchmark. To run other +dataset size configuration, change the command similar to the previous example. + ```bash -./bench.sh data h2o_big_join +# Generate small data +./bench.sh data h2o_small_window + +# Run the benchmark +./bench.sh run h2o_small_window ``` -### Run h2o benchmarks -There are three options for running h2o benchmarks: `small`, `medium`, and `big`. -1. Run small data benchmark +To run a specific query with a specific window data paths, the data paths are including 4 table files (the same as h2o-join dataset) + +For example, to run query 1 with the small data generated above: + ```bash -./bench.sh run h2o_small_join +cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/window.sql --query 1 ``` -2. Run medium data benchmark +# Micro-Benchmarks + +## Nested Loop Join + +This benchmark focuses on the performance of queries with nested loop joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Different queries are included to test nested loop joins under various workloads. + +### Example Run + ```bash -./bench.sh run h2o_medium_join +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run nlj ``` -3. Run large data benchmark +## Hash Join + +This benchmark focuses on the performance of queries with nested hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Several queries are included to test hash joins under various workloads. + +### Example Run + ```bash -./bench.sh run h2o_big_join +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run hj ``` -4. Run a specific query with a specific join data paths, the data paths are including 4 table files. +## Cancellation + +Test performance of cancelling queries. + +Queries in DataFusion should stop executing "quickly" after they are +cancelled (the output stream is dropped). + +The queries are executed on a synthetic dataset generated during +the benchmark execution that is an anonymized version of a +real-world data set. + +The query is an anonymized version of a real-world query, and the +test starts the query then cancels it and reports how long it takes +for the runtime to fully exit. + +Example output: -For example, to run query 1 with the small data generated above: -```bash -cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/join.sql --query 1 ``` -[1]: http://www.tpc.org/tpch/ -[2]: https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page +Using 7 files found on disk +Starting to load data into in-memory object store +Done loading data into in-memory object store +in main, sleeping +Starting spawned +Creating logical plan... +Creating physical plan... +Executing physical plan... +Getting results... +cancelling thread +done dropping runtime in 83.531417ms +``` diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 5d3ad3446ddb9..dbfd319dd9ad4 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -28,6 +28,12 @@ set -e # https://stackoverflow.com/questions/59895/how-do-i-get-the-directory-where-a-bash-script-is-located-from-within-the-script SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +# Execute command and also print it, for debugging purposes +debug_run() { + set -x + "$@" + set +x +} # Set Defaults COMMAND= @@ -43,61 +49,96 @@ usage() { Orchestrates running benchmarks against DataFusion checkouts Usage: -$0 data [benchmark] [query] -$0 run [benchmark] +$0 data [benchmark] +$0 run [benchmark] [query] $0 compare +$0 compare_detail $0 venv -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Examples: -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Create the datasets for all benchmarks in $DATA_DIR ./bench.sh data # Run the 'tpch' benchmark on the datafusion checkout in /source/datafusion DATAFUSION_DIR=/source/datafusion ./bench.sh run tpch -********** -* Commands -********** -data: Generates or downloads data needed for benchmarking -run: Runs the named benchmark -compare: Compares results from benchmark runs -venv: Creates new venv (unless already exists) and installs compare's requirements into it - -********** -* Benchmarks -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Commands +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +data: Generates or downloads data needed for benchmarking +run: Runs the named benchmark +compare: Compares fastest results from benchmark runs +compare_detail: Compares minimum, average (±stddev), and maximum results from benchmark runs +venv: Creates new venv (unless already exists) and installs compare's requirements into it + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Benchmarks +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Run all of the following benchmarks all(default): Data/Run/Compare for all benchmarks + +# TPC-H Benchmarks tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, hash join +tpch_csv: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single csv file per table, hash join tpch_mem: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), query from memory tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table, hash join +tpch_csv10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single csv file per table, hash join tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory -cancellation: How long cancelling a query takes -parquet: Benchmark of parquet reader's filtering speed -sort: Benchmark of sorting speed -sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPCH dataset + +# Extended TPC-H Benchmarks +sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=1) +sort_tpch10: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=10) +topk_tpch: Benchmark of top-k (sorting with limit) queries on TPC-H dataset (SF=1) +external_aggr: External aggregation benchmark on TPC-H dataset (SF=1) + +# ClickBench Benchmarks clickbench_1: ClickBench queries against a single parquet file -clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet +clickbench_partitioned: ClickBench queries against partitioned (100 files) parquet +clickbench_pushdown: ClickBench queries against partitioned (100 files) parquet w/ filter_pushdown enabled clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) -external_aggr: External aggregation benchmark -h2o_small: h2oai benchmark with small dataset (1e7 rows) for groupby, default file format is csv -h2o_medium: h2oai benchmark with medium dataset (1e8 rows) for groupby, default file format is csv -h2o_big: h2oai benchmark with large dataset (1e9 rows) for groupby, default file format is csv -h2o_small_join: h2oai benchmark with small dataset (1e7 rows) for join, default file format is csv -h2o_medium_join: h2oai benchmark with medium dataset (1e8 rows) for join, default file format is csv -h2o_big_join: h2oai benchmark with large dataset (1e9 rows) for join, default file format is csv + +# H2O.ai Benchmarks (Group By, Join, Window) +h2o_small: h2oai benchmark with small dataset (1e7 rows) for groupby, default file format is csv +h2o_medium: h2oai benchmark with medium dataset (1e8 rows) for groupby, default file format is csv +h2o_big: h2oai benchmark with large dataset (1e9 rows) for groupby, default file format is csv +h2o_small_join: h2oai benchmark with small dataset (1e7 rows) for join, default file format is csv +h2o_medium_join: h2oai benchmark with medium dataset (1e8 rows) for join, default file format is csv +h2o_big_join: h2oai benchmark with large dataset (1e9 rows) for join, default file format is csv +h2o_small_window: Extended h2oai benchmark with small dataset (1e7 rows) for window, default file format is csv +h2o_medium_window: Extended h2oai benchmark with medium dataset (1e8 rows) for window, default file format is csv +h2o_big_window: Extended h2oai benchmark with large dataset (1e9 rows) for window, default file format is csv +h2o_small_parquet: h2oai benchmark with small dataset (1e7 rows) for groupby, file format is parquet +h2o_medium_parquet: h2oai benchmark with medium dataset (1e8 rows) for groupby, file format is parquet +h2o_big_parquet: h2oai benchmark with large dataset (1e9 rows) for groupby, file format is parquet +h2o_small_join_parquet: h2oai benchmark with small dataset (1e7 rows) for join, file format is parquet +h2o_medium_join_parquet: h2oai benchmark with medium dataset (1e8 rows) for join, file format is parquet +h2o_big_join_parquet: h2oai benchmark with large dataset (1e9 rows) for join, file format is parquet +h2o_small_window_parquet: Extended h2oai benchmark with small dataset (1e7 rows) for window, file format is parquet +h2o_medium_window_parquet: Extended h2oai benchmark with medium dataset (1e8 rows) for window, file format is parquet +h2o_big_window_parquet: Extended h2oai benchmark with large dataset (1e9 rows) for window, file format is parquet + +# Join Order Benchmark (IMDB) imdb: Join Order Benchmark (JOB) using the IMDB dataset converted to parquet -********** -* Supported Configuration (Environment Variables) -********** +# Micro-Benchmarks (specific operators and features) +cancellation: How long cancelling a query takes +nlj: Benchmark for simple nested loop joins, testing various join scenarios +hj: Benchmark for simple hash joins, testing various join scenarios +compile_profile: Compile and execute TPC-H across selected Cargo profiles, reporting timing and binary size + + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Supported Configuration (Environment Variables) +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ DATA_DIR directory to store datasets CARGO_COMMAND command that runs the benchmark binary DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) RESULTS_NAME folder where the benchmark files are stored PREFER_HASH_JOIN Prefer hash join algorithm (default true) VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) +DATAFUSION_* Set the given datafusion configuration " exit 1 } @@ -159,6 +200,7 @@ main() { data_clickbench_1 data_clickbench_partitioned data_imdb + # nlj uses range() function, no data generation needed ;; tpch) data_tpch "1" @@ -180,6 +222,9 @@ main() { clickbench_partitioned) data_clickbench_partitioned ;; + clickbench_pushdown) + data_clickbench_partitioned # same data as clickbench_partitioned + ;; clickbench_extended) data_clickbench_1 ;; @@ -204,6 +249,44 @@ main() { h2o_big_join) data_h2o_join "BIG" "CSV" ;; + # h2o window benchmark uses the same data as the h2o join + h2o_small_window) + data_h2o_join "SMALL" "CSV" + ;; + h2o_medium_window) + data_h2o_join "MEDIUM" "CSV" + ;; + h2o_big_window) + data_h2o_join "BIG" "CSV" + ;; + h2o_small_parquet) + data_h2o "SMALL" "PARQUET" + ;; + h2o_medium_parquet) + data_h2o "MEDIUM" "PARQUET" + ;; + h2o_big_parquet) + data_h2o "BIG" "PARQUET" + ;; + h2o_small_join_parquet) + data_h2o_join "SMALL" "PARQUET" + ;; + h2o_medium_join_parquet) + data_h2o_join "MEDIUM" "PARQUET" + ;; + h2o_big_join_parquet) + data_h2o_join "BIG" "PARQUET" + ;; + # h2o window benchmark uses the same data as the h2o join + h2o_small_window_parquet) + data_h2o_join "SMALL" "PARQUET" + ;; + h2o_medium_window_parquet) + data_h2o_join "MEDIUM" "PARQUET" + ;; + h2o_big_window_parquet) + data_h2o_join "BIG" "PARQUET" + ;; external_aggr) # same data as for tpch data_tpch "1" @@ -212,6 +295,25 @@ main() { # same data as for tpch data_tpch "1" ;; + sort_tpch10) + # same data as for tpch10 + data_tpch "10" + ;; + topk_tpch) + # same data as for tpch + data_tpch "1" + ;; + nlj) + # nlj uses range() function, no data generation needed + echo "NLJ benchmark does not require data generation" + ;; + hj) + # hj uses range() function, no data generation needed + echo "HJ benchmark does not require data generation" + ;; + compile_profile) + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -221,6 +323,18 @@ main() { run) # Parse positional parameters BENCHMARK=${ARG2:-"${BENCHMARK}"} + EXTRA_ARGS=("${POSITIONAL_ARGS[@]:2}") + PROFILE_ARGS=() + QUERY="" + QUERY_ARG="" + if [ "$BENCHMARK" = "compile_profile" ]; then + PROFILE_ARGS=("${EXTRA_ARGS[@]}") + else + QUERY=${EXTRA_ARGS[0]} + if [ -n "$QUERY" ]; then + QUERY_ARG="--query ${QUERY}" + fi + fi BRANCH_NAME=$(cd "${DATAFUSION_DIR}" && git rev-parse --abbrev-ref HEAD) BRANCH_NAME=${BRANCH_NAME//\//_} # mind blowing syntax to replace / with _ RESULTS_NAME=${RESULTS_NAME:-"${BRANCH_NAME}"} @@ -230,6 +344,11 @@ main() { echo "DataFusion Benchmark Script" echo "COMMAND: ${COMMAND}" echo "BENCHMARK: ${BENCHMARK}" + if [ "$BENCHMARK" = "compile_profile" ]; then + echo "PROFILES: ${PROFILE_ARGS[*]:-All}" + else + echo "QUERY: ${QUERY:-All}" + fi echo "DATAFUSION_DIR: ${DATAFUSION_DIR}" echo "BRANCH_NAME: ${BRANCH_NAME}" echo "DATA_DIR: ${DATA_DIR}" @@ -244,15 +363,16 @@ main() { mkdir -p "${DATA_DIR}" case "$BENCHMARK" in all) - run_tpch "1" + run_tpch "1" "parquet" + run_tpch "1" "csv" run_tpch_mem "1" - run_tpch "10" + run_tpch "10" "parquet" + run_tpch "10" "csv" run_tpch_mem "10" run_cancellation - run_parquet - run_sort run_clickbench_1 run_clickbench_partitioned + run_clickbench_pushdown run_clickbench_extended run_h2o "SMALL" "PARQUET" "groupby" run_h2o "MEDIUM" "PARQUET" "groupby" @@ -262,15 +382,23 @@ main() { run_h2o_join "BIG" "PARQUET" "join" run_imdb run_external_aggr + run_nlj + run_hj ;; tpch) - run_tpch "1" + run_tpch "1" "parquet" + ;; + tpch_csv) + run_tpch "1" "csv" ;; tpch_mem) run_tpch_mem "1" ;; tpch10) - run_tpch "10" + run_tpch "10" "parquet" + ;; + tpch_csv10) + run_tpch "10" "csv" ;; tpch_mem10) run_tpch_mem "10" @@ -278,18 +406,15 @@ main() { cancellation) run_cancellation ;; - parquet) - run_parquet - ;; - sort) - run_sort - ;; clickbench_1) run_clickbench_1 ;; clickbench_partitioned) run_clickbench_partitioned ;; + clickbench_pushdown) + run_clickbench_pushdown + ;; clickbench_extended) run_clickbench_extended ;; @@ -314,11 +439,63 @@ main() { h2o_big_join) run_h2o_join "BIG" "CSV" "join" ;; + h2o_small_window) + run_h2o_window "SMALL" "CSV" "window" + ;; + h2o_medium_window) + run_h2o_window "MEDIUM" "CSV" "window" + ;; + h2o_big_window) + run_h2o_window "BIG" "CSV" "window" + ;; + h2o_small_parquet) + run_h2o "SMALL" "PARQUET" + ;; + h2o_medium_parquet) + run_h2o "MEDIUM" "PARQUET" + ;; + h2o_big_parquet) + run_h2o "BIG" "PARQUET" + ;; + h2o_small_join_parquet) + run_h2o_join "SMALL" "PARQUET" + ;; + h2o_medium_join_parquet) + run_h2o_join "MEDIUM" "PARQUET" + ;; + h2o_big_join_parquet) + run_h2o_join "BIG" "PARQUET" + ;; + # h2o window benchmark uses the same data as the h2o join + h2o_small_window_parquet) + run_h2o_window "SMALL" "PARQUET" + ;; + h2o_medium_window_parquet) + run_h2o_window "MEDIUM" "PARQUET" + ;; + h2o_big_window_parquet) + run_h2o_window "BIG" "PARQUET" + ;; external_aggr) run_external_aggr ;; sort_tpch) - run_sort_tpch + run_sort_tpch "1" + ;; + sort_tpch10) + run_sort_tpch "10" + ;; + topk_tpch) + run_topk_tpch + ;; + nlj) + run_nlj + ;; + hj) + run_hj + ;; + compile_profile) + run_compile_profile "${PROFILE_ARGS[@]}" ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" @@ -331,6 +508,9 @@ main() { compare) compare_benchmarks "$ARG2" "$ARG3" ;; + compare_detail) + compare_benchmarks "$ARG2" "$ARG3" "--detailed" + ;; venv) setup_venv ;; @@ -396,6 +576,17 @@ data_tpch() { $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet popd > /dev/null fi + + # Create 'csv' files from tbl + FILE="${TPCH_DIR}/csv/supplier" + if test -d "${FILE}"; then + echo " csv files exist ($FILE exists)." + else + echo " creating csv files using benchmark binary ..." + pushd "${SCRIPT_DIR}" > /dev/null + $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}/csv" --format csv + popd > /dev/null + fi } # Runs the tpch benchmark @@ -410,12 +601,9 @@ run_tpch() { RESULTS_FILE="${RESULTS_DIR}/tpch_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch benchmark..." - # Optional query filter to run specific query - QUERY=$([ -n "$ARG3" ] && echo "--query $ARG3" || echo "") - # debug the target command - set -x - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" $QUERY - set +x + + FORMAT=$2 + debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format ${FORMAT} -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the tpch in memory @@ -430,13 +618,22 @@ run_tpch_mem() { RESULTS_FILE="${RESULTS_DIR}/tpch_mem_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch_mem benchmark..." - # Optional query filter to run specific query - QUERY=$([ -n "$ARG3" ] && echo "--query $ARG3" || echo "") - # debug the target command - set -x # -m means in memory - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" $QUERY - set +x + debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} +} + +# Runs the compile profile benchmark helper +run_compile_profile() { + local profiles=("$@") + local runner="${SCRIPT_DIR}/compile_profile.py" + local data_path="${DATA_DIR}/tpch_sf1" + + echo "Running compile profile benchmark..." + local cmd=(python3 "${runner}" --data "${data_path}") + if [ ${#profiles[@]} -gt 0 ]; then + cmd+=(--profiles "${profiles[@]}") + fi + debug_run "${cmd[@]}" } # Runs the cancellation benchmark @@ -444,23 +641,7 @@ run_cancellation() { RESULTS_FILE="${RESULTS_DIR}/cancellation.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running cancellation benchmark..." - $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" -} - -# Runs the parquet filter benchmark -run_parquet() { - RESULTS_FILE="${RESULTS_DIR}/parquet.json" - echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running parquet filter benchmark..." - $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" -} - -# Runs the sort benchmark -run_sort() { - RESULTS_FILE="${RESULTS_DIR}/sort.json" - echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running sort benchmark..." - $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" } @@ -514,23 +695,33 @@ run_clickbench_1() { RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} } - # Runs the clickbench benchmark with the partitioned parquet files + # Runs the clickbench benchmark with the partitioned parquet dataset (100 files) run_clickbench_partitioned() { RESULTS_FILE="${RESULTS_DIR}/clickbench_partitioned.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} +} + + + # Runs the clickbench benchmark with the partitioned parquet files and filter_pushdown enabled +run_clickbench_pushdown() { + RESULTS_FILE="${RESULTS_DIR}/clickbench_pushdown.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running clickbench (partitioned, 100 files) benchmark with pushdown_filters=true, reorder_filters=true..." + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --pushdown --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} } + # Runs the clickbench "extended" benchmark with a single large parquet file run_clickbench_extended() { RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) extended benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Downloads the csv.gz files IMDB datasets from Peter Boncz's homepage(one of the JOB paper authors) @@ -645,7 +836,7 @@ run_imdb() { RESULTS_FILE="${RESULTS_DIR}/imdb.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running imdb benchmark..." - $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} } data_h2o() { @@ -800,6 +991,7 @@ data_h2o_join() { deactivate } +# Runner for h2o groupby benchmark run_h2o() { # Default values for size and data format SIZE=${1:-"SMALL"} @@ -835,14 +1027,16 @@ run_h2o() { QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" # Run the benchmark using the dynamically constructed file path and query file - $CARGO_COMMAND --bin dfbench -- h2o \ + debug_run $CARGO_COMMAND --bin dfbench -- h2o \ --iterations 3 \ --path "${H2O_DIR}/${FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ - -o "${RESULTS_FILE}" + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} } -run_h2o_join() { +# Utility function to run h2o join/window benchmark +h2o_runner() { # Default values for size and data format SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} @@ -851,10 +1045,10 @@ run_h2o_join() { # Data directory and results file path H2O_DIR="${DATA_DIR}/h2o" - RESULTS_FILE="${RESULTS_DIR}/h2o_join.json" + RESULTS_FILE="${RESULTS_DIR}/h2o_${RUN_Type}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running h2o join benchmark..." + echo "Running h2o ${RUN_Type} benchmark..." # Set the file name based on the size case "$SIZE" in @@ -882,14 +1076,25 @@ run_h2o_join() { ;; esac - # Set the query file name based on the RUN_Type + # Set the query file name based on the RUN_Type QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" - $CARGO_COMMAND --bin dfbench -- h2o \ + debug_run $CARGO_COMMAND --bin dfbench -- h2o \ --iterations 3 \ --join-paths "${H2O_DIR}/${X_TABLE_FILE_NAME},${H2O_DIR}/${SMALL_TABLE_FILE_NAME},${H2O_DIR}/${MEDIUM_TABLE_FILE_NAME},${H2O_DIR}/${LARGE_TABLE_FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ - -o "${RESULTS_FILE}" + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} +} + +# Runners for h2o join benchmark +run_h2o_join() { + h2o_runner "$1" "$2" "join" +} + +# Runners for h2o join benchmark +run_h2o_window() { + h2o_runner "$1" "$2" "window" } # Runs the external aggregation benchmark @@ -905,17 +1110,48 @@ run_external_aggr() { # number-of-partitions), and by default `--partitions` is set to number of # CPU cores, we set a constant number of partitions to prevent this # benchmark to fail on some machines. - $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the sort integration benchmark run_sort_tpch() { - TPCH_DIR="${DATA_DIR}/tpch_sf1" - RESULTS_FILE="${RESULTS_DIR}/sort_tpch.json" + SCALE_FACTOR=$1 + if [ -z "$SCALE_FACTOR" ] ; then + echo "Internal error: Scale factor not specified" + exit 1 + fi + TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" + RESULTS_FILE="${RESULTS_DIR}/sort_tpch${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort tpch benchmark..." - $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} +} + +# Runs the sort tpch integration benchmark with limit 100 (topk) +run_topk_tpch() { + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/run_topk_tpch.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running topk tpch benchmark..." + + $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" --limit 100 ${QUERY_ARG} +} + +# Runs the nlj benchmark +run_nlj() { + RESULTS_FILE="${RESULTS_DIR}/nlj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running nlj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- nlj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} +} + +# Runs the hj benchmark +run_hj() { + RESULTS_FILE="${RESULTS_DIR}/hj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running hj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} } @@ -923,6 +1159,8 @@ compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" BRANCH1="$1" BRANCH2="$2" + OPTS="$3" + if [ -z "$BRANCH1" ] ; then echo " not specified. Available branches:" ls -1 "${BASE_RESULTS_DIR}" @@ -943,7 +1181,7 @@ compare_benchmarks() { echo "--------------------" echo "Benchmark ${BENCH}" echo "--------------------" - PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" + PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py $OPTS "${RESULTS_FILE1}" "${RESULTS_FILE2}" else echo "Note: Skipping ${RESULTS_FILE1} as ${RESULTS_FILE2} does not exist" fi diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 4b609c744d503..7e51a38a92c2b 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -18,7 +18,9 @@ from __future__ import annotations +import argparse import json +import math from dataclasses import dataclass from typing import Dict, List, Any from pathlib import Path @@ -47,6 +49,7 @@ class QueryRun: query: int iterations: List[QueryResult] start_time: int + success: bool = True @classmethod def load_from(cls, data: Dict[str, Any]) -> QueryRun: @@ -54,17 +57,57 @@ def load_from(cls, data: Dict[str, Any]) -> QueryRun: query=data["query"], iterations=[QueryResult(**iteration) for iteration in data["iterations"]], start_time=data["start_time"], + success=data.get("success", True), ) @property - def execution_time(self) -> float: + def min_execution_time(self) -> float: assert len(self.iterations) >= 1 - # Use minimum execution time to account for variations / other - # things the system was doing return min(iteration.elapsed for iteration in self.iterations) + @property + def max_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + return max(iteration.elapsed for iteration in self.iterations) + + + @property + def mean_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + total = sum(iteration.elapsed for iteration in self.iterations) + return total / len(self.iterations) + + + @property + def stddev_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + mean = self.mean_execution_time + squared_diffs = [(iteration.elapsed - mean) ** 2 for iteration in self.iterations] + variance = sum(squared_diffs) / len(self.iterations) + return math.sqrt(variance) + + def execution_time_report(self, detailed = False) -> tuple[float, str]: + if detailed: + mean_execution_time = self.mean_execution_time + return ( + mean_execution_time, + f"{self.min_execution_time:.2f} / {mean_execution_time :.2f} ±{self.stddev_execution_time:.2f} / {self.max_execution_time:.2f} ms" + ) + else: + # Use minimum execution time to account for variations / other + # things the system was doing + min_execution_time = self.min_execution_time + return ( + min_execution_time, + f"{min_execution_time :.2f} ms" + ) + + @dataclass class Context: benchmark_version: str @@ -106,6 +149,7 @@ def compare( baseline_path: Path, comparison_path: Path, noise_threshold: float, + detailed: bool, ) -> None: baseline = BenchmarkRun.load_from_file(baseline_path) comparison = BenchmarkRun.load_from_file(comparison_path) @@ -125,16 +169,34 @@ def compare( faster_count = 0 slower_count = 0 no_change_count = 0 + failure_count = 0 total_baseline_time = 0 total_comparison_time = 0 for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query - - total_baseline_time += baseline_result.execution_time - total_comparison_time += comparison_result.execution_time - - change = comparison_result.execution_time / baseline_result.execution_time + + base_failed = not baseline_result.success + comp_failed = not comparison_result.success + # If a query fails, its execution time is excluded from the performance comparison + if base_failed or comp_failed: + change_text = "incomparable" + failure_count += 1 + table.add_row( + f"Q{baseline_result.query}", + "FAIL" if base_failed else baseline_result.execution_time_report(detailed)[1], + "FAIL" if comp_failed else comparison_result.execution_time_report(detailed)[1], + change_text, + ) + continue + + baseline_value, baseline_text = baseline_result.execution_time_report(detailed) + comparison_value, comparison_text = comparison_result.execution_time_report(detailed) + + total_baseline_time += baseline_value + total_comparison_time += comparison_value + + change = comparison_value / baseline_value if (1.0 - noise_threshold) <= change <= (1.0 + noise_threshold): change_text = "no change" @@ -148,16 +210,20 @@ def compare( table.add_row( f"Q{baseline_result.query}", - f"{baseline_result.execution_time:.2f}ms", - f"{comparison_result.execution_time:.2f}ms", + baseline_text, + comparison_text, change_text, ) console.print(table) # Calculate averages - avg_baseline_time = total_baseline_time / len(baseline.queries) - avg_comparison_time = total_comparison_time / len(comparison.queries) + avg_baseline_time = 0.0 + avg_comparison_time = 0.0 + if len(baseline.queries) - failure_count > 0: + avg_baseline_time = total_baseline_time / (len(baseline.queries) - failure_count) + if len(comparison.queries) - failure_count > 0: + avg_comparison_time = total_comparison_time / (len(comparison.queries) - failure_count) # Summary table summary_table = Table(show_header=True, header_style="bold magenta") @@ -171,6 +237,7 @@ def compare( summary_table.add_row("Queries Faster", str(faster_count)) summary_table.add_row("Queries Slower", str(slower_count)) summary_table.add_row("Queries with No Change", str(no_change_count)) + summary_table.add_row("Queries with Failure", str(failure_count)) console.print(summary_table) @@ -193,10 +260,16 @@ def main() -> None: default=0.05, help="The threshold for statistically insignificant results (+/- %5).", ) + compare_parser.add_argument( + "--detailed", + action=argparse.BooleanOptionalAction, + default=False, + help="Show detailed result comparison instead of minimum runtime.", + ) options = parser.parse_args() - compare(options.baseline_path, options.comparison_path, options.noise_threshold) + compare(options.baseline_path, options.comparison_path, options.noise_threshold, options.detailed) diff --git a/benchmarks/compile_profile.py b/benchmarks/compile_profile.py new file mode 100644 index 0000000000000..ae51de94937bf --- /dev/null +++ b/benchmarks/compile_profile.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Compile profile benchmark runner for DataFusion. + +Builds the `tpch` benchmark binary with several Cargo profiles (e.g. `--release` or `--profile ci`), runs the full TPC-H suite against the Parquet data under `benchmarks/data/tpch_sf1`, and reports compile time, execution time, and resulting +binary size. + +See `benchmarks/README.md` for usages. +""" + +from __future__ import annotations + +import argparse +import os +import subprocess +import sys +import time +from pathlib import Path +from typing import Iterable, NamedTuple + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_DATA_DIR = REPO_ROOT / "benchmarks" / "data" / "tpch_sf1" +DEFAULT_ITERATIONS = 1 +DEFAULT_FORMAT = "parquet" +DEFAULT_PARTITIONS: int | None = None +TPCH_BINARY = "tpch.exe" if os.name == "nt" else "tpch" +PROFILE_TARGET_DIR = { + "dev": "debug", + "release": "release", + "ci": "ci", + "release-nonlto": "release-nonlto", +} + + +class ProfileResult(NamedTuple): + profile: str + compile_seconds: float + run_seconds: float + binary_bytes: int + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--profiles", + nargs="+", + default=list(PROFILE_TARGET_DIR.keys()), + help="Cargo profiles to test (default: dev release ci release-nonlto)", + ) + parser.add_argument( + "--data", + type=Path, + default=DEFAULT_DATA_DIR, + help="Path to TPCH dataset (default: benchmarks/data/tpch_sf1)", + ) + return parser.parse_args() + + +def timed_run(command: Iterable[str]) -> float: + start = time.perf_counter() + try: + subprocess.run(command, cwd=REPO_ROOT, check=True) + except subprocess.CalledProcessError as exc: + raise RuntimeError(f"command failed: {' '.join(command)}") from exc + return time.perf_counter() - start + + +def cargo_build(profile: str) -> float: + if profile == "dev": + command = ["cargo", "build", "--bin", "tpch"] + else: + command = ["cargo", "build", "--profile", profile, "--bin", "tpch"] + return timed_run(command) + + +def cargo_clean(profile: str) -> None: + command = ["cargo", "clean", "--profile", profile] + try: + subprocess.run(command, cwd=REPO_ROOT, check=True) + except subprocess.CalledProcessError as exc: + raise RuntimeError(f"failed to clean cargo artifacts for profile '{profile}'") from exc + + +def run_benchmark(profile: str, data_path: Path) -> float: + binary_dir = PROFILE_TARGET_DIR.get(profile) + if not binary_dir: + raise ValueError(f"unknown profile '{profile}'") + binary_path = REPO_ROOT / "target" / binary_dir / TPCH_BINARY + if not binary_path.exists(): + raise FileNotFoundError(f"compiled binary not found at {binary_path}") + + command = [ + str(binary_path), + "benchmark", + "datafusion", + "--iterations", + str(DEFAULT_ITERATIONS), + "--path", + str(data_path), + "--format", + DEFAULT_FORMAT, + ] + if DEFAULT_PARTITIONS is not None: + command.extend(["--partitions", str(DEFAULT_PARTITIONS)]) + env = os.environ.copy() + env.setdefault("RUST_LOG", "warn") + + start = time.perf_counter() + try: + subprocess.run(command, cwd=REPO_ROOT, env=env, check=True) + except subprocess.CalledProcessError as exc: + raise RuntimeError(f"benchmark failed for profile '{profile}'") from exc + return time.perf_counter() - start + + +def binary_size(profile: str) -> int: + binary_dir = PROFILE_TARGET_DIR[profile] + binary_path = REPO_ROOT / "target" / binary_dir / TPCH_BINARY + return binary_path.stat().st_size + + +def human_time(seconds: float) -> str: + return f"{seconds:6.2f}s" + + +def human_size(size: int) -> str: + value = float(size) + for unit in ("B", "KB", "MB", "GB", "TB"): + if value < 1024 or unit == "TB": + return f"{value:6.1f}{unit}" + value /= 1024 + return f"{value:6.1f}TB" + + +def main() -> None: + args = parse_args() + data_path = args.data.resolve() + if not data_path.exists(): + print(f"Data directory not found: {data_path}", file=sys.stderr) + sys.exit(1) + + results: list[ProfileResult] = [] + for profile in args.profiles: + print(f"\n=== Profile: {profile} ===") + print("Cleaning previous build artifacts...") + cargo_clean(profile) + compile_seconds = cargo_build(profile) + run_seconds = run_benchmark(profile, data_path) + size_bytes = binary_size(profile) + results.append(ProfileResult(profile, compile_seconds, run_seconds, size_bytes)) + + print("\nSummary") + header = f"{'Profile':<15}{'Compile':>12}{'Run':>12}{'Size':>12}" + print(header) + print("-" * len(header)) + for result in results: + print( + f"{result.profile:<15}{human_time(result.compile_seconds):>12}" + f"{human_time(result.run_seconds):>12}{human_size(result.binary_bytes):>12}" + ) + +if __name__ == "__main__": + main() diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index 2032427e1ef23..877ea0e0c3192 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -5,17 +5,18 @@ This directory contains queries for the ClickBench benchmark https://benchmark.c ClickBench is focused on aggregation and filtering performance (though it has no Joins) ## Files: -* `queries.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository] -* `extended.sql` - "Extended" DataFusion specific queries. -[ClickBench repository]: https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql +- `queries/*.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository](https://raw.githubusercontent.com/ClickHouse/ClickBench/main/datafusion/queries.sql) and split by the `update_queries.sh` script. +- `extended/*.sql` - "Extended" DataFusion specific queries. -## "Extended" Queries +[clickbench repository]: https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql + +## "Extended" Queries The "extended" queries are not part of the official ClickBench benchmark. Instead they are used to test other DataFusion features that are not covered by -the standard benchmark. Each description below is for the corresponding line in -`extended.sql` (line 1 is `Q0`, line 2 is `Q1`, etc.) +the standard benchmark. Each description below is for the corresponding file in +`extended` ### Q0: Data Exploration @@ -25,7 +26,7 @@ the standard benchmark. Each description below is for the corresponding line in distinct string columns. ```sql -SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; ``` @@ -35,7 +36,6 @@ FROM hits; **Important Query Properties**: multiple `COUNT DISTINCT`s. All three are small strings (length either 1 or 2). - ```sql SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; @@ -43,21 +43,20 @@ FROM hits; ### Q2: Top 10 analysis -**Question**: "Find the top 10 "browser country" by number of distinct "social network"s, -including the distinct counts of "hit color", "browser language", +**Question**: "Find the top 10 "browser country" by number of distinct "social network"s, +including the distinct counts of "hit color", "browser language", and "social action"." **Important Query Properties**: GROUP BY short, string, multiple `COUNT DISTINCT`s. There are several small strings (length either 1 or 2). ```sql SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") -FROM hits -GROUP BY 1 -ORDER BY 2 DESC +FROM hits +GROUP BY 1 +ORDER BY 2 DESC LIMIT 10; ``` - ### Q3: What is the income distribution for users in specific regions **Question**: "What regions and social networks have the highest variance of parameter price?" @@ -65,17 +64,17 @@ LIMIT 10; **Important Query Properties**: STDDEV and VAR aggregation functions, GROUP BY multiple small ints ```sql -SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") -FROM 'hits.parquet' -GROUP BY "SocialSourceNetworkID", "RegionID" +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") +FROM 'hits.parquet' +GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL -ORDER BY s DESC +ORDER BY s DESC LIMIT 10; ``` ### Q4: Response start time distribution analysis (median) -**Question**: Find the WatchIDs with the highest median "ResponseStartTiming" without Java enabled +**Question**: Find the WatchIDs with the highest median "ResponseStartTiming" without Java enabled **Important Query Properties**: MEDIAN, functions, high cardinality grouping that skips intermediate aggregation @@ -102,17 +101,16 @@ Results look like +-------------+---------------------+---+------+------+------+ ``` - ### Q5: Response start time distribution analysis (p95) -**Question**: Find the WatchIDs with the highest p95 "ResponseStartTiming" without Java enabled +**Question**: Find the WatchIDs with the highest p95 "ResponseStartTiming" without Java enabled **Important Query Properties**: APPROX_PERCENTILE_CONT, functions, high cardinality grouping that skips intermediate aggregation Note this query is somewhat synthetic as "WatchID" is almost unique (there are a few duplicates) ```sql -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT("ResponseStartTiming", 0.95) tp95, MAX("ResponseStartTiming") tmax +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits.parquet' WHERE "JavaEnable" = 0 -- filters to 32M of 100M rows GROUP BY "ClientIP", "WatchID" @@ -122,6 +120,7 @@ LIMIT 10; ``` Results look like + ``` +-------------+---------------------+---+------+------+------+ | ClientIP | WatchID | c | tmin | tp95 | tmax | @@ -132,6 +131,7 @@ Results look like ``` ### Q6: How many social shares meet complex multi-stage filtering criteria? + **Question**: What is the count of sharing actions from iPhone mobile users on specific social networks, within common timezones, participating in seasonal campaigns, with high screen resolutions and closely matched UTM parameters? **Important Query Properties**: Simple filter with high-selectivity, Costly string matching, A large number of filters with high overhead are positioned relatively later in the process @@ -150,20 +150,89 @@ WHERE -- Stage 3: Heavy computations (expensive) AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL -- Find campaign-specific referrers - AND CASE - WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' - THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT - ELSE 0 + AND CASE + WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' + THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT + ELSE 0 END > 1920 -- Extract and validate resolution parameter - AND levenshtein("UTMSource", "UTMCampaign") < 3 -- Verify UTM parameter similarity + AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3 -- Verify UTM parameter similarity ``` + Result is empty,Since it has already been filtered by `"SocialAction" = 'share'`. +### Q7: Device Resolution and Refresh Behavior Analysis + +**Question**: Identify the top 10 WatchIDs with the highest resolution range (min/max "ResolutionWidth") and total refresh count ("IsRefresh") in descending WatchID order + +**Important Query Properties**: Primitive aggregation functions, group by single primitive column, high cardinality grouping + +```sql +SELECT "WatchID", MIN("ResolutionWidth") as wmin, MAX("ResolutionWidth") as wmax, SUM("IsRefresh") as srefresh +FROM hits +GROUP BY "WatchID" +ORDER BY "WatchID" DESC +LIMIT 10; +``` + +Results look like + +``` ++---------------------+------+------+----------+ +| WatchID | wmin | wmax | srefresh | ++---------------------+------+------+----------+ +| 9223372033328793741 | 1368 | 1368 | 0 | +| 9223371941779979288 | 1479 | 1479 | 0 | +| 9223371906781104763 | 1638 | 1638 | 0 | +| 9223371803397398692 | 1990 | 1990 | 0 | +| 9223371799215233959 | 1638 | 1638 | 0 | +| 9223371785975219972 | 0 | 0 | 0 | +| 9223371776706839366 | 1368 | 1368 | 0 | +| 9223371740707848038 | 1750 | 1750 | 0 | +| 9223371715190479830 | 1368 | 1368 | 0 | +| 9223371620124912624 | 1828 | 1828 | 0 | ++---------------------+------+------+----------+ +``` + +### Q8: Average Latency and Response Time Analysis + +**Question**: Which combinations of operating system, region, and user agent exhibit the highest average latency? For each of these combinations, also report the average response time. + +**Important Query Properties**: Multiple average of Duration, high cardinality grouping + +```sql +SELECT "RegionID", "UserAgent", "OS", AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ResponseStartTiming")) as avg_response_time, AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ConnectTiming")) as avg_latency +FROM hits +GROUP BY "RegionID", "UserAgent", "OS" +ORDER BY avg_latency DESC +LIMIT 10; +``` + +Results look like + +``` ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +| RegionID | UserAgent | OS | avg_response_time | avg_latency | ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +| 22934 | 5 | 126 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 22735 | 82 | 74 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 21687 | 32 | 49 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 18518 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 14006 | 7 | 126 | 0 days 7 hours 58 mins 20.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 9803 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 107108 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 111626 | 7 | 44 | 0 days 7 hours 23 mins 12.500000000 secs | 0 days 8 hours 0 mins 47.000000000 secs | +| 17716 | 56 | 44 | 0 days 6 hours 48 mins 44.500000000 secs | 0 days 7 hours 35 mins 47.000000000 secs | +| 13631 | 82 | 45 | 0 days 7 hours 23 mins 1.000000000 secs | 0 days 7 hours 23 mins 1.000000000 secs | ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +10 row(s) fetched. +Elapsed 30.195 seconds. +``` ## Data Notes Here are some interesting statistics about the data used in the queries Max length of `"SearchPhrase"` is 1113 characters + ```sql > select min(length("SearchPhrase")) as "SearchPhrase_len_min", max(length("SearchPhrase")) "SearchPhrase_len_max" from 'hits.parquet' limit 10; +----------------------+----------------------+ @@ -173,8 +242,8 @@ Max length of `"SearchPhrase"` is 1113 characters +----------------------+----------------------+ ``` - Here is the schema of the data + ```sql > describe 'hits.parquet'; +-----------------------+-----------+-------------+ diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql deleted file mode 100644 index ef3a409c9c024..0000000000000 --- a/benchmarks/queries/clickbench/extended.sql +++ /dev/null @@ -1,7 +0,0 @@ -SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; -SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; -SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; -SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT("ResponseStartTiming", 0.95) tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; -SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein("UTMSource", "UTMCampaign") < 3; \ No newline at end of file diff --git a/benchmarks/queries/clickbench/extended/q0.sql b/benchmarks/queries/clickbench/extended/q0.sql new file mode 100644 index 0000000000000..cb826e5f947e9 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q0.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; diff --git a/benchmarks/queries/clickbench/extended/q1.sql b/benchmarks/queries/clickbench/extended/q1.sql new file mode 100644 index 0000000000000..7862423787d85 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q1.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; diff --git a/benchmarks/queries/clickbench/extended/q2.sql b/benchmarks/queries/clickbench/extended/q2.sql new file mode 100644 index 0000000000000..de2be79885792 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q2.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q3.sql b/benchmarks/queries/clickbench/extended/q3.sql new file mode 100644 index 0000000000000..f52990b9843a5 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q3.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q4.sql b/benchmarks/queries/clickbench/extended/q4.sql new file mode 100644 index 0000000000000..5865129db6425 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q4.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q5.sql b/benchmarks/queries/clickbench/extended/q5.sql new file mode 100644 index 0000000000000..18d3e01c82c4b --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q5.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q6.sql b/benchmarks/queries/clickbench/extended/q6.sql new file mode 100644 index 0000000000000..0a6467b8898aa --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q6.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3; diff --git a/benchmarks/queries/clickbench/extended/q7.sql b/benchmarks/queries/clickbench/extended/q7.sql new file mode 100644 index 0000000000000..ddaff7f8804f5 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q7.sql @@ -0,0 +1 @@ +SELECT "WatchID", MIN("ResolutionWidth") as wmin, MAX("ResolutionWidth") as wmax, SUM("IsRefresh") as srefresh FROM hits GROUP BY "WatchID" ORDER BY "WatchID" DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries.sql b/benchmarks/queries/clickbench/queries.sql deleted file mode 100644 index ba70f8a1c6b29..0000000000000 --- a/benchmarks/queries/clickbench/queries.sql +++ /dev/null @@ -1,43 +0,0 @@ -SELECT COUNT(*) FROM hits; -SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; -SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; -SELECT AVG("UserID") FROM hits; -SELECT COUNT(DISTINCT "UserID") FROM hits; -SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; -SELECT MIN("EventDate"::INT::DATE), MAX("EventDate"::INT::DATE) FROM hits; -SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; -SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; -SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; -SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; -SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; -SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; -SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; -SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; -SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; -SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime", "SearchPhrase" LIMIT 10; -SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; -SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; -SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; -SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; -SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; -SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; -SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; -SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; -SELECT "URLHash", "EventDate"::INT::DATE, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate"::INT::DATE ORDER BY PageViews DESC LIMIT 10 OFFSET 100; -SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; -SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q0.sql b/benchmarks/queries/clickbench/queries/q0.sql new file mode 100644 index 0000000000000..35f2b32ed4863 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q0.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 + +-- set datafusion.execution.parquet.binary_as_string = true +SELECT COUNT(*) FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q1.sql b/benchmarks/queries/clickbench/queries/q1.sql new file mode 100644 index 0000000000000..0bee959ec3c7d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q1.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; diff --git a/benchmarks/queries/clickbench/queries/q10.sql b/benchmarks/queries/clickbench/queries/q10.sql new file mode 100644 index 0000000000000..0f9114803fecf --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q10.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q11.sql b/benchmarks/queries/clickbench/queries/q11.sql new file mode 100644 index 0000000000000..bed8bb210e130 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q11.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q12.sql b/benchmarks/queries/clickbench/queries/q12.sql new file mode 100644 index 0000000000000..8cf09c0049f3d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q12.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q13.sql b/benchmarks/queries/clickbench/queries/q13.sql new file mode 100644 index 0000000000000..ef6583c8d1886 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q13.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q14.sql b/benchmarks/queries/clickbench/queries/q14.sql new file mode 100644 index 0000000000000..dd267146edec5 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q14.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q15.sql b/benchmarks/queries/clickbench/queries/q15.sql new file mode 100644 index 0000000000000..721d924cb9b95 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q15.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q16.sql b/benchmarks/queries/clickbench/queries/q16.sql new file mode 100644 index 0000000000000..389725d58d7a3 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q16.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q17.sql b/benchmarks/queries/clickbench/queries/q17.sql new file mode 100644 index 0000000000000..be9976a01d7a4 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q17.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q18.sql b/benchmarks/queries/clickbench/queries/q18.sql new file mode 100644 index 0000000000000..d649f1edfe2a4 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q18.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q19.sql b/benchmarks/queries/clickbench/queries/q19.sql new file mode 100644 index 0000000000000..8212a765730a3 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q19.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; diff --git a/benchmarks/queries/clickbench/queries/q2.sql b/benchmarks/queries/clickbench/queries/q2.sql new file mode 100644 index 0000000000000..bcdfad84ec10f --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q2.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q20.sql b/benchmarks/queries/clickbench/queries/q20.sql new file mode 100644 index 0000000000000..a7e488c2abcd8 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q20.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; diff --git a/benchmarks/queries/clickbench/queries/q21.sql b/benchmarks/queries/clickbench/queries/q21.sql new file mode 100644 index 0000000000000..3551689728ede --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q21.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q22.sql b/benchmarks/queries/clickbench/queries/q22.sql new file mode 100644 index 0000000000000..d5f696e75a8c8 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q22.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q23.sql b/benchmarks/queries/clickbench/queries/q23.sql new file mode 100644 index 0000000000000..ff399ded6ed8c --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q23.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q24.sql b/benchmarks/queries/clickbench/queries/q24.sql new file mode 100644 index 0000000000000..bc7a364151e23 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q24.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q25.sql b/benchmarks/queries/clickbench/queries/q25.sql new file mode 100644 index 0000000000000..5332e3451aeaf --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q25.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q26.sql b/benchmarks/queries/clickbench/queries/q26.sql new file mode 100644 index 0000000000000..bc1108aea1255 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q26.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime", "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q27.sql b/benchmarks/queries/clickbench/queries/q27.sql new file mode 100644 index 0000000000000..ba234d34f8877 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q27.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; diff --git a/benchmarks/queries/clickbench/queries/q28.sql b/benchmarks/queries/clickbench/queries/q28.sql new file mode 100644 index 0000000000000..6a3bd037bece7 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q28.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; diff --git a/benchmarks/queries/clickbench/queries/q29.sql b/benchmarks/queries/clickbench/queries/q29.sql new file mode 100644 index 0000000000000..bca1eb7bbe54b --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q29.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q3.sql b/benchmarks/queries/clickbench/queries/q3.sql new file mode 100644 index 0000000000000..09cdaca713047 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q3.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT AVG("UserID") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q30.sql b/benchmarks/queries/clickbench/queries/q30.sql new file mode 100644 index 0000000000000..c0d657927478e --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q30.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q31.sql b/benchmarks/queries/clickbench/queries/q31.sql new file mode 100644 index 0000000000000..76ab3622ffb57 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q31.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q32.sql b/benchmarks/queries/clickbench/queries/q32.sql new file mode 100644 index 0000000000000..88f1e4ce42d23 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q32.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q33.sql b/benchmarks/queries/clickbench/queries/q33.sql new file mode 100644 index 0000000000000..3740503bbc0e9 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q33.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q34.sql b/benchmarks/queries/clickbench/queries/q34.sql new file mode 100644 index 0000000000000..fdb7edbb656ac --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q34.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q35.sql b/benchmarks/queries/clickbench/queries/q35.sql new file mode 100644 index 0000000000000..de7e2256eb551 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q35.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q36.sql b/benchmarks/queries/clickbench/queries/q36.sql new file mode 100644 index 0000000000000..81b1199b0381e --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q36.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q37.sql b/benchmarks/queries/clickbench/queries/q37.sql new file mode 100644 index 0000000000000..fa4b85ffbd9cb --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q37.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q38.sql b/benchmarks/queries/clickbench/queries/q38.sql new file mode 100644 index 0000000000000..18fafab6c888f --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q38.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q39.sql b/benchmarks/queries/clickbench/queries/q39.sql new file mode 100644 index 0000000000000..306f0caacff64 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q39.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q4.sql b/benchmarks/queries/clickbench/queries/q4.sql new file mode 100644 index 0000000000000..d89ca78c2fb6f --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q4.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(DISTINCT "UserID") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q40.sql b/benchmarks/queries/clickbench/queries/q40.sql new file mode 100644 index 0000000000000..e9d27f5985fa9 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q40.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "URLHash", "EventDate", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate" ORDER BY PageViews DESC LIMIT 10 OFFSET 100; diff --git a/benchmarks/queries/clickbench/queries/q41.sql b/benchmarks/queries/clickbench/queries/q41.sql new file mode 100644 index 0000000000000..0e067e2dfc9da --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q41.sql @@ -0,0 +1,3 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true +SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; diff --git a/benchmarks/queries/clickbench/queries/q42.sql b/benchmarks/queries/clickbench/queries/q42.sql new file mode 100644 index 0000000000000..111cc1d3c4a9d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q42.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-14' AND "EventDate" <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q5.sql b/benchmarks/queries/clickbench/queries/q5.sql new file mode 100644 index 0000000000000..d371cfb6b3557 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q5.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q6.sql b/benchmarks/queries/clickbench/queries/q6.sql new file mode 100644 index 0000000000000..5b4e896a1df26 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q6.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT MIN("EventDate"), MAX("EventDate") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q7.sql b/benchmarks/queries/clickbench/queries/q7.sql new file mode 100644 index 0000000000000..afffcb1306d54 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q7.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; diff --git a/benchmarks/queries/clickbench/queries/q8.sql b/benchmarks/queries/clickbench/queries/q8.sql new file mode 100644 index 0000000000000..097880a9da5ed --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q8.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q9.sql b/benchmarks/queries/clickbench/queries/q9.sql new file mode 100644 index 0000000000000..cb1b79bf5bdc1 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q9.sql @@ -0,0 +1,4 @@ +-- Must set for ClickBench hits_partitioned dataset. See https://github.com/apache/datafusion/issues/16591 +-- set datafusion.execution.parquet.binary_as_string = true + +SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/update_queries.sh b/benchmarks/queries/clickbench/update_queries.sh new file mode 100755 index 0000000000000..d7db7359aa394 --- /dev/null +++ b/benchmarks/queries/clickbench/update_queries.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This script is meant for developers of DataFusion -- it is runnable +# from the standard DataFusion development environment and uses cargo, +# etc and orchestrates gathering data and run the benchmark binary in +# different configurations. + +# Script to download ClickBench queries and split them into individual files + +set -e # Exit on any error + +# URL for the raw file (not the GitHub page) +URL="https://raw.githubusercontent.com/ClickHouse/ClickBench/main/datafusion/queries.sql" + +# Temporary file to store downloaded content +TEMP_FILE="queries.sql" + +TARGET_DIR="queries" + +# Download the file +echo "Downloading queries from $URL..." +if command -v curl &> /dev/null; then + curl -s -o "$TEMP_FILE" "$URL" +elif command -v wget &> /dev/null; then + wget -q -O "$TEMP_FILE" "$URL" +else + echo "Error: Neither curl nor wget is available. Please install one of them." + exit 1 +fi + +# Check if download was successful +if [ ! -f "$TEMP_FILE" ] || [ ! -s "$TEMP_FILE" ]; then + echo "Error: Failed to download or file is empty" + exit 1 +fi + +# Initialize counter +counter=0 + +# Ensure the target directory exists +if [ ! -d ${TARGET_DIR} ]; then + mkdir -p ${TARGET_DIR} +fi + +# Read the file line by line and create individual query files +mapfile -t lines < $TEMP_FILE +for line in "${lines[@]}"; do + # Skip empty lines + if [ -n "$line" ]; then + # Create filename with zero-padded counter + filename="q${counter}.sql" + + # Write the line to the individual file + echo "$line" > "${TARGET_DIR}/$filename" + + echo "Created ${TARGET_DIR}/$filename" + + # Increment counter + (( counter += 1 )) + fi +done + +# Clean up temporary file +rm "$TEMP_FILE" \ No newline at end of file diff --git a/benchmarks/queries/h2o/groupby.sql b/benchmarks/queries/h2o/groupby.sql index c2101ef8ada2d..4fae7a13810d9 100644 --- a/benchmarks/queries/h2o/groupby.sql +++ b/benchmarks/queries/h2o/groupby.sql @@ -1,10 +1,19 @@ SELECT id1, SUM(v1) AS v1 FROM x GROUP BY id1; + SELECT id1, id2, SUM(v1) AS v1 FROM x GROUP BY id1, id2; + SELECT id3, SUM(v1) AS v1, AVG(v3) AS v3 FROM x GROUP BY id3; + SELECT id4, AVG(v1) AS v1, AVG(v2) AS v2, AVG(v3) AS v3 FROM x GROUP BY id4; + SELECT id6, SUM(v1) AS v1, SUM(v2) AS v2, SUM(v3) AS v3 FROM x GROUP BY id6; + SELECT id4, id5, MEDIAN(v3) AS median_v3, STDDEV(v3) AS sd_v3 FROM x GROUP BY id4, id5; + SELECT id3, MAX(v1) - MIN(v2) AS range_v1_v2 FROM x GROUP BY id3; + SELECT id6, largest2_v3 FROM (SELECT id6, v3 AS largest2_v3, ROW_NUMBER() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 FROM x WHERE v3 IS NOT NULL) sub_query WHERE order_v3 <= 2; + SELECT id2, id4, POWER(CORR(v1, v2), 2) AS r2 FROM x GROUP BY id2, id4; -SELECT id1, id2, id3, id4, id5, id6, SUM(v3) AS v3, COUNT(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6; + +SELECT id1, id2, id3, id4, id5, id6, SUM(v3) AS v3, COUNT(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6; \ No newline at end of file diff --git a/benchmarks/queries/h2o/join.sql b/benchmarks/queries/h2o/join.sql index 8546b9292dbb4..84cd661fdd592 100644 --- a/benchmarks/queries/h2o/join.sql +++ b/benchmarks/queries/h2o/join.sql @@ -1,5 +1,9 @@ SELECT x.id1, x.id2, x.id3, x.id4 as xid4, small.id4 as smallid4, x.id5, x.id6, x.v1, small.v2 FROM x INNER JOIN small ON x.id1 = small.id1; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x INNER JOIN medium ON x.id2 = medium.id2; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x LEFT JOIN medium ON x.id2 = medium.id2; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x JOIN medium ON x.id5 = medium.id5; -SELECT x.id1 as xid1, large.id1 as largeid1, x.id2 as xid2, large.id2 as largeid2, x.id3, x.id4 as xid4, large.id4 as largeid4, x.id5 as xid5, large.id5 as largeid5, x.id6 as xid6, large.id6 as largeid6, x.v1, large.v2 FROM x JOIN large ON x.id3 = large.id3; + +SELECT x.id1 as xid1, large.id1 as largeid1, x.id2 as xid2, large.id2 as largeid2, x.id3, x.id4 as xid4, large.id4 as largeid4, x.id5 as xid5, large.id5 as largeid5, x.id6 as xid6, large.id6 as largeid6, x.v1, large.v2 FROM x JOIN large ON x.id3 = large.id3; \ No newline at end of file diff --git a/benchmarks/queries/h2o/window.sql b/benchmarks/queries/h2o/window.sql new file mode 100644 index 0000000000000..071540927a4cf --- /dev/null +++ b/benchmarks/queries/h2o/window.sql @@ -0,0 +1,112 @@ +-- Basic Window +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER () AS window_basic +FROM large; + +-- Sorted Window +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (ORDER BY id3) AS first_order_by, + row_number() OVER (ORDER BY id3) AS row_number_order_by +FROM large; + +-- PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id1) AS sum_by_id1, + sum(v2) OVER (PARTITION BY id2) AS sum_by_id2, + sum(v2) OVER (PARTITION BY id3) AS sum_by_id3 +FROM large; + +-- PARTITION BY ORDER BY +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3) AS first_by_id2_ordered_by_id3 +FROM large; + +-- Lead and Lag +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (ORDER BY id3 ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING) AS my_lag, + first_value(v2) OVER (ORDER BY id3 ROWS BETWEEN 1 FOLLOWING AND 1 FOLLOWING) AS my_lead +FROM large; + +-- Moving Averages +SELECT + id1, + id2, + id3, + v2, + avg(v2) OVER (ORDER BY id3 ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) AS my_moving_average +FROM large; + +-- Rolling Sum +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (ORDER BY id3 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS my_rolling_sum +FROM large; + +-- RANGE BETWEEN +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (ORDER BY v2 RANGE BETWEEN 3 PRECEDING AND CURRENT ROW) AS my_range_between +FROM large; + +-- First PARTITION BY ROWS BETWEEN +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING) AS my_lag_by_id2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 1 FOLLOWING AND 1 FOLLOWING) AS my_lead_by_id2 +FROM large; + +-- Moving Averages PARTITION BY +SELECT + id1, + id2, + id3, + v2, + avg(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) AS my_moving_average_by_id2 +FROM large; + +-- Rolling Sum PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS my_rolling_sum_by_id2 +FROM large; + +-- RANGE BETWEEN PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id2 ORDER BY v2 RANGE BETWEEN 3 PRECEDING AND CURRENT ROW) AS my_range_between_by_id2 +FROM large; \ No newline at end of file diff --git a/benchmarks/queries/q10.sql b/benchmarks/queries/q10.sql index cf45e43485fb5..8613fd4962837 100644 --- a/benchmarks/queries/q10.sql +++ b/benchmarks/queries/q10.sql @@ -28,4 +28,5 @@ group by c_address, c_comment order by - revenue desc; \ No newline at end of file + revenue desc +limit 20; diff --git a/benchmarks/queries/q18.sql b/benchmarks/queries/q18.sql index 835de28a57be2..ba7ee7f716cf1 100644 --- a/benchmarks/queries/q18.sql +++ b/benchmarks/queries/q18.sql @@ -29,4 +29,5 @@ group by o_totalprice order by o_totalprice desc, - o_orderdate; \ No newline at end of file + o_orderdate +limit 100; diff --git a/benchmarks/queries/q2.sql b/benchmarks/queries/q2.sql index f66af210205e9..68e478f65d3f9 100644 --- a/benchmarks/queries/q2.sql +++ b/benchmarks/queries/q2.sql @@ -40,4 +40,5 @@ order by s_acctbal desc, n_name, s_name, - p_partkey; \ No newline at end of file + p_partkey +limit 100; diff --git a/benchmarks/queries/q21.sql b/benchmarks/queries/q21.sql index 9d2fe32cee228..b95e7b0dfca02 100644 --- a/benchmarks/queries/q21.sql +++ b/benchmarks/queries/q21.sql @@ -36,4 +36,5 @@ group by s_name order by numwait desc, - s_name; \ No newline at end of file + s_name +limit 100; diff --git a/benchmarks/queries/q3.sql b/benchmarks/queries/q3.sql index 7dbc6d9ef6783..e5fa9e38664c3 100644 --- a/benchmarks/queries/q3.sql +++ b/benchmarks/queries/q3.sql @@ -19,4 +19,5 @@ group by o_shippriority order by revenue desc, - o_orderdate; \ No newline at end of file + o_orderdate +limit 10; diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 06337cb758885..816cae0e38555 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -34,7 +34,7 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; use datafusion_benchmarks::{ - cancellation, clickbench, h2o, imdb, parquet_filter, sort, sort_tpch, tpch, + cancellation, clickbench, h2o, hj, imdb, nlj, sort_tpch, tpch, }; #[derive(Debug, StructOpt)] @@ -43,9 +43,9 @@ enum Options { Cancellation(cancellation::RunOpt), Clickbench(clickbench::RunOpt), H2o(h2o::RunOpt), + HJ(hj::RunOpt), Imdb(imdb::RunOpt), - ParquetFilter(parquet_filter::RunOpt), - Sort(sort::RunOpt), + Nlj(nlj::RunOpt), SortTpch(sort_tpch::RunOpt), Tpch(tpch::RunOpt), TpchConvert(tpch::ConvertOpt), @@ -60,11 +60,11 @@ pub async fn main() -> Result<()> { Options::Cancellation(opt) => opt.run().await, Options::Clickbench(opt) => opt.run().await, Options::H2o(opt) => opt.run().await, - Options::Imdb(opt) => opt.run().await, - Options::ParquetFilter(opt) => opt.run().await, - Options::Sort(opt) => opt.run().await, + Options::HJ(opt) => opt.run().await, + Options::Imdb(opt) => Box::pin(opt.run()).await, + Options::Nlj(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, - Options::Tpch(opt) => opt.run().await, + Options::Tpch(opt) => Box::pin(opt.run()).await, Options::TpchConvert(opt) => opt.run().await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index 578f71f8275d5..46b6cc9a80b24 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -40,7 +40,7 @@ use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; -use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; +use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{exec_err, DEFAULT_PARQUET_EXTENSION}; @@ -77,11 +77,6 @@ struct ExternalAggrConfig { output_path: Option, } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - /// Query Memory Limits /// Map query id to predefined memory limits /// @@ -118,7 +113,7 @@ impl ExternalAggrConfig { "#, ]; - /// If `--query` and `--memory-limit` is not speicified, run all queries + /// If `--query` and `--memory-limit` is not specified, run all queries /// with pre-configured memory limits /// If only `--query` is specified, run the query with all memory limits /// for this query @@ -189,7 +184,7 @@ impl ExternalAggrConfig { ) -> Result> { let query_name = format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); - let config = self.common.config(); + let config = self.common.config()?; let memory_pool: Arc = match mem_pool_type { "fair" => Arc::new(FairSpillPool::new(mem_limit as usize)), "greedy" => Arc::new(GreedyMemoryPool::new(mem_limit as usize)), @@ -335,7 +330,7 @@ impl ExternalAggrConfig { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs index 13421f8a89a9b..5ce99928df662 100644 --- a/benchmarks/src/bin/imdb.rs +++ b/benchmarks/src/bin/imdb.rs @@ -53,7 +53,7 @@ pub async fn main() -> Result<()> { env_logger::init(); match ImdbOpt::from_args() { ImdbOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - opt.run().await + Box::pin(opt.run()).await } ImdbOpt::Convert(opt) => opt.run().await, } diff --git a/benchmarks/src/bin/mem_profile.rs b/benchmarks/src/bin/mem_profile.rs new file mode 100644 index 0000000000000..16fc3871bec86 --- /dev/null +++ b/benchmarks/src/bin/mem_profile.rs @@ -0,0 +1,360 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! mem_profile binary entrypoint +use datafusion::error::Result; +use std::{ + env, + io::{BufRead, BufReader}, + path::Path, + process::{Command, Stdio}, +}; +use structopt::StructOpt; + +use datafusion_benchmarks::{ + clickbench, + h2o::{self, AllQueries}, + imdb, sort_tpch, tpch, +}; + +#[derive(Debug, StructOpt)] +#[structopt(name = "Memory Profiling Utility")] +struct MemProfileOpt { + /// Cargo profile to use in dfbench (e.g. release, release-nonlto) + #[structopt(long, default_value = "release")] + bench_profile: String, + + #[structopt(subcommand)] + command: Options, +} + +#[derive(Debug, StructOpt)] +#[structopt(about = "Benchmark command")] +enum Options { + Clickbench(clickbench::RunOpt), + H2o(h2o::RunOpt), + Imdb(imdb::RunOpt), + SortTpch(sort_tpch::RunOpt), + Tpch(tpch::RunOpt), +} + +#[tokio::main] +pub async fn main() -> Result<()> { + // 1. Parse args and check which benchmarks should be run + let mem_profile_opt = MemProfileOpt::from_args(); + let profile = mem_profile_opt.bench_profile; + let query_range = match mem_profile_opt.command { + Options::Clickbench(opt) => { + let entries = std::fs::read_dir(&opt.queries_path)? + .filter_map(Result::ok) + .filter(|e| { + let path = e.path(); + path.extension().map(|ext| ext == "sql").unwrap_or(false) + }) + .collect::>(); + + let max_query_id = entries.len().saturating_sub(1); + match opt.query { + Some(query_id) => query_id..=query_id, + None => 0..=max_query_id, + } + } + Options::H2o(opt) => { + let queries = AllQueries::try_new(&opt.queries_path)?; + match opt.query { + Some(query_id) => query_id..=query_id, + None => queries.min_query_id()..=queries.max_query_id(), + } + } + Options::Imdb(opt) => match opt.query { + Some(query_id) => query_id..=query_id, + None => imdb::IMDB_QUERY_START_ID..=imdb::IMDB_QUERY_END_ID, + }, + Options::SortTpch(opt) => match opt.query { + Some(query_id) => query_id..=query_id, + None => { + sort_tpch::SORT_TPCH_QUERY_START_ID..=sort_tpch::SORT_TPCH_QUERY_END_ID + } + }, + Options::Tpch(opt) => match opt.query { + Some(query_id) => query_id..=query_id, + None => tpch::TPCH_QUERY_START_ID..=tpch::TPCH_QUERY_END_ID, + }, + }; + + // 2. Prebuild dfbench binary so that memory does not blow up due to build process + println!("Pre-building benchmark binary..."); + let status = Command::new("cargo") + .args([ + "build", + "--profile", + &profile, + "--features", + "mimalloc_extended", + "--bin", + "dfbench", + ]) + .status() + .expect("Failed to build dfbench"); + assert!(status.success()); + println!("Benchmark binary built successfully."); + + // 3. Create a new process per each benchmark query and print summary + // Find position of subcommand to collect args for dfbench + let args: Vec<_> = env::args().collect(); + let subcommands = ["tpch", "clickbench", "h2o", "imdb", "sort-tpch"]; + let sub_pos = args + .iter() + .position(|s| subcommands.iter().any(|&cmd| s == cmd)) + .expect("No benchmark subcommand found"); + + // Args starting from subcommand become dfbench args + let mut dfbench_args: Vec = + args[sub_pos..].iter().map(|s| s.to_string()).collect(); + + run_benchmark_as_child_process(&profile, query_range, &mut dfbench_args)?; + + Ok(()) +} + +fn run_benchmark_as_child_process( + profile: &str, + query_range: std::ops::RangeInclusive, + args: &mut Vec, +) -> Result<()> { + let mut query_strings: Vec = Vec::new(); + for i in query_range { + query_strings.push(i.to_string()); + } + + let target_dir = + env::var("CARGO_TARGET_DIR").unwrap_or_else(|_| "target".to_string()); + let command = format!("{target_dir}/{profile}/dfbench"); + // Check whether benchmark binary exists + if !Path::new(&command).exists() { + panic!( + "Benchmark binary not found: `{command}`\nRun this command from the top-level `datafusion/` directory so `target/{profile}/dfbench` can be found.", + ); + } + args.insert(0, command); + let mut results = vec![]; + + // Run Single Query (args already contain --query num) + if args.contains(&"--query".to_string()) { + let _ = run_query(args, &mut results); + print_summary_table(&results); + return Ok(()); + } + + // Run All Queries + args.push("--query".to_string()); + for query_str in query_strings { + args.push(query_str); + let _ = run_query(args, &mut results); + args.pop(); + } + + print_summary_table(&results); + Ok(()) +} + +fn run_query(args: &[String], results: &mut Vec) -> Result<()> { + let exec_path = &args[0]; + let exec_args = &args[1..]; + + let mut child = Command::new(exec_path) + .args(exec_args) + .stdout(Stdio::piped()) + .spawn() + .expect("Failed to start benchmark"); + + let stdout = child.stdout.take().unwrap(); + let reader = BufReader::new(stdout); + + // Buffer child's stdout + let lines: Result, std::io::Error> = + reader.lines().collect::>(); + + child + .wait() + .expect("Benchmark process exited with an error"); + + // Parse after child process terminates + let lines = lines?; + let mut iter = lines.iter().peekable(); + + // Look for lines that contain execution time / memory stats + while let Some(line) = iter.next() { + if let Some((query, duration_ms)) = parse_query_time(line) { + if let Some(next_line) = iter.peek() { + if let Some((peak_rss, peak_commit, page_faults)) = + parse_vm_line(next_line) + { + results.push(QueryResult { + query, + duration_ms, + peak_rss, + peak_commit, + page_faults, + }); + break; + } + } + } + } + + Ok(()) +} + +#[derive(Debug)] +struct QueryResult { + query: usize, + duration_ms: f64, + peak_rss: String, + peak_commit: String, + page_faults: String, +} + +fn parse_query_time(line: &str) -> Option<(usize, f64)> { + let re = regex::Regex::new(r"Query (\d+) avg time: ([\d.]+) ms").unwrap(); + if let Some(caps) = re.captures(line) { + let query_id = caps[1].parse::().ok()?; + let avg_time = caps[2].parse::().ok()?; + Some((query_id, avg_time)) + } else { + None + } +} + +fn parse_vm_line(line: &str) -> Option<(String, String, String)> { + let re = regex::Regex::new( + r"Peak RSS:\s*([\d.]+\s*[A-Z]+),\s*Peak Commit:\s*([\d.]+\s*[A-Z]+),\s*Page Faults:\s*([\d.]+)" + ).ok()?; + let caps = re.captures(line)?; + let peak_rss = caps.get(1)?.as_str().to_string(); + let peak_commit = caps.get(2)?.as_str().to_string(); + let page_faults = caps.get(3)?.as_str().to_string(); + Some((peak_rss, peak_commit, page_faults)) +} + +// Print as simple aligned table +fn print_summary_table(results: &[QueryResult]) { + println!( + "\n{:<8} {:>10} {:>12} {:>12} {:>18}", + "Query", "Time (ms)", "Peak RSS", "Peak Commit", "Major Page Faults" + ); + println!("{}", "-".repeat(64)); + + for r in results { + println!( + "{:<8} {:>10.2} {:>12} {:>12} {:>18}", + r.query, r.duration_ms, r.peak_rss, r.peak_commit, r.page_faults + ); + } +} + +#[cfg(test)] +// Only run with "ci" mode when we have the data +#[cfg(feature = "ci")] +mod tests { + use datafusion::common::exec_err; + use datafusion::error::Result; + use std::path::{Path, PathBuf}; + use std::process::Command; + + fn get_tpch_data_path() -> Result { + let path = + std::env::var("TPCH_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); + if !Path::new(&path).exists() { + return exec_err!( + "Benchmark data not found (set TPCH_DATA env var to override): {}", + path + ); + } + Ok(path) + } + + // Try to find target/ dir upward + fn find_target_dir(start: &Path) -> Option { + let mut dir = start; + + while let Some(current) = Some(dir) { + if current.join("target").is_dir() { + return Some(current.join("target")); + } + + dir = match current.parent() { + Some(parent) => parent, + None => break, + }; + } + + None + } + + #[test] + // This test checks whether `mem_profile` runs successfully and produces expected output + // using TPC-H query 6 (which runs quickly). + fn mem_profile_e2e_tpch_q6() -> Result<()> { + let profile = "ci"; + let tpch_data = get_tpch_data_path()?; + + // The current working directory may not be the top-level datafusion/ directory, + // so we manually walkdir upward, locate the target directory + // and set it explicitly via CARGO_TARGET_DIR for the mem_profile command. + let target_dir = find_target_dir(&std::env::current_dir()?); + let output = Command::new("cargo") + .env("CARGO_TARGET_DIR", target_dir.unwrap()) + .args([ + "run", + "--profile", + profile, + "--bin", + "mem_profile", + "--", + "--bench-profile", + profile, + "tpch", + "--query", + "6", + "--path", + &tpch_data, + "--format", + "tbl", + ]) + .output() + .expect("Failed to run mem_profile"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + if !output.status.success() { + panic!( + "mem_profile failed\nstdout:\n{stdout}\nstderr:\n{stderr}---------------------", + ); + } + + assert!( + stdout.contains("Peak RSS") + && stdout.contains("Query") + && stdout.contains("Time"), + "Unexpected output:\n{stdout}", + ); + + Ok(()) + } +} diff --git a/benchmarks/src/bin/parquet.rs b/benchmarks/src/bin/parquet.rs deleted file mode 100644 index 6351a71a7bd3f..0000000000000 --- a/benchmarks/src/bin/parquet.rs +++ /dev/null @@ -1,49 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::common::Result; - -use datafusion_benchmarks::{parquet_filter, sort}; -use structopt::StructOpt; - -#[cfg(feature = "snmalloc")] -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[derive(Debug, Clone, StructOpt)] -#[structopt(name = "Benchmarks", about = "Apache DataFusion Rust Benchmarks.")] -enum ParquetBenchCmd { - /// Benchmark sorting parquet files - Sort(sort::RunOpt), - /// Benchmark parquet filter pushdown - Filter(parquet_filter::RunOpt), -} - -#[tokio::main] -async fn main() -> Result<()> { - let cmd = ParquetBenchCmd::from_args(); - match cmd { - ParquetBenchCmd::Filter(opt) => { - println!("running filter benchmarks"); - opt.run().await - } - ParquetBenchCmd::Sort(opt) => { - println!("running sort benchmarks"); - opt.run().await - } - } -} diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 3270b082cfb43..ca2bb8e57c0ec 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -58,7 +58,7 @@ async fn main() -> Result<()> { env_logger::init(); match TpchOpt::from_args() { TpchOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - opt.run().await + Box::pin(opt.run()).await } TpchOpt::Convert(opt) => opt.run().await, } diff --git a/benchmarks/src/cancellation.rs b/benchmarks/src/cancellation.rs index f5740bdc96e05..fcf03fbc54550 100644 --- a/benchmarks/src/cancellation.rs +++ b/benchmarks/src/cancellation.rs @@ -38,7 +38,7 @@ use futures::TryStreamExt; use object_store::ObjectStore; use parquet::arrow::async_writer::ParquetObjectWriter; use parquet::arrow::AsyncArrowWriter; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::rngs::ThreadRng; use rand::Rng; use structopt::StructOpt; @@ -237,7 +237,7 @@ fn find_files_on_disk(data_dir: impl AsRef) -> Result> { let path = file.unwrap().path(); if path .extension() - .map(|ext| (ext == "parquet")) + .map(|ext| ext == "parquet") .unwrap_or(false) { Some(path) @@ -312,15 +312,15 @@ async fn generate_data( } fn random_data(column_type: &DataType, rows: usize) -> Arc { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = (0..rows).map(|_| random_value(&mut rng, column_type)); ScalarValue::iter_to_array(values).unwrap() } fn random_value(rng: &mut ThreadRng, column_type: &DataType) -> ScalarValue { match column_type { - DataType::Float64 => ScalarValue::Float64(Some(rng.gen())), - DataType::Boolean => ScalarValue::Boolean(Some(rng.gen())), + DataType::Float64 => ScalarValue::Float64(Some(rng.random())), + DataType::Boolean => ScalarValue::Boolean(Some(rng.random())), DataType::Utf8 => ScalarValue::Utf8(Some( rng.sample_iter(&Alphanumeric) .take(10) diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 923c2bdd7cdf4..a550503390c54 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. -use std::path::Path; -use std::path::PathBuf; +use std::fs; +use std::io::ErrorKind; +use std::path::{Path, PathBuf}; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use datafusion::logical_expr::{ExplainFormat, ExplainOption}; use datafusion::{ error::{DataFusionError, Result}, prelude::SessionContext, @@ -27,7 +29,7 @@ use datafusion_common::exec_datafusion_err; use datafusion_common::instant::Instant; use structopt::StructOpt; -/// Run the clickbench benchmark +/// Driver program to run the ClickBench benchmark /// /// The ClickBench[1] benchmarks are widely cited in the industry and /// focus on grouping / aggregation / filtering. This runner uses the @@ -40,7 +42,15 @@ use structopt::StructOpt; pub struct RunOpt { /// Query number (between 0 and 42). If not specified, runs all queries #[structopt(short, long)] - query: Option, + pub query: Option, + + /// If specified, enables Parquet Filter Pushdown. + /// + /// Specifically, it enables: + /// * `pushdown_filters = true` + /// * `reorder_filters = true` + #[structopt(long = "pushdown")] + pushdown: bool, /// Common options #[structopt(flatten)] @@ -56,108 +66,153 @@ pub struct RunOpt { )] path: PathBuf, - /// Path to queries.sql (single file) + /// Path to queries directory #[structopt( parse(from_os_str), short = "r", long = "queries-path", - default_value = "benchmarks/queries/clickbench/queries.sql" + default_value = "benchmarks/queries/clickbench/queries" )] - queries_path: PathBuf, + pub queries_path: PathBuf, /// If present, write results json here #[structopt(parse(from_os_str), short = "o", long = "output")] output_path: Option, } -struct AllQueries { - queries: Vec, +/// Get the SQL file path +pub fn get_query_path(query_dir: &Path, query: usize) -> PathBuf { + let mut query_path = query_dir.to_path_buf(); + query_path.push(format!("q{query}.sql")); + query_path } -impl AllQueries { - fn try_new(path: &Path) -> Result { - // ClickBench has all queries in a single file identified by line number - let all_queries = std::fs::read_to_string(path) - .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; - Ok(Self { - queries: all_queries.lines().map(|s| s.to_string()).collect(), - }) +/// Get the SQL statement from the specified query file +pub fn get_query_sql(query_path: &Path) -> Result> { + if fs::exists(query_path)? { + Ok(Some(fs::read_to_string(query_path)?)) + } else { + Ok(None) } +} - /// Returns the text of query `query_id` - fn get_query(&self, query_id: usize) -> Result<&str> { - self.queries - .get(query_id) - .ok_or_else(|| { - let min_id = self.min_query_id(); - let max_id = self.max_query_id(); +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + + let query_dir_metadata = fs::metadata(&self.queries_path).map_err(|e| { + if e.kind() == ErrorKind::NotFound { exec_datafusion_err!( - "Invalid query id {query_id}. Must be between {min_id} and {max_id}" + "Query path '{}' does not exist.", + &self.queries_path.to_str().unwrap() ) - }) - .map(|s| s.as_str()) - } + } else { + DataFusionError::External(Box::new(e)) + } + })?; - fn min_query_id(&self) -> usize { - 0 - } + if !query_dir_metadata.is_dir() { + return Err(exec_datafusion_err!( + "Query path '{}' is not a directory.", + &self.queries_path.to_str().unwrap() + )); + } - fn max_query_id(&self) -> usize { - self.queries.len() - 1 - } -} -impl RunOpt { - pub async fn run(self) -> Result<()> { - println!("Running benchmarks with the following options: {self:?}"); - let queries = AllQueries::try_new(self.queries_path.as_path())?; let query_range = match self.query { Some(query_id) => query_id..=query_id, - None => queries.min_query_id()..=queries.max_query_id(), + None => 0..=usize::MAX, }; // configure parquet options - let mut config = self.common.config(); + let mut config = self.common.config()?; { let parquet_options = &mut config.options_mut().execution.parquet; // The hits_partitioned dataset specifies string columns // as binary due to how it was written. Force it to strings parquet_options.binary_as_string = true; + + // Turn on Parquet filter pushdown if requested + if self.pushdown { + parquet_options.pushdown_filters = true; + parquet_options.reorder_filters = true; + } } let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); self.register_hits(&ctx).await?; - let iterations = self.common.iterations; let mut benchmark_run = BenchmarkRun::new(); for query_id in query_range { - let mut millis = Vec::with_capacity(iterations); + let query_path = get_query_path(&self.queries_path, query_id); + let Some(sql) = get_query_sql(&query_path)? else { + if self.query.is_some() { + return Err(exec_datafusion_err!( + "Could not load query file '{}'.", + &query_path.to_str().unwrap() + )); + } + break; + }; benchmark_run.start_new_case(&format!("Query {query_id}")); - let sql = queries.get_query(query_id)?; - println!("Q{query_id}: {sql}"); - - for i in 0..iterations { - let start = Instant::now(); - let results = ctx.sql(sql).await?.collect().await?; - let elapsed = start.elapsed(); - let ms = elapsed.as_secs_f64() * 1000.0; - millis.push(ms); - let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); - println!( - "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" - ); - benchmark_run.write_iter(elapsed, row_count); + let query_run = self.benchmark_query(&sql, query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } - if self.common.debug { - ctx.sql(sql).await?.explain(false, false)?.show().await?; - } - let avg = millis.iter().sum::() / millis.len() as f64; - println!("Query {query_id} avg time: {avg:.2} ms"); } benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); Ok(()) } + async fn benchmark_query( + &self, + sql: &str, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { + println!("Q{query_id}: {sql}"); + + let mut millis = Vec::with_capacity(self.iterations()); + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + let results = ctx.sql(sql).await?.collect().await?; + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }) + } + if self.common.debug { + ctx.sql(sql) + .await? + .explain_with_options( + ExplainOption::default().with_format(ExplainFormat::Tree), + )? + .show() + .await?; + } + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + // Print memory usage stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + + Ok(query_results) + } + /// Registers the `hits.parquet` as a table named `hits` async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { let options = Default::default(); @@ -171,4 +226,8 @@ impl RunOpt { ) }) } + + fn iterations(&self) -> usize { + self.common.iterations + } } diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs index cc463e70d74a2..be74252031194 100644 --- a/benchmarks/src/h2o.rs +++ b/benchmarks/src/h2o.rs @@ -15,9 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::util::{BenchmarkRun, CommonOpt}; +//! H2O benchmark implementation for groupby, join and window operations +//! Reference: +//! - [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) +//! - [Extended window function benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) + +use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt}; +use datafusion::logical_expr::{ExplainFormat, ExplainOption}; use datafusion::{error::Result, prelude::SessionContext}; -use datafusion_common::{exec_datafusion_err, instant::Instant, DataFusionError}; +use datafusion_common::{ + exec_datafusion_err, instant::Instant, internal_err, DataFusionError, TableReference, +}; use std::path::{Path, PathBuf}; use structopt::StructOpt; @@ -26,7 +34,7 @@ use structopt::StructOpt; #[structopt(verbatim_doc_comment)] pub struct RunOpt { #[structopt(short, long)] - query: Option, + pub query: Option, /// Common options #[structopt(flatten)] @@ -40,7 +48,7 @@ pub struct RunOpt { long = "queries-path", default_value = "benchmarks/queries/h2o/groupby.sql" )] - queries_path: PathBuf, + pub queries_path: PathBuf, /// Path to data file (parquet or csv) /// Default value is the G1_1e7_1e7_100_0.csv file in the h2o benchmark @@ -77,19 +85,28 @@ impl RunOpt { None => queries.min_query_id()..=queries.max_query_id(), }; - let config = self.common.config(); + let config = self.common.config()?; let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); - if self.queries_path.to_str().unwrap().contains("join") { + // Register tables depending on which h2o benchmark is being run + // (groupby/join/window) + if self.queries_path.to_str().unwrap().ends_with("groupby.sql") { + self.register_data("x", self.path.as_os_str().to_str().unwrap(), &ctx) + .await?; + } else if self.queries_path.to_str().unwrap().ends_with("join.sql") { let join_paths: Vec<&str> = self.join_paths.split(',').collect(); let table_name: Vec<&str> = vec!["x", "small", "medium", "large"]; for (i, path) in join_paths.iter().enumerate() { - ctx.register_csv(table_name[i], path, Default::default()) - .await?; + self.register_data(table_name[i], path, &ctx).await?; } - } else if self.queries_path.to_str().unwrap().contains("groupby") { - self.register_data(&ctx).await?; + } else if self.queries_path.to_str().unwrap().ends_with("window.sql") { + // Only register the 'large' table in h2o-join dataset + let h2o_join_large_path = self.join_paths.split(',').nth(3).unwrap(); + self.register_data("large", h2o_join_large_path, &ctx) + .await?; + } else { + return internal_err!("Invalid query file path"); } let iterations = self.common.iterations; @@ -115,8 +132,17 @@ impl RunOpt { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {query_id} avg time: {avg:.2} ms"); + // Print memory usage stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + if self.common.debug { - ctx.sql(sql).await?.explain(false, false)?.show().await?; + ctx.sql(sql) + .await? + .explain_with_options( + ExplainOption::default().with_format(ExplainFormat::Tree), + )? + .show() + .await?; } benchmark_run.maybe_write_json(self.output_path.as_ref())?; } @@ -124,59 +150,72 @@ impl RunOpt { Ok(()) } - async fn register_data(&self, ctx: &SessionContext) -> Result<()> { + async fn register_data( + &self, + table_ref: impl Into, + table_path: impl AsRef, + ctx: &SessionContext, + ) -> Result<()> { let csv_options = Default::default(); let parquet_options = Default::default(); - let path = self.path.as_os_str().to_str().unwrap(); - - if self.path.extension().map(|s| s == "csv").unwrap_or(false) { - ctx.register_csv("x", path, csv_options) - .await - .map_err(|e| { - DataFusionError::Context( - format!("Registering 'table' as {path}"), - Box::new(e), - ) - }) - .expect("error registering csv"); - } - if self - .path + let table_path_str = table_path.as_ref(); + + let extension = Path::new(table_path_str) .extension() - .map(|s| s == "parquet") - .unwrap_or(false) - { - ctx.register_parquet("x", path, parquet_options) - .await - .map_err(|e| { - DataFusionError::Context( - format!("Registering 'table' as {path}"), - Box::new(e), - ) - }) - .expect("error registering parquet"); + .and_then(|s| s.to_str()) + .unwrap_or(""); + + match extension { + "csv" => { + ctx.register_csv(table_ref, table_path_str, csv_options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'table' as {table_path_str}"), + Box::new(e), + ) + }) + .expect("error registering csv"); + } + "parquet" => { + ctx.register_parquet(table_ref, table_path_str, parquet_options) + .await + .map_err(|e| { + DataFusionError::Context( + format!("Registering 'table' as {table_path_str}"), + Box::new(e), + ) + }) + .expect("error registering parquet"); + } + _ => { + return Err(DataFusionError::Plan(format!( + "Unsupported file extension: {extension}", + ))); + } } + Ok(()) } } -struct AllQueries { +pub struct AllQueries { queries: Vec, } impl AllQueries { - fn try_new(path: &Path) -> Result { + pub fn try_new(path: &Path) -> Result { let all_queries = std::fs::read_to_string(path) .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; Ok(Self { - queries: all_queries.lines().map(|s| s.to_string()).collect(), + queries: all_queries.split("\n\n").map(|s| s.to_string()).collect(), }) } /// Returns the text of query `query_id` - fn get_query(&self, query_id: usize) -> Result<&str> { + pub fn get_query(&self, query_id: usize) -> Result<&str> { self.queries .get(query_id - 1) .ok_or_else(|| { @@ -189,11 +228,11 @@ impl AllQueries { .map(|s| s.as_str()) } - fn min_query_id(&self) -> usize { + pub fn min_query_id(&self) -> usize { 1 } - fn max_query_id(&self) -> usize { + pub fn max_query_id(&self) -> usize { self.queries.len() } } diff --git a/benchmarks/src/hj.rs b/benchmarks/src/hj.rs new file mode 100644 index 0000000000000..505b322745485 --- /dev/null +++ b/benchmarks/src/hj.rs @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError}; +use structopt::StructOpt; + +use futures::StreamExt; + +// TODO: Add existence joins + +/// Run the Hash Join benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of Hash Joins. +/// It uses simple equality predicates to ensure a hash join is selected. +/// Where we vary selectivity, we do so with additional cheap predicates that +/// do not change the join key (so the physical operator remains HashJoin). +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 1 and 12). If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options (iterations, batch size, target_partitions, etc.) + #[structopt(flatten)] + common: CommonOpt, + + /// If present, write results json here + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +/// Inline SQL queries for Hash Join benchmarks +/// +/// Each query's comment includes: +/// - Left row count × Right row count +/// - Join predicate selectivity (approximate output fraction). +/// - Q11 and Q12 selectivity is relative to cartesian product while the others are +/// relative to probe side. +const HASH_QUERIES: &[&str] = &[ + // Q1: INNER 10 x 10K | LOW ~0.1% + // equality on key + cheap filter to downselect + r#" + SELECT t1.value, t2.value + FROM generate_series(0, 9000, 1000) AS t1(value) + JOIN range(10000) AS t2 + ON t1.value = t2.value; + "#, + // Q2: INNER 10 x 10K | LOW ~0.1% + r#" + SELECT t1.value, t2.value + FROM generate_series(0, 9000, 1000) AS t1 + JOIN range(10000) AS t2 + ON t1.value = t2.value + WHERE t1.value % 5 = 0 + "#, + // Q3: INNER 10K x 10K | HIGH ~90% + r#" + SELECT t1.value, t2.value + FROM range(10000) AS t1 + JOIN range(10000) AS t2 + ON t1.value = t2.value + WHERE t1.value % 10 <> 0 + "#, + // Q4: INNER 30 x 30K | LOW ~0.1% + r#" + SELECT t1.value, t2.value + FROM generate_series(0, 29000, 1000) AS t1 + JOIN range(30000) AS t2 + ON t1.value = t2.value + WHERE t1.value % 5 = 0 + "#, + // Q5: INNER 10 x 200K | VERY LOW ~0.005% (small to large) + r#" + SELECT t1.value, t2.value + FROM generate_series(0, 9000, 1000) AS t1 + JOIN range(200000) AS t2 + ON t1.value = t2.value + WHERE t1.value % 1000 = 0 + "#, + // Q6: INNER 200K x 10 | VERY LOW ~0.005% (large to small) + r#" + SELECT t1.value, t2.value + FROM range(200000) AS t1 + JOIN generate_series(0, 9000, 1000) AS t2 + ON t1.value = t2.value + WHERE t1.value % 1000 = 0 + "#, + // Q7: RIGHT OUTER 10 x 200K | LOW ~0.1% + // Outer join still uses HashJoin for equi-keys; the extra filter reduces matches + r#" + SELECT t1.value AS l, t2.value AS r + FROM generate_series(0, 9000, 1000) AS t1 + RIGHT JOIN range(200000) AS t2 + ON t1.value = t2.value + WHERE t2.value % 1000 = 0 + "#, + // Q8: LEFT OUTER 200K x 10 | LOW ~0.1% + r#" + SELECT t1.value AS l, t2.value AS r + FROM range(200000) AS t1 + LEFT JOIN generate_series(0, 9000, 1000) AS t2 + ON t1.value = t2.value + WHERE t1.value % 1000 = 0 + "#, + // Q9: FULL OUTER 30 x 30K | LOW ~0.1% + r#" + SELECT t1.value AS l, t2.value AS r + FROM generate_series(0, 29000, 1000) AS t1 + FULL JOIN range(30000) AS t2 + ON t1.value = t2.value + WHERE COALESCE(t1.value, t2.value) % 1000 = 0 + "#, + // Q10: FULL OUTER 30 x 30K | HIGH ~90% + r#" + SELECT t1.value AS l, t2.value AS r + FROM generate_series(0, 29000, 1000) AS t1 + FULL JOIN range(30000) AS t2 + ON t1.value = t2.value + WHERE COALESCE(t1.value, t2.value) % 10 <> 0 + "#, + // Q11: INNER 30 x 30K | MEDIUM ~50% | cheap predicate on parity + r#" + SELECT t1.value, t2.value + FROM generate_series(0, 29000, 1000) AS t1 + INNER JOIN range(30000) AS t2 + ON (t1.value % 2) = (t2.value % 2) + "#, + // Q12: FULL OUTER 30 x 30K | MEDIUM ~50% | expression key + r#" + SELECT t1.value AS l, t2.value AS r + FROM generate_series(0, 29000, 1000) AS t1 + FULL JOIN range(30000) AS t2 + ON (t1.value % 2) = (t2.value % 2) + "#, + // Q13: INNER 30 x 30K | LOW 0.1% | modulo with adding values + r#" + SELECT t1.value, t2.value + FROM generate_series(0, 29000, 1000) AS t1 + INNER JOIN range(30000) AS t2 + ON (t1.value = t2.value) AND ((t1.value + t2.value) % 10 < 1) + "#, + // Q14: FULL OUTER 30 x 30K | ALL ~100% | modulo + r#" + SELECT t1.value AS l, t2.value AS r + FROM generate_series(0, 29000, 1000) AS t1 + FULL JOIN range(30000) AS t2 + ON (t1.value = t2.value) AND ((t1.value + t2.value) % 10 = 0) + "#, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running Hash Join benchmarks with the following options: {self:#?}\n"); + + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= HASH_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + HASH_QUERIES.len() + ); + } + } + None => 1..=HASH_QUERIES.len(), + }; + + let config = self.common.config()?; + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + + let mut benchmark_run = BenchmarkRun::new(); + + for query_id in query_range { + let query_index = query_id - 1; + let sql = HASH_QUERIES[query_index]; + + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + format!("Hash Join benchmark Q{query_id} failed with error:"), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + /// Validates that the physical plan uses a HashJoin, then executes. + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Build/validate plan + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("HashJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Hash Join. Physical plan: {plan_string}" + )); + } + + // Execute without buffering + for i in 0..self.common.iterations { + let start = Instant::now(); + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each batch to avoid result buffering. + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + // Drop batches immediately to minimize memory pressure + } + + Ok(row_count) + } +} diff --git a/benchmarks/src/imdb/mod.rs b/benchmarks/src/imdb/mod.rs index 6a45242e6ff4b..87462bc3e81ba 100644 --- a/benchmarks/src/imdb/mod.rs +++ b/benchmarks/src/imdb/mod.rs @@ -54,6 +54,9 @@ pub const IMDB_TABLES: &[&str] = &[ "person_info", ]; +pub const IMDB_QUERY_START_ID: usize = 1; +pub const IMDB_QUERY_END_ID: usize = 113; + /// Get the schema for the IMDB dataset tables /// see benchmarks/data/imdb/schematext.sql pub fn get_imdb_table_schema(table: &str) -> Schema { diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index d7d7a56d0540e..3d58d5f54d4ba 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -18,8 +18,11 @@ use std::path::PathBuf; use std::sync::Arc; -use super::{get_imdb_table_schema, get_query_sql, IMDB_TABLES}; -use crate::util::{BenchmarkRun, CommonOpt}; +use super::{ + get_imdb_table_schema, get_query_sql, IMDB_QUERY_END_ID, IMDB_QUERY_START_ID, + IMDB_TABLES, +}; +use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -51,7 +54,7 @@ type BoolDefaultTrue = bool; /// [2] and [3]. /// /// [1]: https://www.vldb.org/pvldb/vol9/p204-leis.pdf -/// [2]: http://homepages.cwi.nl/~boncz/job/imdb.tgz +/// [2]: https://event.cwi.nl/da/job/imdb.tgz /// [3]: https://db.in.tum.de/~leis/qo/job.tgz #[derive(Debug, StructOpt, Clone)] @@ -59,7 +62,7 @@ type BoolDefaultTrue = bool; pub struct RunOpt { /// Query number. If not specified, runs all queries #[structopt(short, long)] - query: Option, + pub query: Option, /// Common options #[structopt(flatten)] @@ -91,9 +94,6 @@ pub struct RunOpt { prefer_hash_join: BoolDefaultTrue, } -const IMDB_QUERY_START_ID: usize = 1; -const IMDB_QUERY_END_ID: usize = 113; - fn map_query_id_to_str(query_id: usize) -> &'static str { match query_id { // 1 @@ -303,7 +303,7 @@ impl RunOpt { async fn benchmark_query(&self, query_id: usize) -> Result> { let mut config = self .common - .config() + .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; let rt_builder = self.common.runtime_env_builder()?; @@ -341,6 +341,9 @@ impl RunOpt { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {query_id} avg time: {avg:.2} ms"); + // Print memory usage stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + Ok(query_results) } @@ -471,15 +474,10 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - #[cfg(test)] // Only run with "ci" mode when we have the data #[cfg(feature = "ci")] @@ -514,7 +512,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, @@ -550,7 +548,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, @@ -572,7 +570,7 @@ mod tests { let plan = ctx.sql(&query).await?; let plan = plan.create_physical_plan().await?; let bytes = physical_plan_to_bytes(plan.clone())?; - let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; + let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; let plan_formatted = format!("{}", displayable(plan.as_ref()).indent(false)); let plan2_formatted = format!("{}", displayable(plan2.as_ref()).indent(false)); diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index a402fc1b8ce04..07cffa5ae468e 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -19,9 +19,9 @@ pub mod cancellation; pub mod clickbench; pub mod h2o; +pub mod hj; pub mod imdb; -pub mod parquet_filter; -pub mod sort; +pub mod nlj; pub mod sort_tpch; pub mod tpch; pub mod util; diff --git a/benchmarks/src/nlj.rs b/benchmarks/src/nlj.rs new file mode 100644 index 0000000000000..e412c0ade8a83 --- /dev/null +++ b/benchmarks/src/nlj.rs @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError}; +use structopt::StructOpt; + +use futures::StreamExt; + +/// Run the Nested Loop Join (NLJ) benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of NLJs. +/// +/// It always tries to use fast scanners (without decoding overhead) and +/// efficient predicate expressions to ensure it can reflect the performance +/// of the NLJ operator itself. +/// +/// In this micro-benchmark, the following workload characteristics will be +/// varied: +/// - Join type: Inner/Left/Right/Full (all for the NestedLoopJoin physical +/// operator) +/// TODO: Include special join types (Semi/Anti/Mark joins) +/// - Input size: Different combinations of left (build) side and right (probe) +/// side sizes +/// - Selectivity of join filters +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 1 and 10). If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// If present, write results json here + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +/// Inline SQL queries for NLJ benchmarks +/// +/// Each query's comment includes: +/// - Left (build) side row count × Right (probe) side row count +/// - Join predicate selectivity (1% means the output size is 1% * input size) +const NLJ_QUERIES: &[&str] = &[ + // Q1: INNER 10K x 10K | LOW 0.1% + r#" + SELECT * + FROM range(10000) AS t1 + JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q2: INNER 10K x 10K | Medium 20% + r#" + SELECT * + FROM range(10000) AS t1 + JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 5 = 0; + "#, + // Q3: INNER 10K x 10K | High 90% + r#" + SELECT * + FROM range(10000) AS t1 + JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 10 <> 0; + "#, + // Q4: INNER 30K x 30K | Medium 20% + r#" + SELECT * + FROM range(30000) AS t1 + JOIN range(30000) AS t2 + ON (t1.value + t2.value) % 5 = 0; + "#, + // Q5: INNER 10K x 200K | LOW 0.1% (small to large) + r#" + SELECT * + FROM range(10000) AS t1 + JOIN range(200000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q6: INNER 200K x 10K | LOW 0.1% (large to small) + r#" + SELECT * + FROM range(200000) AS t1 + JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q7: RIGHT OUTER 10K x 200K | LOW 0.1% + r#" + SELECT * + FROM range(10000) AS t1 + RIGHT JOIN range(200000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q8: LEFT OUTER 200K x 10K | LOW 0.1% + r#" + SELECT * + FROM range(200000) AS t1 + LEFT JOIN range(10000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q9: FULL OUTER 30K x 30K | LOW 0.1% + r#" + SELECT * + FROM range(30000) AS t1 + FULL JOIN range(30000) AS t2 + ON (t1.value + t2.value) % 1000 = 0; + "#, + // Q10: FULL OUTER 30K x 30K | High 90% + r#" + SELECT * + FROM range(30000) AS t1 + FULL JOIN range(30000) AS t2 + ON (t1.value + t2.value) % 10 <> 0; + "#, + // Q11: INNER 30K x 30K | MEDIUM 50% | cheap predicate + r#" + SELECT * + FROM range(30000) AS t1 + INNER JOIN range(30000) AS t2 + ON (t1.value > t2.value); + "#, + // Q12: FULL OUTER 30K x 30K | MEDIUM 50% | cheap predicate + r#" + SELECT * + FROM range(30000) AS t1 + FULL JOIN range(30000) AS t2 + ON (t1.value > t2.value); + "#, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running NLJ benchmarks with the following options: {self:#?}\n"); + + // Define query range + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= NLJ_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + NLJ_QUERIES.len() + ); + } + } + None => 1..=NLJ_QUERIES.len(), + }; + + let config = self.common.config()?; + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + let query_index = query_id - 1; // Convert 1-based to 0-based index + + let sql = NLJ_QUERIES[query_index]; + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + "NLJ benchmark Q{query_id} failed with error:".to_string(), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + /// Validates that the query's physical plan uses a NestedLoopJoin (NLJ), + /// then executes the query and collects execution times. + /// + /// TODO: ensure the optimizer won't change the join order (it's not at + /// v48.0.0). + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Validate that the query plan includes a Nested Loop Join + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("NestedLoopJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Nested Loop Join. Physical plan: {plan_string}" + )); + } + + for i in 0..self.common.iterations { + let start = Instant::now(); + + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each result batch after evaluation, to + /// minimizes memory usage by not buffering results. + /// + /// Returns the total result row count + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + + // Evaluate the result and do nothing, the result will be dropped + // to reduce memory pressure + } + + Ok(row_count) + } +} diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs deleted file mode 100644 index 34103af0ffd21..0000000000000 --- a/benchmarks/src/parquet_filter.rs +++ /dev/null @@ -1,194 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::path::PathBuf; - -use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; - -use arrow::util::pretty; -use datafusion::common::Result; -use datafusion::logical_expr::utils::disjunction; -use datafusion::logical_expr::{lit, or, Expr}; -use datafusion::physical_plan::collect; -use datafusion::prelude::{col, SessionContext}; -use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_common::instant::Instant; - -use structopt::StructOpt; - -/// Test performance of parquet filter pushdown -/// -/// The queries are executed on a synthetic dataset generated during -/// the benchmark execution and designed to simulate web server access -/// logs. -/// -/// Example -/// -/// dfbench parquet-filter --path ./data --scale-factor 1.0 -/// -/// generates the synthetic dataset at `./data/logs.parquet`. The size -/// of the dataset can be controlled through the `size_factor` -/// (with the default value of `1.0` generating a ~1GB parquet file). -/// -/// For each filter we will run the query using different -/// `ParquetScanOption` settings. -/// -/// Example output: -/// -/// Running benchmarks with the following options: Opt { debug: false, iterations: 3, partitions: 2, path: "./data", batch_size: 8192, scale_factor: 1.0 } -/// Generated test dataset with 10699521 rows -/// Executing with filter 'request_method = Utf8("GET")' -/// Using scan options ParquetScanOptions { pushdown_filters: false, reorder_predicates: false, enable_page_index: false } -/// Iteration 0 returned 10699521 rows in 1303 ms -/// Iteration 1 returned 10699521 rows in 1288 ms -/// Iteration 2 returned 10699521 rows in 1266 ms -/// Using scan options ParquetScanOptions { pushdown_filters: true, reorder_predicates: true, enable_page_index: true } -/// Iteration 0 returned 1781686 rows in 1970 ms -/// Iteration 1 returned 1781686 rows in 2002 ms -/// Iteration 2 returned 1781686 rows in 1988 ms -/// Using scan options ParquetScanOptions { pushdown_filters: true, reorder_predicates: false, enable_page_index: true } -/// Iteration 0 returned 1781686 rows in 1940 ms -/// Iteration 1 returned 1781686 rows in 1986 ms -/// Iteration 2 returned 1781686 rows in 1947 ms -/// ... -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] -pub struct RunOpt { - /// Common options - #[structopt(flatten)] - common: CommonOpt, - - /// Create data files - #[structopt(flatten)] - access_log: AccessLogOpt, - - /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, -} - -impl RunOpt { - pub async fn run(self) -> Result<()> { - let test_file = self.access_log.build()?; - - let mut rundata = BenchmarkRun::new(); - let scan_options_matrix = vec![ - ParquetScanOptions { - pushdown_filters: false, - reorder_filters: false, - enable_page_index: false, - }, - ParquetScanOptions { - pushdown_filters: true, - reorder_filters: true, - enable_page_index: true, - }, - ParquetScanOptions { - pushdown_filters: true, - reorder_filters: true, - enable_page_index: false, - }, - ]; - - let filter_matrix = vec![ - ("Selective-ish filter", col("request_method").eq(lit("GET"))), - ( - "Non-selective filter", - col("request_method").not_eq(lit("GET")), - ), - ( - "Basic conjunction", - col("request_method") - .eq(lit("POST")) - .and(col("response_status").eq(lit(503_u16))), - ), - ( - "Nested filters", - col("request_method").eq(lit("POST")).and(or( - col("response_status").eq(lit(503_u16)), - col("response_status").eq(lit(403_u16)), - )), - ), - ( - "Many filters", - disjunction([ - col("request_method").not_eq(lit("GET")), - col("response_status").eq(lit(400_u16)), - col("service").eq(lit("backend")), - ]) - .unwrap(), - ), - ("Filter everything", col("response_status").eq(lit(429_u16))), - ("Filter nothing", col("response_status").gt(lit(0_u16))), - ]; - - for (name, filter_expr) in &filter_matrix { - println!("Executing '{name}' (filter: {filter_expr})"); - for scan_options in &scan_options_matrix { - println!("Using scan options {scan_options:?}"); - rundata.start_new_case(&format!( - "{name}: {}", - parquet_scan_disp(scan_options) - )); - for i in 0..self.common.iterations { - let config = self.common.update_config(scan_options.config()); - let ctx = SessionContext::new_with_config(config); - - let (rows, elapsed) = exec_scan( - &ctx, - &test_file, - filter_expr.clone(), - self.common.debug, - ) - .await?; - let ms = elapsed.as_secs_f64() * 1000.0; - println!("Iteration {i} returned {rows} rows in {ms} ms"); - rundata.write_iter(elapsed, rows); - } - } - println!("\n"); - } - rundata.maybe_write_json(self.output_path.as_ref())?; - Ok(()) - } -} - -fn parquet_scan_disp(opts: &ParquetScanOptions) -> String { - format!( - "pushdown_filters={}, reorder_filters={}, page_index={}", - opts.pushdown_filters, opts.reorder_filters, opts.enable_page_index - ) -} - -async fn exec_scan( - ctx: &SessionContext, - test_file: &TestParquetFile, - filter: Expr, - debug: bool, -) -> Result<(usize, std::time::Duration)> { - let start = Instant::now(); - let exec = test_file.create_scan(ctx, Some(filter)).await?; - - let task_ctx = ctx.task_ctx(); - let result = collect(exec, task_ctx).await?; - let elapsed = start.elapsed(); - if debug { - pretty::print_batches(&result)?; - } - let rows = result.iter().map(|b| b.num_rows()).sum(); - Ok((rows, elapsed)) -} diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs deleted file mode 100644 index 9cf09c57205a7..0000000000000 --- a/benchmarks/src/sort.rs +++ /dev/null @@ -1,187 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::path::PathBuf; -use std::sync::Arc; - -use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; - -use arrow::util::pretty; -use datafusion::common::Result; -use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion::physical_plan::collect; -use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion::test_util::parquet::TestParquetFile; -use datafusion_common::instant::Instant; -use datafusion_common::utils::get_available_parallelism; -use structopt::StructOpt; - -/// Test performance of sorting large datasets -/// -/// This test sorts a a synthetic dataset generated during the -/// benchmark execution, designed to simulate sorting web server -/// access logs. Such sorting is often done during data transformation -/// steps. -/// -/// The tests sort the entire dataset using several different sort -/// orders. -/// -/// Example: -/// -/// dfbench sort --path ./data --scale-factor 1.0 -#[derive(Debug, StructOpt, Clone)] -#[structopt(verbatim_doc_comment)] -pub struct RunOpt { - /// Common options - #[structopt(flatten)] - common: CommonOpt, - - /// Create data files - #[structopt(flatten)] - access_log: AccessLogOpt, - - /// Path to machine readable output file - #[structopt(parse(from_os_str), short = "o", long = "output")] - output_path: Option, -} - -impl RunOpt { - pub async fn run(self) -> Result<()> { - let test_file = self.access_log.build()?; - - use datafusion::physical_expr::expressions::col; - let mut rundata = BenchmarkRun::new(); - let schema = test_file.schema(); - let sort_cases = vec![ - ( - "sort utf8", - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("request_method", &schema)?, - options: Default::default(), - }]), - ), - ( - "sort int", - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("response_bytes", &schema)?, - options: Default::default(), - }]), - ), - ( - "sort decimal", - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("decimal_price", &schema)?, - options: Default::default(), - }]), - ), - ( - "sort integer tuple", - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: col("request_bytes", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("response_bytes", &schema)?, - options: Default::default(), - }, - ]), - ), - ( - "sort utf8 tuple", - LexOrdering::new(vec![ - // sort utf8 tuple - PhysicalSortExpr { - expr: col("service", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("host", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("pod", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("image", &schema)?, - options: Default::default(), - }, - ]), - ), - ( - "sort mixed tuple", - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: col("service", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("request_bytes", &schema)?, - options: Default::default(), - }, - PhysicalSortExpr { - expr: col("decimal_price", &schema)?, - options: Default::default(), - }, - ]), - ), - ]; - for (title, expr) in sort_cases { - println!("Executing '{title}' (sorting by: {expr:?})"); - rundata.start_new_case(title); - for i in 0..self.common.iterations { - let config = SessionConfig::new().with_target_partitions( - self.common - .partitions - .unwrap_or(get_available_parallelism()), - ); - let ctx = SessionContext::new_with_config(config); - let (rows, elapsed) = - exec_sort(&ctx, &expr, &test_file, self.common.debug).await?; - let ms = elapsed.as_secs_f64() * 1000.0; - println!("Iteration {i} finished in {ms} ms"); - rundata.write_iter(elapsed, rows); - } - println!("\n"); - } - if let Some(path) = &self.output_path { - std::fs::write(path, rundata.to_json())?; - } - Ok(()) - } -} - -async fn exec_sort( - ctx: &SessionContext, - expr: &LexOrdering, - test_file: &TestParquetFile, - debug: bool, -) -> Result<(usize, std::time::Duration)> { - let start = Instant::now(); - let scan = test_file.create_scan(ctx, None).await?; - let exec = Arc::new(SortExec::new(expr.clone(), scan)); - let task_ctx = ctx.task_ctx(); - let result = collect(exec, task_ctx).await?; - let elapsed = start.elapsed(); - if debug { - pretty::print_batches(&result)?; - } - let rows = result.iter().map(|b| b.num_rows()).sum(); - Ok((rows, elapsed)) -} diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index 956bb92b6c78d..09b5a676bbff1 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -40,7 +40,7 @@ use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::DEFAULT_PARQUET_EXTENSION; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; #[derive(Debug, StructOpt)] pub struct RunOpt { @@ -50,7 +50,7 @@ pub struct RunOpt { /// Sort query number. If not specified, runs all queries #[structopt(short, long)] - query: Option, + pub query: Option, /// Path to data files (lineitem). Only parquet format is supported #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] @@ -63,13 +63,20 @@ pub struct RunOpt { /// Load the data into a MemTable before executing the query #[structopt(short = "m", long = "mem-table")] mem_table: bool, -} -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, + /// Mark the first column of each table as sorted in ascending order. + /// The tables should have been created with the `--sort` option for this to have any effect. + #[structopt(short = "t", long = "sorted")] + sorted: bool, + + /// Append a `LIMIT n` clause to the query + #[structopt(short = "l", long = "limit")] + limit: Option, } +pub const SORT_TPCH_QUERY_START_ID: usize = 1; +pub const SORT_TPCH_QUERY_END_ID: usize = 11; + impl RunOpt { const SORT_TABLES: [&'static str; 1] = ["lineitem"]; @@ -163,37 +170,45 @@ impl RunOpt { r#" SELECT l_shipmode, l_comment, l_partkey FROM lineitem - ORDER BY l_shipmode; + ORDER BY l_shipmode "#, ]; /// If query is specified from command line, run only that query. /// Otherwise, run all queries. pub async fn run(&self) -> Result<()> { - let mut benchmark_run = BenchmarkRun::new(); + let mut benchmark_run: BenchmarkRun = BenchmarkRun::new(); let query_range = match self.query { Some(query_id) => query_id..=query_id, - None => 1..=Self::SORT_QUERIES.len(), + None => SORT_TPCH_QUERY_START_ID..=SORT_TPCH_QUERY_END_ID, }; for query_id in query_range { benchmark_run.start_new_case(&format!("{query_id}")); - let query_results = self.benchmark_query(query_id).await?; - for iter in query_results { - benchmark_run.write_iter(iter.elapsed, iter.row_count); + let query_results = self.benchmark_query(query_id).await; + match query_results { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } } benchmark_run.maybe_write_json(self.output_path.as_ref())?; - + benchmark_run.maybe_print_failures(); Ok(()) } /// Benchmark query `query_id` in `SORT_QUERIES` async fn benchmark_query(&self, query_id: usize) -> Result> { - let config = self.common.config(); + let config = self.common.config()?; let rt_builder = self.common.runtime_env_builder()?; let state = SessionStateBuilder::new() .with_config(config) @@ -212,22 +227,30 @@ impl RunOpt { let start = Instant::now(); let query_idx = query_id - 1; // 1-indexed -> 0-indexed - let sql = Self::SORT_QUERIES[query_idx]; + let base_sql = Self::SORT_QUERIES[query_idx].to_string(); + let sql = if let Some(limit) = self.limit { + format!("{base_sql} LIMIT {limit}") + } else { + base_sql + }; - let row_count = self.execute_query(&ctx, sql).await?; + let row_count = self.execute_query(&ctx, sql.as_str()).await?; let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; let ms = elapsed.as_secs_f64() * 1000.0; millis.push(ms); println!( - "Q{query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" ); query_results.push(QueryResult { elapsed, row_count }); } let avg = millis.iter().sum::() / millis.len() as f64; - println!("Q{query_id} avg time: {avg:.2} ms"); + println!("Query {query_id} avg time: {avg:.2} ms"); + + // Print memory usage stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); Ok(query_results) } @@ -280,7 +303,7 @@ impl RunOpt { let mut stream = execute_stream(physical_plan.clone(), state.task_ctx())?; while let Some(batch) = stream.next().await { - row_count += batch.unwrap().num_rows(); + row_count += batch?.num_rows(); } if debug { @@ -315,8 +338,18 @@ impl RunOpt { .with_collect_stat(state.config().collect_statistics()); let table_path = ListingTableUrl::parse(path)?; - let config = ListingTableConfig::new(table_path).with_listing_options(options); - let config = config.infer_schema(&state).await?; + let schema = options.infer_schema(&state, &table_path).await?; + let options = if self.sorted { + let key_column_name = schema.fields()[0].name(); + options + .with_file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; + + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); Ok(Arc::new(ListingTable::try_new(config)?)) } @@ -328,6 +361,6 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } diff --git a/benchmarks/src/tpch/convert.rs b/benchmarks/src/tpch/convert.rs index 7f391d930045a..5219e09cd3052 100644 --- a/benchmarks/src/tpch/convert.rs +++ b/benchmarks/src/tpch/convert.rs @@ -22,15 +22,14 @@ use std::path::{Path, PathBuf}; use datafusion::common::not_impl_err; +use super::get_tbl_tpch_table_schema; +use super::TPCH_TABLES; use datafusion::error::Result; use datafusion::prelude::*; use parquet::basic::Compression; use parquet::file::properties::WriterProperties; use structopt::StructOpt; -use super::get_tbl_tpch_table_schema; -use super::TPCH_TABLES; - /// Convert tpch .slt files to .parquet or .csv files #[derive(Debug, StructOpt)] pub struct ConvertOpt { @@ -57,6 +56,10 @@ pub struct ConvertOpt { /// Batch size when reading CSV or Parquet files #[structopt(short = "s", long = "batch-size", default_value = "8192")] batch_size: usize, + + /// Sort each table by its first column in ascending order. + #[structopt(short = "t", long = "sort")] + sort: bool, } impl ConvertOpt { @@ -70,6 +73,7 @@ impl ConvertOpt { for table in TPCH_TABLES { let start = Instant::now(); let schema = get_tbl_tpch_table_schema(table); + let key_column_name = schema.fields()[0].name(); let input_path = format!("{input_path}/{table}.tbl"); let options = CsvReadOptions::new() @@ -77,6 +81,13 @@ impl ConvertOpt { .has_header(false) .delimiter(b'|') .file_extension(".tbl"); + let options = if self.sort { + // indicated that the file is already sorted by its first column to speed up the conversion + options + .file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; let config = SessionConfig::new().with_batch_size(self.batch_size); let ctx = SessionContext::new_with_config(config); @@ -99,6 +110,11 @@ impl ConvertOpt { if partitions > 1 { csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))? } + let csv = if self.sort { + csv.sort_by(vec![col(key_column_name)])? + } else { + csv + }; // create the physical plan let csv = csv.create_physical_plan().await?; diff --git a/benchmarks/src/tpch/mod.rs b/benchmarks/src/tpch/mod.rs index 23d0681f560c8..233ea94a05c1a 100644 --- a/benchmarks/src/tpch/mod.rs +++ b/benchmarks/src/tpch/mod.rs @@ -34,6 +34,9 @@ pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; +pub const TPCH_QUERY_START_ID: usize = 1; +pub const TPCH_QUERY_END_ID: usize = 22; + /// The `.tbl` file contains a trailing column pub fn get_tbl_tpch_table_schema(table: &str) -> Schema { let mut schema = SchemaBuilder::from(get_tpch_table_schema(table).fields); diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index eb9db821db02f..b93bdf254a279 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -19,9 +19,10 @@ use std::path::PathBuf; use std::sync::Arc; use super::{ - get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES, + get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_QUERY_END_ID, + TPCH_QUERY_START_ID, TPCH_TABLES, }; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -53,14 +54,14 @@ type BoolDefaultTrue = bool; /// [2]. /// /// [1]: http://www.tpc.org/tpch/ -/// [2]: https://github.com/databricks/tpch-dbgen.git, +/// [2]: https://github.com/databricks/tpch-dbgen.git /// [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf #[derive(Debug, StructOpt, Clone)] #[structopt(verbatim_doc_comment)] pub struct RunOpt { /// Query number. If not specified, runs all queries #[structopt(short, long)] - query: Option, + pub query: Option, /// Common options #[structopt(flatten)] @@ -90,10 +91,12 @@ pub struct RunOpt { /// True by default. #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] prefer_hash_join: BoolDefaultTrue, -} -const TPCH_QUERY_START_ID: usize = 1; -const TPCH_QUERY_END_ID: usize = 22; + /// Mark the first column of each table as sorted in ascending order. + /// The tables should have been created with the `--sort` option for this to have any effect. + #[structopt(short = "t", long = "sorted")] + sorted: bool, +} impl RunOpt { pub async fn run(self) -> Result<()> { @@ -104,55 +107,68 @@ impl RunOpt { }; let mut benchmark_run = BenchmarkRun::new(); - for query_id in query_range { - benchmark_run.start_new_case(&format!("Query {query_id}")); - let query_run = self.benchmark_query(query_id).await?; - for iter in query_run { - benchmark_run.write_iter(iter.elapsed, iter.row_count); - } - } - benchmark_run.maybe_write_json(self.output_path.as_ref())?; - Ok(()) - } - - async fn benchmark_query(&self, query_id: usize) -> Result> { let mut config = self .common - .config() + .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); - // register tables self.register_tables(&ctx).await?; + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query( + &self, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { let mut millis = vec![]; // run benchmark let mut query_results = vec![]; + + let sql = &get_query_sql(query_id)?; + for i in 0..self.iterations() { let start = Instant::now(); - let sql = &get_query_sql(query_id)?; - // query 15 is special, with 3 statements. the second statement is the one from which we // want to capture the results let mut result = vec![]; if query_id == 15 { for (n, query) in sql.iter().enumerate() { if n == 1 { - result = self.execute_query(&ctx, query).await?; + result = self.execute_query(ctx, query).await?; } else { - self.execute_query(&ctx, query).await?; + self.execute_query(ctx, query).await?; } } } else { for query in sql { - result = self.execute_query(&ctx, query).await?; + result = self.execute_query(ctx, query).await?; } } - let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let elapsed = start.elapsed(); let ms = elapsed.as_secs_f64() * 1000.0; millis.push(ms); info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); @@ -166,6 +182,9 @@ impl RunOpt { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {query_id} avg time: {avg:.2} ms"); + // Print memory stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + Ok(query_results) } @@ -256,7 +275,7 @@ impl RunOpt { (Arc::new(format), path, ".tbl") } "csv" => { - let path = format!("{path}/{table}"); + let path = format!("{path}/csv/{table}"); let format = CsvFormat::default() .with_delimiter(b',') .with_has_header(true); @@ -275,20 +294,28 @@ impl RunOpt { } }; + let table_path = ListingTableUrl::parse(path)?; let options = ListingOptions::new(format) .with_file_extension(extension) .with_target_partitions(target_partitions) .with_collect_stat(state.config().collect_statistics()); - - let table_path = ListingTableUrl::parse(path)?; - let config = ListingTableConfig::new(table_path).with_listing_options(options); - - let config = match table_format { - "parquet" => config.infer_schema(&state).await?, - "tbl" => config.with_schema(Arc::new(get_tbl_tpch_table_schema(table))), - "csv" => config.with_schema(Arc::new(get_tpch_table_schema(table))), + let schema = match table_format { + "parquet" => options.infer_schema(&state, &table_path).await?, + "tbl" => Arc::new(get_tbl_tpch_table_schema(table)), + "csv" => Arc::new(get_tpch_table_schema(table)), _ => unreachable!(), }; + let options = if self.sorted { + let key_column_name = schema.fields()[0].name(); + options + .with_file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; + + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); Ok(Arc::new(ListingTable::try_new(config)?)) } @@ -300,15 +327,10 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - #[cfg(test)] // Only run with "ci" mode when we have the data #[cfg(feature = "ci")] @@ -342,7 +364,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, @@ -357,6 +379,7 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + sorted: false, }; opt.register_tables(&ctx).await?; let queries = get_query_sql(query)?; @@ -378,7 +401,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, @@ -393,6 +416,7 @@ mod tests { output_path: None, disable_statistics: false, prefer_hash_join: true, + sorted: false, }; opt.register_tables(&ctx).await?; let queries = get_query_sql(query)?; @@ -400,7 +424,7 @@ mod tests { let plan = ctx.sql(&query).await?; let plan = plan.create_physical_plan().await?; let bytes = physical_plan_to_bytes(plan.clone())?; - let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; + let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; let plan_formatted = format!("{}", displayable(plan.as_ref()).indent(false)); let plan2_formatted = format!("{}", displayable(plan2.as_ref()).indent(false)); diff --git a/benchmarks/src/util/access_log.rs b/benchmarks/src/util/access_log.rs deleted file mode 100644 index 2b29465ee20e3..0000000000000 --- a/benchmarks/src/util/access_log.rs +++ /dev/null @@ -1,74 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Benchmark data generation - -use datafusion::common::Result; -use datafusion::test_util::parquet::TestParquetFile; -use parquet::file::properties::WriterProperties; -use std::path::PathBuf; -use structopt::StructOpt; -use test_utils::AccessLogGenerator; - -// Options and builder for making an access log test file -// Note don't use docstring or else it ends up in help -#[derive(Debug, StructOpt, Clone)] -pub struct AccessLogOpt { - /// Path to folder where access log file will be generated - #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] - path: PathBuf, - - /// Data page size of the generated parquet file - #[structopt(long = "page-size")] - page_size: Option, - - /// Data page size of the generated parquet file - #[structopt(long = "row-group-size")] - row_group_size: Option, - - /// Total size of generated dataset. The default scale factor of 1.0 will generate a roughly 1GB parquet file - #[structopt(long = "scale-factor", default_value = "1.0")] - scale_factor: f32, -} - -impl AccessLogOpt { - /// Create the access log and return the file. - /// - /// See [`TestParquetFile`] for more details - pub fn build(self) -> Result { - let path = self.path.join("logs.parquet"); - - let mut props_builder = WriterProperties::builder(); - - if let Some(s) = self.page_size { - props_builder = props_builder - .set_data_page_size_limit(s) - .set_write_batch_size(s); - } - - if let Some(s) = self.row_group_size { - props_builder = props_builder.set_max_row_group_size(s); - } - let props = props_builder.build(); - - let generator = AccessLogGenerator::new(); - - let num_batches = 100_f32 * self.scale_factor; - - TestParquetFile::try_new(path, props, generator.take(num_batches as usize)) - } -} diff --git a/benchmarks/src/util/memory.rs b/benchmarks/src/util/memory.rs new file mode 100644 index 0000000000000..944239df31cfd --- /dev/null +++ b/benchmarks/src/util/memory.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Print Peak RSS, Peak Commit, Page Faults based on mimalloc api +pub fn print_memory_stats() { + #[cfg(all(feature = "mimalloc", feature = "mimalloc_extended"))] + { + use datafusion::execution::memory_pool::human_readable_size; + let mut peak_rss = 0; + let mut peak_commit = 0; + let mut page_faults = 0; + unsafe { + libmimalloc_sys::mi_process_info( + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut peak_rss, + std::ptr::null_mut(), + &mut peak_commit, + &mut page_faults, + ); + } + + // When modifying this output format, make sure to update the corresponding + // parsers in `mem_profile.rs`, specifically `parse_vm_line` and `parse_query_time`, + // to keep the log output and parser logic in sync. + println!( + "Peak RSS: {}, Peak Commit: {}, Page Faults: {}", + if peak_rss == 0 { + "N/A".to_string() + } else { + human_readable_size(peak_rss) + }, + if peak_commit == 0 { + "N/A".to_string() + } else { + human_readable_size(peak_commit) + }, + page_faults + ); + } +} diff --git a/benchmarks/src/util/mod.rs b/benchmarks/src/util/mod.rs index 95c6e5f53d0f0..ab4579a566f66 100644 --- a/benchmarks/src/util/mod.rs +++ b/benchmarks/src/util/mod.rs @@ -16,10 +16,10 @@ // under the License. //! Shared benchmark utilities -mod access_log; +mod memory; mod options; mod run; -pub use access_log::AccessLogOpt; +pub use memory::print_memory_stats; pub use options::CommonOpt; -pub use run::{BenchQuery, BenchmarkRun}; +pub use run::{BenchQuery, BenchmarkRun, QueryResult}; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index a1cf31525dd92..6627a287dfcd4 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -19,13 +19,13 @@ use std::{num::NonZeroUsize, sync::Arc}; use datafusion::{ execution::{ - disk_manager::DiskManagerConfig, + disk_manager::DiskManagerBuilder, memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool}, runtime_env::RuntimeEnvBuilder, }, prelude::SessionConfig, }; -use datafusion_common::{utils::get_available_parallelism, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result}; use structopt::StructOpt; // Common benchmark options (don't use doc comments otherwise this doc @@ -41,8 +41,8 @@ pub struct CommonOpt { pub partitions: Option, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - pub batch_size: usize, + #[structopt(short = "s", long = "batch-size")] + pub batch_size: Option, /// The memory pool type to use, should be one of "fair" or "greedy" #[structopt(long = "mem-pool-type", default_value = "fair")] @@ -65,21 +65,25 @@ pub struct CommonOpt { impl CommonOpt { /// Return an appropriately configured `SessionConfig` - pub fn config(&self) -> SessionConfig { - self.update_config(SessionConfig::new()) + pub fn config(&self) -> Result { + SessionConfig::from_env().map(|config| self.update_config(config)) } /// Modify the existing config appropriately - pub fn update_config(&self, config: SessionConfig) -> SessionConfig { - let mut config = config - .with_target_partitions( - self.partitions.unwrap_or(get_available_parallelism()), - ) - .with_batch_size(self.batch_size); + pub fn update_config(&self, mut config: SessionConfig) -> SessionConfig { + if let Some(batch_size) = self.batch_size { + config = config.with_batch_size(batch_size); + } + + if let Some(partitions) = self.partitions { + config = config.with_target_partitions(partitions); + } + if let Some(sort_spill_reservation_bytes) = self.sort_spill_reservation_bytes { config = config.with_sort_spill_reservation_bytes(sort_spill_reservation_bytes); } + config } @@ -106,7 +110,7 @@ impl CommonOpt { }; rt_builder = rt_builder .with_memory_pool(pool) - .with_disk_manager(DiskManagerConfig::NewOs); + .with_disk_manager_builder(DiskManagerBuilder::default()); } Ok(rt_builder) } @@ -118,15 +122,14 @@ fn parse_memory_limit(limit: &str) -> Result { let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number .parse() - .map_err(|_| format!("Failed to parse number from memory limit '{}'", limit))?; + .map_err(|_| format!("Failed to parse number from memory limit '{limit}'"))?; match unit { "K" => Ok((number * 1024.0) as usize), "M" => Ok((number * 1024.0 * 1024.0) as usize), "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), _ => Err(format!( - "Unsupported unit '{}' in memory limit '{}'", - unit, limit + "Unsupported unit '{unit}' in memory limit '{limit}'" )), } } diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs index 13969f4d39497..764ea648ff725 100644 --- a/benchmarks/src/util/run.rs +++ b/benchmarks/src/util/run.rs @@ -90,8 +90,13 @@ pub struct BenchQuery { iterations: Vec, #[serde(serialize_with = "serialize_start_time")] start_time: SystemTime, + success: bool, +} +/// Internal representation of a single benchmark query iteration result. +pub struct QueryResult { + pub elapsed: Duration, + pub row_count: usize, } - /// collects benchmark run data and then serializes it at the end pub struct BenchmarkRun { context: RunContext, @@ -120,6 +125,7 @@ impl BenchmarkRun { query: id.to_owned(), iterations: vec![], start_time: SystemTime::now(), + success: true, }); if let Some(c) = self.current_case.as_mut() { *c += 1; @@ -138,6 +144,28 @@ impl BenchmarkRun { } } + /// Print the names of failed queries, if any + pub fn maybe_print_failures(&self) { + let failed_queries: Vec<&str> = self + .queries + .iter() + .filter_map(|q| (!q.success).then_some(q.query.as_str())) + .collect(); + + if !failed_queries.is_empty() { + println!("Failed Queries: {}", failed_queries.join(", ")); + } + } + + /// Mark current query + pub fn mark_failed(&mut self) { + if let Some(idx) = self.current_case { + self.queries[idx].success = false; + } else { + unreachable!("Cannot mark failure: no current case"); + } + } + /// Stringify data into formatted json pub fn to_json(&self) -> String { let mut output = HashMap::<&str, Value>::new(); diff --git a/ci/scripts/license_header.sh b/ci/scripts/license_header.sh new file mode 100755 index 0000000000000..5345728f9cdf0 --- /dev/null +++ b/ci/scripts/license_header.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Check Apache license header +set -ex +hawkeye check --config licenserc.toml diff --git a/ci/scripts/rust_clippy.sh b/ci/scripts/rust_clippy.sh index 8118ecc577007..1557bd56eab4a 100755 --- a/ci/scripts/rust_clippy.sh +++ b/ci/scripts/rust_clippy.sh @@ -18,4 +18,4 @@ # under the License. set -ex -cargo clippy --all-targets --workspace --features avro,pyarrow,integration-tests -- -D warnings +cargo clippy --all-targets --workspace --features avro,pyarrow,integration-tests,extended_tests -- -D warnings diff --git a/clippy.toml b/clippy.toml index 114e3bfceb272..ea3609b574c06 100644 --- a/clippy.toml +++ b/clippy.toml @@ -9,4 +9,14 @@ disallowed-types = [ # Lowering the threshold to help prevent stack overflows (default is 16384) # See: https://rust-lang.github.io/rust-clippy/master/index.html#/large_futures -future-size-threshold = 10000 \ No newline at end of file +future-size-threshold = 10000 + +# Be more aware of large error variants which can impact the "happy path" due +# to large stack footprint when considering async state machines (default is 128). +# +# Value of 70 picked arbitrarily as something less than 100. +# +# See: +# - https://github.com/apache/datafusion/issues/16652 +# - https://rust-lang.github.io/rust-clippy/master/index.html#result_large_err +large-error-threshold = 70 diff --git a/datafusion-cli/CONTRIBUTING.md b/datafusion-cli/CONTRIBUTING.md index 4b464dffc57ce..8be656ec4ee34 100644 --- a/datafusion-cli/CONTRIBUTING.md +++ b/datafusion-cli/CONTRIBUTING.md @@ -21,55 +21,40 @@ ## Running Tests -Tests can be run using `cargo` +First check out test files with ```shell -cargo test +git submodule update --init ``` -## Running Storage Integration Tests - -By default, storage integration tests are not run. To run them you will need to set `TEST_STORAGE_INTEGRATION=1` and -then provide the necessary configuration for that object store. +Then run all the tests with -For some of the tests, [snapshots](https://datafusion.apache.org/contributor-guide/testing.html#snapshot-testing) are used. +```shell +cargo test --all-targets +``` -### AWS +## Running Storage Integration Tests -To test the S3 integration against [Minio](https://github.com/minio/minio) +By default, storage integration tests are not run. These tests use the `testcontainers` crate to start up a local MinIO server using Docker on port 9000. -First start up a container with Minio and load test files. +To run them you will need to set `TEST_STORAGE_INTEGRATION`: ```shell -docker run -d \ - --name datafusion-test-minio \ - -p 9000:9000 \ - -e MINIO_ROOT_USER=TEST-DataFusionLogin \ - -e MINIO_ROOT_PASSWORD=TEST-DataFusionPassword \ - -v $(pwd)/../datafusion/core/tests/data:/source \ - quay.io/minio/minio server /data - -docker exec datafusion-test-minio /bin/sh -c "\ - mc ready local - mc alias set localminio http://localhost:9000 TEST-DataFusionLogin TEST-DataFusionPassword && \ - mc mb localminio/data && \ - mc cp -r /source/* localminio/data" +TEST_STORAGE_INTEGRATION=1 cargo test ``` -Setup environment +For some of the tests, [snapshots](https://datafusion.apache.org/contributor-guide/testing.html#snapshot-testing) are used. -```shell -export TEST_STORAGE_INTEGRATION=1 -export AWS_ACCESS_KEY_ID=TEST-DataFusionLogin -export AWS_SECRET_ACCESS_KEY=TEST-DataFusionPassword -export AWS_ENDPOINT=http://127.0.0.1:9000 -export AWS_ALLOW_HTTP=true -``` +### AWS -Note that `AWS_ENDPOINT` is set without slash at the end. +S3 integration is tested against [Minio](https://github.com/minio/minio) with [TestContainers](https://github.com/testcontainers/testcontainers-rs) +This requires Docker to be running on your machine and port 9000 to be free. -Run tests +If you see an error mentioning "failed to load IMDS session token" such as -```shell -cargo test -``` +> ---- object_storage::tests::s3_object_store_builder_resolves_region_when_none_provided stdout ---- +> Error: ObjectStore(Generic { store: "S3", source: "Error getting credentials from provider: an error occurred while loading credentials: failed to load IMDS session token" }) + +You may need to disable trying to fetch S3 credentials from the environment using the `AWS_EC2_METADATA_DISABLED`, for example: + +> $ AWS_EC2_METADATA_DISABLED=true TEST_STORAGE_INTEGRATION=1 cargo test diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index c70e3fc1caec5..d186cd711945d 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -37,37 +37,41 @@ backtrace = ["datafusion/backtrace"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -aws-config = "1.6.1" -aws-credential-types = "1.2.0" -clap = { version = "4.5.34", features = ["derive", "cargo"] } +aws-config = "1.8.7" +aws-credential-types = "1.2.7" +chrono = { workspace = true } +clap = { version = "4.5.47", features = ["derive", "cargo"] } datafusion = { workspace = true, features = [ "avro", + "compression", "crypto_expressions", "datetime_expressions", "encoding_expressions", "nested_expressions", "parquet", + "parquet_encryption", "recursive_protection", "regex_expressions", + "sql", "unicode_expressions", - "compression", ] } dirs = "6.0.0" env_logger = { workspace = true } futures = { workspace = true } +log = { workspace = true } mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "gcp", "http"] } parking_lot = { workspace = true } parquet = { workspace = true, default-features = false } regex = { workspace = true } -rustyline = "15.0" +rustyline = "17.0" tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } url = { workspace = true } [dev-dependencies] -assert_cmd = "2.0" ctor = { workspace = true } insta = { workspace = true } insta-cmd = "0.6.0" -predicates = "3.0" rstest = { workspace = true } +testcontainers = { workspace = true } +testcontainers-modules = { workspace = true, features = ["minio"] } diff --git a/datafusion-cli/README.md b/datafusion-cli/README.md index ca796b525fa15..b34aa770374da 100644 --- a/datafusion-cli/README.md +++ b/datafusion-cli/README.md @@ -19,12 +19,15 @@ -# DataFusion Command-line Interface +# Apache DataFusion Command-line Interface -[DataFusion](https://datafusion.apache.org/) is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. DataFusion CLI (`datafusion-cli`) is a small command line utility that runs SQL queries using the DataFusion engine. +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ + # Frequently Asked Questions ## Where can I find more information? diff --git a/datafusion-cli/examples/cli-session-context.rs b/datafusion-cli/examples/cli-session-context.rs index 1a8f15c8731b2..bd2dbb736781f 100644 --- a/datafusion-cli/examples/cli-session-context.rs +++ b/datafusion-cli/examples/cli-session-context.rs @@ -28,7 +28,9 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_cli::{ - cli_context::CliSessionContext, exec::exec_from_repl, print_options::PrintOptions, + cli_context::CliSessionContext, exec::exec_from_repl, + object_storage::instrumented::InstrumentedObjectStoreRegistry, + print_options::PrintOptions, }; use object_store::ObjectStore; @@ -89,6 +91,7 @@ pub async fn main() { quiet: false, maxrows: datafusion_cli::print_options::MaxRows::Unlimited, color: true, + instrumented_registry: Arc::new(InstrumentedObjectStoreRegistry::new()), }; exec_from_repl(&my_ctx, &mut print_options).await.unwrap(); diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index ceb72dbc546bd..20d62eabc3901 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -200,6 +200,7 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { table_url.scheme(), url, &state.default_table_options(), + false, ) .await?; state.runtime_env().register_object_store(url, store); @@ -229,6 +230,7 @@ pub fn substitute_tilde(cur: String) -> String { } #[cfg(test)] mod tests { + use std::{env, vec}; use super::*; @@ -284,6 +286,19 @@ mod tests { #[tokio::test] async fn query_s3_location_test() -> Result<()> { + let aws_envs = vec![ + "AWS_ENDPOINT", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_ALLOW_HTTP", + ]; + for aws_env in aws_envs { + if env::var(aws_env).is_err() { + eprint!("aws envs not set, skipping s3 test"); + return Ok(()); + } + } + let bucket = "examples3bucket"; let location = format!("s3://{bucket}/file.parquet"); @@ -337,8 +352,7 @@ mod tests { #[cfg(not(target_os = "windows"))] #[test] fn test_substitute_tilde() { - use std::env; - use std::path::MAIN_SEPARATOR; + use std::{env, path::PathBuf}; let original_home = home_dir(); let test_home_path = if cfg!(windows) { "C:\\Users\\user" @@ -350,17 +364,16 @@ mod tests { test_home_path, ); let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet"; - let expected = format!( - "{}{}Code{}datafusion{}benchmarks{}data{}tpch_sf1{}part{}part-0.parquet", - test_home_path, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR - ); + let expected = PathBuf::from(test_home_path) + .join("Code") + .join("datafusion") + .join("benchmarks") + .join("data") + .join("tpch_sf1") + .join("part") + .join("part-0.parquet") + .to_string_lossy() + .to_string(); let actual = substitute_tilde(input.to_string()); assert_eq!(actual, expected); match original_home { diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index fc7d1a2617cf6..48fb37e8a8880 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -26,9 +26,9 @@ use clap::ValueEnum; use datafusion::arrow::array::{ArrayRef, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::exec_err; use datafusion::common::instant::Instant; -use datafusion::error::{DataFusionError, Result}; +use datafusion::common::{exec_datafusion_err, exec_err}; +use datafusion::error::Result; use std::fs::File; use std::io::BufReader; use std::str::FromStr; @@ -46,6 +46,7 @@ pub enum Command { SearchFunctions(String), QuietMode(Option), OutputFormat(Option), + ObjectStoreProfileMode(Option), } pub enum OutputFormat { @@ -64,22 +65,27 @@ impl Command { let command_batch = all_commands_info(); let schema = command_batch.schema(); let num_rows = command_batch.num_rows(); - print_options.print_batches(schema, &[command_batch], now, num_rows) + let task_ctx = ctx.task_ctx(); + let config = &task_ctx.session_config().options().format; + print_options.print_batches( + schema, + &[command_batch], + now, + num_rows, + config, + ) } Self::ListTables => { exec_and_print(ctx, print_options, "SHOW TABLES".into()).await } Self::DescribeTableStmt(name) => { - exec_and_print(ctx, print_options, format!("SHOW COLUMNS FROM {}", name)) + exec_and_print(ctx, print_options, format!("SHOW COLUMNS FROM {name}")) .await } Self::Include(filename) => { if let Some(filename) = filename { let file = File::open(filename).map_err(|e| { - DataFusionError::Execution(format!( - "Error opening {:?} {}", - filename, e - )) + exec_datafusion_err!("Error opening {filename:?} {e}") })?; exec_from_lines(ctx, &mut BufReader::new(file), print_options) .await?; @@ -108,7 +114,7 @@ impl Command { Self::SearchFunctions(function) => { if let Ok(func) = function.parse::() { let details = func.function_details()?; - println!("{}", details); + println!("{details}"); Ok(()) } else { exec_err!("{function} is not a supported function") @@ -117,6 +123,29 @@ impl Command { Self::OutputFormat(_) => exec_err!( "Unexpected change output format, this should be handled outside" ), + Self::ObjectStoreProfileMode(mode) => { + if let Some(mode) = mode { + let profile_mode = mode + .parse() + .map_err(|_| + exec_datafusion_err!("Failed to parse input: {mode}. Valid options are disabled, enabled") + )?; + print_options + .instrumented_registry + .set_instrument_mode(profile_mode); + println!( + "ObjectStore Profile mode set to {}", + print_options.instrumented_registry.instrument_mode() + ); + } else { + println!( + "ObjectStore Profile mode is {}", + print_options.instrumented_registry.instrument_mode() + ); + } + + Ok(()) + } } } @@ -135,11 +164,15 @@ impl Command { Self::OutputFormat(_) => { ("\\pset [NAME [VALUE]]", "set table output option\n(format)") } + Self::ObjectStoreProfileMode(_) => ( + "\\object_store_profiling (disabled|enabled)", + "print or set object store profile mode", + ), } } } -const ALL_COMMANDS: [Command; 9] = [ +const ALL_COMMANDS: [Command; 10] = [ Command::ListTables, Command::DescribeTableStmt(String::new()), Command::Quit, @@ -149,6 +182,7 @@ const ALL_COMMANDS: [Command; 9] = [ Command::SearchFunctions(String::new()), Command::QuietMode(None), Command::OutputFormat(None), + Command::ObjectStoreProfileMode(None), ]; fn all_commands_info() -> RecordBatch { @@ -199,6 +233,10 @@ impl FromStr for Command { Self::OutputFormat(Some(subcommand.to_string())) } ("pset", None) => Self::OutputFormat(None), + ("object_store_profiling", Some(mode)) => { + Self::ObjectStoreProfileMode(Some(mode.to_string())) + } + ("object_store_profiling", None) => Self::ObjectStoreProfileMode(None), _ => return Err(()), }) } @@ -239,3 +277,53 @@ impl OutputFormat { } } } + +#[cfg(test)] +mod tests { + use datafusion::prelude::SessionContext; + + use crate::{ + object_storage::instrumented::{ + InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, + }, + print_options::MaxRows, + }; + + use super::*; + + #[tokio::test] + async fn command_execute_profile_mode() { + let ctx = SessionContext::new(); + + let mut print_options = PrintOptions { + format: PrintFormat::Automatic, + quiet: false, + maxrows: MaxRows::Unlimited, + color: true, + instrumented_registry: Arc::new(InstrumentedObjectStoreRegistry::new()), + }; + + let mut cmd: Command = "object_store_profiling" + .parse() + .expect("expected parse to succeed"); + assert!(cmd.execute(&ctx, &mut print_options).await.is_ok()); + assert_eq!( + print_options.instrumented_registry.instrument_mode(), + InstrumentedObjectStoreMode::default() + ); + + cmd = "object_store_profiling enabled" + .parse() + .expect("expected parse to succeed"); + assert!(cmd.execute(&ctx, &mut print_options).await.is_ok()); + assert_eq!( + print_options.instrumented_registry.instrument_mode(), + InstrumentedObjectStoreMode::Enabled + ); + + cmd = "object_store_profiling does_not_exist" + .parse() + .expect("expected parse to succeed"); + assert!(cmd.execute(&ctx, &mut print_options).await.is_err()); + } +} diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 0f4d70c1cca97..d079a88a6440e 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -26,28 +26,28 @@ use crate::{ object_storage::get_object_store, print_options::{MaxRows, PrintOptions}, }; -use futures::StreamExt; -use std::collections::HashMap; -use std::fs::File; -use std::io::prelude::*; -use std::io::BufReader; - use datafusion::common::instant::Instant; use datafusion::common::{plan_datafusion_err, plan_err}; use datafusion::config::ConfigFileType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::execution_plan::EmissionType; +use datafusion::physical_plan::spill::get_record_batch_memory_size; use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties}; use datafusion::sql::parser::{DFParser, Statement}; -use datafusion::sql::sqlparser::dialect::dialect_from_str; - -use datafusion::execution::memory_pool::MemoryConsumer; -use datafusion::physical_plan::spill::get_record_batch_memory_size; use datafusion::sql::sqlparser; +use datafusion::sql::sqlparser::dialect::dialect_from_str; +use futures::StreamExt; +use log::warn; +use object_store::Error::Generic; use rustyline::error::ReadlineError; use rustyline::Editor; +use std::collections::HashMap; +use std::fs::File; +use std::io::prelude::*; +use std::io::BufReader; use tokio::signal; /// run and execute SQL statements and commands, against a context with the given print options @@ -200,7 +200,7 @@ pub async fn exec_from_repl( break; } Err(err) => { - eprintln!("Unknown error happened {:?}", err); + eprintln!("Unknown error happened {err:?}"); break; } } @@ -214,9 +214,9 @@ pub(super) async fn exec_and_print( print_options: &PrintOptions, sql: String, ) -> Result<()> { - let now = Instant::now(); let task_ctx = ctx.task_ctx(); - let dialect = &task_ctx.session_config().options().sql_parser.dialect; + let options = task_ctx.session_config().options(); + let dialect = &options.sql_parser.dialect; let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ @@ -227,14 +227,43 @@ pub(super) async fn exec_and_print( let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { - let adjusted = - AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement); + StatementExecutor::new(statement) + .execute(ctx, print_options) + .await?; + } - let plan = create_plan(ctx, statement).await?; - let adjusted = adjusted.with_plan(&plan); + Ok(()) +} - let df = ctx.execute_logical_plan(plan).await?; +/// Executor for SQL statements, including special handling for S3 region detection retry logic +struct StatementExecutor { + statement: Statement, + statement_for_retry: Option, +} + +impl StatementExecutor { + fn new(statement: Statement) -> Self { + let statement_for_retry = matches!(statement, Statement::CreateExternalTable(_)) + .then(|| statement.clone()); + + Self { + statement, + statement_for_retry, + } + } + + async fn execute( + self, + ctx: &dyn CliSessionContext, + print_options: &PrintOptions, + ) -> Result<()> { + let now = Instant::now(); + let (df, adjusted) = self + .create_and_execute_logical_plan(ctx, print_options) + .await?; let physical_plan = df.create_physical_plan().await?; + let task_ctx = ctx.task_ctx(); + let options = task_ctx.session_config().options(); // Track memory usage for the query result if it's bounded let mut reservation = @@ -250,7 +279,9 @@ pub(super) async fn exec_and_print( // As the input stream comes, we can generate results. // However, memory safety is not guaranteed. let stream = execute_stream(physical_plan, task_ctx.clone())?; - print_options.print_stream(stream, now).await?; + print_options + .print_stream(stream, now, &options.format) + .await?; } else { // Bounded stream; collected results size is limited by the maxrows option let schema = physical_plan.schema(); @@ -273,14 +304,47 @@ pub(super) async fn exec_and_print( } row_count += curr_num_rows; } - adjusted - .into_inner() - .print_batches(schema, &results, now, row_count)?; + adjusted.into_inner().print_batches( + schema, + &results, + now, + row_count, + &options.format, + )?; reservation.free(); } + + Ok(()) } - Ok(()) + async fn create_and_execute_logical_plan( + mut self, + ctx: &dyn CliSessionContext, + print_options: &PrintOptions, + ) -> Result<(datafusion::dataframe::DataFrame, AdjustedPrintOptions)> { + let adjusted = AdjustedPrintOptions::new(print_options.clone()) + .with_statement(&self.statement); + + let plan = create_plan(ctx, self.statement, false).await?; + let adjusted = adjusted.with_plan(&plan); + + let df = match ctx.execute_logical_plan(plan).await { + Ok(df) => Ok(df), + Err(DataFusionError::ObjectStore(err)) + if matches!(err.as_ref(), Generic { store, source: _ } if "S3".eq_ignore_ascii_case(store)) + && self.statement_for_retry.is_some() => + { + warn!("S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration."); + let plan = + create_plan(ctx, self.statement_for_retry.take().unwrap(), true) + .await?; + ctx.execute_logical_plan(plan).await + } + Err(e) => Err(e), + }?; + + Ok((df, adjusted)) + } } /// Track adjustments to the print options based on the plan / statement being executed @@ -341,6 +405,7 @@ fn config_file_type_from_str(ext: &str) -> Option { async fn create_plan( ctx: &dyn CliSessionContext, statement: Statement, + resolve_region: bool, ) -> Result { let mut plan = ctx.session_state().statement_to_plan(statement).await?; @@ -355,6 +420,7 @@ async fn create_plan( &cmd.location, &cmd.options, format, + resolve_region, ) .await?; } @@ -367,6 +433,7 @@ async fn create_plan( ©_to.output_url, ©_to.options, format, + false, ) .await?; } @@ -405,6 +472,7 @@ pub(crate) async fn register_object_store_and_config_extensions( location: &String, options: &HashMap, format: Option, + resolve_region: bool, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -426,8 +494,14 @@ pub(crate) async fn register_object_store_and_config_extensions( table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options - let store = - get_object_store(&ctx.session_state(), scheme, url, &table_options).await?; + let store = get_object_store( + &ctx.session_state(), + scheme, + url, + &table_options, + resolve_region, + ) + .await?; // Register the retrieved object store in the session context's runtime environment ctx.register_object_store(url, store); @@ -455,6 +529,7 @@ mod tests { &cmd.location, &cmd.options, format, + false, ) .await?; } else { @@ -481,6 +556,7 @@ mod tests { &cmd.output_url, &cmd.options, format, + false, ) .await?; } else { @@ -506,6 +582,19 @@ mod tests { } #[tokio::test] async fn copy_to_external_object_store_test() -> Result<()> { + let aws_envs = vec![ + "AWS_ENDPOINT", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_ALLOW_HTTP", + ]; + for aws_env in aws_envs { + if std::env::var(aws_env).is_err() { + eprint!("aws envs not set, skipping s3 test"); + return Ok(()); + } + } + let locations = vec![ "s3://bucket/path/file.parquet", "oss://bucket/path/file.parquet", @@ -523,11 +612,11 @@ mod tests { ) })?; for location in locations { - let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location); + let sql = format!("copy (values (1,2)) to '{location}' STORED AS PARQUET;"); let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail - let mut plan = create_plan(&ctx, statement).await?; + let mut plan = create_plan(&ctx, statement, false).await?; if let LogicalPlan::Copy(copy_to) = &mut plan { assert_eq!(copy_to.output_url, location); assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 13d2d5fd3547b..3ec446c515836 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -22,8 +22,8 @@ use std::fs::File; use std::str::FromStr; use std::sync::Arc; -use arrow::array::{Int64Array, StringArray}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::array::{Int64Array, StringArray, TimestampMillisecondArray, UInt64Array}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion::catalog::{Session, TableFunctionImpl}; @@ -31,6 +31,7 @@ use datafusion::common::{plan_err, Column}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::TableProvider; use datafusion::error::Result; +use datafusion::execution::cache::cache_manager::CacheManager; use datafusion::logical_expr::Expr; use datafusion::physical_plan::ExecutionPlan; use datafusion::scalar::ScalarValue; @@ -205,7 +206,7 @@ pub fn display_all_functions() -> Result<()> { let array = StringArray::from( ALL_FUNCTIONS .iter() - .map(|f| format!("{}", f)) + .map(|f| format!("{f}")) .collect::>(), ); let schema = Schema::new(vec![Field::new("Function", DataType::Utf8, false)]); @@ -322,7 +323,7 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { let filename = match exprs.first() { - Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") _ => { return plan_err!( @@ -460,3 +461,121 @@ impl TableFunctionImpl for ParquetMetadataFunc { Ok(Arc::new(parquet_metadata)) } } + +/// METADATA_CACHE table function +#[derive(Debug)] +struct MetadataCacheTable { + schema: SchemaRef, + batch: RecordBatch, +} + +#[async_trait] +impl TableProvider for MetadataCacheTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::logical_expr::TableType { + datafusion::logical_expr::TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(MemorySourceConfig::try_new_exec( + &[vec![self.batch.clone()]], + TableProvider::schema(self), + projection.cloned(), + )?) + } +} + +#[derive(Debug)] +pub struct MetadataCacheFunc { + cache_manager: Arc, +} + +impl MetadataCacheFunc { + pub fn new(cache_manager: Arc) -> Self { + Self { cache_manager } + } +} + +impl TableFunctionImpl for MetadataCacheFunc { + fn call(&self, exprs: &[Expr]) -> Result> { + if !exprs.is_empty() { + return plan_err!("metadata_cache should have no arguments"); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("path", DataType::Utf8, false), + Field::new( + "file_modified", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("file_size_bytes", DataType::UInt64, false), + Field::new("e_tag", DataType::Utf8, true), + Field::new("version", DataType::Utf8, true), + Field::new("metadata_size_bytes", DataType::UInt64, false), + Field::new("hits", DataType::UInt64, false), + Field::new("extra", DataType::Utf8, true), + ])); + + // construct record batch from metadata + let mut path_arr = vec![]; + let mut file_modified_arr = vec![]; + let mut file_size_bytes_arr = vec![]; + let mut e_tag_arr = vec![]; + let mut version_arr = vec![]; + let mut metadata_size_bytes = vec![]; + let mut hits_arr = vec![]; + let mut extra_arr = vec![]; + + let cached_entries = self.cache_manager.get_file_metadata_cache().list_entries(); + + for (path, entry) in cached_entries { + path_arr.push(path.to_string()); + file_modified_arr + .push(Some(entry.object_meta.last_modified.timestamp_millis())); + file_size_bytes_arr.push(entry.object_meta.size); + e_tag_arr.push(entry.object_meta.e_tag); + version_arr.push(entry.object_meta.version); + metadata_size_bytes.push(entry.size_bytes as u64); + hits_arr.push(entry.hits as u64); + + let mut extra = entry + .extra + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>(); + extra.sort(); + extra_arr.push(extra.join(" ")); + } + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(path_arr)), + Arc::new(TimestampMillisecondArray::from(file_modified_arr)), + Arc::new(UInt64Array::from(file_size_bytes_arr)), + Arc::new(StringArray::from(e_tag_arr)), + Arc::new(StringArray::from(version_arr)), + Arc::new(UInt64Array::from(metadata_size_bytes)), + Arc::new(UInt64Array::from(hits_arr)), + Arc::new(StringArray::from(extra_arr)), + ], + )?; + + let metadata_cache = MetadataCacheTable { schema, batch }; + Ok(Arc::new(metadata_cache)) + } +} diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index 34fba6f79304b..f0b0bc23fd73d 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![doc = include_str!("../README.md")] pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 0b7a98f652018..3dbe839d3c9b3 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -17,17 +17,24 @@ use std::collections::HashMap; use std::env; +use std::num::NonZeroUsize; use std::path::Path; use std::process::ExitCode; use std::sync::{Arc, LazyLock}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; -use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool}; +use datafusion::execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, +}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::logical_expr::ExplainFormat; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicObjectStoreCatalog; -use datafusion_cli::functions::ParquetMetadataFunc; +use datafusion_cli::functions::{MetadataCacheFunc, ParquetMetadataFunc}; +use datafusion_cli::object_storage::instrumented::{ + InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, +}; use datafusion_cli::{ exec, pool_type::PoolType, @@ -39,6 +46,7 @@ use datafusion_cli::{ use clap::Parser; use datafusion::common::config_err; use datafusion::config::ConfigOptions; +use datafusion::execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use mimalloc::MiMalloc; #[global_allocator] @@ -116,6 +124,13 @@ struct Args { )] mem_pool_type: PoolType, + #[clap( + long, + help = "The number of top memory consumers to display when query fails due to memory exhaustion. To disable memory consumer tracking, set this value to 0", + default_value = "3" + )] + top_memory_consumers: usize, + #[clap( long, help = "The max number of rows to display for 'Table' format\n[possible values: numbers(0/10/...), inf(no limit)]", @@ -125,6 +140,21 @@ struct Args { #[clap(long, help = "Enables console syntax highlighting")] color: bool, + + #[clap( + short = 'd', + long, + help = "Available disk space for spilling queries (e.g. '10g'), default to None (uses DataFusion's default value of '100g')", + value_parser(extract_disk_limit) + )] + disk_limit: Option, + + #[clap( + long, + help = "Specify the default object_store_profiling mode, defaults to 'disabled'.\n[possible values: disabled, enabled]", + default_value_t = InstrumentedObjectStoreMode::Disabled + )] + object_store_profiling: InstrumentedObjectStoreMode, } #[tokio::main] @@ -144,7 +174,7 @@ async fn main_inner() -> Result<()> { let args = Args::parse(); if !args.quiet { - println!("DataFusion CLI v{}", DATAFUSION_CLI_VERSION); + println!("DataFusion CLI v{DATAFUSION_CLI_VERSION}"); } if let Some(ref path) = args.data_path { @@ -159,12 +189,39 @@ async fn main_inner() -> Result<()> { if let Some(memory_limit) = args.memory_limit { // set memory pool type let pool: Arc = match args.mem_pool_type { - PoolType::Fair => Arc::new(FairSpillPool::new(memory_limit)), - PoolType::Greedy => Arc::new(GreedyMemoryPool::new(memory_limit)), + PoolType::Fair if args.top_memory_consumers == 0 => { + Arc::new(FairSpillPool::new(memory_limit)) + } + PoolType::Fair => Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_limit), + NonZeroUsize::new(args.top_memory_consumers).unwrap(), + )), + PoolType::Greedy if args.top_memory_consumers == 0 => { + Arc::new(GreedyMemoryPool::new(memory_limit)) + } + PoolType::Greedy => Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_limit), + NonZeroUsize::new(args.top_memory_consumers).unwrap(), + )), }; + rt_builder = rt_builder.with_memory_pool(pool) } + // set disk limit + if let Some(disk_limit) = args.disk_limit { + let builder = DiskManagerBuilder::default() + .with_mode(DiskManagerMode::OsTmpDirectory) + .with_max_temp_directory_size(disk_limit.try_into().unwrap()); + rt_builder = rt_builder.with_disk_manager_builder(builder); + } + + let instrumented_registry = Arc::new( + InstrumentedObjectStoreRegistry::new() + .with_profile_mode(args.object_store_profiling), + ); + rt_builder = rt_builder.with_object_store_registry(instrumented_registry.clone()); + let runtime_env = rt_builder.build_arc()?; // enable dynamic file query @@ -179,11 +236,20 @@ async fn main_inner() -> Result<()> { // register `parquet_metadata` table function to get metadata from parquet files ctx.register_udtf("parquet_metadata", Arc::new(ParquetMetadataFunc {})); + // register `metadata_cache` table function to get the contents of the file metadata cache + ctx.register_udtf( + "metadata_cache", + Arc::new(MetadataCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + let mut print_options = PrintOptions { format: args.format, quiet: args.quiet, maxrows: args.maxrows, color: args.color, + instrumented_registry: Arc::clone(&instrumented_registry), }; let commands = args.command; @@ -240,7 +306,12 @@ fn get_session_config(args: &Args) -> Result { // use easier to understand "tree" mode by default // if the user hasn't specified an explain format in the environment if env::var_os("DATAFUSION_EXPLAIN_FORMAT").is_none() { - config_options.explain.format = String::from("tree"); + config_options.explain.format = ExplainFormat::Tree; + } + + // in the CLI, we want to show NULL values rather the empty strings + if env::var_os("DATAFUSION_FORMAT_NULL").is_none() { + config_options.format.null = String::from("NULL"); } let session_config = @@ -252,7 +323,7 @@ fn parse_valid_file(dir: &str) -> Result { if Path::new(dir).is_file() { Ok(dir.to_string()) } else { - Err(format!("Invalid file '{}'", dir)) + Err(format!("Invalid file '{dir}'")) } } @@ -260,14 +331,14 @@ fn parse_valid_data_dir(dir: &str) -> Result { if Path::new(dir).is_dir() { Ok(dir.to_string()) } else { - Err(format!("Invalid data directory '{}'", dir)) + Err(format!("Invalid data directory '{dir}'")) } } fn parse_batch_size(size: &str) -> Result { match size.parse::() { Ok(size) if size > 0 => Ok(size), - _ => Err(format!("Invalid batch size '{}'", size)), + _ => Err(format!("Invalid batch size '{size}'")), } } @@ -300,7 +371,7 @@ impl ByteUnit { } } -fn extract_memory_pool_size(size: &str) -> Result { +fn parse_size_string(size: &str, label: &str) -> Result { static BYTE_SUFFIXES: LazyLock> = LazyLock::new(|| { let mut m = HashMap::new(); @@ -322,29 +393,37 @@ fn extract_memory_pool_size(size: &str) -> Result { let lower = size.to_lowercase(); if let Some(caps) = SUFFIX_REGEX.captures(&lower) { let num_str = caps.get(1).unwrap().as_str(); - let num = num_str.parse::().map_err(|_| { - format!("Invalid numeric value in memory pool size '{}'", size) - })?; + let num = num_str + .parse::() + .map_err(|_| format!("Invalid numeric value in {label} '{size}'"))?; let suffix = caps.get(2).map(|m| m.as_str()).unwrap_or("b"); - let unit = &BYTE_SUFFIXES + let unit = BYTE_SUFFIXES .get(suffix) - .ok_or_else(|| format!("Invalid memory pool size '{}'", size))?; - let memory_pool_size = usize::try_from(unit.multiplier()) + .ok_or_else(|| format!("Invalid {label} '{size}'"))?; + let total_bytes = usize::try_from(unit.multiplier()) .ok() .and_then(|multiplier| num.checked_mul(multiplier)) - .ok_or_else(|| format!("Memory pool size '{}' is too large", size))?; + .ok_or_else(|| format!("{label} '{size}' is too large"))?; - Ok(memory_pool_size) + Ok(total_bytes) } else { - Err(format!("Invalid memory pool size '{}'", size)) + Err(format!("Invalid {label} '{size}'")) } } +pub fn extract_memory_pool_size(size: &str) -> Result { + parse_size_string(size, "memory pool size") +} + +pub fn extract_disk_limit(size: &str) -> Result { + parse_size_string(size, "disk limit") +} + #[cfg(test)] mod tests { use super::*; - use datafusion::common::test_util::batches_to_string; + use datafusion::{common::test_util::batches_to_string, prelude::ParquetReadOptions}; use insta::assert_snapshot; fn assert_conversion(input: &str, expected: Result) { @@ -459,4 +538,97 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_metadata_cache() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_udtf( + "metadata_cache", + Arc::new(MetadataCacheFunc::new( + ctx.task_ctx().runtime_env().cache_manager.clone(), + )), + ); + + ctx.register_parquet( + "alltypes_plain", + "../parquet-testing/data/alltypes_plain.parquet", + ParquetReadOptions::new(), + ) + .await?; + + ctx.register_parquet( + "alltypes_tiny_pages", + "../parquet-testing/data/alltypes_tiny_pages.parquet", + ParquetReadOptions::new(), + ) + .await?; + + ctx.register_parquet( + "lz4_raw_compressed_larger", + "../parquet-testing/data/lz4_raw_compressed_larger.parquet", + ParquetReadOptions::new(), + ) + .await?; + + ctx.sql("select * from alltypes_plain") + .await? + .collect() + .await?; + ctx.sql("select * from alltypes_tiny_pages") + .await? + .collect() + .await?; + ctx.sql("select * from lz4_raw_compressed_larger") + .await? + .collect() + .await?; + + // initial state + let sql = "SELECT split_part(path, '/', -1) as filename, file_size_bytes, metadata_size_bytes, hits, extra from metadata_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + assert_snapshot!(batches_to_string(&rbs),@r" + +-----------------------------------+-----------------+---------------------+------+------------------+ + | filename | file_size_bytes | metadata_size_bytes | hits | extra | + +-----------------------------------+-----------------+---------------------+------+------------------+ + | alltypes_plain.parquet | 1851 | 10181 | 2 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 881418 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 2939 | 2 | page_index=false | + +-----------------------------------+-----------------+---------------------+------+------------------+ + "); + + // increase the number of hits + ctx.sql("select * from alltypes_plain") + .await? + .collect() + .await?; + ctx.sql("select * from alltypes_plain") + .await? + .collect() + .await?; + ctx.sql("select * from alltypes_plain") + .await? + .collect() + .await?; + ctx.sql("select * from lz4_raw_compressed_larger") + .await? + .collect() + .await?; + let sql = "select split_part(path, '/', -1) as filename, file_size_bytes, metadata_size_bytes, hits, extra from metadata_cache() order by filename"; + let df = ctx.sql(sql).await?; + let rbs = df.collect().await?; + + assert_snapshot!(batches_to_string(&rbs),@r" + +-----------------------------------+-----------------+---------------------+------+------------------+ + | filename | file_size_bytes | metadata_size_bytes | hits | extra | + +-----------------------------------+-----------------+---------------------+------+------------------+ + | alltypes_plain.parquet | 1851 | 10181 | 5 | page_index=false | + | alltypes_tiny_pages.parquet | 454233 | 881418 | 2 | page_index=true | + | lz4_raw_compressed_larger.parquet | 380836 | 2939 | 3 | page_index=false | + +-----------------------------------+-----------------+---------------------+------+------------------+ + "); + + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index c31310093ac6b..e6e6be42c7ad0 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -15,29 +15,55 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::fmt::{Debug, Display}; -use std::sync::Arc; - -use datafusion::common::config::{ - ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, TableOptions, Visit, -}; -use datafusion::common::{config_err, exec_datafusion_err, exec_err}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::SessionState; +pub mod instrumented; use async_trait::async_trait; use aws_config::BehaviorVersion; -use aws_credential_types::provider::ProvideCredentials; -use object_store::aws::{AmazonS3Builder, AwsCredential}; -use object_store::gcp::GoogleCloudStorageBuilder; -use object_store::http::HttpBuilder; -use object_store::{ClientOptions, CredentialProvider, ObjectStore}; +use aws_credential_types::provider::{ + error::CredentialsError, ProvideCredentials, SharedCredentialsProvider, +}; +use datafusion::{ + common::{ + config::ConfigEntry, config::ConfigExtension, config::ConfigField, + config::ExtensionOptions, config::TableOptions, config::Visit, config_err, + exec_datafusion_err, exec_err, + }, + error::{DataFusionError, Result}, + execution::context::SessionState, +}; +use log::debug; +use object_store::{ + aws::{AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}, + gcp::GoogleCloudStorageBuilder, + http::HttpBuilder, + ClientOptions, CredentialProvider, + Error::Generic, + ObjectStore, +}; +use std::{ + any::Any, + error::Error, + fmt::{Debug, Display}, + sync::Arc, +}; use url::Url; +#[cfg(not(test))] +use object_store::aws::resolve_bucket_region; + +// Provide a local mock when running tests so we don't make network calls +#[cfg(test)] +async fn resolve_bucket_region( + _bucket: &str, + _client_options: &ClientOptions, +) -> object_store::Result { + Ok("eu-central-1".to_string()) +} + pub async fn get_s3_object_store_builder( url: &Url, aws_options: &AwsOptions, + resolve_region: bool, ) -> Result { let AwsOptions { access_key_id, @@ -46,6 +72,7 @@ pub async fn get_s3_object_store_builder( region, endpoint, allow_http, + skip_signature, } = aws_options; let bucket_name = get_bucket_name(url)?; @@ -54,6 +81,7 @@ pub async fn get_s3_object_store_builder( if let (Some(access_key_id), Some(secret_access_key)) = (access_key_id, secret_access_key) { + debug!("Using explicitly provided S3 access_key_id and secret_access_key"); builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); @@ -62,29 +90,37 @@ pub async fn get_s3_object_store_builder( builder = builder.with_token(session_token); } } else { - let config = aws_config::defaults(BehaviorVersion::latest()).load().await; - if let Some(region) = config.region() { - builder = builder.with_region(region.to_string()); + debug!("Using AWS S3 SDK to determine credentials"); + let CredentialsFromConfig { + region, + credentials, + } = CredentialsFromConfig::try_new().await?; + if let Some(region) = region { + builder = builder.with_region(region); + } + if let Some(credentials) = credentials { + let credentials = Arc::new(S3CredentialProvider { credentials }); + builder = builder.with_credentials(credentials); + } else { + debug!("No credentials found, defaulting to skip signature "); + builder = builder.with_skip_signature(true); } - - let credentials = config - .credentials_provider() - .ok_or_else(|| { - DataFusionError::ObjectStore(object_store::Error::Generic { - store: "S3", - source: "Failed to get S3 credentials from the environment".into(), - }) - })? - .clone(); - - let credentials = Arc::new(S3CredentialProvider { credentials }); - builder = builder.with_credentials(credentials); } if let Some(region) = region { builder = builder.with_region(region); } + // If the region is not set or auto_detect_region is true, resolve the region. + if builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_none() + || resolve_region + { + let region = resolve_bucket_region(bucket_name, &ClientOptions::new()).await?; + builder = builder.with_region(region); + } + if let Some(endpoint) = endpoint { // Make a nicer error if the user hasn't allowed http and the endpoint // is http as the default message is "URL scheme is not allowed" @@ -105,9 +141,71 @@ pub async fn get_s3_object_store_builder( builder = builder.with_allow_http(*allow_http); } + if let Some(skip_signature) = skip_signature { + builder = builder.with_skip_signature(*skip_signature); + } + Ok(builder) } +/// Credentials from the AWS SDK +struct CredentialsFromConfig { + region: Option, + credentials: Option, +} + +impl CredentialsFromConfig { + /// Attempt find AWS S3 credentials via the AWS SDK + pub async fn try_new() -> Result { + let config = aws_config::defaults(BehaviorVersion::latest()).load().await; + let region = config.region().map(|r| r.to_string()); + + let credentials = config + .credentials_provider() + .ok_or_else(|| { + DataFusionError::ObjectStore(Box::new(Generic { + store: "S3", + source: "Failed to get S3 credentials aws_config".into(), + })) + })? + .clone(); + + // The credential provider is lazy, so it does not fetch credentials + // until they are needed. To ensure that the credentials are valid, + // we can call `provide_credentials` here. + let credentials = match credentials.provide_credentials().await { + Ok(_) => Some(credentials), + Err(CredentialsError::CredentialsNotLoaded(_)) => { + debug!("Could not use AWS SDK to get credentials"); + None + } + // other errors like `CredentialsError::InvalidConfiguration` + // should be returned to the user so they can be fixed + Err(e) => { + // Pass back underlying error to the user, including underlying source + let source_message = if let Some(source) = e.source() { + format!(": {source}") + } else { + String::new() + }; + + let message = format!( + "Error getting credentials from provider: {e}{source_message}", + ); + + return Err(DataFusionError::ObjectStore(Box::new(Generic { + store: "S3", + source: message.into(), + }))); + } + }; + Ok(Self { + region, + credentials, + }) + } +} + #[derive(Debug)] struct S3CredentialProvider { credentials: aws_credential_types::provider::SharedCredentialsProvider, @@ -118,12 +216,14 @@ impl CredentialProvider for S3CredentialProvider { type Credential = AwsCredential; async fn get_credential(&self) -> object_store::Result> { - let creds = self.credentials.provide_credentials().await.map_err(|e| { - object_store::Error::Generic { - store: "S3", - source: Box::new(e), - } - })?; + let creds = + self.credentials + .provide_credentials() + .await + .map_err(|e| Generic { + store: "S3", + source: Box::new(e), + })?; Ok(Arc::new(AwsCredential { key_id: creds.access_key_id().to_string(), secret_key: creds.secret_access_key().to_string(), @@ -197,10 +297,7 @@ pub fn get_gcs_object_store_builder( fn get_bucket_name(url: &Url) -> Result<&str> { url.host_str().ok_or_else(|| { - DataFusionError::Execution(format!( - "Not able to parse bucket name from url: {}", - url.as_str() - )) + exec_datafusion_err!("Not able to parse bucket name from url: {}", url.as_str()) }) } @@ -219,6 +316,11 @@ pub struct AwsOptions { pub endpoint: Option, /// Allow HTTP (otherwise will always use https) pub allow_http: Option, + /// Do not fetch credentials and do not sign requests + /// + /// This can be useful when interacting with public S3 buckets that deny + /// authorized requests + pub skip_signature: Option, } impl ExtensionOptions for AwsOptions { @@ -256,6 +358,9 @@ impl ExtensionOptions for AwsOptions { "allow_http" => { self.allow_http.set(rem, value)?; } + "skip_signature" | "nosign" => { + self.skip_signature.set(rem, value)?; + } _ => { return config_err!("Config value \"{}\" not found on AwsOptions", rem); } @@ -397,6 +502,7 @@ pub(crate) async fn get_object_store( scheme: &str, url: &Url, table_options: &TableOptions, + resolve_region: bool, ) -> Result, DataFusionError> { let store: Arc = match scheme { "s3" => { @@ -405,7 +511,8 @@ pub(crate) async fn get_object_store( "Given table options incompatible with the 's3' scheme" ); }; - let builder = get_s3_object_store_builder(url, options).await?; + let builder = + get_s3_object_store_builder(url, options, resolve_region).await?; Arc::new(builder.build()?) } "oss" => { @@ -461,7 +568,6 @@ mod tests { use super::*; - use datafusion::common::plan_err; use datafusion::{ datasource::listing::ListingTableUrl, logical_expr::{DdlStatement, LogicalPlan}, @@ -470,6 +576,72 @@ mod tests { use object_store::{aws::AmazonS3ConfigKey, gcp::GoogleConfigKey}; + #[tokio::test] + async fn s3_object_store_builder_default() -> Result<()> { + if let Err(DataFusionError::Execution(e)) = check_aws_envs().await { + // Skip test if AWS envs are not set + eprintln!("{e}"); + return Ok(()); + } + + let location = "s3://bucket/path/FAKE/file.parquet"; + // Set it to a non-existent file to avoid reading the default configuration file + std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + std::env::set_var("AWS_SHARED_CREDENTIALS_FILE", "data/aws.credentials"); + + // No options + let table_url = ListingTableUrl::parse(location)?; + let scheme = table_url.scheme(); + let sql = + format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'"); + + let ctx = SessionContext::new(); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let builder = + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; + + // If the environment variables are set (as they are in CI) use them + let expected_access_key_id = std::env::var("AWS_ACCESS_KEY_ID").ok(); + let expected_secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok(); + let expected_region = Some( + std::env::var("AWS_REGION").unwrap_or_else(|_| "eu-central-1".to_string()), + ); + let expected_endpoint = std::env::var("AWS_ENDPOINT").ok(); + + // get the actual configuration information, then assert_eq! + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::AccessKeyId), + expected_access_key_id + ); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SecretAccessKey), + expected_secret_access_key + ); + // Default is to skip signature when no credentials are provided + let expected_skip_signature = + if expected_access_key_id.is_none() && expected_secret_access_key.is_none() { + Some(String::from("true")) + } else { + Some(String::from("false")) + }; + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + expected_region + ); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Endpoint), + expected_endpoint + ); + assert_eq!(builder.get_config_value(&AmazonS3ConfigKey::Token), None); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SkipSignature), + expected_skip_signature + ); + Ok(()) + } + #[tokio::test] async fn s3_object_store_builder() -> Result<()> { // "fake" is uppercase to ensure the values are not lowercased when parsed @@ -493,29 +665,27 @@ mod tests { ); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let builder = - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; - // get the actual configuration information, then assert_eq! - let config = [ - (AmazonS3ConfigKey::AccessKeyId, access_key_id), - (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), - (AmazonS3ConfigKey::Region, region), - (AmazonS3ConfigKey::Endpoint, endpoint), - (AmazonS3ConfigKey::Token, session_token), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let builder = + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; + // get the actual configuration information, then assert_eq! + let config = [ + (AmazonS3ConfigKey::AccessKeyId, access_key_id), + (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), + (AmazonS3ConfigKey::Region, region), + (AmazonS3ConfigKey::Endpoint, endpoint), + (AmazonS3ConfigKey::Token, session_token), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } + // Should not skip signature when credentials are provided + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SkipSignature), + Some("false".into()) + ); Ok(()) } @@ -538,21 +708,15 @@ mod tests { ); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let err = get_s3_object_store_builder(table_url.as_ref(), aws_options) - .await - .unwrap_err(); + ctx.register_table_options_extension_from_scheme(scheme); - assert_eq!(err.to_string().lines().next().unwrap_or_default(), "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true"); - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); - } + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let err = get_s3_object_store_builder(table_url.as_ref(), aws_options, false) + .await + .unwrap_err(); + + assert_eq!(err.to_string().lines().next().unwrap_or_default(), "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true"); // Now add `allow_http` to the options and check if it works let sql = format!( @@ -563,19 +727,72 @@ mod tests { 'aws.allow_http' 'true'\ ) LOCATION '{location}'" ); + let table_options = get_table_options(&ctx, &sql).await; - let mut plan = ctx.state().create_logical_plan(&sql).await?; + let aws_options = table_options.extensions.get::().unwrap(); + // ensure this isn't an error + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - // ensure this isn't an error - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + Ok(()) + } + + #[tokio::test] + async fn s3_object_store_builder_resolves_region_when_none_provided() -> Result<()> { + if let Err(DataFusionError::Execution(e)) = check_aws_envs().await { + // Skip test if AWS envs are not set + eprintln!("{e}"); + return Ok(()); } + let expected_region = "eu-central-1"; + let location = "s3://test-bucket/path/file.parquet"; + // Set it to a non-existent file to avoid reading the default configuration file + std::env::set_var("AWS_CONFIG_FILE", "data/aws.config"); + + let table_url = ListingTableUrl::parse(location)?; + let aws_options = AwsOptions { + region: None, // No region specified - should auto-detect + ..Default::default() + }; + + let builder = + get_s3_object_store_builder(table_url.as_ref(), &aws_options, false).await?; + + // Verify that the region was auto-detected in test environment + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + Some(expected_region.to_string()) + ); + + Ok(()) + } + + #[tokio::test] + async fn s3_object_store_builder_overrides_region_when_resolve_region_enabled( + ) -> Result<()> { + if let Err(DataFusionError::Execution(e)) = check_aws_envs().await { + // Skip test if AWS envs are not set + eprintln!("{e}"); + return Ok(()); + } + + let original_region = "us-east-1"; + let expected_region = "eu-central-1"; // This should be the auto-detected region + let location = "s3://test-bucket/path/file.parquet"; + + let table_url = ListingTableUrl::parse(location)?; + let aws_options = AwsOptions { + region: Some(original_region.to_string()), // Explicit region provided + ..Default::default() + }; + + let builder = + get_s3_object_store_builder(table_url.as_ref(), &aws_options, true).await?; + + // Verify that the region was overridden by auto-detection + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + Some(expected_region.to_string()) + ); Ok(()) } @@ -592,25 +809,19 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; - // get the actual configuration information, then assert_eq! - let config = [ - (AmazonS3ConfigKey::AccessKeyId, access_key_id), - (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), - (AmazonS3ConfigKey::Endpoint, endpoint), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + + let aws_options = table_options.extensions.get::().unwrap(); + let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; + // get the actual configuration information, then assert_eq! + let config = [ + (AmazonS3ConfigKey::AccessKeyId, access_key_id), + (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), + (AmazonS3ConfigKey::Endpoint, endpoint), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } Ok(()) @@ -629,30 +840,55 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_path' '{service_account_path}', 'gcp.service_account_key' '{service_account_key}', 'gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let gcp_options = table_options.extensions.get::().unwrap(); - let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; - // get the actual configuration information, then assert_eq! - let config = [ - (GoogleConfigKey::ServiceAccount, service_account_path), - (GoogleConfigKey::ServiceAccountKey, service_account_key), - ( - GoogleConfigKey::ApplicationCredentials, - application_credentials_path, - ), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + + let gcp_options = table_options.extensions.get::().unwrap(); + let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; + // get the actual configuration information, then assert_eq! + let config = [ + (GoogleConfigKey::ServiceAccount, service_account_path), + (GoogleConfigKey::ServiceAccountKey, service_account_key), + ( + GoogleConfigKey::ApplicationCredentials, + application_credentials_path, + ), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } Ok(()) } + + /// Plans the `CREATE EXTERNAL TABLE` SQL statement and returns the + /// resulting resolved `CreateExternalTable` command. + async fn get_table_options(ctx: &SessionContext, sql: &str) -> TableOptions { + let mut plan = ctx.state().create_logical_plan(sql).await.unwrap(); + + let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan else { + panic!("plan is not a CreateExternalTable"); + }; + + let mut table_options = ctx.state().default_table_options(); + table_options + .alter_with_string_hash_map(&cmd.options) + .unwrap(); + table_options + } + + async fn check_aws_envs() -> Result<()> { + let aws_envs = [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION", + "AWS_ALLOW_HTTP", + ]; + for aws_env in aws_envs { + std::env::var(aws_env).map_err(|_| { + exec_datafusion_err!("aws envs not set, skipping s3 tests") + })?; + } + Ok(()) + } } diff --git a/datafusion-cli/src/object_storage/instrumented.rs b/datafusion-cli/src/object_storage/instrumented.rs new file mode 100644 index 0000000000000..9252e0688c35a --- /dev/null +++ b/datafusion-cli/src/object_storage/instrumented.rs @@ -0,0 +1,676 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + cmp, fmt, + ops::AddAssign, + str::FromStr, + sync::{ + atomic::{AtomicU8, Ordering}, + Arc, + }, + time::Duration, +}; + +use async_trait::async_trait; +use chrono::Utc; +use datafusion::{ + common::{instant::Instant, HashMap}, + error::DataFusionError, + execution::object_store::{DefaultObjectStoreRegistry, ObjectStoreRegistry}, +}; +use futures::stream::BoxStream; +use object_store::{ + path::Path, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, +}; +use parking_lot::{Mutex, RwLock}; +use url::Url; + +/// The profiling mode to use for an [`InstrumentedObjectStore`] instance. Collecting profiling +/// data will have a small negative impact on both CPU and memory usage. Default is `Disabled` +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub enum InstrumentedObjectStoreMode { + /// Disable collection of profiling data + #[default] + Disabled, + /// Enable collection of profiling data + Enabled, +} + +impl fmt::Display for InstrumentedObjectStoreMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl FromStr for InstrumentedObjectStoreMode { + type Err = DataFusionError; + + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "disabled" => Ok(Self::Disabled), + "enabled" => Ok(Self::Enabled), + _ => Err(DataFusionError::Execution(format!("Unrecognized mode {s}"))), + } + } +} + +impl From for InstrumentedObjectStoreMode { + fn from(value: u8) -> Self { + match value { + 1 => InstrumentedObjectStoreMode::Enabled, + _ => InstrumentedObjectStoreMode::Disabled, + } + } +} + +/// Wrapped [`ObjectStore`] instances that record information for reporting on the usage of the +/// inner [`ObjectStore`] +#[derive(Debug)] +pub struct InstrumentedObjectStore { + inner: Arc, + instrument_mode: AtomicU8, + requests: Mutex>, +} + +impl InstrumentedObjectStore { + /// Returns a new [`InstrumentedObjectStore`] that wraps the provided [`ObjectStore`] + fn new(object_store: Arc, instrument_mode: AtomicU8) -> Self { + Self { + inner: object_store, + instrument_mode, + requests: Mutex::new(Vec::new()), + } + } + + fn set_instrument_mode(&self, mode: InstrumentedObjectStoreMode) { + self.instrument_mode.store(mode as u8, Ordering::Relaxed) + } + + /// Returns all [`RequestDetails`] accumulated in this [`InstrumentedObjectStore`] and clears + /// the stored requests + pub fn take_requests(&self) -> Vec { + let mut req = self.requests.lock(); + + req.drain(..).collect() + } + + async fn instrumented_get_opts( + &self, + location: &Path, + options: GetOptions, + ) -> Result { + let timestamp = Utc::now(); + let range = options.range.clone(); + + let start = Instant::now(); + let ret = self.inner.get_opts(location, options).await?; + let elapsed = start.elapsed(); + + self.requests.lock().push(RequestDetails { + op: Operation::Get, + path: location.clone(), + timestamp, + duration: Some(elapsed), + size: Some((ret.range.end - ret.range.start) as usize), + range, + extra_display: None, + }); + + Ok(ret) + } +} + +impl fmt::Display for InstrumentedObjectStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mode: InstrumentedObjectStoreMode = + self.instrument_mode.load(Ordering::Relaxed).into(); + write!( + f, + "Instrumented Object Store: instrument_mode: {mode}, inner: {}", + self.inner + ) + } +} + +#[async_trait] +impl ObjectStore for InstrumentedObjectStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.inner.put_opts(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> Result> { + self.inner.put_multipart_opts(location, opts).await + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + if self.instrument_mode.load(Ordering::Relaxed) + != InstrumentedObjectStoreMode::Disabled as u8 + { + return self.instrumented_get_opts(location, options).await; + } + + self.inner.get_opts(location, options).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.inner.delete(location).await + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, Result> { + self.inner.list(prefix) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + self.inner.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.inner.copy(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.inner.copy_if_not_exists(from, to).await + } + + async fn head(&self, location: &Path) -> Result { + self.inner.head(location).await + } +} + +/// Object store operation types tracked by [`InstrumentedObjectStore`] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Operation { + _Copy, + _Delete, + Get, + _Head, + _List, + _Put, +} + +/// Holds profiling details about individual requests made through an [`InstrumentedObjectStore`] +#[derive(Debug)] +pub struct RequestDetails { + op: Operation, + path: Path, + timestamp: chrono::DateTime, + duration: Option, + size: Option, + range: Option, + extra_display: Option, +} + +impl fmt::Display for RequestDetails { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut output_parts = vec![format!( + "{} operation={:?}", + self.timestamp.to_rfc3339(), + self.op + )]; + + if let Some(d) = self.duration { + output_parts.push(format!("duration={:.6}s", d.as_secs_f32())); + } + if let Some(s) = self.size { + output_parts.push(format!("size={s}")); + } + if let Some(r) = &self.range { + output_parts.push(format!("range: {r}")); + } + output_parts.push(format!("path={}", self.path)); + + if let Some(ed) = &self.extra_display { + output_parts.push(ed.clone()); + } + + write!(f, "{}", output_parts.join(" ")) + } +} + +/// Summary statistics for an [`InstrumentedObjectStore`]'s [`RequestDetails`] +#[derive(Default)] +pub struct RequestSummary { + count: usize, + duration_stats: Option>, + size_stats: Option>, +} + +impl RequestSummary { + /// Generates a set of [RequestSummaries](RequestSummary) from the input [`RequestDetails`] + /// grouped by the input's [`Operation`] + pub fn summarize_by_operation( + requests: &[RequestDetails], + ) -> HashMap { + let mut summaries: HashMap = HashMap::new(); + for rd in requests { + match summaries.get_mut(&rd.op) { + Some(rs) => rs.push(rd), + None => { + let mut rs = RequestSummary::default(); + rs.push(rd); + summaries.insert(rd.op, rs); + } + } + } + + summaries + } + + fn push(&mut self, request: &RequestDetails) { + self.count += 1; + if let Some(dur) = request.duration { + self.duration_stats.get_or_insert_default().push(dur) + } + if let Some(size) = request.size { + self.size_stats.get_or_insert_default().push(size) + } + } +} + +impl fmt::Display for RequestSummary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "count: {}", self.count)?; + + if let Some(dur_stats) = &self.duration_stats { + writeln!(f, "duration min: {:.6}s", dur_stats.min.as_secs_f32())?; + writeln!(f, "duration max: {:.6}s", dur_stats.max.as_secs_f32())?; + let avg = dur_stats.sum.as_secs_f32() / (self.count as f32); + writeln!(f, "duration avg: {:.6}s", avg)?; + } + + if let Some(size_stats) = &self.size_stats { + writeln!(f, "size min: {} B", size_stats.min)?; + writeln!(f, "size max: {} B", size_stats.max)?; + let avg = size_stats.sum / self.count; + writeln!(f, "size avg: {} B", avg)?; + writeln!(f, "size sum: {} B", size_stats.sum)?; + } + + Ok(()) + } +} + +struct Stats> { + min: T, + max: T, + sum: T, +} + +impl> Stats { + fn push(&mut self, val: T) { + self.min = cmp::min(val, self.min); + self.max = cmp::max(val, self.max); + self.sum += val; + } +} + +impl Default for Stats { + fn default() -> Self { + Self { + min: Duration::MAX, + max: Duration::ZERO, + sum: Duration::ZERO, + } + } +} + +impl Default for Stats { + fn default() -> Self { + Self { + min: usize::MAX, + max: usize::MIN, + sum: 0, + } + } +} + +/// Provides access to [`InstrumentedObjectStore`] instances that record requests for reporting +#[derive(Debug)] +pub struct InstrumentedObjectStoreRegistry { + inner: Arc, + instrument_mode: AtomicU8, + stores: RwLock>>, +} + +impl Default for InstrumentedObjectStoreRegistry { + fn default() -> Self { + Self::new() + } +} + +impl InstrumentedObjectStoreRegistry { + /// Returns a new [`InstrumentedObjectStoreRegistry`] that wraps the provided + /// [`ObjectStoreRegistry`] + pub fn new() -> Self { + Self { + inner: Arc::new(DefaultObjectStoreRegistry::new()), + instrument_mode: AtomicU8::new(InstrumentedObjectStoreMode::default() as u8), + stores: RwLock::new(Vec::new()), + } + } + + pub fn with_profile_mode(self, mode: InstrumentedObjectStoreMode) -> Self { + self.instrument_mode.store(mode as u8, Ordering::Relaxed); + self + } + + /// Provides access to all of the [`InstrumentedObjectStore`]s managed by this + /// [`InstrumentedObjectStoreRegistry`] + pub fn stores(&self) -> Vec> { + self.stores.read().clone() + } + + /// Returns the current [`InstrumentedObjectStoreMode`] for this + /// [`InstrumentedObjectStoreRegistry`] + pub fn instrument_mode(&self) -> InstrumentedObjectStoreMode { + self.instrument_mode.load(Ordering::Relaxed).into() + } + + /// Sets the [`InstrumentedObjectStoreMode`] for this [`InstrumentedObjectStoreRegistry`] + pub fn set_instrument_mode(&self, mode: InstrumentedObjectStoreMode) { + self.instrument_mode.store(mode as u8, Ordering::Relaxed); + for s in self.stores.read().iter() { + s.set_instrument_mode(mode) + } + } +} + +impl ObjectStoreRegistry for InstrumentedObjectStoreRegistry { + fn register_store( + &self, + url: &Url, + store: Arc, + ) -> Option> { + let mode = self.instrument_mode.load(Ordering::Relaxed); + let instrumented = + Arc::new(InstrumentedObjectStore::new(store, AtomicU8::new(mode))); + self.stores.write().push(Arc::clone(&instrumented)); + self.inner.register_store(url, instrumented) + } + + fn get_store(&self, url: &Url) -> datafusion::common::Result> { + self.inner.get_store(url) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn instrumented_mode() { + assert!(matches!( + InstrumentedObjectStoreMode::default(), + InstrumentedObjectStoreMode::Disabled + )); + + assert!(matches!( + "dIsABleD".parse().unwrap(), + InstrumentedObjectStoreMode::Disabled + )); + assert!(matches!( + "EnABlEd".parse().unwrap(), + InstrumentedObjectStoreMode::Enabled + )); + assert!("does_not_exist" + .parse::() + .is_err()); + + assert!(matches!(0.into(), InstrumentedObjectStoreMode::Disabled)); + assert!(matches!(1.into(), InstrumentedObjectStoreMode::Enabled)); + assert!(matches!(2.into(), InstrumentedObjectStoreMode::Disabled)); + } + + #[test] + fn instrumented_registry() { + let mut reg = InstrumentedObjectStoreRegistry::new(); + assert!(reg.stores().is_empty()); + assert_eq!( + reg.instrument_mode(), + InstrumentedObjectStoreMode::default() + ); + + reg = reg.with_profile_mode(InstrumentedObjectStoreMode::Enabled); + assert_eq!(reg.instrument_mode(), InstrumentedObjectStoreMode::Enabled); + + let store = object_store::memory::InMemory::new(); + let url = "mem://test".parse().unwrap(); + let registered = reg.register_store(&url, Arc::new(store)); + assert!(registered.is_none()); + + let fetched = reg.get_store(&url); + assert!(fetched.is_ok()); + assert_eq!(reg.stores().len(), 1); + } + + #[tokio::test] + async fn instrumented_store() { + let store = Arc::new(object_store::memory::InMemory::new()); + let mode = AtomicU8::new(InstrumentedObjectStoreMode::default() as u8); + let instrumented = InstrumentedObjectStore::new(store, mode); + + // Load the test store with some data we can read + let path = Path::from("test/data"); + let payload = PutPayload::from_static(b"test_data"); + instrumented.put(&path, payload).await.unwrap(); + + // By default no requests should be instrumented/stored + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.get(&path).await.unwrap(); + assert!(instrumented.requests.lock().is_empty()); + + instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Enabled); + assert!(instrumented.requests.lock().is_empty()); + let _ = instrumented.get(&path).await.unwrap(); + assert_eq!(instrumented.requests.lock().len(), 1); + + let mut requests = instrumented.take_requests(); + assert_eq!(requests.len(), 1); + assert!(instrumented.requests.lock().is_empty()); + + let request = requests.pop().unwrap(); + assert_eq!(request.op, Operation::Get); + assert_eq!(request.path, path); + assert!(request.duration.is_some()); + assert_eq!(request.size, Some(9)); + assert_eq!(request.range, None); + assert!(request.extra_display.is_none()); + } + + #[test] + fn request_details() { + let rd = RequestDetails { + op: Operation::Get, + path: Path::from("test"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration: Some(Duration::new(5, 0)), + size: Some(10), + range: Some((..10).into()), + extra_display: Some(String::from("extra info")), + }; + + assert_eq!( + format!("{rd}"), + "1970-01-01T00:00:00+00:00 operation=Get duration=5.000000s size=10 range: bytes=0-9 path=test extra info" + ); + } + + #[test] + fn request_summary() { + // Test empty request list + let mut requests = Vec::new(); + let summaries = RequestSummary::summarize_by_operation(&requests); + assert!(summaries.is_empty()); + + requests.push(RequestDetails { + op: Operation::Get, + path: Path::from("test1"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration: Some(Duration::from_secs(5)), + size: Some(100), + range: None, + extra_display: None, + }); + + let summaries = RequestSummary::summarize_by_operation(&requests); + assert_eq!(summaries.len(), 1); + + let summary = summaries.get(&Operation::Get).unwrap(); + assert_eq!(summary.count, 1); + assert_eq!( + summary.duration_stats.as_ref().unwrap().min, + Duration::from_secs(5) + ); + assert_eq!( + summary.duration_stats.as_ref().unwrap().max, + Duration::from_secs(5) + ); + assert_eq!( + summary.duration_stats.as_ref().unwrap().sum, + Duration::from_secs(5) + ); + assert_eq!(summary.size_stats.as_ref().unwrap().min, 100); + assert_eq!(summary.size_stats.as_ref().unwrap().max, 100); + assert_eq!(summary.size_stats.as_ref().unwrap().sum, 100); + + // Add more Get requests to test aggregation + requests.push(RequestDetails { + op: Operation::Get, + path: Path::from("test2"), + timestamp: chrono::DateTime::from_timestamp(1, 0).unwrap(), + duration: Some(Duration::from_secs(8)), + size: Some(150), + range: None, + extra_display: None, + }); + requests.push(RequestDetails { + op: Operation::Get, + path: Path::from("test3"), + timestamp: chrono::DateTime::from_timestamp(2, 0).unwrap(), + duration: Some(Duration::from_secs(2)), + size: Some(50), + range: None, + extra_display: None, + }); + + let summaries = RequestSummary::summarize_by_operation(&requests); + assert_eq!(summaries.len(), 1); + + let summary = summaries.get(&Operation::Get).unwrap(); + assert_eq!(summary.count, 3); + assert_eq!( + summary.duration_stats.as_ref().unwrap().min, + Duration::from_secs(2) + ); + assert_eq!( + summary.duration_stats.as_ref().unwrap().max, + Duration::from_secs(8) + ); + assert_eq!( + summary.duration_stats.as_ref().unwrap().sum, + Duration::from_secs(15) + ); + assert_eq!(summary.size_stats.as_ref().unwrap().min, 50); + assert_eq!(summary.size_stats.as_ref().unwrap().max, 150); + assert_eq!(summary.size_stats.as_ref().unwrap().sum, 300); + + // Add Put requests to test grouping + requests.push(RequestDetails { + op: Operation::_Put, + path: Path::from("test4"), + timestamp: chrono::DateTime::from_timestamp(3, 0).unwrap(), + duration: Some(Duration::from_millis(200)), + size: Some(75), + range: None, + extra_display: None, + }); + + let summaries = RequestSummary::summarize_by_operation(&requests); + assert_eq!(summaries.len(), 2); + + let get_summary = summaries.get(&Operation::Get).unwrap(); + assert_eq!(get_summary.count, 3); + + let put_summary = summaries.get(&Operation::_Put).unwrap(); + assert_eq!(put_summary.count, 1); + assert_eq!( + put_summary.duration_stats.as_ref().unwrap().min, + Duration::from_millis(200) + ); + assert_eq!(put_summary.size_stats.as_ref().unwrap().sum, 75); + + // Test request with only duration (no size) + let only_duration = vec![RequestDetails { + op: Operation::Get, + path: Path::from("test1"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration: Some(Duration::from_secs(3)), + size: None, + range: None, + extra_display: None, + }]; + let summaries = RequestSummary::summarize_by_operation(&only_duration); + let summary = summaries.get(&Operation::Get).unwrap(); + assert_eq!(summary.count, 1); + assert!(summary.duration_stats.is_some()); + assert!(summary.size_stats.is_none()); + + // Test request with only size (no duration) + let only_size = vec![RequestDetails { + op: Operation::Get, + path: Path::from("test1"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration: None, + size: Some(200), + range: None, + extra_display: None, + }]; + let summaries = RequestSummary::summarize_by_operation(&only_size); + let summary = summaries.get(&Operation::Get).unwrap(); + assert_eq!(summary.count, 1); + assert!(summary.duration_stats.is_none()); + assert!(summary.size_stats.is_some()); + assert_eq!(summary.size_stats.as_ref().unwrap().sum, 200); + + // Test request with neither duration nor size + let no_stats = vec![RequestDetails { + op: Operation::Get, + path: Path::from("test1"), + timestamp: chrono::DateTime::from_timestamp(0, 0).unwrap(), + duration: None, + size: None, + range: None, + extra_display: None, + }]; + let summaries = RequestSummary::summarize_by_operation(&no_stats); + let summary = summaries.get(&Operation::Get).unwrap(); + assert_eq!(summary.count, 1); + assert!(summary.duration_stats.is_none()); + assert!(summary.size_stats.is_none()); + } +} diff --git a/datafusion-cli/src/pool_type.rs b/datafusion-cli/src/pool_type.rs index 269790b61f5a5..a2164cc3c7392 100644 --- a/datafusion-cli/src/pool_type.rs +++ b/datafusion-cli/src/pool_type.rs @@ -33,7 +33,7 @@ impl FromStr for PoolType { match s { "Greedy" | "greedy" => Ok(PoolType::Greedy), "Fair" | "fair" => Ok(PoolType::Fair), - _ => Err(format!("Invalid memory pool type '{}'", s)), + _ => Err(format!("Invalid memory pool type '{s}'")), } } } diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 1fc949593512b..56bdb15a315d9 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -26,7 +26,7 @@ use arrow::datatypes::SchemaRef; use arrow::json::{ArrayWriter, LineDelimitedWriter}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; -use datafusion::common::format::DEFAULT_CLI_FORMAT_OPTIONS; +use datafusion::config::FormatOptions; use datafusion::error::Result; /// Allow records to be printed in different formats @@ -110,7 +110,10 @@ fn format_batches_with_maxrows( writer: &mut W, batches: &[RecordBatch], maxrows: MaxRows, + format_options: &FormatOptions, ) -> Result<()> { + let options: arrow::util::display::FormatOptions = format_options.try_into()?; + match maxrows { MaxRows::Limited(maxrows) => { // Filter batches to meet the maxrows condition @@ -131,22 +134,19 @@ fn format_batches_with_maxrows( } } - let formatted = pretty_format_batches_with_options( - &filtered_batches, - &DEFAULT_CLI_FORMAT_OPTIONS, - )?; + let formatted = + pretty_format_batches_with_options(&filtered_batches, &options)?; if over_limit { - let mut formatted_str = format!("{}", formatted); + let mut formatted_str = format!("{formatted}"); formatted_str = keep_only_maxrows(&formatted_str, maxrows); - writeln!(writer, "{}", formatted_str)?; + writeln!(writer, "{formatted_str}")?; } else { - writeln!(writer, "{}", formatted)?; + writeln!(writer, "{formatted}")?; } } MaxRows::Unlimited => { - let formatted = - pretty_format_batches_with_options(batches, &DEFAULT_CLI_FORMAT_OPTIONS)?; - writeln!(writer, "{}", formatted)?; + let formatted = pretty_format_batches_with_options(batches, &options)?; + writeln!(writer, "{formatted}")?; } } @@ -162,6 +162,7 @@ impl PrintFormat { batches: &[RecordBatch], maxrows: MaxRows, with_header: bool, + format_options: &FormatOptions, ) -> Result<()> { // filter out any empty batches let batches: Vec<_> = batches @@ -170,7 +171,7 @@ impl PrintFormat { .cloned() .collect(); if batches.is_empty() { - return self.print_empty(writer, schema); + return self.print_empty(writer, schema, format_options); } match self { @@ -182,7 +183,7 @@ impl PrintFormat { if maxrows == MaxRows::Limited(0) { return Ok(()); } - format_batches_with_maxrows(writer, &batches, maxrows) + format_batches_with_maxrows(writer, &batches, maxrows, format_options) } Self::Json => batches_to_json!(ArrayWriter, writer, &batches), Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, &batches), @@ -194,16 +195,18 @@ impl PrintFormat { &self, writer: &mut W, schema: SchemaRef, + format_options: &FormatOptions, ) -> Result<()> { match self { // Print column headers for Table format Self::Table if !schema.fields().is_empty() => { + let format_options: arrow::util::display::FormatOptions = + format_options.try_into()?; + let empty_batch = RecordBatch::new_empty(schema); - let formatted = pretty_format_batches_with_options( - &[empty_batch], - &DEFAULT_CLI_FORMAT_OPTIONS, - )?; - writeln!(writer, "{}", formatted)?; + let formatted = + pretty_format_batches_with_options(&[empty_batch], &format_options)?; + writeln!(writer, "{formatted}")?; } _ => {} } @@ -218,6 +221,7 @@ mod tests { use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; + use insta::{allow_duplicates, assert_snapshot}; #[test] fn print_empty() { @@ -229,249 +233,204 @@ mod tests { PrintFormat::Automatic, ] { // no output for empty batches, even with header set - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(format) .with_schema(three_column_schema()) .with_batches(vec![]) - .with_expected(&[""]) .run(); + assert_eq!(output, "") } // output column headers for empty batches when format is Table - #[rustfmt::skip] - let expected = &[ - "+---+---+---+", - "| a | b | c |", - "+---+---+---+", - "+---+---+---+", - ]; - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_schema(three_column_schema()) .with_batches(vec![]) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + +---+---+---+ + | a | b | c | + +---+---+---+ + +---+---+---+ + "#); } #[test] fn print_csv_no_header() { - #[rustfmt::skip] - let expected = &[ - "1,4,7", - "2,5,8", - "3,6,9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Csv) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::No) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + 1,4,7 + 2,5,8 + 3,6,9 + "#); } #[test] fn print_csv_with_header() { - #[rustfmt::skip] - let expected = &[ - "a,b,c", - "1,4,7", - "2,5,8", - "3,6,9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Csv) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + a,b,c + 1,4,7 + 2,5,8 + 3,6,9 + "#); } #[test] fn print_tsv_no_header() { - #[rustfmt::skip] - let expected = &[ - "1\t4\t7", - "2\t5\t8", - "3\t6\t9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Tsv) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::No) - .with_expected(expected) .run(); + assert_snapshot!(output, @" + 1\t4\t7 + 2\t5\t8 + 3\t6\t9 + ") } #[test] fn print_tsv_with_header() { - #[rustfmt::skip] - let expected = &[ - "a\tb\tc", - "1\t4\t7", - "2\t5\t8", - "3\t6\t9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Tsv) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_snapshot!(output, @" + a\tb\tc + 1\t4\t7 + 2\t5\t8 + 3\t6\t9 + "); } #[test] fn print_table() { - let expected = &[ - "+---+---+---+", - "| a | b | c |", - "+---+---+---+", - "| 1 | 4 | 7 |", - "| 2 | 5 | 8 |", - "| 3 | 6 | 9 |", - "+---+---+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Ignored) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + +---+---+---+ + | a | b | c | + +---+---+---+ + | 1 | 4 | 7 | + | 2 | 5 | 8 | + | 3 | 6 | 9 | + +---+---+---+ + "#); } #[test] fn print_json() { - let expected = - &[r#"[{"a":1,"b":4,"c":7},{"a":2,"b":5,"c":8},{"a":3,"b":6,"c":9}]"#]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Json) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Ignored) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + [{"a":1,"b":4,"c":7},{"a":2,"b":5,"c":8},{"a":3,"b":6,"c":9}] + "#); } #[test] fn print_ndjson() { - let expected = &[ - r#"{"a":1,"b":4,"c":7}"#, - r#"{"a":2,"b":5,"c":8}"#, - r#"{"a":3,"b":6,"c":9}"#, - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::NdJson) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Ignored) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + {"a":1,"b":4,"c":7} + {"a":2,"b":5,"c":8} + {"a":3,"b":6,"c":9} + "#); } #[test] fn print_automatic_no_header() { - #[rustfmt::skip] - let expected = &[ - "1,4,7", - "2,5,8", - "3,6,9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Automatic) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::No) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + 1,4,7 + 2,5,8 + 3,6,9 + "#); } #[test] fn print_automatic_with_header() { - #[rustfmt::skip] - let expected = &[ - "a,b,c", - "1,4,7", - "2,5,8", - "3,6,9", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Automatic) .with_batches(split_batch(three_column_batch())) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + a,b,c + 1,4,7 + 2,5,8 + 3,6,9 + "#); } #[test] fn print_maxrows_unlimited() { - #[rustfmt::skip] - let expected = &[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+", - ]; - // should print out entire output with no truncation if unlimited or // limit greater than number of batches or equal to the number of batches for max_rows in [MaxRows::Unlimited, MaxRows::Limited(5), MaxRows::Limited(3)] { - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_schema(one_column_schema()) .with_batches(vec![one_column_batch()]) .with_maxrows(max_rows) - .with_expected(expected) .run(); + allow_duplicates! { + assert_snapshot!(output, @r#" + +---+ + | a | + +---+ + | 1 | + | 2 | + | 3 | + +---+ + "#); + } } } #[test] fn print_maxrows_limited_one_batch() { - #[rustfmt::skip] - let expected = &[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| . |", - "| . |", - "| . |", - "+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_batches(vec![one_column_batch()]) .with_maxrows(MaxRows::Limited(1)) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + +---+ + | a | + +---+ + | 1 | + | . | + | . | + | . | + +---+ + "#); } #[test] fn print_maxrows_limited_multi_batched() { - #[rustfmt::skip] - let expected = &[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "| 1 |", - "| 2 |", - "| . |", - "| . |", - "| . |", - "+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_batches(vec![ one_column_batch(), @@ -479,8 +438,21 @@ mod tests { one_column_batch(), ]) .with_maxrows(MaxRows::Limited(5)) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + +---+ + | a | + +---+ + | 1 | + | 2 | + | 3 | + | 1 | + | 2 | + | . | + | . | + | . | + +---+ + "#); } #[test] @@ -488,22 +460,19 @@ mod tests { let batch = one_column_batch(); let empty_batch = RecordBatch::new_empty(batch.schema()); - #[rustfmt::skip] - let expected =&[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_batches(vec![empty_batch.clone(), batch, empty_batch]) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + +---+ + | a | + +---+ + | 1 | + | 2 | + | 3 | + +---+ + "#); } #[test] @@ -511,32 +480,28 @@ mod tests { let empty_batch = RecordBatch::new_empty(one_column_batch().schema()); // Print column headers for empty batch when format is Table - #[rustfmt::skip] - let expected =&[ - "+---+", - "| a |", - "+---+", - "+---+", - ]; - - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_schema(one_column_schema()) .with_batches(vec![empty_batch]) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_snapshot!(output, @r#" + +---+ + | a | + +---+ + +---+ + "#); // No output for empty batch when schema contains no columns let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); - let expected = &[""]; - PrintBatchesTest::new() + let output = PrintBatchesTest::new() .with_format(PrintFormat::Table) .with_schema(Arc::new(Schema::empty())) .with_batches(vec![empty_batch]) .with_header(WithHeader::Yes) - .with_expected(expected) .run(); + assert_eq!(output, "") } #[derive(Debug)] @@ -546,7 +511,6 @@ mod tests { batches: Vec, maxrows: MaxRows, with_header: WithHeader, - expected: Vec<&'static str>, } /// How to test with_header @@ -566,7 +530,6 @@ mod tests { batches: vec![], maxrows: MaxRows::Unlimited, with_header: WithHeader::Ignored, - expected: vec![], } } @@ -600,25 +563,9 @@ mod tests { self } - /// set expected output - fn with_expected(mut self, expected: &[&'static str]) -> Self { - self.expected = expected.to_vec(); - self - } - /// run the test - fn run(self) { - let actual = self.output(); - let actual: Vec<_> = actual.trim_end().split('\n').collect(); - let expected = self.expected; - assert_eq!( - actual, expected, - "\n\nactual:\n{actual:#?}\n\nexpected:\n{expected:#?}" - ); - } - /// formats batches using parameters and returns the resulting output - fn output(&self) -> String { + fn run(self) -> String { match self.with_header { WithHeader::Yes => self.output_with_header(true), WithHeader::No => self.output_with_header(false), @@ -644,6 +591,7 @@ mod tests { &self.batches, self.maxrows, with_header, + &FormatOptions::default(), ) .unwrap(); String::from_utf8(buffer).unwrap() diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 9557e783e8a7c..f54de189b4ef5 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -16,10 +16,14 @@ // under the License. use std::fmt::{Display, Formatter}; -use std::io::Write; +use std::io; use std::pin::Pin; use std::str::FromStr; +use std::sync::Arc; +use crate::object_storage::instrumented::{ + InstrumentedObjectStoreMode, InstrumentedObjectStoreRegistry, RequestSummary, +}; use crate::print_format::PrintFormat; use arrow::datatypes::SchemaRef; @@ -29,6 +33,7 @@ use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::physical_plan::RecordBatchStream; +use datafusion::config::FormatOptions; use futures::StreamExt; #[derive(Debug, Clone, PartialEq, Copy)] @@ -51,7 +56,7 @@ impl FromStr for MaxRows { } else { match maxrows.parse::() { Ok(nrows) => Ok(Self::Limited(nrows)), - _ => Err(format!("Invalid maxrows {}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.", maxrows)), + _ => Err(format!("Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.")), } } } @@ -66,12 +71,15 @@ impl Display for MaxRows { } } +const OBJECT_STORE_PROFILING_HEADER: &str = "Object Store Profiling"; + #[derive(Debug, Clone)] pub struct PrintOptions { pub format: PrintFormat, pub quiet: bool, pub maxrows: MaxRows, pub color: bool, + pub instrumented_registry: Arc, } // Returns the query execution details formatted @@ -103,12 +111,19 @@ impl PrintOptions { batches: &[RecordBatch], query_start_time: Instant, row_count: usize, + format_options: &FormatOptions, ) -> Result<()> { let stdout = std::io::stdout(); let mut writer = stdout.lock(); - self.format - .print_batches(&mut writer, schema, batches, self.maxrows, true)?; + self.format.print_batches( + &mut writer, + schema, + batches, + self.maxrows, + true, + format_options, + )?; let formatted_exec_details = get_execution_details_formatted( row_count, @@ -120,11 +135,7 @@ impl PrintOptions { query_start_time, ); - if !self.quiet { - writeln!(writer, "{formatted_exec_details}")?; - } - - Ok(()) + self.write_output(&mut writer, formatted_exec_details) } /// Print the stream to stdout using the specified format @@ -132,6 +143,7 @@ impl PrintOptions { &self, mut stream: Pin>, query_start_time: Instant, + format_options: &FormatOptions, ) -> Result<()> { if self.format == PrintFormat::Table { return Err(DataFusionError::External( @@ -154,6 +166,7 @@ impl PrintOptions { &[batch], MaxRows::Unlimited, with_header, + format_options, )?; with_header = false; } @@ -164,10 +177,90 @@ impl PrintOptions { query_start_time, ); + self.write_output(&mut writer, formatted_exec_details) + } + + fn write_output( + &self, + writer: &mut W, + formatted_exec_details: String, + ) -> Result<()> { if !self.quiet { writeln!(writer, "{formatted_exec_details}")?; + + if self.instrumented_registry.instrument_mode() + != InstrumentedObjectStoreMode::Disabled + { + writeln!(writer, "{OBJECT_STORE_PROFILING_HEADER}")?; + for store in self.instrumented_registry.stores() { + let requests = store.take_requests(); + + if !requests.is_empty() { + writeln!(writer, "{store}")?; + for req in requests.iter() { + writeln!(writer, "{req}")?; + } + // Add an extra blank line to help visually organize the output + writeln!(writer)?; + + writeln!(writer, "Summaries:")?; + let summaries = RequestSummary::summarize_by_operation(&requests); + for (op, summary) in summaries { + writeln!(writer, "{op:?}")?; + writeln!(writer, "{summary}")?; + } + } + } + } } Ok(()) } } + +#[cfg(test)] +mod tests { + use datafusion::error::Result; + + use super::*; + + #[test] + fn write_output() -> Result<()> { + let instrumented_registry = Arc::new(InstrumentedObjectStoreRegistry::new()); + let mut print_options = PrintOptions { + format: PrintFormat::Automatic, + quiet: true, + maxrows: MaxRows::Unlimited, + color: true, + instrumented_registry: Arc::clone(&instrumented_registry), + }; + + let mut print_output: Vec = Vec::new(); + let exec_out = String::from("Formatted Exec Output"); + print_options.write_output(&mut print_output, exec_out.clone())?; + assert!(print_output.is_empty()); + + print_options.quiet = false; + print_options.write_output(&mut print_output, exec_out.clone())?; + let out_str: String = print_output + .clone() + .try_into() + .expect("Expected successful String conversion"); + assert!(out_str.contains(&exec_out)); + + // clear the previous data from the output so it doesn't pollute the next test + print_output.clear(); + print_options + .instrumented_registry + .set_instrument_mode(InstrumentedObjectStoreMode::Enabled); + print_options.write_output(&mut print_output, exec_out.clone())?; + let out_str: String = print_output + .clone() + .try_into() + .expect("Expected successful String conversion"); + assert!(out_str.contains(&exec_out)); + assert!(out_str.contains(OBJECT_STORE_PROFILING_HEADER)); + + Ok(()) + } +} diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index 9ac09955512b8..a67924fef2537 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -19,9 +19,15 @@ use std::process::Command; use rstest::rstest; +use async_trait::async_trait; use insta::{glob, Settings}; use insta_cmd::{assert_cmd_snapshot, get_cargo_bin}; +use std::path::PathBuf; use std::{env, fs}; +use testcontainers::core::{CmdWaitFor, ExecCommand, Mount}; +use testcontainers::runners::AsyncRunner; +use testcontainers::{ContainerAsync, ImageExt, TestcontainersError}; +use testcontainers_modules::minio; fn cli() -> Command { Command::new(get_cargo_bin("datafusion-cli")) @@ -32,9 +38,87 @@ fn make_settings() -> Settings { settings.set_prepend_module_to_snapshot(false); settings.add_filter(r"Elapsed .* seconds\.", "[ELAPSED]"); settings.add_filter(r"DataFusion CLI v.*", "[CLI_VERSION]"); + settings.add_filter(r"(?s)backtrace:.*?\n\n\n", ""); settings } +async fn setup_minio_container() -> ContainerAsync { + const MINIO_ROOT_USER: &str = "TEST-DataFusionLogin"; + const MINIO_ROOT_PASSWORD: &str = "TEST-DataFusionPassword"; + + let data_path = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../datafusion/core/tests/data"); + + let absolute_data_path = data_path + .canonicalize() + .expect("Failed to get absolute path for test data"); + + let container = minio::MinIO::default() + .with_env_var("MINIO_ROOT_USER", MINIO_ROOT_USER) + .with_env_var("MINIO_ROOT_PASSWORD", MINIO_ROOT_PASSWORD) + .with_mount(Mount::bind_mount( + absolute_data_path.to_str().unwrap(), + "/source", + )) + .start() + .await; + + match container { + Ok(container) => { + // We wait for MinIO to be healthy and prepare test files. We do it via CLI to avoid s3 dependency + let commands = [ + ExecCommand::new(["/usr/bin/mc", "ready", "local"]), + ExecCommand::new([ + "/usr/bin/mc", + "alias", + "set", + "localminio", + "http://localhost:9000", + MINIO_ROOT_USER, + MINIO_ROOT_PASSWORD, + ]), + ExecCommand::new(["/usr/bin/mc", "mb", "localminio/data"]), + ExecCommand::new([ + "/usr/bin/mc", + "cp", + "-r", + "/source/", + "localminio/data/", + ]), + ]; + + for command in commands { + let command = + command.with_cmd_ready_condition(CmdWaitFor::Exit { code: Some(0) }); + + let cmd_ref = format!("{command:?}"); + + if let Err(e) = container.exec(command).await { + let stdout = container.stdout_to_vec().await.unwrap_or_default(); + let stderr = container.stderr_to_vec().await.unwrap_or_default(); + + panic!( + "Failed to execute command: {}\nError: {}\nStdout: {:?}\nStderr: {:?}", + cmd_ref, + e, + String::from_utf8_lossy(&stdout), + String::from_utf8_lossy(&stderr) + ); + } + } + + container + } + + Err(TestcontainersError::Client(e)) => { + panic!("Failed to start MinIO container. Ensure Docker is running and accessible: {e}"); + } + Err(e) => { + panic!("Failed to start MinIO container: {e}"); + } + } +} + #[cfg(test)] #[ctor::ctor] fn init() { @@ -69,6 +153,10 @@ fn init() { // can choose the old explain format too ["--command", "EXPLAIN FORMAT indent SELECT 123"], )] +#[case::change_format_version( + "change_format_version", + ["--file", "tests/sql/types_format.sql", "-q"], +)] #[test] fn cli_quick_test<'a>( #[case] snapshot_name: &'a str, @@ -118,6 +206,42 @@ fn test_cli_format<'a>(#[case] format: &'a str) { assert_cmd_snapshot!(cmd); } +#[rstest] +#[case("no_track", ["--top-memory-consumers", "0"])] +#[case("top2", ["--top-memory-consumers", "2"])] +#[case("top3_default", [])] +#[test] +fn test_cli_top_memory_consumers<'a>( + #[case] snapshot_name: &str, + #[case] top_memory_consumers: impl IntoIterator, +) { + let mut settings = make_settings(); + + settings.set_snapshot_suffix(snapshot_name); + + settings.add_filter( + r"[^\s]+\#\d+\(can spill: (true|false)\) consumed .*?B, peak .*?B", + "Consumer(can spill: bool) consumed XB, peak XB", + ); + settings.add_filter( + r"Error: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", + "Error: Failed to allocate ", + ); + settings.add_filter( + r"Resources exhausted: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", + "Resources exhausted: Failed to allocate", + ); + + let _bound = settings.bind_to_scope(); + + let mut cmd = cli(); + let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; + cmd.args(["--memory-limit", "10M", "--command", sql]); + cmd.args(top_memory_consumers); + + assert_cmd_snapshot!(cmd); +} + #[tokio::test] async fn test_cli() { if env::var("TEST_STORAGE_INTEGRATION").is_err() { @@ -125,12 +249,22 @@ async fn test_cli() { return; } + let container = setup_minio_container().await; + let settings = make_settings(); let _bound = settings.bind_to_scope(); + let port = container.get_host_port_ipv4(9000).await.unwrap(); + glob!("sql/integration/*.sql", |path| { let input = fs::read_to_string(path).unwrap(); - assert_cmd_snapshot!(cli().pass_stdin(input)) + assert_cmd_snapshot!(cli() + .env_clear() + .env("AWS_ACCESS_KEY_ID", "TEST-DataFusionLogin") + .env("AWS_SECRET_ACCESS_KEY", "TEST-DataFusionPassword") + .env("AWS_ENDPOINT", format!("http://localhost:{port}")) + .env("AWS_ALLOW_HTTP", "true") + .pass_stdin(input)) }); } @@ -146,27 +280,190 @@ async fn test_aws_options() { let settings = make_settings(); let _bound = settings.bind_to_scope(); - let access_key_id = - env::var("AWS_ACCESS_KEY_ID").expect("AWS_ACCESS_KEY_ID is not set"); - let secret_access_key = - env::var("AWS_SECRET_ACCESS_KEY").expect("AWS_SECRET_ACCESS_KEY is not set"); - let endpoint_url = env::var("AWS_ENDPOINT").expect("AWS_ENDPOINT is not set"); + let container = setup_minio_container().await; + let port = container.get_host_port_ipv4(9000).await.unwrap(); let input = format!( r#"CREATE EXTERNAL TABLE CARS STORED AS CSV LOCATION 's3://data/cars.csv' OPTIONS( - 'aws.access_key_id' '{}', - 'aws.secret_access_key' '{}', - 'aws.endpoint' '{}', + 'aws.access_key_id' 'TEST-DataFusionLogin', + 'aws.secret_access_key' 'TEST-DataFusionPassword', + 'aws.endpoint' 'http://localhost:{port}', 'aws.allow_http' 'true' ); SELECT * FROM CARS limit 1; -"#, - access_key_id, secret_access_key, endpoint_url +"# ); assert_cmd_snapshot!(cli().env_clear().pass_stdin(input)); } + +#[tokio::test] +async fn test_aws_region_auto_resolution() { + if env::var("TEST_STORAGE_INTEGRATION").is_err() { + eprintln!("Skipping external storages integration tests"); + return; + } + + let mut settings = make_settings(); + settings.add_filter(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", "[TIME]"); + let _bound = settings.bind_to_scope(); + + let bucket = "s3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.parquet"; + let region = "us-east-1"; + + let input = format!( + r#"CREATE EXTERNAL TABLE hits +STORED AS PARQUET +LOCATION '{bucket}' +OPTIONS( + 'aws.region' '{region}', + 'aws.skip_signature' true +); + +SELECT COUNT(*) FROM hits; +"# + ); + + assert_cmd_snapshot!(cli() + .env("RUST_LOG", "warn") + .env_remove("AWS_ENDPOINT") + .pass_stdin(input)); +} + +/// Ensure backtrace will be printed, if executing `datafusion-cli` with a query +/// that triggers error. +/// Example: +/// RUST_BACKTRACE=1 cargo run --features backtrace -- -c 'select pow(1,'foo');' +#[rstest] +#[case("SELECT pow(1,'foo')")] +#[case("SELECT CAST('not_a_number' AS INTEGER);")] +#[cfg(feature = "backtrace")] +fn test_backtrace_output(#[case] query: &str) { + let mut cmd = cli(); + // Use a command that will cause an error and trigger backtrace + cmd.args(["--command", query, "-q"]) + .env("RUST_BACKTRACE", "1"); // Enable backtrace + + let output = cmd.output().expect("Failed to execute command"); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined_output = format!("{}{}", stdout, stderr); + + // Assert that the output includes literal 'backtrace' + assert!( + combined_output.to_lowercase().contains("backtrace"), + "Expected output to contain 'backtrace', but got stdout: '{}' stderr: '{}'", + stdout, + stderr + ); +} + +#[tokio::test] +async fn test_s3_url_fallback() { + if env::var("TEST_STORAGE_INTEGRATION").is_err() { + eprintln!("Skipping external storages integration tests"); + return; + } + + let container = setup_minio_container().await; + + let mut settings = make_settings(); + settings.set_snapshot_suffix("s3_url_fallback"); + let _bound = settings.bind_to_scope(); + + // Create a table using a prefix path (without trailing slash) + // This should trigger the fallback logic where head() fails on the prefix + // and list() is used to discover the actual files + let input = r#"CREATE EXTERNAL TABLE partitioned_data +STORED AS CSV +LOCATION 's3://data/partitioned_csv' +OPTIONS ( + 'format.has_header' 'false' +); + +SELECT * FROM partitioned_data ORDER BY column_1, column_2 LIMIT 5; +"#; + + assert_cmd_snapshot!(cli().with_minio(&container).await.pass_stdin(input)); +} + +/// Validate object store profiling output +#[tokio::test] +async fn test_object_store_profiling() { + if env::var("TEST_STORAGE_INTEGRATION").is_err() { + eprintln!("Skipping external storages integration tests"); + return; + } + + let container = setup_minio_container().await; + + let mut settings = make_settings(); + settings.set_snapshot_suffix("s3_url_fallback"); + + // as the object store profiling contains timestamps and durations, we must + // filter them out to have stable snapshots + // + // Example line to filter: + // 2025-10-11T12:02:59.722646+00:00 operation=Get duration=0.001495s size=1006 path=cars.csv + // Output: + // operation=Get duration=[DURATION] size=1006 path=cars.csv + settings.add_filter( + r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?[+-]\d{2}:\d{2} operation=(Get|Put|Delete|List|Head) duration=\d+\.\d{6}s size=(\d+) path=(.*)", + " operation=$1 duration=[DURATION] size=$2 path=$3", + ); + + // We also need to filter out the durations reported in the summary output + // + // Example line(s) to filter: + // + // duration min: 0.000729s + // duration max: 0.000729s + // duration avg: 0.000729s + settings.add_filter(r"duration (min|max|avg): \d+\.\d{6}s", "[SUMMARY_DURATION]"); + + let _bound = settings.bind_to_scope(); + + let input = r#" + CREATE EXTERNAL TABLE CARS +STORED AS CSV +LOCATION 's3://data/cars.csv'; + +-- Initial query should not show any profiling as the object store is not instrumented yet +SELECT * from CARS LIMIT 1; +\object_store_profiling enabled +-- Query again to see the profiling output +SELECT * from CARS LIMIT 1; +\object_store_profiling disabled +-- Final query should not show any profiling as we disabled it again +SELECT * from CARS LIMIT 1; +"#; + + assert_cmd_snapshot!(cli().with_minio(&container).await.pass_stdin(input)); +} + +/// Extension trait to Add the minio connection information to a Command +#[async_trait] +trait MinioCommandExt { + async fn with_minio(&mut self, container: &ContainerAsync) + -> &mut Self; +} + +#[async_trait] +impl MinioCommandExt for Command { + async fn with_minio( + &mut self, + container: &ContainerAsync, + ) -> &mut Self { + let port = container.get_host_port_ipv4(9000).await.unwrap(); + + self.env_clear() + .env("AWS_ACCESS_KEY_ID", "TEST-DataFusionLogin") + .env("AWS_SECRET_ACCESS_KEY", "TEST-DataFusionPassword") + .env("AWS_ENDPOINT", format!("http://localhost:{port}")) + .env("AWS_ALLOW_HTTP", "true") + } +} diff --git a/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap b/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap new file mode 100644 index 0000000000000..cd6d918b78d99 --- /dev/null +++ b/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap @@ -0,0 +1,29 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + AWS_ENDPOINT: "" + RUST_LOG: warn + stdin: "CREATE EXTERNAL TABLE hits\nSTORED AS PARQUET\nLOCATION 's3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.parquet'\nOPTIONS(\n 'aws.region' 'us-east-1',\n 'aws.skip_signature' true\n);\n\nSELECT COUNT(*) FROM hits;\n" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++----------+ +| count(*) | ++----------+ +| 1000000 | ++----------+ +1 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- +[[TIME] WARN datafusion_cli::exec] S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration. diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@can_see_indent_format.snap b/datafusion-cli/tests/snapshots/cli_quick_test@can_see_indent_format.snap index b2fb64709974e..8275041acaecc 100644 --- a/datafusion-cli/tests/snapshots/cli_quick_test@can_see_indent_format.snap +++ b/datafusion-cli/tests/snapshots/cli_quick_test@can_see_indent_format.snap @@ -5,7 +5,6 @@ info: args: - "--command" - EXPLAIN FORMAT indent SELECT 123 -snapshot_kind: text --- success: true exit_code: 0 @@ -15,7 +14,7 @@ exit_code: 0 | plan_type | plan | +---------------+------------------------------------------+ | logical_plan | Projection: Int64(123) | -| | EmptyRelation | +| | EmptyRelation: rows=1 | | physical_plan | ProjectionExec: expr=[123 as Int64(123)] | | | PlaceholderRowExec | | | | diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap b/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap new file mode 100644 index 0000000000000..74059b2a6103c --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap @@ -0,0 +1,20 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--file" + - tests/sql/types_format.sql + - "-q" +--- +success: true +exit_code: 0 +----- stdout ----- ++-----------+ +| Int64(54) | +| Int64 | ++-----------+ +| 54 | ++-----------+ + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap new file mode 100644 index 0000000000000..89b646a531f8b --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap @@ -0,0 +1,21 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "0" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap new file mode 100644 index 0000000000000..62f864b3adb6e --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap @@ -0,0 +1,24 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "2" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, peak XB, + Consumer(can spill: bool) consumed XB, peak XB. +Error: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap new file mode 100644 index 0000000000000..9845d095c9180 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap @@ -0,0 +1,23 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Additional allocation failed for ExternalSorter[0] with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, peak XB, + Consumer(can spill: bool) consumed XB, peak XB, + Consumer(can spill: bool) consumed XB, peak XB. +Error: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/object_store_profiling@s3_url_fallback.snap b/datafusion-cli/tests/snapshots/object_store_profiling@s3_url_fallback.snap new file mode 100644 index 0000000000000..50c6cc8eab99f --- /dev/null +++ b/datafusion-cli/tests/snapshots/object_store_profiling@s3_url_fallback.snap @@ -0,0 +1,64 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + AWS_ACCESS_KEY_ID: TEST-DataFusionLogin + AWS_ALLOW_HTTP: "true" + AWS_ENDPOINT: "http://localhost:55031" + AWS_SECRET_ACCESS_KEY: TEST-DataFusionPassword + stdin: "\n CREATE EXTERNAL TABLE CARS\nSTORED AS CSV\nLOCATION 's3://data/cars.csv';\n\n-- Initial query should not show any profiling as the object store is not instrumented yet\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling enabled\n-- Query again to see the profiling output\nSELECT * from CARS LIMIT 1;\n\\object_store_profiling disabled\n-- Final query should not show any profiling as we disabled it again\nSELECT * from CARS LIMIT 1;\n" +snapshot_kind: text +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +ObjectStore Profile mode set to Enabled ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +Object Store Profiling +Instrumented Object Store: instrument_mode: Enabled, inner: AmazonS3(data) + operation=Get duration=[DURATION] size=1006 path=cars.csv + +Summaries: +Get +count: 1 +[SUMMARY_DURATION] +[SUMMARY_DURATION] +[SUMMARY_DURATION] +size min: 1006 B +size max: 1006 B +size avg: 1006 B +size sum: 1006 B + +ObjectStore Profile mode set to Disabled ++-----+-------+---------------------+ +| car | speed | time | ++-----+-------+---------------------+ +| red | 20.0 | 1996-04-12T12:05:03 | ++-----+-------+---------------------+ +1 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/s3_url_fallback@s3_url_fallback.snap b/datafusion-cli/tests/snapshots/s3_url_fallback@s3_url_fallback.snap new file mode 100644 index 0000000000000..07036d041b42c --- /dev/null +++ b/datafusion-cli/tests/snapshots/s3_url_fallback@s3_url_fallback.snap @@ -0,0 +1,34 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + AWS_ACCESS_KEY_ID: TEST-DataFusionLogin + AWS_ALLOW_HTTP: "true" + AWS_ENDPOINT: "http://localhost:32771" + AWS_SECRET_ACCESS_KEY: TEST-DataFusionPassword + stdin: "CREATE EXTERNAL TABLE partitioned_data\nSTORED AS CSV\nLOCATION 's3://data/partitioned_csv'\nOPTIONS (\n 'format.has_header' 'false'\n);\n\nSELECT * FROM partitioned_data ORDER BY column_1, column_2 LIMIT 5;\n" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++----------+----------+----------+ +| column_1 | column_2 | column_3 | ++----------+----------+----------+ +| 0 | 0 | true | +| 0 | 1 | false | +| 0 | 2 | true | +| 0 | 3 | false | +| 0 | 4 | true | ++----------+----------+----------+ +5 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- diff --git a/datafusion-cli/tests/sql/types_format.sql b/datafusion-cli/tests/sql/types_format.sql new file mode 100644 index 0000000000000..637929c980a15 --- /dev/null +++ b/datafusion-cli/tests/sql/types_format.sql @@ -0,0 +1,3 @@ +set datafusion.format.types_info to true; + +select 54 \ No newline at end of file diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index f6b7d641d1264..68bb5376a1acc 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -52,6 +52,10 @@ path = "examples/external_dependency/dataframe-to-s3.rs" name = "query_aws_s3" path = "examples/external_dependency/query-aws-s3.rs" +[[example]] +name = "custom_file_casts" +path = "examples/custom_file_casts.rs" + [dev-dependencies] arrow = { workspace = true } # arrow_schema is required for record_batch! macro :sad: @@ -61,7 +65,10 @@ async-trait = { workspace = true } bytes = { workspace = true } dashmap = { workspace = true } # note only use main datafusion crate for examples -datafusion = { workspace = true, default-features = true } +base64 = "0.22.1" +datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } +datafusion-ffi = { workspace = true } +datafusion-physical-expr-adapter = { workspace = true } datafusion-proto = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } @@ -69,14 +76,16 @@ log = { workspace = true } mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } +rand = { workspace = true } +serde_json = { workspace = true } tempfile = { workspace = true } test-utils = { path = "../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tonic = "0.12.1" +tonic = "0.13.1" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } -uuid = "1.16" +uuid = "1.18" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.29.0", features = ["fs"] } +nix = { version = "0.30.1", features = ["fs"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 3ba4c77cd84c3..f1bcbcce82004 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -50,21 +50,29 @@ cargo run --example dataframe - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) - [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files +- [`async_udf.rs`](examples/async_udf.rs): Define and invoke an asynchronous User Defined Scalar Function (UDF) - [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics (row level access control) - [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization - [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file - [`csv_json_opener.rs`](examples/csv_json_opener.rs): Use low level `FileOpener` APIs to read CSV/JSON into Arrow `RecordBatch`es - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) +- [`custom_file_casts.rs`](examples/custom_file_casts.rs): Implement custom casting rules to adapt file schemas - [`custom_file_format.rs`](examples/custom_file_format.rs): Write data to a custom file format - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame API against parquet files, csv files, and in-memory data, including multiple subqueries. Also demonstrates the various methods to write out a DataFrame to a table, parquet file, csv file, and json file. +- [`default_column_values.rs`](examples/default_column_values.rs): Implement custom default value handling for missing columns using field metadata and PhysicalExprAdapter - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results (Arrow ArrayRefs) into Rust structs - [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify, analyze and coerce `Expr`s - [`file_stream_provider.rs`](examples/file_stream_provider.rs): Run a query on `FileStreamProvider` which implements `StreamProvider` for reading and writing to arbitrary stream sources / sinks. - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros +- [`memory_pool_tracking.rs`](examples/memory_pool_tracking.rs): Demonstrates TrackConsumersPool for memory tracking and debugging with enhanced error messages +- [`memory_pool_execution_plan.rs`](examples/memory_pool_execution_plan.rs): Shows how to implement memory-aware ExecutionPlan with memory reservation and spilling - [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates +- [`parquet_embedded_index.rs`](examples/parquet_embedded_index.rs): Store a custom index inside a Parquet file and use it to speed up queries +- [`parquet_encrypted.rs`](examples/parquet_encrypted.rs): Read and write encrypted Parquet files using DataFusion +- [`parquet_encrypted_with_kms.rs`](examples/parquet_encrypted_with_kms.rs): Read and write encrypted Parquet files using an encryption factory - [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries - [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution - [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into DataFusion `Expr`. diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs index b8c303e221618..55400e2192832 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -30,7 +30,7 @@ use datafusion::common::{ use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::parquet::ParquetAccessPlan; use datafusion::datasource::physical_plan::{ - FileMeta, FileScanConfigBuilder, ParquetFileReaderFactory, ParquetSource, + FileScanConfigBuilder, ParquetFileReaderFactory, ParquetSource, }; use datafusion::datasource::TableProvider; use datafusion::execution::object_store::ObjectStoreUrl; @@ -495,7 +495,7 @@ impl TableProvider for IndexTableProvider { ParquetSource::default() // provide the predicate so the DataSourceExec can try and prune // row groups internally - .with_predicate(Arc::clone(&schema), predicate) + .with_predicate(predicate) // provide the factory to create parquet reader without re-reading metadata .with_parquet_file_reader_factory(Arc::new(reader_factory)), ); @@ -555,15 +555,16 @@ impl ParquetFileReaderFactory for CachedParquetFileReaderFactory { fn create_reader( &self, _partition_index: usize, - file_meta: FileMeta, + partitioned_file: PartitionedFile, metadata_size_hint: Option, _metrics: &ExecutionPlanMetricsSet, ) -> Result> { // for this example we ignore the partition index and metrics // but in a real system you would likely use them to report details on // the performance of the reader. - let filename = file_meta - .location() + let filename = partitioned_file + .object_meta + .location .parts() .last() .expect("No path in location") @@ -571,7 +572,9 @@ impl ParquetFileReaderFactory for CachedParquetFileReaderFactory { .to_string(); let object_store = Arc::clone(&self.object_store); - let mut inner = ParquetObjectReader::new(object_store, file_meta.object_meta); + let mut inner = + ParquetObjectReader::new(object_store, partitioned_file.object_meta.location) + .with_file_size(partitioned_file.object_meta.size); if let Some(hint) = metadata_size_hint { inner = inner.with_footer_size_hint(hint) @@ -599,7 +602,7 @@ struct ParquetReaderWithCache { impl AsyncFileReader for ParquetReaderWithCache { fn get_bytes( &mut self, - range: Range, + range: Range, ) -> BoxFuture<'_, datafusion::parquet::errors::Result> { println!("get_bytes: {} Reading range {:?}", self.filename, range); self.inner.get_bytes(range) @@ -607,7 +610,7 @@ impl AsyncFileReader for ParquetReaderWithCache { fn get_byte_ranges( &mut self, - ranges: Vec>, + ranges: Vec>, ) -> BoxFuture<'_, datafusion::parquet::errors::Result>> { println!( "get_byte_ranges: {} Reading ranges {:?}", @@ -618,6 +621,7 @@ impl AsyncFileReader for ParquetReaderWithCache { fn get_metadata( &mut self, + _options: Option<&ArrowReaderOptions>, ) -> BoxFuture<'_, datafusion::parquet::errors::Result>> { println!("get_metadata: {} returning cached metadata", self.filename); diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 9cda726db7197..89f0a470e32e4 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -25,6 +25,7 @@ use arrow::array::{ }; use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}; use arrow::record_batch::RecordBatch; +use arrow_schema::FieldRef; use datafusion::common::{cast::as_float64_array, ScalarValue}; use datafusion::error::Result; use datafusion::logical_expr::{ @@ -40,7 +41,7 @@ use datafusion::prelude::*; /// a function `accumulator` that returns the `Accumulator` instance. /// /// To do so, we must implement the `AggregateUDFImpl` trait. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct GeoMeanUdaf { signature: Signature, } @@ -92,10 +93,10 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", args.return_type.clone(), true), - Field::new("n", DataType::UInt32, true), + Field::new("prod", args.return_type().clone(), true).into(), + Field::new("n", DataType::UInt32, true).into(), ]) } @@ -367,7 +368,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { /// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user /// defined aggregate function with a different expression which is defined in the `simplify` method. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifiedGeoMeanUdaf { signature: Signature, } @@ -401,7 +402,7 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { unimplemented!("should not be invoked") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } @@ -482,7 +483,7 @@ async fn main() -> Result<()> { ctx.register_udaf(udf.clone()); let sql_df = ctx - .sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name)) + .sql(&format!("SELECT {udf_name}(a) FROM t GROUP BY b")) .await?; sql_df.show().await?; diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index 290d1c53334b7..56ae599efa11b 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -39,7 +39,7 @@ use datafusion::prelude::*; /// the power of the second argument `a^b`. /// /// To do so, we must implement the `ScalarUDFImpl` trait. -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Eq, Hash)] struct PowUdf { signature: Signature, aliases: Vec, diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 8330e783319d5..ba4c377fd6762 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. -use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use std::any::Any; - use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, AsArray, Float64Array}, datatypes::Float64Type, }; +use arrow_schema::FieldRef; use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg_udaf; @@ -32,17 +30,21 @@ use datafusion::logical_expr::function::{ }; use datafusion::logical_expr::simplify::SimplifyInfo; use datafusion::logical_expr::{ - Expr, PartitionEvaluator, Signature, WindowFrame, WindowFunctionDefinition, - WindowUDF, WindowUDFImpl, + Expr, LimitEffect, PartitionEvaluator, Signature, WindowFrame, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; +use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::*; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::any::Any; +use std::sync::Arc; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements /// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance. /// /// To do so, we must implement the `WindowUDFImpl` trait. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SmoothItUdf { signature: Signature, } @@ -87,8 +89,12 @@ impl WindowUDFImpl for SmoothItUdf { Ok(Box::new(MyPartitionEvaluator::new())) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true).into()) + } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown } } @@ -148,7 +154,7 @@ impl PartitionEvaluator for MyPartitionEvaluator { } /// This UDWF will show how to use the WindowUDFImpl::simplify() API -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifySmoothItUdf { signature: Signature, } @@ -190,7 +196,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), params: WindowFunctionParams { args: window_function.params.args, @@ -198,6 +204,8 @@ impl WindowUDFImpl for SimplifySmoothItUdf { order_by: window_function.params.order_by, window_frame: window_function.params.window_frame, null_treatment: window_function.params.null_treatment, + distinct: window_function.params.distinct, + filter: window_function.params.filter, }, })) }; @@ -205,8 +213,12 @@ impl WindowUDFImpl for SimplifySmoothItUdf { Some(Box::new(simplify)) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true).into()) + } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown } } diff --git a/datafusion-examples/examples/async_udf.rs b/datafusion-examples/examples/async_udf.rs new file mode 100644 index 0000000000000..b52ec68ea4422 --- /dev/null +++ b/datafusion-examples/examples/async_udf.rs @@ -0,0 +1,237 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example shows how to create and use "Async UDFs" in DataFusion. +//! +//! Async UDFs allow you to perform asynchronous operations, such as +//! making network requests. This can be used for tasks like fetching +//! data from an external API such as a LLM service or an external database. + +use arrow::array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::assert_batches_eq; +use datafusion::common::cast::as_string_view_array; +use datafusion::common::error::Result; +use datafusion::common::not_impl_err; +use datafusion::common::utils::take_function_args; +use datafusion::execution::SessionStateBuilder; +use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use std::any::Any; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + // Use a hard coded parallelism level of 4 so the explain plan + // is consistent across machines. + let config = SessionConfig::new().with_target_partitions(4); + let ctx = + SessionContext::from(SessionStateBuilder::new().with_config(config).build()); + + // Similarly to regular UDFs, you create an AsyncScalarUDF by implementing + // `AsyncScalarUDFImpl` and creating an instance of `AsyncScalarUDF`. + let async_equal = AskLLM::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_equal)); + + // Async UDFs are registered with the SessionContext, using the same + // `register_udf` method as regular UDFs. + ctx.register_udf(udf.into_scalar_udf()); + + // Create a table named 'animal' with some sample data + ctx.register_batch("animal", animal()?)?; + + // You can use the async UDF as normal in SQL queries + // + // Note: Async UDFs can currently be used in the select list and filter conditions. + let results = ctx + .sql("select * from animal a where ask_llm(a.name, 'Is this animal furry?')") + .await? + .collect() + .await?; + + assert_batches_eq!( + [ + "+----+------+", + "| id | name |", + "+----+------+", + "| 1 | cat |", + "| 2 | dog |", + "+----+------+", + ], + &results + ); + + // While the interface is the same for both normal and async UDFs, you can + // use `EXPLAIN` output to see that the async UDF uses a special + // `AsyncFuncExec` node in the physical plan: + let results = ctx + .sql("explain select * from animal a where ask_llm(a.name, 'Is this animal furry?')") + .await? + .collect() + .await?; + + assert_batches_eq!( + [ + "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | SubqueryAlias: a |", + "| | Filter: ask_llm(CAST(animal.name AS Utf8View), Utf8View(\"Is this animal furry?\")) |", + "| | TableScan: animal projection=[id, name] |", + "| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |", + "| | FilterExec: __async_fn_0@2, projection=[id@0, name@1] |", + "| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |", + "| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=ask_llm(CAST(name@1 AS Utf8View), Is this animal furry?))] |", + "| | CoalesceBatchesExec: target_batch_size=8192 |", + "| | DataSourceExec: partitions=1, partition_sizes=[1] |", + "| | |", + "+---------------+--------------------------------------------------------------------------------------------------------------------------------+", + ], + &results + ); + + Ok(()) +} + +/// Returns a sample `RecordBatch` representing an "animal" table with two columns: +fn animal() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let id_array = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); + let name_array = Arc::new(StringArray::from(vec![ + "cat", "dog", "fish", "bird", "snake", + ])); + + Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?) +} + +/// An async UDF that simulates asking a large language model (LLM) service a +/// question based on the content of two columns. The UDF will return a boolean +/// indicating whether the LLM thinks the first argument matches the question in +/// the second argument. +/// +/// Since this is a simplified example, it does not call an LLM service, but +/// could be extended to do so in a real-world scenario. +#[derive(Debug, PartialEq, Eq, Hash)] +struct AskLLM { + signature: Signature, +} + +impl Default for AskLLM { + fn default() -> Self { + Self::new() + } +} + +impl AskLLM { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Utf8View, DataType::Utf8View], + Volatility::Volatile, + ), + } + } +} + +/// All async UDFs implement the `ScalarUDFImpl` trait, which provides the basic +/// information for the function, such as its name, signature, and return type. +/// [async_trait] +impl ScalarUDFImpl for AskLLM { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ask_llm" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + /// Since this is an async UDF, the `invoke_with_args` method will not be + /// called directly. + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("AskLLM can only be called from async contexts") + } +} + +/// In addition to [`ScalarUDFImpl`], we also need to implement the +/// [`AsyncScalarUDFImpl`] trait. +#[async_trait] +impl AsyncScalarUDFImpl for AskLLM { + /// The `invoke_async_with_args` method is similar to `invoke_with_args`, + /// but it returns a `Future` that resolves to the result. + /// + /// Since this signature is `async`, it can do any `async` operations, such + /// as network requests. This method is run on the same tokio `Runtime` that + /// is processing the query, so you may wish to make actual network requests + /// on a different `Runtime`, as explained in the `thread_pools.rs` example + /// in this directory. + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + // in a real UDF you would likely want to special case constant + // arguments to improve performance, but this example converts the + // arguments to arrays for simplicity. + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [content_column, question_column] = take_function_args(self.name(), args)?; + + // In a real function, you would use a library such as `reqwest` here to + // make an async HTTP request. Credentials and other configurations can + // be supplied via the `ConfigOptions` parameter. + + // In this example, we will simulate the LLM response by comparing the two + // input arguments using some static strings + let content_column = as_string_view_array(&content_column)?; + let question_column = as_string_view_array(&question_column)?; + + let result_array: BooleanArray = content_column + .iter() + .zip(question_column.iter()) + .map(|(a, b)| { + // If either value is null, return None + let a = a?; + let b = b?; + // Simulate an LLM response by checking the arguments to some + // hardcoded conditions. + if a.contains("cat") && b.contains("furry") + || a.contains("dog") && b.contains("furry") + { + Some(true) + } else { + Some(false) + } + }) + .collect(); + + Ok(ColumnarValue::from(Arc::new(result_array) as ArrayRef)) + } +} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/catalog.rs index 655438b78b9fa..229867cdfc5bb 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/catalog.rs @@ -309,7 +309,7 @@ fn prepare_example_data() -> Result { 3,baz"#; for i in 0..5 { - let mut file = File::create(path.join(format!("{}.csv", i)))?; + let mut file = File::create(path.join(format!("{i}.csv")))?; file.write_all(content.as_bytes())?; } diff --git a/datafusion-examples/examples/composed_extension_codec.rs b/datafusion-examples/examples/composed_extension_codec.rs index 4baefcae507f6..57f2c370413aa 100644 --- a/datafusion-examples/examples/composed_extension_codec.rs +++ b/datafusion-examples/examples/composed_extension_codec.rs @@ -32,16 +32,16 @@ use std::any::Any; use std::fmt::Debug; -use std::ops::Deref; use std::sync::Arc; +use datafusion::common::internal_err; use datafusion::common::Result; -use datafusion::common::{internal_err, DataFusionError}; -use datafusion::logical_expr::registry::FunctionRegistry; -use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; +use datafusion::execution::TaskContext; use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; use datafusion::prelude::SessionContext; -use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; +use datafusion_proto::physical_plan::{ + AsExecutionPlan, ComposedPhysicalExtensionCodec, PhysicalExtensionCodec, +}; use datafusion_proto::protobuf; #[tokio::main] @@ -54,12 +54,12 @@ async fn main() { }); let ctx = SessionContext::new(); - let composed_codec = ComposedPhysicalExtensionCodec { - codecs: vec![ - Arc::new(ParentPhysicalExtensionCodec {}), - Arc::new(ChildPhysicalExtensionCodec {}), - ], - }; + // Position in this list is important as it will be used for decoding. + // If new codec is added it should go to last position. + let composed_codec = ComposedPhysicalExtensionCodec::new(vec![ + Arc::new(ParentPhysicalExtensionCodec {}), + Arc::new(ChildPhysicalExtensionCodec {}), + ]); // serialize execution plan to proto let proto: protobuf::PhysicalPlanNode = @@ -70,9 +70,8 @@ async fn main() { .expect("to proto"); // deserialize proto back to execution plan - let runtime = ctx.runtime_env(); let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) + .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) .expect("from proto"); // assert that the original and deserialized execution plans are equal @@ -123,7 +122,7 @@ impl ExecutionPlan for ParentExec { fn execute( &self, _partition: usize, - _context: Arc, + _context: Arc, ) -> Result { unreachable!() } @@ -138,7 +137,7 @@ impl PhysicalExtensionCodec for ParentPhysicalExtensionCodec { &self, buf: &[u8], inputs: &[Arc], - _registry: &dyn FunctionRegistry, + _ctx: &TaskContext, ) -> Result> { if buf == "ParentExec".as_bytes() { Ok(Arc::new(ParentExec { @@ -199,7 +198,7 @@ impl ExecutionPlan for ChildExec { fn execute( &self, _partition: usize, - _context: Arc, + _context: Arc, ) -> Result { unreachable!() } @@ -214,7 +213,7 @@ impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { &self, buf: &[u8], _inputs: &[Arc], - _registry: &dyn FunctionRegistry, + _ctx: &TaskContext, ) -> Result> { if buf == "ChildExec".as_bytes() { Ok(Arc::new(ChildExec {})) @@ -232,60 +231,3 @@ impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { } } } - -/// A PhysicalExtensionCodec that tries one of multiple inner codecs -/// until one works -#[derive(Debug)] -struct ComposedPhysicalExtensionCodec { - codecs: Vec>, -} - -impl ComposedPhysicalExtensionCodec { - fn try_any( - &self, - mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result, - ) -> Result { - let mut last_err = None; - for codec in &self.codecs { - match f(codec.as_ref()) { - Ok(node) => return Ok(node), - Err(err) => last_err = Some(err), - } - } - - Err(last_err.unwrap_or_else(|| { - DataFusionError::NotImplemented("Empty list of composed codecs".to_owned()) - })) - } -} - -impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { - fn try_decode( - &self, - buf: &[u8], - inputs: &[Arc], - registry: &dyn FunctionRegistry, - ) -> Result> { - self.try_any(|codec| codec.try_decode(buf, inputs, registry)) - } - - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - self.try_any(|codec| codec.try_encode(node.clone(), buf)) - } - - fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { - self.try_any(|codec| codec.try_decode_udf(name, buf)) - } - - fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { - self.try_any(|codec| codec.try_encode_udf(node, buf)) - } - - fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { - self.try_any(|codec| codec.try_decode_udaf(name, buf)) - } - - fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { - self.try_any(|codec| codec.try_encode_udaf(node, buf)) - } -} diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_file_casts.rs new file mode 100644 index 0000000000000..65ca096820640 --- /dev/null +++ b/datafusion-examples/examples/custom_file_casts.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{record_batch, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; + +use datafusion::assert_batches_eq; +use datafusion::common::not_impl_err; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::common::{Result, ScalarValue}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::physical_expr::expressions::CastExpr; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::prelude::SessionConfig; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; + +// Example showing how to implement custom casting rules to adapt file schemas. +// This example enforces that casts must be strictly widening: if the file type is Int64 and the table type is Int32, it will error +// before even reading the data. +// Without this custom cast rule DataFusion would happily do the narrowing cast, potentially erroring only if it found a row with data it could not cast. + +#[tokio::main] +async fn main() -> Result<()> { + println!("=== Creating example data ==="); + + // Create a logical / table schema with an Int32 column + let logical_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + // Create some data that can be cast (Int16 -> Int32 is widening) and some that cannot (Int64 -> Int32 is narrowing) + let store = Arc::new(InMemory::new()) as Arc; + let path = Path::from("good.parquet"); + let batch = record_batch!(("id", Int16, [1, 2, 3]))?; + write_data(&store, &path, &batch).await?; + let path = Path::from("bad.parquet"); + let batch = record_batch!(("id", Int64, [1, 2, 3]))?; + write_data(&store, &path, &batch).await?; + + // Set up query execution + let mut cfg = SessionConfig::new(); + // Turn on filter pushdown so that the PhysicalExprAdapter is used + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env() + .register_object_store(ObjectStoreUrl::parse("memory://")?.as_ref(), store); + + // Register our good and bad files via ListingTable + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///good.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(Arc::clone(&logical_schema)) + .with_expr_adapter_factory(Arc::new( + CustomCastPhysicalExprAdapterFactory::new(Arc::new( + DefaultPhysicalExprAdapterFactory, + )), + )); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("good_table", Arc::new(table))?; + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///bad.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(Arc::clone(&logical_schema)) + .with_expr_adapter_factory(Arc::new( + CustomCastPhysicalExprAdapterFactory::new(Arc::new( + DefaultPhysicalExprAdapterFactory, + )), + )); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("bad_table", Arc::new(table))?; + + println!("\n=== File with narrower schema is cast ==="); + let query = "SELECT id FROM good_table WHERE id > 1"; + println!("Query: {query}"); + let batches = ctx.sql(query).await?.collect().await?; + #[rustfmt::skip] + let expected = [ + "+----+", + "| id |", + "+----+", + "| 2 |", + "| 3 |", + "+----+", + ]; + arrow::util::pretty::print_batches(&batches)?; + assert_batches_eq!(expected, &batches); + + println!("\n=== File with wider schema errors ==="); + let query = "SELECT id FROM bad_table WHERE id > 1"; + println!("Query: {query}"); + match ctx.sql(query).await?.collect().await { + Ok(_) => panic!("Expected error for narrowing cast, but query succeeded"), + Err(e) => { + println!("Caught expected error: {e}"); + } + } + Ok(()) +} + +async fn write_data( + store: &dyn ObjectStore, + path: &Path, + batch: &RecordBatch, +) -> Result<()> { + let mut buf = vec![]; + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?; + writer.write(batch)?; + writer.close()?; + + let payload = PutPayload::from_bytes(buf.into()); + store.put(path, payload).await?; + Ok(()) +} + +/// Factory for creating DefaultValuePhysicalExprAdapter instances +#[derive(Debug)] +struct CustomCastPhysicalExprAdapterFactory { + inner: Arc, +} + +impl CustomCastPhysicalExprAdapterFactory { + fn new(inner: Arc) -> Self { + Self { inner } + } +} + +impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + let inner = self + .inner + .create(logical_file_schema, Arc::clone(&physical_file_schema)); + Arc::new(CustomCastsPhysicalExprAdapter { + physical_file_schema, + inner, + }) + } +} + +/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata +/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation +#[derive(Debug, Clone)] +struct CustomCastsPhysicalExprAdapter { + physical_file_schema: SchemaRef, + inner: Arc, +} + +impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { + fn rewrite(&self, mut expr: Arc) -> Result> { + // First delegate to the inner adapter to handle missing columns and discover any necessary casts + expr = self.inner.rewrite(expr)?; + // Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression + // For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138). + expr.transform(|expr| { + if let Some(cast) = expr.as_any().downcast_ref::() { + let input_data_type = + cast.expr().data_type(&self.physical_file_schema)?; + let output_data_type = cast.data_type(&self.physical_file_schema)?; + if !cast.is_bigger_cast(&input_data_type) { + return not_impl_err!( + "Unsupported CAST from {input_data_type} to {output_data_type}" + ); + } + } + Ok(Transformed::no(expr)) + }) + .data() + } + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc { + Arc::new(Self { + inner: self.inner.with_partition_values(partition_values), + ..self.clone() + }) + } +} diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index 165d826270613..67fe642fd46ee 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -21,28 +21,24 @@ use arrow::{ array::{AsArray, RecordBatch, StringArray, UInt8Array}, datatypes::{DataType, Field, Schema, SchemaRef, UInt64Type}, }; -use datafusion::physical_expr::LexRequirement; -use datafusion::physical_expr::PhysicalExpr; use datafusion::{ catalog::Session, common::{GetExt, Statistics}, -}; -use datafusion::{ - datasource::physical_plan::FileSource, execution::session_state::SessionStateBuilder, -}; -use datafusion::{ datasource::{ file_format::{ csv::CsvFormatFactory, file_compression_type::FileCompressionType, FileFormat, FileFormatFactory, }, - physical_plan::{FileScanConfig, FileSinkConfig}, + physical_plan::{FileScanConfig, FileSinkConfig, FileSource}, MemTable, }, error::Result, + execution::session_state::SessionStateBuilder, + physical_expr_common::sort_expr::LexRequirement, physical_plan::ExecutionPlan, prelude::SessionContext, }; + use object_store::{ObjectMeta, ObjectStore}; use tempfile::tempdir; @@ -85,6 +81,10 @@ impl FileFormat for TSVFileFormat { } } + fn compression_type(&self) -> Option { + None + } + async fn infer_schema( &self, state: &dyn Session, @@ -112,11 +112,8 @@ impl FileFormat for TSVFileFormat { &self, state: &dyn Session, conf: FileScanConfig, - filters: Option<&Arc>, ) -> Result> { - self.csv_file_format - .create_physical_plan(state, conf, filters) - .await + self.csv_file_format.create_physical_plan(state, conf).await } async fn create_writer_physical_plan( diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 6f61c164f41df..a5ee571a14764 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::MemTable; use datafusion::common::config::CsvOptions; use datafusion::common::parsers::CompressionTypeVariant; use datafusion::common::DataFusionError; @@ -58,11 +59,13 @@ use tempfile::tempdir; /// * [query_to_date]: execute queries against parquet files #[tokio::main] async fn main() -> Result<()> { + env_logger::init(); // The SessionContext is the main high level API for interacting with DataFusion let ctx = SessionContext::new(); read_parquet(&ctx).await?; read_csv(&ctx).await?; read_memory(&ctx).await?; + read_memory_macro().await?; write_out(&ctx).await?; register_aggregate_test_data("t1", &ctx).await?; register_aggregate_test_data("t2", &ctx).await?; @@ -144,7 +147,7 @@ async fn read_csv(ctx: &SessionContext) -> Result<()> { // and using the `enable_url_table` refer to local files directly let dyn_ctx = ctx.clone().enable_url_table(); let csv_df = dyn_ctx - .sql(&format!("SELECT rating, unixtime FROM '{}'", file_path)) + .sql(&format!("SELECT rating, unixtime FROM '{file_path}'")) .await?; csv_df.show().await?; @@ -173,16 +176,40 @@ async fn read_memory(ctx: &SessionContext) -> Result<()> { Ok(()) } +/// Use the DataFrame API to: +/// 1. Read in-memory data. +async fn read_memory_macro() -> Result<()> { + // create a DataFrame using macro + let df = dataframe!( + "a" => ["a", "b", "c", "d"], + "b" => [1, 10, 10, 100] + )?; + // print the results + df.show().await?; + + // create empty DataFrame using macro + let df_empty = dataframe!()?; + df_empty.show().await?; + + Ok(()) +} + /// Use the DataFrame API to: /// 1. Write out a DataFrame to a table /// 2. Write out a DataFrame to a parquet file /// 3. Write out a DataFrame to a csv file /// 4. Write out a DataFrame to a json file async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionError> { - let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); - - // Ensure the column names and types match the target table - df = df.with_column_renamed("column1", "tablecol1").unwrap(); + let array = StringViewArray::from(vec!["a", "b", "c"]); + let schema = Arc::new(Schema::new(vec![Field::new( + "tablecol1", + DataType::Utf8View, + false, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)])?; + let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]])?; + ctx.register_table("initial_data", Arc::new(mem_table))?; + let df = ctx.table("initial_data").await?; ctx.sql( "create external table diff --git a/datafusion-examples/examples/date_time_functions.rs b/datafusion-examples/examples/date_time_functions.rs index dbe9970439df7..2628319ae31f0 100644 --- a/datafusion-examples/examples/date_time_functions.rs +++ b/datafusion-examples/examples/date_time_functions.rs @@ -492,14 +492,14 @@ async fn query_to_char() -> Result<()> { assert_batches_eq!( &[ - "+------------------------------+", - "| to_char(t.values,t.patterns) |", - "+------------------------------+", - "| 2020-09-01 |", - "| 2020:09:02 |", - "| 20200903 |", - "| 04-09-2020 |", - "+------------------------------+", + "+----------------------------------+", + "| date_format(t.values,t.patterns) |", + "+----------------------------------+", + "| 2020-09-01 |", + "| 2020:09:02 |", + "| 20200903 |", + "| 04-09-2020 |", + "+----------------------------------+", ], &result ); diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs new file mode 100644 index 0000000000000..43e2d4ca09884 --- /dev/null +++ b/datafusion-examples/examples/default_column_values.rs @@ -0,0 +1,398 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use async_trait::async_trait; + +use datafusion::assert_batches_eq; +use datafusion::catalog::memory::DataSourceExec; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::common::DFSchema; +use datafusion::common::{Result, ScalarValue}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::{lit, SessionConfig}; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use futures::StreamExt; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; + +// Metadata key for storing default values in field metadata +const DEFAULT_VALUE_METADATA_KEY: &str = "example.default_value"; + +// Example showing how to implement custom default value handling for missing columns +// using field metadata and PhysicalExprAdapter. +// +// This example demonstrates how to: +// 1. Store default values in field metadata using a constant key +// 2. Create a custom PhysicalExprAdapter that reads these defaults +// 3. Inject default values for missing columns in filter predicates +// 4. Use the DefaultPhysicalExprAdapter as a fallback for standard schema adaptation +// 5. Wrap string default values in cast expressions for proper type conversion +// +// Important: PhysicalExprAdapter is specifically designed for rewriting filter predicates +// that get pushed down to file scans. For handling missing columns in projections, +// other mechanisms in DataFusion are used (like SchemaAdapter). +// +// The metadata-based approach provides a flexible way to store default values as strings +// and cast them to the appropriate types at query time. + +#[tokio::main] +async fn main() -> Result<()> { + println!("=== Creating example data with missing columns and default values ==="); + + // Create sample data where the logical schema has more columns than the physical schema + let (logical_schema, physical_schema, batch) = create_sample_data_with_defaults(); + + let store = InMemory::new(); + let buf = { + let mut buf = vec![]; + + let props = WriterProperties::builder() + .set_max_row_group_size(2) + .build(); + + let mut writer = + ArrowWriter::try_new(&mut buf, physical_schema.clone(), Some(props)) + .expect("creating writer"); + + writer.write(&batch).expect("Writing batch"); + writer.close().unwrap(); + buf + }; + let path = Path::from("example.parquet"); + let payload = PutPayload::from_bytes(buf.into()); + store.put(&path, payload).await?; + + // Create a custom table provider that handles missing columns with defaults + let table_provider = Arc::new(DefaultValueTableProvider::new(logical_schema)); + + // Set up query execution + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + + // Register our table + ctx.register_table("example_table", table_provider)?; + + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::new(store), + ); + + println!("\n=== Demonstrating default value injection in filter predicates ==="); + let query = "SELECT id, name FROM example_table WHERE status = 'active' ORDER BY id"; + println!("Query: {query}"); + println!("Note: The 'status' column doesn't exist in the physical schema,"); + println!( + "but our adapter injects the default value 'active' for the filter predicate." + ); + + let batches = ctx.sql(query).await?.collect().await?; + + #[rustfmt::skip] + let expected = [ + "+----+-------+", + "| id | name |", + "+----+-------+", + "| 1 | Alice |", + "| 2 | Bob |", + "| 3 | Carol |", + "+----+-------+", + ]; + arrow::util::pretty::print_batches(&batches)?; + assert_batches_eq!(expected, &batches); + + println!("\n=== Key Insight ==="); + println!("This example demonstrates how PhysicalExprAdapter works:"); + println!("1. Physical schema only has 'id' and 'name' columns"); + println!("2. Logical schema has 'id', 'name', 'status', and 'priority' columns with defaults"); + println!("3. Our custom adapter intercepts filter expressions on missing columns"); + println!("4. Default values from metadata are injected as cast expressions"); + println!("5. The DefaultPhysicalExprAdapter handles other schema adaptations"); + println!("\nNote: PhysicalExprAdapter is specifically for filter predicates."); + println!("For projection columns, different mechanisms handle missing columns."); + + Ok(()) +} + +/// Create sample data with a logical schema that has default values in metadata +/// and a physical schema that's missing some columns +fn create_sample_data_with_defaults() -> (SchemaRef, SchemaRef, RecordBatch) { + // Create metadata for default values + let mut status_metadata = HashMap::new(); + status_metadata.insert(DEFAULT_VALUE_METADATA_KEY.to_string(), "active".to_string()); + + let mut priority_metadata = HashMap::new(); + priority_metadata.insert(DEFAULT_VALUE_METADATA_KEY.to_string(), "1".to_string()); + + // The logical schema includes all columns with their default values in metadata + // Note: We make the columns with defaults nullable to allow the default adapter to handle them + let logical_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("status", DataType::Utf8, true).with_metadata(status_metadata), + Field::new("priority", DataType::Int32, true).with_metadata(priority_metadata), + ]); + + // The physical schema only has some columns (simulating missing columns in storage) + let physical_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ]); + + // Create sample data for the physical schema + let batch = RecordBatch::try_new( + Arc::new(physical_schema.clone()), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow::array::StringArray::from(vec![ + "Alice", "Bob", "Carol", + ])), + ], + ) + .unwrap(); + + (Arc::new(logical_schema), Arc::new(physical_schema), batch) +} + +/// Custom TableProvider that uses DefaultValuePhysicalExprAdapter +#[derive(Debug)] +struct DefaultValueTableProvider { + schema: SchemaRef, +} + +impl DefaultValueTableProvider { + fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +#[async_trait] +impl TableProvider for DefaultValueTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let schema = self.schema.clone(); + let df_schema = DFSchema::try_from(schema.clone())?; + let filter = state.create_physical_expr( + conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), + &df_schema, + )?; + + let parquet_source = ParquetSource::default() + .with_predicate(filter) + .with_pushdown_filters(true); + + let object_store_url = ObjectStoreUrl::parse("memory://")?; + let store = state.runtime_env().object_store(object_store_url)?; + + let mut files = vec![]; + let mut listing = store.list(None); + while let Some(file) = listing.next().await { + if let Ok(file) = file { + files.push(file); + } + } + + let file_group = files + .iter() + .map(|file| PartitionedFile::new(file.location.clone(), file.size)) + .collect(); + + let file_scan_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("memory://")?, + self.schema.clone(), + Arc::new(parquet_source), + ) + .with_projection(projection.cloned()) + .with_limit(limit) + .with_file_group(file_group) + .with_expr_adapter(Some(Arc::new(DefaultValuePhysicalExprAdapterFactory) as _)); + + Ok(Arc::new(DataSourceExec::new(Arc::new( + file_scan_config.build(), + )))) + } +} + +/// Factory for creating DefaultValuePhysicalExprAdapter instances +#[derive(Debug)] +struct DefaultValuePhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for DefaultValuePhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + let default_factory = DefaultPhysicalExprAdapterFactory; + let default_adapter = default_factory + .create(logical_file_schema.clone(), physical_file_schema.clone()); + + Arc::new(DefaultValuePhysicalExprAdapter { + logical_file_schema, + physical_file_schema, + default_adapter, + partition_values: Vec::new(), + }) + } +} + +/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata +/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation +#[derive(Debug)] +struct DefaultValuePhysicalExprAdapter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + default_adapter: Arc, + partition_values: Vec<(FieldRef, ScalarValue)>, +} + +impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + // First try our custom default value injection for missing columns + let rewritten = expr + .transform(|expr| { + self.inject_default_values( + expr, + &self.logical_file_schema, + &self.physical_file_schema, + ) + }) + .data()?; + + // Then apply the default adapter as a fallback to handle standard schema differences + // like type casting, partition column handling, etc. + let default_adapter = if !self.partition_values.is_empty() { + self.default_adapter + .with_partition_values(self.partition_values.clone()) + } else { + self.default_adapter.clone() + }; + + default_adapter.rewrite(rewritten) + } + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc { + Arc::new(DefaultValuePhysicalExprAdapter { + logical_file_schema: self.logical_file_schema.clone(), + physical_file_schema: self.physical_file_schema.clone(), + default_adapter: self.default_adapter.clone(), + partition_values, + }) + } +} + +impl DefaultValuePhysicalExprAdapter { + fn inject_default_values( + &self, + expr: Arc, + logical_file_schema: &Schema, + physical_file_schema: &Schema, + ) -> Result>> { + if let Some(column) = expr.as_any().downcast_ref::() { + let column_name = column.name(); + + // Check if this column exists in the physical schema + if physical_file_schema.index_of(column_name).is_err() { + // Column is missing from physical schema, check if logical schema has a default + if let Ok(logical_field) = + logical_file_schema.field_with_name(column_name) + { + if let Some(default_value_str) = + logical_field.metadata().get(DEFAULT_VALUE_METADATA_KEY) + { + // Create a string literal and wrap it in a cast expression + let default_literal = self.create_default_value_expr( + default_value_str, + logical_field.data_type(), + )?; + return Ok(Transformed::yes(default_literal)); + } + } + } + } + + // No transformation needed + Ok(Transformed::no(expr)) + } + + fn create_default_value_expr( + &self, + value_str: &str, + data_type: &DataType, + ) -> Result> { + // Create a string literal with the default value + let string_literal = + Arc::new(Literal::new(ScalarValue::Utf8(Some(value_str.to_string())))); + + // If the target type is already Utf8, return the string literal directly + if matches!(data_type, DataType::Utf8) { + return Ok(string_literal); + } + + // Otherwise, wrap the string literal in a cast expression + let cast_expr = Arc::new(CastExpr::new(string_literal, data_type.clone(), None)); + + Ok(cast_expr) + } +} diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index b61a350a5a9c4..56f960870e58a 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, - Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(5)), None)), )); assert_eq!(expr, expr2); @@ -85,7 +85,7 @@ async fn main() -> Result<()> { boundary_analysis_and_selectivity_demo()?; // See how boundary analysis works for `AND` & `OR` conjunctions. - boundary_analysis_in_conjuctions_demo()?; + boundary_analysis_in_conjunctions_demo()?; // See how to determine the data types of expressions expression_type_demo()?; @@ -147,8 +147,7 @@ fn evaluate_demo() -> Result<()> { ])) as _; assert!( matches!(&result, ColumnarValue::Array(r) if r == &expected_result), - "result: {:?}", - result + "result: {result:?}" ); Ok(()) @@ -352,7 +351,7 @@ fn boundary_analysis_and_selectivity_demo() -> Result<()> { /// This function shows how to think about and leverage the analysis API /// to infer boundaries in `AND` & `OR` conjunctions. -fn boundary_analysis_in_conjuctions_demo() -> Result<()> { +fn boundary_analysis_in_conjunctions_demo() -> Result<()> { // Let us consider the more common case of AND & OR conjunctions. // // age > 18 AND age <= 25 @@ -424,7 +423,7 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { // // But `AND` conjunctions are easier to reason with because their interval // arithmetic follows naturally from set intersection operations, let us - // now look at an example that is a tad more complicated `OR` conjunctions. + // now look at an example that is a tad more complicated `OR` disjunctions. // The expression we will look at is `age > 60 OR age <= 18`. let age_greater_than_60_less_than_18 = @@ -435,7 +434,7 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { // // Initial range: [14, 79] as described in our column statistics. // - // From the left-hand side and right-hand side of our `OR` conjunctions + // From the left-hand side and right-hand side of our `OR` disjunctions // we end up with two ranges, instead of just one. // // - age > 60: [61, 79] @@ -446,7 +445,8 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { let physical_expr = SessionContext::new() .create_physical_expr(age_greater_than_60_less_than_18, &df_schema)?; - // Since we don't handle interval arithmetic for `OR` operator this will error out. + // However, analysis only supports a single interval, so we don't yet deal + // with the multiple possibilities of the `OR` disjunctions. let analysis = analyze( &physical_expr, AnalysisContext::new(initial_boundaries), @@ -519,7 +519,7 @@ fn type_coercion_demo() -> Result<()> { )?; let i8_array = Int8Array::from_iter_values(vec![0, 1, 2]); let batch = RecordBatch::try_new( - Arc::new(df_schema.as_arrow().to_owned()), + Arc::clone(df_schema.inner()), vec![Arc::new(i8_array) as _], )?; diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index cc5f43746ddfb..58bfb7a341c19 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -98,7 +98,7 @@ impl FlightService for FlightServiceImpl { let df = ctx.sql(sql).await.map_err(to_tonic_err)?; // execute the query - let schema = df.schema().clone().into(); + let schema = Arc::clone(df.schema().inner()); let results = df.collect().await.map_err(to_tonic_err)?; if results.is_empty() { return Err(Status::internal("There were no results from ticket")); diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs index 54e8de7177cbe..c35debec7d712 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -115,6 +115,7 @@ impl FlightSqlServiceImpl { Ok(uuid) } + #[allow(clippy::result_large_err)] fn get_ctx(&self, req: &Request) -> Result, Status> { // get the token from the authorization header on Request let auth = req @@ -140,6 +141,7 @@ impl FlightSqlServiceImpl { } } + #[allow(clippy::result_large_err)] fn get_plan(&self, handle: &str) -> Result { if let Some(plan) = self.statements.get(handle) { Ok(plan.clone()) @@ -148,6 +150,7 @@ impl FlightSqlServiceImpl { } } + #[allow(clippy::result_large_err)] fn get_result(&self, handle: &str) -> Result, Status> { if let Some(result) = self.results.get(handle) { Ok(result.clone()) @@ -195,11 +198,13 @@ impl FlightSqlServiceImpl { .unwrap() } + #[allow(clippy::result_large_err)] fn remove_plan(&self, handle: &str) -> Result<(), Status> { self.statements.remove(&handle.to_string()); Ok(()) } + #[allow(clippy::result_large_err)] fn remove_result(&self, handle: &str) -> Result<(), Status> { self.results.remove(&handle.to_string()); Ok(()) @@ -390,10 +395,8 @@ impl FlightSqlService for FlightSqlServiceImpl { let plan_uuid = Uuid::new_v4().hyphenated().to_string(); self.statements.insert(plan_uuid.clone(), plan.clone()); - let plan_schema = plan.schema(); - - let arrow_schema = (&**plan_schema).into(); - let message = SchemaAsIpc::new(&arrow_schema, &IpcWriteOptions::default()) + let arrow_schema = plan.schema().as_arrow(); + let message = SchemaAsIpc::new(arrow_schema, &IpcWriteOptions::default()) .try_into() .map_err(|e| status!("Unable to serialize schema", e))?; let IpcMessage(schema_bytes) = message; diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index 06367f5c09e36..d4312ae594091 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -17,7 +17,7 @@ use arrow::datatypes::DataType; use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{exec_err, internal_err, DataFusionError}; +use datafusion::common::{exec_datafusion_err, exec_err, internal_err, DataFusionError}; use datafusion::error::Result; use datafusion::execution::context::{ FunctionFactory, RegisterFunction, SessionContext, SessionState, @@ -28,6 +28,7 @@ use datafusion::logical_expr::{ ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use std::hash::Hash; use std::result::Result as RResult; use std::sync::Arc; @@ -106,7 +107,7 @@ impl FunctionFactory for CustomFunctionFactory { } /// this function represents the newly created execution engine. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct ScalarFunctionWrapper { /// The text of the function body, `$1 + f1($2)` in our example name: String, @@ -150,10 +151,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(ExprSimplifyResult::Simplified(replacement)) } - fn aliases(&self) -> &[String] { - &[] - } - fn output_ordering(&self, _input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } @@ -188,10 +185,7 @@ impl ScalarFunctionWrapper { fn parse_placeholder_identifier(placeholder: &str) -> Result { if let Some(value) = placeholder.strip_prefix('$') { Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { - DataFusionError::Execution(format!( - "Placeholder `{}` parsing error: {}!", - placeholder, e - )) + exec_datafusion_err!("Placeholder `{placeholder}` parsing error: {e}!") })?) } else { exec_err!("Placeholder should start with `$`!") diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs new file mode 100644 index 0000000000000..c7d0146a001f7 --- /dev/null +++ b/datafusion-examples/examples/json_shredding.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; + +use datafusion::assert_batches_eq; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion::common::{assert_contains, exec_datafusion_err, Result}; +use datafusion::datasource::listing::{ + ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; +use datafusion::prelude::SessionConfig; +use datafusion::scalar::ScalarValue; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, +}; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; + +// Example showing how to implement custom filter rewriting for JSON shredding. +// +// JSON shredding is a technique for optimizing queries on semi-structured data +// by materializing commonly accessed fields into separate columns for better +// columnar storage performance. +// +// In this example, we have a table with both: +// - Original JSON data: data: '{"age": 30}' +// - Shredded flat columns: _data.name: "Alice" (extracted from JSON) +// +// Our custom TableProvider uses a PhysicalExprAdapter to rewrite +// expressions like `json_get_str('name', data)` to use the pre-computed +// flat column `_data.name` when available. This allows the query engine to: +// 1. Push down predicates for better filtering +// 2. Avoid expensive JSON parsing at query time +// 3. Leverage columnar storage benefits for the materialized fields +#[tokio::main] +async fn main() -> Result<()> { + println!("=== Creating example data with flat columns and underscore prefixes ==="); + + // Create sample data with flat columns using underscore prefixes + let (table_schema, batch) = create_sample_data(); + + let store = InMemory::new(); + let buf = { + let mut buf = vec![]; + + let props = WriterProperties::builder() + .set_max_row_group_size(2) + .build(); + + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)) + .expect("creating writer"); + + writer.write(&batch).expect("Writing batch"); + writer.close().unwrap(); + buf + }; + let path = Path::from("example.parquet"); + let payload = PutPayload::from_bytes(buf.into()); + store.put(&path, payload).await?; + + // Set up query execution + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::new(store), + ); + + // Create a custom table provider that rewrites struct field access + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///example.parquet")?) + .infer_options(&ctx.state()) + .await? + .with_schema(table_schema) + .with_expr_adapter_factory(Arc::new(ShreddedJsonRewriterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + let table_provider = Arc::new(table); + + // Register our table + ctx.register_table("structs", table_provider)?; + ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default())); + + println!("\n=== Showing all data ==="); + let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?; + arrow::util::pretty::print_batches(&batches)?; + + println!("\n=== Running query with flat column access and filter ==="); + let query = "SELECT json_get_str('age', data) as age FROM structs WHERE json_get_str('name', data) = 'Bob'"; + println!("Query: {query}"); + + let batches = ctx.sql(query).await?.collect().await?; + + #[rustfmt::skip] + let expected = [ + "+-----+", + "| age |", + "+-----+", + "| 25 |", + "+-----+", + ]; + arrow::util::pretty::print_batches(&batches)?; + assert_batches_eq!(expected, &batches); + + println!("\n=== Running explain analyze to confirm row group pruning ==="); + + let batches = ctx + .sql(&format!("EXPLAIN ANALYZE {query}")) + .await? + .collect() + .await?; + let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?); + println!("{plan}"); + assert_contains!(&plan, "row_groups_pruned_statistics=1"); + assert_contains!(&plan, "pushdown_rows_pruned=1"); + + Ok(()) +} + +/// Create the example data with flat columns using underscore prefixes. +/// +/// This demonstrates the logical data structure: +/// - Table schema: What users see (just the 'data' JSON column) +/// - File schema: What's physically stored (both 'data' and materialized '_data.name') +/// +/// The naming convention uses underscore prefixes to indicate shredded columns: +/// - `data` -> original JSON column +/// - `_data.name` -> materialized field from JSON data.name +fn create_sample_data() -> (SchemaRef, RecordBatch) { + // The table schema only has the main data column - this is what users query against + let table_schema = Schema::new(vec![Field::new("data", DataType::Utf8, false)]); + + // The file schema has both the main column and the shredded flat column with underscore prefix + // This represents the actual physical storage with pre-computed columns + let file_schema = Schema::new(vec![ + Field::new("data", DataType::Utf8, false), // Original JSON data + Field::new("_data.name", DataType::Utf8, false), // Materialized name field + ]); + + let batch = create_sample_record_batch(&file_schema); + + (Arc::new(table_schema), batch) +} + +/// Create the actual RecordBatch with sample data +fn create_sample_record_batch(file_schema: &Schema) -> RecordBatch { + // Build a RecordBatch with flat columns + let data_array = StringArray::from(vec![ + r#"{"age": 30}"#, + r#"{"age": 25}"#, + r#"{"age": 35}"#, + r#"{"age": 22}"#, + ]); + let names_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]); + + RecordBatch::try_new( + Arc::new(file_schema.clone()), + vec![Arc::new(data_array), Arc::new(names_array)], + ) + .unwrap() +} + +/// Scalar UDF that uses serde_json to access json fields +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct JsonGetStr { + signature: Signature, +} + +impl Default for JsonGetStr { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for JsonGetStr { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_get_str" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert!( + args.args.len() == 2, + "json_get_str requires exactly 2 arguments" + ); + let key = match &args.args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key, + _ => { + return Err(exec_datafusion_err!( + "json_get_str first argument must be a string" + )) + } + }; + // We expect a string array that contains JSON strings + let json_array = match &args.args[1] { + ColumnarValue::Array(array) => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + exec_datafusion_err!( + "json_get_str second argument must be a string array" + ) + })?, + _ => { + return Err(exec_datafusion_err!( + "json_get_str second argument must be a string array" + )) + } + }; + let values = json_array + .iter() + .map(|value| { + value.and_then(|v| { + let json_value: serde_json::Value = + serde_json::from_str(v).unwrap_or_default(); + json_value.get(key).map(|v| v.to_string()) + }) + }) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } +} + +/// Factory for creating ShreddedJsonRewriter instances +#[derive(Debug)] +struct ShreddedJsonRewriterFactory; + +impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + let default_factory = DefaultPhysicalExprAdapterFactory; + let default_adapter = default_factory + .create(logical_file_schema.clone(), physical_file_schema.clone()); + + Arc::new(ShreddedJsonRewriter { + logical_file_schema, + physical_file_schema, + default_adapter, + partition_values: Vec::new(), + }) + } +} + +/// Rewriter that converts json_get_str calls to direct flat column references +/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation +#[derive(Debug)] +struct ShreddedJsonRewriter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + default_adapter: Arc, + partition_values: Vec<(FieldRef, ScalarValue)>, +} + +impl PhysicalExprAdapter for ShreddedJsonRewriter { + fn rewrite(&self, expr: Arc) -> Result> { + // First try our custom JSON shredding rewrite + let rewritten = expr + .transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema)) + .data()?; + + // Then apply the default adapter as a fallback to handle standard schema differences + // like type casting, missing columns, and partition column handling + let default_adapter = if !self.partition_values.is_empty() { + self.default_adapter + .with_partition_values(self.partition_values.clone()) + } else { + self.default_adapter.clone() + }; + + default_adapter.rewrite(rewritten) + } + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc { + Arc::new(ShreddedJsonRewriter { + logical_file_schema: self.logical_file_schema.clone(), + physical_file_schema: self.physical_file_schema.clone(), + default_adapter: self.default_adapter.clone(), + partition_values, + }) + } +} + +impl ShreddedJsonRewriter { + fn rewrite_impl( + &self, + expr: Arc, + physical_file_schema: &Schema, + ) -> Result>> { + if let Some(func) = expr.as_any().downcast_ref::() { + if func.name() == "json_get_str" && func.args().len() == 2 { + // Get the key from the first argument + if let Some(literal) = func.args()[0] + .as_any() + .downcast_ref::() + { + if let ScalarValue::Utf8(Some(field_name)) = literal.value() { + // Get the column from the second argument + if let Some(column) = func.args()[1] + .as_any() + .downcast_ref::() + { + let column_name = column.name(); + // Check if there's a flat column with underscore prefix + let flat_column_name = format!("_{column_name}.{field_name}"); + + if let Ok(flat_field_index) = + physical_file_schema.index_of(&flat_column_name) + { + let flat_field = + physical_file_schema.field(flat_field_index); + + if flat_field.data_type() == &DataType::Utf8 { + // Replace the whole expression with a direct column reference + let new_expr = Arc::new(expressions::Column::new( + &flat_column_name, + flat_field_index, + )) + as Arc; + + return Ok(Transformed { + data: new_expr, + tnr: TreeNodeRecursion::Stop, + transformed: true, + }); + } + } + } + } + } + } + } + Ok(Transformed::no(expr)) + } +} diff --git a/datafusion-examples/examples/memory_pool_execution_plan.rs b/datafusion-examples/examples/memory_pool_execution_plan.rs new file mode 100644 index 0000000000000..3258cde17625f --- /dev/null +++ b/datafusion-examples/examples/memory_pool_execution_plan.rs @@ -0,0 +1,300 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates how to implement custom ExecutionPlans that properly +//! use memory tracking through TrackConsumersPool. +//! +//! This shows the pattern for implementing memory-aware operators that: +//! - Register memory consumers with the pool +//! - Reserve memory before allocating +//! - Handle memory pressure by spilling to disk +//! - Release memory when done + +use arrow::record_batch::RecordBatch; +use arrow_schema::SchemaRef; +use datafusion::common::record_batch; +use datafusion::common::{exec_datafusion_err, internal_err}; +use datafusion::datasource::{memory::MemTable, DefaultTableSource}; +use datafusion::error::Result; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::logical_expr::LogicalPlanBuilder; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, Statistics, +}; +use datafusion::prelude::*; +use futures::stream::{StreamExt, TryStreamExt}; +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("=== DataFusion ExecutionPlan Memory Tracking Example ===\n"); + + // Set up a runtime with memory tracking + // Set a low memory limit to trigger spilling on the second batch + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(15000, 1.0) // Allow only enough for 1 batch at once + .build_arc()?; + + let config = SessionConfig::new().with_coalesce_batches(false); + let ctx = SessionContext::new_with_config_rt(config, runtime.clone()); + + // Create smaller batches to ensure we get multiple RecordBatches from the scan + // Make each batch smaller than the memory limit to force multiple batches + let batch1 = record_batch!( + ("id", Int32, vec![1; 800]), + ("name", Utf8, vec!["Alice"; 800]) + )?; + + let batch2 = record_batch!( + ("id", Int32, vec![2; 800]), + ("name", Utf8, vec!["Bob"; 800]) + )?; + + let batch3 = record_batch!( + ("id", Int32, vec![3; 800]), + ("name", Utf8, vec!["Charlie"; 800]) + )?; + + let batch4 = record_batch!( + ("id", Int32, vec![4; 800]), + ("name", Utf8, vec!["David"; 800]) + )?; + + let schema = batch1.schema(); + + // Create a single MemTable with all batches in one partition to preserve order but ensure streaming + let mem_table = Arc::new(MemTable::try_new( + Arc::clone(&schema), + vec![vec![batch1, batch2, batch3, batch4]], // Single partition with multiple batches + )?); + + // Build logical plan with a single scan that will yield multiple batches + let table_source = Arc::new(DefaultTableSource::new(mem_table)); + let logical_plan = + LogicalPlanBuilder::scan("multi_batch_table", table_source, None)?.build()?; + + // Convert to physical plan + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + + println!("Example: Custom Memory-Aware BufferingExecutionPlan"); + println!("---------------------------------------------------"); + + // Wrap our input plan with our custom BufferingExecutionPlan + let buffering_plan = Arc::new(BufferingExecutionPlan::new(schema, physical_plan)); + + // Create a task context from our runtime + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + + // Execute the plan directly to demonstrate memory tracking + println!("Executing BufferingExecutionPlan with memory tracking..."); + println!("Memory limit: 15000 bytes - should trigger spill on later batches\n"); + + let stream = buffering_plan.execute(0, task_ctx.clone())?; + let _results: Vec = stream.try_collect().await?; + + println!("\nSuccessfully executed BufferingExecutionPlan!"); + + println!("\nThe BufferingExecutionPlan processed 4 input batches and"); + println!("demonstrated memory tracking with spilling behavior when the"); + println!("memory limit was exceeded by later batches."); + println!("Check the console output above to see the spill messages."); + + Ok(()) +} + +/// Example of an external batch bufferer that uses memory reservation. +/// +/// It's a simple example which spills all existing data to disk +/// whenever the memory limit is reached. +struct ExternalBatchBufferer { + buffer: Vec, + reservation: MemoryReservation, + spill_count: usize, +} + +impl ExternalBatchBufferer { + fn new(reservation: MemoryReservation) -> Self { + Self { + buffer: Vec::new(), + reservation, + spill_count: 0, + } + } + + fn add_batch(&mut self, batch_data: Vec) -> Result<()> { + let additional_memory = batch_data.len(); + + // Try to reserve memory before allocating + if self.reservation.try_grow(additional_memory).is_err() { + // Memory limit reached - handle by spilling + println!( + "Memory limit reached, spilling {} bytes to disk", + self.buffer.len() + ); + self.spill_to_disk()?; + + // Try again after spilling + self.reservation.try_grow(additional_memory)?; + } + + self.buffer.extend_from_slice(&batch_data); + println!( + "Added batch of {} bytes, total buffered: {} bytes", + additional_memory, + self.buffer.len() + ); + Ok(()) + } + + fn spill_to_disk(&mut self) -> Result<()> { + // Simulate writing buffer to disk + self.spill_count += 1; + println!( + "Spill #{}: Writing {} bytes to disk", + self.spill_count, + self.buffer.len() + ); + + // Free the memory after spilling + let freed_bytes = self.buffer.len(); + self.buffer.clear(); + self.reservation.shrink(freed_bytes); + + Ok(()) + } + + fn finish(&mut self) -> Vec { + let result = std::mem::take(&mut self.buffer); + // Free the memory when done + self.reservation.free(); + println!("Finished processing, released {} bytes", result.len()); + result + } +} + +/// Example of an ExecutionPlan that uses the ExternalBatchBufferer. +#[derive(Debug)] +struct BufferingExecutionPlan { + schema: SchemaRef, + input: Arc, + properties: PlanProperties, +} + +impl BufferingExecutionPlan { + fn new(schema: SchemaRef, input: Arc) -> Self { + let properties = input.properties().clone(); + + Self { + schema, + input, + properties, + } + } +} + +impl DisplayAs for BufferingExecutionPlan { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "BufferingExecutionPlan") + } +} + +impl ExecutionPlan for BufferingExecutionPlan { + fn name(&self) -> &'static str { + "BufferingExecutionPlan" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() == 1 { + Ok(Arc::new(BufferingExecutionPlan::new( + self.schema.clone(), + children[0].clone(), + ))) + } else { + internal_err!("BufferingExecutionPlan must have exactly one child") + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + // Register memory consumer with the context's memory pool + let reservation = MemoryConsumer::new("MyExternalBatchBufferer") + .with_can_spill(true) + .register(context.memory_pool()); + + // Incoming stream of batches + let mut input_stream = self.input.execute(partition, context)?; + + // Process the stream and collect all batches + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once(async move { + let mut operator = ExternalBatchBufferer::new(reservation); + + while let Some(batch) = input_stream.next().await { + let batch = batch?; + + // Convert RecordBatch to bytes for this example + let batch_data = vec![1u8; batch.get_array_memory_size()]; + + operator.add_batch(batch_data)?; + } + + // Finish processing and get results + let _final_result = operator.finish(); + // In a real implementation, you would convert final_result back to RecordBatches + + // Since this is a simplified example, return an empty batch + // In a real implementation, you would create a batch stream from the processed results + record_batch!(("id", Int32, vec![5]), ("name", Utf8, vec!["Eve"])) + .map_err(|e| { + exec_datafusion_err!("Failed to create final RecordBatch: {e}") + }) + }), + ))) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema)) + } +} diff --git a/datafusion-examples/examples/memory_pool_tracking.rs b/datafusion-examples/examples/memory_pool_tracking.rs new file mode 100644 index 0000000000000..d5823b1173ab3 --- /dev/null +++ b/datafusion-examples/examples/memory_pool_tracking.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates how to use TrackConsumersPool for memory tracking and debugging. +//! +//! The TrackConsumersPool provides enhanced error messages that show the top memory consumers +//! when memory allocation fails, making it easier to debug memory issues in DataFusion queries. +//! +//! # Examples +//! +//! * [`automatic_usage_example`]: Shows how to use RuntimeEnvBuilder to automatically enable memory tracking + +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::prelude::*; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("=== DataFusion Memory Pool Tracking Example ===\n"); + + // Example 1: Automatic Usage with RuntimeEnvBuilder + automatic_usage_example().await?; + + Ok(()) +} + +/// Example 1: Automatic Usage with RuntimeEnvBuilder +/// +/// This shows the recommended way to use TrackConsumersPool through RuntimeEnvBuilder, +/// which automatically creates a TrackConsumersPool with sensible defaults. +async fn automatic_usage_example() -> datafusion::error::Result<()> { + println!("Example 1: Automatic Usage with RuntimeEnvBuilder"); + println!("------------------------------------------------"); + + // Success case: Create a runtime with reasonable memory limit + println!("Success case: Normal operation with sufficient memory"); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(5_000_000, 1.0) // 5MB, 100% utilization + .build_arc()?; + + let config = SessionConfig::new(); + let ctx = SessionContext::new_with_config_rt(config, runtime); + + // Create a simple table for demonstration + ctx.sql("CREATE TABLE test AS VALUES (1, 'a'), (2, 'b'), (3, 'c')") + .await? + .collect() + .await?; + + println!("✓ Created table with memory tracking enabled"); + + // Run a simple query to show it works + let results = ctx.sql("SELECT * FROM test").await?.collect().await?; + println!( + "✓ Query executed successfully. Found {} rows", + results.len() + ); + + println!("\n{}", "-".repeat(50)); + + // Error case: Create a runtime with low memory limit to trigger errors + println!("Error case: Triggering memory limit error with detailed error messages"); + + // Use a WITH query that generates data and then processes it to trigger memory usage + match ctx.sql(" + WITH large_dataset AS ( + SELECT + column1 as id, + column1 * 2 as doubled, + repeat('data_', 20) || column1 as text_field, + column1 * column1 as squared + FROM generate_series(1, 2000) as t(column1) + ), + aggregated AS ( + SELECT + id, + doubled, + text_field, + squared, + sum(doubled) OVER (ORDER BY id ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) as running_sum + FROM large_dataset + ) + SELECT + a1.id, + a1.text_field, + a2.text_field as text_field2, + a1.running_sum + a2.running_sum as combined_sum + FROM aggregated a1 + JOIN aggregated a2 ON a1.id = a2.id - 1 + ORDER BY a1.id + ").await?.collect().await { + Ok(results) => panic!("Should not succeed! Yet got {} batches", results.len()), + Err(e) => { + println!("✓ Expected memory limit error during data processing:"); + println!("Error: {e}"); + /* Example error message: + Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes + caused by + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + ExternalSorterMerge[3]#112(can spill: false) consumed 10.0 MB, peak 10.0 MB, + ExternalSorterMerge[10]#147(can spill: false) consumed 10.0 MB, peak 10.0 MB, + ExternalSorter[1]#93(can spill: true) consumed 69.0 KB, peak 69.0 KB, + ExternalSorter[13]#155(can spill: true) consumed 67.6 KB, peak 67.6 KB, + ExternalSorter[8]#140(can spill: true) consumed 67.2 KB, peak 67.2 KB. + Error: Failed to allocate additional 10.0 MB for ExternalSorterMerge[0] with 0.0 B already allocated for this reservation - 7.1 MB remain available for the total pool + */ + } + } + + println!("\nNote: The error message above shows which memory consumers"); + println!("were using the most memory when the limit was exceeded."); + + Ok(()) +} diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 63f17484809e2..9c137b67432c5 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -171,11 +171,11 @@ fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { /// Return true if the expression is a literal or column reference fn is_lit_or_col(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// A simple user defined filter function -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] struct MyEq { signature: Signature, } diff --git a/datafusion-examples/examples/parquet_embedded_index.rs b/datafusion-examples/examples/parquet_embedded_index.rs new file mode 100644 index 0000000000000..3cbe189147752 --- /dev/null +++ b/datafusion-examples/examples/parquet_embedded_index.rs @@ -0,0 +1,477 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Embedding and using a custom index in Parquet files +//! +//! # Background +//! +//! This example shows how to add an application‑specific index to an Apache +//! Parquet file without modifying the Parquet format itself. The resulting +//! files can be read by any standard Parquet reader, which will simply +//! ignore the extra index data. +//! +//! A “distinct value” index, similar to a ["set" Skip Index in ClickHouse], +//! is stored in a custom binary format within the parquet file. Only the +//! location of index is stored in Parquet footer key/value metadata. +//! This approach is more efficient than storing the index itself in the footer +//! metadata because the footer must be read and parsed by all readers, +//! even those that do not use the index. +//! +//! This example uses a file level index for skipping entire files, but any +//! index can be stored using the same techniques and used skip row groups, +//! data pages, or rows using the APIs on [`TableProvider`] and [`ParquetSource`]. +//! +//! The resulting Parquet file layout is as follows: +//! +//! ```text +//! ┌──────────────────────┐ +//! │┌───────────────────┐ │ +//! ││ DataPage │ │ +//! │└───────────────────┘ │ +//! Standard Parquet │┌───────────────────┐ │ +//! Data Pages ││ DataPage │ │ +//! │└───────────────────┘ │ +//! │ ... │ +//! │┌───────────────────┐ │ +//! ││ DataPage │ │ +//! │└───────────────────┘ │ +//! │┏━━━━━━━━━━━━━━━━━━━┓ │ +//! Non standard │┃ ┃ │ +//! index (ignored by │┃Custom Binary Index┃ │ +//! other Parquet │┃ (Distinct Values) ┃◀│─ ─ ─ +//! readers) │┃ ┃ │ │ +//! │┗━━━━━━━━━━━━━━━━━━━┛ │ +//! Standard Parquet │┏━━━━━━━━━━━━━━━━━━━┓ │ │ key/value metadata +//! Page Index │┃ Page Index ┃ │ contains location +//! │┗━━━━━━━━━━━━━━━━━━━┛ │ │ of special index +//! │╔═══════════════════╗ │ +//! │║ Parquet Footer w/ ║ │ │ +//! │║ Metadata ║ ┼ ─ ─ +//! │║ (Thrift Encoded) ║ │ +//! │╚═══════════════════╝ │ +//! └──────────────────────┘ +//! +//! Parquet File +//! +//! # High Level Flow +//! +//! To create a custom Parquet index: +//! +//! 1. Compute the index and serialize it to a binary format. +//! +//! 2. Write the Parquet file with: +//! - regular data pages +//! - the serialized index inline +//! - footer key/value metadata entry to locate the index +//! +//! To read and use the index are: +//! +//! 1. Read and deserialize the file’s footer to locate the index. +//! +//! 2. Read and deserialize the index. +//! +//! 3. Create a `TableProvider` that knows how to use the index to quickly find +//! the relevant files, row groups, data pages or rows based on on pushed down +//! filters. +//! +//! # FAQ: Why do other Parquet readers skip over the custom index? +//! +//! The flow for reading a parquet file is: +//! +//! 1. Seek to the end of the file and read the last 8 bytes (a 4‑byte +//! little‑endian footer length followed by the `PAR1` magic bytes). +//! +//! 2. Seek backwards by that length to parse the Thrift‑encoded footer +//! metadata (including key/value pairs). +//! +//! 3. Read data required for decoding such as data pages based on the offsets +//! encoded in the metadata. +//! +//! Since parquet readers do not scan from the start of the file they will read +//! data in the file unless it is explicitly referenced in the footer metadata. +//! +//! Thus other readers will encounter and ignore an unknown key +//! (`distinct_index_offset`) in the footer key/value metadata. Unless they +//! know how to use that information, they will not attempt to read or +//! the bytes that make up the index. +//! +//! ["set" Skip Index in ClickHouse]: https://clickhouse.com/docs/optimize/skipping-indexes#set + +use arrow::array::{ArrayRef, StringArray}; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::{exec_err, HashMap, HashSet, Result}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::memory::DataSourceExec; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::datasource::TableType; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::{Operator, TableProviderFilterPushDown}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::errors::ParquetError; +use datafusion::parquet::file::metadata::{FileMetaData, KeyValue}; +use datafusion::parquet::file::reader::{FileReader, SerializedFileReader}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; +use std::fs::{read_dir, File}; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tempfile::TempDir; + +/// An index of distinct values for a single column +/// +/// In this example the index is a simple set of strings, but in a real +/// application it could be any arbitrary data structure. +/// +/// Also, this example indexes the distinct values for an entire file +/// but a real application could create multiple indexes for multiple +/// row groups and/or columns, depending on the use case. +#[derive(Debug, Clone)] +struct DistinctIndex { + inner: HashSet, +} + +impl DistinctIndex { + /// Create a DistinctIndex from an iterator of strings + pub fn new>(iter: I) -> Self { + Self { + inner: iter.into_iter().collect(), + } + } + + /// Returns true if the index contains the given value + pub fn contains(&self, value: &str) -> bool { + self.inner.contains(value) + } + + /// Serialize the distinct index to a writer as bytes + /// + /// In this example, we use a simple newline-separated format, + /// but a real application can use any arbitrary binary format. + /// + /// Note that we must use the ArrowWriter to write the index so that its + /// internal accounting of offsets can correctly track the actual size of + /// the file. If we wrote directly to the underlying writer, the PageIndex + /// written right before the would be incorrect as they would not account + /// for the extra bytes written. + fn serialize( + &self, + arrow_writer: &mut ArrowWriter, + ) -> Result<()> { + let serialized = self + .inner + .iter() + .map(|s| s.as_str()) + .collect::>() + .join("\n"); + let index_bytes = serialized.into_bytes(); + + // Set the offset for the index + let offset = arrow_writer.bytes_written(); + let index_len = index_bytes.len() as u64; + + println!("Writing custom index at offset: {offset}, length: {index_len}"); + // Write the index magic and length to the file + arrow_writer.write_all(INDEX_MAGIC)?; + arrow_writer.write_all(&index_len.to_le_bytes())?; + + // Write the index bytes + arrow_writer.write_all(&index_bytes)?; + + // Append metadata about the index to the Parquet file footer + arrow_writer.append_key_value_metadata(KeyValue::new( + "distinct_index_offset".to_string(), + offset.to_string(), + )); + Ok(()) + } + + /// Read the distinct values index from a reader at the given offset and length + pub fn new_from_reader(mut reader: R, offset: u64) -> Result { + reader.seek(SeekFrom::Start(offset))?; + + let mut magic_buf = [0u8; 4]; + reader.read_exact(&mut magic_buf)?; + if magic_buf != INDEX_MAGIC { + return exec_err!("Invalid index magic number at offset {offset}"); + } + + let mut len_buf = [0u8; 8]; + reader.read_exact(&mut len_buf)?; + let stored_len = u64::from_le_bytes(len_buf) as usize; + + let mut index_buf = vec![0u8; stored_len]; + reader.read_exact(&mut index_buf)?; + + let Ok(s) = String::from_utf8(index_buf) else { + return exec_err!("Invalid UTF-8 in index data"); + }; + + Ok(Self { + inner: s.lines().map(|s| s.to_string()).collect(), + }) + } +} + +/// DataFusion [`TableProvider]` that reads Parquet files and uses a +/// `DistinctIndex` to prune files based on pushed down filters. +#[derive(Debug)] +struct DistinctIndexTable { + /// The schema of the table + schema: SchemaRef, + /// Key is file name, value is DistinctIndex for that file + files_and_index: HashMap, + /// Directory containing the Parquet files + dir: PathBuf, +} + +impl DistinctIndexTable { + /// Create a new DistinctIndexTable for files in the given directory + /// + /// Scans the directory, reading the `DistinctIndex` from each file + fn try_new(dir: impl Into, schema: SchemaRef) -> Result { + let dir = dir.into(); + let mut index = HashMap::new(); + + for entry in read_dir(&dir)? { + let path = entry?.path(); + if path.extension().and_then(|s| s.to_str()) != Some("parquet") { + continue; + } + let file_name = path.file_name().unwrap().to_string_lossy().to_string(); + + let distinct_set = read_distinct_index(&path)?; + + println!("Read distinct index for {file_name}: {file_name:?}"); + index.insert(file_name, distinct_set); + } + + Ok(Self { + schema, + files_and_index: index, + dir, + }) + } +} + +/// Wrapper around ArrowWriter to write Parquet files with an embedded index +struct IndexedParquetWriter { + writer: ArrowWriter, +} + +/// Magic bytes to identify our custom index format +const INDEX_MAGIC: &[u8] = b"IDX1"; + +impl IndexedParquetWriter { + pub fn try_new(sink: W, schema: Arc) -> Result { + let writer = ArrowWriter::try_new(sink, schema, None)?; + Ok(Self { writer }) + } + + /// Write a RecordBatch to the Parquet file + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.writer.write(batch)?; + Ok(()) + } + + /// Flush the current row group + pub fn flush(&mut self) -> Result<()> { + self.writer.flush()?; + Ok(()) + } + + /// Close the Parquet file, flushing any remaining data + pub fn close(self) -> Result<()> { + self.writer.close()?; + Ok(()) + } + + /// write the DistinctIndex to the Parquet file + pub fn write_index(&mut self, index: &DistinctIndex) -> Result<()> { + index.serialize(&mut self.writer) + } +} + +/// Write a Parquet file with a single column "category" containing the +/// strings in `values` and a DistinctIndex for that column. +fn write_file_with_index(path: &Path, values: &[&str]) -> Result<()> { + // form an input RecordBatch with the string values + let field = Field::new("category", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field.clone()])); + let arr: ArrayRef = Arc::new(StringArray::from(values.to_vec())); + let batch = RecordBatch::try_new(schema.clone(), vec![arr])?; + + // compute the distinct index + let distinct_index: DistinctIndex = + DistinctIndex::new(values.iter().map(|s| (*s).to_string())); + + let file = File::create(path)?; + + let mut writer = IndexedParquetWriter::try_new(file, schema.clone())?; + writer.write(&batch)?; + writer.flush()?; + writer.write_index(&distinct_index)?; + writer.close()?; + + println!("Finished writing file to {}", path.display()); + Ok(()) +} + +/// Read a `DistinctIndex` from a Parquet file +fn read_distinct_index(path: &Path) -> Result { + let file = File::open(path)?; + + let file_size = file.metadata()?.len(); + println!("Reading index from {} (size: {file_size})", path.display(),); + + let reader = SerializedFileReader::new(file.try_clone()?)?; + let meta = reader.metadata().file_metadata(); + + let offset = get_key_value(meta, "distinct_index_offset") + .ok_or_else(|| ParquetError::General("Missing index offset".into()))? + .parse::() + .map_err(|e| ParquetError::General(e.to_string()))?; + + println!("Reading index at offset: {offset}, length"); + DistinctIndex::new_from_reader(file, offset) +} + +/// Returns the value of a named key from the Parquet file metadata +/// +/// Returns None if the key is not found +fn get_key_value<'a>(file_meta_data: &'a FileMetaData, key: &'_ str) -> Option<&'a str> { + let kvs = file_meta_data.key_value_metadata()?; + let kv = kvs.iter().find(|kv| kv.key == key)?; + kv.value.as_deref() +} + +/// Implement TableProvider for DistinctIndexTable, using the distinct index to prune files +#[async_trait] +impl TableProvider for DistinctIndexTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + fn table_type(&self) -> TableType { + TableType::Base + } + + /// Prune files before reading: only keep files whose distinct set + /// contains the filter value + async fn scan( + &self, + _ctx: &dyn Session, + _proj: Option<&Vec>, + filters: &[Expr], + _limit: Option, + ) -> Result> { + // This example only handles filters of the form + // `category = 'X'` where X is a string literal + // + // You can use `PruningPredicate` for much more general range and + // equality analysis or write your own custom logic. + let mut target: Option<&str> = None; + + if filters.len() == 1 { + if let Expr::BinaryExpr(expr) = &filters[0] { + if expr.op == Operator::Eq { + if let ( + Expr::Column(c), + Expr::Literal(ScalarValue::Utf8(Some(v)), _), + ) = (&*expr.left, &*expr.right) + { + if c.name == "category" { + println!("Filtering for category: {v}"); + target = Some(v); + } + } + } + } + } + // Determine which files to scan + let files_to_scan: Vec<_> = self + .files_and_index + .iter() + .filter_map(|(f, distinct_index)| { + // keep file if no target or target is in the distinct set + if target.is_none() || distinct_index.contains(target?) { + Some(f) + } else { + None + } + }) + .collect(); + + println!("Scanning only files: {files_to_scan:?}"); + + // Build ParquetSource to actually read the files + let url = ObjectStoreUrl::parse("file://")?; + let source = Arc::new(ParquetSource::default().with_enable_page_index(true)); + let mut builder = FileScanConfigBuilder::new(url, self.schema.clone(), source); + for file in files_to_scan { + let path = self.dir.join(file); + let len = std::fs::metadata(&path)?.len(); + // If the index contained information about row groups or pages, + // you could also pass that information here to further prune + // the data read from the file. + let partitioned_file = + PartitionedFile::new(path.to_str().unwrap().to_string(), len); + builder = builder.with_file(partitioned_file); + } + Ok(DataSourceExec::from_data_source(builder.build())) + } + + /// Tell DataFusion that we can handle filters on the "category" column + fn supports_filters_pushdown( + &self, + fs: &[&Expr], + ) -> Result> { + // Mark as inexact since pruning is file‑granular + Ok(vec![TableProviderFilterPushDown::Inexact; fs.len()]) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // 1. Create temp dir and write 3 Parquet files with different category sets + let tmp = TempDir::new()?; + let dir = tmp.path(); + write_file_with_index(&dir.join("a.parquet"), &["foo", "bar", "foo"])?; + write_file_with_index(&dir.join("b.parquet"), &["baz", "qux"])?; + write_file_with_index(&dir.join("c.parquet"), &["foo", "quux", "quux"])?; + + // 2. Register our custom TableProvider + let field = Field::new("category", DataType::Utf8, false); + let schema_ref = Arc::new(Schema::new(vec![field])); + let provider = Arc::new(DistinctIndexTable::try_new(dir, schema_ref.clone())?); + + let ctx = SessionContext::new(); + ctx.register_table("t", provider)?; + + // 3. Run a query: only files containing 'foo' get scanned. The rest are pruned. + // based on the distinct index. + let df = ctx.sql("SELECT * FROM t WHERE category = 'foo'").await?; + df.show().await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/parquet_encrypted.rs b/datafusion-examples/examples/parquet_encrypted.rs new file mode 100644 index 0000000000000..e9e239b7a1c32 --- /dev/null +++ b/datafusion-examples/examples/parquet_encrypted.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; +use datafusion::config::TableParquetOptions; +use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; +use datafusion::logical_expr::{col, lit}; +use datafusion::parquet::encryption::decrypt::FileDecryptionProperties; +use datafusion::parquet::encryption::encrypt::FileEncryptionProperties; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use tempfile::TempDir; + +#[tokio::main] +async fn main() -> datafusion::common::Result<()> { + // The SessionContext is the main high level API for interacting with DataFusion + let ctx = SessionContext::new(); + + // Find the local path of "alltypes_plain.parquet" + let testdata = datafusion::test_util::parquet_test_data(); + let filename = &format!("{testdata}/alltypes_plain.parquet"); + + // Read the sample parquet file + let parquet_df = ctx + .read_parquet(filename, ParquetReadOptions::default()) + .await?; + + // Show information from the dataframe + println!( + "===============================================================================" + ); + println!("Original Parquet DataFrame:"); + query_dataframe(&parquet_df).await?; + + // Setup encryption and decryption properties + let (encrypt, decrypt) = setup_encryption(&parquet_df)?; + + // Create a temporary file location for the encrypted parquet file + let tmp_dir = TempDir::new()?; + let tempfile = tmp_dir.path().join("alltypes_plain-encrypted.parquet"); + let tempfile_str = tempfile.into_os_string().into_string().unwrap(); + + // Write encrypted parquet + let mut options = TableParquetOptions::default(); + options.crypto.file_encryption = Some((&encrypt).into()); + parquet_df + .write_parquet( + tempfile_str.as_str(), + DataFrameWriteOptions::new().with_single_file_output(true), + Some(options), + ) + .await?; + + // Read encrypted parquet + let ctx: SessionContext = SessionContext::new(); + let read_options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?; + + // Show information from the dataframe + println!("\n\n==============================================================================="); + println!("Encrypted Parquet DataFrame:"); + query_dataframe(&encrypted_parquet_df).await?; + + Ok(()) +} + +// Show information from the dataframe +async fn query_dataframe(df: &DataFrame) -> Result<(), DataFusionError> { + // show its schema using 'describe' + println!("Schema:"); + df.clone().describe().await?.show().await?; + + // Select three columns and filter the results + // so that only rows where id > 1 are returned + println!("\nSelected rows and columns:"); + df.clone() + .select_columns(&["id", "bool_col", "timestamp_col"])? + .filter(col("id").gt(lit(5)))? + .show() + .await?; + + Ok(()) +} + +// Setup encryption and decryption properties +fn setup_encryption( + parquet_df: &DataFrame, +) -> Result<(FileEncryptionProperties, FileDecryptionProperties), DataFusionError> { + let schema = parquet_df.schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields().iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + + let encrypt = encrypt.build()?; + let decrypt = decrypt.build()?; + Ok((encrypt, decrypt)) +} diff --git a/datafusion-examples/examples/parquet_encrypted_with_kms.rs b/datafusion-examples/examples/parquet_encrypted_with_kms.rs new file mode 100644 index 0000000000000..19b0e8d0b1994 --- /dev/null +++ b/datafusion-examples/examples/parquet_encrypted_with_kms.rs @@ -0,0 +1,303 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use base64::Engine; +use datafusion::common::extensions_options; +use datafusion::config::{EncryptionFactoryOptions, TableParquetOptions}; +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::ListingOptions; +use datafusion::error::Result; +use datafusion::execution::parquet_encryption::EncryptionFactory; +use datafusion::parquet::encryption::decrypt::KeyRetriever; +use datafusion::parquet::encryption::{ + decrypt::FileDecryptionProperties, encrypt::FileEncryptionProperties, +}; +use datafusion::prelude::SessionContext; +use futures::StreamExt; +use object_store::path::Path; +use rand::rand_core::{OsRng, TryRngCore}; +use std::collections::HashSet; +use std::sync::Arc; +use tempfile::TempDir; + +const ENCRYPTION_FACTORY_ID: &str = "example.mock_kms_encryption"; + +/// This example demonstrates reading and writing Parquet files that +/// are encrypted using Parquet Modular Encryption. +/// +/// Compared to the `parquet_encrypted` example, where AES keys +/// are specified directly, this example implements an `EncryptionFactory` that +/// generates encryption keys dynamically per file. +/// Encryption key metadata is stored inline in the Parquet files and is used to determine +/// the decryption keys when reading the files. +/// +/// In this example, encryption keys are simply stored base64 encoded in the Parquet metadata, +/// which is not a secure way to store encryption keys. +/// For production use, it is recommended to use a key-management service (KMS) to encrypt +/// data encryption keys. +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + + // Register an `EncryptionFactory` implementation to be used for Parquet encryption + // in the runtime environment. + // `EncryptionFactory` instances are registered with a name to identify them so + // they can be later referenced in configuration options, and it's possible to register + // multiple different factories to handle different ways of encrypting Parquet. + let encryption_factory = TestEncryptionFactory::default(); + ctx.runtime_env().register_parquet_encryption_factory( + ENCRYPTION_FACTORY_ID, + Arc::new(encryption_factory), + ); + + // Register some simple test data + let a: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + let b: ArrayRef = Arc::new(Int32Array::from(vec![1, 10, 10, 100])); + let c: ArrayRef = Arc::new(Int32Array::from(vec![2, 20, 20, 200])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?; + ctx.register_batch("test_data", batch)?; + + { + // Write and read encrypted Parquet with the programmatic API + let tmpdir = TempDir::new()?; + let table_path = format!("{}/", tmpdir.path().to_str().unwrap()); + write_encrypted(&ctx, &table_path).await?; + read_encrypted(&ctx, &table_path).await?; + } + + { + // Write and read encrypted Parquet with the SQL API + let tmpdir = TempDir::new()?; + let table_path = format!("{}/", tmpdir.path().to_str().unwrap()); + write_encrypted_with_sql(&ctx, &table_path).await?; + read_encrypted_with_sql(&ctx, &table_path).await?; + } + + Ok(()) +} + +/// Write an encrypted Parquet file +async fn write_encrypted(ctx: &SessionContext, table_path: &str) -> Result<()> { + let df = ctx.table("test_data").await?; + + let mut parquet_options = TableParquetOptions::new(); + // We specify that we want to use Parquet encryption by setting the identifier of the + // encryption factory to use and providing the factory-specific configuration. + // Our encryption factory only requires specifying the columns to encrypt. + let encryption_config = EncryptionConfig { + encrypted_columns: "b,c".to_owned(), + }; + parquet_options + .crypto + .configure_factory(ENCRYPTION_FACTORY_ID, &encryption_config); + + df.write_parquet( + table_path, + DataFrameWriteOptions::new(), + Some(parquet_options), + ) + .await?; + + println!("Encrypted Parquet written to {table_path}"); + Ok(()) +} + +/// Read from an encrypted Parquet file +async fn read_encrypted(ctx: &SessionContext, table_path: &str) -> Result<()> { + let mut parquet_options = TableParquetOptions::new(); + // Specify the encryption factory to use for decrypting Parquet. + // In this example, we don't require any additional configuration options when reading + // as we only need the key metadata from the Parquet files to determine the decryption keys. + parquet_options + .crypto + .configure_factory(ENCRYPTION_FACTORY_ID, &EncryptionConfig::default()); + + let file_format = ParquetFormat::default().with_options(parquet_options); + let listing_options = ListingOptions::new(Arc::new(file_format)); + + ctx.register_listing_table( + "encrypted_parquet_table", + &table_path, + listing_options.clone(), + None, + None, + ) + .await?; + + let mut batch_stream = ctx + .table("encrypted_parquet_table") + .await? + .execute_stream() + .await?; + println!("Reading encrypted Parquet as a RecordBatch stream"); + while let Some(batch) = batch_stream.next().await { + let batch = batch?; + println!("Read batch with {} rows", batch.num_rows()); + } + + println!("Finished reading"); + Ok(()) +} + +/// Write an encrypted Parquet file using only SQL syntax with string configuration +async fn write_encrypted_with_sql(ctx: &SessionContext, table_path: &str) -> Result<()> { + let query = format!( + "COPY test_data \ + TO '{table_path}' \ + STORED AS parquet + OPTIONS (\ + 'format.crypto.factory_id' '{ENCRYPTION_FACTORY_ID}', \ + 'format.crypto.factory_options.encrypted_columns' 'b,c' \ + )" + ); + let _ = ctx.sql(&query).await?.collect().await?; + + println!("Encrypted Parquet written to {table_path}"); + Ok(()) +} + +/// Read from an encrypted Parquet file using only the SQL API and string-based configuration +async fn read_encrypted_with_sql(ctx: &SessionContext, table_path: &str) -> Result<()> { + let ddl = format!( + "CREATE EXTERNAL TABLE encrypted_parquet_table_2 \ + STORED AS PARQUET LOCATION '{table_path}' OPTIONS (\ + 'format.crypto.factory_id' '{ENCRYPTION_FACTORY_ID}' \ + )" + ); + ctx.sql(&ddl).await?; + let df = ctx.sql("SELECT * FROM encrypted_parquet_table_2").await?; + let mut batch_stream = df.execute_stream().await?; + + println!("Reading encrypted Parquet as a RecordBatch stream"); + while let Some(batch) = batch_stream.next().await { + let batch = batch?; + println!("Read batch with {} rows", batch.num_rows()); + } + println!("Finished reading"); + Ok(()) +} + +// Options used to configure our example encryption factory +extensions_options! { + struct EncryptionConfig { + /// Comma-separated list of columns to encrypt + pub encrypted_columns: String, default = "".to_owned() + } +} + +/// Mock implementation of an `EncryptionFactory` that stores encryption keys +/// base64 encoded in the Parquet encryption metadata. +/// For production use, integrating with a key-management service to encrypt +/// data encryption keys is recommended. +#[derive(Default, Debug)] +struct TestEncryptionFactory {} + +/// `EncryptionFactory` is a DataFusion trait for types that generate +/// file encryption and decryption properties. +#[async_trait] +impl EncryptionFactory for TestEncryptionFactory { + /// Generate file encryption properties to use when writing a Parquet file. + /// The `schema` is provided so that it may be used to dynamically configure + /// per-column encryption keys. + /// The file path is also available. We don't use the path in this example, + /// but other implementations may want to use this to compute an + /// AAD prefix for the file, or to allow use of external key material + /// (where key metadata is stored in a JSON file alongside Parquet files). + async fn get_file_encryption_properties( + &self, + options: &EncryptionFactoryOptions, + schema: &SchemaRef, + _file_path: &Path, + ) -> Result> { + let config: EncryptionConfig = options.to_extension_options()?; + + // Generate a random encryption key for this file. + let mut key = vec![0u8; 16]; + OsRng.try_fill_bytes(&mut key).unwrap(); + + // Generate the key metadata that allows retrieving the key when reading the file. + let key_metadata = wrap_key(&key); + + let mut builder = FileEncryptionProperties::builder(key.to_vec()) + .with_footer_key_metadata(key_metadata.clone()); + + let encrypted_columns: HashSet<&str> = + config.encrypted_columns.split(",").collect(); + if !encrypted_columns.is_empty() { + // Set up per-column encryption. + for field in schema.fields().iter() { + if encrypted_columns.contains(field.name().as_str()) { + // Here we re-use the same key for all encrypted columns, + // but new keys could also be generated per column. + builder = builder.with_column_key_and_metadata( + field.name().as_str(), + key.clone(), + key_metadata.clone(), + ); + } + } + } + + let encryption_properties = builder.build()?; + + Ok(Some(encryption_properties)) + } + + /// Generate file decryption properties to use when reading a Parquet file. + /// Rather than provide the AES keys directly for decryption, we set a `KeyRetriever` + /// that can determine the keys using the encryption metadata. + async fn get_file_decryption_properties( + &self, + _options: &EncryptionFactoryOptions, + _file_path: &Path, + ) -> Result> { + let decryption_properties = + FileDecryptionProperties::with_key_retriever(Arc::new(TestKeyRetriever {})) + .build()?; + Ok(Some(decryption_properties)) + } +} + +/// Mock implementation of encrypting a key that simply base64 encodes the key. +/// Note that this is not a secure way to store encryption keys, +/// and for production use keys should be encrypted with a KMS. +fn wrap_key(key: &[u8]) -> Vec { + base64::prelude::BASE64_STANDARD + .encode(key) + .as_bytes() + .to_vec() +} + +struct TestKeyRetriever {} + +impl KeyRetriever for TestKeyRetriever { + /// Get a data encryption key using the metadata stored in the Parquet file. + fn retrieve_key( + &self, + key_metadata: &[u8], + ) -> datafusion::parquet::errors::Result> { + let key_metadata = std::str::from_utf8(key_metadata)?; + let key = base64::prelude::BASE64_STANDARD + .decode(key_metadata) + .unwrap(); + Ok(key) + } +} diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs index 0b6bccc27b1d1..afc3b279f4a9f 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/parquet_index.rs @@ -23,6 +23,7 @@ use arrow::datatypes::{Int32Type, SchemaRef}; use arrow::util::pretty::pretty_format_batches; use async_trait::async_trait; use datafusion::catalog::Session; +use datafusion::common::pruning::PruningStatistics; use datafusion::common::{ internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, }; @@ -39,7 +40,7 @@ use datafusion::parquet::arrow::{ arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, }; use datafusion::physical_expr::PhysicalExpr; -use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use std::any::Any; @@ -70,7 +71,7 @@ use url::Url; /// (using the same underlying APIs) /// /// For a more advanced example of using an index to prune row groups within a -/// file, see the (forthcoming) `advanced_parquet_index` example. +/// file, see the `advanced_parquet_index` example. /// /// # Diagram /// @@ -242,8 +243,7 @@ impl TableProvider for IndexTableProvider { let files = self.index.get_files(predicate.clone())?; let object_store_url = ObjectStoreUrl::parse("file://")?; - let source = - Arc::new(ParquetSource::default().with_predicate(self.schema(), predicate)); + let source = Arc::new(ParquetSource::default().with_predicate(predicate)); let mut file_scan_config_builder = FileScanConfigBuilder::new(object_store_url, self.schema(), source) .with_projection(projection.cloned()) @@ -313,7 +313,7 @@ impl Display for ParquetMetadataIndex { "ParquetMetadataIndex(last_num_pruned: {})", self.last_num_pruned() )?; - let batches = pretty_format_batches(&[self.index.clone()]).unwrap(); + let batches = pretty_format_batches(std::slice::from_ref(&self.index)).unwrap(); write!(f, "{batches}",) } } @@ -685,7 +685,7 @@ fn make_demo_file(path: impl AsRef, value_range: Range) -> Result<()> let num_values = value_range.len(); let file_names = - StringArray::from_iter_values(std::iter::repeat(&filename).take(num_values)); + StringArray::from_iter_values(std::iter::repeat_n(&filename, num_values)); let values = Int32Array::from_iter_values(value_range); let batch = RecordBatch::try_from_iter(vec![ ("file_name", Arc::new(file_names) as ArrayRef), diff --git a/datafusion-examples/examples/planner_api.rs b/datafusion-examples/examples/planner_api.rs index 4943e593bd0bf..55aec7b0108a4 100644 --- a/datafusion-examples/examples/planner_api.rs +++ b/datafusion-examples/examples/planner_api.rs @@ -96,7 +96,7 @@ async fn to_physical_plan_step_by_step_demo( ctx.state().config_options(), |_, _| (), )?; - println!("Analyzed logical plan:\n\n{:?}\n\n", analyzed_logical_plan); + println!("Analyzed logical plan:\n\n{analyzed_logical_plan:?}\n\n"); // Optimize the analyzed logical plan let optimized_logical_plan = ctx.state().optimizer().optimize( @@ -104,10 +104,7 @@ async fn to_physical_plan_step_by_step_demo( &ctx.state(), |_, _| (), )?; - println!( - "Optimized logical plan:\n\n{:?}\n\n", - optimized_logical_plan - ); + println!("Optimized logical plan:\n\n{optimized_logical_plan:?}\n\n"); // Create the physical plan let physical_plan = ctx diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 4c802bcdbda04..9a61789662cdd 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -20,10 +20,11 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::pruning::PruningStatistics; use datafusion::common::{DFSchema, ScalarValue}; use datafusion::execution::context::ExecutionProps; use datafusion::physical_expr::create_physical_expr; -use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion::prelude::*; /// This example shows how to use DataFusion's `PruningPredicate` to prove @@ -186,10 +187,10 @@ impl PruningStatistics for MyCatalog { } fn create_pruning_predicate(expr: Expr, schema: &SchemaRef) -> PruningPredicate { - let df_schema = DFSchema::try_from(schema.as_ref().clone()).unwrap(); + let df_schema = DFSchema::try_from(Arc::clone(schema)).unwrap(); let props = ExecutionProps::new(); let physical_expr = create_physical_expr(&expr, &df_schema, &props).unwrap(); - PruningPredicate::try_new(physical_expr, schema.clone()).unwrap() + PruningPredicate::try_new(physical_expr, Arc::clone(schema)).unwrap() } fn i32_array<'a>(values: impl Iterator>) -> ArrayRef { diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index d2b2d1bf96551..b65ffb8d71748 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -133,7 +133,8 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first() + else { return plan_err!("read_csv requires at least one string argument"); }; @@ -145,7 +146,7 @@ impl TableFunctionImpl for LocalCsvTableFunc { let info = SimplifyContext::new(&execution_props); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; - if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr { Ok(limit as usize) } else { plan_err!("Limit must be an integer") diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs index d3826026a9725..4ff669faf1d0c 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_analysis.rs @@ -274,7 +274,10 @@ from for table in tables { ctx.register_table( table.name, - Arc::new(MemTable::try_new(Arc::new(table.schema.clone()), vec![])?), + Arc::new(MemTable::try_new( + Arc::new(table.schema.clone()), + vec![vec![]], + )?), )?; } // We can create a LogicalPlan from a SQL query like this diff --git a/datafusion-examples/examples/sql_dialect.rs b/datafusion-examples/examples/sql_dialect.rs index 12141847ca361..20b515506f3b4 100644 --- a/datafusion-examples/examples/sql_dialect.rs +++ b/datafusion-examples/examples/sql_dialect.rs @@ -17,10 +17,10 @@ use std::fmt::Display; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::sql::{ parser::{CopyToSource, CopyToStatement, DFParser, DFParserBuilder, Statement}, - sqlparser::{keywords::Keyword, parser::ParserError, tokenizer::Token}, + sqlparser::{keywords::Keyword, tokenizer::Token}, }; /// This example demonstrates how to use the DFParser to parse a statement in a custom way @@ -34,8 +34,8 @@ async fn main() -> Result<()> { let my_statement = my_parser.parse_statement()?; match my_statement { - MyStatement::DFStatement(s) => println!("df: {}", s), - MyStatement::MyCopyTo(s) => println!("my_copy: {}", s), + MyStatement::DFStatement(s) => println!("df: {s}"), + MyStatement::MyCopyTo(s) => println!("my_copy: {s}"), } Ok(()) @@ -62,7 +62,7 @@ impl<'a> MyParser<'a> { /// This is the entry point to our parser -- it handles `COPY` statements specially /// but otherwise delegates to the existing DataFusion parser. - pub fn parse_statement(&mut self) -> Result { + pub fn parse_statement(&mut self) -> Result { if self.is_copy() { self.df_parser.parser.next_token(); // COPY let df_statement = self.df_parser.parse_copy()?; @@ -87,8 +87,8 @@ enum MyStatement { impl Display for MyStatement { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - MyStatement::DFStatement(s) => write!(f, "{}", s), - MyStatement::MyCopyTo(s) => write!(f, "{}", s), + MyStatement::DFStatement(s) => write!(f, "{s}"), + MyStatement::MyCopyTo(s) => write!(f, "{s}"), } } } diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs index 3955d5038cfb0..1fc9ce24ecbb5 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_frontend.rs @@ -83,7 +83,7 @@ pub fn main() -> Result<()> { let config = OptimizerContext::default().with_skip_failing_rules(false); let analyzed_plan = Analyzer::new().execute_and_check( logical_plan, - config.options(), + &config.options(), observe_analyzer, )?; // Note that the Analyzer has added a CAST to the plan to align the types diff --git a/datafusion-examples/examples/thread_pools.rs b/datafusion-examples/examples/thread_pools.rs new file mode 100644 index 0000000000000..bba56b2932abc --- /dev/null +++ b/datafusion-examples/examples/thread_pools.rs @@ -0,0 +1,350 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example shows how to use separate thread pools (tokio [`Runtime`]))s to +//! run the IO and CPU intensive parts of DataFusion plans. +//! +//! # Background +//! +//! DataFusion, by default, plans and executes all operations (both CPU and IO) +//! on the same thread pool. This makes it fast and easy to get started, but +//! can cause issues when running at scale, especially when fetching and operating +//! on data directly from remote sources. +//! +//! Specifically, without configuration such as in this example, DataFusion +//! plans and executes everything the same thread pool (Tokio Runtime), including +//! any I/O, such as reading Parquet files from remote object storage +//! (e.g. AWS S3), catalog access, and CPU intensive work. Running this diverse +//! workload can lead to issues described in the [Architecture section] such as +//! throttled network bandwidth (due to congestion control) and increased +//! latencies or timeouts while processing network messages. +//! +//! [Architecture section]: https://docs.rs/datafusion/latest/datafusion/index.html#thread-scheduling-cpu--io-thread-pools-and-tokio-runtimes + +use arrow::util::pretty::pretty_format_batches; +use datafusion::common::runtime::JoinSet; +use datafusion::error::Result; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::prelude::*; +use futures::stream::StreamExt; +use object_store::client::SpawnedReqwestConnector; +use object_store::http::HttpBuilder; +use std::sync::Arc; +use tokio::runtime::Handle; +use tokio::sync::Notify; +use url::Url; + +/// Normally, you don't need to worry about the details of the tokio +/// [`Runtime`], but for this example it is important to understand how the +/// [`Runtime`]s work. +/// +/// Each thread has "current" runtime that is installed in a thread local +/// variable which is used by the `tokio::spawn` function. +/// +/// The `#[tokio::main]` macro creates a [`Runtime`] and installs it as +/// as the "current" runtime in a thread local variable, on which any `async` +/// [`Future`], [`Stream]`s and [`Task]`s are run. +/// +/// This example uses the runtime created by [`tokio::main`] to do I/O and spawn +/// CPU intensive tasks on a separate [`Runtime`], mirroring the common pattern +/// when using Rust libraries such as `tonic`. Using a separate `Runtime` for +/// CPU bound tasks will often be simpler in larger applications, even though it +/// makes this example slightly more complex. +#[tokio::main] +async fn main() -> Result<()> { + // The first two examples read local files. Enabling the URL table feature + // lets us treat filenames as tables in SQL. + let ctx = SessionContext::new().enable_url_table(); + let sql = format!( + "SELECT * FROM '{}/alltypes_plain.parquet'", + datafusion::test_util::parquet_test_data() + ); + + // Run a query on the current runtime. Calling `await` means the future + // (in this case the `async` function and all spawned work in DataFusion + // plans) on the current runtime. + same_runtime(&ctx, &sql).await?; + + // Run the same query but this time on a different runtime. + // + // Since we call `await` here, the `async` function itself runs on the + // current runtime, but internally `different_runtime_basic` executes the + // DataFusion plan on a different Runtime. + different_runtime_basic(ctx, sql).await?; + + // Run the same query on a different runtime, including remote IO. + // + // NOTE: This is best practice for production systems + different_runtime_advanced().await?; + + Ok(()) +} + +/// Run queries directly on the current tokio `Runtime` +/// +/// This is how most examples in DataFusion are written and works well for +/// development, local query processing, and non latency sensitive workloads. +async fn same_runtime(ctx: &SessionContext, sql: &str) -> Result<()> { + // Calling .sql is an async function as it may also do network + // I/O, for example to contact a remote catalog or do an object store LIST + let df = ctx.sql(sql).await?; + + // While many examples call `collect` or `show()`, those methods buffers the + // results. Internally DataFusion generates output a RecordBatch at a time + + // Calling `execute_stream` return a `SendableRecordBatchStream`. Depending + // on the plan, this may also do network I/O, for example to begin reading a + // parquet file from a remote object store. + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // `next()` drives the plan, incrementally producing new `RecordBatch`es + // using the current runtime. + // + // Perhaps somewhat non obviously, calling `next()` can also result in other + // tasks being spawned on the current runtime (e.g. for `RepartitionExec` to + // read data from each of its input partitions in parallel). + // + // Executing the plan using this pattern intermixes any IO and CPU intensive + // work on same Runtime + while let Some(batch) = stream.next().await { + println!("{}", pretty_format_batches(&[batch?]).unwrap()); + } + Ok(()) +} + +/// Run queries on a **different** Runtime dedicated for CPU bound work +/// +/// This example is suitable for running DataFusion plans against local data +/// sources (e.g. files) and returning results to an async destination, as might +/// be done to return query results to a remote client. +/// +/// Production systems which also read data locally or require very low latency +/// should follow the recommendations on [`different_runtime_advanced`] when +/// processing data from a remote source such as object storage. +async fn different_runtime_basic(ctx: SessionContext, sql: String) -> Result<()> { + // Since we are already in the context of runtime (installed by + // #[tokio::main]), we need a new Runtime (threadpool) for CPU bound tasks + let cpu_runtime = CpuRuntime::try_new()?; + + // Prepare a task that runs the plan on cpu_runtime and sends + // the results back to the original runtime via a channel. + let (tx, mut rx) = tokio::sync::mpsc::channel(2); + let driver_task = async move { + // Plan the query (which might require CPU work to evaluate statistics) + let df = ctx.sql(&sql).await?; + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // Calling `next()` to drive the plan in this task drives the + // execution from the cpu runtime the other thread pool + // + // NOTE any IO run by this plan (for example, reading from an + // `ObjectStore`) will be done on this new thread pool as well. + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + // error means dropped receiver, so nothing will get results anymore + return Ok(()); + } + } + Ok(()) as Result<()> + }; + + // Run the driver task on the cpu runtime. Use a JoinSet to + // ensure the spawned task is canceled on error/drop + let mut join_set = JoinSet::new(); + join_set.spawn_on(driver_task, cpu_runtime.handle()); + + // Retrieve the results in the original (IO) runtime. This requires only + // minimal work (pass pointers around). + while let Some(batch) = rx.recv().await { + println!("{}", pretty_format_batches(&[batch?])?); + } + + // wait for completion of the driver task + drain_join_set(join_set).await; + + Ok(()) +} + +/// Run CPU intensive work on a different runtime but do IO operations (object +/// store access) on the current runtime. +async fn different_runtime_advanced() -> Result<()> { + // In this example, we will query a file via https, reading + // the data directly from the plan + + // The current runtime (created by tokio::main) is used for IO + // + // Note this handle should be used for *ALL* remote IO operations in your + // systems, including remote catalog access, which is not included in this + // example. + let cpu_runtime = CpuRuntime::try_new()?; + let io_handle = Handle::current(); + + let ctx = SessionContext::new(); + + // By default, the HttpStore use the same runtime that calls `await` for IO + // operations. This means that if the DataFusion plan is called from the + // cpu_runtime, the HttpStore IO operations will *also* run on the CPU + // runtime, which will error. + // + // To avoid this, we use a `SpawnedReqwestConnector` to configure the + // `ObjectStore` to run the HTTP requests on the IO runtime. + let base_url = Url::parse("https://github.com").unwrap(); + let http_store = HttpBuilder::new() + .with_url(base_url.clone()) + // Use the io_runtime to run the HTTP requests. Without this line, + // you will see an error such as: + // A Tokio 1.x context was found, but IO is disabled. + .with_http_connector(SpawnedReqwestConnector::new(io_handle)) + .build()?; + + // Tell DataFusion to process `http://` urls with this wrapped object store + ctx.register_object_store(&base_url, Arc::new(http_store)); + + // As above, plan and execute the query on the cpu runtime. + let (tx, mut rx) = tokio::sync::mpsc::channel(2); + let driver_task = async move { + // Plan / execute the query + let url = "https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv"; + let df = ctx + .sql(&format!("SELECT c1,c2,c3 FROM '{url}' LIMIT 5")) + .await?; + + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // Note you can do other non trivial CPU work on the results of the + // stream before sending it back to the original runtime. For example, + // calling a FlightDataEncoder to convert the results to flight messages + // to send over the network + + // send results, as above + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + return Ok(()); + } + } + Ok(()) as Result<()> + }; + + let mut join_set = JoinSet::new(); + join_set.spawn_on(driver_task, cpu_runtime.handle()); + while let Some(batch) = rx.recv().await { + println!("{}", pretty_format_batches(&[batch?])?); + } + + Ok(()) +} + +/// Waits for all tasks in the JoinSet to complete and reports any errors that +/// occurred. +/// +/// If we don't do this, any errors that occur in the task (such as IO errors) +/// are not reported. +async fn drain_join_set(mut join_set: JoinSet>) { + // retrieve any errors from the tasks + while let Some(result) = join_set.join_next().await { + match result { + Ok(Ok(())) => {} // task completed successfully + Ok(Err(e)) => eprintln!("Task failed: {e}"), // task failed + Err(e) => eprintln!("JoinSet error: {e}"), // JoinSet error + } + } +} + +/// Creates a Tokio [`Runtime`] for use with CPU bound tasks +/// +/// Tokio forbids dropping `Runtime`s in async contexts, so creating a separate +/// `Runtime` correctly is somewhat tricky. This structure manages the creation +/// and shutdown of a separate thread. +/// +/// # Notes +/// On drop, the thread will wait for all remaining tasks to complete. +/// +/// Depending on your application, more sophisticated shutdown logic may be +/// required, such as ensuring that no new tasks are added to the runtime. +/// +/// # Credits +/// This code is derived from code originally written for [InfluxDB 3.0] +/// +/// [InfluxDB 3.0]: https://github.com/influxdata/influxdb3_core/tree/6fcbb004232738d55655f32f4ad2385523d10696/executor +struct CpuRuntime { + /// Handle is the tokio structure for interacting with a Runtime. + handle: Handle, + /// Signal to start shutting down + notify_shutdown: Arc, + /// When thread is active, is Some + thread_join_handle: Option>, +} + +impl Drop for CpuRuntime { + fn drop(&mut self) { + // Notify the thread to shutdown. + self.notify_shutdown.notify_one(); + // In a production system you also need to ensure your code stops adding + // new tasks to the underlying runtime after this point to allow the + // thread to complete its work and exit cleanly. + if let Some(thread_join_handle) = self.thread_join_handle.take() { + // If the thread is still running, we wait for it to finish + print!("Shutting down CPU runtime thread..."); + if let Err(e) = thread_join_handle.join() { + eprintln!("Error joining CPU runtime thread: {e:?}",); + } else { + println!("CPU runtime thread shutdown successfully."); + } + } + } +} + +impl CpuRuntime { + /// Create a new Tokio Runtime for CPU bound tasks + pub fn try_new() -> Result { + let cpu_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_time() + .build()?; + let handle = cpu_runtime.handle().clone(); + let notify_shutdown = Arc::new(Notify::new()); + let notify_shutdown_captured = Arc::clone(¬ify_shutdown); + + // The cpu_runtime runs and is dropped on a separate thread + let thread_join_handle = std::thread::spawn(move || { + cpu_runtime.block_on(async move { + notify_shutdown_captured.notified().await; + }); + // Note: cpu_runtime is dropped here, which will wait for all tasks + // to complete + }); + + Ok(Self { + handle, + notify_shutdown, + thread_join_handle: Some(thread_join_handle), + }) + } + + /// Return a handle suitable for spawning CPU bound tasks + /// + /// # Notes + /// + /// If a task spawned on this handle attempts to do IO, it will error with a + /// message such as: + /// + /// ```text + ///A Tokio 1.x context was found, but IO is disabled. + /// ``` + pub fn handle(&self) -> &Handle { + &self.handle + } +} diff --git a/datafusion-testing b/datafusion-testing index 243047b9dd682..905df5f65cc9d 160000 --- a/datafusion-testing +++ b/datafusion-testing @@ -1 +1 @@ -Subproject commit 243047b9dd682be688628539c604daaddfe640f9 +Subproject commit 905df5f65cc9d0851719c21f5a4dd5cd77621f19 diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index 734580202232b..69f952ae98407 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-catalog-listing" description = "datafusion-catalog-listing" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -41,14 +41,12 @@ datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } -datafusion-session = { workspace = true } futures = { workspace = true } log = { workspace = true } object_store = { workspace = true } tokio = { workspace = true } [dev-dependencies] -tempfile = { workspace = true } [lints] workspace = true diff --git a/datafusion/catalog-listing/README.md b/datafusion/catalog-listing/README.md index b4760c413d60b..81a7c7b1da3ae 100644 --- a/datafusion/catalog-listing/README.md +++ b/datafusion/catalog-listing/README.md @@ -17,14 +17,20 @@ under the License. --> -# DataFusion catalog-listing +# Apache DataFusion Catalog Listing -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion with [ListingTable], an implementation of [TableProvider] based on files in a directory (either locally or on remote object storage such as S3). -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ [listingtable]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html [tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 8efb74d4ea1ee..00e9c71df3489 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -61,7 +61,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { Ok(TreeNodeRecursion::Stop) } } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -346,8 +346,8 @@ fn populate_partition_values<'a>( { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val)) - | (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { ref name, .. }), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { ref name, .. })) => { if partition_values .insert(name, PartitionValue::Single(val.to_string())) .is_some() @@ -507,11 +507,7 @@ where Some((name, val)) if name == pn => part_values.push(val), _ => { debug!( - "Ignoring file: file_path='{}', table_path='{}', part='{}', partition_col='{}'", - file_path, - table_path, - part, - pn, + "Ignoring file: file_path='{file_path}', table_path='{table_path}', part='{part}', partition_col='{pn}'", ); return None; } @@ -988,7 +984,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3)), None))], ), Some(Path::from("a=1970-01-04")), ); @@ -997,9 +993,10 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( - 4 * 24 * 60 * 60 * 1000 - )))),], + &[col("a").eq(Expr::Literal( + ScalarValue::Date64(Some(4 * 24 * 60 * 60 * 1000)), + None + )),], ), Some(Path::from("a=1970-01-05")), ); diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index fb0a960f37b6a..1322577b207ab 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/catalog/Cargo.toml b/datafusion/catalog/Cargo.toml index 7307c4de87a8a..a1db45654be01 100644 --- a/datafusion/catalog/Cargo.toml +++ b/datafusion/catalog/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-catalog" description = "datafusion-catalog" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -42,7 +42,6 @@ datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } -datafusion-sql = { workspace = true } futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } diff --git a/datafusion/catalog/README.md b/datafusion/catalog/README.md index 5b201e736fdc4..48c61b43c025b 100644 --- a/datafusion/catalog/README.md +++ b/datafusion/catalog/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Catalog +# Apache DataFusion Catalog -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides catalog management functionality, including catalogs, schemas, and tables. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/catalog/src/async.rs b/datafusion/catalog/src/async.rs index 5d7a51ad71232..1c830c976d8b8 100644 --- a/datafusion/catalog/src/async.rs +++ b/datafusion/catalog/src/async.rs @@ -737,7 +737,7 @@ mod tests { ] { let async_provider = MockAsyncCatalogProviderList::default(); let cached_provider = async_provider - .resolve(&[table_ref.clone()], &test_config()) + .resolve(std::slice::from_ref(table_ref), &test_config()) .await .unwrap(); diff --git a/datafusion/catalog/src/cte_worktable.rs b/datafusion/catalog/src/cte_worktable.rs index d72a30909c02c..d6b2a453118c9 100644 --- a/datafusion/catalog/src/cte_worktable.rs +++ b/datafusion/catalog/src/cte_worktable.rs @@ -71,7 +71,7 @@ impl TableProvider for CteWorkTable { self } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { None } diff --git a/datafusion/catalog/src/default_table_source.rs b/datafusion/catalog/src/default_table_source.rs index 9db8242caa999..11963c06c88f5 100644 --- a/datafusion/catalog/src/default_table_source.rs +++ b/datafusion/catalog/src/default_table_source.rs @@ -33,8 +33,6 @@ use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource, TableType} /// /// It is used so logical plans in the `datafusion_expr` crate do not have a /// direct dependency on physical plans, such as [`TableProvider`]s. -/// -/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html pub struct DefaultTableSource { /// table provider pub table_provider: Arc, @@ -78,7 +76,7 @@ impl TableSource for DefaultTableSource { self.table_provider.supports_filters_pushdown(filter) } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { self.table_provider.get_logical_plan() } diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index 7948c0299d393..d733551f44051 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -30,6 +30,7 @@ use arrow::{ use async_trait::async_trait; use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::error::Result; +use datafusion_common::types::NativeType; use datafusion_common::DataFusionError; use datafusion_execution::TaskContext; use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; @@ -37,7 +38,7 @@ use datafusion_expr::{TableType, Volatility}; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::streaming::PartitionStream; use datafusion_physical_plan::SendableRecordBatchStream; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -102,12 +103,14 @@ impl InformationSchemaConfig { // schema name may not exist in the catalog, so we need to check if let Some(schema) = catalog.schema(&schema_name) { for table_name in schema.table_names() { - if let Some(table) = schema.table(&table_name).await? { + if let Some(table_type) = + schema.table_type(&table_name).await? + { builder.add_table( &catalog_name, &schema_name, &table_name, - table.table_type(), + table_type, ); } } @@ -403,58 +406,63 @@ impl InformationSchemaConfig { /// returns a tuple of (arg_types, return_type) fn get_udf_args_and_return_types( udf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() .map(|arg_types| { // only handle the function which implemented [`ScalarUDFImpl::return_type`] method - let return_type = udf.return_type(&arg_types).ok().map(|t| t.to_string()); + let return_type = udf + .return_type(&arg_types) + .map(|t| remove_native_type_prefix(NativeType::from(t))) + .ok(); let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, return_type) }) - .collect::>()) + .collect::>()) } } fn get_udaf_args_and_return_types( udaf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udaf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() .map(|arg_types| { // only handle the function which implemented [`ScalarUDFImpl::return_type`] method - let return_type = - udaf.return_type(&arg_types).ok().map(|t| t.to_string()); + let return_type = udaf + .return_type(&arg_types) + .ok() + .map(|t| remove_native_type_prefix(NativeType::from(t))); let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, return_type) }) - .collect::>()) + .collect::>()) } } fn get_udwf_args_and_return_types( udwf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udwf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() @@ -462,14 +470,19 @@ fn get_udwf_args_and_return_types( // only handle the function which implemented [`ScalarUDFImpl::return_type`] method let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, None) }) - .collect::>()) + .collect::>()) } } +#[inline] +fn remove_native_type_prefix(native_type: NativeType) -> String { + format!("{native_type}") +} + #[async_trait] impl SchemaProvider for InformationSchemaProvider { fn as_any(&self) -> &dyn Any { @@ -479,7 +492,7 @@ impl SchemaProvider for InformationSchemaProvider { fn table_names(&self) -> Vec { INFORMATION_SCHEMA_TABLES .iter() - .map(|t| t.to_string()) + .map(|t| (*t).to_string()) .collect() } @@ -797,7 +810,7 @@ impl InformationSchemaColumnsBuilder { ) { use DataType::*; - // Note: append_value is actually infallable. + // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name); self.schema_names.append_value(schema_name); self.table_names.append_value(table_name); @@ -814,8 +827,7 @@ impl InformationSchemaColumnsBuilder { self.is_nullables.append_value(nullable_str); // "System supplied type" --> Use debug format of the datatype - self.data_types - .append_value(format!("{:?}", field.data_type())); + self.data_types.append_value(field.data_type().to_string()); // "If data_type identifies a character or bit string type, the // declared maximum length; null for all other data types or @@ -1348,3 +1360,92 @@ impl PartitionStream for InformationSchemaParameters { )) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::CatalogProvider; + + #[tokio::test] + async fn make_tables_uses_table_type() { + let config = InformationSchemaConfig { + catalog_list: Arc::new(Fixture), + }; + let mut builder = InformationSchemaTablesBuilder { + catalog_names: StringBuilder::new(), + schema_names: StringBuilder::new(), + table_names: StringBuilder::new(), + table_types: StringBuilder::new(), + schema: Arc::new(Schema::empty()), + }; + + assert!(config.make_tables(&mut builder).await.is_ok()); + + assert_eq!("BASE TABLE", builder.table_types.finish().value(0)); + } + + #[derive(Debug)] + struct Fixture; + + #[async_trait] + impl SchemaProvider for Fixture { + // InformationSchemaConfig::make_tables should use this. + async fn table_type(&self, _: &str) -> Result> { + Ok(Some(TableType::Base)) + } + + // InformationSchemaConfig::make_tables used this before `table_type` + // existed but should not, as it may be expensive. + async fn table(&self, _: &str) -> Result>> { + panic!("InformationSchemaConfig::make_tables called SchemaProvider::table instead of table_type") + } + + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn table_names(&self) -> Vec { + vec!["atable".to_string()] + } + + fn table_exist(&self, _: &str) -> bool { + unimplemented!("not required for these tests") + } + } + + impl CatalogProviderList for Fixture { + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn register_catalog( + &self, + _: String, + _: Arc, + ) -> Option> { + unimplemented!("not required for these tests") + } + + fn catalog_names(&self) -> Vec { + vec!["acatalog".to_string()] + } + + fn catalog(&self, _: &str) -> Option> { + Some(Arc::new(Self)) + } + } + + impl CatalogProvider for Fixture { + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn schema_names(&self) -> Vec { + vec!["aschema".to_string()] + } + + fn schema(&self, _: &str) -> Option> { + Some(Arc::new(Self)) + } + } +} diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs index 0394b05277dac..1c5e38438724e 100644 --- a/datafusion/catalog/src/lib.rs +++ b/datafusion/catalog/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/catalog/src/listing_schema.rs b/datafusion/catalog/src/listing_schema.rs index cc2c2ee606b3d..af96cfc15fc82 100644 --- a/datafusion/catalog/src/listing_schema.rs +++ b/datafusion/catalog/src/listing_schema.rs @@ -26,7 +26,7 @@ use crate::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::Session; use datafusion_common::{ - Constraints, DFSchema, DataFusionError, HashMap, TableReference, + internal_datafusion_err, DFSchema, DataFusionError, HashMap, TableReference, }; use datafusion_expr::CreateExternalTable; @@ -111,17 +111,13 @@ impl ListingSchemaProvider { let file_name = table .path .file_name() - .ok_or_else(|| { - DataFusionError::Internal("Cannot parse file name!".to_string()) - })? + .ok_or_else(|| internal_datafusion_err!("Cannot parse file name!"))? .to_str() - .ok_or_else(|| { - DataFusionError::Internal("Cannot parse file name!".to_string()) - })?; + .ok_or_else(|| internal_datafusion_err!("Cannot parse file name!"))?; let table_name = file_name.split('.').collect_vec()[0]; - let table_path = table.to_string().ok_or_else(|| { - DataFusionError::Internal("Cannot parse file name!".to_string()) - })?; + let table_path = table + .to_string() + .ok_or_else(|| internal_datafusion_err!("Cannot parse file name!"))?; if !self.table_exist(table_name) { let table_url = format!("{}/{}", self.authority, table_path); @@ -138,12 +134,13 @@ impl ListingSchemaProvider { file_type: self.format.clone(), table_partition_cols: vec![], if_not_exists: false, + or_replace: false, temporary: false, definition: None, order_exprs: vec![], unbounded: false, options: Default::default(), - constraints: Constraints::empty(), + constraints: Default::default(), column_defaults: Default::default(), }, ) diff --git a/datafusion/catalog/src/memory/table.rs b/datafusion/catalog/src/memory/table.rs index 81243e2c4889e..90224f6a37bc3 100644 --- a/datafusion/catalog/src/memory/table.rs +++ b/datafusion/catalog/src/memory/table.rs @@ -23,25 +23,22 @@ use std::fmt::Debug; use std::sync::Arc; use crate::TableProvider; -use datafusion_common::error::Result; -use datafusion_expr::Expr; -use datafusion_expr::TableType; -use datafusion_physical_expr::create_physical_sort_exprs; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{ - common, ExecutionPlan, ExecutionPlanProperties, Partitioning, -}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_common::error::Result; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_common_runtime::JoinSet; -use datafusion_datasource::memory::MemSink; -use datafusion_datasource::memory::MemorySourceConfig; +use datafusion_datasource::memory::{MemSink, MemorySourceConfig}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::DataSourceExec; use datafusion_expr::dml::InsertOp; -use datafusion_expr::SortExpr; +use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::{ + common, ExecutionPlan, ExecutionPlanProperties, Partitioning, +}; use datafusion_session::Session; use async_trait::async_trait; @@ -70,8 +67,16 @@ pub struct MemTable { } impl MemTable { - /// Create a new in-memory table from the provided schema and record batches + /// Create a new in-memory table from the provided schema and record batches. + /// + /// Requires at least one partition. To construct an empty `MemTable`, pass + /// `vec![vec![]]` as the `partitions` argument, this represents one partition with + /// no batches. pub fn try_new(schema: SchemaRef, partitions: Vec>) -> Result { + if partitions.is_empty() { + return plan_err!("No partitions provided, expected at least one partition"); + } + for batches in partitions.iter().flatten() { let batches_schema = batches.schema(); if !schema.contains(&batches_schema) { @@ -89,7 +94,7 @@ impl MemTable { .into_iter() .map(|e| Arc::new(RwLock::new(e))) .collect::>(), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), sort_order: Arc::new(Mutex::new(vec![])), }) @@ -237,18 +242,15 @@ impl TableProvider for MemTable { // add sort information if present let sort_order = self.sort_order.lock(); if !sort_order.is_empty() { - let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?; - - let file_sort_order = sort_order - .iter() - .map(|sort_exprs| { - create_physical_sort_exprs( - sort_exprs, - &df_schema, - state.execution_props(), - ) - }) - .collect::>>()?; + let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; + + let eqp = state.execution_props(); + let mut file_sort_order = vec![]; + for sort_exprs in sort_order.iter() { + let physical_exprs = + create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?; + file_sort_order.extend(LexOrdering::new(physical_exprs)); + } source = source.try_with_sort_information(file_sort_order)?; } diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs index 5b37348fd7427..9ba55256f1824 100644 --- a/datafusion/catalog/src/schema.rs +++ b/datafusion/catalog/src/schema.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use crate::table::TableProvider; use datafusion_common::Result; +use datafusion_expr::TableType; /// Represents a schema, comprising a number of named tables. /// @@ -54,6 +55,14 @@ pub trait SchemaProvider: Debug + Sync + Send { name: &str, ) -> Result>, DataFusionError>; + /// Retrieves the type of a specific table from the schema by name, if it exists, otherwise + /// returns `None`. Implementations for which this operation is cheap but [Self::table] is + /// expensive can override this to improve operations that only need the type, e.g. + /// `SELECT * FROM information_schema.tables`. + async fn table_type(&self, name: &str) -> Result> { + self.table(name).await.map(|o| o.map(|t| t.table_type())) + } + /// If supported by the implementation, adds a new table named `name` to /// this schema. /// diff --git a/datafusion/catalog/src/stream.rs b/datafusion/catalog/src/stream.rs index fbfab513229e0..f4a2338b8eecb 100644 --- a/datafusion/catalog/src/stream.rs +++ b/datafusion/catalog/src/stream.rs @@ -34,7 +34,7 @@ use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; -use datafusion_physical_expr::create_ordering; +use datafusion_physical_expr::create_lex_ordering; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; @@ -53,7 +53,7 @@ impl TableProviderFactory for StreamTableFactory { state: &dyn Session, cmd: &CreateExternalTable, ) -> Result> { - let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into()); + let schema: SchemaRef = Arc::clone(cmd.schema.inner()); let location = cmd.location.clone(); let encoding = cmd.file_type.parse()?; let header = if let Ok(opt) = cmd @@ -256,7 +256,7 @@ impl StreamConfig { Self { source, order: vec![], - constraints: Constraints::empty(), + constraints: Constraints::default(), } } @@ -321,17 +321,21 @@ impl TableProvider for StreamTable { async fn scan( &self, - _state: &dyn Session, + state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], limit: Option, ) -> Result> { let projected_schema = match projection { Some(p) => { - let projected = self.0.source.schema().project(p)?; - create_ordering(&projected, &self.0.order)? + let projected = Arc::new(self.0.source.schema().project(p)?); + create_lex_ordering(&projected, &self.0.order, state.execution_props())? } - None => create_ordering(self.0.source.schema(), &self.0.order)?, + None => create_lex_ordering( + self.0.source.schema(), + &self.0.order, + state.execution_props(), + )?, }; Ok(Arc::new(StreamingTableExec::try_new( @@ -350,15 +354,11 @@ impl TableProvider for StreamTable { input: Arc, _insert_op: InsertOp, ) -> Result> { - let ordering = match self.0.order.first() { - Some(x) => { - let schema = self.0.source.schema(); - let orders = create_ordering(schema, std::slice::from_ref(x))?; - let ordering = orders.into_iter().next().unwrap(); - Some(ordering.into_iter().map(Into::into).collect()) - } - None => None, - }; + let schema = self.0.source.schema(); + let orders = + create_lex_ordering(schema, &self.0.order, _state.execution_props())?; + // It is sufficient to pass only one of the equivalent orderings: + let ordering = orders.into_iter().next().map(Into::into); Ok(Arc::new(DataSinkExec::new( input, @@ -440,6 +440,6 @@ impl DataSink for StreamWrite { write_task .join_unwind() .await - .map_err(DataFusionError::ExecutionJoin)? + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))? } } diff --git a/datafusion/catalog/src/streaming.rs b/datafusion/catalog/src/streaming.rs index 654e6755d7d4c..082e74dab9a15 100644 --- a/datafusion/catalog/src/streaming.rs +++ b/datafusion/catalog/src/streaming.rs @@ -20,15 +20,17 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::SchemaRef; -use async_trait::async_trait; - use crate::Session; use crate::TableProvider; -use datafusion_common::{plan_err, Result}; -use datafusion_expr::{Expr, TableType}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::{plan_err, DFSchema, Result}; +use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering}; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::ExecutionPlan; + +use async_trait::async_trait; use log::debug; /// A [`TableProvider`] that streams a set of [`PartitionStream`] @@ -37,6 +39,7 @@ pub struct StreamingTable { schema: SchemaRef, partitions: Vec>, infinite: bool, + sort_order: Vec, } impl StreamingTable { @@ -60,13 +63,21 @@ impl StreamingTable { schema, partitions, infinite: false, + sort_order: vec![], }) } + /// Sets streaming table can be infinite. pub fn with_infinite_table(mut self, infinite: bool) -> Self { self.infinite = infinite; self } + + /// Sets the existing ordering of streaming table. + pub fn with_sort_order(mut self, sort_order: Vec) -> Self { + self.sort_order = sort_order; + self + } } #[async_trait] @@ -85,16 +96,25 @@ impl TableProvider for StreamingTable { async fn scan( &self, - _state: &dyn Session, + state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], limit: Option, ) -> Result> { + let physical_sort = if !self.sort_order.is_empty() { + let df_schema = DFSchema::try_from(Arc::clone(&self.schema))?; + let eqp = state.execution_props(); + + create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)? + } else { + vec![] + }; + Ok(Arc::new(StreamingTableExec::try_new( Arc::clone(&self.schema), self.partitions.clone(), projection, - None, + LexOrdering::new(physical_sort), self.infinite, limit, )?)) diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 207abb9c66703..11c9af01a7a54 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -49,7 +49,7 @@ use datafusion_physical_plan::ExecutionPlan; /// [`CatalogProvider`]: super::CatalogProvider #[async_trait] pub trait TableProvider: Debug + Sync + Send { - /// Returns the table provider as [`Any`](std::any::Any) so that it can be + /// Returns the table provider as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -75,7 +75,7 @@ pub trait TableProvider: Debug + Sync + Send { } /// Get the [`LogicalPlan`] of this table, if available. - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { None } @@ -171,6 +171,37 @@ pub trait TableProvider: Debug + Sync + Send { limit: Option, ) -> Result>; + /// Create an [`ExecutionPlan`] for scanning the table using structured arguments. + /// + /// This method uses [`ScanArgs`] to pass scan parameters in a structured way + /// and returns a [`ScanResult`] containing the execution plan. + /// + /// Table providers can override this method to take advantage of additional + /// parameters like the upcoming `preferred_ordering` that may not be available through + /// other scan methods. + /// + /// # Arguments + /// * `state` - The session state containing configuration and context + /// * `args` - Structured scan arguments including projection, filters, limit, and ordering preferences + /// + /// # Returns + /// A [`ScanResult`] containing the [`ExecutionPlan`] for scanning the table + /// + /// See [`Self::scan`] for detailed documentation about projection, filters, and limits. + async fn scan_with_args<'a>( + &self, + state: &dyn Session, + args: ScanArgs<'a>, + ) -> Result { + let filters = args.filters().unwrap_or(&[]); + let projection = args.projection().map(|p| p.to_vec()); + let limit = args.limit(); + let plan = self + .scan(state, projection.as_ref(), filters, limit) + .await?; + Ok(plan.into()) + } + /// Specify if DataFusion should provide filter expressions to the /// TableProvider to apply *during* the scan. /// @@ -299,6 +330,114 @@ pub trait TableProvider: Debug + Sync + Send { } } +/// Arguments for scanning a table with [`TableProvider::scan_with_args`]. +#[derive(Debug, Clone, Default)] +pub struct ScanArgs<'a> { + filters: Option<&'a [Expr]>, + projection: Option<&'a [usize]>, + limit: Option, +} + +impl<'a> ScanArgs<'a> { + /// Set the column projection for the scan. + /// + /// The projection is a list of column indices from [`TableProvider::schema`] + /// that should be included in the scan results. If `None`, all columns are included. + /// + /// # Arguments + /// * `projection` - Optional slice of column indices to project + pub fn with_projection(mut self, projection: Option<&'a [usize]>) -> Self { + self.projection = projection; + self + } + + /// Get the column projection for the scan. + /// + /// Returns a reference to the projection column indices, or `None` if + /// no projection was specified (meaning all columns should be included). + pub fn projection(&self) -> Option<&'a [usize]> { + self.projection + } + + /// Set the filter expressions for the scan. + /// + /// Filters are boolean expressions that should be evaluated during the scan + /// to reduce the number of rows returned. All expressions are combined with AND logic. + /// Whether filters are actually pushed down depends on [`TableProvider::supports_filters_pushdown`]. + /// + /// # Arguments + /// * `filters` - Optional slice of filter expressions + pub fn with_filters(mut self, filters: Option<&'a [Expr]>) -> Self { + self.filters = filters; + self + } + + /// Get the filter expressions for the scan. + /// + /// Returns a reference to the filter expressions, or `None` if no filters were specified. + pub fn filters(&self) -> Option<&'a [Expr]> { + self.filters + } + + /// Set the maximum number of rows to return from the scan. + /// + /// If specified, the scan should return at most this many rows. This is typically + /// used to optimize queries with `LIMIT` clauses. + /// + /// # Arguments + /// * `limit` - Optional maximum number of rows to return + pub fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + /// Get the maximum number of rows to return from the scan. + /// + /// Returns the row limit, or `None` if no limit was specified. + pub fn limit(&self) -> Option { + self.limit + } +} + +/// Result of a table scan operation from [`TableProvider::scan_with_args`]. +#[derive(Debug, Clone)] +pub struct ScanResult { + /// The ExecutionPlan to run. + plan: Arc, +} + +impl ScanResult { + /// Create a new `ScanResult` with the given execution plan. + /// + /// # Arguments + /// * `plan` - The execution plan that will perform the table scan + pub fn new(plan: Arc) -> Self { + Self { plan } + } + + /// Get a reference to the execution plan for this scan result. + /// + /// Returns a reference to the [`ExecutionPlan`] that will perform + /// the actual table scanning and data retrieval. + pub fn plan(&self) -> &Arc { + &self.plan + } + + /// Consume this ScanResult and return the execution plan. + /// + /// Returns the owned [`ExecutionPlan`] that will perform + /// the actual table scanning and data retrieval. + pub fn into_inner(self) -> Arc { + self.plan + } +} + +impl From> for ScanResult { + fn from(plan: Arc) -> Self { + Self::new(plan) + } +} + /// A factory which creates [`TableProvider`]s at runtime given a URL. /// /// For example, this can be used to create a table "on the fly" @@ -320,7 +459,7 @@ pub trait TableFunctionImpl: Debug + Sync + Send { } /// A table that uses a function to generate data -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct TableFunction { /// Name of the table function name: String, diff --git a/datafusion/catalog/src/view.rs b/datafusion/catalog/src/view.rs index 8dfb79718c9bb..89c6a4a224511 100644 --- a/datafusion/catalog/src/view.rs +++ b/datafusion/catalog/src/view.rs @@ -51,7 +51,7 @@ impl ViewTable { /// Notes: the `LogicalPlan` is not validated or type coerced. If this is /// needed it should be done after calling this function. pub fn new(logical_plan: LogicalPlan, definition: Option) -> Self { - let table_schema = logical_plan.schema().as_ref().to_owned().into(); + let table_schema = Arc::clone(logical_plan.schema().inner()); Self { logical_plan, table_schema, @@ -87,7 +87,7 @@ impl TableProvider for ViewTable { self } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { Some(Cow::Borrowed(&self.logical_plan)) } diff --git a/datafusion/common-runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml index 5e7816b669de2..e53d97b41360a 100644 --- a/datafusion/common-runtime/Cargo.toml +++ b/datafusion/common-runtime/Cargo.toml @@ -43,4 +43,4 @@ log = { workspace = true } tokio = { workspace = true } [dev-dependencies] -tokio = { version = "1.44", features = ["rt", "rt-multi-thread", "time"] } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "time"] } diff --git a/datafusion/common-runtime/README.md b/datafusion/common-runtime/README.md index 77100e52603c9..ff44e6c3e209e 100644 --- a/datafusion/common-runtime/README.md +++ b/datafusion/common-runtime/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Common Runtime +# Apache DataFusion Common Runtime -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides common utilities. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/common-runtime/src/common.rs b/datafusion/common-runtime/src/common.rs index 361f6af95cf13..cebd6e04cd1b1 100644 --- a/datafusion/common-runtime/src/common.rs +++ b/datafusion/common-runtime/src/common.rs @@ -15,18 +15,25 @@ // specific language governing permissions and limitations // under the License. -use std::future::Future; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; -use crate::JoinSet; -use tokio::task::JoinError; +use tokio::task::{JoinError, JoinHandle}; + +use crate::trace_utils::{trace_block, trace_future}; /// Helper that provides a simple API to spawn a single task and join it. /// Provides guarantees of aborting on `Drop` to keep it cancel-safe. +/// Note that if the task was spawned with `spawn_blocking`, it will only be +/// aborted if it hasn't started yet. /// -/// Technically, it's just a wrapper of `JoinSet` (with size=1). +/// Technically, it's just a wrapper of a `JoinHandle` overriding drop. #[derive(Debug)] pub struct SpawnedTask { - inner: JoinSet, + inner: JoinHandle, } impl SpawnedTask { @@ -36,8 +43,9 @@ impl SpawnedTask { T: Send + 'static, R: Send, { - let mut inner = JoinSet::new(); - inner.spawn(task); + // Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop + #[allow(clippy::disallowed_methods)] + let inner = tokio::task::spawn(trace_future(task)); Self { inner } } @@ -47,29 +55,41 @@ impl SpawnedTask { T: Send + 'static, R: Send, { - let mut inner = JoinSet::new(); - inner.spawn_blocking(task); + // Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop + #[allow(clippy::disallowed_methods)] + let inner = tokio::task::spawn_blocking(trace_block(task)); Self { inner } } /// Joins the task, returning the result of join (`Result`). - pub async fn join(mut self) -> Result { - self.inner - .join_next() - .await - .expect("`SpawnedTask` instance always contains exactly 1 task") + /// Same as awaiting the spawned task, but left for backwards compatibility. + pub async fn join(self) -> Result { + self.await } /// Joins the task and unwinds the panic if it happens. - pub async fn join_unwind(self) -> Result { - self.join().await.map_err(|e| { + pub async fn join_unwind(mut self) -> Result { + self.join_unwind_mut().await + } + + /// Joins the task using a mutable reference and unwinds the panic if it happens. + /// + /// This method is similar to [`join_unwind`](Self::join_unwind), but takes a mutable + /// reference instead of consuming `self`. This allows the `SpawnedTask` to remain + /// usable after the call. + /// + /// If called multiple times on the same task: + /// - If the task is still running, it will continue waiting for completion + /// - If the task has already completed successfully, subsequent calls will + /// continue to return the same `JoinError` indicating the task is finished + /// - If the task panicked, the first call will resume the panic, and the + /// program will not reach subsequent calls + pub async fn join_unwind_mut(&mut self) -> Result { + self.await.map_err(|e| { // `JoinError` can be caused either by panic or cancellation. We have to handle panics: if e.is_panic() { std::panic::resume_unwind(e.into_panic()); } else { - // Cancellation may be caused by two reasons: - // 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside). - // 2. The runtime is shutting down. log::warn!("SpawnedTask was polled during shutdown"); e } @@ -77,17 +97,32 @@ impl SpawnedTask { } } +impl Future for SpawnedTask { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx) + } +} + +impl Drop for SpawnedTask { + fn drop(&mut self) { + self.inner.abort(); + } +} + #[cfg(test)] mod tests { use super::*; use std::future::{pending, Pending}; - use tokio::runtime::Runtime; + use tokio::{runtime::Runtime, sync::oneshot}; #[tokio::test] async fn runtime_shutdown() { let rt = Runtime::new().unwrap(); + #[allow(clippy::async_yields_async)] let task = rt .spawn(async { SpawnedTask::spawn(async { @@ -119,4 +154,36 @@ mod tests { .await .ok(); } + + #[tokio::test] + async fn cancel_not_started_task() { + let (sender, receiver) = oneshot::channel::(); + let task = SpawnedTask::spawn(async { + // Shouldn't be reached. + sender.send(42).unwrap(); + }); + + drop(task); + + // If the task was cancelled, the sender was also dropped, + // and awaiting the receiver should result in an error. + assert!(receiver.await.is_err()); + } + + #[tokio::test] + async fn cancel_ongoing_task() { + let (sender, mut receiver) = tokio::sync::mpsc::channel(1); + let task = SpawnedTask::spawn(async move { + sender.send(1).await.unwrap(); + // This line will never be reached because the channel has a buffer + // of 1. + sender.send(2).await.unwrap(); + }); + // Let the task start. + assert_eq!(receiver.recv().await.unwrap(), 1); + drop(task); + + // The sender was dropped so we receive `None`. + assert!(receiver.recv().await.is_none()); + } } diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index ec8db0bdcd911..5d404d99e7760 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -30,4 +30,6 @@ mod trace_utils; pub use common::SpawnedTask; pub use join_set::JoinSet; -pub use trace_utils::{set_join_set_tracer, JoinSetTracer}; +pub use trace_utils::{ + set_join_set_tracer, trace_block, trace_future, JoinSetTracer, JoinSetTracerError, +}; diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 39b47a96bccf3..f5e51cb236d47 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -40,13 +40,20 @@ name = "datafusion_common" [features] avro = ["apache-avro"] backtrace = [] +parquet_encryption = [ + "parquet", + "parquet/encryption", + "dep:hex", +] pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] force_hash_collisions = [] recursive_protection = ["dep:recursive"] +parquet = ["dep:parquet"] +sql = ["sqlparser"] [dependencies] ahash = { workspace = true } -apache-avro = { version = "0.17", default-features = false, features = [ +apache-avro = { version = "0.20", default-features = false, features = [ "bzip", "snappy", "xz", @@ -54,18 +61,19 @@ apache-avro = { version = "0.17", default-features = false, features = [ ], optional = true } arrow = { workspace = true } arrow-ipc = { workspace = true } -base64 = "0.22.1" +chrono = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } +hex = { workspace = true, optional = true } indexmap = { workspace = true } -libc = "0.2.171" +libc = "0.2.176" log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" -pyo3 = { version = "0.23.5", optional = true } +pyo3 = { version = "0.25", optional = true } recursive = { workspace = true, optional = true } -sqlparser = { workspace = true } +sqlparser = { workspace = true, optional = true } tokio = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] @@ -75,3 +83,4 @@ web-time = "1.1.0" chrono = { workspace = true } insta = { workspace = true } rand = { workspace = true } +sqlparser = { workspace = true } diff --git a/datafusion/common/README.md b/datafusion/common/README.md index 524ab4420d2a8..4948c8c581be9 100644 --- a/datafusion/common/README.md +++ b/datafusion/common/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Common +# Apache DataFusion Common -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides common data types and utilities. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 28202c6684b50..e6eda3c585e89 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -22,8 +22,10 @@ use crate::{downcast_value, Result}; use arrow::array::{ - BinaryViewArray, Float16Array, Int16Array, Int8Array, LargeBinaryArray, - LargeStringArray, StringViewArray, UInt16Array, + BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, + Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray, + UInt16Array, }; use arrow::{ array::{ @@ -41,246 +43,282 @@ use arrow::{ datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; -// Downcast ArrayRef to Date32Array +// Downcast Array to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { Ok(downcast_value!(array, Date32Array)) } -// Downcast ArrayRef to Date64Array +// Downcast Array to Date64Array pub fn as_date64_array(array: &dyn Array) -> Result<&Date64Array> { Ok(downcast_value!(array, Date64Array)) } -// Downcast ArrayRef to StructArray +// Downcast Array to StructArray pub fn as_struct_array(array: &dyn Array) -> Result<&StructArray> { Ok(downcast_value!(array, StructArray)) } -// Downcast ArrayRef to Int8Array +// Downcast Array to Int8Array pub fn as_int8_array(array: &dyn Array) -> Result<&Int8Array> { Ok(downcast_value!(array, Int8Array)) } -// Downcast ArrayRef to UInt8Array +// Downcast Array to UInt8Array pub fn as_uint8_array(array: &dyn Array) -> Result<&UInt8Array> { Ok(downcast_value!(array, UInt8Array)) } -// Downcast ArrayRef to Int16Array +// Downcast Array to Int16Array pub fn as_int16_array(array: &dyn Array) -> Result<&Int16Array> { Ok(downcast_value!(array, Int16Array)) } -// Downcast ArrayRef to UInt16Array +// Downcast Array to UInt16Array pub fn as_uint16_array(array: &dyn Array) -> Result<&UInt16Array> { Ok(downcast_value!(array, UInt16Array)) } -// Downcast ArrayRef to Int32Array +// Downcast Array to Int32Array pub fn as_int32_array(array: &dyn Array) -> Result<&Int32Array> { Ok(downcast_value!(array, Int32Array)) } -// Downcast ArrayRef to UInt32Array +// Downcast Array to UInt32Array pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> { Ok(downcast_value!(array, UInt32Array)) } -// Downcast ArrayRef to Int64Array +// Downcast Array to Int64Array pub fn as_int64_array(array: &dyn Array) -> Result<&Int64Array> { Ok(downcast_value!(array, Int64Array)) } -// Downcast ArrayRef to UInt64Array +// Downcast Array to UInt64Array pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> { Ok(downcast_value!(array, UInt64Array)) } -// Downcast ArrayRef to Decimal128Array +// Downcast Array to Decimal32Array +pub fn as_decimal32_array(array: &dyn Array) -> Result<&Decimal32Array> { + Ok(downcast_value!(array, Decimal32Array)) +} + +// Downcast Array to Decimal64Array +pub fn as_decimal64_array(array: &dyn Array) -> Result<&Decimal64Array> { + Ok(downcast_value!(array, Decimal64Array)) +} + +// Downcast Array to Decimal128Array pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> { Ok(downcast_value!(array, Decimal128Array)) } -// Downcast ArrayRef to Decimal256Array +// Downcast Array to Decimal256Array pub fn as_decimal256_array(array: &dyn Array) -> Result<&Decimal256Array> { Ok(downcast_value!(array, Decimal256Array)) } -// Downcast ArrayRef to Float16Array +// Downcast Array to Float16Array pub fn as_float16_array(array: &dyn Array) -> Result<&Float16Array> { Ok(downcast_value!(array, Float16Array)) } -// Downcast ArrayRef to Float32Array +// Downcast Array to Float32Array pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array> { Ok(downcast_value!(array, Float32Array)) } -// Downcast ArrayRef to Float64Array +// Downcast Array to Float64Array pub fn as_float64_array(array: &dyn Array) -> Result<&Float64Array> { Ok(downcast_value!(array, Float64Array)) } -// Downcast ArrayRef to StringArray +// Downcast Array to StringArray pub fn as_string_array(array: &dyn Array) -> Result<&StringArray> { Ok(downcast_value!(array, StringArray)) } -// Downcast ArrayRef to StringViewArray +// Downcast Array to StringViewArray pub fn as_string_view_array(array: &dyn Array) -> Result<&StringViewArray> { Ok(downcast_value!(array, StringViewArray)) } -// Downcast ArrayRef to LargeStringArray +// Downcast Array to LargeStringArray pub fn as_large_string_array(array: &dyn Array) -> Result<&LargeStringArray> { Ok(downcast_value!(array, LargeStringArray)) } -// Downcast ArrayRef to BooleanArray +// Downcast Array to BooleanArray pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray> { Ok(downcast_value!(array, BooleanArray)) } -// Downcast ArrayRef to ListArray +// Downcast Array to ListArray pub fn as_list_array(array: &dyn Array) -> Result<&ListArray> { Ok(downcast_value!(array, ListArray)) } -// Downcast ArrayRef to DictionaryArray +// Downcast Array to DictionaryArray pub fn as_dictionary_array( array: &dyn Array, ) -> Result<&DictionaryArray> { Ok(downcast_value!(array, DictionaryArray, T)) } -// Downcast ArrayRef to GenericBinaryArray +// Downcast Array to GenericBinaryArray pub fn as_generic_binary_array( array: &dyn Array, ) -> Result<&GenericBinaryArray> { Ok(downcast_value!(array, GenericBinaryArray, T)) } -// Downcast ArrayRef to GenericListArray +// Downcast Array to GenericListArray pub fn as_generic_list_array( array: &dyn Array, ) -> Result<&GenericListArray> { Ok(downcast_value!(array, GenericListArray, T)) } -// Downcast ArrayRef to LargeListArray +// Downcast Array to LargeListArray pub fn as_large_list_array(array: &dyn Array) -> Result<&LargeListArray> { Ok(downcast_value!(array, LargeListArray)) } -// Downcast ArrayRef to PrimitiveArray +// Downcast Array to PrimitiveArray pub fn as_primitive_array( array: &dyn Array, ) -> Result<&PrimitiveArray> { Ok(downcast_value!(array, PrimitiveArray, T)) } -// Downcast ArrayRef to MapArray +// Downcast Array to MapArray pub fn as_map_array(array: &dyn Array) -> Result<&MapArray> { Ok(downcast_value!(array, MapArray)) } -// Downcast ArrayRef to NullArray +// Downcast Array to NullArray pub fn as_null_array(array: &dyn Array) -> Result<&NullArray> { Ok(downcast_value!(array, NullArray)) } -// Downcast ArrayRef to NullArray +// Downcast Array to NullArray pub fn as_union_array(array: &dyn Array) -> Result<&UnionArray> { Ok(downcast_value!(array, UnionArray)) } -// Downcast ArrayRef to Time32SecondArray +// Downcast Array to Time32SecondArray pub fn as_time32_second_array(array: &dyn Array) -> Result<&Time32SecondArray> { Ok(downcast_value!(array, Time32SecondArray)) } -// Downcast ArrayRef to Time32MillisecondArray +// Downcast Array to Time32MillisecondArray pub fn as_time32_millisecond_array(array: &dyn Array) -> Result<&Time32MillisecondArray> { Ok(downcast_value!(array, Time32MillisecondArray)) } -// Downcast ArrayRef to Time64MicrosecondArray +// Downcast Array to Time64MicrosecondArray pub fn as_time64_microsecond_array(array: &dyn Array) -> Result<&Time64MicrosecondArray> { Ok(downcast_value!(array, Time64MicrosecondArray)) } -// Downcast ArrayRef to Time64NanosecondArray +// Downcast Array to Time64NanosecondArray pub fn as_time64_nanosecond_array(array: &dyn Array) -> Result<&Time64NanosecondArray> { Ok(downcast_value!(array, Time64NanosecondArray)) } -// Downcast ArrayRef to TimestampNanosecondArray +// Downcast Array to TimestampNanosecondArray pub fn as_timestamp_nanosecond_array( array: &dyn Array, ) -> Result<&TimestampNanosecondArray> { Ok(downcast_value!(array, TimestampNanosecondArray)) } -// Downcast ArrayRef to TimestampMillisecondArray +// Downcast Array to TimestampMillisecondArray pub fn as_timestamp_millisecond_array( array: &dyn Array, ) -> Result<&TimestampMillisecondArray> { Ok(downcast_value!(array, TimestampMillisecondArray)) } -// Downcast ArrayRef to TimestampMicrosecondArray +// Downcast Array to TimestampMicrosecondArray pub fn as_timestamp_microsecond_array( array: &dyn Array, ) -> Result<&TimestampMicrosecondArray> { Ok(downcast_value!(array, TimestampMicrosecondArray)) } -// Downcast ArrayRef to TimestampSecondArray +// Downcast Array to TimestampSecondArray pub fn as_timestamp_second_array(array: &dyn Array) -> Result<&TimestampSecondArray> { Ok(downcast_value!(array, TimestampSecondArray)) } -// Downcast ArrayRef to IntervalYearMonthArray +// Downcast Array to IntervalYearMonthArray pub fn as_interval_ym_array(array: &dyn Array) -> Result<&IntervalYearMonthArray> { Ok(downcast_value!(array, IntervalYearMonthArray)) } -// Downcast ArrayRef to IntervalDayTimeArray +// Downcast Array to IntervalDayTimeArray pub fn as_interval_dt_array(array: &dyn Array) -> Result<&IntervalDayTimeArray> { Ok(downcast_value!(array, IntervalDayTimeArray)) } -// Downcast ArrayRef to IntervalMonthDayNanoArray +// Downcast Array to IntervalMonthDayNanoArray pub fn as_interval_mdn_array(array: &dyn Array) -> Result<&IntervalMonthDayNanoArray> { Ok(downcast_value!(array, IntervalMonthDayNanoArray)) } -// Downcast ArrayRef to BinaryArray +// Downcast Array to DurationSecondArray +pub fn as_duration_second_array(array: &dyn Array) -> Result<&DurationSecondArray> { + Ok(downcast_value!(array, DurationSecondArray)) +} + +// Downcast Array to DurationMillisecondArray +pub fn as_duration_millisecond_array( + array: &dyn Array, +) -> Result<&DurationMillisecondArray> { + Ok(downcast_value!(array, DurationMillisecondArray)) +} + +// Downcast Array to DurationMicrosecondArray +pub fn as_duration_microsecond_array( + array: &dyn Array, +) -> Result<&DurationMicrosecondArray> { + Ok(downcast_value!(array, DurationMicrosecondArray)) +} + +// Downcast Array to DurationNanosecondArray +pub fn as_duration_nanosecond_array( + array: &dyn Array, +) -> Result<&DurationNanosecondArray> { + Ok(downcast_value!(array, DurationNanosecondArray)) +} + +// Downcast Array to BinaryArray pub fn as_binary_array(array: &dyn Array) -> Result<&BinaryArray> { Ok(downcast_value!(array, BinaryArray)) } -// Downcast ArrayRef to BinaryViewArray +// Downcast Array to BinaryViewArray pub fn as_binary_view_array(array: &dyn Array) -> Result<&BinaryViewArray> { Ok(downcast_value!(array, BinaryViewArray)) } -// Downcast ArrayRef to LargeBinaryArray +// Downcast Array to LargeBinaryArray pub fn as_large_binary_array(array: &dyn Array) -> Result<&LargeBinaryArray> { Ok(downcast_value!(array, LargeBinaryArray)) } -// Downcast ArrayRef to FixedSizeListArray +// Downcast Array to FixedSizeListArray pub fn as_fixed_size_list_array(array: &dyn Array) -> Result<&FixedSizeListArray> { Ok(downcast_value!(array, FixedSizeListArray)) } -// Downcast ArrayRef to FixedSizeListArray +// Downcast Array to FixedSizeBinaryArray pub fn as_fixed_size_binary_array(array: &dyn Array) -> Result<&FixedSizeBinaryArray> { Ok(downcast_value!(array, FixedSizeBinaryArray)) } -// Downcast ArrayRef to GenericBinaryArray +// Downcast Array to GenericBinaryArray pub fn as_generic_string_array( array: &dyn Array, ) -> Result<&GenericStringArray> { diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 50a4e257d1c99..c7f0b5a4f4881 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -18,13 +18,12 @@ //! Column use crate::error::{_schema_err, add_possible_columns_to_diag}; -use crate::utils::{parse_identifiers_normalized, quote_identifier}; +use crate::utils::parse_identifiers_normalized; +use crate::utils::quote_identifier; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; use arrow::datatypes::{Field, FieldRef}; use std::collections::HashSet; -use std::convert::Infallible; use std::fmt; -use std::str::FromStr; /// A named reference to a qualified field in a schema. #[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -130,8 +129,8 @@ impl Column { /// where `"foo.BAR"` would be parsed to a reference to column named `foo.BAR` pub fn from_qualified_name(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - Self::from_idents(parse_identifiers_normalized(&flat_name, false)).unwrap_or( - Self { + Self::from_idents(parse_identifiers_normalized(&flat_name, false)).unwrap_or_else( + || Self { relation: None, name: flat_name, spans: Spans::new(), @@ -140,10 +139,11 @@ impl Column { } /// Deserialize a fully qualified name string into a column preserving column text case + #[cfg(feature = "sql")] pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - Self::from_idents(parse_identifiers_normalized(&flat_name, true)).unwrap_or( - Self { + Self::from_idents(parse_identifiers_normalized(&flat_name, true)).unwrap_or_else( + || Self { relation: None, name: flat_name, spans: Spans::new(), @@ -151,6 +151,11 @@ impl Column { ) } + #[cfg(not(feature = "sql"))] + pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { + Self::from_qualified_name(flat_name) + } + /// return the column's name. /// /// Note: This ignores the relation and returns the column name only. @@ -262,7 +267,7 @@ impl Column { // If not due to USING columns then due to ambiguous column name return _schema_err!(SchemaError::AmbiguousReference { - field: Column::new_unqualified(&self.name), + field: Box::new(Column::new_unqualified(&self.name)), }) .map_err(|err| { let mut diagnostic = Diagnostic::new_error( @@ -356,8 +361,9 @@ impl From<(Option<&TableReference>, &FieldRef)> for Column { } } -impl FromStr for Column { - type Err = Infallible; +#[cfg(feature = "sql")] +impl std::str::FromStr for Column { + type Err = std::convert::Infallible; fn from_str(s: &str) -> Result { Ok(s.into()) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index b0f17630c910c..39d730eaafb49 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -17,16 +17,23 @@ //! Runtime configuration, via [`ConfigOptions`] +use arrow_ipc::CompressionType; + +#[cfg(feature = "parquet_encryption")] +use crate::encryption::{FileDecryptionProperties, FileEncryptionProperties}; +use crate::error::_config_err; +use crate::format::ExplainFormat; +use crate::parsers::CompressionTypeVariant; +use crate::utils::get_available_parallelism; +use crate::{DataFusionError, Result}; use std::any::Any; use std::collections::{BTreeMap, HashMap}; use std::error::Error; use std::fmt::{self, Display}; use std::str::FromStr; -use crate::error::_config_err; -use crate::parsers::CompressionTypeVariant; -use crate::utils::get_available_parallelism; -use crate::{DataFusionError, Result}; +#[cfg(feature = "parquet_encryption")] +use hex; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -149,9 +156,17 @@ macro_rules! config_namespace { // $(#[allow(deprecated)])? { $(let value = $transform(value);)? // Apply transformation if specified - $(log::warn!($warn);)? // Log warning if specified #[allow(deprecated)] - self.$field_name.set(rem, value.as_ref()) + let ret = self.$field_name.set(rem, value.as_ref()); + + $(if !$warn.is_empty() { + let default: $field_type = $default; + #[allow(deprecated)] + if default != self.$field_name { + log::warn!($warn); + } + })? // Log warning if specified, and the value is not the default + ret } }, )* @@ -252,10 +267,10 @@ config_namespace! { /// string length and thus DataFusion can not enforce such limits. pub support_varchar_with_length: bool, default = true - /// If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. - /// If false, `VARCHAR` is mapped to `Utf8` during SQL planning. - /// Default is false. - pub map_varchar_to_utf8view: bool, default = false + /// If true, string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. + /// If false, they are mapped to `Utf8`. + /// Default is true. + pub map_string_types_to_utf8view: bool, default = true /// When set to true, the source locations relative to the original SQL /// query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected @@ -264,6 +279,71 @@ config_namespace! { /// Specifies the recursion depth limit when parsing complex SQL Queries pub recursion_limit: usize, default = 50 + + /// Specifies the default null ordering for query results. There are 4 options: + /// - `nulls_max`: Nulls appear last in ascending order. + /// - `nulls_min`: Nulls appear first in ascending order. + /// - `nulls_first`: Nulls always be first in any order. + /// - `nulls_last`: Nulls always be last in any order. + /// + /// By default, `nulls_max` is used to follow Postgres's behavior. + /// postgres rule: + pub default_null_ordering: String, default = "nulls_max".to_string() + } +} + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum SpillCompression { + Zstd, + Lz4Frame, + #[default] + Uncompressed, +} + +impl FromStr for SpillCompression { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "zstd" => Ok(Self::Zstd), + "lz4_frame" => Ok(Self::Lz4Frame), + "uncompressed" | "" => Ok(Self::Uncompressed), + other => Err(DataFusionError::Configuration(format!( + "Invalid Spill file compression type: {other}. Expected one of: zstd, lz4_frame, uncompressed" + ))), + } + } +} + +impl ConfigField for SpillCompression { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = SpillCompression::from_str(value)?; + Ok(()) + } +} + +impl Display for SpillCompression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let str = match self { + Self::Zstd => "zstd", + Self::Lz4Frame => "lz4_frame", + Self::Uncompressed => "uncompressed", + }; + write!(f, "{str}") + } +} + +impl From for Option { + fn from(c: SpillCompression) -> Self { + match c { + SpillCompression::Zstd => Some(CompressionType::ZSTD), + SpillCompression::Lz4Frame => Some(CompressionType::LZ4_FRAME), + SpillCompression::Uncompressed => None, + } } } @@ -285,20 +365,22 @@ config_namespace! { /// target batch size is determined by the configuration setting pub coalesce_batches: bool, default = true - /// Should DataFusion collect statistics after listing files - pub collect_statistics: bool, default = false + /// Should DataFusion collect statistics when first creating a table. + /// Has no effect after the table is created. Applies to the default + /// `ListingTableProvider` in DataFusion. Defaults to true. + pub collect_statistics: bool, default = true /// Number of partitions for query execution. Increasing partitions can increase /// concurrency. /// /// Defaults to the number of CPU cores on the system - pub target_partitions: usize, default = get_available_parallelism() + pub target_partitions: usize, transform = ExecutionOptions::normalized_parallelism, default = get_available_parallelism() /// The default time zone /// /// Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime /// according to this time zone, and then extract the hour - pub time_zone: Option, default = Some("+00:00".into()) + pub time_zone: String, default = "+00:00".into() /// Parquet options pub parquet: ParquetOptions, default = Default::default() @@ -308,7 +390,7 @@ config_namespace! { /// This is mostly use to plan `UNION` children in parallel. /// /// Defaults to the number of CPU cores on the system - pub planning_concurrency: usize, default = get_available_parallelism() + pub planning_concurrency: usize, transform = ExecutionOptions::normalized_parallelism, default = get_available_parallelism() /// When set to true, skips verifying that the schema produced by /// planning the input of `LogicalPlan::Aggregate` exactly matches the @@ -321,6 +403,16 @@ config_namespace! { /// the new schema verification step. pub skip_physical_aggregate_schema_check: bool, default = false + /// Sets the compression codec used when spilling data to disk. + /// + /// Since datafusion writes spill files using the Arrow IPC Stream format, + /// only codecs supported by the Arrow IPC Stream Writer are allowed. + /// Valid values are: uncompressed, lz4_frame, zstd. + /// Note: lz4_frame offers faster (de)compression, but typically results in + /// larger spill files. In contrast, zstd achieves + /// higher compression ratios at the cost of slower (de)compression speed. + pub spill_compression: SpillCompression, default = SpillCompression::Uncompressed + /// Specifies the reserved memory for each spillable sort operation to /// facilitate an in-memory merge. /// @@ -364,6 +456,11 @@ config_namespace! { /// tables (e.g. `/table/year=2021/month=01/data.parquet`). pub listing_table_ignore_subdirectory: bool, default = true + /// Should a `ListingTable` created through the `ListingTableFactory` infer table + /// partitions from Hive compliant directories. Defaults to true (partition columns are + /// inferred and will be represented in the table schema). + pub listing_table_factory_infer_partitions: bool, default = true + /// Should DataFusion support recursive CTEs pub enable_recursive_ctes: bool, default = true @@ -397,6 +494,13 @@ config_namespace! { /// in joins can reduce memory usage when joining large /// tables with a highly-selective join filter, but is also slightly slower. pub enforce_batch_size_in_joins: bool, default = false + + /// Size (bytes) of data buffer DataFusion uses when writing output files. + /// This affects the size of the data chunks that are uploaded to remote + /// object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being + /// written, it may be necessary to increase this size to avoid errors from + /// the remote end point. + pub objectstore_writer_buffer_size: usize, default = 10 * 1024 * 1024 } } @@ -451,6 +555,25 @@ config_namespace! { /// BLOB instead. pub binary_as_string: bool, default = false + /// (reading) If true, parquet reader will read columns of + /// physical type int96 as originating from a different resolution + /// than nanosecond. This is useful for reading data from systems like Spark + /// which stores microsecond resolution timestamps in an int96 allowing it + /// to write values with a larger date range than 64-bit timestamps with + /// nanosecond resolution. + pub coerce_int96: Option, transform = str::to_lowercase, default = None + + /// (reading) Use any available bloom filters when reading parquet files + pub bloom_filter_on_read: bool, default = true + + /// (reading) The maximum predicate cache size, in bytes. When + /// `pushdown_filters` is enabled, sets the maximum memory used to cache + /// the results of predicate evaluation between filter evaluation and + /// output generation. Decreasing this value will reduce memory usage, + /// but may increase IO and CPU usage. None means use the default + /// parquet reader setting. 0 means no caching. + pub max_predicate_cache_size: Option, default = None + // The following options affect writing to parquet files // and map to parquet::file::properties::WriterProperties @@ -493,13 +616,6 @@ config_namespace! { /// default parquet writer setting pub statistics_enabled: Option, transform = str::to_lowercase, default = Some("page".into()) - /// (writing) Sets max statistics size for any column. If NULL, uses - /// default parquet writer setting - /// max_statistics_size is deprecated, currently it is not being used - // TODO: remove once deprecated - #[deprecated(since = "45.0.0", note = "Setting does not do anything")] - pub max_statistics_size: Option, default = Some(4096) - /// (writing) Target maximum number of rows in each row group (defaults to 1M /// rows). Writing larger row groups requires more memory to write, but /// can get better compression and be faster to read. @@ -511,9 +627,9 @@ config_namespace! { /// (writing) Sets column index truncate length pub column_index_truncate_length: Option, default = Some(64) - /// (writing) Sets statictics truncate length. If NULL, uses + /// (writing) Sets statistics truncate length. If NULL, uses /// default parquet writer setting - pub statistics_truncate_length: Option, default = None + pub statistics_truncate_length: Option, default = Some(64) /// (writing) Sets best effort maximum number of rows in data page pub data_page_row_count_limit: usize, default = 20_000 @@ -526,9 +642,6 @@ config_namespace! { /// default parquet writer setting pub encoding: Option, transform = str::to_lowercase, default = None - /// (writing) Use any available bloom filters when reading parquet files - pub bloom_filter_on_read: bool, default = true - /// (writing) Write bloom filters for all columns when creating parquet files pub bloom_filter_on_write: bool, default = false @@ -570,6 +683,44 @@ config_namespace! { } } +config_namespace! { + /// Options for configuring Parquet Modular Encryption + /// + /// To use Parquet encryption, you must enable the `parquet_encryption` feature flag, as it is not activated by default. + pub struct ParquetEncryptionOptions { + /// Optional file decryption properties + pub file_decryption: Option, default = None + + /// Optional file encryption properties + pub file_encryption: Option, default = None + + /// Identifier for the encryption factory to use to create file encryption and decryption properties. + /// Encryption factories can be registered in the runtime environment with + /// `RuntimeEnv::register_parquet_encryption_factory`. + pub factory_id: Option, default = None + + /// Any encryption factory specific options + pub factory_options: EncryptionFactoryOptions, default = EncryptionFactoryOptions::default() + } +} + +impl ParquetEncryptionOptions { + /// Specify the encryption factory to use for Parquet modular encryption, along with its configuration + pub fn configure_factory( + &mut self, + factory_id: &str, + config: &impl ExtensionOptions, + ) { + self.factory_id = Some(factory_id.to_owned()); + self.factory_options.options.clear(); + for entry in config.entries() { + if let Some(value) = entry.value { + self.factory_options.options.insert(entry.key, value); + } + } + } +} + config_namespace! { /// Options related to query optimization /// @@ -590,6 +741,17 @@ config_namespace! { /// during aggregations, if possible pub enable_topk_aggregation: bool, default = true + /// When set to true, the optimizer will attempt to push limit operations + /// past window functions, if possible + pub enable_window_limits: bool, default = true + + /// When set to true attempts to push down dynamic filters generated by operators into the file scan phase. + /// For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer + /// will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. + /// This means that if we already have 10 timestamps in the year 2025 + /// any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. + pub enable_dynamic_filter_pushdown: bool, default = true + /// When set to true, the optimizer will insert filters before a join between /// a nullable and non-nullable column to filter out nulls on the nullable side. This /// filter can add additional overhead when the file format does not fully support @@ -616,13 +778,20 @@ config_namespace! { /// long runner execution, all types of joins may encounter out-of-memory errors. pub allow_symmetric_joins_without_pruning: bool, default = true - /// When set to `true`, file groups will be repartitioned to achieve maximum parallelism. - /// Currently Parquet and CSV formats are supported. + /// When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. + /// This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). + /// + /// For FileSources, only Parquet and CSV formats are currently supported. /// - /// If set to `true`, all files will be repartitioned evenly (i.e., a single large file + /// If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file /// might be partitioned into smaller chunks) for parallel scanning. - /// If set to `false`, different files will be read in parallel, but repartitioning won't + /// If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't /// happen within a single file. + /// + /// If set to `true` for an in-memory source, all memtable's partitions will have their batches + /// repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change + /// the total number of partitions and batches per partition, but does not slice the initial + /// record tables provided to the MemTable on creation. pub repartition_file_scans: bool, default = true /// Should DataFusion repartition data using the partitions keys to execute window @@ -719,12 +888,82 @@ config_namespace! { /// Display format of explain. Default is "indent". /// When set to "tree", it will print the plan in a tree-rendered format. - pub format: String, default = "indent".to_string() + pub format: ExplainFormat, default = ExplainFormat::Indent + + /// (format=tree only) Maximum total width of the rendered tree. + /// When set to 0, the tree will have no width limit. + pub tree_maximum_render_width: usize, default = 240 + } +} + +impl ExecutionOptions { + /// Returns the correct parallelism based on the provided `value`. + /// If `value` is `"0"`, returns the default available parallelism, computed with + /// `get_available_parallelism`. Otherwise, returns `value`. + fn normalized_parallelism(value: &str) -> String { + if value.parse::() == Ok(0) { + get_available_parallelism().to_string() + } else { + value.to_owned() + } + } +} + +config_namespace! { + /// Options controlling the format of output when printing record batches + /// Copies [`arrow::util::display::FormatOptions`] + pub struct FormatOptions { + /// If set to `true` any formatting errors will be written to the output + /// instead of being converted into a [`std::fmt::Error`] + pub safe: bool, default = true + /// Format string for nulls + pub null: String, default = "".into() + /// Date format for date arrays + pub date_format: Option, default = Some("%Y-%m-%d".to_string()) + /// Format for DateTime arrays + pub datetime_format: Option, default = Some("%Y-%m-%dT%H:%M:%S%.f".to_string()) + /// Timestamp format for timestamp arrays + pub timestamp_format: Option, default = Some("%Y-%m-%dT%H:%M:%S%.f".to_string()) + /// Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. + pub timestamp_tz_format: Option, default = None + /// Time format for time arrays + pub time_format: Option, default = Some("%H:%M:%S%.f".to_string()) + /// Duration format. Can be either `"pretty"` or `"ISO8601"` + pub duration_format: String, transform = str::to_lowercase, default = "pretty".into() + /// Show types in visual representation batches + pub types_info: bool, default = false + } +} + +impl<'a> TryInto> for &'a FormatOptions { + type Error = DataFusionError; + fn try_into(self) -> Result> { + let duration_format = match self.duration_format.as_str() { + "pretty" => arrow::util::display::DurationFormat::Pretty, + "iso8601" => arrow::util::display::DurationFormat::ISO8601, + _ => { + return _config_err!( + "Invalid duration format: {}. Valid values are pretty or iso8601", + self.duration_format + ) + } + }; + + Ok(arrow::util::display::FormatOptions::new() + .with_display_error(self.safe) + .with_null(&self.null) + .with_date_format(self.date_format.as_deref()) + .with_datetime_format(self.datetime_format.as_deref()) + .with_timestamp_format(self.timestamp_format.as_deref()) + .with_timestamp_tz_format(self.timestamp_tz_format.as_deref()) + .with_time_format(self.time_format.as_deref()) + .with_duration_format(duration_format) + .with_types_info(self.types_info)) } } /// A key value pair, with a corresponding description -#[derive(Debug)] +#[derive(Debug, Hash, PartialEq, Eq)] pub struct ConfigEntry { /// A unique string to identify this config value pub key: String, @@ -752,6 +991,8 @@ pub struct ConfigOptions { pub explain: ExplainOptions, /// Optional extensions registered using [`Extensions::insert`] pub extensions: Extensions, + /// Formatting options when printing batches + pub format: FormatOptions, } impl ConfigField for ConfigOptions { @@ -764,6 +1005,7 @@ impl ConfigField for ConfigOptions { "optimizer" => self.optimizer.set(rem, value), "explain" => self.explain.set(rem, value), "sql_parser" => self.sql_parser.set(rem, value), + "format" => self.format.set(rem, value), _ => _config_err!("Config value \"{key}\" not found on ConfigOptions"), } } @@ -774,6 +1016,7 @@ impl ConfigField for ConfigOptions { self.optimizer.visit(v, "datafusion.optimizer", ""); self.explain.visit(v, "datafusion.explain", ""); self.sql_parser.visit(v, "datafusion.sql_parser", ""); + self.format.visit(v, "datafusion.format", ""); } } @@ -805,11 +1048,23 @@ impl ConfigOptions { e.0.set(key, value) } - /// Create new ConfigOptions struct, taking values from - /// environment variables where possible. + /// Create new [`ConfigOptions`], taking values from environment variables + /// where possible. + /// + /// For example, to configure `datafusion.execution.batch_size` + /// ([`ExecutionOptions::batch_size`]) you would set the + /// `DATAFUSION_EXECUTION_BATCH_SIZE` environment variable. + /// + /// The name of the environment variable is the option's key, transformed to + /// uppercase and with periods replaced with underscores. /// - /// For example, setting `DATAFUSION_EXECUTION_BATCH_SIZE` will - /// control `datafusion.execution.batch_size`. + /// Values are parsed according to the [same rules used in casts from + /// Utf8](https://docs.rs/arrow/latest/arrow/compute/kernels/cast/fn.cast.html). + /// + /// If the value in the environment variable cannot be cast to the type of + /// the configuration option, the default value will be used instead and a + /// warning emitted. Environment variables are read when this method is + /// called, and are not re-read later. pub fn from_env() -> Result { struct Visitor(Vec); @@ -835,7 +1090,9 @@ impl ConfigOptions { for key in keys.0 { let env = key.to_uppercase().replace('.', "_"); if let Some(var) = std::env::var_os(env) { - ret.set(&key, var.to_string_lossy().as_ref())?; + let value = var.to_string_lossy(); + log::info!("Set {key} to {value} from the environment variable"); + ret.set(&key, value.as_ref())?; } } @@ -1071,7 +1328,10 @@ impl ConfigField for Option { } } -fn default_transform(input: &str) -> Result +/// Default transformation to parse a [`ConfigField`] for a string. +/// +/// This uses [`FromStr`] to parse the data. +pub fn default_config_transform(input: &str) -> Result where T: FromStr, ::Err: Sync + Send + Error + 'static, @@ -1088,19 +1348,45 @@ where }) } +/// Macro that generates [`ConfigField`] for a given type. +/// +/// # Usage +/// This always requires [`Display`] to be implemented for the given type. +/// +/// There are two ways to invoke this macro. The first one uses +/// [`default_config_transform`]/[`FromStr`] to parse the data: +/// +/// ```ignore +/// config_field(MyType); +/// ``` +/// +/// Note that the parsing error MUST implement [`std::error::Error`]! +/// +/// Or you can specify how you want to parse an [`str`] into the type: +/// +/// ```ignore +/// fn parse_it(s: &str) -> Result { +/// ... +/// } +/// +/// config_field( +/// MyType, +/// value => parse_it(value) +/// ); +/// ``` #[macro_export] macro_rules! config_field { ($t:ty) => { - config_field!($t, value => default_transform(value)?); + config_field!($t, value => $crate::config::default_config_transform(value)?); }; ($t:ty, $arg:ident => $transform:expr) => { - impl ConfigField for $t { - fn visit(&self, v: &mut V, key: &str, description: &'static str) { + impl $crate::config::ConfigField for $t { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { v.some(key, self, description) } - fn set(&mut self, _: &str, $arg: &str) -> Result<()> { + fn set(&mut self, _: &str, $arg: &str) -> $crate::error::Result<()> { *self = $transform; Ok(()) } @@ -1109,7 +1395,7 @@ macro_rules! config_field { } config_field!(String); -config_field!(bool, value => default_transform(value.to_lowercase().as_str())?); +config_field!(bool, value => default_config_transform(value.to_lowercase().as_str())?); config_field!(usize); config_field!(f64); config_field!(u64); @@ -1122,8 +1408,7 @@ impl ConfigField for u8 { fn set(&mut self, key: &str, value: &str) -> Result<()> { if value.is_empty() { return Err(DataFusionError::Configuration(format!( - "Input string for {} key is empty", - key + "Input string for {key} key is empty" ))); } // Check if the string is a valid number @@ -1135,8 +1420,7 @@ impl ConfigField for u8 { // Check if the first character is ASCII (single byte) if bytes.len() > 1 || !value.chars().next().unwrap().is_ascii() { return Err(DataFusionError::Configuration(format!( - "Error parsing {} as u8. Non-ASCII string provided", - value + "Error parsing {value} as u8. Non-ASCII string provided" ))); } *self = bytes[0]; @@ -1615,6 +1899,26 @@ pub struct TableParquetOptions { /// ) /// ``` pub key_value_metadata: HashMap>, + /// Options for configuring Parquet modular encryption + /// + /// To use Parquet encryption, you must enable the `parquet_encryption` feature flag, as it is not activated by default. + /// See ConfigFileEncryptionProperties and ConfigFileDecryptionProperties in datafusion/common/src/config.rs + /// These can be set via 'format.crypto', for example: + /// ```sql + /// OPTIONS ( + /// 'format.crypto.file_encryption.encrypt_footer' 'true', + /// 'format.crypto.file_encryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" */ + /// 'format.crypto.file_encryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + /// 'format.crypto.file_encryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + /// -- Same for decryption + /// 'format.crypto.file_decryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + /// 'format.crypto.file_decryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + /// 'format.crypto.file_decryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + /// ) + /// ``` + /// See datafusion-cli/tests/sql/encrypted_parquet.sql for a more complete example. + /// Note that keys must be provided as in hex format since these are binary strings. + pub crypto: ParquetEncryptionOptions, } impl TableParquetOptions { @@ -1636,13 +1940,52 @@ impl TableParquetOptions { ..self } } + + /// Retrieves all configuration entries from this `TableParquetOptions`. + /// + /// # Returns + /// + /// A vector of `ConfigEntry` instances, representing all the configuration options within this + pub fn entries(self: &TableParquetOptions) -> Vec { + struct Visitor(Vec); + + impl Visit for Visitor { + fn some( + &mut self, + key: &str, + value: V, + description: &'static str, + ) { + self.0.push(ConfigEntry { + key: key[1..].to_string(), + value: Some(value.to_string()), + description, + }) + } + + fn none(&mut self, key: &str, description: &'static str) { + self.0.push(ConfigEntry { + key: key[1..].to_string(), + value: None, + description, + }) + } + } + + let mut v = Visitor(vec![]); + self.visit(&mut v, "", ""); + + v.0 + } } impl ConfigField for TableParquetOptions { fn visit(&self, v: &mut V, key_prefix: &str, description: &'static str) { self.global.visit(v, key_prefix, description); self.column_specific_options - .visit(v, key_prefix, description) + .visit(v, key_prefix, description); + self.crypto + .visit(v, &format!("{key_prefix}.crypto"), description); } fn set(&mut self, key: &str, value: &str) -> Result<()> { @@ -1663,6 +2006,8 @@ impl ConfigField for TableParquetOptions { }; self.key_value_metadata.insert(k, Some(value.into())); Ok(()) + } else if let Some(crypto_feature) = key.strip_prefix("crypto.") { + self.crypto.set(crypto_feature, value) } else if key.contains("::") { self.column_specific_options.set(key, value) } else { @@ -1803,13 +2148,358 @@ config_namespace_with_hashmap! { /// Sets bloom filter number of distinct values. If NULL, uses /// default parquet options pub bloom_filter_ndv: Option, default = None + } +} - /// Sets max statistics size for the column path. If NULL, uses - /// default parquet options - /// max_statistics_size is deprecated, currently it is not being used - // TODO: remove once deprecated - #[deprecated(since = "45.0.0", note = "Setting does not do anything")] - pub max_statistics_size: Option, default = None +#[derive(Clone, Debug, PartialEq)] +pub struct ConfigFileEncryptionProperties { + /// Should the parquet footer be encrypted + /// default is true + pub encrypt_footer: bool, + /// Key to use for the parquet footer encoded in hex format + pub footer_key_as_hex: String, + /// Metadata information for footer key + pub footer_key_metadata_as_hex: String, + /// HashMap of column names --> (key in hex format, metadata) + pub column_encryption_properties: HashMap, + /// AAD prefix string uniquely identifies the file and prevents file swapping + pub aad_prefix_as_hex: String, + /// If true, store the AAD prefix in the file + /// default is false + pub store_aad_prefix: bool, +} + +// Setup to match EncryptionPropertiesBuilder::new() +impl Default for ConfigFileEncryptionProperties { + fn default() -> Self { + ConfigFileEncryptionProperties { + encrypt_footer: true, + footer_key_as_hex: String::new(), + footer_key_metadata_as_hex: String::new(), + column_encryption_properties: Default::default(), + aad_prefix_as_hex: String::new(), + store_aad_prefix: false, + } + } +} + +config_namespace_with_hashmap! { + pub struct ColumnEncryptionProperties { + /// Per column encryption key + pub column_key_as_hex: String, default = "".to_string() + /// Per column encryption key metadata + pub column_metadata_as_hex: Option, default = None + } +} + +impl ConfigField for ConfigFileEncryptionProperties { + fn visit(&self, v: &mut V, key_prefix: &str, _description: &'static str) { + let key = format!("{key_prefix}.encrypt_footer"); + let desc = "Encrypt the footer"; + self.encrypt_footer.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_key_as_hex"); + let desc = "Key to use for the parquet footer"; + self.footer_key_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_key_metadata_as_hex"); + let desc = "Metadata to use for the parquet footer"; + self.footer_key_metadata_as_hex.visit(v, key.as_str(), desc); + + self.column_encryption_properties.visit(v, key_prefix, desc); + + let key = format!("{key_prefix}.aad_prefix_as_hex"); + let desc = "AAD prefix to use"; + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.store_aad_prefix"); + let desc = "If true, store the AAD prefix"; + self.store_aad_prefix.visit(v, key.as_str(), desc); + + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + // Any hex encoded values must be pre-encoded using + // hex::encode() before calling set. + + if key.contains("::") { + // Handle any column specific properties + return self.column_encryption_properties.set(key, value); + }; + + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + "encrypt_footer" => self.encrypt_footer.set(rem, value.as_ref()), + "footer_key_as_hex" => self.footer_key_as_hex.set(rem, value.as_ref()), + "footer_key_metadata_as_hex" => { + self.footer_key_metadata_as_hex.set(rem, value.as_ref()) + } + "aad_prefix_as_hex" => self.aad_prefix_as_hex.set(rem, value.as_ref()), + "store_aad_prefix" => self.store_aad_prefix.set(rem, value.as_ref()), + _ => _config_err!( + "Config value \"{}\" not found on ConfigFileEncryptionProperties", + key + ), + } + } +} + +#[cfg(feature = "parquet_encryption")] +impl From for FileEncryptionProperties { + fn from(val: ConfigFileEncryptionProperties) -> Self { + let mut fep = FileEncryptionProperties::builder( + hex::decode(val.footer_key_as_hex).unwrap(), + ) + .with_plaintext_footer(!val.encrypt_footer) + .with_aad_prefix_storage(val.store_aad_prefix); + + if !val.footer_key_metadata_as_hex.is_empty() { + fep = fep.with_footer_key_metadata( + hex::decode(&val.footer_key_metadata_as_hex) + .expect("Invalid footer key metadata"), + ); + } + + for (column_name, encryption_props) in val.column_encryption_properties.iter() { + let encryption_key = hex::decode(&encryption_props.column_key_as_hex) + .expect("Invalid column encryption key"); + let key_metadata = encryption_props + .column_metadata_as_hex + .as_ref() + .map(|x| hex::decode(x).expect("Invalid column metadata")); + match key_metadata { + Some(key_metadata) => { + fep = fep.with_column_key_and_metadata( + column_name, + encryption_key, + key_metadata, + ); + } + None => { + fep = fep.with_column_key(column_name, encryption_key); + } + } + } + + if !val.aad_prefix_as_hex.is_empty() { + let aad_prefix: Vec = + hex::decode(&val.aad_prefix_as_hex).expect("Invalid AAD prefix"); + fep = fep.with_aad_prefix(aad_prefix); + } + fep.build().unwrap() + } +} + +#[cfg(feature = "parquet_encryption")] +impl From<&FileEncryptionProperties> for ConfigFileEncryptionProperties { + fn from(f: &FileEncryptionProperties) -> Self { + let (column_names_vec, column_keys_vec, column_metas_vec) = f.column_keys(); + + let mut column_encryption_properties: HashMap< + String, + ColumnEncryptionProperties, + > = HashMap::new(); + + for (i, column_name) in column_names_vec.iter().enumerate() { + let column_key_as_hex = hex::encode(&column_keys_vec[i]); + let column_metadata_as_hex: Option = + column_metas_vec.get(i).map(hex::encode); + column_encryption_properties.insert( + column_name.clone(), + ColumnEncryptionProperties { + column_key_as_hex, + column_metadata_as_hex, + }, + ); + } + let mut aad_prefix: Vec = Vec::new(); + if let Some(prefix) = f.aad_prefix() { + aad_prefix = prefix.clone(); + } + ConfigFileEncryptionProperties { + encrypt_footer: f.encrypt_footer(), + footer_key_as_hex: hex::encode(f.footer_key()), + footer_key_metadata_as_hex: f + .footer_key_metadata() + .map(hex::encode) + .unwrap_or_default(), + column_encryption_properties, + aad_prefix_as_hex: hex::encode(aad_prefix), + store_aad_prefix: f.store_aad_prefix(), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ConfigFileDecryptionProperties { + /// Binary string to use for the parquet footer encoded in hex format + pub footer_key_as_hex: String, + /// HashMap of column names --> key in hex format + pub column_decryption_properties: HashMap, + /// AAD prefix string uniquely identifies the file and prevents file swapping + pub aad_prefix_as_hex: String, + /// If true, then verify signature for files with plaintext footers. + /// default = true + pub footer_signature_verification: bool, +} + +config_namespace_with_hashmap! { + pub struct ColumnDecryptionProperties { + /// Per column encryption key + pub column_key_as_hex: String, default = "".to_string() + } +} + +// Setup to match DecryptionPropertiesBuilder::new() +impl Default for ConfigFileDecryptionProperties { + fn default() -> Self { + ConfigFileDecryptionProperties { + footer_key_as_hex: String::new(), + column_decryption_properties: Default::default(), + aad_prefix_as_hex: String::new(), + footer_signature_verification: true, + } + } +} + +impl ConfigField for ConfigFileDecryptionProperties { + fn visit(&self, v: &mut V, key_prefix: &str, _description: &'static str) { + let key = format!("{key_prefix}.footer_key_as_hex"); + let desc = "Key to use for the parquet footer"; + self.footer_key_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.aad_prefix_as_hex"); + let desc = "AAD prefix to use"; + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_signature_verification"); + let desc = "If true, verify the footer signature"; + self.footer_signature_verification + .visit(v, key.as_str(), desc); + + self.column_decryption_properties.visit(v, key_prefix, desc); + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + // Any hex encoded values must be pre-encoded using + // hex::encode() before calling set. + + if key.contains("::") { + // Handle any column specific properties + return self.column_decryption_properties.set(key, value); + }; + + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + "footer_key_as_hex" => self.footer_key_as_hex.set(rem, value.as_ref()), + "aad_prefix_as_hex" => self.aad_prefix_as_hex.set(rem, value.as_ref()), + "footer_signature_verification" => { + self.footer_signature_verification.set(rem, value.as_ref()) + } + _ => _config_err!( + "Config value \"{}\" not found on ConfigFileEncryptionProperties", + key + ), + } + } +} + +#[cfg(feature = "parquet_encryption")] +impl From for FileDecryptionProperties { + fn from(val: ConfigFileDecryptionProperties) -> Self { + let mut column_names: Vec<&str> = Vec::new(); + let mut column_keys: Vec> = Vec::new(); + + for (col_name, decryption_properties) in val.column_decryption_properties.iter() { + column_names.push(col_name.as_str()); + column_keys.push( + hex::decode(&decryption_properties.column_key_as_hex) + .expect("Invalid column decryption key"), + ); + } + + let mut fep = FileDecryptionProperties::builder( + hex::decode(val.footer_key_as_hex).expect("Invalid footer key"), + ) + .with_column_keys(column_names, column_keys) + .unwrap(); + + if !val.footer_signature_verification { + fep = fep.disable_footer_signature_verification(); + } + + if !val.aad_prefix_as_hex.is_empty() { + let aad_prefix = + hex::decode(&val.aad_prefix_as_hex).expect("Invalid AAD prefix"); + fep = fep.with_aad_prefix(aad_prefix); + } + + fep.build().unwrap() + } +} + +#[cfg(feature = "parquet_encryption")] +impl From<&FileDecryptionProperties> for ConfigFileDecryptionProperties { + fn from(f: &FileDecryptionProperties) -> Self { + let (column_names_vec, column_keys_vec) = f.column_keys(); + let mut column_decryption_properties: HashMap< + String, + ColumnDecryptionProperties, + > = HashMap::new(); + for (i, column_name) in column_names_vec.iter().enumerate() { + let props = ColumnDecryptionProperties { + column_key_as_hex: hex::encode(column_keys_vec[i].clone()), + }; + column_decryption_properties.insert(column_name.clone(), props); + } + + let mut aad_prefix: Vec = Vec::new(); + if let Some(prefix) = f.aad_prefix() { + aad_prefix = prefix.clone(); + } + ConfigFileDecryptionProperties { + footer_key_as_hex: hex::encode( + f.footer_key(None).unwrap_or_default().as_ref(), + ), + column_decryption_properties, + aad_prefix_as_hex: hex::encode(aad_prefix), + footer_signature_verification: f.check_plaintext_footer_integrity(), + } + } +} + +/// Holds implementation-specific options for an encryption factory +#[derive(Clone, Debug, Default, PartialEq)] +pub struct EncryptionFactoryOptions { + pub options: HashMap, +} + +impl ConfigField for EncryptionFactoryOptions { + fn visit(&self, v: &mut V, key: &str, _description: &'static str) { + for (option_key, option_value) in &self.options { + v.some( + &format!("{key}.{option_key}"), + option_value, + "Encryption factory specific option", + ); + } + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + self.options.insert(key.to_owned(), value.to_owned()); + Ok(()) + } +} + +impl EncryptionFactoryOptions { + /// Convert these encryption factory options to an [`ExtensionOptions`] instance. + pub fn to_extension_options(&self) -> Result { + let mut options = T::default(); + for (key, value) in &self.options { + options.set(key, value)?; + } + Ok(options) } } @@ -1845,6 +2535,16 @@ config_namespace! { // The input regex for Nulls when loading CSVs. pub null_regex: Option, default = None pub comment: Option, default = None + /// Whether to allow truncated rows when parsing, both within a single file and across files. + /// + /// When set to false (default), reading a single CSV file which has rows of different lengths will + /// error; if reading multiple CSV files with different number of columns, it will also fail. + /// + /// When set to true, reading a single CSV file with rows of different lengths will pad the truncated + /// rows with null values for the missing columns; if reading multiple CSV files with different number + /// of columns, it creates a union schema containing all columns found across the files, and will + /// pad any files missing columns with null values for their rows. + pub truncated_rows: Option, default = None } } @@ -1937,6 +2637,15 @@ impl CsvOptions { self } + /// Whether to allow truncated rows when parsing. + /// By default this is set to false and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns and fill the missing columns with nulls. + /// If the record’s schema is not nullable, then it will still return an error. + pub fn with_truncated_rows(mut self, allow: bool) -> Self { + self.truncated_rows = Some(allow); + self + } + /// The delimiter character. pub fn delimiter(&self) -> u8 { self.delimiter @@ -1966,11 +2675,11 @@ config_namespace! { } } -pub trait FormatOptionsExt: Display {} +pub trait OutputFormatExt: Display {} #[derive(Debug, Clone, PartialEq)] #[allow(clippy::large_enum_variant)] -pub enum FormatOptions { +pub enum OutputFormat { CSV(CsvOptions), JSON(JsonOptions), #[cfg(feature = "parquet")] @@ -1979,29 +2688,30 @@ pub enum FormatOptions { ARROW, } -impl Display for FormatOptions { +impl Display for OutputFormat { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let out = match self { - FormatOptions::CSV(_) => "csv", - FormatOptions::JSON(_) => "json", + OutputFormat::CSV(_) => "csv", + OutputFormat::JSON(_) => "json", #[cfg(feature = "parquet")] - FormatOptions::PARQUET(_) => "parquet", - FormatOptions::AVRO => "avro", - FormatOptions::ARROW => "arrow", + OutputFormat::PARQUET(_) => "parquet", + OutputFormat::AVRO => "avro", + OutputFormat::ARROW => "arrow", }; - write!(f, "{}", out) + write!(f, "{out}") } } #[cfg(test)] mod tests { - use std::any::Any; - use std::collections::HashMap; - + #[cfg(feature = "parquet")] + use crate::config::TableParquetOptions; use crate::config::{ - ConfigEntry, ConfigExtension, ConfigFileType, ExtensionOptions, Extensions, - TableOptions, + ConfigEntry, ConfigExtension, ConfigField, ConfigFileType, ExtensionOptions, + Extensions, TableOptions, }; + use std::any::Any; + use std::collections::HashMap; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -2085,6 +2795,37 @@ mod tests { assert_eq!(table_config.csv.escape.unwrap() as char, '\''); } + #[test] + fn warning_only_not_default() { + use std::sync::atomic::AtomicUsize; + static COUNT: AtomicUsize = AtomicUsize::new(0); + use log::{Level, LevelFilter, Metadata, Record}; + struct SimpleLogger; + impl log::Log for SimpleLogger { + fn enabled(&self, metadata: &Metadata) -> bool { + metadata.level() <= Level::Info + } + + fn log(&self, record: &Record) { + if self.enabled(record.metadata()) { + COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + } + fn flush(&self) {} + } + log::set_logger(&SimpleLogger).unwrap(); + log::set_max_level(LevelFilter::Info); + let mut sql_parser_options = crate::config::SqlParserOptions::default(); + sql_parser_options + .set("enable_options_value_normalization", "false") + .unwrap(); + assert_eq!(COUNT.load(std::sync::atomic::Ordering::Relaxed), 0); + sql_parser_options + .set("enable_options_value_normalization", "true") + .unwrap(); + assert_eq!(COUNT.load(std::sync::atomic::Ordering::Relaxed), 1); + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options() { @@ -2099,6 +2840,159 @@ mod tests { ); } + #[cfg(feature = "parquet_encryption")] + #[test] + fn parquet_table_encryption() { + use crate::config::{ + ConfigFileDecryptionProperties, ConfigFileEncryptionProperties, + }; + use parquet::encryption::decrypt::FileDecryptionProperties; + use parquet::encryption::encrypt::FileEncryptionProperties; + + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_names = vec!["double_field", "float_field"]; + let column_keys = + vec![b"1234567890123450".to_vec(), b"1234567890123451".to_vec()]; + + let file_encryption_properties = + FileEncryptionProperties::builder(footer_key.clone()) + .with_column_keys(column_names.clone(), column_keys.clone()) + .unwrap() + .build() + .unwrap(); + + let decryption_properties = FileDecryptionProperties::builder(footer_key.clone()) + .with_column_keys(column_names.clone(), column_keys.clone()) + .unwrap() + .build() + .unwrap(); + + // Test round-trip + let config_encrypt: ConfigFileEncryptionProperties = + (&file_encryption_properties).into(); + let encryption_properties_built: FileEncryptionProperties = + config_encrypt.clone().into(); + assert_eq!(file_encryption_properties, encryption_properties_built); + + let config_decrypt: ConfigFileDecryptionProperties = + (&decryption_properties).into(); + let decryption_properties_built: FileDecryptionProperties = + config_decrypt.clone().into(); + assert_eq!(decryption_properties, decryption_properties_built); + + /////////////////////////////////////////////////////////////////////////////////// + // Test encryption config + + // Display original encryption config + // println!("{:#?}", config_encrypt); + + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config + .parquet + .set( + "crypto.file_encryption.encrypt_footer", + config_encrypt.encrypt_footer.to_string().as_str(), + ) + .unwrap(); + table_config + .parquet + .set( + "crypto.file_encryption.footer_key_as_hex", + config_encrypt.footer_key_as_hex.as_str(), + ) + .unwrap(); + + for (i, col_name) in column_names.iter().enumerate() { + let key = format!("crypto.file_encryption.column_key_as_hex::{col_name}"); + let value = hex::encode(column_keys[i].clone()); + table_config + .parquet + .set(key.as_str(), value.as_str()) + .unwrap(); + } + + // Print matching final encryption config + // println!("{:#?}", table_config.parquet.crypto.file_encryption); + + assert_eq!( + table_config.parquet.crypto.file_encryption, + Some(config_encrypt) + ); + + /////////////////////////////////////////////////////////////////////////////////// + // Test decryption config + + // Display original decryption config + // println!("{:#?}", config_decrypt); + + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config + .parquet + .set( + "crypto.file_decryption.footer_key_as_hex", + config_decrypt.footer_key_as_hex.as_str(), + ) + .unwrap(); + + for (i, col_name) in column_names.iter().enumerate() { + let key = format!("crypto.file_decryption.column_key_as_hex::{col_name}"); + let value = hex::encode(column_keys[i].clone()); + table_config + .parquet + .set(key.as_str(), value.as_str()) + .unwrap(); + } + + // Print matching final decryption config + // println!("{:#?}", table_config.parquet.crypto.file_decryption); + + assert_eq!( + table_config.parquet.crypto.file_decryption, + Some(config_decrypt.clone()) + ); + + // Set config directly + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config.parquet.crypto.file_decryption = Some(config_decrypt.clone()); + assert_eq!( + table_config.parquet.crypto.file_decryption, + Some(config_decrypt.clone()) + ); + } + + #[cfg(feature = "parquet_encryption")] + #[test] + fn parquet_encryption_factory_config() { + let mut parquet_options = TableParquetOptions::default(); + + assert_eq!(parquet_options.crypto.factory_id, None); + assert_eq!(parquet_options.crypto.factory_options.options.len(), 0); + + let mut input_config = TestExtensionConfig::default(); + input_config + .properties + .insert("key1".to_string(), "value 1".to_string()); + input_config + .properties + .insert("key2".to_string(), "value 2".to_string()); + + parquet_options + .crypto + .configure_factory("example_factory", &input_config); + + assert_eq!( + parquet_options.crypto.factory_id, + Some("example_factory".to_string()) + ); + let factory_options = &parquet_options.crypto.factory_options.options; + assert_eq!(factory_options.len(), 2); + assert_eq!(factory_options.get("key1"), Some(&"value 1".to_string())); + assert_eq!(factory_options.get("key2"), Some(&"value 2".to_string())); + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options_config_entry() { @@ -2113,6 +3007,23 @@ mod tests { .any(|item| item.key == "format.bloom_filter_enabled::col1")) } + #[cfg(feature = "parquet")] + #[test] + fn parquet_table_parquet_options_config_entry() { + let mut table_parquet_options = TableParquetOptions::new(); + table_parquet_options + .set( + "crypto.file_encryption.column_key_as_hex::double_field", + "31323334353637383930313233343530", + ) + .unwrap(); + let entries = table_parquet_options.entries(); + assert!(entries + .iter() + .any(|item| item.key + == "crypto.file_encryption.column_key_as_hex::double_field")) + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options_config_metadata_entry() { diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 43d082f9dc936..6866b4011f9ec 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -101,7 +101,7 @@ pub type DFSchemaRef = Arc; /// let df_schema = DFSchema::from_unqualified_fields(vec![ /// Field::new("c1", arrow::datatypes::DataType::Int32, false), /// ].into(),HashMap::new()).unwrap(); -/// let schema = Schema::from(df_schema); +/// let schema: &Schema = df_schema.as_arrow(); /// assert_eq!(schema.fields().len(), 1); /// ``` #[derive(Debug, Clone, PartialEq, Eq)] @@ -206,6 +206,25 @@ impl DFSchema { Ok(dfschema) } + /// Return the same schema, where all fields have a given qualifier. + pub fn with_field_specific_qualified_schema( + &self, + qualifiers: Vec>, + ) -> Result { + if qualifiers.len() != self.fields().len() { + return _plan_err!( + "Number of qualifiers must match number of fields. Expected {}, got {}", + self.fields().len(), + qualifiers.len() + ); + } + Ok(DFSchema { + inner: Arc::clone(&self.inner), + field_qualifiers: qualifiers, + functional_dependencies: self.functional_dependencies.clone(), + }) + } + /// Check if the schema have some fields with the same name pub fn check_names(&self) -> Result<()> { let mut qualified_names = BTreeSet::new(); @@ -229,7 +248,7 @@ impl DFSchema { for (qualifier, name) in qualified_names { if unqualified_names.contains(name) { return _schema_err!(SchemaError::AmbiguousReference { - field: Column::new(Some(qualifier.clone()), name) + field: Box::new(Column::new(Some(qualifier.clone()), name)) }); } } @@ -278,6 +297,20 @@ impl DFSchema { /// Modify this schema by appending the fields from the supplied schema, ignoring any /// duplicate fields. + /// + /// ## Merge Precedence + /// + /// **Schema-level metadata**: Metadata from both schemas is merged. + /// If both schemas have the same metadata key, the value from the `other_schema` parameter takes precedence. + /// + /// **Field-level merging**: Only non-duplicate fields are added. This means that the + /// `self` fields will always take precedence over the `other_schema` fields. + /// Duplicate field detection is based on: + /// - For qualified fields: both qualifier and field name must match + /// - For unqualified fields: only field name needs to match + /// + /// Take note how the precedence for fields & metadata merging differs; + /// merging prefers fields from `self` but prefers metadata from `other_schema`. pub fn merge(&mut self, other_schema: &DFSchema) { if other_schema.inner.fields.is_empty() { return; @@ -472,7 +505,7 @@ impl DFSchema { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), - 1 => Ok((matches[0].0, (matches[0].1))), + 1 => Ok((matches[0].0, matches[0].1)), _ => { // When `matches` size > 1, it doesn't necessarily mean an `ambiguous name` problem. // Because name may generate from Alias/... . It means that it don't own qualifier. @@ -489,7 +522,7 @@ impl DFSchema { Ok((fields_without_qualifier[0].0, fields_without_qualifier[0].1)) } else { _schema_err!(SchemaError::AmbiguousReference { - field: Column::new_unqualified(name.to_string(),), + field: Box::new(Column::new_unqualified(name.to_string())) }) } } @@ -515,14 +548,6 @@ impl DFSchema { Ok(self.field(idx)) } - /// Find the field with the given qualified column - pub fn field_from_column(&self, column: &Column) -> Result<&Field> { - match &column.relation { - Some(r) => self.field_with_qualified_name(r, &column.name), - None => self.field_with_unqualified_name(&column.name), - } - } - /// Find the field with the given qualified column pub fn qualified_field_from_column( &self, @@ -569,7 +594,7 @@ impl DFSchema { &self, arrow_schema: &Schema, ) -> Result<()> { - let self_arrow_schema: Schema = self.into(); + let self_arrow_schema = self.as_arrow(); self_arrow_schema .fields() .iter() @@ -641,11 +666,11 @@ impl DFSchema { || (!DFSchema::datatype_is_semantically_equal( f1.data_type(), f2.data_type(), - ) && !can_cast_types(f2.data_type(), f1.data_type())) + )) { _plan_err!( - "Schema mismatch: Expected field '{}' with type {:?}, \ - but got '{}' with type {:?}.", + "Schema mismatch: Expected field '{}' with type {}, \ + but got '{}' with type {}.", f1.name(), f1.data_type(), f2.name(), @@ -659,9 +684,12 @@ impl DFSchema { } /// Checks if two [`DataType`]s are logically equal. This is a notably weaker constraint - /// than datatype_is_semantically_equal in that a Dictionary type is logically - /// equal to a plain V type, but not semantically equal. Dictionary is also - /// logically equal to Dictionary. + /// than datatype_is_semantically_equal in that different representations of same data can be + /// logically but not semantically equivalent. Semantically equivalent types are always also + /// logically equivalent. For example: + /// - a Dictionary type is logically equal to a plain V type + /// - a Dictionary is also logically equal to Dictionary + /// - Utf8 and Utf8View are logically equal pub fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { // check nested fields match (dt1, dt2) { @@ -711,12 +739,16 @@ impl DFSchema { .zip(iter2) .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_logically_equal(f1, f2)) } - _ => dt1 == dt2, + // Utf8 and Utf8View are logically equivalent + (DataType::Utf8, DataType::Utf8View) => true, + (DataType::Utf8View, DataType::Utf8) => true, + _ => Self::datatype_is_semantically_equal(dt1, dt2), } } /// Returns true of two [`DataType`]s are semantically equal (same - /// name and type), ignoring both metadata and nullability. + /// name and type), ignoring both metadata and nullability, decimal precision/scale, + /// and timezone time units/timezones. /// /// request to upstream: pub fn datatype_is_semantically_equal(dt1: &DataType, dt2: &DataType) -> bool { @@ -767,6 +799,14 @@ impl DFSchema { .zip(iter2) .all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_semantically_equal(f1, f2)) } + ( + DataType::Decimal32(_l_precision, _l_scale), + DataType::Decimal32(_r_precision, _r_scale), + ) => true, + ( + DataType::Decimal64(_l_precision, _l_scale), + DataType::Decimal64(_r_precision, _r_scale), + ) => true, ( DataType::Decimal128(_l_precision, _l_scale), DataType::Decimal128(_r_precision, _r_scale), @@ -775,6 +815,10 @@ impl DFSchema { DataType::Decimal256(_l_precision, _l_scale), DataType::Decimal256(_r_precision, _r_scale), ) => true, + ( + DataType::Timestamp(_l_time_unit, _l_timezone), + DataType::Timestamp(_r_time_unit, _r_timezone), + ) => true, _ => dt1 == dt2, } } @@ -832,21 +876,213 @@ impl DFSchema { .zip(self.inner.fields().iter()) .map(|(qualifier, field)| (qualifier.as_ref(), field)) } + /// Returns a tree-like string representation of the schema. + /// + /// This method formats the schema + /// with a tree-like structure showing field names, types, and nullability. + /// + /// # Example + /// + /// ``` + /// use datafusion_common::DFSchema; + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use std::collections::HashMap; + /// + /// let schema = DFSchema::from_unqualified_fields( + /// vec![ + /// Field::new("id", DataType::Int32, false), + /// Field::new("name", DataType::Utf8, true), + /// ].into(), + /// HashMap::new() + /// ).unwrap(); + /// + /// assert_eq!(schema.tree_string().to_string(), + /// r#"root + /// |-- id: int32 (nullable = false) + /// |-- name: utf8 (nullable = true)"#); + /// ``` + pub fn tree_string(&self) -> impl Display + '_ { + let mut result = String::from("root\n"); + + for (qualifier, field) in self.iter() { + let field_name = match qualifier { + Some(q) => format!("{}.{}", q, field.name()), + None => field.name().to_string(), + }; + + format_field_with_indent( + &mut result, + &field_name, + field.data_type(), + field.is_nullable(), + " ", + ); + } + + // Remove the trailing newline + if result.ends_with('\n') { + result.pop(); + } + + result + } } -impl From for Schema { - /// Convert DFSchema into a Schema - fn from(df_schema: DFSchema) -> Self { - let fields: Fields = df_schema.inner.fields.clone(); - Schema::new_with_metadata(fields, df_schema.inner.metadata.clone()) +/// Format field with proper nested indentation for complex types +fn format_field_with_indent( + result: &mut String, + field_name: &str, + data_type: &DataType, + nullable: bool, + indent: &str, +) { + let nullable_str = nullable.to_string().to_lowercase(); + let child_indent = format!("{indent}| "); + + match data_type { + DataType::List(field) => { + result.push_str(&format!( + "{indent}|-- {field_name}: list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::LargeList(field) => { + result.push_str(&format!( + "{indent}|-- {field_name}: large list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::FixedSizeList(field, _size) => { + result.push_str(&format!( + "{indent}|-- {field_name}: fixed size list (nullable = {nullable_str})\n" + )); + format_field_with_indent( + result, + field.name(), + field.data_type(), + field.is_nullable(), + &child_indent, + ); + } + DataType::Map(field, _) => { + result.push_str(&format!( + "{indent}|-- {field_name}: map (nullable = {nullable_str})\n" + )); + if let DataType::Struct(inner_fields) = field.data_type() { + if inner_fields.len() == 2 { + format_field_with_indent( + result, + "key", + inner_fields[0].data_type(), + inner_fields[0].is_nullable(), + &child_indent, + ); + let value_contains_null = + field.is_nullable().to_string().to_lowercase(); + // Handle complex value types properly + match inner_fields[1].data_type() { + DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Map(_, _) => { + format_field_with_indent( + result, + "value", + inner_fields[1].data_type(), + inner_fields[1].is_nullable(), + &child_indent, + ); + } + _ => { + result.push_str(&format!("{child_indent}|-- value: {} (nullable = {value_contains_null})\n", + format_simple_data_type(inner_fields[1].data_type()))); + } + } + } + } + } + DataType::Struct(fields) => { + result.push_str(&format!( + "{indent}|-- {field_name}: struct (nullable = {nullable_str})\n" + )); + for struct_field in fields { + format_field_with_indent( + result, + struct_field.name(), + struct_field.data_type(), + struct_field.is_nullable(), + &child_indent, + ); + } + } + _ => { + let type_str = format_simple_data_type(data_type); + result.push_str(&format!( + "{indent}|-- {field_name}: {type_str} (nullable = {nullable_str})\n" + )); + } } } -impl From<&DFSchema> for Schema { - /// Convert DFSchema reference into a Schema - fn from(df_schema: &DFSchema) -> Self { - let fields: Fields = df_schema.inner.fields.clone(); - Schema::new_with_metadata(fields, df_schema.inner.metadata.clone()) +/// Format simple DataType in lowercase format (for leaf nodes) +fn format_simple_data_type(data_type: &DataType) -> String { + match data_type { + DataType::Boolean => "boolean".to_string(), + DataType::Int8 => "int8".to_string(), + DataType::Int16 => "int16".to_string(), + DataType::Int32 => "int32".to_string(), + DataType::Int64 => "int64".to_string(), + DataType::UInt8 => "uint8".to_string(), + DataType::UInt16 => "uint16".to_string(), + DataType::UInt32 => "uint32".to_string(), + DataType::UInt64 => "uint64".to_string(), + DataType::Float16 => "float16".to_string(), + DataType::Float32 => "float32".to_string(), + DataType::Float64 => "float64".to_string(), + DataType::Utf8 => "utf8".to_string(), + DataType::LargeUtf8 => "large_utf8".to_string(), + DataType::Binary => "binary".to_string(), + DataType::LargeBinary => "large_binary".to_string(), + DataType::FixedSizeBinary(_) => "fixed_size_binary".to_string(), + DataType::Date32 => "date32".to_string(), + DataType::Date64 => "date64".to_string(), + DataType::Time32(_) => "time32".to_string(), + DataType::Time64(_) => "time64".to_string(), + DataType::Timestamp(_, tz) => match tz { + Some(tz_str) => format!("timestamp ({tz_str})"), + None => "timestamp".to_string(), + }, + DataType::Interval(_) => "interval".to_string(), + DataType::Dictionary(_, value_type) => { + format_simple_data_type(value_type.as_ref()) + } + DataType::Decimal32(precision, scale) => { + format!("decimal32({precision}, {scale})") + } + DataType::Decimal64(precision, scale) => { + format!("decimal64({precision}, {scale})") + } + DataType::Decimal128(precision, scale) => { + format!("decimal128({precision}, {scale})") + } + DataType::Decimal256(precision, scale) => { + format!("decimal256({precision}, {scale})") + } + DataType::Null => "null".to_string(), + _ => format!("{data_type}").to_lowercase(), } } @@ -882,16 +1118,15 @@ impl TryFrom for DFSchema { field_qualifiers: vec![None; field_count], functional_dependencies: FunctionalDependencies::empty(), }; + // Without checking names, because schema here may have duplicate field names. + // For example, Partial AggregateMode will generate duplicate field names from + // state_fields. + // See + // dfschema.check_names()?; Ok(dfschema) } } -impl From for SchemaRef { - fn from(df_schema: DFSchema) -> Self { - SchemaRef::new(df_schema.into()) - } -} - // Hashing refers to a subset of fields considered in PartialEq. impl Hash for DFSchema { fn hash(&self, state: &mut H) { @@ -963,16 +1198,28 @@ impl Display for DFSchema { /// widely used in the DataFusion codebase. pub trait ExprSchema: std::fmt::Debug { /// Is this column reference nullable? - fn nullable(&self, col: &Column) -> Result; + fn nullable(&self, col: &Column) -> Result { + Ok(self.field_from_column(col)?.is_nullable()) + } /// What is the datatype of this column? - fn data_type(&self, col: &Column) -> Result<&DataType>; + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.field_from_column(col)?.data_type()) + } /// Returns the column's optional metadata. - fn metadata(&self, col: &Column) -> Result<&HashMap>; + fn metadata(&self, col: &Column) -> Result<&HashMap> { + Ok(self.field_from_column(col)?.metadata()) + } /// Return the column's datatype and nullability - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + let field = self.field_from_column(col)?; + Ok((field.data_type(), field.is_nullable())) + } + + // Return the column's field + fn field_from_column(&self, col: &Column) -> Result<&Field>; } // Implement `ExprSchema` for `Arc` @@ -992,24 +1239,18 @@ impl + std::fmt::Debug> ExprSchema for P { fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { self.as_ref().data_type_and_nullable(col) } -} - -impl ExprSchema for DFSchema { - fn nullable(&self, col: &Column) -> Result { - Ok(self.field_from_column(col)?.is_nullable()) - } - - fn data_type(&self, col: &Column) -> Result<&DataType> { - Ok(self.field_from_column(col)?.data_type()) - } - fn metadata(&self, col: &Column) -> Result<&HashMap> { - Ok(self.field_from_column(col)?.metadata()) + fn field_from_column(&self, col: &Column) -> Result<&Field> { + self.as_ref().field_from_column(col) } +} - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - let field = self.field_from_column(col)?; - Ok((field.data_type(), field.is_nullable())) +impl ExprSchema for DFSchema { + fn field_from_column(&self, col: &Column) -> Result<&Field> { + match &col.relation { + Some(r) => self.field_with_qualified_name(r, &col.name), + None => self.field_with_unqualified_name(&col.name), + } } } @@ -1068,8 +1309,8 @@ impl SchemaExt for Schema { .try_for_each(|(f1, f2)| { if f1.name() != f2.name() || (!DFSchema::datatype_is_logically_equal(f1.data_type(), f2.data_type()) && !can_cast_types(f2.data_type(), f1.data_type())) { _plan_err!( - "Inserting query schema mismatch: Expected table field '{}' with type {:?}, \ - but got '{}' with type {:?}.", + "Inserting query schema mismatch: Expected table field '{}' with type {}, \ + but got '{}' with type {}.", f1.name(), f1.data_type(), f2.name(), @@ -1084,7 +1325,7 @@ impl SchemaExt for Schema { pub fn qualified_name(qualifier: Option<&TableReference>, name: &str) -> String { match qualifier { - Some(q) => format!("{}.{}", q, name), + Some(q) => format!("{q}.{name}"), None => name.to_string(), } } @@ -1175,10 +1416,8 @@ mod tests { #[test] fn from_qualified_schema_into_arrow_schema() -> Result<()> { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - let arrow_schema: Schema = schema.into(); - let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }"; - assert_eq!(expected, arrow_schema.to_string()); + let arrow_schema = schema.as_arrow(); + insta::assert_snapshot!(arrow_schema, @r#"Field { name: "c0", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }"#); Ok(()) } @@ -1554,6 +1793,36 @@ mod tests { &DataType::Int16 )); + // Succeeds if decimal precision and scale are different + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal32(1, 2), + &DataType::Decimal32(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal64(1, 2), + &DataType::Decimal64(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal128(1, 2), + &DataType::Decimal128(2, 1), + )); + + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Decimal256(1, 2), + &DataType::Decimal256(2, 1), + )); + + // Any two timestamp types should match + assert!(DFSchema::datatype_is_semantically_equal( + &DataType::Timestamp( + arrow::datatypes::TimeUnit::Microsecond, + Some("UTC".into()) + ), + &DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, None), + )); + // Test lists // Succeeds if both have the same element type, disregards names and nullability @@ -1696,4 +1965,488 @@ mod tests { fn test_metadata_n(n: usize) -> HashMap { (0..n).map(|i| (format!("k{i}"), format!("v{i}"))).collect() } + + #[test] + fn test_print_schema_unqualified() { + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::Int64, true), + Field::new("active", DataType::Boolean, false), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- id: int32 (nullable = false) + |-- name: utf8 (nullable = true) + |-- age: int64 (nullable = true) + |-- active: boolean (nullable = false) + "); + } + + #[test] + fn test_print_schema_qualified() { + let schema = DFSchema::try_from_qualified_schema( + "table1", + &Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ]), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- table1.id: int32 (nullable = false) + |-- table1.name: utf8 (nullable = true) + "); + } + + #[test] + fn test_print_schema_complex_types() { + let struct_field = Field::new( + "address", + DataType::Struct(Fields::from(vec![ + Field::new("street", DataType::Utf8, true), + Field::new("city", DataType::Utf8, true), + ])), + true, + ); + + let list_field = Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("id", DataType::Int32, false), + struct_field, + list_field, + Field::new("score", DataType::Decimal128(10, 2), true), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + insta::assert_snapshot!(output, @r" + root + |-- id: int32 (nullable = false) + |-- address: struct (nullable = true) + | |-- street: utf8 (nullable = true) + | |-- city: utf8 (nullable = true) + |-- tags: list (nullable = true) + | |-- item: utf8 (nullable = true) + |-- score: decimal128(10, 2) (nullable = true) + "); + } + + #[test] + fn test_print_schema_empty() { + let schema = DFSchema::empty(); + let output = schema.tree_string(); + insta::assert_snapshot!(output, @r###"root"###); + } + + #[test] + fn test_print_schema_deeply_nested_types() { + // Create a deeply nested structure to test indentation and complex type formatting + let inner_struct = Field::new( + "inner", + DataType::Struct(Fields::from(vec![ + Field::new("level1", DataType::Utf8, true), + Field::new("level2", DataType::Int32, false), + ])), + true, + ); + + let nested_list = Field::new( + "nested_list", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Float64, true), + ])), + true, + ))), + true, + ); + + let map_field = Field::new( + "map_data", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true, + ))), + true, + ), + ])), + false, + )), + false, + ), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("simple_field", DataType::Utf8, true), + inner_struct, + nested_list, + map_field, + Field::new( + "timestamp_field", + DataType::Timestamp( + arrow::datatypes::TimeUnit::Microsecond, + Some("UTC".into()), + ), + false, + ), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- simple_field: utf8 (nullable = true) + |-- inner: struct (nullable = true) + | |-- level1: utf8 (nullable = true) + | |-- level2: int32 (nullable = false) + |-- nested_list: list (nullable = true) + | |-- item: struct (nullable = true) + | | |-- id: int64 (nullable = false) + | | |-- value: float64 (nullable = true) + |-- map_data: map (nullable = true) + | |-- key: utf8 (nullable = false) + | |-- value: list (nullable = true) + | | |-- item: int32 (nullable = true) + |-- timestamp_field: timestamp (UTC) (nullable = false) + "); + } + + #[test] + fn test_print_schema_mixed_qualified_unqualified() { + // Test a schema with mixed qualified and unqualified fields + let schema = DFSchema::new_with_metadata( + vec![ + ( + Some("table1".into()), + Arc::new(Field::new("id", DataType::Int32, false)), + ), + (None, Arc::new(Field::new("name", DataType::Utf8, true))), + ( + Some("table2".into()), + Arc::new(Field::new("score", DataType::Float64, true)), + ), + ( + None, + Arc::new(Field::new("active", DataType::Boolean, false)), + ), + ], + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- table1.id: int32 (nullable = false) + |-- name: utf8 (nullable = true) + |-- table2.score: float64 (nullable = true) + |-- active: boolean (nullable = false) + "); + } + + #[test] + fn test_print_schema_array_of_map() { + // Test the specific example from user feedback: array of map + let map_field = Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ])), + false, + ); + + let array_of_map_field = Field::new( + "array_map_field", + DataType::List(Arc::new(Field::new( + "item", + DataType::Map(Arc::new(map_field), false), + false, + ))), + false, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![array_of_map_field].into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- array_map_field: list (nullable = false) + | |-- item: map (nullable = false) + | | |-- key: utf8 (nullable = false) + | | |-- value: utf8 (nullable = false) + "); + } + + #[test] + fn test_print_schema_complex_type_combinations() { + // Test various combinations of list, struct, and map types + + // List of structs + let list_of_structs = Field::new( + "list_of_structs", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("score", DataType::Float64, true), + ])), + true, + ))), + true, + ); + + // Struct containing lists + let struct_with_lists = Field::new( + "struct_with_lists", + DataType::Struct(Fields::from(vec![ + Field::new( + "tags", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ), + Field::new( + "scores", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + ), + Field::new("metadata", DataType::Utf8, true), + ])), + false, + ); + + // Map with struct values + let map_with_struct_values = Field::new( + "map_with_struct_values", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Struct(Fields::from(vec![ + Field::new("count", DataType::Int64, false), + Field::new("active", DataType::Boolean, true), + ])), + true, + ), + ])), + false, + )), + false, + ), + true, + ); + + // List of maps + let list_of_maps = Field::new( + "list_of_maps", + DataType::List(Arc::new(Field::new( + "item", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ))), + true, + ); + + // Deeply nested: struct containing list of structs containing maps + let deeply_nested = Field::new( + "deeply_nested", + DataType::Struct(Fields::from(vec![ + Field::new("level1", DataType::Utf8, true), + Field::new( + "level2", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "properties", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Float64, true), + ])), + false, + )), + false, + ), + true, + ), + ])), + true, + ))), + false, + ), + ])), + true, + ); + + let schema = DFSchema::from_unqualified_fields( + vec![ + list_of_structs, + struct_with_lists, + map_with_struct_values, + list_of_maps, + deeply_nested, + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- list_of_structs: list (nullable = true) + | |-- item: struct (nullable = true) + | | |-- id: int32 (nullable = false) + | | |-- name: utf8 (nullable = true) + | | |-- score: float64 (nullable = true) + |-- struct_with_lists: struct (nullable = false) + | |-- tags: list (nullable = true) + | | |-- item: utf8 (nullable = true) + | |-- scores: list (nullable = false) + | | |-- item: int32 (nullable = true) + | |-- metadata: utf8 (nullable = true) + |-- map_with_struct_values: map (nullable = true) + | |-- key: utf8 (nullable = false) + | |-- value: struct (nullable = true) + | | |-- count: int64 (nullable = false) + | | |-- active: boolean (nullable = true) + |-- list_of_maps: list (nullable = true) + | |-- item: map (nullable = true) + | | |-- key: utf8 (nullable = false) + | | |-- value: int32 (nullable = false) + |-- deeply_nested: struct (nullable = true) + | |-- level1: utf8 (nullable = true) + | |-- level2: list (nullable = false) + | | |-- item: struct (nullable = true) + | | | |-- id: int32 (nullable = false) + | | | |-- properties: map (nullable = true) + | | | | |-- key: utf8 (nullable = false) + | | | | |-- value: float64 (nullable = false) + "); + } + + #[test] + fn test_print_schema_edge_case_types() { + // Test edge cases and special types + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("null_field", DataType::Null, true), + Field::new("binary_field", DataType::Binary, false), + Field::new("large_binary", DataType::LargeBinary, true), + Field::new("large_utf8", DataType::LargeUtf8, false), + Field::new("fixed_size_binary", DataType::FixedSizeBinary(16), true), + Field::new( + "fixed_size_list", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Int32, true)), + 5, + ), + false, + ), + Field::new("decimal32", DataType::Decimal32(9, 4), true), + Field::new("decimal64", DataType::Decimal64(9, 4), true), + Field::new("decimal128", DataType::Decimal128(18, 4), true), + Field::new("decimal256", DataType::Decimal256(38, 10), false), + Field::new("date32", DataType::Date32, true), + Field::new("date64", DataType::Date64, false), + Field::new( + "time32_seconds", + DataType::Time32(arrow::datatypes::TimeUnit::Second), + true, + ), + Field::new( + "time64_nanoseconds", + DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond), + false, + ), + ] + .into(), + HashMap::new(), + ) + .unwrap(); + + let output = schema.tree_string(); + + insta::assert_snapshot!(output, @r" + root + |-- null_field: null (nullable = true) + |-- binary_field: binary (nullable = false) + |-- large_binary: large_binary (nullable = true) + |-- large_utf8: large_utf8 (nullable = false) + |-- fixed_size_binary: fixed_size_binary (nullable = true) + |-- fixed_size_list: fixed size list (nullable = false) + | |-- item: int32 (nullable = true) + |-- decimal32: decimal32(9, 4) (nullable = true) + |-- decimal64: decimal64(9, 4) (nullable = true) + |-- decimal128: decimal128(18, 4) (nullable = true) + |-- decimal256: decimal256(38, 10) (nullable = false) + |-- date32: date32 (nullable = true) + |-- date64: date64 (nullable = false) + |-- time32_seconds: time32 (nullable = true) + |-- time64_nanoseconds: time64 (nullable = false) + "); + } } diff --git a/datafusion/common/src/encryption.rs b/datafusion/common/src/encryption.rs new file mode 100644 index 0000000000000..b764ad77cff19 --- /dev/null +++ b/datafusion/common/src/encryption.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Support optional features for encryption in Parquet files. +//! This module provides types and functions related to encryption in Parquet files. + +#[cfg(feature = "parquet_encryption")] +pub use parquet::encryption::decrypt::FileDecryptionProperties; +#[cfg(feature = "parquet_encryption")] +pub use parquet::encryption::encrypt::FileEncryptionProperties; + +#[cfg(not(feature = "parquet_encryption"))] +#[derive(Default, Debug)] +pub struct FileDecryptionProperties; +#[cfg(not(feature = "parquet_encryption"))] +#[derive(Default, Debug)] +pub struct FileEncryptionProperties; + +pub use crate::config::{ConfigFileDecryptionProperties, ConfigFileEncryptionProperties}; + +#[cfg(feature = "parquet_encryption")] +pub fn map_encryption_to_config_encryption( + encryption: Option<&FileEncryptionProperties>, +) -> Option { + encryption.map(|fe| fe.into()) +} + +#[cfg(not(feature = "parquet_encryption"))] +pub fn map_encryption_to_config_encryption( + _encryption: Option<&FileEncryptionProperties>, +) -> Option { + None +} + +#[cfg(feature = "parquet_encryption")] +pub fn map_config_decryption_to_decryption( + decryption: &ConfigFileDecryptionProperties, +) -> FileDecryptionProperties { + decryption.clone().into() +} + +#[cfg(not(feature = "parquet_encryption"))] +pub fn map_config_decryption_to_decryption( + _decryption: &ConfigFileDecryptionProperties, +) -> FileDecryptionProperties { + FileDecryptionProperties {} +} diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index c50ec64759d55..210f0442972d2 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -35,6 +35,7 @@ use apache_avro::Error as AvroError; use arrow::error::ArrowError; #[cfg(feature = "parquet")] use parquet::errors::ParquetError; +#[cfg(feature = "sql")] use sqlparser::parser::ParserError; use tokio::task::JoinError; @@ -53,22 +54,23 @@ pub enum DataFusionError { /// Error returned by arrow. /// /// 2nd argument is for optional backtrace - ArrowError(ArrowError, Option), + ArrowError(Box, Option), /// Error when reading / writing Parquet data. #[cfg(feature = "parquet")] - ParquetError(ParquetError), + ParquetError(Box), /// Error when reading Avro data. #[cfg(feature = "avro")] - AvroError(AvroError), + AvroError(Box), /// Error when reading / writing to / from an object_store (e.g. S3 or LocalFile) #[cfg(feature = "object_store")] - ObjectStore(object_store::Error), + ObjectStore(Box), /// Error when an I/O operation fails IoError(io::Error), /// Error when SQL is syntactically incorrect. /// /// 2nd argument is for optional backtrace - SQL(ParserError, Option), + #[cfg(feature = "sql")] + SQL(Box, Option), /// Error when a feature is not yet implemented. /// /// These errors are sometimes returned for features that are still in @@ -107,7 +109,7 @@ pub enum DataFusionError { /// /// 2nd argument is for optional backtrace /// Boxing the optional backtrace to prevent - SchemaError(SchemaError, Box>), + SchemaError(Box, Box>), /// Error during execution of the query. /// /// This error is returned when an error happens during execution due to a @@ -118,7 +120,7 @@ pub enum DataFusionError { /// [`JoinError`] during execution of the query. /// /// This error can't occur for unjoined tasks, such as execution shutdown. - ExecutionJoin(JoinError), + ExecutionJoin(Box), /// Error when resources (such as memory of scratch disk space) are exhausted. /// /// This error is thrown when a consumer cannot acquire additional memory @@ -164,7 +166,7 @@ macro_rules! context { #[derive(Debug)] pub enum SchemaError { /// Schema contains a (possibly) qualified and unqualified field with same unqualified name - AmbiguousReference { field: Column }, + AmbiguousReference { field: Box }, /// Schema contains duplicate qualified field name DuplicateQualifiedField { qualifier: Box, @@ -276,14 +278,14 @@ impl From for DataFusionError { impl From for DataFusionError { fn from(e: ArrowError) -> Self { - DataFusionError::ArrowError(e, None) + DataFusionError::ArrowError(Box::new(e), Some(DataFusionError::get_back_trace())) } } impl From for ArrowError { fn from(e: DataFusionError) -> Self { match e { - DataFusionError::ArrowError(e, _) => e, + DataFusionError::ArrowError(e, _) => *e, DataFusionError::External(e) => ArrowError::ExternalError(e), other => ArrowError::ExternalError(Box::new(other)), } @@ -304,34 +306,35 @@ impl From<&Arc> for DataFusionError { #[cfg(feature = "parquet")] impl From for DataFusionError { fn from(e: ParquetError) -> Self { - DataFusionError::ParquetError(e) + DataFusionError::ParquetError(Box::new(e)) } } #[cfg(feature = "avro")] impl From for DataFusionError { fn from(e: AvroError) -> Self { - DataFusionError::AvroError(e) + DataFusionError::AvroError(Box::new(e)) } } #[cfg(feature = "object_store")] impl From for DataFusionError { fn from(e: object_store::Error) -> Self { - DataFusionError::ObjectStore(e) + DataFusionError::ObjectStore(Box::new(e)) } } #[cfg(feature = "object_store")] impl From for DataFusionError { fn from(e: object_store::path::Error) -> Self { - DataFusionError::ObjectStore(e.into()) + DataFusionError::ObjectStore(Box::new(e.into())) } } +#[cfg(feature = "sql")] impl From for DataFusionError { fn from(e: ParserError) -> Self { - DataFusionError::SQL(e, None) + DataFusionError::SQL(Box::new(e), None) } } @@ -361,22 +364,23 @@ impl Display for DataFusionError { impl Error for DataFusionError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { - DataFusionError::ArrowError(e, _) => Some(e), + DataFusionError::ArrowError(e, _) => Some(e.as_ref()), #[cfg(feature = "parquet")] - DataFusionError::ParquetError(e) => Some(e), + DataFusionError::ParquetError(e) => Some(e.as_ref()), #[cfg(feature = "avro")] - DataFusionError::AvroError(e) => Some(e), + DataFusionError::AvroError(e) => Some(e.as_ref()), #[cfg(feature = "object_store")] - DataFusionError::ObjectStore(e) => Some(e), + DataFusionError::ObjectStore(e) => Some(e.as_ref()), DataFusionError::IoError(e) => Some(e), - DataFusionError::SQL(e, _) => Some(e), + #[cfg(feature = "sql")] + DataFusionError::SQL(e, _) => Some(e.as_ref()), DataFusionError::NotImplemented(_) => None, DataFusionError::Internal(_) => None, DataFusionError::Configuration(_) => None, DataFusionError::Plan(_) => None, - DataFusionError::SchemaError(e, _) => Some(e), + DataFusionError::SchemaError(e, _) => Some(e.as_ref()), DataFusionError::Execution(_) => None, - DataFusionError::ExecutionJoin(e) => Some(e), + DataFusionError::ExecutionJoin(e) => Some(e.as_ref()), DataFusionError::ResourcesExhausted(_) => None, DataFusionError::External(e) => Some(e.as_ref()), DataFusionError::Context(_, e) => Some(e.as_ref()), @@ -397,7 +401,7 @@ impl Error for DataFusionError { impl From for io::Error { fn from(e: DataFusionError) -> Self { - io::Error::new(io::ErrorKind::Other, e) + io::Error::other(e) } } @@ -451,12 +455,13 @@ impl DataFusionError { /// If backtrace enabled then error has a format "message" [`Self::BACK_TRACE_SEP`] "backtrace" /// The method strips the backtrace and outputs "message" pub fn strip_backtrace(&self) -> String { - self.to_string() + (*self + .to_string() .split(Self::BACK_TRACE_SEP) .collect::>() .first() - .unwrap_or(&"") - .to_string() + .unwrap_or(&"")) + .to_string() } /// To enable optional rust backtrace in DataFusion: @@ -497,6 +502,7 @@ impl DataFusionError { #[cfg(feature = "object_store")] DataFusionError::ObjectStore(_) => "Object Store error: ", DataFusionError::IoError(_) => "IO error: ", + #[cfg(feature = "sql")] DataFusionError::SQL(_, _) => "SQL error: ", DataFusionError::NotImplemented(_) => { "This feature is not implemented: " @@ -523,10 +529,10 @@ impl DataFusionError { } } - pub fn message(&self) -> Cow { + pub fn message(&self) -> Cow<'_, str> { match *self { DataFusionError::ArrowError(ref desc, ref backtrace) => { - let backtrace = backtrace.clone().unwrap_or("".to_owned()); + let backtrace = backtrace.clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc}{backtrace}")) } #[cfg(feature = "parquet")] @@ -534,20 +540,23 @@ impl DataFusionError { #[cfg(feature = "avro")] DataFusionError::AvroError(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::IoError(ref desc) => Cow::Owned(desc.to_string()), + #[cfg(feature = "sql")] DataFusionError::SQL(ref desc, ref backtrace) => { - let backtrace: String = backtrace.clone().unwrap_or("".to_owned()); + let backtrace: String = + backtrace.clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc:?}{backtrace}")) } DataFusionError::Configuration(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::NotImplemented(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::Internal(ref desc) => Cow::Owned(format!( - "{desc}.\nThis was likely caused by a bug in DataFusion's \ - code and we would welcome that you file an bug report in our issue tracker" + "{desc}.\nThis issue was likely caused by a bug in DataFusion's code. \ + Please help us to resolve this by filing a bug report in our issue tracker: \ + https://github.com/apache/datafusion/issues" )), DataFusionError::Plan(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::SchemaError(ref desc, ref backtrace) => { let backtrace: &str = - &backtrace.as_ref().clone().unwrap_or("".to_owned()); + &backtrace.as_ref().clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc}{backtrace}")) } DataFusionError::Execution(ref desc) => Cow::Owned(desc.to_string()), @@ -759,23 +768,33 @@ macro_rules! make_error { /// Macro wraps `$ERR` to add backtrace feature #[macro_export] macro_rules! $NAME_DF_ERR { - ($d($d args:expr),*) => { - $crate::DataFusionError::$ERR( + ($d($d args:expr),* $d(; diagnostic=$d DIAG:expr)?) => {{ + let err =$crate::DataFusionError::$ERR( ::std::format!( "{}{}", ::std::format!($d($d args),*), $crate::DataFusionError::get_back_trace(), ).into() - ) + ); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + err } } + } /// Macro wraps Err(`$ERR`) to add backtrace feature #[macro_export] macro_rules! $NAME_ERR { - ($d($d args:expr),*) => { - Err($crate::[<_ $NAME_DF_ERR>]!($d($d args),*)) - } + ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ + let err = $crate::[<_ $NAME_DF_ERR>]!($d($d args),*); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + Err(err) + + }} } @@ -816,54 +835,80 @@ make_error!(resources_err, resources_datafusion_err, ResourcesExhausted); // Exposes a macro to create `DataFusionError::SQL` with optional backtrace #[macro_export] macro_rules! sql_datafusion_err { - ($ERR:expr) => { - DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = DataFusionError::SQL(Box::new($ERR), Some(DataFusionError::get_back_trace())); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::SQL)` with optional backtrace #[macro_export] macro_rules! sql_err { - ($ERR:expr) => { - Err(datafusion_common::sql_datafusion_err!($ERR)) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = datafusion_common::sql_datafusion_err!($ERR); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } // Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace #[macro_export] macro_rules! arrow_datafusion_err { - ($ERR:expr) => { - DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = DataFusionError::ArrowError(Box::new($ERR), Some(DataFusionError::get_back_trace())); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::ArrowError)` with optional backtrace #[macro_export] macro_rules! arrow_err { - ($ERR:expr) => { - Err(datafusion_common::arrow_datafusion_err!($ERR)) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => { + { + let err = datafusion_common::arrow_datafusion_err!($ERR); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } // Exposes a macro to create `DataFusionError::SchemaError` with optional backtrace #[macro_export] macro_rules! schema_datafusion_err { - ($ERR:expr) => { - $crate::error::DataFusionError::SchemaError( - $ERR, + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = $crate::error::DataFusionError::SchemaError( + Box::new($ERR), Box::new(Some($crate::error::DataFusionError::get_back_trace())), - ) - }; + ); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::SchemaError)` with optional backtrace #[macro_export] macro_rules! schema_err { - ($ERR:expr) => { - Err($crate::error::DataFusionError::SchemaError( - $ERR, + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = $crate::error::DataFusionError::SchemaError( + Box::new($ERR), Box::new(Some($crate::error::DataFusionError::get_back_trace())), - )) + ); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + } }; } @@ -908,17 +953,27 @@ pub fn add_possible_columns_to_diag( .collect(); for name in field_names { - diagnostic.add_note(format!("possible column {}", name), None); + diagnostic.add_note(format!("possible column {name}"), None); } } #[cfg(test)] mod test { + use super::*; + + use std::mem::size_of; use std::sync::Arc; - use crate::error::{DataFusionError, GenericError}; use arrow::error::ArrowError; + #[test] + fn test_error_size() { + // Since Errors influence the size of Result which influence the size of the stack + // please don't allow this to grow larger + assert_eq!(size_of::(), 40); + assert_eq!(size_of::(), 40); + } + #[test] fn datafusion_error_to_arrow() { let res = return_arrow_error().unwrap_err(); @@ -983,8 +1038,8 @@ mod test { do_root_test( DataFusionError::ArrowError( - ArrowError::ExternalError(Box::new(DataFusionError::ResourcesExhausted( - "foo".to_string(), + Box::new(ArrowError::ExternalError(Box::new( + DataFusionError::ResourcesExhausted("foo".to_string()), ))), None, ), @@ -1007,9 +1062,11 @@ mod test { do_root_test( DataFusionError::ArrowError( - ArrowError::ExternalError(Box::new(ArrowError::ExternalError(Box::new( - DataFusionError::ResourcesExhausted("foo".to_string()), - )))), + Box::new(ArrowError::ExternalError(Box::new( + ArrowError::ExternalError(Box::new( + DataFusionError::ResourcesExhausted("foo".to_string()), + )), + ))), None, ), DataFusionError::ResourcesExhausted("foo".to_string()), @@ -1083,8 +1140,7 @@ mod test { ); // assert wrapping other Error - let generic_error: GenericError = - Box::new(std::io::Error::new(std::io::ErrorKind::Other, "io error")); + let generic_error: GenericError = Box::new(io::Error::other("io error")); let datafusion_error: DataFusionError = generic_error.into(); println!("{}", datafusion_error.strip_backtrace()); assert_eq!( @@ -1095,13 +1151,12 @@ mod test { #[test] fn external_error_no_recursive() { - let generic_error_1: GenericError = - Box::new(std::io::Error::new(std::io::ErrorKind::Other, "io error")); + let generic_error_1: GenericError = Box::new(io::Error::other("io error")); let external_error_1: DataFusionError = generic_error_1.into(); let generic_error_2: GenericError = Box::new(external_error_1); let external_error_2: DataFusionError = generic_error_2.into(); - println!("{}", external_error_2); + println!("{external_error_2}"); assert!(external_error_2 .to_string() .starts_with("External error: io error")); @@ -1116,7 +1171,7 @@ mod test { /// Model what happens when using arrow kernels in DataFusion /// code: need to turn an ArrowError into a DataFusionError - fn return_datafusion_error() -> crate::error::Result<()> { + fn return_datafusion_error() -> Result<()> { // Expect the '?' to work Err(ArrowError::SchemaError("bar".to_string()).into()) } diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 939cb5e1a3578..3977f2b489e18 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -17,7 +17,6 @@ //! Options related to how parquet files should be written -use base64::Engine; use std::sync::Arc; use crate::{ @@ -26,6 +25,7 @@ use crate::{ }; use arrow::datatypes::Schema; +use parquet::arrow::encode_arrow_schema; // TODO: handle once deprecated #[allow(deprecated)] use parquet::{ @@ -35,7 +35,7 @@ use parquet::{ metadata::KeyValue, properties::{ EnabledStatistics, WriterProperties, WriterPropertiesBuilder, WriterVersion, - DEFAULT_MAX_STATISTICS_SIZE, DEFAULT_STATISTICS_ENABLED, + DEFAULT_STATISTICS_ENABLED, }, }, schema::types::ColumnPath, @@ -89,12 +89,15 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { /// Convert the session's [`TableParquetOptions`] into a single write action's [`WriterPropertiesBuilder`]. /// /// The returned [`WriterPropertiesBuilder`] includes customizations applicable per column. + /// Note that any encryption options are ignored as building the `FileEncryptionProperties` + /// might require other inputs besides the [`TableParquetOptions`]. fn try_from(table_parquet_options: &TableParquetOptions) -> Result { // Table options include kv_metadata and col-specific options let TableParquetOptions { global, column_specific_options, key_value_metadata, + crypto: _, } = table_parquet_options; let mut builder = global.into_writer_properties_builder()?; @@ -157,47 +160,12 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { builder = builder.set_column_bloom_filter_ndv(path.clone(), bloom_filter_ndv); } - - // max_statistics_size is deprecated, currently it is not being used - // TODO: remove once deprecated - #[allow(deprecated)] - if let Some(max_statistics_size) = options.max_statistics_size { - builder = { - #[allow(deprecated)] - builder.set_column_max_statistics_size(path, max_statistics_size) - } - } } Ok(builder) } } -/// Encodes the Arrow schema into the IPC format, and base64 encodes it -/// -/// TODO: use extern parquet's private method, once publicly available. -/// Refer to -fn encode_arrow_schema(schema: &Arc) -> String { - let options = arrow_ipc::writer::IpcWriteOptions::default(); - let mut dictionary_tracker = arrow_ipc::writer::DictionaryTracker::new(true); - let data_gen = arrow_ipc::writer::IpcDataGenerator::default(); - let mut serialized_schema = data_gen.schema_to_bytes_with_dictionary_tracker( - schema, - &mut dictionary_tracker, - &options, - ); - - // manually prepending the length to the schema as arrow uses the legacy IPC format - // TODO: change after addressing ARROW-9777 - let schema_len = serialized_schema.ipc_message.len(); - let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); - len_prefix_schema.append(&mut vec![255u8, 255, 255, 255]); - len_prefix_schema.append((schema_len as u32).to_le_bytes().to_vec().as_mut()); - len_prefix_schema.append(&mut serialized_schema.ipc_message); - - base64::prelude::BASE64_STANDARD.encode(&len_prefix_schema) -} - impl ParquetOptions { /// Convert the global session options, [`ParquetOptions`], into a single write action's [`WriterPropertiesBuilder`]. /// @@ -215,7 +183,6 @@ impl ParquetOptions { dictionary_enabled, dictionary_page_size_limit, statistics_enabled, - max_statistics_size, max_row_group_size, created_by, column_index_truncate_length, @@ -239,7 +206,9 @@ impl ParquetOptions { bloom_filter_on_read: _, // reads not used for writer props schema_force_view_types: _, binary_as_string: _, // not used for writer props + coerce_int96: _, // not used for writer props skip_arrow_metadata: _, + max_predicate_cache_size: _, } = self; let mut builder = WriterProperties::builder() @@ -260,13 +229,6 @@ impl ParquetOptions { .set_data_page_row_count_limit(*data_page_row_count_limit) .set_bloom_filter_enabled(*bloom_filter_on_write); - builder = { - #[allow(deprecated)] - builder.set_max_statistics_size( - max_statistics_size.unwrap_or(DEFAULT_MAX_STATISTICS_SIZE), - ) - }; - if let Some(bloom_filter_fpp) = bloom_filter_fpp { builder = builder.set_bloom_filter_fpp(*bloom_filter_fpp); }; @@ -329,8 +291,7 @@ fn split_compression_string(str_setting: &str) -> Result<(String, Option)> let level = &rh[..rh.len() - 1].parse::().map_err(|_| { DataFusionError::Configuration(format!( "Could not parse compression string. \ - Got codec: {} and unknown level from {}", - codec, str_setting + Got codec: {codec} and unknown level from {str_setting}" )) })?; Ok((codec.to_owned(), Some(*level))) @@ -440,6 +401,10 @@ pub(crate) fn parse_statistics_string(str_setting: &str) -> Result ParquetColumnOptions { - #[allow(deprecated)] // max_statistics_size ParquetColumnOptions { compression: Some("zstd(22)".into()), dictionary_enabled: src_col_defaults.dictionary_enabled.map(|v| !v), statistics_enabled: Some("none".into()), - max_statistics_size: Some(72), encoding: Some("RLE".into()), bloom_filter_enabled: Some(true), bloom_filter_fpp: Some(0.72), @@ -489,7 +448,6 @@ mod tests { dictionary_enabled: Some(!defaults.dictionary_enabled.unwrap_or(false)), dictionary_page_size_limit: 42, statistics_enabled: Some("chunk".into()), - max_statistics_size: Some(42), max_row_group_size: 42, created_by: "wordy".into(), column_index_truncate_length: Some(42), @@ -516,6 +474,8 @@ mod tests { schema_force_view_types: defaults.schema_force_view_types, binary_as_string: defaults.binary_as_string, skip_arrow_metadata: defaults.skip_arrow_metadata, + coerce_int96: None, + max_predicate_cache_size: defaults.max_predicate_cache_size, } } @@ -546,7 +506,6 @@ mod tests { ), bloom_filter_fpp: bloom_filter_default_props.map(|p| p.fpp), bloom_filter_ndv: bloom_filter_default_props.map(|p| p.ndv), - max_statistics_size: Some(props.max_statistics_size(&col)), } } @@ -579,6 +538,11 @@ mod tests { HashMap::from([(COL_NAME.into(), configured_col_props)]) }; + #[cfg(feature = "parquet_encryption")] + let fep = map_encryption_to_config_encryption(props.file_encryption_properties()); + #[cfg(not(feature = "parquet_encryption"))] + let fep = None; + #[allow(deprecated)] // max_statistics_size TableParquetOptions { global: ParquetOptions { @@ -598,7 +562,6 @@ mod tests { compression: default_col_props.compression, dictionary_enabled: default_col_props.dictionary_enabled, statistics_enabled: default_col_props.statistics_enabled, - max_statistics_size: default_col_props.max_statistics_size, bloom_filter_on_write: default_col_props .bloom_filter_enabled .unwrap_or_default(), @@ -619,12 +582,21 @@ mod tests { maximum_buffered_record_batches_per_stream: global_options_defaults .maximum_buffered_record_batches_per_stream, bloom_filter_on_read: global_options_defaults.bloom_filter_on_read, + max_predicate_cache_size: global_options_defaults + .max_predicate_cache_size, schema_force_view_types: global_options_defaults.schema_force_view_types, binary_as_string: global_options_defaults.binary_as_string, skip_arrow_metadata: global_options_defaults.skip_arrow_metadata, + coerce_int96: None, }, column_specific_options, key_value_metadata, + crypto: ParquetEncryptionOptions { + file_encryption: fep, + file_decryption: None, + factory_id: None, + factory_options: Default::default(), + }, } } @@ -679,6 +651,7 @@ mod tests { )] .into(), key_value_metadata: [(key, value)].into(), + crypto: Default::default(), }; let writer_props = WriterPropertiesBuilder::try_from(&table_parquet_opts) diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index 23cfb72314a3c..06ec519ef356c 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -15,10 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::fmt::{self, Display}; +use std::str::FromStr; + use arrow::compute::CastOptions; use arrow::util::display::{DurationFormat, FormatOptions}; +use crate::config::{ConfigField, Visit}; +use crate::error::{DataFusionError, Result}; + /// The default [`FormatOptions`] to use within DataFusion +/// Also see [`crate::config::FormatOptions`] pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> = FormatOptions::new().with_duration_format(DurationFormat::Pretty); @@ -28,6 +35,173 @@ pub const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { format_options: DEFAULT_FORMAT_OPTIONS, }; -pub const DEFAULT_CLI_FORMAT_OPTIONS: FormatOptions<'static> = FormatOptions::new() - .with_duration_format(DurationFormat::Pretty) - .with_null("NULL"); +/// Output formats for controlling for Explain plans +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ExplainFormat { + /// Indent mode + /// + /// Example: + /// ```text + /// > explain format indent select x from values (1) t(x); + /// +---------------+-----------------------------------------------------+ + /// | plan_type | plan | + /// +---------------+-----------------------------------------------------+ + /// | logical_plan | SubqueryAlias: t | + /// | | Projection: column1 AS x | + /// | | Values: (Int64(1)) | + /// | physical_plan | ProjectionExec: expr=[column1@0 as x] | + /// | | DataSourceExec: partitions=1, partition_sizes=[1] | + /// | | | + /// +---------------+-----------------------------------------------------+ + /// ``` + Indent, + /// Tree mode + /// + /// Example: + /// ```text + /// > explain format tree select x from values (1) t(x); + /// +---------------+-------------------------------+ + /// | plan_type | plan | + /// +---------------+-------------------------------+ + /// | physical_plan | ┌───────────────────────────┐ | + /// | | │ ProjectionExec │ | + /// | | │ -------------------- │ | + /// | | │ x: column1@0 │ | + /// | | └─────────────┬─────────────┘ | + /// | | ┌─────────────┴─────────────┐ | + /// | | │ DataSourceExec │ | + /// | | │ -------------------- │ | + /// | | │ bytes: 128 │ | + /// | | │ format: memory │ | + /// | | │ rows: 1 │ | + /// | | └───────────────────────────┘ | + /// | | | + /// +---------------+-------------------------------+ + /// ``` + Tree, + /// Postgres Json mode + /// + /// A displayable structure that produces plan in postgresql JSON format. + /// + /// Users can use this format to visualize the plan in existing plan + /// visualization tools, for example [dalibo](https://explain.dalibo.com/) + /// + /// Example: + /// ```text + /// > explain format pgjson select x from values (1) t(x); + /// +--------------+--------------------------------------+ + /// | plan_type | plan | + /// +--------------+--------------------------------------+ + /// | logical_plan | [ | + /// | | { | + /// | | "Plan": { | + /// | | "Alias": "t", | + /// | | "Node Type": "Subquery", | + /// | | "Output": [ | + /// | | "x" | + /// | | ], | + /// | | "Plans": [ | + /// | | { | + /// | | "Expressions": [ | + /// | | "column1 AS x" | + /// | | ], | + /// | | "Node Type": "Projection", | + /// | | "Output": [ | + /// | | "x" | + /// | | ], | + /// | | "Plans": [ | + /// | | { | + /// | | "Node Type": "Values", | + /// | | "Output": [ | + /// | | "column1" | + /// | | ], | + /// | | "Plans": [], | + /// | | "Values": "(Int64(1))" | + /// | | } | + /// | | ] | + /// | | } | + /// | | ] | + /// | | } | + /// | | } | + /// | | ] | + /// +--------------+--------------------------------------+ + /// ``` + PostgresJSON, + /// Graphviz mode + /// + /// Example: + /// ```text + /// > explain format graphviz select x from values (1) t(x); + /// +--------------+------------------------------------------------------------------------+ + /// | plan_type | plan | + /// +--------------+------------------------------------------------------------------------+ + /// | logical_plan | | + /// | | // Begin DataFusion GraphViz Plan, | + /// | | // display it online here: https://dreampuf.github.io/GraphvizOnline | + /// | | | + /// | | digraph { | + /// | | subgraph cluster_1 | + /// | | { | + /// | | graph[label="LogicalPlan"] | + /// | | 2[shape=box label="SubqueryAlias: t"] | + /// | | 3[shape=box label="Projection: column1 AS x"] | + /// | | 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] | + /// | | 4[shape=box label="Values: (Int64(1))"] | + /// | | 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] | + /// | | } | + /// | | subgraph cluster_5 | + /// | | { | + /// | | graph[label="Detailed LogicalPlan"] | + /// | | 6[shape=box label="SubqueryAlias: t\nSchema: [x:Int64;N]"] | + /// | | 7[shape=box label="Projection: column1 AS x\nSchema: [x:Int64;N]"] | + /// | | 6 -> 7 [arrowhead=none, arrowtail=normal, dir=back] | + /// | | 8[shape=box label="Values: (Int64(1))\nSchema: [column1:Int64;N]"] | + /// | | 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] | + /// | | } | + /// | | } | + /// | | // End DataFusion GraphViz Plan | + /// | | | + /// +--------------+------------------------------------------------------------------------+ + /// ``` + Graphviz, +} + +/// Implement parsing strings to `ExplainFormat` +impl FromStr for ExplainFormat { + type Err = DataFusionError; + + fn from_str(format: &str) -> Result { + match format.to_lowercase().as_str() { + "indent" => Ok(ExplainFormat::Indent), + "tree" => Ok(ExplainFormat::Tree), + "pgjson" => Ok(ExplainFormat::PostgresJSON), + "graphviz" => Ok(ExplainFormat::Graphviz), + _ => { + Err(DataFusionError::Configuration(format!("Invalid explain format. Expected 'indent', 'tree', 'pgjson' or 'graphviz'. Got '{format}'"))) + } + } + } +} + +impl Display for ExplainFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + ExplainFormat::Indent => "indent", + ExplainFormat::Tree => "tree", + ExplainFormat::PostgresJSON => "pgjson", + ExplainFormat::Graphviz => "graphviz", + }; + write!(f, "{s}") + } +} + +impl ConfigField for ExplainFormat { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = ExplainFormat::from_str(value)?; + Ok(()) + } +} diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 5f262d634af37..63962998ad18b 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -36,33 +36,31 @@ pub enum Constraint { } /// This object encapsulates a list of functional constraints: -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd)] pub struct Constraints { inner: Vec, } impl Constraints { - /// Create empty constraints - pub fn empty() -> Self { - Constraints::new_unverified(vec![]) - } - - /// Create a new `Constraints` object from the given `constraints`. - /// Users should use the `empty` or `new_from_table_constraints` functions - /// for constructing `Constraints`. This constructor is for internal - /// purposes only and does not check whether the argument is valid. The user - /// is responsible for supplying a valid vector of `Constraint` objects. + /// Create a new [`Constraints`] object from the given `constraints`. + /// Users should use the [`Constraints::default`] or [`SqlToRel::new_constraint_from_table_constraints`] + /// functions for constructing [`Constraints`] instances. This constructor + /// is for internal purposes only and does not check whether the argument + /// is valid. The user is responsible for supplying a valid vector of + /// [`Constraint`] objects. + /// + /// [`SqlToRel::new_constraint_from_table_constraints`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html#method.new_constraint_from_table_constraints pub fn new_unverified(constraints: Vec) -> Self { Self { inner: constraints } } - /// Check whether constraints is empty - pub fn is_empty(&self) -> bool { - self.inner.is_empty() + /// Extends the current constraints with the given `other` constraints. + pub fn extend(&mut self, other: Constraints) { + self.inner.extend(other.inner); } - /// Projects constraints using the given projection indices. - /// Returns None if any of the constraint columns are not included in the projection. + /// Projects constraints using the given projection indices. Returns `None` + /// if any of the constraint columns are not included in the projection. pub fn project(&self, proj_indices: &[usize]) -> Option { let projected = self .inner @@ -72,14 +70,14 @@ impl Constraints { Constraint::PrimaryKey(indices) => { let new_indices = update_elements_with_matching_indices(indices, proj_indices); - // Only keep constraint if all columns are preserved + // Only keep the constraint if all columns are preserved: (new_indices.len() == indices.len()) .then_some(Constraint::PrimaryKey(new_indices)) } Constraint::Unique(indices) => { let new_indices = update_elements_with_matching_indices(indices, proj_indices); - // Only keep constraint if all columns are preserved + // Only keep the constraint if all columns are preserved: (new_indices.len() == indices.len()) .then_some(Constraint::Unique(new_indices)) } @@ -91,15 +89,9 @@ impl Constraints { } } -impl Default for Constraints { - fn default() -> Self { - Constraints::empty() - } -} - impl IntoIterator for Constraints { type Item = Constraint; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.inner.into_iter() @@ -111,7 +103,7 @@ impl Display for Constraints { let pk = self .inner .iter() - .map(|c| format!("{:?}", c)) + .map(|c| format!("{c:?}")) .collect::>(); let pk = pk.join(", "); write!(f, "constraints=[{pk}]") @@ -372,7 +364,7 @@ impl FunctionalDependencies { // These joins preserve functional dependencies of the left side: left_func_dependencies } - JoinType::RightSemi | JoinType::RightAnti => { + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { // These joins preserve functional dependencies of the right side: right_func_dependencies } diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index e78d42257b9cb..4b18351f708b7 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -184,6 +184,26 @@ fn hash_array( } } +/// Helper function to update hash for a dictionary key if the value is valid +#[cfg(not(feature = "force_hash_collisions"))] +#[inline] +fn update_hash_for_dict_key( + hash: &mut u64, + dict_hashes: &[u64], + dict_values: &dyn Array, + idx: usize, + multi_col: bool, +) { + if dict_values.is_valid(idx) { + if multi_col { + *hash = combine_hashes(dict_hashes[idx], *hash); + } else { + *hash = dict_hashes[idx]; + } + } + // no update for invalid dictionary value +} + /// Hash the values in a dictionary array #[cfg(not(feature = "force_hash_collisions"))] fn hash_dictionary( @@ -195,23 +215,23 @@ fn hash_dictionary( // Hash each dictionary value once, and then use that computed // hash for each key value to avoid a potentially expensive // redundant hashing for large dictionary elements (e.g. strings) - let values = Arc::clone(array.values()); - let mut dict_hashes = vec![0; values.len()]; - create_hashes(&[values], random_state, &mut dict_hashes)?; + let dict_values = Arc::clone(array.values()); + let mut dict_hashes = vec![0; dict_values.len()]; + create_hashes(&[dict_values], random_state, &mut dict_hashes)?; // combine hash for each index in values - if multi_col { - for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { - if let Some(key) = key { - *hash = combine_hashes(dict_hashes[key.as_usize()], *hash) - } // no update for Null, consistent with other hashes - } - } else { - for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { - if let Some(key) = key { - *hash = dict_hashes[key.as_usize()] - } // no update for Null, consistent with other hashes - } + let dict_values = array.values(); + for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) { + if let Some(key) = key { + let idx = key.as_usize(); + update_hash_for_dict_key( + hash, + &dict_hashes, + dict_values.as_ref(), + idx, + multi_col, + ); + } // no update for Null key } Ok(()) } diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index ac81d977b7296..e6a90db2dc3eb 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -67,6 +67,11 @@ pub enum JoinType { /// /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf LeftMark, + /// Right Mark Join + /// + /// Same logic as the LeftMark Join above, however it returns a record for each record from the + /// right input. + RightMark, } impl JoinType { @@ -87,13 +92,12 @@ impl JoinType { JoinType::RightSemi => JoinType::LeftSemi, JoinType::LeftAnti => JoinType::RightAnti, JoinType::RightAnti => JoinType::LeftAnti, - JoinType::LeftMark => { - unreachable!("LeftMark join type does not support swapping") - } + JoinType::LeftMark => JoinType::RightMark, + JoinType::RightMark => JoinType::LeftMark, } } - /// Does the join type support swapping inputs? + /// Does the join type support swapping inputs? pub fn supports_swap(&self) -> bool { matches!( self, @@ -105,6 +109,8 @@ impl JoinType { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark ) } } @@ -121,6 +127,7 @@ impl Display for JoinType { JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", JoinType::LeftMark => "LeftMark", + JoinType::RightMark => "RightMark", }; write!(f, "{join_type}") } @@ -141,6 +148,7 @@ impl FromStr for JoinType { "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), "LEFTMARK" => Ok(JoinType::LeftMark), + "RIGHTMARK" => Ok(JoinType::RightMark), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index b137624532b92..24ec9b7be3233 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -41,12 +41,16 @@ pub mod config; pub mod cse; pub mod diagnostic; pub mod display; +pub mod encryption; pub mod error; pub mod file_options; pub mod format; pub mod hash_utils; pub mod instant; +pub mod nested_struct; +mod null_equality; pub mod parsers; +pub mod pruning; pub mod rounding; pub mod scalar; pub mod spans; @@ -78,6 +82,8 @@ pub use functional_dependencies::{ }; use hashbrown::hash_map::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use nested_struct::cast_column; +pub use null_equality::NullEquality; pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; @@ -135,10 +141,12 @@ pub mod __private { impl DowncastArrayHelper for T { fn downcast_array_helper(&self) -> Result<&U> { self.as_any().downcast_ref().ok_or_else(|| { + let actual_type = self.data_type(); + let desired_type_name = type_name::(); _internal_datafusion_err!( "could not cast array of type {} to {}", - self.data_type(), - type_name::() + actual_type, + desired_type_name ) }) } @@ -185,9 +193,7 @@ mod tests { let expected_prefix = expected_prefix.as_ref(); assert!( actual.starts_with(expected_prefix), - "Expected '{}' to start with '{}'", - actual, - expected_prefix + "Expected '{actual}' to start with '{expected_prefix}'" ); } } diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs new file mode 100644 index 0000000000000..38060e370bfa1 --- /dev/null +++ b/datafusion/common/src/nested_struct.rs @@ -0,0 +1,704 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{Result, _plan_err}; +use arrow::{ + array::{new_null_array, Array, ArrayRef, StructArray}, + compute::{cast_with_options, CastOptions}, + datatypes::{DataType::Struct, Field, FieldRef}, +}; +use std::sync::Arc; + +/// Cast a struct column to match target struct fields, handling nested structs recursively. +/// +/// This function implements struct-to-struct casting with the assumption that **structs should +/// always be allowed to cast to other structs**. However, the source column must already be +/// a struct type - non-struct sources will result in an error. +/// +/// ## Field Matching Strategy +/// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) +/// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type +/// - **Missing Fields**: Target fields not present in the source are filled with null values +/// - **Extra Fields**: Source fields not present in the target are ignored +/// +/// ## Nested Struct Handling +/// - Nested structs are handled recursively using the same casting rules +/// - Each level of nesting follows the same field matching and null-filling strategy +/// - This allows for complex struct transformations while maintaining data integrity +/// +/// # Arguments +/// * `source_col` - The source array to cast (must be a struct array) +/// * `target_fields` - The target struct field definitions to cast to +/// +/// # Returns +/// A `Result` containing the cast struct array +/// +/// # Errors +/// Returns a `DataFusionError::Plan` if the source column is not a struct type +fn cast_struct_column( + source_col: &ArrayRef, + target_fields: &[Arc], + cast_options: &CastOptions, +) -> Result { + if let Some(source_struct) = source_col.as_any().downcast_ref::() { + validate_struct_compatibility(source_struct.fields(), target_fields)?; + + let mut fields: Vec> = Vec::with_capacity(target_fields.len()); + let mut arrays: Vec = Vec::with_capacity(target_fields.len()); + let num_rows = source_col.len(); + + for target_child_field in target_fields { + fields.push(Arc::clone(target_child_field)); + match source_struct.column_by_name(target_child_field.name()) { + Some(source_child_col) => { + let adapted_child = + cast_column(source_child_col, target_child_field, cast_options) + .map_err(|e| { + e.context(format!( + "While casting struct field '{}'", + target_child_field.name() + )) + })?; + arrays.push(adapted_child); + } + None => { + arrays.push(new_null_array(target_child_field.data_type(), num_rows)); + } + } + } + + let struct_array = + StructArray::new(fields.into(), arrays, source_struct.nulls().cloned()); + Ok(Arc::new(struct_array)) + } else { + // Return error if source is not a struct type + _plan_err!( + "Cannot cast column of type {} to struct type. Source must be a struct to cast to struct.", + source_col.data_type() + ) + } +} + +/// Cast a column to match the target field type, with special handling for nested structs. +/// +/// This function serves as the main entry point for column casting operations. For struct +/// types, it enforces that **only struct columns can be cast to struct types**. +/// +/// ## Casting Behavior +/// - **Struct Types**: Delegates to `cast_struct_column` for struct-to-struct casting only +/// - **Non-Struct Types**: Uses Arrow's standard `cast` function for primitive type conversions +/// +/// ## Cast Options +/// The `cast_options` argument controls how Arrow handles values that cannot be represented +/// in the target type. When `safe` is `false` (DataFusion's default) the cast will return an +/// error if such a value is encountered. Setting `safe` to `true` instead produces `NULL` +/// for out-of-range or otherwise invalid values. The options also allow customizing how +/// temporal values are formatted when cast to strings. +/// +/// ``` +/// use std::sync::Arc; +/// use arrow::array::{Int64Array, ArrayRef}; +/// use arrow::compute::CastOptions; +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::nested_struct::cast_column; +/// +/// let source: ArrayRef = Arc::new(Int64Array::from(vec![1, i64::MAX])); +/// let target = Field::new("ints", DataType::Int32, true); +/// // Permit lossy conversions by producing NULL on overflow instead of erroring +/// let options = CastOptions { safe: true, ..Default::default() }; +/// let result = cast_column(&source, &target, &options).unwrap(); +/// assert!(result.is_null(1)); +/// ``` +/// +/// ## Struct Casting Requirements +/// The struct casting logic requires that the source column must already be a struct type. +/// This makes the function useful for: +/// - Schema evolution scenarios where struct layouts change over time +/// - Data migration between different struct schemas +/// - Type-safe data processing pipelines that maintain struct type integrity +/// +/// # Arguments +/// * `source_col` - The source array to cast +/// * `target_field` - The target field definition (including type and metadata) +/// * `cast_options` - Options that govern strictness and formatting of the cast +/// +/// # Returns +/// A `Result` containing the cast array +/// +/// # Errors +/// Returns an error if: +/// - Attempting to cast a non-struct column to a struct type +/// - Arrow's cast function fails for non-struct types +/// - Memory allocation fails during struct construction +/// - Invalid data type combinations are encountered +pub fn cast_column( + source_col: &ArrayRef, + target_field: &Field, + cast_options: &CastOptions, +) -> Result { + match target_field.data_type() { + Struct(target_fields) => { + cast_struct_column(source_col, target_fields, cast_options) + } + _ => Ok(cast_with_options( + source_col, + target_field.data_type(), + cast_options, + )?), + } +} + +/// Validates compatibility between source and target struct fields for casting operations. +/// +/// This function implements comprehensive struct compatibility checking by examining: +/// - Field name matching between source and target structs +/// - Type castability for each matching field (including recursive struct validation) +/// - Proper handling of missing fields (target fields not in source are allowed - filled with nulls) +/// - Proper handling of extra fields (source fields not in target are allowed - ignored) +/// +/// # Compatibility Rules +/// - **Field Matching**: Fields are matched by name (case-sensitive) +/// - **Missing Target Fields**: Allowed - will be filled with null values during casting +/// - **Extra Source Fields**: Allowed - will be ignored during casting +/// - **Type Compatibility**: Each matching field must be castable using Arrow's type system +/// - **Nested Structs**: Recursively validates nested struct compatibility +/// +/// # Arguments +/// * `source_fields` - Fields from the source struct type +/// * `target_fields` - Fields from the target struct type +/// +/// # Returns +/// * `Ok(())` if the structs are compatible for casting +/// * `Err(DataFusionError)` with detailed error message if incompatible +/// +/// # Examples +/// ```text +/// // Compatible: source has extra field, target has missing field +/// // Source: {a: i32, b: string, c: f64} +/// // Target: {a: i64, d: bool} +/// // Result: Ok(()) - 'a' can cast i32->i64, 'b','c' ignored, 'd' filled with nulls +/// +/// // Incompatible: matching field has incompatible types +/// // Source: {a: string} +/// // Target: {a: binary} +/// // Result: Err(...) - string cannot cast to binary +/// ``` +pub fn validate_struct_compatibility( + source_fields: &[FieldRef], + target_fields: &[FieldRef], +) -> Result<()> { + // Check compatibility for each target field + for target_field in target_fields { + // Look for matching field in source by name + if let Some(source_field) = source_fields + .iter() + .find(|f| f.name() == target_field.name()) + { + // Ensure nullability is compatible. It is invalid to cast a nullable + // source field to a non-nullable target field as this may discard + // null values. + if source_field.is_nullable() && !target_field.is_nullable() { + return _plan_err!( + "Cannot cast nullable struct field '{}' to non-nullable field", + target_field.name() + ); + } + // Check if the matching field types are compatible + match (source_field.data_type(), target_field.data_type()) { + // Recursively validate nested structs + (Struct(source_nested), Struct(target_nested)) => { + validate_struct_compatibility(source_nested, target_nested)?; + } + // For non-struct types, use the existing castability check + _ => { + if !arrow::compute::can_cast_types( + source_field.data_type(), + target_field.data_type(), + ) { + return _plan_err!( + "Cannot cast struct field '{}' from type {} to type {}", + target_field.name(), + source_field.data_type(), + target_field.data_type() + ); + } + } + } + } + // Missing fields in source are OK - they'll be filled with nulls + } + + // Extra fields in source are OK - they'll be ignored + Ok(()) +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::format::DEFAULT_CAST_OPTIONS; + use arrow::{ + array::{ + BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, + MapBuilder, StringArray, StringBuilder, + }, + buffer::NullBuffer, + datatypes::{DataType, Field, FieldRef, Int32Type}, + }; + /// Macro to extract and downcast a column from a StructArray + macro_rules! get_column_as { + ($struct_array:expr, $column_name:expr, $array_type:ty) => { + $struct_array + .column_by_name($column_name) + .unwrap() + .as_any() + .downcast_ref::<$array_type>() + .unwrap() + }; + } + + fn field(name: &str, data_type: DataType) -> Field { + Field::new(name, data_type, true) + } + + fn non_null_field(name: &str, data_type: DataType) -> Field { + Field::new(name, data_type, false) + } + + fn arc_field(name: &str, data_type: DataType) -> FieldRef { + Arc::new(field(name, data_type)) + } + + fn struct_type(fields: Vec) -> DataType { + Struct(fields.into()) + } + + fn struct_field(name: &str, fields: Vec) -> Field { + field(name, struct_type(fields)) + } + + fn arc_struct_field(name: &str, fields: Vec) -> FieldRef { + Arc::new(struct_field(name, fields)) + } + + #[test] + fn test_cast_simple_column() { + let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let target_field = field("ints", DataType::Int64); + let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), 1); + assert_eq!(result.value(1), 2); + assert_eq!(result.value(2), 3); + } + + #[test] + fn test_cast_column_with_options() { + let source = Arc::new(Int64Array::from(vec![1, i64::MAX])) as ArrayRef; + let target_field = field("ints", DataType::Int32); + + let safe_opts = CastOptions { + // safe: false - return Err for failure + safe: false, + ..DEFAULT_CAST_OPTIONS + }; + assert!(cast_column(&source, &target_field, &safe_opts).is_err()); + + let unsafe_opts = CastOptions { + // safe: true - return Null for failure + safe: true, + ..DEFAULT_CAST_OPTIONS + }; + let result = cast_column(&source, &target_field, &unsafe_opts).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.value(0), 1); + assert!(result.is_null(1)); + } + + #[test] + fn test_cast_struct_with_missing_field() { + let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef; + let source_struct = StructArray::from(vec![( + arc_field("a", DataType::Int32), + Arc::clone(&a_array), + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("a", DataType::Int32), field("b", DataType::Utf8)], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_array.fields().len(), 2); + let a_result = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_result.value(0), 1); + assert_eq!(a_result.value(1), 2); + + let b_result = get_column_as!(&struct_array, "b", StringArray); + assert_eq!(b_result.len(), 2); + assert!(b_result.is_null(0)); + assert!(b_result.is_null(1)); + } + + #[test] + fn test_cast_struct_source_not_struct() { + let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef; + let target_field = struct_field("s", vec![field("a", DataType::Int32)]); + + let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast column of type")); + assert!(error_msg.contains("to struct type")); + assert!(error_msg.contains("Source must be a struct")); + } + + #[test] + fn test_cast_struct_incompatible_child_type() { + let a_array = Arc::new(BinaryArray::from(vec![ + Some(b"a".as_ref()), + Some(b"b".as_ref()), + ])) as ArrayRef; + let source_struct = + StructArray::from(vec![(arc_field("a", DataType::Binary), a_array)]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field("s", vec![field("a", DataType::Int32)]); + + let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast struct field 'a'")); + } + + #[test] + fn test_validate_struct_compatibility_incompatible_types() { + // Source struct: {field1: Binary, field2: String} + let source_fields = vec![ + arc_field("field1", DataType::Binary), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field1: Int32} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast struct field 'field1'")); + assert!(error_msg.contains("Binary")); + assert!(error_msg.contains("Int32")); + } + + #[test] + fn test_validate_struct_compatibility_compatible_types() { + // Source struct: {field1: Int32, field2: String} + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field1: Int64} (Int32 can cast to Int64) + let target_fields = vec![arc_field("field1", DataType::Int64)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_missing_field_in_source() { + // Source struct: {field2: String} (missing field1) + let source_fields = vec![arc_field("field2", DataType::Utf8)]; + + // Target struct: {field1: Int32} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + // Should be OK - missing fields will be filled with nulls + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_additional_field_in_source() { + // Source struct: {field1: Int32, field2: String} (extra field2) + let source_fields = vec![ + arc_field("field1", DataType::Int32), + arc_field("field2", DataType::Utf8), + ]; + + // Target struct: {field1: Int32} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + // Should be OK - extra fields in source are ignored + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_cast_struct_parent_nulls_retained() { + let a_array = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let fields = vec![arc_field("a", DataType::Int32)]; + let nulls = Some(NullBuffer::from(vec![true, false])); + let source_struct = StructArray::new(fields.clone().into(), vec![a_array], nulls); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field("s", vec![field("a", DataType::Int64)]); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_array.null_count(), 1); + assert!(struct_array.is_valid(0)); + assert!(struct_array.is_null(1)); + + let a_result = get_column_as!(&struct_array, "a", Int64Array); + assert_eq!(a_result.value(0), 1); + assert_eq!(a_result.value(1), 2); + } + + #[test] + fn test_validate_struct_compatibility_nullable_to_non_nullable() { + // Source struct: {field1: Int32 nullable} + let source_fields = vec![arc_field("field1", DataType::Int32)]; + + // Target struct: {field1: Int32 non-nullable} + let target_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("field1")); + assert!(error_msg.contains("non-nullable")); + } + + #[test] + fn test_validate_struct_compatibility_non_nullable_to_nullable() { + // Source struct: {field1: Int32 non-nullable} + let source_fields = vec![Arc::new(non_null_field("field1", DataType::Int32))]; + + // Target struct: {field1: Int32 nullable} + let target_fields = vec![arc_field("field1", DataType::Int32)]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_struct_compatibility_nested_nullable_to_non_nullable() { + // Source struct: {field1: {nested: Int32 nullable}} + let source_fields = vec![Arc::new(non_null_field( + "field1", + struct_type(vec![field("nested", DataType::Int32)]), + ))]; + + // Target struct: {field1: {nested: Int32 non-nullable}} + let target_fields = vec![Arc::new(non_null_field( + "field1", + struct_type(vec![non_null_field("nested", DataType::Int32)]), + ))]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("nested")); + assert!(error_msg.contains("non-nullable")); + } + + #[test] + fn test_cast_nested_struct_with_extra_and_missing_fields() { + // Source inner struct has fields a, b, extra + let a = Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef; + let b = Arc::new(Int32Array::from(vec![Some(2), Some(3)])) as ArrayRef; + let extra = Arc::new(Int32Array::from(vec![Some(9), Some(10)])) as ArrayRef; + + let inner = StructArray::from(vec![ + (arc_field("a", DataType::Int32), a), + (arc_field("b", DataType::Int32), b), + (arc_field("extra", DataType::Int32), extra), + ]); + + let source_struct = StructArray::from(vec![( + arc_struct_field( + "inner", + vec![ + field("a", DataType::Int32), + field("b", DataType::Int32), + field("extra", DataType::Int32), + ], + ), + Arc::new(inner) as ArrayRef, + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + // Target inner struct reorders fields, adds "missing", and drops "extra" + let target_field = struct_field( + "outer", + vec![struct_field( + "inner", + vec![ + field("b", DataType::Int64), + field("a", DataType::Int32), + field("missing", DataType::Int32), + ], + )], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let outer = result.as_any().downcast_ref::().unwrap(); + let inner = get_column_as!(&outer, "inner", StructArray); + assert_eq!(inner.fields().len(), 3); + + let b = get_column_as!(inner, "b", Int64Array); + assert_eq!(b.value(0), 2); + assert_eq!(b.value(1), 3); + assert!(!b.is_null(0)); + assert!(!b.is_null(1)); + + let a = get_column_as!(inner, "a", Int32Array); + assert_eq!(a.value(0), 1); + assert!(a.is_null(1)); + + let missing = get_column_as!(inner, "missing", Int32Array); + assert!(missing.is_null(0)); + assert!(missing.is_null(1)); + } + + #[test] + fn test_cast_struct_with_array_and_map_fields() { + // Array field with second row null + let arr_array = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + ])) as ArrayRef; + + // Map field with second row null + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + map_builder.keys().append_value("a"); + map_builder.values().append_value(1); + map_builder.append(true).unwrap(); + map_builder.append(false).unwrap(); + let map_array = Arc::new(map_builder.finish()) as ArrayRef; + + let source_struct = StructArray::from(vec![ + ( + arc_field( + "arr", + DataType::List(Arc::new(field("item", DataType::Int32))), + ), + arr_array, + ), + ( + arc_field( + "map", + DataType::Map( + Arc::new(non_null_field( + "entries", + struct_type(vec![ + non_null_field("keys", DataType::Utf8), + field("values", DataType::Int32), + ]), + )), + false, + ), + ), + map_array, + ), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![ + field( + "arr", + DataType::List(Arc::new(field("item", DataType::Int32))), + ), + field( + "map", + DataType::Map( + Arc::new(non_null_field( + "entries", + struct_type(vec![ + non_null_field("keys", DataType::Utf8), + field("values", DataType::Int32), + ]), + )), + false, + ), + ), + ], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let arr = get_column_as!(&struct_array, "arr", ListArray); + assert!(!arr.is_null(0)); + assert!(arr.is_null(1)); + let arr0 = arr.value(0); + let values = arr0.as_any().downcast_ref::().unwrap(); + assert_eq!(values.value(0), 1); + assert_eq!(values.value(1), 2); + + let map = get_column_as!(&struct_array, "map", MapArray); + assert!(!map.is_null(0)); + assert!(map.is_null(1)); + let map0 = map.value(0); + let entries = map0.as_any().downcast_ref::().unwrap(); + let keys = get_column_as!(entries, "keys", StringArray); + let vals = get_column_as!(entries, "values", Int32Array); + assert_eq!(keys.value(0), "a"); + assert_eq!(vals.value(0), 1); + } + + #[test] + fn test_cast_struct_field_order_differs() { + let a = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let b = Arc::new(Int32Array::from(vec![Some(3), None])) as ArrayRef; + + let source_struct = StructArray::from(vec![ + (arc_field("a", DataType::Int32), a), + (arc_field("b", DataType::Int32), b), + ]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = struct_field( + "s", + vec![field("b", DataType::Int64), field("a", DataType::Int32)], + ); + + let result = + cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let b_col = get_column_as!(&struct_array, "b", Int64Array); + assert_eq!(b_col.value(0), 3); + assert!(b_col.is_null(1)); + + let a_col = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_col.value(0), 1); + assert_eq!(a_col.value(1), 2); + } +} diff --git a/datafusion/common/src/null_equality.rs b/datafusion/common/src/null_equality.rs new file mode 100644 index 0000000000000..847fb0975703e --- /dev/null +++ b/datafusion/common/src/null_equality.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Represents the behavior for null values when evaluating equality. Currently, its primary use +/// case is to define the behavior of joins for null values. +/// +/// # Examples +/// +/// The following table shows the expected equality behavior for `NullEquality`. +/// +/// | A | B | NullEqualsNothing | NullEqualsNull | +/// |------|------|-------------------|----------------| +/// | NULL | NULL | false | true | +/// | NULL | 'b' | false | false | +/// | 'a' | NULL | false | false | +/// | 'a' | 'b' | false | false | +/// +/// # Order +/// +/// The order on this type represents the "restrictiveness" of the behavior. The more restrictive +/// a behavior is, the fewer elements are considered to be equal to null. +/// [NullEquality::NullEqualsNothing] represents the most restrictive behavior. +/// +/// This mirrors the old order with `null_equals_null` booleans, as `false` indicated that +/// `null != null`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum NullEquality { + /// Null is *not* equal to anything (`null != null`) + NullEqualsNothing, + /// Null is equal to null (`null == null`) + NullEqualsNull, +} diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index d2802c096da1b..7582cff56f87a 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -48,7 +48,7 @@ impl ParamValues { for (i, (param_type, value)) in iter.enumerate() { if *param_type != value.data_type() { return _plan_err!( - "Expected parameter of type {:?}, got {:?} at index {}", + "Expected parameter of type {}, got {:?} at index {}", param_type, value.data_type(), i diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index c73c8a55f18c5..cd3d607dacd88 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -20,7 +20,7 @@ use std::fmt::Display; use std::str::FromStr; -use sqlparser::parser::ParserError; +use crate::DataFusionError; /// Readable file compression type #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -38,9 +38,9 @@ pub enum CompressionTypeVariant { } impl FromStr for CompressionTypeVariant { - type Err = ParserError; + type Err = DataFusionError; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> Result { let s = s.to_uppercase(); match s.as_str() { "GZIP" | "GZ" => Ok(Self::GZIP), @@ -48,7 +48,7 @@ impl FromStr for CompressionTypeVariant { "XZ" => Ok(Self::XZ), "ZST" | "ZSTD" => Ok(Self::ZSTD), "" | "UNCOMPRESSED" => Ok(Self::UNCOMPRESSED), - _ => Err(ParserError::ParserError(format!( + _ => Err(DataFusionError::NotImplemented(format!( "Unsupported file compression type {s}" ))), } @@ -64,7 +64,7 @@ impl Display for CompressionTypeVariant { Self::ZSTD => "ZSTD", Self::UNCOMPRESSED => "", }; - write!(f, "{}", str) + write!(f, "{str}") } } diff --git a/datafusion/common/src/pruning.rs b/datafusion/common/src/pruning.rs new file mode 100644 index 0000000000000..48750e3c995c4 --- /dev/null +++ b/datafusion/common/src/pruning.rs @@ -0,0 +1,1122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, NullArray, UInt64Array}; +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::{FieldRef, Schema, SchemaRef}; +use std::collections::HashSet; +use std::sync::Arc; + +use crate::error::DataFusionError; +use crate::stats::Precision; +use crate::{Column, Statistics}; +use crate::{ColumnStatistics, ScalarValue}; + +/// A source of runtime statistical information to [`PruningPredicate`]s. +/// +/// # Supported Information +/// +/// 1. Minimum and maximum values for columns +/// +/// 2. Null counts and row counts for columns +/// +/// 3. Whether the values in a column are contained in a set of literals +/// +/// # Vectorized Interface +/// +/// Information for containers / files are returned as Arrow [`ArrayRef`], so +/// the evaluation happens once on a single `RecordBatch`, which amortizes the +/// overhead of evaluating the predicate. This is important when pruning 1000s +/// of containers which often happens in analytic systems that have 1000s of +/// potential files to consider. +/// +/// For example, for the following three files with a single column `a`: +/// ```text +/// file1: column a: min=5, max=10 +/// file2: column a: No stats +/// file2: column a: min=20, max=30 +/// ``` +/// +/// PruningStatistics would return: +/// +/// ```text +/// min_values("a") -> Some([5, Null, 20]) +/// max_values("a") -> Some([10, Null, 30]) +/// min_values("X") -> None +/// ``` +/// +/// [`PruningPredicate`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html +pub trait PruningStatistics { + /// Return the minimum values for the named column, if known. + /// + /// If the minimum value for a particular container is not known, the + /// returned array should have `null` in that row. If the minimum value is + /// not known for any row, return `None`. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn min_values(&self, column: &Column) -> Option; + + /// Return the maximum values for the named column, if known. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn max_values(&self, column: &Column) -> Option; + + /// Return the number of containers (e.g. Row Groups) being pruned with + /// these statistics. + /// + /// This value corresponds to the size of the [`ArrayRef`] returned by + /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], + /// and [`Self::row_counts`]. + fn num_containers(&self) -> usize; + + /// Return the number of null values for the named column as an + /// [`UInt64Array`] + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + /// + /// [`UInt64Array`]: arrow::array::UInt64Array + fn null_counts(&self, column: &Column) -> Option; + + /// Return the number of rows for the named column in each container + /// as an [`UInt64Array`]. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + /// + /// [`UInt64Array`]: arrow::array::UInt64Array + fn row_counts(&self, column: &Column) -> Option; + + /// Returns [`BooleanArray`] where each row represents information known + /// about specific literal `values` in a column. + /// + /// For example, Parquet Bloom Filters implement this API to communicate + /// that `values` are known not to be present in a Row Group. + /// + /// The returned array has one row for each container, with the following + /// meanings: + /// * `true` if the values in `column` ONLY contain values from `values` + /// * `false` if the values in `column` are NOT ANY of `values` + /// * `null` if the neither of the above holds or is unknown. + /// + /// If these statistics can not determine column membership for any + /// container, return `None` (the default). + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option; +} + +/// Prune files based on their partition values. +/// +/// This is used both at planning time and execution time to prune +/// files based on their partition values. +/// This feeds into [`CompositePruningStatistics`] to allow pruning +/// with filters that depend both on partition columns and data columns +/// (e.g. `WHERE partition_col = data_col`). +#[derive(Clone)] +pub struct PartitionPruningStatistics { + /// Values for each column for each container. + /// + /// The outer vectors represent the columns while the inner vectors + /// represent the containers. The order must match the order of the + /// partition columns in [`PartitionPruningStatistics::partition_schema`]. + partition_values: Vec, + /// The number of containers. + /// + /// Stored since the partition values are column-major and if + /// there are no columns we wouldn't know the number of containers. + num_containers: usize, + /// The schema of the partition columns. + /// + /// This must **not** be the schema of the entire file or table: it must + /// only be the schema of the partition columns, in the same order as the + /// values in [`PartitionPruningStatistics::partition_values`]. + partition_schema: SchemaRef, +} + +impl PartitionPruningStatistics { + /// Create a new instance of [`PartitionPruningStatistics`]. + /// + /// Args: + /// * `partition_values`: A vector of vectors of [`ScalarValue`]s. + /// The outer vector represents the containers while the inner + /// vector represents the partition values for each column. + /// Note that this is the **opposite** of the order of the + /// partition columns in `PartitionPruningStatistics::partition_schema`. + /// * `partition_schema`: The schema of the partition columns. + /// This must **not** be the schema of the entire file or table: + /// instead it must only be the schema of the partition columns, + /// in the same order as the values in `partition_values`. + pub fn try_new( + partition_values: Vec>, + partition_fields: Vec, + ) -> Result { + let num_containers = partition_values.len(); + let partition_schema = Arc::new(Schema::new(partition_fields)); + let mut partition_values_by_column = + vec![ + Vec::with_capacity(partition_values.len()); + partition_schema.fields().len() + ]; + for partition_value in partition_values { + for (i, value) in partition_value.into_iter().enumerate() { + partition_values_by_column[i].push(value); + } + } + Ok(Self { + partition_values: partition_values_by_column + .into_iter() + .map(|v| { + if v.is_empty() { + Ok(Arc::new(NullArray::new(0)) as ArrayRef) + } else { + ScalarValue::iter_to_array(v) + } + }) + .collect::, _>>()?, + num_containers, + partition_schema, + }) + } +} + +impl PruningStatistics for PartitionPruningStatistics { + fn min_values(&self, column: &Column) -> Option { + let index = self.partition_schema.index_of(column.name()).ok()?; + self.partition_values.get(index).and_then(|v| { + if v.is_empty() || v.null_count() == v.len() { + // If the array is empty or all nulls, return None + None + } else { + // Otherwise, return the array as is + Some(Arc::clone(v)) + } + }) + } + + fn max_values(&self, column: &Column) -> Option { + self.min_values(column) + } + + fn num_containers(&self) -> usize { + self.num_containers + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + fn row_counts(&self, _column: &Column) -> Option { + None + } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + let index = self.partition_schema.index_of(column.name()).ok()?; + let array = self.partition_values.get(index)?; + let boolean_array = values.iter().try_fold(None, |acc, v| { + let arrow_value = v.to_scalar().ok()?; + let eq_result = arrow::compute::kernels::cmp::eq(array, &arrow_value).ok()?; + match acc { + None => Some(Some(eq_result)), + Some(acc_array) => { + arrow::compute::kernels::boolean::and(&acc_array, &eq_result) + .map(Some) + .ok() + } + } + })??; + // If the boolean array is empty or all null values, return None + if boolean_array.is_empty() || boolean_array.null_count() == boolean_array.len() { + None + } else { + Some(boolean_array) + } + } +} + +/// Prune a set of containers represented by their statistics. +/// +/// Each [`Statistics`] represents a "container" -- some collection of data +/// that has statistics of its columns. +/// +/// It is up to the caller to decide what each container represents. For +/// example, they can come from a file (e.g. [`PartitionedFile`]) or a set of of +/// files (e.g. [`FileGroup`]) +/// +/// [`PartitionedFile`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.PartitionedFile.html +/// [`FileGroup`]: https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.FileGroup.html +#[derive(Clone)] +pub struct PrunableStatistics { + /// Statistics for each container. + /// These are taken as a reference since they may be rather large / expensive to clone + /// and we often won't return all of them as ArrayRefs (we only return the columns the predicate requests). + statistics: Vec>, + /// The schema of the file these statistics are for. + schema: SchemaRef, +} + +impl PrunableStatistics { + /// Create a new instance of [`PrunableStatistics`]. + /// Each [`Statistics`] represents a container (e.g. a file or a partition of files). + /// The `schema` is the schema of the data in the containers and should apply to all files. + pub fn new(statistics: Vec>, schema: SchemaRef) -> Self { + Self { statistics, schema } + } + + fn get_exact_column_statistics( + &self, + column: &Column, + get_stat: impl Fn(&ColumnStatistics) -> &Precision, + ) -> Option { + let index = self.schema.index_of(column.name()).ok()?; + let mut has_value = false; + match ScalarValue::iter_to_array(self.statistics.iter().map(|s| { + s.column_statistics + .get(index) + .and_then(|stat| { + if let Precision::Exact(min) = get_stat(stat) { + has_value = true; + Some(min.clone()) + } else { + None + } + }) + .unwrap_or(ScalarValue::Null) + })) { + // If there is any non-null value and no errors, return the array + Ok(array) => has_value.then_some(array), + Err(_) => { + log::warn!( + "Failed to convert min values to array for column {}", + column.name() + ); + None + } + } + } +} + +impl PruningStatistics for PrunableStatistics { + fn min_values(&self, column: &Column) -> Option { + self.get_exact_column_statistics(column, |stat| &stat.min_value) + } + + fn max_values(&self, column: &Column) -> Option { + self.get_exact_column_statistics(column, |stat| &stat.max_value) + } + + fn num_containers(&self) -> usize { + self.statistics.len() + } + + fn null_counts(&self, column: &Column) -> Option { + let index = self.schema.index_of(column.name()).ok()?; + if self.statistics.iter().any(|s| { + s.column_statistics + .get(index) + .is_some_and(|stat| stat.null_count.is_exact().unwrap_or(false)) + }) { + Some(Arc::new( + self.statistics + .iter() + .map(|s| { + s.column_statistics.get(index).and_then(|stat| { + if let Precision::Exact(null_count) = &stat.null_count { + u64::try_from(*null_count).ok() + } else { + None + } + }) + }) + .collect::(), + )) + } else { + None + } + } + + fn row_counts(&self, column: &Column) -> Option { + // If the column does not exist in the schema, return None + if self.schema.index_of(column.name()).is_err() { + return None; + } + if self + .statistics + .iter() + .any(|s| s.num_rows.is_exact().unwrap_or(false)) + { + Some(Arc::new( + self.statistics + .iter() + .map(|s| { + if let Precision::Exact(row_count) = &s.num_rows { + u64::try_from(*row_count).ok() + } else { + None + } + }) + .collect::(), + )) + } else { + None + } + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } +} + +/// Combine multiple [`PruningStatistics`] into a single +/// [`CompositePruningStatistics`]. +/// This can be used to combine statistics from different sources, +/// for example partition values and file statistics. +/// This allows pruning with filters that depend on multiple sources of statistics, +/// such as `WHERE partition_col = data_col`. +/// This is done by iterating over the statistics and returning the first +/// one that has information for the requested column. +/// If multiple statistics have information for the same column, +/// the first one is returned without any regard for completeness or accuracy. +/// That is: if the first statistics has information for a column, even if it is incomplete, +/// that is returned even if a later statistics has more complete information. +pub struct CompositePruningStatistics { + pub statistics: Vec>, +} + +impl CompositePruningStatistics { + /// Create a new instance of [`CompositePruningStatistics`] from + /// a vector of [`PruningStatistics`]. + pub fn new(statistics: Vec>) -> Self { + assert!(!statistics.is_empty()); + // Check that all statistics have the same number of containers + let num_containers = statistics[0].num_containers(); + for stats in &statistics { + assert_eq!(num_containers, stats.num_containers()); + } + Self { statistics } + } +} + +impl PruningStatistics for CompositePruningStatistics { + fn min_values(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.min_values(column) { + return Some(array); + } + } + None + } + + fn max_values(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.max_values(column) { + return Some(array); + } + } + None + } + + fn num_containers(&self) -> usize { + self.statistics[0].num_containers() + } + + fn null_counts(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.null_counts(column) { + return Some(array); + } + } + None + } + + fn row_counts(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.row_counts(column) { + return Some(array); + } + } + None + } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.contained(column, values) { + return Some(array); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use crate::{ + cast::{as_int32_array, as_uint64_array}, + ColumnStatistics, + }; + + use super::*; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + #[test] + fn test_partition_pruning_statistics() { + let partition_values = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], + vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], + ]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Partition values don't know anything about nulls or row counts + assert!(partition_stats.null_counts(&column_a).is_none()); + assert!(partition_stats.row_counts(&column_a).is_none()); + assert!(partition_stats.null_counts(&column_b).is_none()); + assert!(partition_stats.row_counts(&column_b).is_none()); + + // Min/max values are the same as the partition values + let min_values_a = + as_int32_array(&partition_stats.min_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(1), Some(3)]; + assert_eq!(min_values_a, expected_values_a); + let max_values_a = + as_int32_array(&partition_stats.max_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(1), Some(3)]; + assert_eq!(max_values_a, expected_values_a); + + let min_values_b = + as_int32_array(&partition_stats.min_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(2), Some(4)]; + assert_eq!(min_values_b, expected_values_b); + let max_values_b = + as_int32_array(&partition_stats.max_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(2), Some(4)]; + assert_eq!(max_values_b, expected_values_b); + + // Contained values are only true for the partition values + let values = HashSet::from([ScalarValue::from(1i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_a, expected_contained_a); + let contained_b = partition_stats.contained(&column_b, &values).unwrap(); + let expected_contained_b = BooleanArray::from(vec![false, false]); + assert_eq!(contained_b, expected_contained_b); + + // The number of containers is the length of the partition values + assert_eq!(partition_stats.num_containers(), 2); + } + + #[test] + fn test_partition_pruning_statistics_empty() { + let partition_values = vec![]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Partition values don't know anything about nulls or row counts + assert!(partition_stats.null_counts(&column_a).is_none()); + assert!(partition_stats.row_counts(&column_a).is_none()); + assert!(partition_stats.null_counts(&column_b).is_none()); + assert!(partition_stats.row_counts(&column_b).is_none()); + + // Min/max values are all missing + assert!(partition_stats.min_values(&column_a).is_none()); + assert!(partition_stats.max_values(&column_a).is_none()); + assert!(partition_stats.min_values(&column_b).is_none()); + assert!(partition_stats.max_values(&column_b).is_none()); + + // Contained values are all empty + let values = HashSet::from([ScalarValue::from(1i32)]); + assert!(partition_stats.contained(&column_a, &values).is_none()); + } + + #[test] + fn test_statistics_pruning_statistics() { + let statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(0i32))) + .with_max_value(Precision::Exact(ScalarValue::from(100i32))) + .with_null_count(Precision::Exact(0)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(50i32))) + .with_max_value(Precision::Exact(ScalarValue::from(300i32))) + .with_null_count(Precision::Exact(10)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(200i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let pruning_stats = PrunableStatistics::new(statistics, schema); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Min/max values are the same as the statistics + let min_values_a = as_int32_array(&pruning_stats.min_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(0), Some(50)]; + assert_eq!(min_values_a, expected_values_a); + let max_values_a = as_int32_array(&pruning_stats.max_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(100), Some(300)]; + assert_eq!(max_values_a, expected_values_a); + let min_values_b = as_int32_array(&pruning_stats.min_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(100), Some(200)]; + assert_eq!(min_values_b, expected_values_b); + let max_values_b = as_int32_array(&pruning_stats.max_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(200), Some(400)]; + assert_eq!(max_values_b, expected_values_b); + + // Null counts are the same as the statistics + let null_counts_a = + as_uint64_array(&pruning_stats.null_counts(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_a = vec![Some(0), Some(10)]; + assert_eq!(null_counts_a, expected_null_counts_a); + let null_counts_b = + as_uint64_array(&pruning_stats.null_counts(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_b = vec![Some(5), Some(0)]; + assert_eq!(null_counts_b, expected_null_counts_b); + + // Row counts are the same as the statistics + let row_counts_a = as_uint64_array(&pruning_stats.row_counts(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_a = vec![Some(100), Some(200)]; + assert_eq!(row_counts_a, expected_row_counts_a); + let row_counts_b = as_uint64_array(&pruning_stats.row_counts(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_b = vec![Some(100), Some(200)]; + assert_eq!(row_counts_b, expected_row_counts_b); + + // Contained values are all null/missing (we can't know this just from statistics) + let values = HashSet::from([ScalarValue::from(0i32)]); + assert!(pruning_stats.contained(&column_a, &values).is_none()); + assert!(pruning_stats.contained(&column_b, &values).is_none()); + + // The number of containers is the length of the statistics + assert_eq!(pruning_stats.num_containers(), 2); + + // Test with a column that has no statistics + let column_c = Column::new_unqualified("c"); + assert!(pruning_stats.min_values(&column_c).is_none()); + assert!(pruning_stats.max_values(&column_c).is_none()); + assert!(pruning_stats.null_counts(&column_c).is_none()); + // Since row counts uses the first column that has row counts we get them back even + // if this columns does not have them set. + // This is debatable, personally I think `row_count` should not take a `Column` as an argument + // at all since all columns should have the same number of rows. + // But for now we just document the current behavior in this test. + let row_counts_c = as_uint64_array(&pruning_stats.row_counts(&column_c).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_c = vec![Some(100), Some(200)]; + assert_eq!(row_counts_c, expected_row_counts_c); + assert!(pruning_stats.contained(&column_c, &values).is_none()); + + // Test with a column that doesn't exist + let column_d = Column::new_unqualified("d"); + assert!(pruning_stats.min_values(&column_d).is_none()); + assert!(pruning_stats.max_values(&column_d).is_none()); + assert!(pruning_stats.null_counts(&column_d).is_none()); + assert!(pruning_stats.row_counts(&column_d).is_none()); + assert!(pruning_stats.contained(&column_d, &values).is_none()); + } + + #[test] + fn test_statistics_pruning_statistics_empty() { + let statistics = vec![]; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let pruning_stats = PrunableStatistics::new(statistics, schema); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Min/max values are all missing + assert!(pruning_stats.min_values(&column_a).is_none()); + assert!(pruning_stats.max_values(&column_a).is_none()); + assert!(pruning_stats.min_values(&column_b).is_none()); + assert!(pruning_stats.max_values(&column_b).is_none()); + + // Null counts are all missing + assert!(pruning_stats.null_counts(&column_a).is_none()); + assert!(pruning_stats.null_counts(&column_b).is_none()); + + // Row counts are all missing + assert!(pruning_stats.row_counts(&column_a).is_none()); + assert!(pruning_stats.row_counts(&column_b).is_none()); + + // Contained values are all empty + let values = HashSet::from([ScalarValue::from(1i32)]); + assert!(pruning_stats.contained(&column_a, &values).is_none()); + } + + #[test] + fn test_composite_pruning_statistics_partition_and_file() { + // Create partition statistics + let partition_values = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(10i32)], + vec![ScalarValue::from(2i32), ScalarValue::from(20i32)], + ]; + let partition_fields = vec![ + Arc::new(Field::new("part_a", DataType::Int32, false)), + Arc::new(Field::new("part_b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + // Create file statistics + let file_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(0)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(300i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(500i32))) + .with_max_value(Precision::Exact(ScalarValue::from(600i32))) + .with_null_count(Precision::Exact(10)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(700i32))) + .with_max_value(Precision::Exact(ScalarValue::from(800i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let file_schema = Arc::new(Schema::new(vec![ + Field::new("col_x", DataType::Int32, false), + Field::new("col_y", DataType::Int32, false), + ])); + let file_stats = PrunableStatistics::new(file_statistics, file_schema); + + // Create composite statistics + let composite_stats = CompositePruningStatistics::new(vec![ + Box::new(partition_stats), + Box::new(file_stats), + ]); + + // Test accessing columns that are only in partition statistics + let part_a = Column::new_unqualified("part_a"); + let part_b = Column::new_unqualified("part_b"); + + // Test accessing columns that are only in file statistics + let col_x = Column::new_unqualified("col_x"); + let col_y = Column::new_unqualified("col_y"); + + // For partition columns, should get values from partition statistics + let min_values_part_a = + as_int32_array(&composite_stats.min_values(&part_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_part_a = vec![Some(1), Some(2)]; + assert_eq!(min_values_part_a, expected_values_part_a); + + let max_values_part_a = + as_int32_array(&composite_stats.max_values(&part_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + // For partition values, min and max are the same + assert_eq!(max_values_part_a, expected_values_part_a); + + let min_values_part_b = + as_int32_array(&composite_stats.min_values(&part_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_part_b = vec![Some(10), Some(20)]; + assert_eq!(min_values_part_b, expected_values_part_b); + + // For file columns, should get values from file statistics + let min_values_col_x = + as_int32_array(&composite_stats.min_values(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_col_x = vec![Some(100), Some(500)]; + assert_eq!(min_values_col_x, expected_values_col_x); + + let max_values_col_x = + as_int32_array(&composite_stats.max_values(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values_col_x = vec![Some(200), Some(600)]; + assert_eq!(max_values_col_x, expected_max_values_col_x); + + let min_values_col_y = + as_int32_array(&composite_stats.min_values(&col_y).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_col_y = vec![Some(300), Some(700)]; + assert_eq!(min_values_col_y, expected_values_col_y); + + // Test null counts - only available from file statistics + assert!(composite_stats.null_counts(&part_a).is_none()); + assert!(composite_stats.null_counts(&part_b).is_none()); + + let null_counts_col_x = + as_uint64_array(&composite_stats.null_counts(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_col_x = vec![Some(0), Some(10)]; + assert_eq!(null_counts_col_x, expected_null_counts_col_x); + + // Test row counts - only available from file statistics + assert!(composite_stats.row_counts(&part_a).is_none()); + let row_counts_col_x = + as_uint64_array(&composite_stats.row_counts(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(100), Some(200)]; + assert_eq!(row_counts_col_x, expected_row_counts); + + // Test contained values - only available from partition statistics + let values = HashSet::from([ScalarValue::from(1i32)]); + let contained_part_a = composite_stats.contained(&part_a, &values).unwrap(); + let expected_contained_part_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_part_a, expected_contained_part_a); + + // File statistics don't implement contained + assert!(composite_stats.contained(&col_x, &values).is_none()); + + // Non-existent column should return None for everything + let non_existent = Column::new_unqualified("non_existent"); + assert!(composite_stats.min_values(&non_existent).is_none()); + assert!(composite_stats.max_values(&non_existent).is_none()); + assert!(composite_stats.null_counts(&non_existent).is_none()); + assert!(composite_stats.row_counts(&non_existent).is_none()); + assert!(composite_stats.contained(&non_existent, &values).is_none()); + + // Verify num_containers matches + assert_eq!(composite_stats.num_containers(), 2); + } + + #[test] + fn test_composite_pruning_statistics_priority() { + // Create two sets of file statistics with the same column names + // but different values to test that the first one gets priority + + // First set of statistics + let first_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(300i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let first_schema = Arc::new(Schema::new(vec![Field::new( + "col_a", + DataType::Int32, + false, + )])); + let first_stats = PrunableStatistics::new(first_statistics, first_schema); + + // Second set of statistics with the same column name but different values + let second_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(1000i32))) + .with_max_value(Precision::Exact(ScalarValue::from(2000i32))) + .with_null_count(Precision::Exact(10)), + ) + .with_num_rows(Precision::Exact(1000)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(3000i32))) + .with_max_value(Precision::Exact(ScalarValue::from(4000i32))) + .with_null_count(Precision::Exact(20)), + ) + .with_num_rows(Precision::Exact(2000)), + ), + ]; + + let second_schema = Arc::new(Schema::new(vec![Field::new( + "col_a", + DataType::Int32, + false, + )])); + let second_stats = PrunableStatistics::new(second_statistics, second_schema); + + // Create composite statistics with first stats having priority + let composite_stats = CompositePruningStatistics::new(vec![ + Box::new(first_stats.clone()), + Box::new(second_stats.clone()), + ]); + + let col_a = Column::new_unqualified("col_a"); + + // Should get values from first statistics since it has priority + let min_values = as_int32_array(&composite_stats.min_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_min_values = vec![Some(100), Some(300)]; + assert_eq!(min_values, expected_min_values); + + let max_values = as_int32_array(&composite_stats.max_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values = vec![Some(200), Some(400)]; + assert_eq!(max_values, expected_max_values); + + let null_counts = as_uint64_array(&composite_stats.null_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts = vec![Some(0), Some(5)]; + assert_eq!(null_counts, expected_null_counts); + + let row_counts = as_uint64_array(&composite_stats.row_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(100), Some(200)]; + assert_eq!(row_counts, expected_row_counts); + + // Create composite statistics with second stats having priority + // Now that we've added Clone trait to PrunableStatistics, we can just clone them + + let composite_stats_reversed = CompositePruningStatistics::new(vec![ + Box::new(second_stats.clone()), + Box::new(first_stats.clone()), + ]); + + // Should get values from second statistics since it now has priority + let min_values = + as_int32_array(&composite_stats_reversed.min_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_min_values = vec![Some(1000), Some(3000)]; + assert_eq!(min_values, expected_min_values); + + let max_values = + as_int32_array(&composite_stats_reversed.max_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values = vec![Some(2000), Some(4000)]; + assert_eq!(max_values, expected_max_values); + + let null_counts = + as_uint64_array(&composite_stats_reversed.null_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts = vec![Some(10), Some(20)]; + assert_eq!(null_counts, expected_null_counts); + + let row_counts = + as_uint64_array(&composite_stats_reversed.row_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(1000), Some(2000)]; + assert_eq!(row_counts, expected_row_counts); + } + + #[test] + fn test_composite_pruning_statistics_empty_and_mismatched_containers() { + // Test with empty statistics vector + // This should never happen, so we panic instead of returning a Result which would burned callers + let result = std::panic::catch_unwind(|| { + CompositePruningStatistics::new(vec![]); + }); + assert!(result.is_err()); + + // We should panic here because the number of containers is different + let result = std::panic::catch_unwind(|| { + // Create statistics with different number of containers + // Use partition stats for the test + let partition_values_1 = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(10i32)], + vec![ScalarValue::from(2i32), ScalarValue::from(20i32)], + ]; + let partition_fields_1 = vec![ + Arc::new(Field::new("part_a", DataType::Int32, false)), + Arc::new(Field::new("part_b", DataType::Int32, false)), + ]; + let partition_stats_1 = PartitionPruningStatistics::try_new( + partition_values_1, + partition_fields_1, + ) + .unwrap(); + let partition_values_2 = vec![ + vec![ScalarValue::from(3i32), ScalarValue::from(30i32)], + vec![ScalarValue::from(4i32), ScalarValue::from(40i32)], + vec![ScalarValue::from(5i32), ScalarValue::from(50i32)], + ]; + let partition_fields_2 = vec![ + Arc::new(Field::new("part_x", DataType::Int32, false)), + Arc::new(Field::new("part_y", DataType::Int32, false)), + ]; + let partition_stats_2 = PartitionPruningStatistics::try_new( + partition_values_2, + partition_fields_2, + ) + .unwrap(); + + CompositePruningStatistics::new(vec![ + Box::new(partition_stats_1), + Box::new(partition_stats_2), + ]); + }); + assert!(result.is_err()); + } +} diff --git a/datafusion/common/src/rounding.rs b/datafusion/common/src/rounding.rs index 413067ecd61ed..95eefd3235b5f 100644 --- a/datafusion/common/src/rounding.rs +++ b/datafusion/common/src/rounding.rs @@ -77,6 +77,7 @@ pub trait FloatBits { /// The integer value 0, used in bitwise operations. const ZERO: Self::Item; + const NEG_ZERO: Self::Item; /// Converts the floating-point value to its bitwise representation. fn to_bits(self) -> Self::Item; @@ -101,6 +102,7 @@ impl FloatBits for f32 { const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff; const ONE: Self::Item = 1; const ZERO: Self::Item = 0; + const NEG_ZERO: Self::Item = 0x8000_0000; fn to_bits(self) -> Self::Item { self.to_bits() @@ -130,6 +132,7 @@ impl FloatBits for f64 { const CLEAR_SIGN_MASK: u64 = 0x7fff_ffff_ffff_ffff; const ONE: Self::Item = 1; const ZERO: Self::Item = 0; + const NEG_ZERO: Self::Item = 0x8000_0000_0000_0000; fn to_bits(self) -> Self::Item { self.to_bits() @@ -175,8 +178,10 @@ pub fn next_up(float: F) -> F { } let abs = bits & F::CLEAR_SIGN_MASK; - let next_bits = if abs == F::ZERO { + let next_bits = if bits == F::ZERO { F::TINY_BITS + } else if abs == F::ZERO { + F::ZERO } else if bits == abs { bits + F::ONE } else { @@ -206,8 +211,11 @@ pub fn next_down(float: F) -> F { if float.float_is_nan() || bits == F::neg_infinity().to_bits() { return float; } + let abs = bits & F::CLEAR_SIGN_MASK; - let next_bits = if abs == F::ZERO { + let next_bits = if bits == F::ZERO { + F::NEG_ZERO + } else if abs == F::ZERO { F::NEG_TINY_BITS } else if bits == abs { bits - F::ONE @@ -396,4 +404,32 @@ mod tests { let result = next_down(value); assert!(result.is_nan()); } + + #[test] + fn test_next_up_neg_zero_f32() { + let value: f32 = -0.0; + let result = next_up(value); + assert_eq!(result, 0.0); + } + + #[test] + fn test_next_down_zero_f32() { + let value: f32 = 0.0; + let result = next_down(value); + assert_eq!(result, -0.0); + } + + #[test] + fn test_next_up_neg_zero_f64() { + let value: f64 = -0.0; + let result = next_up(value); + assert_eq!(result, 0.0); + } + + #[test] + fn test_next_down_zero_f64() { + let value: f64 = 0.0; + let result = next_down(value); + assert_eq!(result, -0.0); + } } diff --git a/datafusion/common/src/scalar/cache.rs b/datafusion/common/src/scalar/cache.rs new file mode 100644 index 0000000000000..f1476a518774b --- /dev/null +++ b/datafusion/common/src/scalar/cache.rs @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Array caching utilities for scalar values + +use std::iter::repeat_n; +use std::sync::{Arc, LazyLock, Mutex}; + +use arrow::array::{new_null_array, Array, ArrayRef, PrimitiveArray}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; + +/// Maximum number of rows to cache to be conservative on memory usage +const MAX_CACHE_SIZE: usize = 1024 * 1024; + +/// Cache for dictionary key arrays to avoid repeated allocations +/// when the same size is used frequently. +/// +/// Similar to PartitionColumnProjector's ZeroBufferGenerators, this cache +/// stores key arrays for different dictionary key types. The cache is +/// limited to 1 entry per type (the last size used) to prevent memory leaks +/// for extremely large array requests. +#[derive(Debug)] +struct KeyArrayCache { + cache: Option<(usize, bool, PrimitiveArray)>, // (num_rows, is_null, key_array) +} + +impl Default for KeyArrayCache { + fn default() -> Self { + Self { cache: None } + } +} + +impl KeyArrayCache { + /// Get or create a cached key array for the given number of rows and null status + fn get_or_create(&mut self, num_rows: usize, is_null: bool) -> PrimitiveArray { + // Check cache size limit to prevent memory leaks + if num_rows > MAX_CACHE_SIZE { + // For very large arrays, don't cache them - just create and return + return self.create_key_array(num_rows, is_null); + } + + match &self.cache { + Some((cached_num_rows, cached_is_null, cached_array)) + if *cached_num_rows == num_rows && *cached_is_null == is_null => + { + // Cache hit: reuse existing array if same size and null status + cached_array.clone() + } + _ => { + // Cache miss: create new array and cache it + let key_array = self.create_key_array(num_rows, is_null); + self.cache = Some((num_rows, is_null, key_array.clone())); + key_array + } + } + } + + /// Create a new key array with the specified number of rows and null status + fn create_key_array(&self, num_rows: usize, is_null: bool) -> PrimitiveArray { + let key_array: PrimitiveArray = repeat_n( + if is_null { + None + } else { + Some(K::default_value()) + }, + num_rows, + ) + .collect(); + key_array + } +} + +/// Cache for null arrays to avoid repeated allocations +/// when the same size is used frequently. +#[derive(Debug, Default)] +struct NullArrayCache { + cache: Option<(usize, ArrayRef)>, // (num_rows, null_array) +} + +impl NullArrayCache { + /// Get or create a cached null array for the given number of rows + fn get_or_create(&mut self, num_rows: usize) -> ArrayRef { + // Check cache size limit to prevent memory leaks + if num_rows > MAX_CACHE_SIZE { + // For very large arrays, don't cache them - just create and return + return new_null_array(&DataType::Null, num_rows); + } + + match &self.cache { + Some((cached_num_rows, cached_array)) if *cached_num_rows == num_rows => { + // Cache hit: reuse existing array if same size + Arc::clone(cached_array) + } + _ => { + // Cache miss: create new array and cache it + let null_array = new_null_array(&DataType::Null, num_rows); + self.cache = Some((num_rows, Arc::clone(&null_array))); + null_array + } + } + } +} + +/// Global cache for dictionary key arrays and null arrays +#[derive(Debug, Default)] +struct ArrayCaches { + cache_i8: KeyArrayCache, + cache_i16: KeyArrayCache, + cache_i32: KeyArrayCache, + cache_i64: KeyArrayCache, + cache_u8: KeyArrayCache, + cache_u16: KeyArrayCache, + cache_u32: KeyArrayCache, + cache_u64: KeyArrayCache, + null_cache: NullArrayCache, +} + +static ARRAY_CACHES: LazyLock> = + LazyLock::new(|| Mutex::new(ArrayCaches::default())); + +/// Get the global cache for arrays +fn get_array_caches() -> &'static Mutex { + &ARRAY_CACHES +} + +/// Get or create a cached null array for the given number of rows +pub(crate) fn get_or_create_cached_null_array(num_rows: usize) -> ArrayRef { + let cache = get_array_caches(); + let mut caches = cache.lock().unwrap(); + caches.null_cache.get_or_create(num_rows) +} + +/// Get or create a cached key array for a specific key type +pub(crate) fn get_or_create_cached_key_array( + num_rows: usize, + is_null: bool, +) -> PrimitiveArray { + let cache = get_array_caches(); + let mut caches = cache.lock().unwrap(); + + // Use the DATA_TYPE to dispatch to the correct cache, similar to original implementation + match K::DATA_TYPE { + DataType::Int8 => { + let array = caches.cache_i8.get_or_create(num_rows, is_null); + // Convert using ArrayData to avoid unsafe transmute + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::Int16 => { + let array = caches.cache_i16.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::Int32 => { + let array = caches.cache_i32.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::Int64 => { + let array = caches.cache_i64.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::UInt8 => { + let array = caches.cache_u8.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::UInt16 => { + let array = caches.cache_u16.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::UInt32 => { + let array = caches.cache_u32.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + DataType::UInt64 => { + let array = caches.cache_u64.get_or_create(num_rows, is_null); + let array_data = array.to_data(); + PrimitiveArray::::from(array_data) + } + _ => { + // Fallback for unsupported types - create array directly without caching + let key_array: PrimitiveArray = repeat_n( + if is_null { + None + } else { + Some(K::default_value()) + }, + num_rows, + ) + .collect(); + key_array + } + } +} diff --git a/datafusion/common/src/scalar/consts.rs b/datafusion/common/src/scalar/consts.rs index efcde651841b0..8cb446b1c9211 100644 --- a/datafusion/common/src/scalar/consts.rs +++ b/datafusion/common/src/scalar/consts.rs @@ -17,28 +17,28 @@ // Constants defined for scalar construction. -// PI ~ 3.1415927 in f32 -#[allow(clippy::approx_constant)] -pub(super) const PI_UPPER_F32: f32 = 3.141593_f32; +// Next f32 value above π (upper bound) +pub(super) const PI_UPPER_F32: f32 = std::f32::consts::PI.next_up(); -// PI ~ 3.141592653589793 in f64 -pub(super) const PI_UPPER_F64: f64 = 3.141592653589794_f64; +// Next f64 value above π (upper bound) +pub(super) const PI_UPPER_F64: f64 = std::f64::consts::PI.next_up(); -// -PI ~ -3.1415927 in f32 -#[allow(clippy::approx_constant)] -pub(super) const NEGATIVE_PI_LOWER_F32: f32 = -3.141593_f32; +// Next f32 value below -π (lower bound) +pub(super) const NEGATIVE_PI_LOWER_F32: f32 = (-std::f32::consts::PI).next_down(); -// -PI ~ -3.141592653589793 in f64 -pub(super) const NEGATIVE_PI_LOWER_F64: f64 = -3.141592653589794_f64; +// Next f64 value below -π (lower bound) +pub(super) const NEGATIVE_PI_LOWER_F64: f64 = (-std::f64::consts::PI).next_down(); -// PI / 2 ~ 1.5707964 in f32 -pub(super) const FRAC_PI_2_UPPER_F32: f32 = 1.5707965_f32; +// Next f32 value above π/2 (upper bound) +pub(super) const FRAC_PI_2_UPPER_F32: f32 = std::f32::consts::FRAC_PI_2.next_up(); -// PI / 2 ~ 1.5707963267948966 in f64 -pub(super) const FRAC_PI_2_UPPER_F64: f64 = 1.5707963267948967_f64; +// Next f64 value above π/2 (upper bound) +pub(super) const FRAC_PI_2_UPPER_F64: f64 = std::f64::consts::FRAC_PI_2.next_up(); -// -PI / 2 ~ -1.5707964 in f32 -pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = -1.5707965_f32; +// Next f32 value below -π/2 (lower bound) +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = + (-std::f32::consts::FRAC_PI_2).next_down(); -// -PI / 2 ~ -1.5707963267948966 in f64 -pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F64: f64 = -1.5707963267948967_f64; +// Next f64 value below -π/2 (lower bound) +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F64: f64 = + (-std::f64::consts::FRAC_PI_2).next_down(); diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 4a58530edf9ee..60ff1f4b2ed44 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -17,6 +17,7 @@ //! [`ScalarValue`]: stores single values +mod cache; mod consts; mod struct_builder; @@ -27,38 +28,66 @@ use std::convert::Infallible; use std::fmt; use std::hash::Hash; use std::hash::Hasher; -use std::iter::repeat; +use std::iter::repeat_n; use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; -use crate::arrow_datafusion_err; use crate::cast::{ - as_decimal128_array, as_decimal256_array, as_dictionary_array, - as_fixed_size_binary_array, as_fixed_size_list_array, + as_binary_array, as_binary_view_array, as_boolean_array, as_date32_array, + as_date64_array, as_decimal128_array, as_decimal256_array, as_decimal32_array, + as_decimal64_array, as_dictionary_array, as_duration_microsecond_array, + as_duration_millisecond_array, as_duration_nanosecond_array, + as_duration_second_array, as_fixed_size_binary_array, as_fixed_size_list_array, + as_float16_array, as_float32_array, as_float64_array, as_int16_array, as_int32_array, + as_int64_array, as_int8_array, as_interval_dt_array, as_interval_mdn_array, + as_interval_ym_array, as_large_binary_array, as_large_list_array, + as_large_string_array, as_string_array, as_string_view_array, + as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, + as_time64_nanosecond_array, as_timestamp_microsecond_array, + as_timestamp_millisecond_array, as_timestamp_nanosecond_array, + as_timestamp_second_array, as_uint16_array, as_uint32_array, as_uint64_array, + as_uint8_array, as_union_array, }; use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_err}; use crate::format::DEFAULT_CAST_OPTIONS; use crate::hash_utils::create_hashes; use crate::utils::SingleRowListArrayBuilder; +use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ - types::{IntervalDayTime, IntervalMonthDayNano}, - *, + new_empty_array, new_null_array, Array, ArrayData, ArrayRef, ArrowNativeTypeOp, + ArrowPrimitiveType, AsArray, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, + Date64Array, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray, + DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray, + FixedSizeListArray, Float16Array, Float32Array, Float64Array, GenericListArray, + Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, LargeListArray, + LargeStringArray, ListArray, MapArray, MutableArrayData, OffsetSizeTrait, + PrimitiveArray, Scalar, StringArray, StringViewArray, StructArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, UnionArray, }; use arrow::buffer::ScalarBuffer; -use arrow::compute::kernels::{ - cast::{cast_with_options, CastOptions}, - numeric::*, +use arrow::compute::kernels::cast::{cast_with_options, CastOptions}; +use arrow::compute::kernels::numeric::{ + add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping, }; use arrow::datatypes::{ - i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, - Date32Type, Date64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION, + i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, ArrowNativeType, + ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, + Decimal32Type, Decimal64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, + IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, UnionFields, + UnionMode, DECIMAL128_MAX_PRECISION, }; use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; +use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array}; +use chrono::{Duration, NaiveDate}; use half::f16; pub use struct_builder::ScalarStructBuilder; @@ -191,6 +220,8 @@ pub use struct_builder::ScalarStructBuilder; /// See [datatypes](https://arrow.apache.org/docs/python/api/datatypes.html) for /// details on datatypes and the [format](https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375) /// for the definitive reference. +/// +/// [`NullArray`]: arrow::array::NullArray #[derive(Clone)] pub enum ScalarValue { /// represents `DataType::Null` (castable to/from any other type) @@ -203,6 +234,10 @@ pub enum ScalarValue { Float32(Option), /// 64bit float Float64(Option), + /// 32bit decimal, using the i32 to represent the decimal, precision scale + Decimal32(Option, u8, i8), + /// 64bit decimal, using the i64 to represent the decimal, precision scale + Decimal64(Option, u8, i8), /// 128bit decimal, using the i128 to represent the decimal, precision scale Decimal128(Option, u8, i8), /// 256bit decimal, using the i256 to represent the decimal, precision scale @@ -312,6 +347,14 @@ impl PartialEq for ScalarValue { // any newly added enum variant will require editing this list // or else face a compile error match (self, other) { + (Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal32(_, _, _), _) => false, + (Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + } + (Decimal64(_, _, _), _) => false, (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { v1.eq(v2) && p1.eq(p2) && s1.eq(s2) } @@ -431,6 +474,24 @@ impl PartialOrd for ScalarValue { // any newly added enum variant will require editing this list // or else face a compile error match (self, other) { + (Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal32(_, _, _), _) => None, + (Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + // Two decimal values can be compared if they have the same precision and scale. + None + } + } + (Decimal64(_, _, _), _) => None, (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { if p1.eq(p2) && s1.eq(s2) { v1.partial_cmp(v2) @@ -506,7 +567,7 @@ impl PartialOrd for ScalarValue { } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Struct(struct_arr1), Struct(struct_arr2)) => { - partial_cmp_struct(struct_arr1, struct_arr2) + partial_cmp_struct(struct_arr1.as_ref(), struct_arr2.as_ref()) } (Struct(_), _) => None, (Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2), @@ -597,10 +658,28 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { let arr1 = first_array_for_list(arr1); let arr2 = first_array_for_list(arr2); - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + let min_length = arr1.len().min(arr2.len()); + let arr1_trimmed = arr1.slice(0, min_length); + let arr2_trimmed = arr2.slice(0, min_length); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1_trimmed, &arr2_trimmed).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1_trimmed, &arr2_trimmed).ok()?; for j in 0..lt_res.len() { + // In Postgres, NULL values in lists are always considered to be greater than non-NULL values: + // + // $ SELECT ARRAY[NULL]::integer[] > ARRAY[1] + // true + // + // These next two if statements are introduced for replicating Postgres behavior, as + // arrow::compute does not account for this. + if arr1_trimmed.is_null(j) && !arr2_trimmed.is_null(j) { + return Some(Ordering::Greater); + } + if !arr1_trimmed.is_null(j) && arr2_trimmed.is_null(j) { + return Some(Ordering::Less); + } + if lt_res.is_valid(j) && lt_res.value(j) { return Some(Ordering::Less); } @@ -609,10 +688,23 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { } } - Some(Ordering::Equal) + Some(arr1.len().cmp(&arr2.len())) +} + +fn flatten<'a>(array: &'a StructArray, columns: &mut Vec<&'a ArrayRef>) { + for i in 0..array.num_columns() { + let column = array.column(i); + if let Some(nested_struct) = column.as_any().downcast_ref::() { + // If it's a nested struct, recursively expand + flatten(nested_struct, columns); + } else { + // If it's a primitive type, add directly + columns.push(column); + } + } } -fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option { +pub fn partial_cmp_struct(s1: &StructArray, s2: &StructArray) -> Option { if s1.len() != s2.len() { return None; } @@ -621,9 +713,15 @@ fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option(&self, state: &mut H) { use ScalarValue::*; match self { + Decimal32(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } + Decimal64(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } Decimal128(v, p, s) => { v.hash(state); p.hash(state); @@ -769,8 +877,9 @@ impl Hash for ScalarValue { } fn hash_nested_array(arr: ArrayRef, state: &mut H) { - let arrays = vec![arr.to_owned()]; - let hashes_buffer = &mut vec![0; arr.len()]; + let len = arr.len(); + let arrays = vec![arr]; + let hashes_buffer = &mut vec![0; len]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); let hashes = create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); // Hash back to std::hash::Hasher @@ -802,13 +911,9 @@ fn dict_from_scalar( let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 - let key_array: PrimitiveArray = repeat(if value.is_null() { - None - } else { - Some(K::default_value()) - }) - .take(size) - .collect(); + // Use cache to avoid repeated allocations for the same size + let key_array: PrimitiveArray = + get_or_create_cached_key_array::(size, value.is_null()); // create a new DictionaryArray // @@ -820,8 +925,21 @@ fn dict_from_scalar( )) } -/// Create a dictionary array representing all the values in values -fn dict_from_values( +/// Create a `DictionaryArray` from the provided values array. +/// +/// Each element gets a unique key (`0..N-1`), without deduplication. +/// Useful for wrapping arrays in dictionary form. +/// +/// # Input +/// ["alice", "bob", "alice", null, "carol"] +/// +/// # Output +/// `DictionaryArray` +/// { +/// keys: [0, 1, 2, 3, 4], +/// values: ["alice", "bob", "alice", null, "carol"] +/// } +pub fn dict_from_values( values_array: ArrayRef, ) -> Result { // Create a key array with `size` elements of 0..array_len for all @@ -830,11 +948,10 @@ fn dict_from_values( .map(|index| { if values_array.is_valid(index) { let native_index = K::Native::from_usize(index).ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not create index of type {} from value {}", - K::DATA_TYPE, - index - )) + _internal_datafusion_err!( + "Can not create index of type {} from value {index}", + K::DATA_TYPE + ) })?; Ok(Some(native_index)) } else { @@ -855,17 +972,8 @@ fn dict_from_values( } macro_rules! typed_cast_tz { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ - use std::any::type_name; - let array = $array - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$ARRAYTYPE>() - )) - })?; + ($array:expr, $index:expr, $array_cast:ident, $SCALAR:ident, $TZ:expr) => {{ + let array = $array_cast($array)?; Ok::(ScalarValue::$SCALAR( match array.is_null($index) { true => None, @@ -877,17 +985,8 @@ macro_rules! typed_cast_tz { } macro_rules! typed_cast { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ - use std::any::type_name; - let array = $array - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$ARRAYTYPE>() - )) - })?; + ($array:expr, $index:expr, $array_cast:ident, $SCALAR:ident) => {{ + let array = $array_cast($array)?; Ok::(ScalarValue::$SCALAR( match array.is_null($index) { true => None, @@ -924,17 +1023,8 @@ macro_rules! build_timestamp_array_from_option { } macro_rules! eq_array_primitive { - ($array:expr, $index:expr, $ARRAYTYPE:ident, $VALUE:expr) => {{ - use std::any::type_name; - let array = $array - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast value to {}", - type_name::<$ARRAYTYPE>() - )) - })?; + ($array:expr, $index:expr, $array_cast:ident, $VALUE:expr) => {{ + let array = $array_cast($array)?; let is_valid = array.is_valid($index); Ok::(match $VALUE { Some(val) => is_valid && &array.value($index) == val, @@ -999,6 +1089,12 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(None), DataType::UInt32 => ScalarValue::UInt32(None), DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal32(precision, scale) => { + ScalarValue::Decimal32(None, *precision, *scale) + } + DataType::Decimal64(precision, scale) => { + ScalarValue::Decimal64(None, *precision, *scale) + } DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(None, *precision, *scale) } @@ -1091,7 +1187,7 @@ impl ScalarValue { DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( - "Can't create a null scalar from data_type \"{data_type:?}\"" + "Can't create a null scalar from data_type \"{data_type}\"" ); } }) @@ -1147,19 +1243,17 @@ impl ScalarValue { match datatype { DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)), DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)), - _ => _internal_err!("PI is not supported for data type: {:?}", datatype), + _ => _internal_err!("PI is not supported for data type: {}", datatype), } } /// Returns a [`ScalarValue`] representing PI's upper bound pub fn new_pi_upper(datatype: &DataType) -> Result { - // TODO: replace the constants with next_up/next_down when - // they are stabilized: https://doc.rust-lang.org/std/primitive.f64.html#method.next_up match datatype { DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)), _ => { - _internal_err!("PI_UPPER is not supported for data type: {:?}", datatype) + _internal_err!("PI_UPPER is not supported for data type: {}", datatype) } } } @@ -1170,7 +1264,7 @@ impl ScalarValue { DataType::Float32 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)), _ => { - _internal_err!("-PI_LOWER is not supported for data type: {:?}", datatype) + _internal_err!("-PI_LOWER is not supported for data type: {}", datatype) } } } @@ -1181,10 +1275,7 @@ impl ScalarValue { DataType::Float32 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)), DataType::Float64 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)), _ => { - _internal_err!( - "PI_UPPER/2 is not supported for data type: {:?}", - datatype - ) + _internal_err!("PI_UPPER/2 is not supported for data type: {}", datatype) } } } @@ -1199,10 +1290,7 @@ impl ScalarValue { Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F64)) } _ => { - _internal_err!( - "-PI/2_LOWER is not supported for data type: {:?}", - datatype - ) + _internal_err!("-PI/2_LOWER is not supported for data type: {}", datatype) } } } @@ -1212,7 +1300,7 @@ impl ScalarValue { match datatype { DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)), DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)), - _ => _internal_err!("-PI is not supported for data type: {:?}", datatype), + _ => _internal_err!("-PI is not supported for data type: {}", datatype), } } @@ -1221,7 +1309,7 @@ impl ScalarValue { match datatype { DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)), DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)), - _ => _internal_err!("PI/2 is not supported for data type: {:?}", datatype), + _ => _internal_err!("PI/2 is not supported for data type: {}", datatype), } } @@ -1230,7 +1318,7 @@ impl ScalarValue { match datatype { DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)), DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)), - _ => _internal_err!("-PI/2 is not supported for data type: {:?}", datatype), + _ => _internal_err!("-PI/2 is not supported for data type: {}", datatype), } } @@ -1240,7 +1328,7 @@ impl ScalarValue { DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)), DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)), _ => { - _internal_err!("Infinity is not supported for data type: {:?}", datatype) + _internal_err!("Infinity is not supported for data type: {}", datatype) } } } @@ -1252,7 +1340,7 @@ impl ScalarValue { DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)), _ => { _internal_err!( - "Negative Infinity is not supported for data type: {:?}", + "Negative Infinity is not supported for data type: {}", datatype ) } @@ -1274,6 +1362,12 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), + DataType::Decimal32(precision, scale) => { + ScalarValue::Decimal32(Some(0), *precision, *scale) + } + DataType::Decimal64(precision, scale) => { + ScalarValue::Decimal64(Some(0), *precision, *scale) + } DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(Some(0), *precision, *scale) } @@ -1325,12 +1419,150 @@ impl ScalarValue { DataType::Date64 => ScalarValue::Date64(Some(0)), _ => { return _not_impl_err!( - "Can't create a zero scalar from data_type \"{datatype:?}\"" + "Can't create a zero scalar from data_type \"{datatype}\"" ); } }) } + /// Returns a default value for the given `DataType`. + /// + /// This function is useful when you need to initialize a column with + /// non-null values in a DataFrame or when you need a "zero" value + /// for a specific data type. + /// + /// # Default Values + /// + /// - **Numeric types**: Returns zero (via [`new_zero`]) + /// - **String types**: Returns empty string (`""`) + /// - **Binary types**: Returns empty byte array + /// - **Temporal types**: Returns zero/epoch value + /// - **List types**: Returns empty list + /// - **Struct types**: Returns struct with all fields set to their defaults + /// - **Dictionary types**: Returns dictionary with default value + /// - **Map types**: Returns empty map + /// - **Union types**: Returns first variant with default value + /// + /// # Errors + /// + /// Returns an error for data types that don't have a clear default value + /// or are not yet supported (e.g., `RunEndEncoded`). + /// + /// [`new_zero`]: Self::new_zero + pub fn new_default(datatype: &DataType) -> Result { + match datatype { + // Null type + DataType::Null => Ok(ScalarValue::Null), + + // Numeric types + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Timestamp(_, _) + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Interval(_) + | DataType::Duration(_) + | DataType::Date32 + | DataType::Date64 => ScalarValue::new_zero(datatype), + + // String types + DataType::Utf8 => Ok(ScalarValue::Utf8(Some("".to_string()))), + DataType::LargeUtf8 => Ok(ScalarValue::LargeUtf8(Some("".to_string()))), + DataType::Utf8View => Ok(ScalarValue::Utf8View(Some("".to_string()))), + + // Binary types + DataType::Binary => Ok(ScalarValue::Binary(Some(vec![]))), + DataType::LargeBinary => Ok(ScalarValue::LargeBinary(Some(vec![]))), + DataType::BinaryView => Ok(ScalarValue::BinaryView(Some(vec![]))), + + // Fixed-size binary + DataType::FixedSizeBinary(size) => Ok(ScalarValue::FixedSizeBinary( + *size, + Some(vec![0; *size as usize]), + )), + + // List types + DataType::List(field) => { + let list = + ScalarValue::new_list(&[], field.data_type(), field.is_nullable()); + Ok(ScalarValue::List(list)) + } + DataType::FixedSizeList(field, _size) => { + let empty_arr = new_empty_array(field.data_type()); + let values = Arc::new( + SingleRowListArrayBuilder::new(empty_arr) + .with_nullable(field.is_nullable()) + .build_fixed_size_list_array(0), + ); + Ok(ScalarValue::FixedSizeList(values)) + } + DataType::LargeList(field) => { + let list = ScalarValue::new_large_list(&[], field.data_type()); + Ok(ScalarValue::LargeList(list)) + } + + // Struct types + DataType::Struct(fields) => { + let values = fields + .iter() + .map(|f| ScalarValue::new_default(f.data_type())) + .collect::>>()?; + Ok(ScalarValue::Struct(Arc::new(StructArray::new( + fields.clone(), + values + .into_iter() + .map(|v| v.to_array()) + .collect::>()?, + None, + )))) + } + + // Dictionary types + DataType::Dictionary(key_type, value_type) => Ok(ScalarValue::Dictionary( + key_type.clone(), + Box::new(ScalarValue::new_default(value_type)?), + )), + + // Map types + DataType::Map(field, _) => Ok(ScalarValue::Map(Arc::new(MapArray::from( + ArrayData::new_empty(field.data_type()), + )))), + + // Union types - return first variant with default value + DataType::Union(fields, mode) => { + if let Some((type_id, field)) = fields.iter().next() { + let default_value = ScalarValue::new_default(field.data_type())?; + Ok(ScalarValue::Union( + Some((type_id, Box::new(default_value))), + fields.clone(), + *mode, + )) + } else { + _internal_err!("Union type must have at least one field") + } + } + + // Unsupported types for now + _ => { + _not_impl_err!( + "Default value for data_type \"{datatype}\" is not implemented yet" + ) + } + } + } + /// Create an one value in the given type. pub fn new_one(datatype: &DataType) -> Result { Ok(match datatype { @@ -1345,9 +1577,65 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), + DataType::Decimal32(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match 10_i32.checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i64::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal128(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i128::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i256::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( - "Can't create an one scalar from data_type \"{datatype:?}\"" + "Can't create an one scalar from data_type \"{datatype}\"" ); } }) @@ -1363,9 +1651,65 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), + DataType::Decimal32(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match 10_i32.checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i64::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal128(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i128::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + validate_decimal_precision_and_scale::( + *precision, *scale, + )?; + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i256::from(10).checked_pow(*scale as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(-value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( - "Can't create a negative one scalar from data_type \"{datatype:?}\"" + "Can't create a negative one scalar from data_type \"{datatype}\"" ); } }) @@ -1384,9 +1728,73 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))), DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), + DataType::Decimal32(precision, scale) => { + if let Err(err) = validate_decimal_precision_and_scale::( + *precision, *scale, + ) { + return _internal_err!("Invalid precision and scale {err}"); + } + if *scale <= 0 { + return _internal_err!("Negative scale is not supported"); + } + match 10_i32.checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal32(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal64(precision, scale) => { + if let Err(err) = validate_decimal_precision_and_scale::( + *precision, *scale, + ) { + return _internal_err!("Invalid precision and scale {err}"); + } + if *scale <= 0 { + return _internal_err!("Negative scale is not supported"); + } + match i64::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal64(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal128(precision, scale) => { + if let Err(err) = validate_decimal_precision_and_scale::( + *precision, *scale, + ) { + return _internal_err!("Invalid precision and scale {err}"); + } + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i128::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal128(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } + DataType::Decimal256(precision, scale) => { + if let Err(err) = validate_decimal_precision_and_scale::( + *precision, *scale, + ) { + return _internal_err!("Invalid precision and scale {err}"); + } + if *scale < 0 { + return _internal_err!("Negative scale is not supported"); + } + match i256::from(10).checked_pow((*scale + 1) as u32) { + Some(value) => { + ScalarValue::Decimal256(Some(value), *precision, *scale) + } + None => return _internal_err!("Unsupported scale {scale}"), + } + } _ => { return _not_impl_err!( - "Can't create a ten scalar from data_type \"{datatype:?}\"" + "Can't create a ten scalar from data_type \"{datatype}\"" ); } }) @@ -1404,6 +1812,12 @@ impl ScalarValue { ScalarValue::Int16(_) => DataType::Int16, ScalarValue::Int32(_) => DataType::Int32, ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Decimal32(_, precision, scale) => { + DataType::Decimal32(*precision, *scale) + } + ScalarValue::Decimal64(_, precision, scale) => { + DataType::Decimal64(*precision, *scale) + } ScalarValue::Decimal128(_, precision, scale) => { DataType::Decimal128(*precision, *scale) } @@ -1526,6 +1940,24 @@ impl ScalarValue { ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } + ScalarValue::Decimal32(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal32( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal32({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } + ScalarValue::Decimal64(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal64( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal64({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } ScalarValue::Decimal128(Some(v), precision, scale) => { Ok(ScalarValue::Decimal128( Some(neg_checked_with_ctx(*v, || { @@ -1677,6 +2109,8 @@ impl ScalarValue { ScalarValue::Float16(v) => v.is_none(), ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), + ScalarValue::Decimal32(v, _, _) => v.is_none(), + ScalarValue::Decimal64(v, _, _) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), ScalarValue::Decimal256(v, _, _) => v.is_none(), ScalarValue::Int8(v) => v.is_none(), @@ -1753,6 +2187,26 @@ impl ScalarValue { (Self::Float64(Some(l)), Self::Float64(Some(r))) => { Some((l - r).abs().round() as _) } + ( + Self::Decimal128(Some(l), lprecision, lscale), + Self::Decimal128(Some(r), rprecision, rscale), + ) => { + if lprecision == rprecision && lscale == rscale { + l.checked_sub(*r)?.checked_abs()?.to_usize() + } else { + None + } + } + ( + Self::Decimal256(Some(l), lprecision, lscale), + Self::Decimal256(Some(r), rprecision, rscale), + ) => { + if lprecision == rprecision && lscale == rscale { + l.checked_sub(*r)?.checked_abs()?.to_usize() + } else { + None + } + } _ => None, } } @@ -1809,10 +2263,6 @@ impl ScalarValue { /// Returns an error if the iterator is empty or if the /// [`ScalarValue`]s are not all the same type /// - /// # Panics - /// - /// Panics if `self` is a dictionary with invalid key type - /// /// # Example /// ``` /// use datafusion_common::ScalarValue; @@ -1916,9 +2366,19 @@ impl ScalarValue { } let array: ArrayRef = match &data_type { + DataType::Decimal32(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal32_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) + } + DataType::Decimal64(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal64_array(scalars, *precision, *scale)?; + Arc::new(decimal_array) + } DataType::Decimal128(precision, scale) => { let decimal_array = - ScalarValue::iter_to_decimal_array(scalars, *precision, *scale)?; + ScalarValue::iter_to_decimal128_array(scalars, *precision, *scale)?; Arc::new(decimal_array) } DataType::Decimal256(precision, scale) => { @@ -2068,7 +2528,7 @@ impl ScalarValue { DataType::UInt16 => dict_from_values::(values)?, DataType::UInt32 => dict_from_values::(values)?, DataType::UInt64 => dict_from_values::(values)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } DataType::FixedSizeBinary(size) => { @@ -2079,7 +2539,7 @@ impl ScalarValue { } else { _exec_err!( "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {data_type:?}, got {sv:?}" + Expected {data_type}, got {sv:?}" ) } }) @@ -2091,7 +2551,7 @@ impl ScalarValue { Arc::new(array) } // explicitly enumerate unsupported types so newly added - // types must be aknowledged, Time32 and Time64 types are + // types must be acknowledged, Time32 and Time64 types are // not supported if the TimeUnit is not valid (Time32 can // only be used with Second and Millisecond, Time64 only // with Microsecond and Nanosecond) @@ -2127,7 +2587,43 @@ impl ScalarValue { Ok(new_null_array(&DataType::Null, length)) } - fn iter_to_decimal_array( + fn iter_to_decimal32_array( + scalars: impl IntoIterator, + precision: u8, + scale: i8, + ) -> Result { + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal32(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) + } + + fn iter_to_decimal64_array( + scalars: impl IntoIterator, + precision: u8, + scale: i8, + ) -> Result { + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal64(v1, _, _) => Ok(v1), + s => { + _internal_err!("Expected ScalarValue::Null element. Received {s:?}") + } + }) + .collect::>()? + .with_precision_and_scale(precision, scale)?; + Ok(array) + } + + fn iter_to_decimal128_array( scalars: impl IntoIterator, precision: u8, scale: i8, @@ -2165,17 +2661,17 @@ impl ScalarValue { Ok(array) } - fn build_decimal_array( - value: Option, + fn build_decimal32_array( + value: Option, precision: u8, scale: i8, size: usize, - ) -> Result { + ) -> Result { Ok(match value { - Some(val) => Decimal128Array::from(vec![val; size]) + Some(val) => Decimal32Array::from(vec![val; size]) .with_precision_and_scale(precision, scale)?, None => { - let mut builder = Decimal128Array::builder(size) + let mut builder = Decimal32Array::builder(size) .with_precision_and_scale(precision, scale)?; builder.append_nulls(size); builder.finish() @@ -2183,26 +2679,61 @@ impl ScalarValue { }) } - fn build_decimal256_array( - value: Option, + fn build_decimal64_array( + value: Option, precision: u8, scale: i8, size: usize, - ) -> Result { - Ok(repeat(value) - .take(size) - .collect::() - .with_precision_and_scale(precision, scale)?) + ) -> Result { + Ok(match value { + Some(val) => Decimal64Array::from(vec![val; size]) + .with_precision_and_scale(precision, scale)?, + None => { + let mut builder = Decimal64Array::builder(size) + .with_precision_and_scale(precision, scale)?; + builder.append_nulls(size); + builder.finish() + } + }) } - /// Converts `Vec` where each element has type corresponding to - /// `data_type`, to a single element [`ListArray`]. - /// - /// Example - /// ``` - /// use datafusion_common::ScalarValue; - /// use arrow::array::{ListArray, Int32Array}; - /// use arrow::datatypes::{DataType, Int32Type}; + fn build_decimal128_array( + value: Option, + precision: u8, + scale: i8, + size: usize, + ) -> Result { + Ok(match value { + Some(val) => Decimal128Array::from(vec![val; size]) + .with_precision_and_scale(precision, scale)?, + None => { + let mut builder = Decimal128Array::builder(size) + .with_precision_and_scale(precision, scale)?; + builder.append_nulls(size); + builder.finish() + } + }) + } + + fn build_decimal256_array( + value: Option, + precision: u8, + scale: i8, + size: usize, + ) -> Result { + Ok(repeat_n(value, size) + .collect::() + .with_precision_and_scale(precision, scale)?) + } + + /// Converts `Vec` where each element has type corresponding to + /// `data_type`, to a single element [`ListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; /// use datafusion_common::cast::as_list_array; /// /// let scalars = vec![ @@ -2345,8 +2876,14 @@ impl ScalarValue { /// - a `Dictionary` that fails be converted to a dictionary array of size pub fn to_array_of_size(&self, size: usize) -> Result { Ok(match self { + ScalarValue::Decimal32(e, precision, scale) => Arc::new( + ScalarValue::build_decimal32_array(*e, *precision, *scale, size)?, + ), + ScalarValue::Decimal64(e, precision, scale) => Arc::new( + ScalarValue::build_decimal64_array(*e, *precision, *scale, size)?, + ), ScalarValue::Decimal128(e, precision, scale) => Arc::new( - ScalarValue::build_decimal_array(*e, *precision, *scale, size)?, + ScalarValue::build_decimal128_array(*e, *precision, *scale, size)?, ), ScalarValue::Decimal256(e, precision, scale) => Arc::new( ScalarValue::build_decimal256_array(*e, *precision, *scale, size)?, @@ -2416,53 +2953,47 @@ impl ScalarValue { } ScalarValue::Utf8(e) => match e { Some(value) => { - Arc::new(StringArray::from_iter_values(repeat(value).take(size))) + Arc::new(StringArray::from_iter_values(repeat_n(value, size))) } None => new_null_array(&DataType::Utf8, size), }, ScalarValue::Utf8View(e) => match e { Some(value) => { - Arc::new(StringViewArray::from_iter_values(repeat(value).take(size))) + Arc::new(StringViewArray::from_iter_values(repeat_n(value, size))) } None => new_null_array(&DataType::Utf8View, size), }, ScalarValue::LargeUtf8(e) => match e { Some(value) => { - Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size))) + Arc::new(LargeStringArray::from_iter_values(repeat_n(value, size))) } None => new_null_array(&DataType::LargeUtf8, size), }, ScalarValue::Binary(e) => match e { Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), + repeat_n(Some(value.as_slice()), size).collect::(), ), - None => { - Arc::new(repeat(None::<&str>).take(size).collect::()) - } + None => Arc::new(repeat_n(None::<&str>, size).collect::()), }, ScalarValue::BinaryView(e) => match e { Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), + repeat_n(Some(value.as_slice()), size).collect::(), ), None => { - Arc::new(repeat(None::<&str>).take(size).collect::()) + Arc::new(repeat_n(None::<&str>, size).collect::()) } }, ScalarValue::FixedSizeBinary(s, e) => match e { Some(value) => Arc::new( FixedSizeBinaryArray::try_from_sparse_iter_with_size( - repeat(Some(value.as_slice())).take(size), + repeat_n(Some(value.as_slice()), size), *s, ) .unwrap(), ), None => Arc::new( FixedSizeBinaryArray::try_from_sparse_iter_with_size( - repeat(None::<&[u8]>).take(size), + repeat_n(None::<&[u8]>, size), *s, ) .unwrap(), @@ -2470,29 +3001,40 @@ impl ScalarValue { }, ScalarValue::LargeBinary(e) => match e { Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), - ), - None => Arc::new( - repeat(None::<&str>) - .take(size) - .collect::(), + repeat_n(Some(value.as_slice()), size).collect::(), ), + None => { + Arc::new(repeat_n(None::<&str>, size).collect::()) + } }, ScalarValue::List(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::LargeList(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::FixedSizeList(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Struct(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Map(arr) => { + if size == 1 { + return Ok(Arc::clone(arr) as Arc); + } Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Date32(e) => { @@ -2606,7 +3148,7 @@ impl ScalarValue { child_arrays.push(ar); new_fields.push(field.clone()); } - let type_ids = repeat(*v_id).take(size); + let type_ids = repeat_n(*v_id, size); let type_ids = ScalarBuffer::::from_iter(type_ids); let value_offsets = match mode { UnionMode::Sparse => None, @@ -2618,7 +3160,7 @@ impl ScalarValue { value_offsets, child_arrays, ) - .map_err(|e| DataFusionError::ArrowError(e, None))?; + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; Arc::new(ar) } None => { @@ -2637,10 +3179,10 @@ impl ScalarValue { DataType::UInt16 => dict_from_scalar::(v, size)?, DataType::UInt32 => dict_from_scalar::(v, size)?, DataType::UInt64 => dict_from_scalar::(v, size)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + _ => unreachable!("Invalid dictionary keys type: {}", key_type), } } - ScalarValue::Null => new_null_array(&DataType::Null, size), + ScalarValue::Null => get_or_create_cached_null_array(size), }) } @@ -2651,6 +3193,24 @@ impl ScalarValue { scale: i8, ) -> Result { match array.data_type() { + DataType::Decimal32(_, _) => { + let array = as_decimal32_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal32(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal32(Some(value), precision, scale)) + } + } + DataType::Decimal64(_, _) => { + let array = as_decimal64_array(array)?; + if array.is_null(index) { + Ok(ScalarValue::Decimal64(None, precision, scale)) + } else { + let value = array.value(index); + Ok(ScalarValue::Decimal64(Some(value), precision, scale)) + } + } DataType::Decimal128(_, _) => { let array = as_decimal128_array(array)?; if array.is_null(index) { @@ -2669,12 +3229,14 @@ impl ScalarValue { Ok(ScalarValue::Decimal256(Some(value), precision, scale)) } } - _ => _internal_err!("Unsupported decimal type"), + other => { + unreachable!("Invalid type isn't decimal: {other:?}") + } } } fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = repeat(arr).take(size).collect::>(); + let arrays = repeat_n(arr, size).collect::>(); let ret = match !arrays.is_empty() { true => arrow::compute::concat(arrays.as_slice())?, false => arr.slice(0, 0), @@ -2684,6 +3246,8 @@ impl ScalarValue { /// Retrieve ScalarValue for each row in `array` /// + /// Elements in `array` may be NULL, in which case the corresponding element in the returned vector is None. + /// /// Example 1: Array (ScalarValue::Int32) /// ``` /// use datafusion_common::ScalarValue; @@ -2700,15 +3264,15 @@ impl ScalarValue { /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); /// /// let expected = vec![ - /// vec![ - /// ScalarValue::Int32(Some(1)), - /// ScalarValue::Int32(Some(2)), - /// ScalarValue::Int32(Some(3)), - /// ], - /// vec![ - /// ScalarValue::Int32(Some(4)), - /// ScalarValue::Int32(Some(5)), - /// ], + /// Some(vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(3)), + /// ]), + /// Some(vec![ + /// ScalarValue::Int32(Some(4)), + /// ScalarValue::Int32(Some(5)), + /// ]), /// ]; /// /// assert_eq!(scalar_vec, expected); @@ -2741,26 +3305,73 @@ impl ScalarValue { /// ]); /// /// let expected = vec![ - /// vec![ + /// Some(vec![ /// ScalarValue::List(Arc::new(l1)), /// ScalarValue::List(Arc::new(l2)), - /// ], + /// ]), + /// ]; + /// + /// assert_eq!(scalar_vec, expected); + /// ``` + /// + /// Example 3: Nullable array + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::ListArray; + /// use arrow::datatypes::{DataType, Int32Type}; + /// + /// let list_arr = ListArray::from_iter_primitive::(vec![ + /// Some(vec![Some(1), Some(2), Some(3)]), + /// None, + /// Some(vec![Some(4), Some(5)]) + /// ]); + /// + /// // Convert the array into Scalar Values for each row + /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); + /// + /// let expected = vec![ + /// Some(vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(3)), + /// ]), + /// None, + /// Some(vec![ + /// ScalarValue::Int32(Some(4)), + /// ScalarValue::Int32(Some(5)), + /// ]), /// ]; /// /// assert_eq!(scalar_vec, expected); /// ``` - pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result>> { - let mut scalars = Vec::with_capacity(array.len()); - - for index in 0..array.len() { - let nested_array = array.as_list::().value(index); - let scalar_values = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - scalars.push(scalar_values); + pub fn convert_array_to_scalar_vec( + array: &dyn Array, + ) -> Result>>> { + fn generic_collect( + array: &dyn Array, + ) -> Result>>> { + array + .as_list::() + .iter() + .map(|nested_array| { + nested_array + .map(|array| { + (0..array.len()) + .map(|i| ScalarValue::try_from_array(&array, i)) + .collect::>>() + }) + .transpose() + }) + .collect() } - Ok(scalars) + match array.data_type() { + DataType::List(_) => generic_collect::(array), + DataType::LargeList(_) => generic_collect::(array), + _ => _internal_err!( + "ScalarValue::convert_array_to_scalar_vec input must be a List/LargeList type" + ), + } } #[deprecated( @@ -2783,6 +3394,16 @@ impl ScalarValue { Ok(match array.data_type() { DataType::Null => ScalarValue::Null, + DataType::Decimal32(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } + DataType::Decimal64(precision, scale) => { + ScalarValue::get_decimal_value_from_array( + array, index, *precision, *scale, + )? + } DataType::Decimal128(precision, scale) => { ScalarValue::get_decimal_value_from_array( array, index, *precision, *scale, @@ -2793,30 +3414,32 @@ impl ScalarValue { array, index, *precision, *scale, )? } - DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, - DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, - DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, - DataType::Float16 => typed_cast!(array, index, Float16Array, Float16)?, - DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, - DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, - DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, - DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8)?, - DataType::Int64 => typed_cast!(array, index, Int64Array, Int64)?, - DataType::Int32 => typed_cast!(array, index, Int32Array, Int32)?, - DataType::Int16 => typed_cast!(array, index, Int16Array, Int16)?, - DataType::Int8 => typed_cast!(array, index, Int8Array, Int8)?, - DataType::Binary => typed_cast!(array, index, BinaryArray, Binary)?, + DataType::Boolean => typed_cast!(array, index, as_boolean_array, Boolean)?, + DataType::Float64 => typed_cast!(array, index, as_float64_array, Float64)?, + DataType::Float32 => typed_cast!(array, index, as_float32_array, Float32)?, + DataType::Float16 => typed_cast!(array, index, as_float16_array, Float16)?, + DataType::UInt64 => typed_cast!(array, index, as_uint64_array, UInt64)?, + DataType::UInt32 => typed_cast!(array, index, as_uint32_array, UInt32)?, + DataType::UInt16 => typed_cast!(array, index, as_uint16_array, UInt16)?, + DataType::UInt8 => typed_cast!(array, index, as_uint8_array, UInt8)?, + DataType::Int64 => typed_cast!(array, index, as_int64_array, Int64)?, + DataType::Int32 => typed_cast!(array, index, as_int32_array, Int32)?, + DataType::Int16 => typed_cast!(array, index, as_int16_array, Int16)?, + DataType::Int8 => typed_cast!(array, index, as_int8_array, Int8)?, + DataType::Binary => typed_cast!(array, index, as_binary_array, Binary)?, DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary)? + typed_cast!(array, index, as_large_binary_array, LargeBinary)? } DataType::BinaryView => { - typed_cast!(array, index, BinaryViewArray, BinaryView)? + typed_cast!(array, index, as_binary_view_array, BinaryView)? } - DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, + DataType::Utf8 => typed_cast!(array, index, as_string_array, Utf8)?, DataType::LargeUtf8 => { - typed_cast!(array, index, LargeStringArray, LargeUtf8)? + typed_cast!(array, index, as_large_string_array, LargeUtf8)? + } + DataType::Utf8View => { + typed_cast!(array, index, as_string_view_array, Utf8View)? } - DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?, DataType::List(field) => { let list_array = array.as_list::(); let nested_array = list_array.value(index); @@ -2826,7 +3449,7 @@ impl ScalarValue { .build_list_scalar() } DataType::LargeList(field) => { - let list_array = as_large_list_array(array); + let list_array = as_large_list_array(array)?; let nested_array = list_array.value(index); // Produces a single element `LargeListArray` with the value at `index`. SingleRowListArrayBuilder::new(nested_array) @@ -2843,45 +3466,45 @@ impl ScalarValue { .with_field(field) .build_fixed_size_list_scalar(list_size) } - DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, - DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, + DataType::Date32 => typed_cast!(array, index, as_date32_array, Date32)?, + DataType::Date64 => typed_cast!(array, index, as_date64_array, Date64)?, DataType::Time32(TimeUnit::Second) => { - typed_cast!(array, index, Time32SecondArray, Time32Second)? + typed_cast!(array, index, as_time32_second_array, Time32Second)? } DataType::Time32(TimeUnit::Millisecond) => { - typed_cast!(array, index, Time32MillisecondArray, Time32Millisecond)? + typed_cast!(array, index, as_time32_millisecond_array, Time32Millisecond)? } DataType::Time64(TimeUnit::Microsecond) => { - typed_cast!(array, index, Time64MicrosecondArray, Time64Microsecond)? + typed_cast!(array, index, as_time64_microsecond_array, Time64Microsecond)? } DataType::Time64(TimeUnit::Nanosecond) => { - typed_cast!(array, index, Time64NanosecondArray, Time64Nanosecond)? + typed_cast!(array, index, as_time64_nanosecond_array, Time64Nanosecond)? } DataType::Timestamp(TimeUnit::Second, tz_opt) => typed_cast_tz!( array, index, - TimestampSecondArray, + as_timestamp_second_array, TimestampSecond, tz_opt )?, DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_cast_tz!( array, index, - TimestampMillisecondArray, + as_timestamp_millisecond_array, TimestampMillisecond, tz_opt )?, DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_cast_tz!( array, index, - TimestampMicrosecondArray, + as_timestamp_microsecond_array, TimestampMicrosecond, tz_opt )?, DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_cast_tz!( array, index, - TimestampNanosecondArray, + as_timestamp_nanosecond_array, TimestampNanosecond, tz_opt )?, @@ -2895,7 +3518,7 @@ impl ScalarValue { DataType::UInt16 => get_dict_value::(array, index)?, DataType::UInt32 => get_dict_value::(array, index)?, DataType::UInt64 => get_dict_value::(array, index)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + _ => unreachable!("Invalid dictionary keys type: {}", key_type), }; // look up the index in the values dictionary let value = match values_index { @@ -2927,36 +3550,42 @@ impl ScalarValue { ) } DataType::Interval(IntervalUnit::DayTime) => { - typed_cast!(array, index, IntervalDayTimeArray, IntervalDayTime)? + typed_cast!(array, index, as_interval_dt_array, IntervalDayTime)? } DataType::Interval(IntervalUnit::YearMonth) => { - typed_cast!(array, index, IntervalYearMonthArray, IntervalYearMonth)? + typed_cast!(array, index, as_interval_ym_array, IntervalYearMonth)? + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + typed_cast!(array, index, as_interval_mdn_array, IntervalMonthDayNano)? } - DataType::Interval(IntervalUnit::MonthDayNano) => typed_cast!( - array, - index, - IntervalMonthDayNanoArray, - IntervalMonthDayNano - )?, DataType::Duration(TimeUnit::Second) => { - typed_cast!(array, index, DurationSecondArray, DurationSecond)? - } - DataType::Duration(TimeUnit::Millisecond) => { - typed_cast!(array, index, DurationMillisecondArray, DurationMillisecond)? - } - DataType::Duration(TimeUnit::Microsecond) => { - typed_cast!(array, index, DurationMicrosecondArray, DurationMicrosecond)? - } - DataType::Duration(TimeUnit::Nanosecond) => { - typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? + typed_cast!(array, index, as_duration_second_array, DurationSecond)? } + DataType::Duration(TimeUnit::Millisecond) => typed_cast!( + array, + index, + as_duration_millisecond_array, + DurationMillisecond + )?, + DataType::Duration(TimeUnit::Microsecond) => typed_cast!( + array, + index, + as_duration_microsecond_array, + DurationMicrosecond + )?, + DataType::Duration(TimeUnit::Nanosecond) => typed_cast!( + array, + index, + as_duration_nanosecond_array, + DurationNanosecond + )?, DataType::Map(_, _) => { let a = array.slice(index, 1); Self::Map(Arc::new(a.as_map().to_owned())) } DataType::Union(fields, mode) => { - let array = as_union_array(array); + let array = as_union_array(array)?; let ti = array.type_id(index); let index = array.value_offset(index); let value = ScalarValue::try_from_array(array.child(ti), index)?; @@ -3030,47 +3659,49 @@ impl ScalarValue { target_type: &DataType, cast_options: &CastOptions<'static>, ) -> Result { - let scalar_array = match (self, target_type) { - ( - ScalarValue::Float64(Some(float_ts)), - DataType::Timestamp(TimeUnit::Nanosecond, None), - ) => ScalarValue::Int64(Some((float_ts * 1_000_000_000_f64).trunc() as i64)) - .to_array()?, - ( - ScalarValue::Decimal128(Some(decimal_value), _, scale), - DataType::Timestamp(time_unit, None), - ) => { - let scale_factor = 10_i128.pow(*scale as u32); - let seconds = decimal_value / scale_factor; - let fraction = decimal_value % scale_factor; - - let timestamp_value = match time_unit { - TimeUnit::Second => ScalarValue::Int64(Some(seconds as i64)), - TimeUnit::Millisecond => { - let millis = seconds * 1_000 + (fraction * 1_000) / scale_factor; - ScalarValue::Int64(Some(millis as i64)) - } - TimeUnit::Microsecond => { - let micros = - seconds * 1_000_000 + (fraction * 1_000_000) / scale_factor; - ScalarValue::Int64(Some(micros as i64)) - } - TimeUnit::Nanosecond => { - let nanos = seconds * 1_000_000_000 - + (fraction * 1_000_000_000) / scale_factor; - ScalarValue::Int64(Some(nanos as i64)) - } - }; - - timestamp_value.to_array()? - } - _ => self.to_array()?, - }; - + let scalar_array = self.to_array()?; let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } + fn eq_array_decimal32( + array: &ArrayRef, + index: usize, + value: Option<&i32>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal32_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + + fn eq_array_decimal64( + array: &ArrayRef, + index: usize, + value: Option<&i64>, + precision: u8, + scale: i8, + ) -> Result { + let array = as_decimal64_array(array)?; + if array.precision() != precision || array.scale() != scale { + return Ok(false); + } + let is_null = array.is_null(index); + if let Some(v) = value { + Ok(!array.is_null(index) && array.value(index) == *v) + } else { + Ok(is_null) + } + } + fn eq_array_decimal( array: &ArrayRef, index: usize, @@ -3138,6 +3769,24 @@ impl ScalarValue { #[inline] pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { Ok(match self { + ScalarValue::Decimal32(v, precision, scale) => { + ScalarValue::eq_array_decimal32( + array, + index, + v.as_ref(), + *precision, + *scale, + )? + } + ScalarValue::Decimal64(v, precision, scale) => { + ScalarValue::eq_array_decimal64( + array, + index, + v.as_ref(), + *precision, + *scale, + )? + } ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( array, @@ -3157,59 +3806,61 @@ impl ScalarValue { )? } ScalarValue::Boolean(val) => { - eq_array_primitive!(array, index, BooleanArray, val)? + eq_array_primitive!(array, index, as_boolean_array, val)? } ScalarValue::Float16(val) => { - eq_array_primitive!(array, index, Float16Array, val)? + eq_array_primitive!(array, index, as_float16_array, val)? } ScalarValue::Float32(val) => { - eq_array_primitive!(array, index, Float32Array, val)? + eq_array_primitive!(array, index, as_float32_array, val)? } ScalarValue::Float64(val) => { - eq_array_primitive!(array, index, Float64Array, val)? + eq_array_primitive!(array, index, as_float64_array, val)? + } + ScalarValue::Int8(val) => { + eq_array_primitive!(array, index, as_int8_array, val)? } - ScalarValue::Int8(val) => eq_array_primitive!(array, index, Int8Array, val)?, ScalarValue::Int16(val) => { - eq_array_primitive!(array, index, Int16Array, val)? + eq_array_primitive!(array, index, as_int16_array, val)? } ScalarValue::Int32(val) => { - eq_array_primitive!(array, index, Int32Array, val)? + eq_array_primitive!(array, index, as_int32_array, val)? } ScalarValue::Int64(val) => { - eq_array_primitive!(array, index, Int64Array, val)? + eq_array_primitive!(array, index, as_int64_array, val)? } ScalarValue::UInt8(val) => { - eq_array_primitive!(array, index, UInt8Array, val)? + eq_array_primitive!(array, index, as_uint8_array, val)? } ScalarValue::UInt16(val) => { - eq_array_primitive!(array, index, UInt16Array, val)? + eq_array_primitive!(array, index, as_uint16_array, val)? } ScalarValue::UInt32(val) => { - eq_array_primitive!(array, index, UInt32Array, val)? + eq_array_primitive!(array, index, as_uint32_array, val)? } ScalarValue::UInt64(val) => { - eq_array_primitive!(array, index, UInt64Array, val)? + eq_array_primitive!(array, index, as_uint64_array, val)? } ScalarValue::Utf8(val) => { - eq_array_primitive!(array, index, StringArray, val)? + eq_array_primitive!(array, index, as_string_array, val)? } ScalarValue::Utf8View(val) => { - eq_array_primitive!(array, index, StringViewArray, val)? + eq_array_primitive!(array, index, as_string_view_array, val)? } ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val)? + eq_array_primitive!(array, index, as_large_string_array, val)? } ScalarValue::Binary(val) => { - eq_array_primitive!(array, index, BinaryArray, val)? + eq_array_primitive!(array, index, as_binary_array, val)? } ScalarValue::BinaryView(val) => { - eq_array_primitive!(array, index, BinaryViewArray, val)? + eq_array_primitive!(array, index, as_binary_view_array, val)? } ScalarValue::FixedSizeBinary(_, val) => { - eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? + eq_array_primitive!(array, index, as_fixed_size_binary_array, val)? } ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val)? + eq_array_primitive!(array, index, as_large_binary_array, val)? } ScalarValue::List(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) @@ -3227,58 +3878,58 @@ impl ScalarValue { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } ScalarValue::Date32(val) => { - eq_array_primitive!(array, index, Date32Array, val)? + eq_array_primitive!(array, index, as_date32_array, val)? } ScalarValue::Date64(val) => { - eq_array_primitive!(array, index, Date64Array, val)? + eq_array_primitive!(array, index, as_date64_array, val)? } ScalarValue::Time32Second(val) => { - eq_array_primitive!(array, index, Time32SecondArray, val)? + eq_array_primitive!(array, index, as_time32_second_array, val)? } ScalarValue::Time32Millisecond(val) => { - eq_array_primitive!(array, index, Time32MillisecondArray, val)? + eq_array_primitive!(array, index, as_time32_millisecond_array, val)? } ScalarValue::Time64Microsecond(val) => { - eq_array_primitive!(array, index, Time64MicrosecondArray, val)? + eq_array_primitive!(array, index, as_time64_microsecond_array, val)? } ScalarValue::Time64Nanosecond(val) => { - eq_array_primitive!(array, index, Time64NanosecondArray, val)? + eq_array_primitive!(array, index, as_time64_nanosecond_array, val)? } ScalarValue::TimestampSecond(val, _) => { - eq_array_primitive!(array, index, TimestampSecondArray, val)? + eq_array_primitive!(array, index, as_timestamp_second_array, val)? } ScalarValue::TimestampMillisecond(val, _) => { - eq_array_primitive!(array, index, TimestampMillisecondArray, val)? + eq_array_primitive!(array, index, as_timestamp_millisecond_array, val)? } ScalarValue::TimestampMicrosecond(val, _) => { - eq_array_primitive!(array, index, TimestampMicrosecondArray, val)? + eq_array_primitive!(array, index, as_timestamp_microsecond_array, val)? } ScalarValue::TimestampNanosecond(val, _) => { - eq_array_primitive!(array, index, TimestampNanosecondArray, val)? + eq_array_primitive!(array, index, as_timestamp_nanosecond_array, val)? } ScalarValue::IntervalYearMonth(val) => { - eq_array_primitive!(array, index, IntervalYearMonthArray, val)? + eq_array_primitive!(array, index, as_interval_ym_array, val)? } ScalarValue::IntervalDayTime(val) => { - eq_array_primitive!(array, index, IntervalDayTimeArray, val)? + eq_array_primitive!(array, index, as_interval_dt_array, val)? } ScalarValue::IntervalMonthDayNano(val) => { - eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)? + eq_array_primitive!(array, index, as_interval_mdn_array, val)? } ScalarValue::DurationSecond(val) => { - eq_array_primitive!(array, index, DurationSecondArray, val)? + eq_array_primitive!(array, index, as_duration_second_array, val)? } ScalarValue::DurationMillisecond(val) => { - eq_array_primitive!(array, index, DurationMillisecondArray, val)? + eq_array_primitive!(array, index, as_duration_millisecond_array, val)? } ScalarValue::DurationMicrosecond(val) => { - eq_array_primitive!(array, index, DurationMicrosecondArray, val)? + eq_array_primitive!(array, index, as_duration_microsecond_array, val)? } ScalarValue::DurationNanosecond(val) => { - eq_array_primitive!(array, index, DurationNanosecondArray, val)? + eq_array_primitive!(array, index, as_duration_nanosecond_array, val)? } ScalarValue::Union(value, _, _) => { - let array = as_union_array(array); + let array = as_union_array(array)?; let ti = array.type_id(index); let index = array.value_offset(index); if let Some((ti_v, value)) = value { @@ -3297,7 +3948,7 @@ impl ScalarValue { DataType::UInt16 => get_dict_value::(array, index)?, DataType::UInt32 => get_dict_value::(array, index)?, DataType::UInt64 => get_dict_value::(array, index)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + _ => unreachable!("Invalid dictionary keys type: {}", key_type), }; // was the value in the array non null? match values_index { @@ -3314,6 +3965,16 @@ impl ScalarValue { arr1 == &right } + /// Compare `self` with `other` and return an `Ordering`. + /// + /// This is the same as [`PartialOrd`] except that it returns + /// `Err` if the values cannot be compared, e.g., they have incompatible data types. + pub fn try_cmp(&self, other: &Self) -> Result { + self.partial_cmp(other).ok_or_else(|| { + _internal_datafusion_err!("Uncomparable values: {self:?}, {other:?}") + }) + } + /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { @@ -3324,6 +3985,8 @@ impl ScalarValue { | ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) | ScalarValue::Decimal128(_, _, _) | ScalarValue::Decimal256(_, _, _) | ScalarValue::Int8(_) @@ -3420,6 +4083,319 @@ impl ScalarValue { .map(|sv| sv.size() - size_of_val(sv)) .sum::() } + + /// Compacts the allocation referenced by `self` to the minimum, copying the data if + /// necessary. + /// + /// This can be relevant when `self` is a list or contains a list as a nested value, as + /// a single list holds an Arc to its entire original array buffer. + pub fn compact(&mut self) { + match self { + ScalarValue::Null + | ScalarValue::Boolean(_) + | ScalarValue::Float16(_) + | ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Date32(_) + | ScalarValue::Date64(_) + | ScalarValue::Time32Second(_) + | ScalarValue::Time32Millisecond(_) + | ScalarValue::Time64Microsecond(_) + | ScalarValue::Time64Nanosecond(_) + | ScalarValue::IntervalYearMonth(_) + | ScalarValue::IntervalDayTime(_) + | ScalarValue::IntervalMonthDayNano(_) + | ScalarValue::DurationSecond(_) + | ScalarValue::DurationMillisecond(_) + | ScalarValue::DurationMicrosecond(_) + | ScalarValue::DurationNanosecond(_) + | ScalarValue::Utf8(_) + | ScalarValue::LargeUtf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::TimestampSecond(_, _) + | ScalarValue::TimestampMillisecond(_, _) + | ScalarValue::TimestampMicrosecond(_, _) + | ScalarValue::TimestampNanosecond(_, _) + | ScalarValue::Binary(_) + | ScalarValue::FixedSizeBinary(_, _) + | ScalarValue::LargeBinary(_) + | ScalarValue::BinaryView(_) => (), + ScalarValue::FixedSizeList(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = FixedSizeListArray::from(array); + } + ScalarValue::List(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = ListArray::from(array); + } + ScalarValue::LargeList(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = LargeListArray::from(array) + } + ScalarValue::Struct(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = StructArray::from(array); + } + ScalarValue::Map(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = MapArray::from(array); + } + ScalarValue::Union(val, _, _) => { + if let Some((_, value)) = val.as_mut() { + value.compact(); + } + } + ScalarValue::Dictionary(_, value) => { + value.compact(); + } + } + } + + /// Compacts ([ScalarValue::compact]) the current [ScalarValue] and returns it. + pub fn compacted(mut self) -> Self { + self.compact(); + self + } + + /// Returns the minimum value for the given numeric `DataType`. + /// + /// This function returns the smallest representable value for numeric + /// and temporal data types. For non-numeric types, it returns `None`. + /// + /// # Supported Types + /// + /// - **Integer types**: `i8::MIN`, `i16::MIN`, etc. + /// - **Unsigned types**: Always 0 (`u8::MIN`, `u16::MIN`, etc.) + /// - **Float types**: Negative infinity (IEEE 754) + /// - **Decimal types**: Smallest value based on precision + /// - **Temporal types**: Minimum timestamp/date values + /// - **Time types**: 0 (midnight) + /// - **Duration types**: `i64::MIN` + pub fn min(datatype: &DataType) -> Option { + match datatype { + DataType::Int8 => Some(ScalarValue::Int8(Some(i8::MIN))), + DataType::Int16 => Some(ScalarValue::Int16(Some(i16::MIN))), + DataType::Int32 => Some(ScalarValue::Int32(Some(i32::MIN))), + DataType::Int64 => Some(ScalarValue::Int64(Some(i64::MIN))), + DataType::UInt8 => Some(ScalarValue::UInt8(Some(u8::MIN))), + DataType::UInt16 => Some(ScalarValue::UInt16(Some(u16::MIN))), + DataType::UInt32 => Some(ScalarValue::UInt32(Some(u32::MIN))), + DataType::UInt64 => Some(ScalarValue::UInt64(Some(u64::MIN))), + DataType::Float16 => Some(ScalarValue::Float16(Some(f16::NEG_INFINITY))), + DataType::Float32 => Some(ScalarValue::Float32(Some(f32::NEG_INFINITY))), + DataType::Float64 => Some(ScalarValue::Float64(Some(f64::NEG_INFINITY))), + DataType::Decimal128(precision, scale) => { + // For decimal, min is -10^(precision-scale) + 10^(-scale) + // But for simplicity, we use the minimum i128 value that fits the precision + let max_digits = 10_i128.pow(*precision as u32) - 1; + Some(ScalarValue::Decimal128( + Some(-max_digits), + *precision, + *scale, + )) + } + DataType::Decimal256(precision, scale) => { + // Similar to Decimal128 but with i256 + // For now, use a large negative value + let max_digits = i256::from_i128(10_i128) + .checked_pow(*precision as u32) + .and_then(|v| v.checked_sub(i256::from_i128(1))) + .unwrap_or(i256::MAX); + Some(ScalarValue::Decimal256( + Some(max_digits.neg_wrapping()), + *precision, + *scale, + )) + } + DataType::Date32 => Some(ScalarValue::Date32(Some(i32::MIN))), + DataType::Date64 => Some(ScalarValue::Date64(Some(i64::MIN))), + DataType::Time32(TimeUnit::Second) => { + Some(ScalarValue::Time32Second(Some(0))) + } + DataType::Time32(TimeUnit::Millisecond) => { + Some(ScalarValue::Time32Millisecond(Some(0))) + } + DataType::Time64(TimeUnit::Microsecond) => { + Some(ScalarValue::Time64Microsecond(Some(0))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + Some(ScalarValue::Time64Nanosecond(Some(0))) + } + DataType::Timestamp(unit, tz) => match unit { + TimeUnit::Second => { + Some(ScalarValue::TimestampSecond(Some(i64::MIN), tz.clone())) + } + TimeUnit::Millisecond => Some(ScalarValue::TimestampMillisecond( + Some(i64::MIN), + tz.clone(), + )), + TimeUnit::Microsecond => Some(ScalarValue::TimestampMicrosecond( + Some(i64::MIN), + tz.clone(), + )), + TimeUnit::Nanosecond => { + Some(ScalarValue::TimestampNanosecond(Some(i64::MIN), tz.clone())) + } + }, + DataType::Duration(unit) => match unit { + TimeUnit::Second => Some(ScalarValue::DurationSecond(Some(i64::MIN))), + TimeUnit::Millisecond => { + Some(ScalarValue::DurationMillisecond(Some(i64::MIN))) + } + TimeUnit::Microsecond => { + Some(ScalarValue::DurationMicrosecond(Some(i64::MIN))) + } + TimeUnit::Nanosecond => { + Some(ScalarValue::DurationNanosecond(Some(i64::MIN))) + } + }, + _ => None, + } + } + + /// Returns the maximum value for the given numeric `DataType`. + /// + /// This function returns the largest representable value for numeric + /// and temporal data types. For non-numeric types, it returns `None`. + /// + /// # Supported Types + /// + /// - **Integer types**: `i8::MAX`, `i16::MAX`, etc. + /// - **Unsigned types**: `u8::MAX`, `u16::MAX`, etc. + /// - **Float types**: Positive infinity (IEEE 754) + /// - **Decimal types**: Largest value based on precision + /// - **Temporal types**: Maximum timestamp/date values + /// - **Time types**: Maximum time in the day (1 day - 1 unit) + /// - **Duration types**: `i64::MAX` + pub fn max(datatype: &DataType) -> Option { + match datatype { + DataType::Int8 => Some(ScalarValue::Int8(Some(i8::MAX))), + DataType::Int16 => Some(ScalarValue::Int16(Some(i16::MAX))), + DataType::Int32 => Some(ScalarValue::Int32(Some(i32::MAX))), + DataType::Int64 => Some(ScalarValue::Int64(Some(i64::MAX))), + DataType::UInt8 => Some(ScalarValue::UInt8(Some(u8::MAX))), + DataType::UInt16 => Some(ScalarValue::UInt16(Some(u16::MAX))), + DataType::UInt32 => Some(ScalarValue::UInt32(Some(u32::MAX))), + DataType::UInt64 => Some(ScalarValue::UInt64(Some(u64::MAX))), + DataType::Float16 => Some(ScalarValue::Float16(Some(f16::INFINITY))), + DataType::Float32 => Some(ScalarValue::Float32(Some(f32::INFINITY))), + DataType::Float64 => Some(ScalarValue::Float64(Some(f64::INFINITY))), + DataType::Decimal128(precision, scale) => { + // For decimal, max is 10^(precision-scale) - 10^(-scale) + // But for simplicity, we use the maximum i128 value that fits the precision + let max_digits = 10_i128.pow(*precision as u32) - 1; + Some(ScalarValue::Decimal128( + Some(max_digits), + *precision, + *scale, + )) + } + DataType::Decimal256(precision, scale) => { + // Similar to Decimal128 but with i256 + let max_digits = i256::from_i128(10_i128) + .checked_pow(*precision as u32) + .and_then(|v| v.checked_sub(i256::from_i128(1))) + .unwrap_or(i256::MAX); + Some(ScalarValue::Decimal256( + Some(max_digits), + *precision, + *scale, + )) + } + DataType::Date32 => Some(ScalarValue::Date32(Some(i32::MAX))), + DataType::Date64 => Some(ScalarValue::Date64(Some(i64::MAX))), + DataType::Time32(TimeUnit::Second) => { + // 86399 seconds = 23:59:59 + Some(ScalarValue::Time32Second(Some(86_399))) + } + DataType::Time32(TimeUnit::Millisecond) => { + // 86_399_999 milliseconds = 23:59:59.999 + Some(ScalarValue::Time32Millisecond(Some(86_399_999))) + } + DataType::Time64(TimeUnit::Microsecond) => { + // 86_399_999_999 microseconds = 23:59:59.999999 + Some(ScalarValue::Time64Microsecond(Some(86_399_999_999))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + // 86_399_999_999_999 nanoseconds = 23:59:59.999999999 + Some(ScalarValue::Time64Nanosecond(Some(86_399_999_999_999))) + } + DataType::Timestamp(unit, tz) => match unit { + TimeUnit::Second => { + Some(ScalarValue::TimestampSecond(Some(i64::MAX), tz.clone())) + } + TimeUnit::Millisecond => Some(ScalarValue::TimestampMillisecond( + Some(i64::MAX), + tz.clone(), + )), + TimeUnit::Microsecond => Some(ScalarValue::TimestampMicrosecond( + Some(i64::MAX), + tz.clone(), + )), + TimeUnit::Nanosecond => { + Some(ScalarValue::TimestampNanosecond(Some(i64::MAX), tz.clone())) + } + }, + DataType::Duration(unit) => match unit { + TimeUnit::Second => Some(ScalarValue::DurationSecond(Some(i64::MAX))), + TimeUnit::Millisecond => { + Some(ScalarValue::DurationMillisecond(Some(i64::MAX))) + } + TimeUnit::Microsecond => { + Some(ScalarValue::DurationMicrosecond(Some(i64::MAX))) + } + TimeUnit::Nanosecond => { + Some(ScalarValue::DurationNanosecond(Some(i64::MAX))) + } + }, + _ => None, + } + } +} + +/// Compacts the data of an `ArrayData` into a new `ArrayData`. +/// +/// This is useful when you want to minimize the memory footprint of an +/// `ArrayData`. For example, the value returned by [`Array::slice`] still +/// points at the same underlying data buffers as the original array, which may +/// hold many more values. Calling `copy_array_data` on the sliced array will +/// create a new, smaller, `ArrayData` that only contains the data for the +/// sliced array. +/// +/// # Example +/// ``` +/// # use arrow::array::{make_array, Array, Int32Array}; +/// use datafusion_common::scalar::copy_array_data; +/// let array = Int32Array::from_iter_values(0..8192); +/// // Take only the first 2 elements +/// let sliced_array = array.slice(0, 2); +/// // The memory footprint of `sliced_array` is close to 8192 * 4 bytes +/// assert_eq!(32864, sliced_array.get_array_memory_size()); +/// // however, we can copy the data to a new `ArrayData` +/// let new_array = make_array(copy_array_data(&sliced_array.into_data())); +/// // The memory footprint of `new_array` is now only 2 * 4 bytes +/// // and overhead: +/// assert_eq!(160, new_array.get_array_memory_size()); +/// ``` +/// +/// See also [`ScalarValue::compact`] which applies to `ScalarValue` instances +/// as necessary. +pub fn copy_array_data(src_data: &ArrayData) -> ArrayData { + let mut copy = MutableArrayData::new(vec![&src_data], true, src_data.len()); + copy.extend(0, 0, src_data.len()); + copy.freeze() } macro_rules! impl_scalar { @@ -3459,7 +4435,7 @@ impl From<&str> for ScalarValue { impl From> for ScalarValue { fn from(value: Option<&str>) -> Self { let value = value.map(|s| s.to_string()); - ScalarValue::Utf8(value) + value.into() } } @@ -3486,7 +4462,13 @@ impl FromStr for ScalarValue { impl From for ScalarValue { fn from(value: String) -> Self { - ScalarValue::Utf8(Some(value)) + Some(value).into() + } +} + +impl From> for ScalarValue { + fn from(value: Option) -> Self { + ScalarValue::Utf8(value) } } @@ -3629,6 +4611,12 @@ macro_rules! format_option { impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + ScalarValue::Decimal32(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } + ScalarValue::Decimal64(v, p, s) => { + write!(f, "{v:?},{p:?},{s:?}")?; + } ScalarValue::Decimal128(v, p, s) => { write!(f, "{v:?},{p:?},{s:?}")?; } @@ -3672,12 +4660,28 @@ impl fmt::Display for ScalarValue { ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, - ScalarValue::Date32(e) => { - format_option!(f, e.map(|v| Date32Type::to_naive_date(v).to_string()))? - } - ScalarValue::Date64(e) => { - format_option!(f, e.map(|v| Date64Type::to_naive_date(v).to_string()))? - } + ScalarValue::Date32(e) => format_option!( + f, + e.map(|v| { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + match epoch.checked_add_signed(Duration::try_days(v as i64).unwrap()) + { + Some(date) => date.to_string(), + None => "".to_string(), + } + }) + )?, + ScalarValue::Date64(e) => format_option!( + f, + e.map(|v| { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + match epoch.checked_add_signed(Duration::try_milliseconds(v).unwrap()) + { + Some(date) => date.to_string(), + None => "".to_string(), + } + }) + )?, ScalarValue::Time32Second(e) => format_option!(f, e)?, ScalarValue::Time32Millisecond(e) => format_option!(f, e)?, ScalarValue::Time64Microsecond(e) => format_option!(f, e)?, @@ -3748,7 +4752,7 @@ impl fmt::Display for ScalarValue { array_value_to_string(arr.column(0), i).unwrap(); let value = array_value_to_string(arr.column(1), i).unwrap(); - buffer.push_back(format!("{}:{}", key, value)); + buffer.push_back(format!("{key}:{value}")); } format!( "{{{}}}", @@ -3767,7 +4771,7 @@ impl fmt::Display for ScalarValue { )? } ScalarValue::Union(val, _fields, _mode) => match val { - Some((id, val)) => write!(f, "{}:{}", id, val)?, + Some((id, val)) => write!(f, "{id}:{val}")?, None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, @@ -3802,6 +4806,8 @@ fn fmt_binary(data: &[u8], f: &mut fmt::Formatter) -> fmt::Result { impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + ScalarValue::Decimal32(_, _, _) => write!(f, "Decimal32({self})"), + ScalarValue::Decimal64(_, _, _) => write!(f, "Decimal64({self})"), ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), @@ -3944,7 +4950,7 @@ impl fmt::Debug for ScalarValue { write!(f, "DurationNanosecond(\"{self}\")") } ScalarValue::Union(val, _fields, _mode) => match val { - Some((id, val)) => write!(f, "Union {}:{}", id, val), + Some((id, val)) => write!(f, "Union {id}:{val}"), None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), @@ -3997,17 +5003,21 @@ impl ScalarType for Date32Type { #[cfg(test)] mod tests { + use std::sync::Arc; use super::*; - use crate::cast::{ - as_map_array, as_string_array, as_struct_array, as_uint32_array, as_uint64_array, - }; - + use crate::cast::{as_list_array, as_map_array, as_struct_array}; use crate::test_util::batches_to_string; - use arrow::array::{types::Float64Type, NullBufferBuilder}; - use arrow::buffer::{Buffer, OffsetBuffer}; + use arrow::array::{ + FixedSizeListBuilder, Int32Builder, LargeListBuilder, ListBuilder, MapBuilder, + NullArray, NullBufferBuilder, OffsetSizeTrait, PrimitiveBuilder, RecordBatch, + StringBuilder, StringDictionaryBuilder, StructBuilder, UnionBuilder, + }; + use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer}; use arrow::compute::{is_null, kernels}; - use arrow::datatypes::Fields; + use arrow::datatypes::{ + ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION, + }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; use chrono::NaiveDate; @@ -4068,7 +5078,7 @@ mod tests { #[test] #[should_panic( - expected = "Error building ScalarValue::Struct. Expected array with exactly one element, found array with 4 elements" + expected = "InvalidArgumentError(\"Incorrect array length for StructArray field \\\"bool\\\", expected 1 got 4\")" )] fn test_scalar_value_from_for_struct_should_panic() { let _ = ScalarStructBuilder::new() @@ -4364,7 +5374,7 @@ mod tests { ]); let array = ScalarValue::iter_to_array(scalars).unwrap(); - let list_array = as_list_array(&array); + let list_array = as_list_array(&array).unwrap(); // List[[1,2,3], null, [4,5]] let expected = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), @@ -4380,7 +5390,7 @@ mod tests { ]); let array = ScalarValue::iter_to_array(scalars).unwrap(); - let list_array = as_large_list_array(&array); + let list_array = as_large_list_array(&array).unwrap(); let expected = LargeListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), None, @@ -4444,6 +5454,17 @@ mod tests { } } + #[test] + fn test_eq_array_err_message() { + assert_starts_with( + ScalarValue::Utf8(Some("123".to_string())) + .eq_array(&(Arc::new(Int32Array::from(vec![123])) as ArrayRef), 0) + .unwrap_err() + .message(), + "could not cast array of type Int32 to arrow_array::array::byte_array::GenericByteArray>", + ); + } + #[test] fn scalar_add_trait_test() -> Result<()> { let float_value = ScalarValue::Float64(Some(123.)); @@ -4600,6 +5621,32 @@ mod tests { Ok(()) } + #[test] + fn test_try_cmp() { + assert_eq!( + ScalarValue::try_cmp( + &ScalarValue::Int32(Some(1)), + &ScalarValue::Int32(Some(2)) + ) + .unwrap(), + Ordering::Less + ); + assert_eq!( + ScalarValue::try_cmp(&ScalarValue::Int32(None), &ScalarValue::Int32(Some(2))) + .unwrap(), + Ordering::Less + ); + assert_starts_with( + ScalarValue::try_cmp( + &ScalarValue::Int32(Some(1)), + &ScalarValue::Int64(Some(2)), + ) + .unwrap_err() + .message(), + "Uncomparable values: Int32(1), Int64(2)", + ); + } + #[test] fn scalar_decimal_test() -> Result<()> { let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1); @@ -4706,6 +5753,114 @@ mod tests { Ok(()) } + #[test] + fn test_new_one_decimal128() { + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 0)).unwrap(), + ScalarValue::Decimal128(Some(1), 5, 0) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 1)).unwrap(), + ScalarValue::Decimal128(Some(10), 5, 1) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(100), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_one(&DataType::Decimal128(7, 2)).unwrap(), + ScalarValue::Decimal128(Some(100), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_one(&DataType::Decimal128(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_one(&DataType::Decimal128(0, 2)).is_err()); + assert!(ScalarValue::new_one(&DataType::Decimal128(5, 7)).is_err()); + } + + #[test] + fn test_new_one_decimal256() { + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 0)).unwrap(), + ScalarValue::Decimal256(Some(1.into()), 5, 0) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 1)).unwrap(), + ScalarValue::Decimal256(Some(10.into()), 5, 1) + ); + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(5, 2)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_one(&DataType::Decimal256(7, 2)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_one(&DataType::Decimal256(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_one(&DataType::Decimal256(0, 2)).is_err()); + assert!(ScalarValue::new_one(&DataType::Decimal256(5, 7)).is_err()); + } + + #[test] + fn test_new_ten_decimal128() { + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(5, 1)).unwrap(), + ScalarValue::Decimal128(Some(100), 5, 1) + ); + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(1000), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(), + ScalarValue::Decimal128(Some(1000), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err()); + assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 7)).is_err()); + } + + #[test] + fn test_new_ten_decimal256() { + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(5, 1)).unwrap(), + ScalarValue::Decimal256(Some(100.into()), 5, 1) + ); + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(5, 2)).unwrap(), + ScalarValue::Decimal256(Some(1000.into()), 5, 2) + ); + // More precision + assert_eq!( + ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(), + ScalarValue::Decimal256(Some(1000.into()), 7, 2) + ); + // No negative scale + assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err()); + // Invalid combination + assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err()); + assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 7)).is_err()); + } + + #[test] + fn test_new_negative_one_decimal128() { + assert_eq!( + ScalarValue::new_negative_one(&DataType::Decimal128(5, 0)).unwrap(), + ScalarValue::Decimal128(Some(-1), 5, 0) + ); + assert_eq!( + ScalarValue::new_negative_one(&DataType::Decimal128(5, 2)).unwrap(), + ScalarValue::Decimal128(Some(-100), 5, 2) + ); + } + #[test] fn test_list_partial_cmp() { let a = @@ -4757,10 +5912,113 @@ mod tests { ListArray::from_iter_primitive::(vec![Some(vec![ Some(10), Some(2), - Some(30), + Some(30), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(2), + Some(3), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(2), + Some(3), + Some(4), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + None, + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), ])]), )); - assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = ScalarValue::LargeList(Arc::new(LargeListArray::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![Some(vec![ + None, + Some(2), + Some(3), + ])]))); + let b = ScalarValue::LargeList(Arc::new(LargeListArray::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]))); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![None, Some(2), Some(3)])], + 3, + ), + )); + let b = ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(1), Some(2), Some(3)])], + 3, + ), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); } #[test] @@ -6365,10 +7623,7 @@ mod tests { let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}"); let root_err = err.find_root(); match root_err{ - DataFusionError::ArrowError( - ArrowError::ArithmeticOverflow(_), - _, - ) => {} + DataFusionError::ArrowError(err, _) if matches!(err.as_ref(), ArrowError::ArithmeticOverflow(_)) => {} _ => return Err(err), }; } @@ -6656,6 +7911,26 @@ mod tests { ScalarValue::Float64(Some(-9.9)), 5, ), + ( + ScalarValue::Decimal128(Some(10), 1, 0), + ScalarValue::Decimal128(Some(5), 1, 0), + 5, + ), + ( + ScalarValue::Decimal128(Some(5), 1, 0), + ScalarValue::Decimal128(Some(10), 1, 0), + 5, + ), + ( + ScalarValue::Decimal256(Some(10.into()), 1, 0), + ScalarValue::Decimal256(Some(5.into()), 1, 0), + 5, + ), + ( + ScalarValue::Decimal256(Some(5.into()), 1, 0), + ScalarValue::Decimal256(Some(10.into()), 1, 0), + 5, + ), ]; for (lhs, rhs, expected) in cases.iter() { let distance = lhs.distance(rhs).unwrap(); @@ -6663,6 +7938,24 @@ mod tests { } } + #[test] + fn test_distance_none() { + let cases = [ + ( + ScalarValue::Decimal128(Some(i128::MAX), DECIMAL128_MAX_PRECISION, 0), + ScalarValue::Decimal128(Some(-i128::MAX), DECIMAL128_MAX_PRECISION, 0), + ), + ( + ScalarValue::Decimal256(Some(i256::MAX), DECIMAL256_MAX_PRECISION, 0), + ScalarValue::Decimal256(Some(-i256::MAX), DECIMAL256_MAX_PRECISION, 0), + ), + ]; + for (lhs, rhs) in cases.iter() { + let distance = lhs.distance(rhs); + assert!(distance.is_none(), "{lhs} vs {rhs}"); + } + } + #[test] fn test_scalar_distance_invalid() { let cases = [ @@ -6704,7 +7997,33 @@ mod tests { (ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))), ( ScalarValue::Decimal128(Some(123), 5, 5), - ScalarValue::Decimal128(Some(120), 5, 5), + ScalarValue::Decimal128(Some(120), 5, 3), + ), + ( + ScalarValue::Decimal128(Some(123), 5, 5), + ScalarValue::Decimal128(Some(120), 3, 5), + ), + ( + ScalarValue::Decimal256(Some(123.into()), 5, 5), + ScalarValue::Decimal256(Some(120.into()), 3, 5), + ), + // Distance 2 * 2^50 is larger than usize + ( + ScalarValue::Decimal256( + Some(i256::from_parts(0, 2_i64.pow(50).into())), + 1, + 0, + ), + ScalarValue::Decimal256( + Some(i256::from_parts(0, (-(2_i64).pow(50)).into())), + 1, + 0, + ), + ), + // Distance overflow + ( + ScalarValue::Decimal256(Some(i256::from_parts(0, i128::MAX)), 1, 0), + ScalarValue::Decimal256(Some(i256::from_parts(0, -i128::MAX)), 1, 0), ), ]; for (lhs, rhs) in cases { @@ -6982,6 +8301,19 @@ mod tests { "); } + #[test] + fn test_display_date64_large_values() { + assert_eq!( + format!("{}", ScalarValue::Date64(Some(790179464505))), + "1995-01-15" + ); + // This used to panic, see https://github.com/apache/arrow-rs/issues/7728 + assert_eq!( + format!("{}", ScalarValue::Date64(Some(-790179464505600000))), + "" + ); + } + #[test] fn test_struct_display_null() { let fields = vec![Field::new("a", DataType::Int32, false)]; @@ -7171,14 +8503,14 @@ mod tests { fn get_random_timestamps(sample_size: u64) -> Vec { let vector_size = sample_size; let mut timestamp = vec![]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for i in 0..vector_size { - let year = rng.gen_range(1995..=2050); - let month = rng.gen_range(1..=12); - let day = rng.gen_range(1..=28); // to exclude invalid dates - let hour = rng.gen_range(0..=23); - let minute = rng.gen_range(0..=59); - let second = rng.gen_range(0..=59); + let year = rng.random_range(1995..=2050); + let month = rng.random_range(1..=12); + let day = rng.random_range(1..=28); // to exclude invalid dates + let hour = rng.random_range(0..=23); + let minute = rng.random_range(0..=59); + let second = rng.random_range(0..=59); if i % 4 == 0 { timestamp.push(ScalarValue::TimestampSecond( Some( @@ -7192,7 +8524,7 @@ mod tests { None, )) } else if i % 4 == 1 { - let millisec = rng.gen_range(0..=999); + let millisec = rng.random_range(0..=999); timestamp.push(ScalarValue::TimestampMillisecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7205,7 +8537,7 @@ mod tests { None, )) } else if i % 4 == 2 { - let microsec = rng.gen_range(0..=999_999); + let microsec = rng.random_range(0..=999_999); timestamp.push(ScalarValue::TimestampMicrosecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7218,7 +8550,7 @@ mod tests { None, )) } else if i % 4 == 3 { - let nanosec = rng.gen_range(0..=999_999_999); + let nanosec = rng.random_range(0..=999_999_999); timestamp.push(ScalarValue::TimestampNanosecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7242,27 +8574,27 @@ mod tests { let vector_size = sample_size; let mut intervals = vec![]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); const SECS_IN_ONE_DAY: i32 = 86_400; const MICROSECS_IN_ONE_DAY: i64 = 86_400_000_000; for i in 0..vector_size { if i % 4 == 0 { - let days = rng.gen_range(0..5000); + let days = rng.random_range(0..5000); // to not break second precision - let millis = rng.gen_range(0..SECS_IN_ONE_DAY) * 1000; + let millis = rng.random_range(0..SECS_IN_ONE_DAY) * 1000; intervals.push(ScalarValue::new_interval_dt(days, millis)); } else if i % 4 == 1 { - let days = rng.gen_range(0..5000); - let millisec = rng.gen_range(0..(MILLISECS_IN_ONE_DAY as i32)); + let days = rng.random_range(0..5000); + let millisec = rng.random_range(0..(MILLISECS_IN_ONE_DAY as i32)); intervals.push(ScalarValue::new_interval_dt(days, millisec)); } else if i % 4 == 2 { - let days = rng.gen_range(0..5000); + let days = rng.random_range(0..5000); // to not break microsec precision - let nanosec = rng.gen_range(0..MICROSECS_IN_ONE_DAY) * 1000; + let nanosec = rng.random_range(0..MICROSECS_IN_ONE_DAY) * 1000; intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); } else { - let days = rng.gen_range(0..5000); - let nanosec = rng.gen_range(0..NANOSECS_IN_ONE_DAY); + let days = rng.random_range(0..5000); + let nanosec = rng.random_range(0..NANOSECS_IN_ONE_DAY); intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); } } @@ -7392,4 +8724,456 @@ mod tests { ]; assert!(scalars.iter().all(|s| s.is_null())); } + + // `err.to_string()` depends on backtrace being present (may have backtrace appended) + // `err.strip_backtrace()` also depends on backtrace being present (may have "This was likely caused by ..." stripped) + fn assert_starts_with(actual: impl AsRef, expected_prefix: impl AsRef) { + let actual = actual.as_ref(); + let expected_prefix = expected_prefix.as_ref(); + assert!( + actual.starts_with(expected_prefix), + "Expected '{actual}' to start with '{expected_prefix}'" + ); + } + + #[test] + fn test_new_default() { + // Test numeric types + assert_eq!( + ScalarValue::new_default(&DataType::Int32).unwrap(), + ScalarValue::Int32(Some(0)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Float64).unwrap(), + ScalarValue::Float64(Some(0.0)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Boolean).unwrap(), + ScalarValue::Boolean(Some(false)) + ); + + // Test string types + assert_eq!( + ScalarValue::new_default(&DataType::Utf8).unwrap(), + ScalarValue::Utf8(Some("".to_string())) + ); + assert_eq!( + ScalarValue::new_default(&DataType::LargeUtf8).unwrap(), + ScalarValue::LargeUtf8(Some("".to_string())) + ); + + // Test binary types + assert_eq!( + ScalarValue::new_default(&DataType::Binary).unwrap(), + ScalarValue::Binary(Some(vec![])) + ); + + // Test fixed size binary + assert_eq!( + ScalarValue::new_default(&DataType::FixedSizeBinary(5)).unwrap(), + ScalarValue::FixedSizeBinary(5, Some(vec![0, 0, 0, 0, 0])) + ); + + // Test temporal types + assert_eq!( + ScalarValue::new_default(&DataType::Date32).unwrap(), + ScalarValue::Date32(Some(0)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Time32(TimeUnit::Second)).unwrap(), + ScalarValue::Time32Second(Some(0)) + ); + + // Test decimal types + assert_eq!( + ScalarValue::new_default(&DataType::Decimal128(10, 2)).unwrap(), + ScalarValue::Decimal128(Some(0), 10, 2) + ); + + // Test list type + let list_field = Field::new_list_field(DataType::Int32, true); + let list_result = + ScalarValue::new_default(&DataType::List(Arc::new(list_field.clone()))) + .unwrap(); + match list_result { + ScalarValue::List(arr) => { + assert_eq!(arr.len(), 1); + assert_eq!(arr.value_length(0), 0); // empty list + } + _ => panic!("Expected List"), + } + + // Test struct type + let struct_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + let struct_result = + ScalarValue::new_default(&DataType::Struct(struct_fields.clone())).unwrap(); + match struct_result { + ScalarValue::Struct(arr) => { + assert_eq!(arr.len(), 1); + assert_eq!(arr.column(0).as_primitive::().value(0), 0); + assert_eq!(arr.column(1).as_string::().value(0), ""); + } + _ => panic!("Expected Struct"), + } + + // Test union type + let union_fields = UnionFields::new( + vec![0, 1], + vec![ + Field::new("i32", DataType::Int32, false), + Field::new("f64", DataType::Float64, false), + ], + ); + let union_result = ScalarValue::new_default(&DataType::Union( + union_fields.clone(), + UnionMode::Sparse, + )) + .unwrap(); + match union_result { + ScalarValue::Union(Some((type_id, value)), _, _) => { + assert_eq!(type_id, 0); + assert_eq!(*value, ScalarValue::Int32(Some(0))); + } + _ => panic!("Expected Union"), + } + } + + #[test] + fn test_scalar_min() { + // Test integer types + assert_eq!( + ScalarValue::min(&DataType::Int8), + Some(ScalarValue::Int8(Some(i8::MIN))) + ); + assert_eq!( + ScalarValue::min(&DataType::Int32), + Some(ScalarValue::Int32(Some(i32::MIN))) + ); + assert_eq!( + ScalarValue::min(&DataType::UInt8), + Some(ScalarValue::UInt8(Some(0))) + ); + assert_eq!( + ScalarValue::min(&DataType::UInt64), + Some(ScalarValue::UInt64(Some(0))) + ); + + // Test float types + assert_eq!( + ScalarValue::min(&DataType::Float32), + Some(ScalarValue::Float32(Some(f32::NEG_INFINITY))) + ); + assert_eq!( + ScalarValue::min(&DataType::Float64), + Some(ScalarValue::Float64(Some(f64::NEG_INFINITY))) + ); + + // Test decimal types + let decimal_min = ScalarValue::min(&DataType::Decimal128(5, 2)).unwrap(); + match decimal_min { + ScalarValue::Decimal128(Some(val), 5, 2) => { + assert_eq!(val, -99999); // -999.99 with scale 2 + } + _ => panic!("Expected Decimal128"), + } + + // Test temporal types + assert_eq!( + ScalarValue::min(&DataType::Date32), + Some(ScalarValue::Date32(Some(i32::MIN))) + ); + assert_eq!( + ScalarValue::min(&DataType::Time32(TimeUnit::Second)), + Some(ScalarValue::Time32Second(Some(0))) + ); + assert_eq!( + ScalarValue::min(&DataType::Timestamp(TimeUnit::Nanosecond, None)), + Some(ScalarValue::TimestampNanosecond(Some(i64::MIN), None)) + ); + + // Test duration types + assert_eq!( + ScalarValue::min(&DataType::Duration(TimeUnit::Second)), + Some(ScalarValue::DurationSecond(Some(i64::MIN))) + ); + + // Test unsupported types + assert_eq!(ScalarValue::min(&DataType::Utf8), None); + assert_eq!(ScalarValue::min(&DataType::Binary), None); + assert_eq!( + ScalarValue::min(&DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))), + None + ); + } + + #[test] + fn test_scalar_max() { + // Test integer types + assert_eq!( + ScalarValue::max(&DataType::Int8), + Some(ScalarValue::Int8(Some(i8::MAX))) + ); + assert_eq!( + ScalarValue::max(&DataType::Int32), + Some(ScalarValue::Int32(Some(i32::MAX))) + ); + assert_eq!( + ScalarValue::max(&DataType::UInt8), + Some(ScalarValue::UInt8(Some(u8::MAX))) + ); + assert_eq!( + ScalarValue::max(&DataType::UInt64), + Some(ScalarValue::UInt64(Some(u64::MAX))) + ); + + // Test float types + assert_eq!( + ScalarValue::max(&DataType::Float32), + Some(ScalarValue::Float32(Some(f32::INFINITY))) + ); + assert_eq!( + ScalarValue::max(&DataType::Float64), + Some(ScalarValue::Float64(Some(f64::INFINITY))) + ); + + // Test decimal types + let decimal_max = ScalarValue::max(&DataType::Decimal128(5, 2)).unwrap(); + match decimal_max { + ScalarValue::Decimal128(Some(val), 5, 2) => { + assert_eq!(val, 99999); // 999.99 with scale 2 + } + _ => panic!("Expected Decimal128"), + } + + // Test temporal types + assert_eq!( + ScalarValue::max(&DataType::Date32), + Some(ScalarValue::Date32(Some(i32::MAX))) + ); + assert_eq!( + ScalarValue::max(&DataType::Time32(TimeUnit::Second)), + Some(ScalarValue::Time32Second(Some(86_399))) // 23:59:59 + ); + assert_eq!( + ScalarValue::max(&DataType::Time64(TimeUnit::Microsecond)), + Some(ScalarValue::Time64Microsecond(Some(86_399_999_999))) // 23:59:59.999999 + ); + assert_eq!( + ScalarValue::max(&DataType::Timestamp(TimeUnit::Nanosecond, None)), + Some(ScalarValue::TimestampNanosecond(Some(i64::MAX), None)) + ); + + // Test duration types + assert_eq!( + ScalarValue::max(&DataType::Duration(TimeUnit::Millisecond)), + Some(ScalarValue::DurationMillisecond(Some(i64::MAX))) + ); + + // Test unsupported types + assert_eq!(ScalarValue::max(&DataType::Utf8), None); + assert_eq!(ScalarValue::max(&DataType::Binary), None); + assert_eq!( + ScalarValue::max(&DataType::Struct(Fields::from(vec![Field::new( + "field", + DataType::Int32, + true + )]))), + None + ); + } + + #[test] + fn test_min_max_float16() { + // Test Float16 min and max + let min_f16 = ScalarValue::min(&DataType::Float16).unwrap(); + match min_f16 { + ScalarValue::Float16(Some(val)) => { + assert_eq!(val, f16::NEG_INFINITY); + } + _ => panic!("Expected Float16"), + } + + let max_f16 = ScalarValue::max(&DataType::Float16).unwrap(); + match max_f16 { + ScalarValue::Float16(Some(val)) => { + assert_eq!(val, f16::INFINITY); + } + _ => panic!("Expected Float16"), + } + } + + #[test] + fn test_new_default_interval() { + // Test all interval types + assert_eq!( + ScalarValue::new_default(&DataType::Interval(IntervalUnit::YearMonth)) + .unwrap(), + ScalarValue::IntervalYearMonth(Some(0)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Interval(IntervalUnit::DayTime)).unwrap(), + ScalarValue::IntervalDayTime(Some(IntervalDayTime::ZERO)) + ); + assert_eq!( + ScalarValue::new_default(&DataType::Interval(IntervalUnit::MonthDayNano)) + .unwrap(), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::ZERO)) + ); + } + + #[test] + fn test_min_max_with_timezone() { + let tz = Some(Arc::from("UTC")); + + // Test timestamp with timezone + let min_ts = + ScalarValue::min(&DataType::Timestamp(TimeUnit::Second, tz.clone())).unwrap(); + match min_ts { + ScalarValue::TimestampSecond(Some(val), Some(tz_str)) => { + assert_eq!(val, i64::MIN); + assert_eq!(tz_str.as_ref(), "UTC"); + } + _ => panic!("Expected TimestampSecond with timezone"), + } + + let max_ts = + ScalarValue::max(&DataType::Timestamp(TimeUnit::Millisecond, tz.clone())) + .unwrap(); + match max_ts { + ScalarValue::TimestampMillisecond(Some(val), Some(tz_str)) => { + assert_eq!(val, i64::MAX); + assert_eq!(tz_str.as_ref(), "UTC"); + } + _ => panic!("Expected TimestampMillisecond with timezone"), + } + } + + #[test] + fn test_convert_array_to_scalar_vec() { + // 1: Regular ListArray + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(4)]), + ]); + let converted = ScalarValue::convert_array_to_scalar_vec(&list).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(None), + ScalarValue::Int64(Some(4)) + ]), + ] + ); + + // 2: Regular LargeListArray + let large_list = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(4)]), + ]); + let converted = ScalarValue::convert_array_to_scalar_vec(&large_list).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(None), + ScalarValue::Int64(Some(4)) + ]), + ] + ); + + // 3: Funky (null slot has non-zero list offsets) + // Offsets + Values looks like this: [[1, 2], [3, 4], [5]] + // But with NullBuffer it's like this: [[1, 2], NULL, [5]] + let funky = ListArray::new( + Field::new_list_field(DataType::Int64, true).into(), + OffsetBuffer::new(vec![0, 2, 4, 5].into()), + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])), + Some(NullBuffer::from(vec![true, false, true])), + ); + let converted = ScalarValue::convert_array_to_scalar_vec(&funky).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ScalarValue::Int64(Some(5))]), + ] + ); + + // 4: Offsets + Values looks like this: [[1, 2], [], [5]] + // But with NullBuffer it's like this: [[1, 2], NULL, [5]] + // The converted result is: [[1, 2], None, [5]] + let array4 = ListArray::new( + Field::new_list_field(DataType::Int64, true).into(), + OffsetBuffer::new(vec![0, 2, 2, 5].into()), + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])), + Some(NullBuffer::from(vec![true, false, true])), + ); + let converted = ScalarValue::convert_array_to_scalar_vec(&array4).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + None, + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(Some(4)), + ScalarValue::Int64(Some(5)), + ]), + ] + ); + + // 5: Offsets + Values looks like this: [[1, 2], [], [5]] + // Same as 4, but the middle array is not null, so after conversion it's empty. + let array5 = ListArray::new( + Field::new_list_field(DataType::Int64, true).into(), + OffsetBuffer::new(vec![0, 2, 2, 5].into()), + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])), + Some(NullBuffer::from(vec![true, true, true])), + ); + let converted = ScalarValue::convert_array_to_scalar_vec(&array5).unwrap(); + assert_eq!( + converted, + vec![ + Some(vec![ + ScalarValue::Int64(Some(1)), + ScalarValue::Int64(Some(2)) + ]), + Some(vec![]), + Some(vec![ + ScalarValue::Int64(Some(3)), + ScalarValue::Int64(Some(4)), + ScalarValue::Int64(Some(5)), + ]), + ] + ); + } } diff --git a/datafusion/common/src/scalar/struct_builder.rs b/datafusion/common/src/scalar/struct_builder.rs index 5ed464018401d..fd19dccf89636 100644 --- a/datafusion/common/src/scalar/struct_builder.rs +++ b/datafusion/common/src/scalar/struct_builder.rs @@ -17,7 +17,6 @@ //! [`ScalarStructBuilder`] for building [`ScalarValue::Struct`] -use crate::error::_internal_err; use crate::{Result, ScalarValue}; use arrow::array::{ArrayRef, StructArray}; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; @@ -109,17 +108,8 @@ impl ScalarStructBuilder { pub fn build(self) -> Result { let Self { fields, arrays } = self; - for array in &arrays { - if array.len() != 1 { - return _internal_err!( - "Error building ScalarValue::Struct. \ - Expected array with exactly one element, found array with {} elements", - array.len() - ); - } - } - - let struct_array = StructArray::try_new(Fields::from(fields), arrays, None)?; + let struct_array = + StructArray::try_new_with_length(Fields::from(fields), arrays, None, 1)?; Ok(ScalarValue::Struct(Arc::new(struct_array))) } } @@ -181,3 +171,15 @@ impl IntoFields for Vec { Fields::from(self) } } + +#[cfg(test)] +mod tests { + use super::*; + + // Other cases are tested by doc tests + #[test] + fn test_empty_struct() { + let sv = ScalarStructBuilder::new().build().unwrap(); + assert_eq!(format!("{sv}"), "{}"); + } +} diff --git a/datafusion/common/src/spans.rs b/datafusion/common/src/spans.rs index 5111e264123ce..c0b52977e14a9 100644 --- a/datafusion/common/src/spans.rs +++ b/datafusion/common/src/spans.rs @@ -39,6 +39,7 @@ impl fmt::Debug for Location { } } +#[cfg(feature = "sql")] impl From for Location { fn from(value: sqlparser::tokenizer::Location) -> Self { Self { @@ -70,6 +71,7 @@ impl Span { /// Convert a [`Span`](sqlparser::tokenizer::Span) from the parser, into a /// DataFusion [`Span`]. If the input span is empty (line 0 column 0, to /// line 0 column 0), then [`None`] is returned. + #[cfg(feature = "sql")] pub fn try_from_sqlparser_span(span: sqlparser::tokenizer::Span) -> Option { if span == sqlparser::tokenizer::Span::empty() { None diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 5b841db53c5ee..2481a88676efb 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -21,7 +21,8 @@ use std::fmt::{self, Debug, Display}; use crate::{Result, ScalarValue}; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use crate::error::_plan_err; +use arrow::datatypes::{DataType, Schema}; /// Represents a value with a degree of certainty. `Precision` is used to /// propagate information the precision of statistical values. @@ -119,10 +120,15 @@ impl Precision { /// values is [`Precision::Absent`], the result is `Absent` too. pub fn add(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a + b), + (Precision::Exact(a), Precision::Exact(b)) => a.checked_add(*b).map_or_else( + || Precision::Inexact(a.saturating_add(*b)), + Precision::Exact, + ), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a + b), + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(a.saturating_add(*b)) + } (_, _) => Precision::Absent, } } @@ -132,10 +138,15 @@ impl Precision { /// values is [`Precision::Absent`], the result is `Absent` too. pub fn sub(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a - b), + (Precision::Exact(a), Precision::Exact(b)) => a.checked_sub(*b).map_or_else( + || Precision::Inexact(a.saturating_sub(*b)), + Precision::Exact, + ), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a - b), + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(a.saturating_sub(*b)) + } (_, _) => Precision::Absent, } } @@ -145,10 +156,15 @@ impl Precision { /// values is [`Precision::Absent`], the result is `Absent` too. pub fn multiply(&self, other: &Precision) -> Precision { match (self, other) { - (Precision::Exact(a), Precision::Exact(b)) => Precision::Exact(a * b), + (Precision::Exact(a), Precision::Exact(b)) => a.checked_mul(*b).map_or_else( + || Precision::Inexact(a.saturating_mul(*b)), + Precision::Exact, + ), (Precision::Inexact(a), Precision::Exact(b)) | (Precision::Exact(a), Precision::Inexact(b)) - | (Precision::Inexact(a), Precision::Inexact(b)) => Precision::Inexact(a * b), + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(a.saturating_mul(*b)) + } (_, _) => Precision::Absent, } } @@ -232,8 +248,8 @@ impl Precision { impl Debug for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Precision::Exact(inner) => write!(f, "Exact({:?})", inner), - Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Exact(inner) => write!(f, "Exact({inner:?})"), + Precision::Inexact(inner) => write!(f, "Inexact({inner:?})"), Precision::Absent => write!(f, "Absent"), } } @@ -242,8 +258,8 @@ impl Debug for Precision { impl Display for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Precision::Exact(inner) => write!(f, "Exact({:?})", inner), - Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Exact(inner) => write!(f, "Exact({inner:?})"), + Precision::Inexact(inner) => write!(f, "Inexact({inner:?})"), Precision::Absent => write!(f, "Absent"), } } @@ -271,11 +287,25 @@ pub struct Statistics { pub num_rows: Precision, /// Total bytes of the table rows. pub total_byte_size: Precision, - /// Statistics on a column level. It contains a [`ColumnStatistics`] for - /// each field in the schema of the table to which the [`Statistics`] refer. + /// Statistics on a column level. + /// + /// It must contains a [`ColumnStatistics`] for each field in the schema of + /// the table to which the [`Statistics`] refer. pub column_statistics: Vec, } +impl Default for Statistics { + /// Returns a new [`Statistics`] instance with all fields set to unknown + /// and no columns. + fn default() -> Self { + Self { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![], + } + } +} + impl Statistics { /// Returns a [`Statistics`] instance for the given schema by assigning /// unknown statistics to each column in the schema. @@ -296,6 +326,24 @@ impl Statistics { .collect() } + /// Set the number of rows + pub fn with_num_rows(mut self, num_rows: Precision) -> Self { + self.num_rows = num_rows; + self + } + + /// Set the total size, in bytes + pub fn with_total_byte_size(mut self, total_byte_size: Precision) -> Self { + self.total_byte_size = total_byte_size; + self + } + + /// Add a column to the column statistics + pub fn add_column_statistics(mut self, column_stats: ColumnStatistics) -> Self { + self.column_statistics.push(column_stats); + self + } + /// If the exactness of a [`Statistics`] instance is lost, this function relaxes /// the exactness of all information by converting them [`Precision::Inexact`]. pub fn to_inexact(mut self) -> Self { @@ -319,6 +367,7 @@ impl Statistics { return self; }; + #[allow(clippy::large_enum_variant)] enum Slot { /// The column is taken and put into the specified statistics location Taken(usize), @@ -351,18 +400,21 @@ impl Statistics { self } - /// Calculates the statistics after `fetch` and `skip` operations apply. + /// Calculates the statistics after applying `fetch` and `skip` operations. + /// /// Here, `self` denotes per-partition statistics. Use the `n_partitions` /// parameter to compute global statistics in a multi-partition setting. pub fn with_fetch( mut self, - schema: SchemaRef, fetch: Option, skip: usize, n_partitions: usize, ) -> Result { let fetch_val = fetch.unwrap_or(usize::MAX); + // Get the ratio of rows after / rows before on a per-partition basis + let num_rows_before = self.num_rows; + self.num_rows = match self { Statistics { num_rows: Precision::Exact(nr), @@ -396,8 +448,7 @@ impl Statistics { // At this point we know that we were given a `fetch` value // as the `None` case would go into the branch above. Since // the input has more rows than `fetch + skip`, the number - // of rows will be the `fetch`, but we won't be able to - // predict the other statistics. + // of rows will be the `fetch`, other statistics will have to be downgraded to inexact. check_num_rows( fetch_val.checked_mul(n_partitions), // We know that we have an estimate for the number of rows: @@ -410,10 +461,132 @@ impl Statistics { .. } => check_num_rows(fetch.and_then(|v| v.checked_mul(n_partitions)), false), }; - self.column_statistics = Statistics::unknown_column(&schema); - self.total_byte_size = Precision::Absent; + let ratio: f64 = match (num_rows_before, self.num_rows) { + ( + Precision::Exact(nr_before) | Precision::Inexact(nr_before), + Precision::Exact(nr_after) | Precision::Inexact(nr_after), + ) => { + if nr_before == 0 { + 0.0 + } else { + nr_after as f64 / nr_before as f64 + } + } + _ => 0.0, + }; + self.column_statistics = self + .column_statistics + .into_iter() + .map(ColumnStatistics::to_inexact) + .collect(); + // Adjust the total_byte_size for the ratio of rows before and after, also marking it as inexact + self.total_byte_size = match &self.total_byte_size { + Precision::Exact(n) | Precision::Inexact(n) => { + let adjusted = (*n as f64 * ratio) as usize; + Precision::Inexact(adjusted) + } + Precision::Absent => Precision::Absent, + }; Ok(self) } + + /// Summarize zero or more statistics into a single `Statistics` instance. + /// + /// The method assumes that all statistics are for the same schema. + /// If not, maybe you can call `SchemaMapper::map_column_statistics` to make them consistent. + /// + /// Returns an error if the statistics do not match the specified schemas. + pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result + where + I: IntoIterator, + { + let mut items = items.into_iter(); + + let Some(init) = items.next() else { + return Ok(Statistics::new_unknown(schema)); + }; + items.try_fold(init.clone(), |acc: Statistics, item_stats: &Statistics| { + acc.try_merge(item_stats) + }) + } + + /// Merge this Statistics value with another Statistics value. + /// + /// Returns an error if the statistics do not match (different schemas). + /// + /// # Example + /// ``` + /// # use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; + /// # use arrow::datatypes::{Field, Schema, DataType}; + /// # use datafusion_common::stats::Precision; + /// let stats1 = Statistics::default() + /// .with_num_rows(Precision::Exact(1)) + /// .with_total_byte_size(Precision::Exact(2)) + /// .add_column_statistics(ColumnStatistics::new_unknown() + /// .with_null_count(Precision::Exact(3)) + /// .with_min_value(Precision::Exact(ScalarValue::from(4))) + /// .with_max_value(Precision::Exact(ScalarValue::from(5))) + /// ); + /// + /// let stats2 = Statistics::default() + /// .with_num_rows(Precision::Exact(10)) + /// .with_total_byte_size(Precision::Inexact(20)) + /// .add_column_statistics(ColumnStatistics::new_unknown() + /// // absent null count + /// .with_min_value(Precision::Exact(ScalarValue::from(40))) + /// .with_max_value(Precision::Exact(ScalarValue::from(50))) + /// ); + /// + /// let merged_stats = stats1.try_merge(&stats2).unwrap(); + /// let expected_stats = Statistics::default() + /// .with_num_rows(Precision::Exact(11)) + /// .with_total_byte_size(Precision::Inexact(22)) // inexact in stats2 --> inexact + /// .add_column_statistics( + /// ColumnStatistics::new_unknown() + /// .with_null_count(Precision::Absent) // missing from stats2 --> absent + /// .with_min_value(Precision::Exact(ScalarValue::from(4))) + /// .with_max_value(Precision::Exact(ScalarValue::from(50))) + /// ); + /// + /// assert_eq!(merged_stats, expected_stats) + /// ``` + pub fn try_merge(self, other: &Statistics) -> Result { + let Self { + mut num_rows, + mut total_byte_size, + mut column_statistics, + } = self; + + // Accumulate statistics for subsequent items + num_rows = num_rows.add(&other.num_rows); + total_byte_size = total_byte_size.add(&other.total_byte_size); + + if column_statistics.len() != other.column_statistics.len() { + return _plan_err!( + "Cannot merge statistics with different number of columns: {} vs {}", + column_statistics.len(), + other.column_statistics.len() + ); + } + + for (item_col_stats, col_stats) in other + .column_statistics + .iter() + .zip(column_statistics.iter_mut()) + { + col_stats.null_count = col_stats.null_count.add(&item_col_stats.null_count); + col_stats.max_value = col_stats.max_value.max(&item_col_stats.max_value); + col_stats.min_value = col_stats.min_value.min(&item_col_stats.min_value); + col_stats.sum_value = col_stats.sum_value.add(&item_col_stats.sum_value); + col_stats.distinct_count = Precision::Absent; + } + + Ok(Statistics { + num_rows, + total_byte_size, + column_statistics, + }) + } } /// Creates an estimate of the number of rows in the output using the given @@ -441,7 +614,7 @@ impl Display for Statistics { .iter() .enumerate() .map(|(i, cs)| { - let s = format!("(Col[{}]:", i); + let s = format!("(Col[{i}]:"); let s = if cs.min_value != Precision::Absent { format!("{} Min={}", s, cs.min_value) } else { @@ -521,6 +694,36 @@ impl ColumnStatistics { } } + /// Set the null count + pub fn with_null_count(mut self, null_count: Precision) -> Self { + self.null_count = null_count; + self + } + + /// Set the max value + pub fn with_max_value(mut self, max_value: Precision) -> Self { + self.max_value = max_value; + self + } + + /// Set the min value + pub fn with_min_value(mut self, min_value: Precision) -> Self { + self.min_value = min_value; + self + } + + /// Set the sum value + pub fn with_sum_value(mut self, sum_value: Precision) -> Self { + self.sum_value = sum_value; + self + } + + /// Set the distinct count + pub fn with_distinct_count(mut self, distinct_count: Precision) -> Self { + self.distinct_count = distinct_count; + self + } + /// If the exactness of a [`ColumnStatistics`] instance is lost, this /// function relaxes the exactness of all information by converting them /// [`Precision::Inexact`]. @@ -537,6 +740,9 @@ impl ColumnStatistics { #[cfg(test)] mod tests { use super::*; + use crate::assert_contains; + use arrow::datatypes::Field; + use std::sync::Arc; #[test] fn test_get_value() { @@ -616,11 +822,21 @@ mod tests { let precision2 = Precision::Inexact(23); let precision3 = Precision::Exact(30); let absent_precision = Precision::Absent; + let precision_max_exact = Precision::Exact(usize::MAX); + let precision_max_inexact = Precision::Exact(usize::MAX); assert_eq!(precision1.add(&precision2), Precision::Inexact(65)); assert_eq!(precision1.add(&precision3), Precision::Exact(72)); assert_eq!(precision2.add(&precision3), Precision::Inexact(53)); assert_eq!(precision1.add(&absent_precision), Precision::Absent); + assert_eq!( + precision_max_exact.add(&precision1), + Precision::Inexact(usize::MAX) + ); + assert_eq!( + precision_max_inexact.add(&precision1), + Precision::Inexact(usize::MAX) + ); } #[test] @@ -652,6 +868,8 @@ mod tests { assert_eq!(precision1.sub(&precision2), Precision::Inexact(19)); assert_eq!(precision1.sub(&precision3), Precision::Exact(12)); + assert_eq!(precision2.sub(&precision1), Precision::Inexact(0)); + assert_eq!(precision3.sub(&precision1), Precision::Inexact(0)); assert_eq!(precision1.sub(&absent_precision), Precision::Absent); } @@ -680,12 +898,22 @@ mod tests { let precision1 = Precision::Exact(6); let precision2 = Precision::Inexact(3); let precision3 = Precision::Exact(5); + let precision_max_exact = Precision::Exact(usize::MAX); + let precision_max_inexact = Precision::Exact(usize::MAX); let absent_precision = Precision::Absent; assert_eq!(precision1.multiply(&precision2), Precision::Inexact(18)); assert_eq!(precision1.multiply(&precision3), Precision::Exact(30)); assert_eq!(precision2.multiply(&precision3), Precision::Inexact(15)); assert_eq!(precision1.multiply(&absent_precision), Precision::Absent); + assert_eq!( + precision_max_exact.multiply(&precision1), + Precision::Inexact(usize::MAX) + ); + assert_eq!( + precision_max_inexact.multiply(&precision1), + Precision::Inexact(usize::MAX) + ); } #[test] @@ -798,4 +1026,500 @@ mod tests { distinct_count: Precision::Exact(100), } } + + #[test] + fn test_try_merge_basic() { + // Create a schema with two columns + let schema = Arc::new(Schema::new(vec![ + Field::new("col1", DataType::Int32, false), + Field::new("col2", DataType::Int32, false), + ])); + + // Create items with statistics + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(200))), + min_value: Precision::Exact(ScalarValue::Int32(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(1000))), + distinct_count: Precision::Absent, + }, + ], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(15), + total_byte_size: Precision::Exact(150), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(120))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(600))), + distinct_count: Precision::Absent, + }, + ColumnStatistics { + null_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int32(Some(180))), + min_value: Precision::Exact(ScalarValue::Int32(Some(5))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(1200))), + distinct_count: Precision::Absent, + }, + ], + }; + + let items = vec![stats1, stats2]; + + let summary_stats = Statistics::try_merge_iter(&items, &schema).unwrap(); + + // Verify the results + assert_eq!(summary_stats.num_rows, Precision::Exact(25)); // 10 + 15 + assert_eq!(summary_stats.total_byte_size, Precision::Exact(250)); // 100 + 150 + + // Verify column statistics + let col1_stats = &summary_stats.column_statistics[0]; + assert_eq!(col1_stats.null_count, Precision::Exact(3)); // 1 + 2 + assert_eq!( + col1_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(120))) + ); + assert_eq!( + col1_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(-10))) + ); + assert_eq!( + col1_stats.sum_value, + Precision::Exact(ScalarValue::Int32(Some(1100))) + ); // 500 + 600 + + let col2_stats = &summary_stats.column_statistics[1]; + assert_eq!(col2_stats.null_count, Precision::Exact(5)); // 2 + 3 + assert_eq!( + col2_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(200))) + ); + assert_eq!( + col2_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(5))) + ); + assert_eq!( + col2_stats.sum_value, + Precision::Exact(ScalarValue::Int32(Some(2200))) + ); // 1000 + 1200 + } + + #[test] + fn test_try_merge_mixed_precision() { + // Create a schema with one column + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + // Create items with different precision levels + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Inexact(15), + total_byte_size: Precision::Exact(150), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(2), + max_value: Precision::Inexact(ScalarValue::Int32(Some(120))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }], + }; + + let items = vec![stats1, stats2]; + + let summary_stats = Statistics::try_merge_iter(&items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Inexact(25)); + assert_eq!(summary_stats.total_byte_size, Precision::Inexact(250)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Inexact(3)); + assert_eq!( + col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(120))) + ); + assert_eq!( + col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-10))) + ); + assert!(matches!(col_stats.sum_value, Precision::Absent)); + } + + #[test] + fn test_try_merge_empty() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + // Empty collection + let items: Vec = vec![]; + + let summary_stats = Statistics::try_merge_iter(&items, &schema).unwrap(); + + // Verify default values for empty collection + assert_eq!(summary_stats.num_rows, Precision::Absent); + assert_eq!(summary_stats.total_byte_size, Precision::Absent); + assert_eq!(summary_stats.column_statistics.len(), 1); + assert_eq!( + summary_stats.column_statistics[0].null_count, + Precision::Absent + ); + } + + #[test] + fn test_try_merge_mismatched_size() { + // Create a schema with one column + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + // No column statistics + let stats1 = Statistics::default(); + + let stats2 = + Statistics::default().add_column_statistics(ColumnStatistics::new_unknown()); + + let items = vec![stats1, stats2]; + + let e = Statistics::try_merge_iter(&items, &schema).unwrap_err(); + assert_contains!(e.to_string(), "Error during planning: Cannot merge statistics with different number of columns: 0 vs 1"); + } + + #[test] + fn test_try_merge_distinct_count_absent() { + // Create statistics with known distinct counts + let stats1 = Statistics::default() + .with_num_rows(Precision::Exact(10)) + .with_total_byte_size(Precision::Exact(100)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_null_count(Precision::Exact(0)) + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(1)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + .with_distinct_count(Precision::Exact(5)), + ); + + let stats2 = Statistics::default() + .with_num_rows(Precision::Exact(15)) + .with_total_byte_size(Precision::Exact(150)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_null_count(Precision::Exact(0)) + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(20)))) + .with_distinct_count(Precision::Exact(7)), + ); + + // Merge statistics + let merged_stats = stats1.try_merge(&stats2).unwrap(); + + // Verify the results + assert_eq!(merged_stats.num_rows, Precision::Exact(25)); + assert_eq!(merged_stats.total_byte_size, Precision::Exact(250)); + + let col_stats = &merged_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Exact(0)); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(1))) + ); + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(20))) + ); + // Distinct count should be Absent after merge + assert_eq!(col_stats.distinct_count, Precision::Absent); + } + + #[test] + fn test_with_fetch_basic_preservation() { + // Test that column statistics and byte size are preserved (as inexact) when applying fetch + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(10), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(0))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(5050))), + distinct_count: Precision::Exact(50), + }, + ColumnStatistics { + null_count: Precision::Exact(20), + max_value: Precision::Exact(ScalarValue::Int64(Some(200))), + min_value: Precision::Exact(ScalarValue::Int64(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(10100))), + distinct_count: Precision::Exact(75), + }, + ], + }; + + // Apply fetch of 100 rows (10% of original) + let result = original_stats.clone().with_fetch(Some(100), 0, 1).unwrap(); + + // Check num_rows + assert_eq!(result.num_rows, Precision::Exact(100)); + + // Check total_byte_size is scaled proportionally and marked as inexact + // 100/1000 = 0.1, so 8000 * 0.1 = 800 + assert_eq!(result.total_byte_size, Precision::Inexact(800)); + + // Check column statistics are preserved but marked as inexact + assert_eq!(result.column_statistics.len(), 2); + + // First column + assert_eq!( + result.column_statistics[0].null_count, + Precision::Inexact(10) + ); + assert_eq!( + result.column_statistics[0].max_value, + Precision::Inexact(ScalarValue::Int32(Some(100))) + ); + assert_eq!( + result.column_statistics[0].min_value, + Precision::Inexact(ScalarValue::Int32(Some(0))) + ); + assert_eq!( + result.column_statistics[0].sum_value, + Precision::Inexact(ScalarValue::Int32(Some(5050))) + ); + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(50) + ); + + // Second column + assert_eq!( + result.column_statistics[1].null_count, + Precision::Inexact(20) + ); + assert_eq!( + result.column_statistics[1].max_value, + Precision::Inexact(ScalarValue::Int64(Some(200))) + ); + assert_eq!( + result.column_statistics[1].min_value, + Precision::Inexact(ScalarValue::Int64(Some(10))) + ); + assert_eq!( + result.column_statistics[1].sum_value, + Precision::Inexact(ScalarValue::Int64(Some(10100))) + ); + assert_eq!( + result.column_statistics[1].distinct_count, + Precision::Inexact(75) + ); + } + + #[test] + fn test_with_fetch_inexact_input() { + // Test that inexact input statistics remain inexact + let original_stats = Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(8000), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(10), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(0))), + sum_value: Precision::Inexact(ScalarValue::Int32(Some(5050))), + distinct_count: Precision::Inexact(50), + }], + }; + + let result = original_stats.clone().with_fetch(Some(500), 0, 1).unwrap(); + + // Check num_rows is inexact + assert_eq!(result.num_rows, Precision::Inexact(500)); + + // Check total_byte_size is scaled and inexact + // 500/1000 = 0.5, so 8000 * 0.5 = 4000 + assert_eq!(result.total_byte_size, Precision::Inexact(4000)); + + // Column stats remain inexact + assert_eq!( + result.column_statistics[0].null_count, + Precision::Inexact(10) + ); + } + + #[test] + fn test_with_fetch_skip_all_rows() { + // Test when skip >= num_rows (all rows are skipped) + let original_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(800), + column_statistics: vec![col_stats_i64(10)], + }; + + let result = original_stats.clone().with_fetch(Some(50), 100, 1).unwrap(); + + assert_eq!(result.num_rows, Precision::Exact(0)); + // When ratio is 0/100 = 0, byte size should be 0 + assert_eq!(result.total_byte_size, Precision::Inexact(0)); + } + + #[test] + fn test_with_fetch_no_limit() { + // Test when fetch is None and skip is 0 (no limit applied) + let original_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(800), + column_statistics: vec![col_stats_i64(10)], + }; + + let result = original_stats.clone().with_fetch(None, 0, 1).unwrap(); + + // Stats should be unchanged when no fetch and no skip + assert_eq!(result.num_rows, Precision::Exact(100)); + assert_eq!(result.total_byte_size, Precision::Exact(800)); + } + + #[test] + fn test_with_fetch_with_skip() { + // Test with both skip and fetch + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![col_stats_i64(10)], + }; + + // Skip 200, fetch 300, so we get rows 200-500 + let result = original_stats + .clone() + .with_fetch(Some(300), 200, 1) + .unwrap(); + + assert_eq!(result.num_rows, Precision::Exact(300)); + // 300/1000 = 0.3, so 8000 * 0.3 = 2400 + assert_eq!(result.total_byte_size, Precision::Inexact(2400)); + } + + #[test] + fn test_with_fetch_multi_partition() { + // Test with multiple partitions + let original_stats = Statistics { + num_rows: Precision::Exact(1000), // per partition + total_byte_size: Precision::Exact(8000), + column_statistics: vec![col_stats_i64(10)], + }; + + // Fetch 100 per partition, 4 partitions = 400 total + let result = original_stats.clone().with_fetch(Some(100), 0, 4).unwrap(); + + assert_eq!(result.num_rows, Precision::Exact(400)); + // 400/1000 = 0.4, so 8000 * 0.4 = 3200 + assert_eq!(result.total_byte_size, Precision::Inexact(3200)); + } + + #[test] + fn test_with_fetch_absent_stats() { + // Test with absent statistics + let original_stats = Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }], + }; + + let result = original_stats.clone().with_fetch(Some(100), 0, 1).unwrap(); + + // With absent input stats, output should be inexact estimate + assert_eq!(result.num_rows, Precision::Inexact(100)); + assert_eq!(result.total_byte_size, Precision::Absent); + // Column stats should remain absent + assert_eq!(result.column_statistics[0].null_count, Precision::Absent); + } + + #[test] + fn test_with_fetch_fetch_exceeds_rows() { + // Test when fetch is larger than available rows after skip + let original_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(800), + column_statistics: vec![col_stats_i64(10)], + }; + + // Skip 50, fetch 100, but only 50 rows remain + let result = original_stats.clone().with_fetch(Some(100), 50, 1).unwrap(); + + assert_eq!(result.num_rows, Precision::Exact(50)); + // 50/100 = 0.5, so 800 * 0.5 = 400 + assert_eq!(result.total_byte_size, Precision::Inexact(400)); + } + + #[test] + fn test_with_fetch_preserves_all_column_stats() { + // Comprehensive test that all column statistic fields are preserved + let original_col_stats = ColumnStatistics { + null_count: Precision::Exact(42), + max_value: Precision::Exact(ScalarValue::Int32(Some(999))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-100))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(123456))), + distinct_count: Precision::Exact(789), + }; + + let original_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![original_col_stats.clone()], + }; + + let result = original_stats.with_fetch(Some(250), 0, 1).unwrap(); + + let result_col_stats = &result.column_statistics[0]; + + // All values should be preserved but marked as inexact + assert_eq!(result_col_stats.null_count, Precision::Inexact(42)); + assert_eq!( + result_col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(999))) + ); + assert_eq!( + result_col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-100))) + ); + assert_eq!( + result_col_stats.sum_value, + Precision::Inexact(ScalarValue::Int32(Some(123456))) + ); + assert_eq!(result_col_stats.distinct_count, Precision::Inexact(789)); + } } diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index 9b6f9696c00bb..7cf8e7af1a794 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::{parse_identifiers_normalized, quote_identifier}; +use crate::utils::parse_identifiers_normalized; +use crate::utils::quote_identifier; use std::sync::Arc; /// A fully resolved path to a table of the form "catalog.schema.table" @@ -367,26 +368,32 @@ mod tests { let actual = TableReference::from("TABLE"); assert_eq!(expected, actual); - // if fail to parse, take entire input string as identifier - let expected = TableReference::Bare { - table: "TABLE()".into(), - }; - let actual = TableReference::from("TABLE()"); - assert_eq!(expected, actual); + // Disable this test for non-sql features so that we don't need to reproduce + // things like table function upper case conventions, since those will not + // be used if SQL is not selected. + #[cfg(feature = "sql")] + { + // if fail to parse, take entire input string as identifier + let expected = TableReference::Bare { + table: "TABLE()".into(), + }; + let actual = TableReference::from("TABLE()"); + assert_eq!(expected, actual); + } } #[test] fn test_table_reference_to_vector() { - let table_reference = TableReference::parse_str("table"); + let table_reference = TableReference::from("table"); assert_eq!(vec!["table".to_string()], table_reference.to_vec()); - let table_reference = TableReference::parse_str("schema.table"); + let table_reference = TableReference::from("schema.table"); assert_eq!( vec!["schema".to_string(), "table".to_string()], table_reference.to_vec() ); - let table_reference = TableReference::parse_str("catalog.schema.table"); + let table_reference = TableReference::from("catalog.schema.table"); assert_eq!( vec![ "catalog".to_string(), diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index b801c452af2c9..d97d4003e7292 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -18,10 +18,25 @@ //! Utility functions to make testing DataFusion based crates easier use crate::arrow::util::pretty::pretty_format_batches_with_options; -use crate::format::DEFAULT_FORMAT_OPTIONS; -use arrow::array::RecordBatch; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::error::ArrowError; +use std::fmt::Display; use std::{error::Error, path::PathBuf}; +/// Converts a vector or array into an ArrayRef. +pub trait IntoArrayRef { + fn into_array_ref(self) -> ArrayRef; +} + +pub fn format_batches(results: &[RecordBatch]) -> Result { + let datafusion_format_options = crate::config::FormatOptions::default(); + + let arrow_format_options: arrow::util::display::FormatOptions = + (&datafusion_format_options).try_into().unwrap(); + + pretty_format_batches_with_options(results, &arrow_format_options) +} + /// Compares formatted output of a record batch with an expected /// vector of strings, with the result of pretty formatting record /// batches. This is a macro so errors appear on the correct line @@ -59,12 +74,9 @@ macro_rules! assert_batches_eq { let expected_lines: Vec = $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); - let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options( - $CHUNKS, - &$crate::format::DEFAULT_FORMAT_OPTIONS, - ) - .unwrap() - .to_string(); + let formatted = $crate::test_util::format_batches($CHUNKS) + .unwrap() + .to_string(); let actual_lines: Vec<&str> = formatted.trim().lines().collect(); @@ -77,18 +89,13 @@ macro_rules! assert_batches_eq { } pub fn batches_to_string(batches: &[RecordBatch]) -> String { - let actual = pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS) - .unwrap() - .to_string(); + let actual = format_batches(batches).unwrap().to_string(); actual.trim().to_string() } pub fn batches_to_sort_string(batches: &[RecordBatch]) -> String { - let actual_lines = - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS) - .unwrap() - .to_string(); + let actual_lines = format_batches(batches).unwrap().to_string(); let mut actual_lines: Vec<&str> = actual_lines.trim().lines().collect(); @@ -122,12 +129,9 @@ macro_rules! assert_batches_sorted_eq { expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() } - let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options( - $CHUNKS, - &$crate::format::DEFAULT_FORMAT_OPTIONS, - ) - .unwrap() - .to_string(); + let formatted = $crate::test_util::format_batches($CHUNKS) + .unwrap() + .to_string(); // fix for windows: \r\n --> let mut actual_lines: Vec<&str> = formatted.trim().lines().collect(); @@ -154,7 +158,7 @@ macro_rules! assert_batches_sorted_eq { /// Is a macro so test error /// messages are on the same line as the failure; /// -/// Both arguments must be convertable into Strings ([`Into`]<[`String`]>) +/// Both arguments must be convertible into Strings ([`Into`]<[`String`]>) #[macro_export] macro_rules! assert_contains { ($ACTUAL: expr, $EXPECTED: expr) => { @@ -177,7 +181,7 @@ macro_rules! assert_contains { /// Is a macro so test error /// messages are on the same line as the failure; /// -/// Both arguments must be convertable into Strings ([`Into`]<[`String`]>) +/// Both arguments must be convertible into Strings ([`Into`]<[`String`]>) #[macro_export] macro_rules! assert_not_contains { ($ACTUAL: expr, $UNEXPECTED: expr) => { @@ -251,7 +255,14 @@ pub fn arrow_test_data() -> String { #[cfg(feature = "parquet")] pub fn parquet_test_data() -> String { match get_data_dir("PARQUET_TEST_DATA", "../../parquet-testing/data") { - Ok(pb) => pb.display().to_string(), + Ok(pb) => { + let mut path = pb.display().to_string(); + if cfg!(target_os = "windows") { + // Replace backslashes (Windows paths; avoids some test issues). + path = path.replace("\\", "/"); + } + path + } Err(err) => panic!("failed to get parquet data dir: {err}"), } } @@ -310,43 +321,43 @@ pub fn get_data_dir( #[macro_export] macro_rules! create_array { (Boolean, $values: expr) => { - std::sync::Arc::new(arrow::array::BooleanArray::from($values)) + std::sync::Arc::new($crate::arrow::array::BooleanArray::from($values)) }; (Int8, $values: expr) => { - std::sync::Arc::new(arrow::array::Int8Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Int8Array::from($values)) }; (Int16, $values: expr) => { - std::sync::Arc::new(arrow::array::Int16Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Int16Array::from($values)) }; (Int32, $values: expr) => { - std::sync::Arc::new(arrow::array::Int32Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Int32Array::from($values)) }; (Int64, $values: expr) => { - std::sync::Arc::new(arrow::array::Int64Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Int64Array::from($values)) }; (UInt8, $values: expr) => { - std::sync::Arc::new(arrow::array::UInt8Array::from($values)) + std::sync::Arc::new($crate::arrow::array::UInt8Array::from($values)) }; (UInt16, $values: expr) => { - std::sync::Arc::new(arrow::array::UInt16Array::from($values)) + std::sync::Arc::new($crate::arrow::array::UInt16Array::from($values)) }; (UInt32, $values: expr) => { - std::sync::Arc::new(arrow::array::UInt32Array::from($values)) + std::sync::Arc::new($crate::arrow::array::UInt32Array::from($values)) }; (UInt64, $values: expr) => { - std::sync::Arc::new(arrow::array::UInt64Array::from($values)) + std::sync::Arc::new($crate::arrow::array::UInt64Array::from($values)) }; (Float16, $values: expr) => { - std::sync::Arc::new(arrow::array::Float16Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Float16Array::from($values)) }; (Float32, $values: expr) => { - std::sync::Arc::new(arrow::array::Float32Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Float32Array::from($values)) }; (Float64, $values: expr) => { - std::sync::Arc::new(arrow::array::Float64Array::from($values)) + std::sync::Arc::new($crate::arrow::array::Float64Array::from($values)) }; (Utf8, $values: expr) => { - std::sync::Arc::new(arrow::array::StringArray::from($values)) + std::sync::Arc::new($crate::arrow::array::StringArray::from($values)) }; } @@ -355,7 +366,7 @@ macro_rules! create_array { /// /// Example: /// ``` -/// use datafusion_common::{record_batch, create_array}; +/// use datafusion_common::record_batch; /// let batch = record_batch!( /// ("a", Int32, vec![1, 2, 3]), /// ("b", Float64, vec![Some(4.0), None, Some(5.0)]), @@ -366,13 +377,13 @@ macro_rules! create_array { macro_rules! record_batch { ($(($name: expr, $type: ident, $values: expr)),*) => { { - let schema = std::sync::Arc::new(arrow::datatypes::Schema::new(vec![ + let schema = std::sync::Arc::new($crate::arrow::datatypes::Schema::new(vec![ $( - arrow::datatypes::Field::new($name, arrow::datatypes::DataType::$type, true), + $crate::arrow::datatypes::Field::new($name, $crate::arrow::datatypes::DataType::$type, true), )* ])); - let batch = arrow::array::RecordBatch::try_new( + let batch = $crate::arrow::array::RecordBatch::try_new( schema, vec![$( $crate::create_array!($type, $values), @@ -384,6 +395,326 @@ macro_rules! record_batch { } } +pub mod array_conversion { + use arrow::array::ArrayRef; + + use super::IntoArrayRef; + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self) + } + } + + impl IntoArrayRef for &[bool] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self) + } + } + + impl IntoArrayRef for &[i8] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self) + } + } + + impl IntoArrayRef for &[i16] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self) + } + } + + impl IntoArrayRef for &[i32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self) + } + } + + impl IntoArrayRef for &[i64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self) + } + } + + impl IntoArrayRef for &[u8] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self) + } + } + + impl IntoArrayRef for &[u16] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self) + } + } + + impl IntoArrayRef for &[u32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self) + } + } + + impl IntoArrayRef for &[u64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self.to_vec()) + } + } + + //#TODO add impl for f16 + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self) + } + } + + impl IntoArrayRef for &[f32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self) + } + } + + impl IntoArrayRef for &[f64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self.to_vec()) + } + } + + impl IntoArrayRef for Vec<&str> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for &[&str] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option<&str>] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for &[String] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } +} + #[cfg(test)] mod tests { use crate::cast::{as_float64_array, as_int32_array, as_string_array}; diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c70389b631773..ea0aa28c938d5 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -680,6 +680,11 @@ impl Transformed { Self::new(data, true, TreeNodeRecursion::Continue) } + /// Wrapper for transformed data with [`TreeNodeRecursion::Stop`] statement. + pub fn complete(data: T) -> Self { + Self::new(data, true, TreeNodeRecursion::Stop) + } + /// Wrapper for unchanged data with [`TreeNodeRecursion::Continue`] statement. pub fn no(data: T) -> Self { Self::new(data, false, TreeNodeRecursion::Continue) @@ -985,6 +990,48 @@ impl< } } +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + C3: TreeNodeContainer<'a, T>, + > TreeNodeContainer<'a, T> for (C0, C1, C2, C3) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f))? + .visit_sibling(|| self.3.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1, self.2, self.3)))? + .transform_sibling(|(new_c0, c1, c2, c3)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1, c2, c3))) + })? + .transform_sibling(|(new_c0, new_c1, c2, c3)| { + c2.map_elements(&mut f)? + .map_data(|new_c2| Ok((new_c0, new_c1, new_c2, c3))) + })? + .transform_sibling(|(new_c0, new_c1, new_c2, c3)| { + c3.map_elements(&mut f)? + .map_data(|new_c3| Ok((new_c0, new_c1, new_c2, new_c3))) + }) + } +} + /// [`TreeNodeRefContainer`] contains references to elements that a function can be /// applied on. The elements of the container are siblings so the continuation rules are /// similar to [`TreeNodeRecursion::visit_sibling`]. @@ -1060,6 +1107,27 @@ impl< } } +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + C3: TreeNodeContainer<'a, T>, + > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f))? + .visit_sibling(|| self.3.apply_elements(&mut f)) + } +} + /// Transformation helper to process a sequence of iterable tree nodes that are siblings. pub trait TreeNodeIterator: Iterator { /// Apples `f` to each item in this iterator @@ -2354,7 +2422,7 @@ pub(crate) mod tests { fn test_large_tree() { let mut item = TestTreeNode::new_leaf("initial".to_string()); for i in 0..3000 { - item = TestTreeNode::new(vec![item], format!("parent-{}", i)); + item = TestTreeNode::new(vec![item], format!("parent-{i}")); } let mut visitor = diff --git a/datafusion/common/src/types/logical.rs b/datafusion/common/src/types/logical.rs index 884ce20fd9e29..eb7cf88e00753 100644 --- a/datafusion/common/src/types/logical.rs +++ b/datafusion/common/src/types/logical.rs @@ -106,6 +106,7 @@ impl std::fmt::Display for dyn LogicalType { impl PartialEq for dyn LogicalType { fn eq(&self, other: &Self) -> bool { + // Logical types with identical signatures are considered equal. self.signature().eq(&other.signature()) } } @@ -120,15 +121,14 @@ impl PartialOrd for dyn LogicalType { impl Ord for dyn LogicalType { fn cmp(&self, other: &Self) -> Ordering { - self.signature() - .cmp(&other.signature()) - .then(self.native().cmp(other.native())) + // Logical types with identical signatures are considered equal. + self.signature().cmp(&other.signature()) } } impl Hash for dyn LogicalType { fn hash(&self, state: &mut H) { + // Logical types with identical signatures are considered equal. self.signature().hash(state); - self.native().hash(state); } } diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 39c79b4b99742..5cef0adfbde80 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -23,6 +23,7 @@ use crate::error::{Result, _internal_err}; use arrow::compute::can_cast_types; use arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, + DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; use std::{fmt::Display, sync::Arc}; @@ -185,7 +186,7 @@ pub enum NativeType { impl Display for NativeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "NativeType::{self:?}") + write!(f, "{self:?}") // TODO: nicer formatting } } @@ -228,7 +229,15 @@ impl LogicalType for NativeType { (Self::Float16, _) => Float16, (Self::Float32, _) => Float32, (Self::Float64, _) => Float64, - (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), + (Self::Decimal(p, s), _) if *p <= DECIMAL32_MAX_PRECISION => { + Decimal32(*p, *s) + } + (Self::Decimal(p, s), _) if *p <= DECIMAL64_MAX_PRECISION => { + Decimal64(*p, *s) + } + (Self::Decimal(p, s), _) if *p <= DECIMAL128_MAX_PRECISION => { + Decimal128(*p, *s) + } (Self::Decimal(p, s), _) => Decimal256(*p, *s), (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), // If given type is Date, return the same type @@ -352,10 +361,10 @@ impl LogicalType for NativeType { } _ => { return _internal_err!( - "Unavailable default cast for native type {:?} from physical type {:?}", - self, - origin - ) + "Unavailable default cast for native type {} from physical type {}", + self, + origin + ) } }) } @@ -407,7 +416,10 @@ impl From for NativeType { DataType::Union(union_fields, _) => { Union(LogicalUnionFields::from(&union_fields)) } - DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s), + DataType::Decimal32(p, s) + | DataType::Decimal64(p, s) + | DataType::Decimal128(p, s) + | DataType::Decimal256(p, s) => Decimal(p, s), DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())), DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), DataType::RunEndEncoded(_, field) => field.data_type().clone().into(), @@ -469,4 +481,9 @@ impl NativeType { pub fn is_duration(&self) -> bool { matches!(self, NativeType::Duration(_)) } + + #[inline] + pub fn is_binary(&self) -> bool { + matches!(self, NativeType::Binary | NativeType::FixedSizeBinary(_)) + } } diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs index ab73996fcd8b7..29e523996cf4c 100644 --- a/datafusion/common/src/utils/memory.rs +++ b/datafusion/common/src/utils/memory.rs @@ -17,7 +17,8 @@ //! This module provides a function to estimate the memory size of a HashTable prior to allocation -use crate::{DataFusionError, Result}; +use crate::error::_exec_datafusion_err; +use crate::Result; use std::mem::size_of; /// Estimates the memory size required for a hash table prior to allocation. @@ -25,7 +26,7 @@ use std::mem::size_of; /// # Parameters /// - `num_elements`: The number of elements expected in the hash table. /// - `fixed_size`: A fixed overhead size associated with the collection -/// (e.g., HashSet or HashTable). +/// (e.g., HashSet or HashTable). /// - `T`: The type of elements stored in the hash table. /// /// # Details @@ -36,7 +37,7 @@ use std::mem::size_of; /// buckets. /// - One byte overhead for each bucket. /// - The fixed size overhead of the collection. -/// - If the estimation overflows, we return a [`DataFusionError`] +/// - If the estimation overflows, we return a [`crate::error::DataFusionError`] /// /// # Examples /// --- @@ -94,9 +95,7 @@ pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result .checked_add(fixed_size) }) .ok_or_else(|| { - DataFusionError::Execution( - "usize overflow while estimating the number of buckets".to_string(), - ) + _exec_datafusion_err!("usize overflow while estimating the number of buckets") }) } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 409f248621f7f..c72e3b3a8df74 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -23,7 +23,7 @@ pub mod proxy; pub mod string_utils; use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; -use crate::{DataFusionError, Result, ScalarValue}; +use crate::{Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, @@ -31,9 +31,8 @@ use arrow::array::{ use arrow::buffer::OffsetBuffer; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field, SchemaRef}; -use sqlparser::ast::Ident; -use sqlparser::dialect::GenericDialect; -use sqlparser::parser::Parser; +#[cfg(feature = "sql")] +use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; use std::cmp::{min, Ordering}; use std::collections::HashSet; @@ -120,14 +119,13 @@ pub fn compare_rows( let result = match (lhs.is_null(), rhs.is_null(), sort_options.nulls_first) { (true, false, false) | (false, true, true) => Ordering::Greater, (true, false, true) | (false, true, false) => Ordering::Less, - (false, false, _) => if sort_options.descending { - rhs.partial_cmp(lhs) - } else { - lhs.partial_cmp(rhs) + (false, false, _) => { + if sort_options.descending { + rhs.try_cmp(lhs)? + } else { + lhs.try_cmp(rhs)? + } } - .ok_or_else(|| { - _internal_datafusion_err!("Column array shouldn't be empty") - })?, (true, true, _) => continue, }; if result != Ordering::Equal { @@ -149,9 +147,7 @@ pub fn bisect( let low: usize = 0; let high: usize = item_columns .first() - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? + .ok_or_else(|| _internal_datafusion_err!("Column array shouldn't be empty"))? .len(); let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare_rows(current, target, sort_options)?; @@ -200,9 +196,7 @@ pub fn linear_search( let low: usize = 0; let high: usize = item_columns .first() - .ok_or_else(|| { - DataFusionError::Internal("Column array shouldn't be empty".to_string()) - })? + .ok_or_else(|| _internal_datafusion_err!("Column array shouldn't be empty"))? .len(); let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare_rows(current, target, sort_options)?; @@ -261,7 +255,7 @@ pub fn evaluate_partition_ranges( /// the identifier by replacing it with two double quotes /// /// e.g. identifier `tab.le"name` becomes `"tab.le""name"` -pub fn quote_identifier(s: &str) -> Cow { +pub fn quote_identifier(s: &str) -> Cow<'_, str> { if needs_quotes(s) { Cow::Owned(format!("\"{}\"", s.replace('"', "\"\""))) } else { @@ -283,6 +277,7 @@ fn needs_quotes(s: &str) -> bool { !chars.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_') } +#[cfg(feature = "sql")] pub(crate) fn parse_identifiers(s: &str) -> Result> { let dialect = GenericDialect; let mut parser = Parser::new(&dialect).try_with_sql(s)?; @@ -290,6 +285,7 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { Ok(idents) } +#[cfg(feature = "sql")] pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() @@ -302,6 +298,59 @@ pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec>() } +#[cfg(not(feature = "sql"))] +pub(crate) fn parse_identifiers(s: &str) -> Result> { + let mut result = Vec::new(); + let mut current = String::new(); + let mut in_quotes = false; + + for ch in s.chars() { + match ch { + '"' => { + in_quotes = !in_quotes; + current.push(ch); + } + '.' if !in_quotes => { + result.push(current.clone()); + current.clear(); + } + _ => { + current.push(ch); + } + } + } + + // Push the last part if it's not empty + if !current.is_empty() { + result.push(current); + } + + Ok(result) +} + +#[cfg(not(feature = "sql"))] +pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { + parse_identifiers(s) + .unwrap_or_default() + .into_iter() + .map(|id| { + let is_double_quoted = if id.len() > 2 { + let mut chars = id.chars(); + chars.next() == Some('"') && chars.last() == Some('"') + } else { + false + }; + if is_double_quoted { + id[1..id.len() - 1].to_string().replace("\"\"", "\"") + } else if ignore_case { + id + } else { + id.to_ascii_lowercase() + } + }) + .collect::>() +} + /// This function "takes" the elements at `indices` from the slice `items`. pub fn get_at_indices>( items: &[T], @@ -312,9 +361,7 @@ pub fn get_at_indices>( .map(|idx| items.get(*idx.borrow()).cloned()) .collect::>>() .ok_or_else(|| { - DataFusionError::Execution( - "Expects indices to be in the range of searched vector".to_string(), - ) + _exec_datafusion_err!("Expects indices to be in the range of searched vector") }) } @@ -445,94 +492,6 @@ impl SingleRowListArrayBuilder { } } -/// Wrap an array into a single element `ListArray`. -/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -/// The field in the list array is nullable. -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_list_array_nullable(arr: ArrayRef) -> ListArray { - SingleRowListArrayBuilder::new(arr) - .with_nullable(true) - .build_list_array() -} - -/// Wrap an array into a single element `ListArray`. -/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { - SingleRowListArrayBuilder::new(arr) - .with_nullable(nullable) - .build_list_array() -} - -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_list_array_with_field_name( - arr: ArrayRef, - nullable: bool, - field_name: &str, -) -> ListArray { - SingleRowListArrayBuilder::new(arr) - .with_nullable(nullable) - .with_field_name(Some(field_name.to_string())) - .build_list_array() -} - -/// Wrap an array into a single element `LargeListArray`. -/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { - SingleRowListArrayBuilder::new(arr).build_large_list_array() -} - -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_large_list_array_with_field_name( - arr: ArrayRef, - field_name: &str, -) -> LargeListArray { - SingleRowListArrayBuilder::new(arr) - .with_field_name(Some(field_name.to_string())) - .build_large_list_array() -} - -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_fixed_size_list_array( - arr: ArrayRef, - list_size: usize, -) -> FixedSizeListArray { - SingleRowListArrayBuilder::new(arr).build_fixed_size_list_array(list_size) -} - -#[deprecated( - since = "44.0.0", - note = "please use `SingleRowListArrayBuilder` instead" -)] -pub fn array_into_fixed_size_list_array_with_field_name( - arr: ArrayRef, - list_size: usize, - field_name: &str, -) -> FixedSizeListArray { - SingleRowListArrayBuilder::new(arr) - .with_field_name(Some(field_name.to_string())) - .build_fixed_size_list_array(list_size) -} - /// Wrap arrays into a single element `ListArray`. /// /// Example: @@ -833,21 +792,6 @@ pub fn set_difference, S: Borrow>( .collect() } -/// Checks whether the given index sequence is monotonically non-decreasing. -#[deprecated(since = "45.0.0", note = "Use std::Iterator::is_sorted instead")] -pub fn is_sorted>(sequence: impl IntoIterator) -> bool { - // TODO: Remove this function when `is_sorted` graduates from Rust nightly. - let mut previous = 0; - for item in sequence.into_iter() { - let current = *item.borrow(); - if current < previous { - return false; - } - previous = current; - } - true -} - /// Find indices of each element in `targets` inside `items`. If one of the /// elements is absent in `items`, returns an error. pub fn find_indices>( @@ -858,7 +802,7 @@ pub fn find_indices>( .into_iter() .map(|target| items.iter().position(|e| target.borrow().eq(e))) .collect::>() - .ok_or_else(|| DataFusionError::Execution("Target not found".to_string())) + .ok_or_else(|| _exec_datafusion_err!("Target not found")) } /// Transposes the given vector of vectors. @@ -950,7 +894,7 @@ pub fn get_available_parallelism() -> usize { .get() } -/// Converts a collection of function arguments into an fixed-size array of length N +/// Converts a collection of function arguments into a fixed-size array of length N /// producing a reasonable error message in case of unexpected number of arguments. /// /// # Example @@ -994,6 +938,7 @@ mod tests { use super::*; use crate::ScalarValue::Null; use arrow::array::Float64Array; + use sqlparser::ast::Ident; use sqlparser::tokenizer::Span; #[test] @@ -1190,6 +1135,7 @@ mod tests { Ok(()) } + #[cfg(feature = "sql")] #[test] fn test_quote_identifier() -> Result<()> { let cases = vec![ @@ -1275,19 +1221,6 @@ mod tests { assert_eq!(set_difference([3, 4, 0], [4, 1, 2]), vec![3, 0]); } - #[test] - #[expect(deprecated)] - fn test_is_sorted() { - assert!(is_sorted::([])); - assert!(is_sorted([0])); - assert!(is_sorted([0, 3, 4])); - assert!(is_sorted([0, 1, 2])); - assert!(is_sorted([0, 1, 4])); - assert!(is_sorted([0usize; 0])); - assert!(is_sorted([1, 2])); - assert!(!is_sorted([3, 2])); - } - #[test] fn test_find_indices() -> Result<()> { assert_eq!(find_indices(&[0, 3, 4], [0, 3, 4])?, vec![0, 1, 2]); diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 56698e4d7e255..d3bc4546588de 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -47,6 +47,7 @@ compression = [ "bzip2", "flate2", "zstd", + "arrow-ipc/zstd", "datafusion-datasource/compression", ] crypto_expressions = ["datafusion-functions/crypto_expressions"] @@ -62,12 +63,19 @@ default = [ "compression", "parquet", "recursive_protection", + "sql", ] encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion-physical-plan/force_hash_collisions", "datafusion-common/force_hash_collisions"] math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet", "datafusion-datasource-parquet"] +parquet_encryption = [ + "parquet", + "parquet/encryption", + "datafusion-common/parquet_encryption", + "datafusion-datasource-parquet/parquet_encryption", +] pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = [ "datafusion-functions/regex_expressions", @@ -78,6 +86,7 @@ recursive_protection = [ "datafusion-optimizer/recursive_protection", "datafusion-physical-optimizer/recursive_protection", "datafusion-sql/recursive_protection", + "sqlparser/recursive-protection", ] serde = [ "dep:serde", @@ -85,9 +94,15 @@ serde = [ # statements in `arrow-schema` crate "arrow-schema/serde", ] +sql = [ + "datafusion-common/sql", + "datafusion-functions-nested?/sql", + "datafusion-sql", + "sqlparser", +] string_expressions = ["datafusion-functions/string_expressions"] unicode_expressions = [ - "datafusion-sql/unicode_expressions", + "datafusion-sql?/unicode_expressions", "datafusion-functions/unicode_expressions", ] extended_tests = [] @@ -95,10 +110,10 @@ extended_tests = [] [dependencies] arrow = { workspace = true } arrow-ipc = { workspace = true } -arrow-schema = { workspace = true } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.5.2", optional = true } +bzip2 = { version = "0.6.0", optional = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } @@ -110,22 +125,22 @@ datafusion-datasource-csv = { workspace = true } datafusion-datasource-json = { workspace = true } datafusion-datasource-parquet = { workspace = true, optional = true } datafusion-execution = { workspace = true } -datafusion-expr = { workspace = true } +datafusion-expr = { workspace = true, default-features = false } datafusion-expr-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } -datafusion-functions-nested = { workspace = true, optional = true } +datafusion-functions-nested = { workspace = true, default-features = false, optional = true } datafusion-functions-table = { workspace = true } datafusion-functions-window = { workspace = true } -datafusion-macros = { workspace = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } -datafusion-sql = { workspace = true } -flate2 = { version = "1.1.0", optional = true } +datafusion-sql = { workspace = true, optional = true } +flate2 = { version = "1.1.4", optional = true } futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } @@ -134,12 +149,13 @@ parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } rand = { workspace = true } regex = { workspace = true } +rstest = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } -sqlparser = { workspace = true } +sqlparser = { workspace = true, optional = true } tempfile = { workspace = true } tokio = { workspace = true } url = { workspace = true } -uuid = { version = "1.16", features = ["v4", "js"] } +uuid = { version = "1.18", features = ["v4", "js"] } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } @@ -150,22 +166,27 @@ ctor = { workspace = true } dashmap = "6.1.0" datafusion-doc = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-macros = { workspace = true } datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } +glob = { version = "0.3.0" } insta = { workspace = true } paste = "^1.0" rand = { workspace = true, features = ["small_rng"] } -rand_distr = "0.4.3" +rand_distr = "0.5" regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } -sysinfo = "0.33.1" +sysinfo = "0.37.2" test-utils = { path = "../../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } +[package.metadata.cargo-machete] +ignored = ["datafusion-doc", "datafusion-macros", "dashmap"] + [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.29.0", features = ["fs"] } +nix = { version = "0.30.1", features = ["fs"] } [[bench]] harness = false @@ -179,6 +200,10 @@ name = "csv_load" harness = false name = "distinct_query_sql" +[[bench]] +harness = false +name = "push_down_filter" + [[bench]] harness = false name = "sort_limit_query_sql" diff --git a/datafusion/core/README.md b/datafusion/core/README.md index b5501087d2647..859fcb9c0dff9 100644 --- a/datafusion/core/README.md +++ b/datafusion/core/README.md @@ -17,15 +17,12 @@ under the License. --> -# DataFusion Core + -DataFusion is an extensible query execution framework, written in Rust, -that uses Apache Arrow as its in-memory format. +# Apache DataFusion Core This crate contains the main entry points and high level DataFusion APIs such as `SessionContext`, `DataFrame` and `ListingTable`. - -For more information, please see: - -- [DataFusion Website](https://datafusion.apache.org) -- [DataFusion API Docs](https://docs.rs/datafusion/latest/datafusion/) diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index b29bfc487340d..9da341ce2e926 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -153,12 +153,44 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function( + "aggregate_query_group_by_wide_u64_and_string_without_aggregate_expressions", + |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + // Due to the large number of distinct values in u64_wide, + // this query test the actual grouping performance for more than 1 column + "SELECT u64_wide, utf8 \ + FROM t GROUP BY u64_wide, utf8", + ) + }) + }, + ); + + c.bench_function( + "aggregate_query_group_by_wide_u64_and_f32_without_aggregate_expressions", + |b| { + b.iter(|| { + query( + ctx.clone(), + &rt, + // Due to the large number of distinct values in u64_wide, + // this query test the actual grouping performance for more than 1 column + "SELECT u64_wide, f32 \ + FROM t GROUP BY u64_wide, f32", + ) + }) + }, + ); + c.bench_function("aggregate_query_approx_percentile_cont_on_u64", |b| { b.iter(|| { query( ctx.clone(), &rt, - "SELECT utf8, approx_percentile_cont(u64_wide, 0.5, 2500) \ + "SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY u64_wide) \ FROM t GROUP BY utf8", ) }) @@ -169,7 +201,7 @@ fn criterion_benchmark(c: &mut Criterion) { query( ctx.clone(), &rt, - "SELECT utf8, approx_percentile_cont(f32, 0.5, 2500) \ + "SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY f32) \ FROM t GROUP BY utf8", ) }) diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 38f6a2c76df6d..fffe2e2d17522 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -19,14 +19,15 @@ use arrow::array::{ builder::{Int64Builder, StringBuilder}, - Float32Array, Float64Array, RecordBatch, StringArray, UInt64Array, + ArrayRef, Float32Array, Float64Array, RecordBatch, StringArray, StringViewBuilder, + UInt64Array, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion_common::DataFusionError; +use rand::prelude::IndexedRandom; use rand::rngs::StdRng; -use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; use rand_distr::Distribution; use rand_distr::{Normal, Pareto}; @@ -48,11 +49,6 @@ pub fn create_table_provider( MemTable::try_new(schema, partitions).map(Arc::new) } -/// create a seedable [`StdRng`](rand::StdRng) -fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - /// Create test data schema pub fn create_schema() -> Schema { Schema::new(vec![ @@ -72,29 +68,30 @@ pub fn create_schema() -> Schema { fn create_data(size: usize, null_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..size) .map(|_| { - if rng.gen::() > null_density { + if rng.random::() > null_density { None } else { - Some(rng.gen::()) + Some(rng.random::()) } }) .collect() } -fn create_integer_data(size: usize, value_density: f64) -> Vec> { - // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = seedable_rng(); - +fn create_integer_data( + rng: &mut StdRng, + size: usize, + value_density: f64, +) -> Vec> { (0..size) .map(|_| { - if rng.gen::() > value_density { + if rng.random::() > value_density { None } else { - Some(rng.gen::()) + Some(rng.random::()) } }) .collect() @@ -120,11 +117,11 @@ fn create_record_batch( let values = create_data(batch_size, 0.5); // Integer values between [0, u64::MAX]. - let integer_values_wide = create_integer_data(batch_size, 9.0); + let integer_values_wide = create_integer_data(rng, batch_size, 9.0); // Integer values between [0, 9]. let integer_values_narrow = (0..batch_size) - .map(|_| rng.gen_range(0_u64..10)) + .map(|_| rng.random_range(0_u64..10)) .collect::>(); RecordBatch::try_new( @@ -148,7 +145,7 @@ pub fn create_record_batches( partitions_len: usize, batch_size: usize, ) -> Vec> { - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..partitions_len) .map(|_| { (0..array_len / batch_size / partitions_len) @@ -158,6 +155,31 @@ pub fn create_record_batches( .collect::>() } +/// An enum that wraps either a regular StringBuilder or a GenericByteViewBuilder +/// so that both can be used interchangeably. +enum TraceIdBuilder { + Utf8(StringBuilder), + Utf8View(StringViewBuilder), +} + +impl TraceIdBuilder { + /// Append a value to the builder. + fn append_value(&mut self, value: &str) { + match self { + TraceIdBuilder::Utf8(builder) => builder.append_value(value), + TraceIdBuilder::Utf8View(builder) => builder.append_value(value), + } + } + + /// Finish building and return the ArrayRef. + fn finish(self) -> ArrayRef { + match self { + TraceIdBuilder::Utf8(mut builder) => Arc::new(builder.finish()), + TraceIdBuilder::Utf8View(mut builder) => Arc::new(builder.finish()), + } + } +} + /// Create time series data with `partition_cnt` partitions and `sample_cnt` rows per partition /// in ascending order, if `asc` is true, otherwise randomly sampled using a Pareto distribution #[allow(dead_code)] @@ -165,6 +187,7 @@ pub(crate) fn make_data( partition_cnt: i32, sample_cnt: i32, asc: bool, + use_view: bool, ) -> Result<(Arc, Vec>), DataFusionError> { // constants observed from trace data let simultaneous_group_cnt = 2000; @@ -177,14 +200,20 @@ pub(crate) fn make_data( let mut rng = rand::rngs::SmallRng::from_seed([0; 32]); // populate data - let schema = test_schema(); + let schema = test_schema(use_view); let mut partitions = vec![]; let mut cur_time = 16909000000000i64; for _ in 0..partition_cnt { - let mut id_builder = StringBuilder::new(); + // Choose the appropriate builder based on use_view. + let mut id_builder = if use_view { + TraceIdBuilder::Utf8View(StringViewBuilder::new()) + } else { + TraceIdBuilder::Utf8(StringBuilder::new()) + }; + let mut ts_builder = Int64Builder::new(); let gen_id = |rng: &mut rand::rngs::SmallRng| { - rng.gen::<[u8; 16]>() + rng.random::<[u8; 16]>() .iter() .fold(String::new(), |mut output, b| { let _ = write!(output, "{b:02X}"); @@ -200,7 +229,7 @@ pub(crate) fn make_data( .map(|_| gen_sample_cnt(&mut rng)) .collect::>(); for _ in 0..sample_cnt { - let random_index = rng.gen_range(0..simultaneous_group_cnt); + let random_index = rng.random_range(0..simultaneous_group_cnt); let trace_id = &mut group_ids[random_index]; let sample_cnt = &mut group_sample_cnts[random_index]; *sample_cnt -= 1; @@ -230,10 +259,19 @@ pub(crate) fn make_data( Ok((schema, partitions)) } -/// The Schema used by make_data -fn test_schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("trace_id", DataType::Utf8, false), - Field::new("timestamp_ms", DataType::Int64, false), - ])) +/// Returns a Schema based on the use_view flag +fn test_schema(use_view: bool) -> SchemaRef { + if use_view { + // Return Utf8View schema + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8View, false), + Field::new("timestamp_ms", DataType::Int64, false), + ])) + } else { + // Return regular Utf8 schema + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8, false), + Field::new("timestamp_ms", DataType::Int64, false), + ])) + } } diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs index 832553ebed82a..12eb34719e4ba 100644 --- a/datafusion/core/benches/dataframe.rs +++ b/datafusion/core/benches/dataframe.rs @@ -32,7 +32,7 @@ use tokio::runtime::Runtime; fn create_context(field_count: u32) -> datafusion_common::Result> { let mut fields = vec![]; for i in 0..field_count { - fields.push(Field::new(format!("str{}", i), DataType::Utf8, true)) + fields.push(Field::new(format!("str{i}"), DataType::Utf8, true)) } let schema = Arc::new(Schema::new(fields)); @@ -49,8 +49,8 @@ fn run(column_count: u32, ctx: Arc, rt: &Runtime) { let mut data_frame = ctx.table("t").await.unwrap(); for i in 0..column_count { - let field_name = &format!("str{}", i); - let new_field_name = &format!("newstr{}", i); + let field_name = &format!("str{i}"); + let new_field_name = &format!("newstr{i}"); data_frame = data_frame .with_column_renamed(field_name, new_field_name) diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index 4992ae6607666..c1ef55992689e 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -133,7 +133,8 @@ pub async fn create_context_sampled_data( partition_cnt: i32, sample_cnt: i32, ) -> Result<(Arc, Arc)> { - let (schema, parts) = make_data(partition_cnt, sample_cnt, false /* asc */).unwrap(); + let (schema, parts) = + make_data(partition_cnt, sample_cnt, false /* asc */, false).unwrap(); let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); // Create the DataFrame @@ -153,7 +154,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) @@ -167,7 +168,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) @@ -181,7 +182,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index 79229dfc2fbdb..063b8e6c86bbf 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -34,7 +34,7 @@ mod data_utils; fn build_keys(rng: &mut ThreadRng) -> Vec { let mut keys = vec![]; for _ in 0..1000 { - keys.push(rng.gen_range(0..9999).to_string()); + keys.push(rng.random_range(0..9999).to_string()); } keys } @@ -42,7 +42,7 @@ fn build_keys(rng: &mut ThreadRng) -> Vec { fn build_values(rng: &mut ThreadRng) -> Vec { let mut values = vec![]; for _ in 0..1000 { - values.push(rng.gen_range(0..9999)); + values.push(rng.random_range(0..9999)); } values } @@ -64,15 +64,18 @@ fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); let df = rt.block_on(ctx.lock().table("t")).unwrap(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let keys = build_keys(&mut rng); let values = build_values(&mut rng); let mut key_buffer = Vec::new(); let mut value_buffer = Vec::new(); for i in 0..1000 { - key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + key_buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } c.bench_function("map_1000_1", |b| { b.iter(|| { diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index f82a126c56520..14dcdf15f173b 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -29,9 +29,10 @@ use datafusion_common::instant::Instant; use futures::stream::StreamExt; use parquet::arrow::ArrowWriter; use parquet::file::properties::{WriterProperties, WriterVersion}; -use rand::distributions::uniform::SampleUniform; -use rand::distributions::Alphanumeric; +use rand::distr::uniform::SampleUniform; +use rand::distr::Alphanumeric; use rand::prelude::*; +use rand::rng; use std::fs::File; use std::io::Read; use std::ops::Range; @@ -97,13 +98,13 @@ fn generate_string_dictionary( len: usize, valid_percent: f64, ) -> ArrayRef { - let mut rng = thread_rng(); + let mut rng = rng(); let strings: Vec<_> = (0..cardinality).map(|x| format!("{prefix}#{x}")).collect(); Arc::new(DictionaryArray::::from_iter((0..len).map( |_| { - rng.gen_bool(valid_percent) - .then(|| strings[rng.gen_range(0..cardinality)].as_str()) + rng.random_bool(valid_percent) + .then(|| strings[rng.random_range(0..cardinality)].as_str()) }, ))) } @@ -113,10 +114,10 @@ fn generate_strings( len: usize, valid_percent: f64, ) -> ArrayRef { - let mut rng = thread_rng(); + let mut rng = rng(); Arc::new(StringArray::from_iter((0..len).map(|_| { - rng.gen_bool(valid_percent).then(|| { - let string_len = rng.gen_range(string_length_range.clone()); + rng.random_bool(valid_percent).then(|| { + let string_len = rng.random_range(string_length_range.clone()); (0..string_len) .map(|_| char::from(rng.sample(Alphanumeric))) .collect::() @@ -133,10 +134,10 @@ where T: ArrowPrimitiveType, T::Native: SampleUniform, { - let mut rng = thread_rng(); + let mut rng = rng(); Arc::new(PrimitiveArray::::from_iter((0..len).map(|_| { - rng.gen_bool(valid_percent) - .then(|| rng.gen_range(range.clone())) + rng.random_bool(valid_percent) + .then(|| rng.random_range(range.clone())) }))) } diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index 0a65c52f72def..e4838572f60fb 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -50,11 +50,8 @@ fn sort_preserving_merge_operator( let sort = sort .iter() - .map(|name| PhysicalSortExpr { - expr: col(name, &schema).unwrap(), - options: Default::default(), - }) - .collect::(); + .map(|name| PhysicalSortExpr::new_default(col(name, &schema).unwrap())); + let sort = LexOrdering::new(sort).unwrap(); let exec = MemorySourceConfig::try_new_exec( &batches.into_iter().map(|rb| vec![rb]).collect::>(), diff --git a/datafusion/core/benches/push_down_filter.rs b/datafusion/core/benches/push_down_filter.rs new file mode 100644 index 0000000000000..139fb12c30947 --- /dev/null +++ b/datafusion/core/benches/push_down_filter.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; +use bytes::{BufMut, BytesMut}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::config::ConfigOptions; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::ExecutionPlan; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::ObjectStore; +use parquet::arrow::ArrowWriter; +use std::sync::Arc; + +async fn create_plan() -> Arc { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::UInt16, true), + Field::new("salary", DataType::Float64, true), + ])); + let batch = RecordBatch::new_empty(schema); + + let store = Arc::new(InMemory::new()) as Arc; + let mut out = BytesMut::new().writer(); + { + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + store + .put(&Path::from("test.parquet"), data.into()) + .await + .unwrap(); + ctx.register_object_store( + ObjectStoreUrl::parse("memory://").unwrap().as_ref(), + store, + ); + + ctx.register_parquet("t", "memory:///", ParquetReadOptions::default()) + .await + .unwrap(); + + let df = ctx + .sql( + r" + WITH brackets AS ( + SELECT age % 10 AS age_bracket + FROM t + GROUP BY age % 10 + HAVING COUNT(*) > 10 + ) + SELECT id, name, age, salary + FROM t + JOIN brackets ON t.age % 10 = brackets.age_bracket + WHERE age > 20 AND t.salary > 1000 + ORDER BY t.salary DESC + LIMIT 100 + ", + ) + .await + .unwrap(); + + df.create_physical_plan().await.unwrap() +} + +#[derive(Clone)] +struct BenchmarkPlan { + plan: Arc, + config: ConfigOptions, +} + +impl std::fmt::Display for BenchmarkPlan { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BenchmarkPlan") + } +} + +fn bench_push_down_filter(c: &mut Criterion) { + // Create a relatively complex plan + let plan = tokio::runtime::Runtime::new() + .unwrap() + .block_on(create_plan()); + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + let plan = BenchmarkPlan { plan, config }; + let optimizer = FilterPushdown::new(); + + c.bench_function("push_down_filter", |b| { + b.iter(|| { + optimizer + .optimize(Arc::clone(&plan.plan), &plan.config) + .unwrap(); + }); + }); +} + +// It's a bit absurd that it's this complicated but to generate a flamegraph you can run: +// `cargo flamegraph -p datafusion --bench push_down_filter --flamechart --root --profile profiling --freq 1000 -- --bench` +// See https://github.com/flamegraph-rs/flamegraph +criterion_group!(benches, bench_push_down_filter); +criterion_main!(benches); diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index 85f456ce5dc22..276151e253f7e 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -71,7 +71,6 @@ use std::sync::Arc; use arrow::array::StringViewArray; use arrow::{ array::{DictionaryArray, Float64Array, Int64Array, StringArray}, - compute::SortOptions, datatypes::{Int32Type, Schema}, record_batch::RecordBatch, }; @@ -272,14 +271,11 @@ impl BenchCase { /// Make sort exprs for each column in `schema` fn make_sort_exprs(schema: &Schema) -> LexOrdering { - schema + let sort_exprs = schema .fields() .iter() - .map(|f| PhysicalSortExpr { - expr: col(f.name(), schema).unwrap(), - options: SortOptions::default(), - }) - .collect() + .map(|f| PhysicalSortExpr::new_default(col(f.name(), schema).unwrap())); + LexOrdering::new(sort_exprs).unwrap() } /// Create streams of int64 (where approximately 1/3 values is repeated) @@ -595,7 +591,7 @@ impl DataGenerator { /// Create an array of i64 sorted values (where approximately 1/3 values is repeated) fn i64_values(&mut self) -> Vec { let mut vec: Vec<_> = (0..INPUT_SIZE) - .map(|_| self.rng.gen_range(0..INPUT_SIZE as i64)) + .map(|_| self.rng.random_range(0..INPUT_SIZE as i64)) .collect(); vec.sort_unstable(); @@ -620,7 +616,7 @@ impl DataGenerator { // pick from the 100 strings randomly let mut input = (0..INPUT_SIZE) .map(|_| { - let idx = self.rng.gen_range(0..strings.len()); + let idx = self.rng.random_range(0..strings.len()); let s = Arc::clone(&strings[idx]); Some(s) }) @@ -643,7 +639,7 @@ impl DataGenerator { fn random_string(&mut self) -> String { let rng = &mut self.rng; - rng.sample_iter(rand::distributions::Alphanumeric) + rng.sample_iter(rand::distr::Alphanumeric) .filter(|c| c.is_ascii_alphabetic()) .take(20) .map(char::from) @@ -665,7 +661,7 @@ where let mut outputs: Vec>> = (0..NUM_STREAMS).map(|_| Vec::new()).collect(); for i in input { - let stream_idx = rng.gen_range(0..NUM_STREAMS); + let stream_idx = rng.random_range(0..NUM_STREAMS); let stream = &mut outputs[stream_idx]; match stream.last_mut() { Some(x) if x.len() < BATCH_SIZE => x.push(i), diff --git a/datafusion/core/benches/spm.rs b/datafusion/core/benches/spm.rs index 63b06f20cd86a..5c244832300e4 100644 --- a/datafusion/core/benches/spm.rs +++ b/datafusion/core/benches/spm.rs @@ -21,7 +21,6 @@ use arrow::array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray}; use datafusion_execution::TaskContext; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{collect, ExecutionPlan}; @@ -67,10 +66,10 @@ fn generate_spm_for_round_robin_tie_breaker( }; let rbs = (0..batch_count).map(|_| rb.clone()).collect::>(); - let partitiones = vec![rbs.clone(); partition_count]; + let partitions = vec![rbs.clone(); partition_count]; let schema = rb.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), @@ -79,9 +78,10 @@ fn generate_spm_for_round_robin_tie_breaker( expr: col("c", &schema).unwrap(), options: Default::default(), }, - ]); + ] + .into(); - let exec = MemorySourceConfig::try_new_exec(&partitiones, schema, None).unwrap(); + let exec = MemorySourceConfig::try_new_exec(&partitions, schema, None).unwrap(); SortPreservingMergeExec::new(sort, exec) .with_round_robin_repartition(enable_round_robin_repartition) } @@ -125,8 +125,7 @@ fn criterion_benchmark(c: &mut Criterion) { for &batch_count in &batch_counts { for &partition_count in &partition_counts { let description = format!( - "{}_batch_count_{}_partition_count_{}", - cardinality_label, batch_count, partition_count + "{cardinality_label}_batch_count_{batch_count}_partition_count_{partition_count}" ); run_bench( c, diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 49cc830d58bc4..3be8668b2b8c4 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -25,14 +25,18 @@ mod data_utils; use crate::criterion::Criterion; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, Field, Fields, Schema}; +use arrow_schema::TimeUnit::Nanosecond; use criterion::Bencher; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; +use datafusion::prelude::DataFrame; use datafusion_common::ScalarValue; -use datafusion_expr::col; -use itertools::Itertools; -use std::fs::File; -use std::io::{BufRead, BufReader}; +use datafusion_expr::Expr::Literal; +use datafusion_expr::{cast, col, lit, not, try_cast, when}; +use datafusion_functions::expr_fn::{ + btrim, length, regexp_like, regexp_replace, to_timestamp, upper, +}; +use std::ops::Rem; use std::path::PathBuf; use std::sync::Arc; use test_utils::tpcds::tpcds_schemas; @@ -61,6 +65,150 @@ fn physical_plan(ctx: &SessionContext, rt: &Runtime, sql: &str) { })); } +/// Build a dataframe for testing logical plan optimization +fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { + register_string_table(ctx, 100, 1000); + + rt.block_on(async { + let mut df = ctx.table("t").await.unwrap(); + // add some columns in + for i in 100..150 { + df = df + .with_column(&format!("c{i}"), Literal(ScalarValue::Utf8(None), None)) + .unwrap(); + } + // add in some columns with string encoded timestamps + for i in 150..175 { + df = df + .with_column( + &format!("c{i}"), + Literal(ScalarValue::Utf8(Some("2025-08-21 09:43:17".into())), None), + ) + .unwrap(); + } + // do a bunch of ops on the columns + for i in 0..175 { + // trim the columns + df = df + .with_column(&format!("c{i}"), btrim(vec![col(format!("c{i}"))])) + .unwrap(); + } + + for i in 0..175 { + let c_name = format!("c{i}"); + let c = col(&c_name); + + // random ops + if i % 5 == 0 && i < 150 { + // the actual ops here are largely unimportant as they are just a sample + // of ops that could occur on a dataframe + df = df + .with_column(&c_name, cast(c.clone(), DataType::Utf8)) + .unwrap() + .with_column( + &c_name, + when( + cast(c.clone(), DataType::Int32).gt(lit(135)), + cast( + cast(c.clone(), DataType::Int32) - lit(i + 3), + DataType::Utf8, + ), + ) + .otherwise(c.clone()) + .unwrap(), + ) + .unwrap() + .with_column( + &c_name, + when( + c.clone().is_not_null().and( + cast(c.clone(), DataType::Int32) + .between(lit(120), lit(130)), + ), + Literal(ScalarValue::Utf8(None), None), + ) + .otherwise( + when( + c.clone().is_not_null().and(regexp_like( + cast(c.clone(), DataType::Utf8View), + lit("[0-9]*"), + None, + )), + upper(c.clone()), + ) + .otherwise(c.clone()) + .unwrap(), + ) + .unwrap(), + ) + .unwrap() + .with_column( + &c_name, + when( + c.clone().is_not_null().and( + cast(c.clone(), DataType::Int32) + .between(lit(90), lit(100)), + ), + cast(c.clone(), DataType::Utf8View), + ) + .otherwise(Literal(ScalarValue::Date32(None), None)) + .unwrap(), + ) + .unwrap() + .with_column( + &c_name, + when( + c.clone().is_not_null().and( + cast(c.clone(), DataType::Int32).rem(lit(10)).gt(lit(7)), + ), + regexp_replace( + cast(c.clone(), DataType::Utf8View), + lit("1"), + lit("a"), + None, + ), + ) + .otherwise(Literal(ScalarValue::Date32(None), None)) + .unwrap(), + ) + .unwrap() + } + if i >= 150 { + df = df + .with_column( + &c_name, + try_cast( + to_timestamp(vec![c.clone(), lit("%Y-%m-%d %H:%M:%S")]), + DataType::Timestamp(Nanosecond, Some("UTC".into())), + ), + ) + .unwrap() + .with_column(&c_name, try_cast(c.clone(), DataType::Date32)) + .unwrap() + } + + // add in a few unions + if i % 30 == 0 { + let df1 = df + .clone() + .filter(length(c.clone()).gt(lit(2))) + .unwrap() + .with_column(&format!("c{i}_filtered"), lit(true)) + .unwrap(); + let df2 = df + .filter(not(length(c.clone()).gt(lit(2)))) + .unwrap() + .with_column(&format!("c{i}_filtered"), lit(false)) + .unwrap(); + + df = df1.union_by_name(df2).unwrap() + } + } + + df + }) +} + /// Create schema with the specified number of columns fn create_schema(column_prefix: &str, num_columns: usize) -> Schema { let fields: Fields = (0..num_columns) @@ -136,10 +284,10 @@ fn benchmark_with_param_values_many_columns( if i > 0 { aggregates.push_str(", "); } - aggregates.push_str(format!("MAX(a{})", i).as_str()); + aggregates.push_str(format!("MAX(a{i})").as_str()); } // SELECT max(attr0), ..., max(attrN) FROM t1. - let query = format!("SELECT {} FROM t1", aggregates); + let query = format!("SELECT {aggregates} FROM t1"); let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap(); let plan = rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() }); @@ -164,7 +312,7 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows .map(|j| j as u64 * 100 + i) .collect::>(), )); - (format!("c{}", i), array) + (format!("c{i}"), array) }); let batch = RecordBatch::try_from_iter(iter).unwrap(); let schema = batch.schema(); @@ -172,7 +320,7 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows // tell DataFusion that the table is sorted by all columns let sort_order = (0..num_columns) - .map(|i| col(format!("c{}", i)).sort(true, true)) + .map(|i| col(format!("c{i}")).sort(true, true)) .collect::>(); // create the table @@ -183,13 +331,40 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows ctx.register_table("t", Arc::new(table)).unwrap(); } +/// Registers a table like this: +/// c0,c1,c2...,c99 +/// "0","100"..."9900" +/// "0","200"..."19800" +/// "0","300"..."29700" +fn register_string_table(ctx: &SessionContext, num_columns: usize, num_rows: usize) { + // ("c0", ["0", "0", ...]) + // ("c1": ["100", "200", ...]) + // etc + let iter = (0..num_columns).map(|i| i as u64).map(|i| { + let array: ArrayRef = Arc::new(arrow::array::StringViewArray::from_iter_values( + (0..num_rows) + .map(|j| format!("c{}", j as u64 * 100 + i)) + .collect::>(), + )); + (format!("c{i}"), array) + }); + let batch = RecordBatch::try_from_iter(iter).unwrap(); + let schema = batch.schema(); + let partitions = vec![vec![batch]]; + + // create the table + let table = MemTable::try_new(schema, partitions).unwrap(); + + ctx.register_table("t", Arc::new(table)).unwrap(); +} + /// return a query like /// ```sql -/// select c1, null as c2, ... null as cn from t ORDER BY c1 +/// select c1, 2 as c2, ... n as cn from t ORDER BY c1 /// UNION ALL -/// select null as c1, c2, ... null as cn from t ORDER BY c2 +/// select 1 as c1, c2, ... n as cn from t ORDER BY c2 /// ... -/// select null as c1, null as c2, ... cn from t ORDER BY cn +/// select 1 as c1, 2 as c2, ... cn from t ORDER BY cn /// ORDER BY c1, c2 ... CN /// ``` fn union_orderby_query(n: usize) -> String { @@ -203,17 +378,17 @@ fn union_orderby_query(n: usize) -> String { if i == j { format!("c{j}") } else { - format!("null as c{j}") + format!("{j} as c{j}") } }) .collect::>() .join(", "); - query.push_str(&format!("(SELECT {} FROM t ORDER BY c{})", select_list, i)); + query.push_str(&format!("(SELECT {select_list} FROM t ORDER BY c{i})")); } query.push_str(&format!( "\nORDER BY {}", (0..n) - .map(|i| format!("c{}", i)) + .map(|i| format!("c{i}")) .collect::>() .join(", ") )); @@ -293,14 +468,42 @@ fn criterion_benchmark(c: &mut Criterion) { if i > 0 { aggregates.push_str(", "); } - aggregates.push_str(format!("MAX(a{})", i).as_str()); + aggregates.push_str(format!("MAX(a{i})").as_str()); } - let query = format!("SELECT {} FROM t1", aggregates); + let query = format!("SELECT {aggregates} FROM t1"); b.iter(|| { physical_plan(&ctx, &rt, &query); }); }); + // It was observed in production that queries with window functions sometimes partition over more than 30 columns + for partitioning_columns in [4, 7, 8, 12, 30] { + c.bench_function( + &format!( + "physical_window_function_partition_by_{partitioning_columns}_on_values" + ), + |b| { + let source = format!( + "SELECT 1 AS n{}", + (0..partitioning_columns) + .map(|i| format!(", {i} AS c{i}")) + .collect::() + ); + let window = format!( + "SUM(n) OVER (PARTITION BY {}) AS sum_n", + (0..partitioning_columns) + .map(|i| format!("c{i}")) + .collect::>() + .join(", ") + ); + let query = format!("SELECT {window} FROM ({source})"); + b.iter(|| { + physical_plan(&ctx, &rt, &query); + }); + }, + ); + } + // Benchmark for Physical Planning Joins c.bench_function("physical_join_consider_sort", |b| { b.iter(|| { @@ -373,16 +576,37 @@ fn criterion_benchmark(c: &mut Criterion) { }); // -- Sorted Queries -- - register_union_order_table(&ctx, 100, 1000); - - // this query has many expressions in its sort order so stresses - // order equivalence validation - c.bench_function("physical_sorted_union_orderby", |b| { - // SELECT ... UNION ALL ... - let query = union_orderby_query(20); - b.iter(|| physical_plan(&ctx, &rt, &query)) + for column_count in [10, 50, 100, 200, 300] { + register_union_order_table(&ctx, column_count, 1000); + + // this query has many expressions in its sort order so stresses + // order equivalence validation + c.bench_function( + &format!("physical_sorted_union_order_by_{column_count}"), + |b| { + // SELECT ... UNION ALL ... + let query = union_orderby_query(column_count); + b.iter(|| physical_plan(&ctx, &rt, &query)) + }, + ); + + let _ = ctx.deregister_table("t"); + } + + // -- validate logical plan optimize performance + let df = build_test_data_frame(&ctx, &rt); + + c.bench_function("logical_plan_optimize", |b| { + b.iter(|| { + let df_clone = df.clone(); + criterion::black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) }); + let _ = ctx.deregister_table("t"); + // --- TPC-H --- let tpch_ctx = register_defs(SessionContext::new(), tpch_schemas()); @@ -402,7 +626,7 @@ fn criterion_benchmark(c: &mut Criterion) { for q in tpch_queries { let sql = std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap(); - c.bench_function(&format!("physical_plan_tpch_{}", q), |b| { + c.bench_function(&format!("physical_plan_tpch_{q}"), |b| { b.iter(|| physical_plan(&tpch_ctx, &rt, &sql)) }); } @@ -440,6 +664,9 @@ fn criterion_benchmark(c: &mut Criterion) { }; let raw_tpcds_sql_queries = (1..100) + // skip query 75 until it is fixed + // https://github.com/apache/datafusion/issues/17801 + .filter(|q| *q != 75) .map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap()) .collect::>(); @@ -466,17 +693,20 @@ fn criterion_benchmark(c: &mut Criterion) { // }); // -- clickbench -- - - let queries_file = - File::open(format!("{benchmarks_path}queries/clickbench/queries.sql")).unwrap(); - let extended_file = - File::open(format!("{benchmarks_path}queries/clickbench/extended.sql")).unwrap(); - - let clickbench_queries: Vec = BufReader::new(queries_file) - .lines() - .chain(BufReader::new(extended_file).lines()) - .map(|l| l.expect("Could not parse line")) - .collect_vec(); + let clickbench_queries = (0..=42) + .map(|q| { + std::fs::read_to_string(format!( + "{benchmarks_path}queries/clickbench/queries/q{q}.sql" + )) + .unwrap() + }) + .chain((0..=7).map(|q| { + std::fs::read_to_string(format!( + "{benchmarks_path}queries/clickbench/extended/q{q}.sql" + )) + .unwrap() + })) + .collect::>(); let clickbench_ctx = register_clickbench_hits_table(&rt); diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 58d71ee5b2eb8..58797dfed6b67 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -66,7 +66,7 @@ fn create_parquet_file(rng: &mut StdRng, id_offset: usize) -> Bytes { let mut payload_builder = Int64Builder::new(); for row in 0..FILE_ROWS { id_builder.append_value((row + id_offset) as u64); - payload_builder.append_value(rng.gen()); + payload_builder.append_value(rng.random()); } let batch = RecordBatch::try_new( Arc::clone(&schema), diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index 777d586b344c4..cf3c7fa2e26fe 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -33,8 +33,9 @@ async fn create_context( sample_cnt: i32, asc: bool, use_topk: bool, + use_view: bool, ) -> Result<(Arc, Arc)> { - let (schema, parts) = make_data(partition_cnt, sample_cnt, asc).unwrap(); + let (schema, parts) = make_data(partition_cnt, sample_cnt, asc, use_view).unwrap(); let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); // Create the DataFrame @@ -108,7 +109,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let real = rt.block_on(async { - create_context(limit, partitions, samples, false, false) + create_context(limit, partitions, samples, false, false, false) .await .unwrap() }); @@ -122,7 +123,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let asc = rt.block_on(async { - create_context(limit, partitions, samples, true, false) + create_context(limit, partitions, samples, true, false, false) .await .unwrap() }); @@ -140,7 +141,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let topk_real = rt.block_on(async { - create_context(limit, partitions, samples, false, true) + create_context(limit, partitions, samples, false, true, false) .await .unwrap() }); @@ -158,7 +159,45 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let topk_asc = rt.block_on(async { - create_context(limit, partitions, samples, true, true) + create_context(limit, partitions, samples, true, true, false) + .await + .unwrap() + }); + run(&rt, topk_asc.0.clone(), topk_asc.1.clone(), true) + }) + }, + ); + + // Utf8View schema,time-series rows + c.bench_function( + format!( + "top k={limit} aggregate {} time-series rows [Utf8View]", + partitions * samples + ) + .as_str(), + |b| { + b.iter(|| { + let topk_real = rt.block_on(async { + create_context(limit, partitions, samples, false, true, true) + .await + .unwrap() + }); + run(&rt, topk_real.0.clone(), topk_real.1.clone(), false) + }) + }, + ); + + // Utf8View schema,worst-case rows + c.bench_function( + format!( + "top k={limit} aggregate {} worst-case rows [Utf8View]", + partitions * samples + ) + .as_str(), + |b| { + b.iter(|| { + let topk_asc = rt.block_on(async { + create_context(limit, partitions, samples, true, true, true) .await .unwrap() }); diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 7afb90282a80a..63387c023b11a 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -46,7 +46,7 @@ fn main() -> Result<()> { "scalar" => print_scalar_docs(), "window" => print_window_docs(), _ => { - panic!("Unknown function type: {}", function_type) + panic!("Unknown function type: {function_type}") } }?; @@ -92,7 +92,7 @@ fn print_window_docs() -> Result { fn save_doc_code_text(documentation: &Documentation, name: &str) { let attr_text = documentation.to_doc_attribute(); - let file_path = format!("{}.txt", name); + let file_path = format!("{name}.txt"); if std::path::Path::new(&file_path).exists() { std::fs::remove_file(&file_path).unwrap(); } @@ -215,16 +215,15 @@ fn print_docs( r#" #### Example -{} -"#, - example +{example} +"# ); } if let Some(alt_syntax) = &documentation.alternative_syntax { let _ = writeln!(docs, "#### Alternative Syntax\n"); for syntax in alt_syntax { - let _ = writeln!(docs, "```sql\n{}\n```", syntax); + let _ = writeln!(docs, "```sql\n{syntax}\n```"); } } @@ -261,7 +260,7 @@ fn print_docs( } } -/// Trait for accessing name / aliases / documentation for differnet functions +/// Trait for accessing name / aliases / documentation for different functions trait DocProvider { fn get_name(&self) -> String; fn get_aliases(&self) -> Vec; diff --git a/datafusion/core/src/bin/print_runtime_config_docs.rs b/datafusion/core/src/bin/print_runtime_config_docs.rs new file mode 100644 index 0000000000000..31425da73d354 --- /dev/null +++ b/datafusion/core/src/bin/print_runtime_config_docs.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_execution::runtime_env::RuntimeEnvBuilder; + +fn main() { + let docs = RuntimeEnvBuilder::generate_config_markdown(); + println!("{docs}"); +} diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 9a70f8f43fb61..287a133273d87 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -33,8 +33,8 @@ use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, - Partitioning, TableType, + col, ident, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, + LogicalPlanBuilderOptions, Partitioning, TableType, }; use crate::physical_plan::{ collect, collect_partitioned, execute_stream, execute_stream_partitioned, @@ -43,7 +43,7 @@ use crate::physical_plan::{ use crate::prelude::SessionContext; use std::any::Any; use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; @@ -51,8 +51,9 @@ use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ - exec_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, - DataFusionError, ParamValues, ScalarValue, SchemaError, UnnestOptions, + exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, + Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, + TableReference, UnnestOptions, }; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{ @@ -61,7 +62,7 @@ use datafusion_expr::{ expr::{Alias, ScalarFunction}, is_null, lit, utils::COUNT_STAR_EXPANSION, - SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, + ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_functions::core::coalesce; use datafusion_functions_aggregate::expr_fn::{ @@ -70,7 +71,6 @@ use datafusion_functions_aggregate::expr_fn::{ use async_trait::async_trait; use datafusion_catalog::Session; -use datafusion_sql::TableReference; /// Contains options that control how data is /// written out from a DataFrame @@ -166,9 +166,12 @@ impl Default for DataFrameWriteOptions { /// /// # Example /// ``` +/// # use std::sync::Arc; /// # use datafusion::prelude::*; /// # use datafusion::error::Result; /// # use datafusion::functions_aggregate::expr_fn::min; +/// # use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray}; +/// # use datafusion::arrow::datatypes::{DataType, Field, Schema}; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -181,6 +184,28 @@ impl Default for DataFrameWriteOptions { /// .limit(0, Some(100))?; /// // Perform the actual computation /// let results = df.collect(); +/// +/// // Create a new dataframe with in-memory data +/// let schema = Schema::new(vec![ +/// Field::new("id", DataType::Int32, true), +/// Field::new("name", DataType::Utf8, true), +/// ]); +/// let batch = RecordBatch::try_new( +/// Arc::new(schema), +/// vec![ +/// Arc::new(Int32Array::from(vec![1, 2, 3])), +/// Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), +/// ], +/// )?; +/// let df = ctx.read_batch(batch)?; +/// df.show().await?; +/// +/// // Create a new dataframe with in-memory data using macro +/// let df = dataframe!( +/// "id" => [1, 2, 3], +/// "name" => ["foo", "bar", "baz"] +/// )?; +/// df.show().await?; /// # Ok(()) /// # } /// ``` @@ -242,6 +267,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub fn parse_sql_expr(&self, sql: &str) -> Result { let df_schema = self.schema(); @@ -308,6 +334,7 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub fn select_exprs(self, exprs: &[&str]) -> Result { let expr_list = exprs .iter() @@ -350,15 +377,12 @@ impl DataFrame { let expr_list: Vec = expr_list.into_iter().map(|e| e.into()).collect::>(); - let expressions = expr_list - .iter() - .filter_map(|e| match e { - SelectExpr::Expression(expr) => Some(expr.clone()), - _ => None, - }) - .collect::>(); + let expressions = expr_list.iter().filter_map(|e| match e { + SelectExpr::Expression(expr) => Some(expr), + _ => None, + }); - let window_func_exprs = find_window_exprs(&expressions); + let window_func_exprs = find_window_exprs(expressions); let plan = if window_func_exprs.is_empty() { self.plan } else { @@ -934,7 +958,7 @@ impl DataFrame { vec![], original_schema_fields .clone() - .map(|f| count(col(f.name())).alias(f.name())) + .map(|f| count(ident(f.name())).alias(f.name())) .collect::>(), ), // null_count aggregation @@ -943,7 +967,7 @@ impl DataFrame { original_schema_fields .clone() .map(|f| { - sum(case(is_null(col(f.name()))) + sum(case(is_null(ident(f.name()))) .when(lit(true), lit(1)) .otherwise(lit(0)) .unwrap()) @@ -957,7 +981,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| avg(col(f.name())).alias(f.name())) + .map(|f| avg(ident(f.name())).alias(f.name())) .collect::>(), ), // std aggregation @@ -966,7 +990,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| stddev(col(f.name())).alias(f.name())) + .map(|f| stddev(ident(f.name())).alias(f.name())) .collect::>(), ), // min aggregation @@ -977,7 +1001,7 @@ impl DataFrame { .filter(|f| { !matches!(f.data_type(), DataType::Binary | DataType::Boolean) }) - .map(|f| min(col(f.name())).alias(f.name())) + .map(|f| min(ident(f.name())).alias(f.name())) .collect::>(), ), // max aggregation @@ -988,7 +1012,7 @@ impl DataFrame { .filter(|f| { !matches!(f.data_type(), DataType::Binary | DataType::Boolean) }) - .map(|f| max(col(f.name())).alias(f.name())) + .map(|f| max(ident(f.name())).alias(f.name())) .collect::>(), ), // median aggregation @@ -997,7 +1021,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| median(col(f.name())).alias(f.name())) + .map(|f| median(ident(f.name())).alias(f.name())) .collect::>(), ), ]; @@ -1312,7 +1336,10 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? + .aggregate( + vec![], + vec![count(Expr::Literal(COUNT_STAR_EXPANSION, None))], + )? .collect() .await?; let len = *rows @@ -1320,9 +1347,9 @@ impl DataFrame { .and_then(|r| r.columns().first()) .and_then(|c| c.as_any().downcast_ref::()) .and_then(|a| a.values().first()) - .ok_or(DataFusionError::Internal( - "Unexpected output when collecting for count()".to_string(), - ))? as usize; + .ok_or_else(|| { + internal_datafusion_err!("Unexpected output when collecting for count()") + })? as usize; Ok(len) } @@ -1366,8 +1393,47 @@ impl DataFrame { /// # } /// ``` pub async fn show(self) -> Result<()> { + println!("{}", self.to_string().await?); + Ok(()) + } + + /// Execute the `DataFrame` and return a string representation of the results. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion::execution::SessionStateBuilder; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let cfg = SessionConfig::new() + /// .set_str("datafusion.format.null", "no-value"); + /// let session_state = SessionStateBuilder::new() + /// .with_config(cfg) + /// .with_default_features() + /// .build(); + /// let ctx = SessionContext::new_with_state(session_state); + /// let df = ctx.sql("select null as 'null-column'").await?; + /// let result = df.to_string().await?; + /// assert_eq!(result, + /// "+-------------+ + /// | null-column | + /// +-------------+ + /// | no-value | + /// +-------------+" + /// ); + /// # Ok(()) + /// # } + pub async fn to_string(self) -> Result { + let options = self.session_state.config().options().format.clone(); + let arrow_options: arrow::util::display::FormatOptions = (&options).try_into()?; + let results = self.collect().await?; - Ok(pretty::print_batches(&results)?) + Ok( + pretty::pretty_format_batches_with_options(&results, &arrow_options)? + .to_string(), + ) } /// Execute the `DataFrame` and print only the first `num` rows of the @@ -1538,6 +1604,8 @@ impl DataFrame { /// Return a DataFrame with the explanation of its plan so far. /// /// if `analyze` is specified, runs the plan and reports metrics + /// if `verbose` is true, prints out additional details. + /// The default format is Indent format. /// /// ``` /// # use datafusion::prelude::*; @@ -1551,11 +1619,38 @@ impl DataFrame { /// # } /// ``` pub fn explain(self, verbose: bool, analyze: bool) -> Result { + // Set the default format to Indent to keep the previous behavior + let opts = ExplainOption::default() + .with_verbose(verbose) + .with_analyze(analyze); + self.explain_with_options(opts) + } + + /// Return a DataFrame with the explanation of its plan so far. + /// + /// `opt` is used to specify the options for the explain operation. + /// Details of the options can be found in [`ExplainOption`]. + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// use datafusion_expr::{Explain, ExplainOption}; + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let batches = df.limit(0, Some(100))?.explain_with_options(ExplainOption::default().with_verbose(false).with_analyze(false))?.collect().await?; + /// # Ok(()) + /// # } + /// ``` + pub fn explain_with_options( + self, + explain_option: ExplainOption, + ) -> Result { if matches!(self.plan, LogicalPlan::Explain(_)) { return plan_err!("Nested EXPLAINs are not supported"); } let plan = LogicalPlanBuilder::from(self.plan) - .explain(verbose, analyze)? + .explain_option_format(explain_option)? .build()?; Ok(DataFrame { session_state: self.session_state, @@ -1617,6 +1712,40 @@ impl DataFrame { }) } + /// Calculate the distinct intersection of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let d2 = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let df = df.intersect_distinct(d2)?; + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); + /// # Ok(()) + /// # } + /// ``` + pub fn intersect_distinct(self, dataframe: DataFrame) -> Result { + let left_plan = self.plan; + let right_plan = dataframe.plan; + let plan = LogicalPlanBuilder::intersect(left_plan, right_plan, false)?; + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: true, + }) + } + /// Calculate the exception of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema /// /// ``` @@ -1653,6 +1782,42 @@ impl DataFrame { }) } + /// Calculate the distinct exception of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let d2 = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let result = df.except_distinct(d2)?; + /// // those columns are not in example.csv, but in example_long.csv + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 4 | 5 | 6 |", + /// "| 7 | 8 | 9 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &result.collect().await?); + /// # Ok(()) + /// # } + /// ``` + pub fn except_distinct(self, dataframe: DataFrame) -> Result { + let left_plan = self.plan; + let right_plan = dataframe.plan; + let plan = LogicalPlanBuilder::except(left_plan, right_plan, false)?; + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: true, + }) + } + /// Execute this `DataFrame` and write the results to `table_name`. /// /// Returns a single [RecordBatch] containing a single column and @@ -1856,33 +2021,40 @@ impl DataFrame { /// # } /// ``` pub fn with_column(self, name: &str, expr: Expr) -> Result { - let window_func_exprs = find_window_exprs(std::slice::from_ref(&expr)); + let window_func_exprs = find_window_exprs([&expr]); - let (window_fn_str, plan) = if window_func_exprs.is_empty() { - (None, self.plan) + let original_names: HashSet = self + .plan + .schema() + .iter() + .map(|(_, f)| f.name().clone()) + .collect(); + + // Maybe build window plan + let plan = if window_func_exprs.is_empty() { + self.plan } else { - ( - Some(window_func_exprs[0].to_string()), - LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?, - ) + LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; - let mut col_exists = false; let new_column = expr.alias(name); + let mut col_exists = false; + let mut fields: Vec<(Expr, bool)> = plan .schema() .iter() .filter_map(|(qualifier, field)| { + // Skip new fields introduced by window_plan + if !original_names.contains(field.name()) { + return None; + } + if field.name() == name { col_exists = true; Some((new_column.clone(), true)) } else { let e = col(Column::from((qualifier, field))); - window_fn_str - .as_ref() - .filter(|s| *s == &e.to_string()) - .is_none() - .then_some((e, self.projection_requires_validation)) + Some((e, self.projection_requires_validation)) } }) .collect(); @@ -1943,10 +2115,11 @@ impl DataFrame { match self.plan.schema().qualified_field_from_column(&old_column) { Ok(qualifier_and_field) => qualifier_and_field, // no-op if field not found - Err(DataFusionError::SchemaError( - SchemaError::FieldNotFound { .. }, - _, - )) => return Ok(self), + Err(DataFusionError::SchemaError(e, _)) + if matches!(*e, SchemaError::FieldNotFound { .. }) => + { + return Ok(self); + } Err(err) => return Err(err), }; let projection = self @@ -2160,6 +2333,94 @@ impl DataFrame { }) .collect() } + + /// Helper for creating DataFrame. + /// # Example + /// ``` + /// use std::sync::Arc; + /// use arrow::array::{ArrayRef, Int32Array, StringArray}; + /// use datafusion::prelude::DataFrame; + /// let id: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + /// let name: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + /// let df = DataFrame::from_columns(vec![("id", id), ("name", name)]).unwrap(); + /// // +----+------+, + /// // | id | name |, + /// // +----+------+, + /// // | 1 | foo |, + /// // | 2 | bar |, + /// // | 3 | baz |, + /// // +----+------+, + /// ``` + pub fn from_columns(columns: Vec<(&str, ArrayRef)>) -> Result { + let fields = columns + .iter() + .map(|(name, array)| Field::new(*name, array.data_type().clone(), true)) + .collect::>(); + + let arrays = columns + .into_iter() + .map(|(_, array)| array) + .collect::>(); + + let schema = Arc::new(Schema::new(fields)); + let batch = RecordBatch::try_new(schema, arrays)?; + let ctx = SessionContext::new(); + let df = ctx.read_batch(batch)?; + Ok(df) + } +} + +/// Macro for creating DataFrame. +/// # Example +/// ``` +/// use datafusion::prelude::dataframe; +/// # use datafusion::error::Result; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let df = dataframe!( +/// "id" => [1, 2, 3], +/// "name" => ["foo", "bar", "baz"] +/// )?; +/// df.show().await?; +/// // +----+------+, +/// // | id | name |, +/// // +----+------+, +/// // | 1 | foo |, +/// // | 2 | bar |, +/// // | 3 | baz |, +/// // +----+------+, +/// let df_empty = dataframe!()?; // empty DataFrame +/// assert_eq!(df_empty.schema().fields().len(), 0); +/// assert_eq!(df_empty.count().await?, 0); +/// # Ok(()) +/// # } +/// ``` +#[macro_export] +macro_rules! dataframe { + () => {{ + use std::sync::Arc; + + use datafusion::prelude::SessionContext; + use datafusion::arrow::array::RecordBatch; + use datafusion::arrow::datatypes::Schema; + + let ctx = SessionContext::new(); + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + ctx.read_batch(batch) + }}; + + ($($name:expr => $data:expr),+ $(,)?) => {{ + use datafusion::prelude::DataFrame; + use datafusion::common::test_util::IntoArrayRef; + + let columns = vec![ + $( + ($name, $data.into_array_ref()), + )+ + ]; + + DataFrame::from_columns(columns) + }}; } #[derive(Debug)] @@ -2173,7 +2434,7 @@ impl TableProvider for DataFrameTableProvider { self } - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&self) -> Option> { Some(Cow::Borrowed(&self.plan)) } @@ -2186,8 +2447,7 @@ impl TableProvider for DataFrameTableProvider { } fn schema(&self) -> SchemaRef { - let schema: Schema = self.plan.schema().as_ref().into(); - Arc::new(schema) + Arc::clone(self.plan.schema().inner()) } fn table_type(&self) -> TableType { diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 1bb5444ca009f..d46a902ca5139 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -205,7 +205,7 @@ mod tests { &HashMap::from_iter( [("datafusion.execution.batch_size", "10")] .iter() - .map(|(s1, s2)| (s1.to_string(), s2.to_string())), + .map(|(s1, s2)| ((*s1).to_string(), (*s2).to_string())), ), )?); register_aggregate_csv(&ctx, "aggregate_test_100").await?; @@ -246,4 +246,77 @@ mod tests { Ok(()) } + + #[rstest::rstest] + #[cfg(feature = "parquet_encryption")] + #[tokio::test] + async fn roundtrip_parquet_with_encryption( + #[values(false, true)] allow_single_file_parallelism: bool, + ) -> Result<()> { + use parquet::encryption::decrypt::FileDecryptionProperties; + use parquet::encryption::encrypt::FileEncryptionProperties; + + let test_df = test_util::test_table().await?; + + let schema = test_df.schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields().iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + + let encrypt = encrypt.build()?; + let decrypt = decrypt.build()?; + + let df = test_df.clone(); + let tmp_dir = TempDir::new()?; + let tempfile = tmp_dir.path().join("roundtrip.parquet"); + let tempfile_str = tempfile.into_os_string().into_string().unwrap(); + + // Write encrypted parquet using write_parquet + let mut options = TableParquetOptions::default(); + options.crypto.file_encryption = Some((&encrypt).into()); + options.global.allow_single_file_parallelism = allow_single_file_parallelism; + + df.write_parquet( + tempfile_str.as_str(), + DataFrameWriteOptions::new().with_single_file_output(true), + Some(options), + ) + .await?; + let num_rows_written = test_df.count().await?; + + // Read encrypted parquet + let ctx: SessionContext = SessionContext::new(); + let read_options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + ctx.register_parquet("roundtrip_parquet", &tempfile_str, read_options.clone()) + .await?; + + let df_enc = ctx.sql("SELECT * FROM roundtrip_parquet").await?; + let num_rows_read = df_enc.count().await?; + + assert_eq!(num_rows_read, num_rows_written); + + // Read encrypted parquet and subset rows + columns + let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?; + + // Select three columns and filter the results + // Test that the filter works as expected + let selected = encrypted_parquet_df + .clone() + .select_columns(&["c1", "c2", "c3"])? + .filter(col("c2").gt(lit(4)))?; + + let num_rows_selected = selected.count().await?; + assert_eq!(num_rows_selected, 14); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 6c7c9463cf3b7..25bc166d657a5 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -27,7 +27,7 @@ use std::sync::Arc; use super::file_compression_type::FileCompressionType; use super::write::demux::DemuxedStreamReceiver; -use super::write::{create_writer, SharedBuffer}; +use super::write::SharedBuffer; use super::FileFormatFactory; use crate::datasource::file_format::write::get_writer_schema; use crate::datasource::file_format::FileFormat; @@ -44,16 +44,17 @@ use arrow::ipc::{root_as_message, CompressionType}; use datafusion_catalog::Session; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, + internal_datafusion_err, not_impl_err, DataFusionError, GetExt, Statistics, + DEFAULT_ARROW_EXTENSION, }; use datafusion_common_runtime::{JoinSet, SpawnedTask}; use datafusion_datasource::display::FileGroupDisplay; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::sink::{DataSink, DataSinkExec}; +use datafusion_datasource::write::ObjectWriterBuilder; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use async_trait::async_trait; @@ -128,12 +129,16 @@ impl FileFormat for ArrowFormat { let ext = self.get_ext(); match file_compression_type.get_variant() { CompressionTypeVariant::UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "Arrow FileFormat does not support compression.".into(), + _ => Err(internal_datafusion_err!( + "Arrow FileFormat does not support compression." )), } } + fn compression_type(&self) -> Option { + None + } + async fn infer_schema( &self, _state: &dyn Session, @@ -144,6 +149,7 @@ impl FileFormat for ArrowFormat { for object in objects { let r = store.as_ref().get(&object.location).await?; let schema = match r.payload { + #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(mut file, _) => { let reader = FileReader::try_new(&mut file, None)?; reader.schema() @@ -172,7 +178,6 @@ impl FileFormat for ArrowFormat { &self, _state: &dyn Session, conf: FileScanConfig, - _filters: Option<&Arc>, ) -> Result> { let source = Arc::new(ArrowSource::default()); let config = FileScanConfigBuilder::from(conf) @@ -222,7 +227,7 @@ impl FileSink for ArrowFileSink { async fn spawn_writer_tasks_and_join( &self, - _context: &Arc, + context: &Arc, demux_task: SpawnedTask>, mut file_stream_rx: DemuxedStreamReceiver, object_store: Arc, @@ -240,12 +245,19 @@ impl FileSink for ArrowFileSink { &get_writer_schema(&self.config), ipc_options.clone(), )?; - let mut object_store_writer = create_writer( + let mut object_store_writer = ObjectWriterBuilder::new( FileCompressionType::UNCOMPRESSED, &path, Arc::clone(&object_store), ) - .await?; + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; file_write_tasks.spawn(async move { let mut row_count = 0; while let Some(batch) = rx.recv().await { @@ -287,7 +299,7 @@ impl FileSink for ArrowFileSink { demux_task .join_unwind() .await - .map_err(DataFusionError::ExecutionJoin)??; + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; Ok(row_count as u64) } } @@ -442,7 +454,7 @@ mod tests { let object_meta = ObjectMeta { location, last_modified: DateTime::default(), - size: usize::MAX, + size: u64::MAX, e_tag: None, version: None, }; @@ -485,7 +497,7 @@ mod tests { let object_meta = ObjectMeta { location, last_modified: DateTime::default(), - size: usize::MAX, + size: u64::MAX, e_tag: None, version: None, }; @@ -504,7 +516,7 @@ mod tests { assert!(err.is_err()); assert_eq!( "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file", - err.unwrap_err().to_string() + err.unwrap_err().to_string().lines().next().unwrap() ); Ok(()) diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index a9516aad9e22d..3428d08a6ae52 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -382,6 +382,15 @@ mod tests { let testdata = test_util::arrow_test_data(); let store_root = format!("{testdata}/avro"); let format = AvroFormat {}; - scan_format(state, &format, &store_root, file_name, projection, limit).await + scan_format( + state, + &format, + None, + &store_root, + file_name, + projection, + limit, + ) + .await } } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 309458975ab6c..52fb8ae904ebf 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -33,6 +33,7 @@ mod tests { use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::cast::as_string_array; + use datafusion_common::config::CsvOptions; use datafusion_common::internal_err; use datafusion_common::stats::Precision; use datafusion_common::test_util::{arrow_test_data, batches_to_string}; @@ -47,7 +48,7 @@ mod tests { use datafusion_physical_plan::{collect, ExecutionPlan}; use arrow::array::{ - BooleanArray, Float64Array, Int32Array, RecordBatch, StringArray, + Array, BooleanArray, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow::compute::concat_batches; use arrow::csv::ReaderBuilder; @@ -55,14 +56,16 @@ mod tests { use async_trait::async_trait; use bytes::Bytes; use chrono::DateTime; + use datafusion_common::parsers::CompressionTypeVariant; use futures::stream::BoxStream; use futures::StreamExt; use insta::assert_snapshot; + use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::{ Attributes, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, - ObjectMeta, ObjectStore, PutMultipartOpts, PutOptions, PutPayload, PutResult, + ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, }; use regex::Regex; use rstest::*; @@ -72,7 +75,7 @@ mod tests { #[derive(Debug)] struct VariableStream { bytes_to_repeat: Bytes, - max_iterations: usize, + max_iterations: u64, iterations_detected: Arc>, } @@ -96,21 +99,30 @@ mod tests { async fn put_multipart_opts( &self, _location: &Path, - _opts: PutMultipartOpts, + _opts: PutMultipartOptions, ) -> object_store::Result> { unimplemented!() } async fn get(&self, location: &Path) -> object_store::Result { + self.get_opts(location, GetOptions::default()).await + } + + async fn get_opts( + &self, + location: &Path, + _opts: GetOptions, + ) -> object_store::Result { let bytes = self.bytes_to_repeat.clone(); - let range = 0..bytes.len() * self.max_iterations; + let len = bytes.len() as u64; + let range = 0..len * self.max_iterations; let arc = self.iterations_detected.clone(); let stream = futures::stream::repeat_with(move || { let arc_inner = arc.clone(); *arc_inner.lock().unwrap() += 1; Ok(bytes.clone()) }) - .take(self.max_iterations) + .take(self.max_iterations as usize) .boxed(); Ok(GetResult { @@ -127,18 +139,10 @@ mod tests { }) } - async fn get_opts( - &self, - _location: &Path, - _opts: GetOptions, - ) -> object_store::Result { - unimplemented!() - } - async fn get_ranges( &self, _location: &Path, - _ranges: &[Range], + _ranges: &[Range], ) -> object_store::Result> { unimplemented!() } @@ -154,7 +158,7 @@ mod tests { fn list( &self, _prefix: Option<&Path>, - ) -> BoxStream<'_, object_store::Result> { + ) -> BoxStream<'static, object_store::Result> { unimplemented!() } @@ -179,7 +183,7 @@ mod tests { } impl VariableStream { - pub fn new(bytes_to_repeat: Bytes, max_iterations: usize) -> Self { + pub fn new(bytes_to_repeat: Bytes, max_iterations: u64) -> Self { Self { bytes_to_repeat, max_iterations, @@ -216,8 +220,11 @@ mod tests { assert_eq!(tt_batches, 50 /* 100/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); Ok(()) } @@ -249,6 +256,7 @@ mod tests { let exec = scan_format( &state, &format, + None, root, "aggregate_test_100_with_nulls.csv", projection, @@ -299,6 +307,7 @@ mod tests { let exec = scan_format( &state, &format, + None, root, "aggregate_test_100_with_nulls.csv", projection, @@ -371,7 +380,7 @@ mod tests { let object_meta = ObjectMeta { location: Path::parse("/")?, last_modified: DateTime::default(), - size: usize::MAX, + size: u64::MAX, e_tag: None, version: None, }; @@ -429,7 +438,7 @@ mod tests { let object_meta = ObjectMeta { location: Path::parse("/")?, last_modified: DateTime::default(), - size: usize::MAX, + size: u64::MAX, e_tag: None, version: None, }; @@ -462,6 +471,59 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_infer_schema_stream_null_chunks() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + // a stream where each line is read as a separate chunk, + // data type for each chunk is inferred separately. + // +----+-----+----+ + // | c1 | c2 | c3 | + // +----+-----+----+ + // | 1 | 1.0 | | type: Int64, Float64, Null + // | | | | type: Null, Null, Null + // +----+-----+----+ + let chunked_object_store = Arc::new(ChunkedStore::new( + Arc::new(VariableStream::new( + Bytes::from( + r#"c1,c2,c3 +1,1.0, +,, +"#, + ), + 1, + )), + 1, + )); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: u64::MAX, + e_tag: None, + version: None, + }; + + let csv_format = CsvFormat::default().with_has_header(true); + let inferred_schema = csv_format + .infer_schema( + &state, + &(chunked_object_store as Arc), + &[object_meta], + ) + .await?; + + let actual_fields: Vec<_> = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + + // ensure null chunks don't skew type inference + assert_eq!(vec!["c1: Int64", "c2: Float64", "c3: Null"], actual_fields); + Ok(()) + } + #[rstest( file_compression_type, case(FileCompressionType::UNCOMPRESSED), @@ -581,7 +643,7 @@ mod tests { ) -> Result> { let root = format!("{}/csv", arrow_test_data()); let format = CsvFormat::default().with_has_header(has_header); - scan_format(state, &format, &root, file_name, projection, limit).await + scan_format(state, &format, None, &root, file_name, projection, limit).await } #[tokio::test] @@ -814,6 +876,128 @@ mod tests { Ok(()) } + /// Read multiple csv files (some are empty) with header + /// + /// some_empty_with_header + /// ├── a_empty.csv + /// ├── b.csv + /// └── c_nulls_column.csv + /// + /// a_empty.csv: + /// c1,c2,c3 + /// + /// b.csv: + /// c1,c2,c3 + /// 1,1,1 + /// 2,2,2 + /// + /// c_nulls_column.csv: + /// c1,c2,c3 + /// 3,3, + #[tokio::test] + async fn test_csv_some_empty_with_header() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv( + "some_empty_with_header", + "tests/data/empty_files/some_empty_with_header", + CsvReadOptions::new().has_header(true), + ) + .await?; + + let query = "select sum(c3) from some_empty_with_header;"; + let query_result = ctx.sql(query).await?.collect().await?; + + assert_snapshot!(batches_to_string(&query_result),@r" + +--------------------------------+ + | sum(some_empty_with_header.c3) | + +--------------------------------+ + | 3 | + +--------------------------------+ + "); + + Ok(()) + } + + #[tokio::test] + async fn test_csv_extension_compressed() -> Result<()> { + // Write compressed CSV files + // Expect: under the directory, a file is created with ".csv.gz" extension + let ctx = SessionContext::new(); + + let df = ctx + .read_csv( + &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()), + CsvReadOptions::default().has_header(true), + ) + .await?; + + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let cfg1 = crate::dataframe::DataFrameWriteOptions::new(); + let cfg2 = CsvOptions::default() + .with_has_header(true) + .with_compression(CompressionTypeVariant::GZIP); + + df.write_csv(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect(); + assert_eq!(files.len(), 1); + assert!(files + .last() + .unwrap() + .as_ref() + .unwrap() + .path() + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".csv.gz")); + + Ok(()) + } + + #[tokio::test] + async fn test_csv_extension_uncompressed() -> Result<()> { + // Write plain uncompressed CSV files + // Expect: under the directory, a file is created with ".csv" extension + let ctx = SessionContext::new(); + + let df = ctx + .read_csv( + &format!("{}/csv/aggregate_test_100.csv", arrow_test_data()), + CsvReadOptions::default().has_header(true), + ) + .await?; + + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let cfg1 = crate::dataframe::DataFrameWriteOptions::new(); + let cfg2 = CsvOptions::default().with_has_header(true); + + df.write_csv(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + let files: Vec<_> = std::fs::read_dir(&path).unwrap().collect(); + assert_eq!(files.len(), 1); + assert!(files + .last() + .unwrap() + .as_ref() + .unwrap() + .path() + .file_name() + .unwrap() + .to_str() + .unwrap() + .ends_with(".csv")); + + Ok(()) + } + /// Read multiple empty csv files /// /// all_empty @@ -1020,7 +1204,7 @@ mod tests { for _ in 0..batch_count { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])?; } @@ -1058,7 +1242,7 @@ mod tests { for _ in 0..batch_count { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])?; } @@ -1139,18 +1323,14 @@ mod tests { fn csv_line(line_number: usize) -> Bytes { let (int_value, float_value, bool_value, char_value) = csv_values(line_number); - format!( - "{},{},{},{}\n", - int_value, float_value, bool_value, char_value - ) - .into() + format!("{int_value},{float_value},{bool_value},{char_value}\n").into() } fn csv_values(line_number: usize) -> (i32, f64, bool, String) { let int_value = line_number as i32; let float_value = line_number as f64; - let bool_value = line_number % 2 == 0; - let char_value = format!("{}-string", line_number); + let bool_value = line_number.is_multiple_of(2); + let char_value = format!("{line_number}-string"); (int_value, float_value, bool_value, char_value) } @@ -1172,4 +1352,181 @@ mod tests { .build_decoder(); DecoderDeserializer::new(CsvDecoder::new(decoder)) } + + fn csv_deserializer_with_truncated( + batch_size: usize, + schema: &Arc, + ) -> impl BatchDeserializer { + // using Arrow's ReaderBuilder and enabling truncated_rows + let decoder = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .with_truncated_rows(true) // <- enable runtime truncated_rows + .build_decoder(); + DecoderDeserializer::new(CsvDecoder::new(decoder)) + } + + #[tokio::test] + async fn infer_schema_with_truncated_rows_true() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + // CSV: header has 3 columns, but first data row has only 2 columns, second row has 3 + let csv_data = Bytes::from("a,b,c\n1,2\n3,4,5\n"); + let variable_object_store = Arc::new(VariableStream::new(csv_data, 1)); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: u64::MAX, + e_tag: None, + version: None, + }; + + // Construct CsvFormat and enable truncated_rows via CsvOptions + let csv_options = CsvOptions::default().with_truncated_rows(true); + let csv_format = CsvFormat::default() + .with_has_header(true) + .with_options(csv_options) + .with_schema_infer_max_rec(10); + + let inferred_schema = csv_format + .infer_schema( + &state, + &(variable_object_store.clone() as Arc), + &[object_meta], + ) + .await?; + + // header has 3 columns; inferred schema should also have 3 + assert_eq!(inferred_schema.fields().len(), 3); + + // inferred columns should be nullable + for f in inferred_schema.fields() { + assert!(f.is_nullable()); + } + + Ok(()) + } + #[test] + fn test_decoder_truncated_rows_runtime() -> Result<()> { + // Synchronous test: Decoder API used here is synchronous + let schema = csv_schema(); // helper already defined in file + + // Construct a decoder that enables truncated_rows at runtime + let mut deserializer = csv_deserializer_with_truncated(10, &schema); + + // Provide two rows: first row complete, second row missing last column + let input = Bytes::from("0,0.0,true,0-string\n1,1.0,true\n"); + deserializer.digest(input); + + // Finish and collect output + deserializer.finish(); + + let output = deserializer.next()?; + match output { + DeserializerOutput::RecordBatch(batch) => { + // ensure at least two rows present + assert!(batch.num_rows() >= 2); + // column 4 (index 3) should be a StringArray where second row is NULL + let col4 = batch + .column(3) + .as_any() + .downcast_ref::() + .expect("column 4 should be StringArray"); + + // first row present, second row should be null + assert!(!col4.is_null(0)); + assert!(col4.is_null(1)); + } + other => panic!("expected RecordBatch but got {other:?}"), + } + Ok(()) + } + + #[tokio::test] + async fn infer_schema_truncated_rows_false_error() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + // CSV: header has 4 cols, first data row has 3 cols -> truncated at end + let csv_data = Bytes::from("id,a,b,c\n1,foo,bar\n2,foo,bar,baz\n"); + let variable_object_store = Arc::new(VariableStream::new(csv_data, 1)); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: u64::MAX, + e_tag: None, + version: None, + }; + + // CsvFormat without enabling truncated_rows (default behavior = false) + let csv_format = CsvFormat::default() + .with_has_header(true) + .with_schema_infer_max_rec(10); + + let res = csv_format + .infer_schema( + &state, + &(variable_object_store.clone() as Arc), + &[object_meta], + ) + .await; + + // Expect an error due to unequal lengths / incorrect number of fields + assert!( + res.is_err(), + "expected infer_schema to error on truncated rows when disabled" + ); + + // Optional: check message contains indicative text (two known possibilities) + if let Err(err) = res { + let msg = format!("{err}"); + assert!( + msg.contains("Encountered unequal lengths") + || msg.contains("incorrect number of fields"), + "unexpected error message: {msg}", + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_read_csv_truncated_rows_via_tempfile() -> Result<()> { + use std::io::Write; + + // create a SessionContext + let ctx = SessionContext::new(); + + // Create a temp file with a .csv suffix so the reader accepts it + let mut tmp = tempfile::Builder::new().suffix(".csv").tempfile()?; // ensures path ends with .csv + // CSV has header "a,b,c". First data row is truncated (only "1,2"), second row is complete. + write!(tmp, "a,b,c\n1,2\n3,4,5\n")?; + let path = tmp.path().to_str().unwrap().to_string(); + + // Build CsvReadOptions: header present, enable truncated_rows. + // (Use the exact builder method your crate exposes: `truncated_rows(true)` here, + // if the method name differs in your codebase use the appropriate one.) + let options = CsvReadOptions::default().truncated_rows(true); + + println!("options: {}, path: {path}", options.truncated_rows); + + // Call the API under test + let df = ctx.read_csv(&path, options).await?; + + // Collect the results and combine batches so we can inspect columns + let batches = df.collect().await?; + let combined = concat_batches(&batches[0].schema(), &batches)?; + + // Column 'c' is the 3rd column (index 2). The first data row was truncated -> should be NULL. + let col_c = combined.column(2); + assert!( + col_c.is_null(0), + "expected first row column 'c' to be NULL due to truncated row" + ); + + // Also ensure we read at least one row + assert!(combined.num_rows() >= 2); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index d533dcf7646da..34d3d64f07fb2 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -75,8 +75,11 @@ mod tests { assert_eq!(tt_batches, 6 /* 12/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); Ok(()) } @@ -149,7 +152,7 @@ mod tests { ) -> Result> { let filename = "tests/data/2.json"; let format = JsonFormat::default(); - scan_format(state, &format, ".", filename, projection, limit).await + scan_format(state, &format, None, ".", filename, projection, limit).await } #[tokio::test] @@ -275,7 +278,7 @@ mod tests { for _ in 0..3 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])? } @@ -315,7 +318,7 @@ mod tests { for _ in 0..2 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])? } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index e921f0158e540..e165707c2eb0e 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -36,19 +36,20 @@ pub use datafusion_datasource::write; #[cfg(test)] pub(crate) mod test_util { - use std::sync::Arc; - + use arrow_schema::SchemaRef; use datafusion_catalog::Session; use datafusion_common::Result; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::{file_format::FileFormat, PartitionedFile}; use datafusion_execution::object_store::ObjectStoreUrl; + use std::sync::Arc; use crate::test::object_store::local_unpartitioned_file; pub async fn scan_format( state: &dyn Session, format: &dyn FileFormat, + schema: Option, store_root: &str, file_name: &str, projection: Option>, @@ -57,9 +58,13 @@ pub(crate) mod test_util { let store = Arc::new(object_store::local::LocalFileSystem::new()) as _; let meta = local_unpartitioned_file(format!("{store_root}/{file_name}")); - let file_schema = format - .infer_schema(state, &store, std::slice::from_ref(&meta)) - .await?; + let file_schema = if let Some(file_schema) = schema { + file_schema + } else { + format + .infer_schema(state, &store, std::slice::from_ref(&meta)) + .await? + }; let statistics = format .infer_stats(state, &store, file_schema.clone(), &meta) @@ -88,7 +93,6 @@ pub(crate) mod test_util { .with_projection(projection) .with_limit(limit) .build(), - None, ) .await?; Ok(exec) @@ -127,7 +131,7 @@ mod tests { .write_parquet(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value d for column 0 at line 4"); + assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); Ok(()) } } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 08e9a628dd611..8c1bb02ef0737 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -34,7 +34,7 @@ use crate::error::Result; use crate::execution::context::{SessionConfig, SessionState}; use arrow::datatypes::{DataType, Schema, SchemaRef}; -use datafusion_common::config::TableOptions; +use datafusion_common::config::{ConfigFileDecryptionProperties, TableOptions}; use datafusion_common::{ DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, @@ -91,6 +91,11 @@ pub struct CsvReadOptions<'a> { pub file_sort_order: Vec>, /// Optional regex to match null values pub null_regex: Option, + /// Whether to allow truncated rows when parsing. + /// By default this is set to false and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns and fill the missing columns with nulls. + /// If the record’s schema is not nullable, then it will still return an error. + pub truncated_rows: bool, } impl Default for CsvReadOptions<'_> { @@ -117,6 +122,7 @@ impl<'a> CsvReadOptions<'a> { file_sort_order: vec![], comment: None, null_regex: None, + truncated_rows: false, } } @@ -223,6 +229,15 @@ impl<'a> CsvReadOptions<'a> { self.null_regex = null_regex; self } + + /// Configure whether to allow truncated rows when parsing. + /// By default this is set to false and will error if the CSV rows have different lengths + /// When set to true then it will allow records with less than the expected number of columns and fill the missing columns with nulls. + /// If the record’s schema is not nullable, then it will still return an error. + pub fn truncated_rows(mut self, truncated_rows: bool) -> Self { + self.truncated_rows = truncated_rows; + self + } } /// Options that control the reading of Parquet files. @@ -252,6 +267,8 @@ pub struct ParquetReadOptions<'a> { pub schema: Option<&'a Schema>, /// Indicates how the file is sorted pub file_sort_order: Vec>, + /// Properties for decryption of Parquet files that use modular encryption + pub file_decryption_properties: Option, } impl Default for ParquetReadOptions<'_> { @@ -263,6 +280,7 @@ impl Default for ParquetReadOptions<'_> { skip_metadata: None, schema: None, file_sort_order: vec![], + file_decryption_properties: None, } } } @@ -313,6 +331,15 @@ impl<'a> ParquetReadOptions<'a> { self.file_sort_order = file_sort_order; self } + + /// Configure file decryption properties for reading encrypted Parquet files + pub fn file_decryption_properties( + mut self, + file_decryption_properties: ConfigFileDecryptionProperties, + ) -> Self { + self.file_decryption_properties = Some(file_decryption_properties); + self + } } /// Options that control the reading of ARROW files. @@ -546,11 +573,12 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { .with_newlines_in_values(self.newlines_in_values) .with_schema_infer_max_rec(self.schema_infer_max_records) .with_file_compression_type(self.file_compression_type.to_owned()) - .with_null_regex(self.null_regex.clone()); + .with_null_regex(self.null_regex.clone()) + .with_truncated_rows(self.truncated_rows); ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) } @@ -574,7 +602,11 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { config: &SessionConfig, table_options: TableOptions, ) -> ListingOptions { - let mut file_format = ParquetFormat::new().with_options(table_options.parquet); + let mut options = table_options.parquet; + if let Some(file_decryption_properties) = &self.file_decryption_properties { + options.crypto.file_decryption = Some(file_decryption_properties.clone()); + } + let mut file_format = ParquetFormat::new().with_options(options); if let Some(parquet_pruning) = self.parquet_pruning { file_format = file_format.with_enable_pruning(parquet_pruning) @@ -585,9 +617,9 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) + .with_session_config_options(config) } async fn get_resolved_schema( @@ -615,7 +647,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) } @@ -643,7 +675,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) } @@ -669,7 +701,7 @@ impl ReadOptions<'_> for ArrowReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 27a7e7ae3c061..088c4408fff57 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -27,7 +27,10 @@ pub(crate) mod test_util { use crate::test::object_store::local_unpartitioned_file; - /// Writes `batches` to a temporary parquet file + /// Writes each `batch` to at least one temporary parquet file + /// + /// For example, if `batches` contains 2 batches, the function will create + /// 2 temporary files, each containing the contents of one batch /// /// If multi_page is set to `true`, the parquet file(s) are written /// with 2 rows per data page (used to test page filtering and @@ -52,7 +55,7 @@ pub(crate) mod test_util { } } - // we need the tmp files to be sorted as some tests rely on the how the returning files are ordered + // we need the tmp files to be sorted as some tests rely on the returned file ordering // https://github.com/apache/datafusion/pull/6629 let tmp_files = { let mut tmp_files: Vec<_> = (0..batches.len()) @@ -67,13 +70,13 @@ pub(crate) mod test_util { .into_iter() .zip(tmp_files.into_iter()) .map(|(batch, mut output)| { - let builder = parquet::file::properties::WriterProperties::builder(); - let props = if multi_page { - builder.set_data_page_row_count_limit(ROWS_PER_PAGE) - } else { - builder + let mut builder = parquet::file::properties::WriterProperties::builder(); + if multi_page { + builder = builder.set_data_page_row_count_limit(ROWS_PER_PAGE) } - .build(); + builder = builder.set_bloom_filter_enabled(true); + + let props = builder.build(); let mut writer = parquet::arrow::ArrowWriter::try_new( &mut output, @@ -104,10 +107,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::task::{Context, Poll}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -117,7 +118,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use arrow::array::RecordBatch; - use arrow_schema::{Schema, SchemaRef}; + use arrow_schema::Schema; use datafusion_catalog::Session; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, @@ -132,16 +133,16 @@ mod tests { use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; use datafusion_datasource::{ListingTableUrl, PartitionedFile}; use datafusion_datasource_parquet::{ - fetch_parquet_metadata, fetch_statistics, statistics_from_parquet_meta_calc, ParquetFormat, ParquetFormatFactory, ParquetSink, }; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::{RecordBatchStream, TaskContext}; + use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{collect, ExecutionPlan}; + use crate::test_util::bounded_stream; use arrow::array::{ types::Int32Type, Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, StringArray, @@ -149,15 +150,16 @@ mod tests { use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; + use datafusion_datasource_parquet::metadata::DFParquetMetadata; use futures::stream::BoxStream; - use futures::{Stream, StreamExt}; + use futures::StreamExt; use insta::assert_snapshot; use log::error; use object_store::local::LocalFileSystem; use object_store::ObjectMeta; use object_store::{ path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectStore, - PutMultipartOpts, PutOptions, PutPayload, PutResult, + PutMultipartOptions, PutOptions, PutPayload, PutResult, }; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; @@ -177,8 +179,8 @@ mod tests { let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); - let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); - let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())])?; + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)])?; let store = Arc::new(LocalFileSystem::new()) as _; let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; @@ -190,10 +192,14 @@ mod tests { ForceViews::No => false, }; let format = ParquetFormat::default().with_force_view_types(force_views); - let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); + let schema = format.infer_schema(&ctx, &store, &meta).await?; - let stats = - fetch_statistics(store.as_ref(), schema.clone(), &meta[0], None).await?; + let file_metadata_cache = + ctx.runtime_env().cache_manager.get_file_metadata_cache(); + let stats = DFParquetMetadata::new(&store, &meta[0]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; @@ -201,7 +207,11 @@ mod tests { assert_eq!(c1_stats.null_count, Precision::Exact(1)); assert_eq!(c2_stats.null_count, Precision::Exact(3)); - let stats = fetch_statistics(store.as_ref(), schema, &meta[1], None).await?; + let stats = DFParquetMetadata::new(&store, &meta[1]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; + assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; let c2_stats = &stats.column_statistics[1]; @@ -235,11 +245,9 @@ mod tests { let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); let batch1 = - RecordBatch::try_from_iter(vec![("a", c1.clone()), ("b", c1.clone())]) - .unwrap(); + RecordBatch::try_from_iter(vec![("a", c1.clone()), ("b", c1.clone())])?; let batch2 = - RecordBatch::try_from_iter(vec![("c", c2.clone()), ("d", c2.clone())]) - .unwrap(); + RecordBatch::try_from_iter(vec![("c", c2.clone()), ("d", c2.clone())])?; let store = Arc::new(LocalFileSystem::new()) as _; let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; @@ -247,7 +255,7 @@ mod tests { let session = SessionContext::new(); let ctx = session.state(); let format = ParquetFormat::default(); - let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); + let schema = format.infer_schema(&ctx, &store, &meta).await?; let order: Vec<_> = ["a", "b", "c", "d"] .into_iter() @@ -306,7 +314,7 @@ mod tests { async fn put_multipart_opts( &self, _location: &Path, - _opts: PutMultipartOpts, + _opts: PutMultipartOptions, ) -> object_store::Result> { Err(object_store::Error::NotImplemented) } @@ -331,7 +339,7 @@ mod tests { fn list( &self, _prefix: Option<&Path>, - ) -> BoxStream<'_, object_store::Result> { + ) -> BoxStream<'static, object_store::Result> { Box::pin(futures::stream::once(async { Err(object_store::Error::NotImplemented) })) @@ -363,24 +371,42 @@ mod tests { let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); - let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); - let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())])?; + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)])?; let store = Arc::new(RequestCountingObjectStore::new(Arc::new( LocalFileSystem::new(), ))); let (meta, _files) = store_parquet(vec![batch1, batch2], false).await?; + let session = SessionContext::new(); + let ctx = session.state(); + // Use a size hint larger than the parquet footer but smaller than the actual metadata, requiring a second fetch // for the remaining metadata - fetch_parquet_metadata(store.as_ref() as &dyn ObjectStore, &meta[0], Some(9)) - .await - .expect("error reading metadata with hint"); - + let file_metadata_cache = + ctx.runtime_env().cache_manager.get_file_metadata_cache(); + let df_meta = DFParquetMetadata::new(store.as_ref(), &meta[0]) + .with_metadata_size_hint(Some(9)); + df_meta.fetch_metadata().await?; assert_eq!(store.request_count(), 2); - let session = SessionContext::new(); - let ctx = session.state(); + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + + // Increases by 3 because cache has no entries yet + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 5); + + // No increase because cache has an entry + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 5); + + // Increase by 2 because `get_file_metadata_cache()` is None + let df_meta = df_meta.with_file_metadata_cache(None); + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 7); + let force_views = match force_views { ForceViews::Yes => true, ForceViews::No => false, @@ -388,14 +414,18 @@ mod tests { let format = ParquetFormat::default() .with_metadata_size_hint(Some(9)) .with_force_view_types(force_views); - let schema = format - .infer_schema(&ctx, &store.upcast(), &meta) - .await - .unwrap(); - - let stats = - fetch_statistics(store.upcast().as_ref(), schema.clone(), &meta[0], Some(9)) - .await?; + // Increase by 3, partial cache being used. + let _schema = format.infer_schema(&ctx, &store.upcast(), &meta).await?; + assert_eq!(store.request_count(), 10); + // No increase, full cache being used. + let schema = format.infer_schema(&ctx, &store.upcast(), &meta).await?; + assert_eq!(store.request_count(), 10); + + // No increase, cache being used + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + let stats = df_meta.fetch_statistics(&schema).await?; + assert_eq!(store.request_count(), 10); assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; @@ -408,29 +438,47 @@ mod tests { ))); // Use the file size as the hint so we can get the full metadata from the first fetch - let size_hint = meta[0].size; - - fetch_parquet_metadata(store.upcast().as_ref(), &meta[0], Some(size_hint)) - .await - .expect("error reading metadata with hint"); + let size_hint = meta[0].size as usize; + let df_meta = DFParquetMetadata::new(store.as_ref(), &meta[0]) + .with_metadata_size_hint(Some(size_hint)); + df_meta.fetch_metadata().await?; // ensure the requests were coalesced into a single request assert_eq!(store.request_count(), 1); + let session = SessionContext::new(); + let ctx = session.state(); + let file_metadata_cache = + ctx.runtime_env().cache_manager.get_file_metadata_cache(); + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + // Increases by 1 because cache has no entries yet and new session context + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 2); + + // No increase because cache has an entry + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 2); + + // Increase by 1 because `get_file_metadata_cache` is None + let df_meta = df_meta.with_file_metadata_cache(None); + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 3); + let format = ParquetFormat::default() .with_metadata_size_hint(Some(size_hint)) .with_force_view_types(force_views); - let schema = format - .infer_schema(&ctx, &store.upcast(), &meta) - .await - .unwrap(); - let stats = fetch_statistics( - store.upcast().as_ref(), - schema.clone(), - &meta[0], - Some(size_hint), - ) - .await?; + // Increase by 1, partial cache being used. + let _schema = format.infer_schema(&ctx, &store.upcast(), &meta).await?; + assert_eq!(store.request_count(), 4); + // No increase, full cache being used. + let schema = format.infer_schema(&ctx, &store.upcast(), &meta).await?; + assert_eq!(store.request_count(), 4); + // No increase, cache being used + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + let stats = df_meta.fetch_statistics(&schema).await?; + assert_eq!(store.request_count(), 4); assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; @@ -442,13 +490,18 @@ mod tests { LocalFileSystem::new(), ))); - // Use the a size hint larger than the file size to make sure we don't panic - let size_hint = meta[0].size + 100; + // Use a size hint larger than the file size to make sure we don't panic + let size_hint = (meta[0].size + 100) as usize; + let df_meta = DFParquetMetadata::new(store.as_ref(), &meta[0]) + .with_metadata_size_hint(Some(size_hint)); - fetch_parquet_metadata(store.upcast().as_ref(), &meta[0], Some(size_hint)) - .await - .expect("error reading metadata with hint"); + df_meta.fetch_metadata().await?; + assert_eq!(store.request_count(), 1); + // No increase because cache has an entry + let df_meta = + df_meta.with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))); + df_meta.fetch_metadata().await?; assert_eq!(store.request_count(), 1); Ok(()) @@ -467,25 +520,46 @@ mod tests { // Data for column c_dic: ["a", "b", "c", "d"] let values = StringArray::from_iter_values(["a", "b", "c", "d"]); let keys = Int32Array::from_iter_values([0, 1, 2, 3]); - let dic_array = - DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + let dic_array = DictionaryArray::::try_new(keys, Arc::new(values))?; let c_dic: ArrayRef = Arc::new(dic_array); - let batch1 = RecordBatch::try_from_iter(vec![("c_dic", c_dic)]).unwrap(); + // Data for column string_truncation: ["a".repeat(128), null, "b".repeat(128), null] + let string_truncation: ArrayRef = Arc::new(StringArray::from(vec![ + Some("a".repeat(128)), + None, + Some("b".repeat(128)), + None, + ])); + + let batch1 = RecordBatch::try_from_iter(vec![ + ("c_dic", c_dic), + ("string_truncation", string_truncation), + ])?; // Use store_parquet to write each batch to its own file // . batch1 written into first file and includes: // - column c_dic that has 4 rows with no null. Stats min and max of dictionary column is available. - let store = Arc::new(LocalFileSystem::new()) as _; + // - column string_truncation that has 4 rows with 2 nulls. Stats min and max of string column is available but not exact. + let store = Arc::new(RequestCountingObjectStore::new(Arc::new( + LocalFileSystem::new(), + ))); let (files, _file_names) = store_parquet(vec![batch1], false).await?; let state = SessionContext::new().state(); let format = ParquetFormat::default(); - let schema = format.infer_schema(&state, &store, &files).await.unwrap(); - - // Fetch statistics for first file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; - let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + let _schema = format.infer_schema(&state, &store.upcast(), &files).await?; + assert_eq!(store.request_count(), 3); + // No increase, cache being used. + let schema = format.infer_schema(&state, &store.upcast(), &files).await?; + assert_eq!(store.request_count(), 3); + + // No increase in request count because cache is not empty + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let stats = DFParquetMetadata::new(store.as_ref(), &files[0]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; assert_eq!(stats.num_rows, Precision::Exact(4)); // column c_dic @@ -501,6 +575,19 @@ mod tests { Precision::Exact(Utf8(Some("a".into()))) ); + // column string_truncation + let string_truncation_stats = &stats.column_statistics[1]; + + assert_eq!(string_truncation_stats.null_count, Precision::Exact(2)); + assert_eq!( + string_truncation_stats.max_value, + Precision::Inexact(ScalarValue::Utf8View(Some("b".repeat(63) + "c"))) + ); + assert_eq!( + string_truncation_stats.min_value, + Precision::Inexact(ScalarValue::Utf8View(Some("a".repeat(64)))) + ); + Ok(()) } @@ -510,18 +597,20 @@ mod tests { // Data for column c1: ["Foo", null, "bar"] let c1: ArrayRef = Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); - let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())])?; // Data for column c2: [1, 2, null] let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); - let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)])?; // Use store_parquet to write each batch to its own file // . batch1 written into first file and includes: // - column c1 that has 3 rows with one null. Stats min and max of string column is missing for this test even the column has values // . batch2 written into second file and includes: // - column c2 that has 3 rows with one null. Stats min and max of int are available and 1 and 2 respectively - let store = Arc::new(LocalFileSystem::new()) as _; + let store = Arc::new(RequestCountingObjectStore::new(Arc::new( + LocalFileSystem::new(), + ))); let (files, _file_names) = store_parquet(vec![batch1, batch2], false).await?; let force_views = match force_views { @@ -532,7 +621,8 @@ mod tests { let mut state = SessionContext::new().state(); state = set_view_state(state, force_views); let format = ParquetFormat::default().with_force_view_types(force_views); - let schema = format.infer_schema(&state, &store, &files).await.unwrap(); + let schema = format.infer_schema(&state, &store.upcast(), &files).await?; + assert_eq!(store.request_count(), 6); let null_i64 = ScalarValue::Int64(None); let null_utf8 = if force_views { @@ -541,9 +631,14 @@ mod tests { Utf8(None) }; - // Fetch statistics for first file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; - let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + // No increase in request count because cache is not empty + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let stats = DFParquetMetadata::new(store.as_ref(), &files[0]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; + assert_eq!(store.request_count(), 6); assert_eq!(stats.num_rows, Precision::Exact(3)); // column c1 let c1_stats = &stats.column_statistics[0]; @@ -567,9 +662,12 @@ mod tests { assert_eq!(c2_stats.max_value, Precision::Exact(null_i64.clone())); assert_eq!(c2_stats.min_value, Precision::Exact(null_i64.clone())); - // Fetch statistics for second file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[1], None).await?; - let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + // No increase in request count because cache is not empty + let stats = DFParquetMetadata::new(store.as_ref(), &files[1]) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .fetch_statistics(&schema) + .await?; + assert_eq!(store.request_count(), 6); assert_eq!(stats.num_rows, Precision::Exact(3)); // column c1: missing from the file so the table treats all 3 rows as null let c1_stats = &stats.column_statistics[0]; @@ -616,9 +714,15 @@ mod tests { assert_eq!(tt_batches, 4 /* 8/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); Ok(()) } @@ -659,9 +763,15 @@ mod tests { get_exec(&state, "alltypes_plain.parquet", projection, Some(1)).await?; // note: even if the limit is set, the executor rounds up to the batch size - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -987,22 +1097,20 @@ mod tests { async fn test_read_parquet_page_index() -> Result<()> { let testdata = datafusion_common::test_util::parquet_test_data(); let path = format!("{testdata}/alltypes_tiny_pages.parquet"); - let file = File::open(path).await.unwrap(); + let file = File::open(path).await?; let options = ArrowReaderOptions::new().with_page_index(true); let builder = ParquetRecordBatchStreamBuilder::new_with_options(file, options.clone()) - .await - .unwrap() + .await? .metadata() .clone(); check_page_index_validation(builder.column_index(), builder.offset_index()); let path = format!("{testdata}/alltypes_tiny_pages_plain.parquet"); - let file = File::open(path).await.unwrap(); + let file = File::open(path).await?; let builder = ParquetRecordBatchStreamBuilder::new_with_options(file, options) - .await - .unwrap() + .await? .metadata() .clone(); check_page_index_validation(builder.column_index(), builder.offset_index()); @@ -1073,15 +1181,18 @@ mod tests { let format = state .get_file_format_factory("parquet") .map(|factory| factory.create(state, &Default::default()).unwrap()) - .unwrap_or(Arc::new(ParquetFormat::new())); + .unwrap_or_else(|| Arc::new(ParquetFormat::new())); - scan_format(state, &*format, &testdata, file_name, projection, limit).await + scan_format( + state, &*format, None, &testdata, file_name, projection, limit, + ) + .await } /// Test that 0-byte files don't break while reading #[tokio::test] async fn test_read_empty_parquet() -> Result<()> { - let tmp_dir = tempfile::TempDir::new().unwrap(); + let tmp_dir = tempfile::TempDir::new()?; let path = format!("{}/empty.parquet", tmp_dir.path().to_string_lossy()); File::create(&path).await?; @@ -1105,12 +1216,10 @@ mod tests { /// Test that 0-byte files don't break while reading #[tokio::test] async fn test_read_partitioned_empty_parquet() -> Result<()> { - let tmp_dir = tempfile::TempDir::new().unwrap(); + let tmp_dir = tempfile::TempDir::new()?; let partition_dir = tmp_dir.path().join("col1=a"); - std::fs::create_dir(&partition_dir).unwrap(); - File::create(partition_dir.join("empty.parquet")) - .await - .unwrap(); + std::fs::create_dir(&partition_dir)?; + File::create(partition_dir.join("empty.parquet")).await?; let ctx = SessionContext::new(); @@ -1228,6 +1337,34 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_write_empty_recordbatch_creates_file() -> Result<()> { + let empty_record_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(Vec::::new()))], + ) + .expect("Failed to create empty RecordBatch"); + + let tmp_dir = tempfile::TempDir::new()?; + let path = format!("{}/empty2.parquet", tmp_dir.path().to_string_lossy()); + + let ctx = SessionContext::new(); + let df = ctx.read_batch(empty_record_batch.clone())?; + df.write_parquet(&path, crate::dataframe::DataFrameWriteOptions::new(), None) + .await?; + assert!(std::path::Path::new(&path).exists()); + + let stream = ctx + .read_parquet(&path, ParquetReadOptions::new()) + .await? + .execute_stream() + .await?; + assert_eq!(stream.schema(), empty_record_batch.schema()); + let results = stream.collect::>().await; + assert_eq!(results.len(), 0); + Ok(()) + } + #[tokio::test] async fn parquet_sink_write_insert_schema_into_metadata() -> Result<()> { // expected kv metadata without schema @@ -1305,7 +1442,7 @@ mod tests { #[tokio::test] async fn parquet_sink_write_with_extension() -> Result<()> { let filename = "test_file.custom_ext"; - let file_path = format!("file:///path/to/{}", filename); + let file_path = format!("file:///path/to/{filename}"); let parquet_sink = create_written_parquet_sink(file_path.as_str()).await?; // assert written to proper path @@ -1403,7 +1540,7 @@ mod tests { // create data let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); - let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)])?; // write stream FileSink::write_all( @@ -1482,7 +1619,7 @@ mod tests { // create data with 2 partitions let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); - let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)])?; // write stream FileSink::write_all( @@ -1520,8 +1657,7 @@ mod tests { let prefix = path_parts[0].as_ref(); assert!( expected_partitions.contains(prefix), - "expected path prefix to match partition, instead found {:?}", - prefix + "expected path prefix to match partition, instead found {prefix:?}" ); expected_partitions.remove(prefix); @@ -1576,8 +1712,7 @@ mod tests { // create data let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); - let batch = - RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + let batch = RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)])?; // create task context let task_context = build_ctx(object_store_url.as_ref()); @@ -1645,43 +1780,4 @@ mod tests { Ok(()) } - - /// Creates an bounded stream for testing purposes. - fn bounded_stream( - batch: RecordBatch, - limit: usize, - ) -> datafusion_execution::SendableRecordBatchStream { - Box::pin(BoundedStream { - count: 0, - limit, - batch, - }) - } - - struct BoundedStream { - limit: usize, - count: usize, - batch: RecordBatch, - } - - impl Stream for BoundedStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if self.count >= self.limit { - return Poll::Ready(None); - } - self.count += 1; - Poll::Ready(Some(Ok(self.batch.clone()))) - } - } - - impl RecordBatchStream for BoundedStream { - fn schema(&self) -> SchemaRef { - self.batch.schema() - } - } } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 6049614f37e8e..3ce58938d77e4 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -17,107 +17,199 @@ //! The table implementation. -use super::helpers::{expr_applicable_for_cols, pruned_partition_list}; -use super::{ListingTableUrl, PartitionedFile}; -use std::collections::HashMap; -use std::{any::Any, str::FromStr, sync::Arc}; - -use crate::datasource::{ - create_ordering, - file_format::{ - file_compression_type::FileCompressionType, FileFormat, FilePushdownSupport, - }, - physical_plan::FileSinkConfig, +use super::{ + helpers::{expr_applicable_for_cols, pruned_partition_list}, + ListingTableUrl, PartitionedFile, }; -use crate::execution::context::SessionState; -use datafusion_catalog::TableProvider; -use datafusion_common::{config_err, DataFusionError, Result}; -use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_expr::dml::InsertOp; -use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; -use datafusion_expr::{SortExpr, TableType}; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::{ExecutionPlan, Statistics}; - -use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, SchemaRef}; +use crate::{ + datasource::file_format::{file_compression_type::FileCompressionType, FileFormat}, + datasource::physical_plan::FileSinkConfig, + execution::context::SessionState, +}; +use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; +use arrow_schema::Schema; +use async_trait::async_trait; +use datafusion_catalog::{ScanArgs, ScanResult, Session, TableProvider}; use datafusion_common::{ - config_datafusion_err, internal_err, plan_err, project_schema, Constraints, - SchemaExt, ToDFSchema, + config_datafusion_err, config_err, internal_datafusion_err, internal_err, plan_err, + project_schema, stats::Precision, Constraints, DataFusionError, Result, SchemaExt, }; -use datafusion_execution::cache::{ - cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache, +use datafusion_datasource::{ + compute_all_files_statistics, + file::FileSource, + file_groups::FileGroup, + file_scan_config::{FileScanConfig, FileScanConfigBuilder}, + schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory}, }; -use datafusion_physical_expr::{ - create_physical_expr, LexOrdering, PhysicalSortRequirement, +use datafusion_execution::{ + cache::{cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache}, + config::SessionConfig, }; - -use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_common::stats::Precision; -use datafusion_datasource::add_row_stats; -use datafusion_datasource::compute_all_files_statistics; -use datafusion_datasource::file_groups::FileGroup; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::{ + dml::InsertOp, Expr, SortExpr, TableProviderFilterPushDown, TableType, +}; +use datafusion_physical_expr::create_lex_ordering; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; use futures::{future, stream, Stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; +use std::{any::Any, collections::HashMap, str::FromStr, sync::Arc}; + +/// Indicates the source of the schema for a [`ListingTable`] +// PartialEq required for assert_eq! in tests +#[derive(Debug, Clone, Copy, PartialEq, Default)] +pub enum SchemaSource { + /// Schema is not yet set (initial state) + #[default] + Unset, + /// Schema was inferred from first table_path + Inferred, + /// Schema was specified explicitly via with_schema + Specified, +} /// Configuration for creating a [`ListingTable`] -#[derive(Debug, Clone)] +/// +/// # Schema Evolution Support +/// +/// This configuration supports schema evolution through the optional +/// [`SchemaAdapterFactory`]. You might want to override the default factory when you need: +/// +/// - **Type coercion requirements**: When you need custom logic for converting between +/// different Arrow data types (e.g., Int32 ↔ Int64, Utf8 ↔ LargeUtf8) +/// - **Column mapping**: You need to map columns with a legacy name to a new name +/// - **Custom handling of missing columns**: By default they are filled in with nulls, but you may e.g. want to fill them in with `0` or `""`. +/// +/// If not specified, a [`DefaultSchemaAdapterFactory`] will be used, which handles +/// basic schema compatibility cases. +/// +#[derive(Debug, Clone, Default)] pub struct ListingTableConfig { /// Paths on the `ObjectStore` for creating `ListingTable`. /// They should share the same schema and object store. pub table_paths: Vec, /// Optional `SchemaRef` for the to be created `ListingTable`. + /// + /// See details on [`ListingTableConfig::with_schema`] pub file_schema: Option, - /// Optional `ListingOptions` for the to be created `ListingTable`. + /// Optional [`ListingOptions`] for the to be created [`ListingTable`]. + /// + /// See details on [`ListingTableConfig::with_listing_options`] pub options: Option, + /// Tracks the source of the schema information + schema_source: SchemaSource, + /// Optional [`SchemaAdapterFactory`] for creating schema adapters + schema_adapter_factory: Option>, + /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters + expr_adapter_factory: Option>, } impl ListingTableConfig { - /// Creates new [`ListingTableConfig`]. - /// - /// The [`SchemaRef`] and [`ListingOptions`] are inferred based on - /// the suffix of the provided `table_paths` first element. + /// Creates new [`ListingTableConfig`] for reading the specified URL pub fn new(table_path: ListingTableUrl) -> Self { - let table_paths = vec![table_path]; Self { - table_paths, - file_schema: None, - options: None, + table_paths: vec![table_path], + ..Default::default() } } /// Creates new [`ListingTableConfig`] with multiple table paths. /// - /// The [`SchemaRef`] and [`ListingOptions`] are inferred based on - /// the suffix of the provided `table_paths` first element. + /// See [`Self::infer_options`] for details on what happens with multiple paths pub fn new_with_multi_paths(table_paths: Vec) -> Self { Self { table_paths, - file_schema: None, - options: None, + ..Default::default() } } - /// Add `schema` to [`ListingTableConfig`] + + /// Returns the source of the schema for this configuration + pub fn schema_source(&self) -> SchemaSource { + self.schema_source + } + /// Set the `schema` for the overall [`ListingTable`] + /// + /// [`ListingTable`] will automatically coerce, when possible, the schema + /// for individual files to match this schema. + /// + /// If a schema is not provided, it is inferred using + /// [`Self::infer_schema`]. + /// + /// If the schema is provided, it must contain only the fields in the file + /// without the table partitioning columns. + /// + /// # Example: Specifying Table Schema + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion::datasource::listing::{ListingTableConfig, ListingOptions, ListingTableUrl}; + /// # use datafusion::datasource::file_format::parquet::ParquetFormat; + /// # use arrow::datatypes::{Schema, Field, DataType}; + /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("id", DataType::Int64, false), + /// Field::new("name", DataType::Utf8, true), + /// ])); + /// + /// let config = ListingTableConfig::new(table_paths) + /// .with_listing_options(listing_options) // Set options first + /// .with_schema(schema); // Then set schema + /// ``` pub fn with_schema(self, schema: SchemaRef) -> Self { + // Note: We preserve existing options state, but downstream code may expect + // options to be set. Consider calling with_listing_options() or infer_options() + // before operations that require options to be present. + debug_assert!( + self.options.is_some() || cfg!(test), + "ListingTableConfig::with_schema called without options set. \ + Consider calling with_listing_options() or infer_options() first to avoid panics in downstream code." + ); + Self { - table_paths: self.table_paths, file_schema: Some(schema), - options: self.options, + schema_source: SchemaSource::Specified, + ..self } } /// Add `listing_options` to [`ListingTableConfig`] + /// + /// If not provided, format and other options are inferred via + /// [`Self::infer_options`]. + /// + /// # Example: Configuring Parquet Files with Custom Options + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion::datasource::listing::{ListingTableConfig, ListingOptions, ListingTableUrl}; + /// # use datafusion::datasource::file_format::parquet::ParquetFormat; + /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// let options = ListingOptions::new(Arc::new(ParquetFormat::default())) + /// .with_file_extension(".parquet") + /// .with_collect_stat(true); + /// + /// let config = ListingTableConfig::new(table_paths) + /// .with_listing_options(options); // Configure file format and options + /// ``` pub fn with_listing_options(self, listing_options: ListingOptions) -> Self { + // Note: This method properly sets options, but be aware that downstream + // methods like infer_schema() and try_new() require both schema and options + // to be set to function correctly. + debug_assert!( + !self.table_paths.is_empty() || cfg!(test), + "ListingTableConfig::with_listing_options called without table_paths set. \ + Consider calling new() or new_with_multi_paths() first to establish table paths." + ); + Self { - table_paths: self.table_paths, - file_schema: self.file_schema, options: Some(listing_options), + ..self } } - ///Returns a tupe of (file_extension, optional compression_extension) + /// Returns a tuple of `(file_extension, optional compression_extension)` /// /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` /// For example a path ending with blah.test.csv returns `("csv", None)` @@ -126,20 +218,22 @@ impl ListingTableConfig { ) -> Result<(String, Option)> { let mut exts = path.rsplit('.'); - let splitted = exts.next().unwrap_or(""); + let split = exts.next().unwrap_or(""); - let file_compression_type = FileCompressionType::from_str(splitted) + let file_compression_type = FileCompressionType::from_str(split) .unwrap_or(FileCompressionType::UNCOMPRESSED); if file_compression_type.is_compressed() { - let splitted2 = exts.next().unwrap_or(""); - Ok((splitted2.to_string(), Some(splitted.to_string()))) + let split2 = exts.next().unwrap_or(""); + Ok((split2.to_string(), Some(split.to_string()))) } else { - Ok((splitted.to_string(), None)) + Ok((split.to_string(), None)) } } - /// Infer `ListingOptions` based on `table_path` suffix. + /// Infer `ListingOptions` based on `table_path` and file suffix. + /// + /// The format is inferred based on the first `table_path`. pub async fn infer_options(self, state: &dyn Session) -> Result { let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? @@ -155,7 +249,7 @@ impl ListingTableConfig { .await? .next() .await - .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; + .ok_or_else(|| internal_datafusion_err!("No files for table"))??; let (file_extension, maybe_compression_type) = ListingTableConfig::infer_file_extension_and_compression_type( @@ -184,41 +278,74 @@ impl ListingTableConfig { let listing_options = ListingOptions::new(file_format) .with_file_extension(listing_file_extension) - .with_target_partitions(state.config().target_partitions()); + .with_target_partitions(state.config().target_partitions()) + .with_collect_stat(state.config().collect_statistics()); Ok(Self { table_paths: self.table_paths, file_schema: self.file_schema, options: Some(listing_options), + schema_source: self.schema_source, + schema_adapter_factory: self.schema_adapter_factory, + expr_adapter_factory: self.expr_adapter_factory, }) } - /// Infer the [`SchemaRef`] based on `table_path` suffix. Requires `self.options` to be set prior to using. + /// Infer the [`SchemaRef`] based on `table_path`s. + /// + /// This method infers the table schema using the first `table_path`. + /// See [`ListingOptions::infer_schema`] for more details + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] pub async fn infer_schema(self, state: &dyn Session) -> Result { match self.options { Some(options) => { - let schema = if let Some(url) = self.table_paths.first() { - options.infer_schema(state, url).await? - } else { - Arc::new(Schema::empty()) + let ListingTableConfig { + table_paths, + file_schema, + options: _, + schema_source, + schema_adapter_factory, + expr_adapter_factory: physical_expr_adapter_factory, + } = self; + + let (schema, new_schema_source) = match file_schema { + Some(schema) => (schema, schema_source), // Keep existing source if schema exists + None => { + if let Some(url) = table_paths.first() { + ( + options.infer_schema(state, url).await?, + SchemaSource::Inferred, + ) + } else { + (Arc::new(Schema::empty()), SchemaSource::Inferred) + } + } }; Ok(Self { - table_paths: self.table_paths, + table_paths, file_schema: Some(schema), options: Some(options), + schema_source: new_schema_source, + schema_adapter_factory, + expr_adapter_factory: physical_expr_adapter_factory, }) } None => internal_err!("No `ListingOptions` set for inferring schema"), } } - /// Convenience wrapper for calling `infer_options` and `infer_schema` + /// Convenience method to call both [`Self::infer_options`] and [`Self::infer_schema`] pub async fn infer(self, state: &dyn Session) -> Result { self.infer_options(state).await?.infer_schema(state).await } - /// Infer the partition columns from the path. Requires `self.options` to be set prior to using. + /// Infer the partition columns from `table_paths`. + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] pub async fn infer_partitions_from_path(self, state: &dyn Session) -> Result { match self.options { Some(options) => { @@ -244,11 +371,80 @@ impl ListingTableConfig { table_paths: self.table_paths, file_schema: self.file_schema, options: Some(options), + schema_source: self.schema_source, + schema_adapter_factory: self.schema_adapter_factory, + expr_adapter_factory: self.expr_adapter_factory, }) } None => config_err!("No `ListingOptions` set for inferring schema"), } } + + /// Set the [`SchemaAdapterFactory`] for the [`ListingTable`] + /// + /// The schema adapter factory is used to create schema adapters that can + /// handle schema evolution and type conversions when reading files with + /// different schemas than the table schema. + /// + /// If not provided, a default schema adapter factory will be used. + /// + /// # Example: Custom Schema Adapter for Type Coercion + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion::datasource::listing::{ListingTableConfig, ListingOptions, ListingTableUrl}; + /// # use datafusion::datasource::schema_adapter::{SchemaAdapterFactory, SchemaAdapter}; + /// # use datafusion::datasource::file_format::parquet::ParquetFormat; + /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; + /// # + /// # #[derive(Debug)] + /// # struct MySchemaAdapterFactory; + /// # impl SchemaAdapterFactory for MySchemaAdapterFactory { + /// # fn create(&self, _projected_table_schema: SchemaRef, _file_schema: SchemaRef) -> Box { + /// # unimplemented!() + /// # } + /// # } + /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); + /// # let table_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + /// let config = ListingTableConfig::new(table_paths) + /// .with_listing_options(listing_options) + /// .with_schema(table_schema) + /// .with_schema_adapter_factory(Arc::new(MySchemaAdapterFactory)); + /// ``` + pub fn with_schema_adapter_factory( + self, + schema_adapter_factory: Arc, + ) -> Self { + Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self + } + } + + /// Get the [`SchemaAdapterFactory`] for this configuration + pub fn schema_adapter_factory(&self) -> Option<&Arc> { + self.schema_adapter_factory.as_ref() + } + + /// Set the [`PhysicalExprAdapterFactory`] for the [`ListingTable`] + /// + /// The expression adapter factory is used to create physical expression adapters that can + /// handle schema evolution and type conversions when evaluating expressions + /// with different schemas than the table schema. + /// + /// If not provided, a default physical expression adapter factory will be used unless a custom + /// `SchemaAdapterFactory` is set, in which case only the `SchemaAdapterFactory` will be used. + /// + /// See for details on this transition. + pub fn with_expr_adapter_factory( + self, + expr_adapter_factory: Arc, + ) -> Self { + Self { + expr_adapter_factory: Some(expr_adapter_factory), + ..self + } + } } /// Options for creating a [`ListingTable`] @@ -278,6 +474,7 @@ pub struct ListingOptions { /// parquet metadata. /// /// See + /// /// NOTE: This attribute stores all equivalent orderings (the outer `Vec`) /// where each ordering consists of an individual lexicographic /// ordering (encapsulated by a `Vec`). If there aren't @@ -292,18 +489,29 @@ impl ListingOptions { /// - use default file extension filter /// - no input partition to discover /// - one target partition - /// - stat collection + /// - do not collect statistics pub fn new(format: Arc) -> Self { Self { file_extension: format.get_ext(), format, table_partition_cols: vec![], - collect_stat: true, + collect_stat: false, target_partitions: 1, file_sort_order: vec![], } } + /// Set options from [`SessionConfig`] and returns self. + /// + /// Currently this sets `target_partitions` and `collect_stat` + /// but if more options are added in the future that need to be coordinated + /// they will be synchronized through this method. + pub fn with_session_config_options(mut self, config: &SessionConfig) -> Self { + self = self.with_target_partitions(config.target_partitions()); + self = self.with_collect_stat(config.collect_statistics()); + self + } + /// Set file extension on [`ListingOptions`] and returns self. /// /// # Example @@ -480,11 +688,13 @@ impl ListingOptions { } /// Infer the schema of the files at the given path on the provided object store. - /// The inferred schema does not include the partitioning columns. /// - /// This method will not be called by the table itself but before creating it. - /// This way when creating the logical plan we can decide to resolve the schema - /// locally or ask a remote service to do it (e.g a scheduler). + /// If the table_path contains one or more files (i.e. it is a directory / + /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`] + /// + /// Note: The inferred schema does not include any partitioning columns. + /// + /// This method is called as part of creating a [`ListingTable`]. pub async fn infer_schema<'a>( &'a self, state: &dyn Session, @@ -595,6 +805,9 @@ impl ListingOptions { .rev() .skip(1) // get parents only; skip the file itself .rev() + // Partitions are expected to follow the format "column_name=value", so we + // should ignore any path part that cannot be parsed into the expected format + .filter(|s| s.contains('=')) .map(|s| s.split('=').take(1).collect()) .collect_vec() }) @@ -612,13 +825,26 @@ impl ListingOptions { } } -/// Reads data from one or more files as a single table. +/// Built in [`TableProvider`] that reads data from one or more files as a single table. +/// +/// The files are read using an [`ObjectStore`] instance, for example from +/// local files or objects from AWS S3. +/// +/// # Features: +/// * Reading multiple files as a single table +/// * Hive style partitioning (e.g., directories named `date=2024-06-01`) +/// * Merges schemas from files with compatible but not identical schemas (see [`ListingTableConfig::file_schema`]) +/// * `limit`, `filter` and `projection` pushdown for formats that support it (e.g., +/// Parquet) +/// * Statistics collection and pruning based on file metadata +/// * Pre-existing sort order (see [`ListingOptions::file_sort_order`]) +/// * Metadata caching to speed up repeated queries (see [`FileMetadataCache`]) +/// * Statistics caching (see [`FileStatisticsCache`]) /// -/// Implements [`TableProvider`], a DataFusion data source. The files are read -/// using an [`ObjectStore`] instance, for example from local files or objects -/// from AWS S3. +/// [`FileMetadataCache`]: datafusion_execution::cache::cache_manager::FileMetadataCache +/// +/// # Reading Directories and Hive Style Partitioning /// -/// # Reading Directories /// For example, given the `table1` directory (or object store prefix) /// /// ```text @@ -654,19 +880,24 @@ impl ListingOptions { /// If the query has a predicate like `WHERE date = '2024-06-01'` /// only the corresponding directory will be read. /// -/// `ListingTable` also supports limit, filter and projection pushdown for formats that -/// support it as such as Parquet. -/// -/// # Implementation +/// # See Also /// -/// `ListingTable` Uses [`DataSourceExec`] to execute the data. See that struct -/// for more details. +/// 1. [`ListingTableConfig`]: Configuration options +/// 1. [`DataSourceExec`]: `ExecutionPlan` used by `ListingTable` /// /// [`DataSourceExec`]: crate::datasource::source::DataSourceExec /// -/// # Example +/// # Caching Metadata +/// +/// Some formats, such as Parquet, use the `FileMetadataCache` to cache file +/// metadata that is needed to execute but expensive to read, such as row +/// groups and statistics. The cache is scoped to the [`SessionContext`] and can +/// be configured via the [runtime config options]. +/// +/// [`SessionContext`]: crate::prelude::SessionContext +/// [runtime config options]: https://datafusion.apache.org/user-guide/configs.html#runtime-configuration-settings /// -/// To read a directory of parquet files using a [`ListingTable`]: +/// # Example: Read a directory of parquet files using a [`ListingTable`] /// /// ```no_run /// # use datafusion::prelude::SessionContext; @@ -713,39 +944,51 @@ impl ListingOptions { /// # Ok(()) /// # } /// ``` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ListingTable { table_paths: Vec, - /// File fields only + /// `file_schema` contains only the columns physically stored in the data files themselves. + /// - Represents the actual fields found in files like Parquet, CSV, etc. + /// - Used when reading the raw data from files file_schema: SchemaRef, - /// File fields + partition columns + /// `table_schema` combines `file_schema` + partition columns + /// - Partition columns are derived from directory paths (not stored in files) + /// - These are columns like "year=2022/month=01" in paths like `/data/year=2022/month=01/file.parquet` table_schema: SchemaRef, + /// Indicates how the schema was derived (inferred or explicitly specified) + schema_source: SchemaSource, + /// Options used to configure the listing table such as the file format + /// and partitioning information options: ListingOptions, + /// The SQL definition for this table, if any definition: Option, + /// Cache for collected file statistics collected_statistics: FileStatisticsCache, + /// Constraints applied to this table constraints: Constraints, + /// Column default expressions for columns that are not physically present in the data files column_defaults: HashMap, + /// Optional [`SchemaAdapterFactory`] for creating schema adapters + schema_adapter_factory: Option>, + /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters + expr_adapter_factory: Option>, } impl ListingTable { - /// Create new [`ListingTable`] that lists the FS to get the files - /// to scan. See [`ListingTable`] for and example. - /// - /// Takes a `ListingTableConfig` as input which requires an `ObjectStore` and `table_path`. - /// `ListingOptions` and `SchemaRef` are optional. If they are not - /// provided the file type is inferred based on the file suffix. - /// If the schema is provided then it must be resolved before creating the table - /// and should contain the fields of the file without the table - /// partitioning columns. + /// Create new [`ListingTable`] /// + /// See documentation and example on [`ListingTable`] and [`ListingTableConfig`] pub fn try_new(config: ListingTableConfig) -> Result { + // Extract schema_source before moving other parts of the config + let schema_source = config.schema_source(); + let file_schema = config .file_schema - .ok_or_else(|| DataFusionError::Internal("No schema provided.".into()))?; + .ok_or_else(|| internal_datafusion_err!("No schema provided."))?; - let options = config.options.ok_or_else(|| { - DataFusionError::Internal("No ListingOptions provided".into()) - })?; + let options = config + .options + .ok_or_else(|| internal_datafusion_err!("No ListingOptions provided"))?; // Add the partition columns to the file schema let mut builder = SchemaBuilder::from(file_schema.as_ref().to_owned()); @@ -763,11 +1006,14 @@ impl ListingTable { table_paths: config.table_paths, file_schema, table_schema, + schema_source, options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), + schema_adapter_factory: config.schema_adapter_factory, + expr_adapter_factory: config.expr_adapter_factory, }; Ok(table) @@ -796,7 +1042,7 @@ impl ListingTable { /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. pub fn with_cache(mut self, cache: Option) -> Self { self.collected_statistics = - cache.unwrap_or(Arc::new(DefaultFileStatisticsCache::default())); + cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); self } @@ -816,15 +1062,91 @@ impl ListingTable { &self.options } + /// Get the schema source + pub fn schema_source(&self) -> SchemaSource { + self.schema_source + } + + /// Set the [`SchemaAdapterFactory`] for this [`ListingTable`] + /// + /// The schema adapter factory is used to create schema adapters that can + /// handle schema evolution and type conversions when reading files with + /// different schemas than the table schema. + /// + /// # Example: Adding Schema Evolution Support + /// ```rust + /// # use std::sync::Arc; + /// # use datafusion::datasource::listing::{ListingTable, ListingTableConfig, ListingOptions, ListingTableUrl}; + /// # use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapter}; + /// # use datafusion::datasource::file_format::parquet::ParquetFormat; + /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; + /// # let table_path = ListingTableUrl::parse("file:///path/to/data").unwrap(); + /// # let options = ListingOptions::new(Arc::new(ParquetFormat::default())); + /// # let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); + /// # let config = ListingTableConfig::new(table_path).with_listing_options(options).with_schema(schema); + /// # let table = ListingTable::try_new(config).unwrap(); + /// let table_with_evolution = table + /// .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)); + /// ``` + /// See [`ListingTableConfig::with_schema_adapter_factory`] for an example of custom SchemaAdapterFactory. + pub fn with_schema_adapter_factory( + self, + schema_adapter_factory: Arc, + ) -> Self { + Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self + } + } + + /// Get the [`SchemaAdapterFactory`] for this table + pub fn schema_adapter_factory(&self) -> Option<&Arc> { + self.schema_adapter_factory.as_ref() + } + + /// Creates a schema adapter for mapping between file and table schemas + /// + /// Uses the configured schema adapter factory if available, otherwise falls back + /// to the default implementation. + fn create_schema_adapter(&self) -> Box { + let table_schema = self.schema(); + match &self.schema_adapter_factory { + Some(factory) => { + factory.create_with_projected_schema(Arc::clone(&table_schema)) + } + None => DefaultSchemaAdapterFactory::from_schema(Arc::clone(&table_schema)), + } + } + + /// Creates a file source and applies schema adapter factory if available + fn create_file_source_with_schema_adapter(&self) -> Result> { + let mut source = self.options.format.file_source(); + // Apply schema adapter to source if available + // + // The source will use this SchemaAdapter to adapt data batches as they flow up the plan. + // Note: ListingTable also creates a SchemaAdapter in `scan()` but that is only used to adapt collected statistics. + if let Some(factory) = &self.schema_adapter_factory { + source = source.with_schema_adapter_factory(Arc::clone(factory))?; + } + Ok(source) + } + /// If file_sort_order is specified, creates the appropriate physical expressions - fn try_create_output_ordering(&self) -> Result> { - create_ordering(&self.table_schema, &self.options.file_sort_order) + fn try_create_output_ordering( + &self, + execution_props: &ExecutionProps, + ) -> Result> { + create_lex_ordering( + &self.table_schema, + &self.options.file_sort_order, + execution_props, + ) } } -// Expressions can be used for parttion pruning if they can be evaluated using -// only the partiton columns and there are partition columns. -fn can_be_evaluted_for_partition_pruning( +// Expressions can be used for partition pruning if they can be evaluated using +// only the partition columns and there are partition columns. +fn can_be_evaluated_for_partition_pruning( partition_column_names: &[&str], expr: &Expr, ) -> bool { @@ -857,6 +1179,22 @@ impl TableProvider for ListingTable { filters: &[Expr], limit: Option, ) -> Result> { + let options = ScanArgs::default() + .with_projection(projection.map(|p| p.as_slice())) + .with_filters(Some(filters)) + .with_limit(limit); + Ok(self.scan_with_args(state, options).await?.into_inner()) + } + + async fn scan_with_args<'a>( + &self, + state: &dyn Session, + args: ScanArgs<'a>, + ) -> Result { + let projection = args.projection().map(|p| p.to_vec()); + let filters = args.filters().map(|f| f.to_vec()).unwrap_or_default(); + let limit = args.limit(); + // extract types of partition columns let table_partition_cols = self .options @@ -869,40 +1207,40 @@ impl TableProvider for ListingTable { .iter() .map(|field| field.name().as_str()) .collect::>(); + // If the filters can be resolved using only partition cols, there is no need to // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated let (partition_filters, filters): (Vec<_>, Vec<_>) = filters.iter().cloned().partition(|filter| { - can_be_evaluted_for_partition_pruning(&table_partition_col_names, filter) + can_be_evaluated_for_partition_pruning(&table_partition_col_names, filter) }); - // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? - let session_state = state.as_any().downcast_ref::().unwrap(); // We should not limit the number of partitioned files to scan if there are filters and limit // at the same time. This is because the limit should be applied after the filters are applied. let statistic_file_limit = if filters.is_empty() { limit } else { None }; let (mut partitioned_file_lists, statistics) = self - .list_files_for_scan(session_state, &partition_filters, statistic_file_limit) + .list_files_for_scan(state, &partition_filters, statistic_file_limit) .await?; // if no files need to be read, return an `EmptyExec` if partitioned_file_lists.is_empty() { - let projected_schema = project_schema(&self.schema(), projection)?; - return Ok(Arc::new(EmptyExec::new(projected_schema))); + let projected_schema = project_schema(&self.schema(), projection.as_ref())?; + return Ok(ScanResult::new(Arc::new(EmptyExec::new(projected_schema)))); } - let output_ordering = self.try_create_output_ordering()?; + let output_ordering = self.try_create_output_ordering(state.execution_props())?; match state .config_options() .execution .split_file_groups_by_statistics .then(|| { output_ordering.first().map(|output_ordering| { - FileScanConfig::split_groups_by_statistics( + FileScanConfig::split_groups_by_statistics_with_target_partitions( &self.table_schema, &partitioned_file_lists, output_ordering, + self.options.target_partitions, ) }) }) @@ -919,46 +1257,40 @@ impl TableProvider for ListingTable { None => {} // no ordering required }; - let filters = match conjunction(filters.to_vec()) { - Some(expr) => { - let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; - let filters = create_physical_expr( - &expr, - &table_df_schema, - state.execution_props(), - )?; - Some(filters) - } - None => None, - }; - let Some(object_store_url) = self.table_paths.first().map(ListingTableUrl::object_store) else { - return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); + return Ok(ScanResult::new(Arc::new(EmptyExec::new(Arc::new( + Schema::empty(), + ))))); }; + let file_source = self.create_file_source_with_schema_adapter()?; + // create the execution plan - self.options + let plan = self + .options .format .create_physical_plan( - session_state, + state, FileScanConfigBuilder::new( object_store_url, Arc::clone(&self.file_schema), - self.options.format.file_source(), + file_source, ) .with_file_groups(partitioned_file_lists) .with_constraints(self.constraints.clone()) .with_statistics(statistics) - .with_projection(projection.cloned()) + .with_projection(projection) .with_limit(limit) .with_output_ordering(output_ordering) .with_table_partition_cols(table_partition_cols) + .with_expr_adapter(self.expr_adapter_factory.clone()) .build(), - filters.as_ref(), ) - .await + .await?; + + Ok(ScanResult::new(plan)) } fn supports_filters_pushdown( @@ -974,24 +1306,12 @@ impl TableProvider for ListingTable { filters .iter() .map(|filter| { - if can_be_evaluted_for_partition_pruning(&partition_column_names, filter) + if can_be_evaluated_for_partition_pruning(&partition_column_names, filter) { // if filter can be handled by partition pruning, it is exact return Ok(TableProviderFilterPushDown::Exact); } - // if we can't push it down completely with only the filename-based/path-based - // column names, then we should check if we can do parquet predicate pushdown - let supports_pushdown = self.options.format.supports_filters_pushdown( - &self.file_schema, - &self.table_schema, - &[filter], - )?; - - if supports_pushdown == FilePushdownSupport::Supported { - return Ok(TableProviderFilterPushDown::Exact); - } - Ok(TableProviderFilterPushDown::Inexact) }) .collect() @@ -1022,10 +1342,8 @@ impl TableProvider for ListingTable { // Get the object store for the table path. let store = state.runtime_env().object_store(table_path)?; - // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? - let session_state = state.as_any().downcast_ref::().unwrap(); let file_list_stream = pruned_partition_list( - session_state, + state, store.as_ref(), table_path, &[], @@ -1051,29 +1369,13 @@ impl TableProvider for ListingTable { file_extension: self.options().format.get_ext(), }; - let order_requirements = if !self.options().file_sort_order.is_empty() { - // Multiple sort orders in outer vec are equivalent, so we pass only the first one - let orderings = self.try_create_output_ordering()?; - let Some(ordering) = orderings.first() else { - return internal_err!( - "Expected ListingTable to have a sort order, but none found!" - ); - }; - // Converts Vec> into type required by execution plan to specify its required input ordering - Some(LexRequirement::new( - ordering - .into_iter() - .cloned() - .map(PhysicalSortRequirement::from) - .collect::>(), - )) - } else { - None - }; + let orderings = self.try_create_output_ordering(state.execution_props())?; + // It is sufficient to pass only one of the equivalent orderings: + let order_requirements = orderings.into_iter().next().map(Into::into); self.options() .format - .create_writer_physical_plan(input, session_state, config, order_requirements) + .create_writer_physical_plan(input, state, config, order_requirements) .await } @@ -1130,12 +1432,26 @@ impl ListingTable { get_files_with_limit(files, limit, self.options.collect_stat).await?; let file_groups = file_group.split_files(self.options.target_partitions); - compute_all_files_statistics( + let (mut file_groups, mut stats) = compute_all_files_statistics( file_groups, self.schema(), self.options.collect_stat, inexact_stats, - ) + )?; + + let schema_adapter = self.create_schema_adapter(); + let (schema_mapper, _) = schema_adapter.map_schema(self.file_schema.as_ref())?; + + stats.column_statistics = + schema_mapper.map_column_statistics(&stats.column_statistics)?; + file_groups.iter_mut().try_for_each(|file_group| { + if let Some(stat) = file_group.statistics_mut() { + stat.column_statistics = + schema_mapper.map_column_statistics(&stat.column_statistics)?; + } + Ok::<_, DataFusionError>(()) + })?; + Ok((file_groups, stats)) } /// Collects statistics for a given partitioned file. @@ -1186,7 +1502,7 @@ impl ListingTable { /// # Arguments /// * `files` - A stream of `Result` items to process /// * `limit` - An optional row count limit. If provided, the function will stop collecting files -/// once the accumulated number of rows exceeds this limit +/// once the accumulated number of rows exceeds this limit /// * `collect_stats` - Whether to collect and accumulate statistics from the files /// /// # Returns @@ -1229,7 +1545,7 @@ async fn get_files_with_limit( file_stats.num_rows } else { // For subsequent files, accumulate the counts - add_row_stats(num_rows, file_stats.num_rows) + num_rows.add(&file_stats.num_rows) }; } } @@ -1256,100 +1572,146 @@ async fn get_files_with_limit( #[cfg(test)] mod tests { use super::*; - use crate::datasource::file_format::csv::CsvFormat; - use crate::datasource::file_format::json::JsonFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; - use crate::datasource::{provider_as_source, DefaultTableSource, MemTable}; - use crate::execution::options::ArrowReadOptions; use crate::prelude::*; - use crate::test::{columns, object_store::register_test_store}; - - use arrow::compute::SortOptions; - use arrow::record_batch::RecordBatch; - use datafusion_common::stats::Precision; - use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_contains, ScalarValue}; + use crate::{ + datasource::{ + file_format::csv::CsvFormat, file_format::json::JsonFormat, + provider_as_source, DefaultTableSource, MemTable, + }, + execution::options::ArrowReadOptions, + test::{ + columns, object_store::ensure_head_concurrency, + object_store::make_test_store_and_state, object_store::register_test_store, + }, + }; + use arrow::{compute::SortOptions, record_batch::RecordBatch}; + use datafusion_common::{ + assert_contains, + stats::Precision, + test_util::{batches_to_string, datafusion_test_data}, + ColumnStatistics, ScalarValue, + }; + use datafusion_datasource::schema_adapter::{ + SchemaAdapter, SchemaAdapterFactory, SchemaMapper, + }; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; + use datafusion_physical_expr::expressions::binary; use datafusion_physical_expr::PhysicalSortExpr; - use datafusion_physical_plan::collect; - use datafusion_physical_plan::ExecutionPlanProperties; - - use crate::test::object_store::{ensure_head_concurrency, make_test_store_and_state}; + use datafusion_physical_plan::{collect, ExecutionPlanProperties}; + use rstest::rstest; + use std::io::Write; use tempfile::TempDir; use url::Url; - #[tokio::test] - async fn read_single_file() -> Result<()> { - let ctx = SessionContext::new(); - - let table = load_table(&ctx, "alltypes_plain.parquet").await?; - let projection = None; - let exec = table - .scan(&ctx.state(), projection, &[], None) - .await - .expect("Scan table"); + const DUMMY_NULL_COUNT: Precision = Precision::Exact(42); - assert_eq!(exec.children().len(), 0); - assert_eq!(exec.output_partitioning().partition_count(), 1); + /// Creates a test schema with standard field types used in tests + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Float32, true), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::Boolean, true), + Field::new("c4", DataType::Utf8, true), + ])) + } - // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + /// Helper function to generate test file paths with given prefix, count, and optional start index + fn generate_test_files(prefix: &str, count: usize) -> Vec { + generate_test_files_with_start(prefix, count, 0) + } - Ok(()) + /// Helper function to generate test file paths with given prefix, count, and start index + fn generate_test_files_with_start( + prefix: &str, + count: usize, + start_index: usize, + ) -> Vec { + (start_index..start_index + count) + .map(|i| format!("{prefix}/file{i}")) + .collect() } - #[cfg(feature = "parquet")] #[tokio::test] - async fn load_table_stats_by_default() -> Result<()> { - use crate::datasource::file_format::parquet::ParquetFormat; - - let testdata = crate::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); + async fn test_schema_source_tracking_comprehensive() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion_test_data(); + let filename = format!("{testdata}/aggregate_simple.csv"); let table_path = ListingTableUrl::parse(filename).unwrap(); - let ctx = SessionContext::new(); - let state = ctx.state(); + // Test default schema source + let config = ListingTableConfig::new(table_path.clone()); + assert_eq!(config.schema_source(), SchemaSource::Unset); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); - let schema = opt.infer_schema(&state, &table_path).await?; - let config = ListingTableConfig::new(table_path) - .with_listing_options(opt) - .with_schema(schema); - let table = ListingTable::try_new(config)?; + // Test schema source after setting a schema explicitly + let provided_schema = create_test_schema(); + let config_with_schema = config.clone().with_schema(provided_schema.clone()); + assert_eq!(config_with_schema.schema_source(), SchemaSource::Specified); - let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + // Test schema source after inferring schema + let format = CsvFormat::default(); + let options = ListingOptions::new(Arc::new(format)); + let config_with_options = config.with_listing_options(options.clone()); + assert_eq!(config_with_options.schema_source(), SchemaSource::Unset); + + let config_with_inferred = config_with_options.infer_schema(&ctx.state()).await?; + assert_eq!(config_with_inferred.schema_source(), SchemaSource::Inferred); + + // Test schema preservation through operations + let config_with_schema_and_options = config_with_schema + .clone() + .with_listing_options(options.clone()); + assert_eq!( + config_with_schema_and_options.schema_source(), + SchemaSource::Specified + ); + + // Make sure inferred schema doesn't override specified schema + let config_with_schema_and_infer = config_with_schema_and_options + .clone() + .infer(&ctx.state()) + .await?; + assert_eq!( + config_with_schema_and_infer.schema_source(), + SchemaSource::Specified + ); + + // Verify sources in actual ListingTable objects + let table_specified = ListingTable::try_new(config_with_schema_and_options)?; + assert_eq!(table_specified.schema_source(), SchemaSource::Specified); + + let table_inferred = ListingTable::try_new(config_with_inferred)?; + assert_eq!(table_inferred.schema_source(), SchemaSource::Inferred); Ok(()) } - #[cfg(feature = "parquet")] #[tokio::test] - async fn load_table_stats_when_no_stats() -> Result<()> { - use crate::datasource::file_format::parquet::ParquetFormat; - - let testdata = crate::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + async fn read_single_file() -> Result<()> { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_collect_statistics(true), + ); - let ctx = SessionContext::new(); - let state = ctx.state(); + let table = load_table(&ctx, "alltypes_plain.parquet").await?; + let projection = None; + let exec = table + .scan(&ctx.state(), projection, &[], None) + .await + .expect("Scan table"); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())) - .with_collect_stat(false); - let schema = opt.infer_schema(&state, &table_path).await?; - let config = ListingTableConfig::new(table_path) - .with_listing_options(opt) - .with_schema(schema); - let table = ListingTable::try_new(config)?; + assert_eq!(exec.children().len(), 0); + assert_eq!(exec.output_partitioning().partition_count(), 1); - let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + // test metadata + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); Ok(()) } @@ -1368,31 +1730,44 @@ mod tests { use crate::datasource::file_format::parquet::ParquetFormat; use datafusion_physical_plan::expressions::col as physical_col; + use datafusion_physical_plan::expressions::lit as physical_lit; use std::ops::Add; // (file_sort_order, expected_result) let cases = vec![ - (vec![], Ok(vec![])), + ( + vec![], + Ok::, DataFusionError>(Vec::::new()), + ), // sort expr, but non column ( - vec![vec![ - col("int_col").add(lit(1)).sort(true, true), - ]], - Err("Expected single column reference in sort_order[0][0], got int_col + Int32(1)"), + vec![vec![col("int_col").add(lit(1)).sort(true, true)]], + Ok(vec![[PhysicalSortExpr { + expr: binary( + physical_col("int_col", &schema).unwrap(), + Operator::Plus, + physical_lit(1), + &schema, + ) + .unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }] + .into()]), ), // ok with one column ( vec![vec![col("string_col").sort(true, false)]], - Ok(vec![LexOrdering::new( - vec![PhysicalSortExpr { - expr: physical_col("string_col", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }], - ) - ]) + Ok(vec![[PhysicalSortExpr { + expr: physical_col("string_col", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }] + .into()]), ), // ok with two columns, different options ( @@ -1400,17 +1775,19 @@ mod tests { col("string_col").sort(true, false), col("int_col").sort(false, true), ]], - Ok(vec![LexOrdering::new( - vec![ - PhysicalSortExpr::new_default(physical_col("string_col", &schema).unwrap()) - .asc() - .nulls_last(), - PhysicalSortExpr::new_default(physical_col("int_col", &schema).unwrap()) - .desc() - .nulls_first() - ], - ) - ]) + Ok(vec![[ + PhysicalSortExpr::new_default( + physical_col("string_col", &schema).unwrap(), + ) + .asc() + .nulls_last(), + PhysicalSortExpr::new_default( + physical_col("int_col", &schema).unwrap(), + ) + .desc() + .nulls_first(), + ] + .into()]), ), ]; @@ -1423,7 +1800,8 @@ mod tests { let table = ListingTable::try_new(config.clone()).expect("Creating the table"); - let ordering_result = table.try_create_output_ordering(); + let ordering_result = + table.try_create_output_ordering(state.execution_props()); match (expected_result, ordering_result) { (Ok(expected), Ok(result)) => { @@ -1488,295 +1866,38 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_assert_list_files_for_scan_grouping() -> Result<()> { - // more expected partitions than files - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - "bucket/key-prefix/file4", - ], - "test:///bucket/key-prefix/", - 12, - 5, - Some(""), - ) - .await?; + async fn load_table( + ctx: &SessionContext, + name: &str, + ) -> Result> { + let testdata = crate::test_util::parquet_test_data(); + let filename = format!("{testdata}/{name}"); + let table_path = ListingTableUrl::parse(filename).unwrap(); - // as many expected partitions as files - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - ], - "test:///bucket/key-prefix/", - 4, - 4, - Some(""), - ) - .await?; + let config = ListingTableConfig::new(table_path) + .infer(&ctx.state()) + .await?; + let table = ListingTable::try_new(config)?; + Ok(Arc::new(table)) + } - // more files as expected partitions - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - "bucket/key-prefix/file4", - ], - "test:///bucket/key-prefix/", - 2, - 2, - Some(""), - ) - .await?; + /// Check that the files listed by the table match the specified `output_partitioning` + /// when the object store contains `files`. + async fn assert_list_files_for_scan_grouping( + files: &[&str], + table_prefix: &str, + target_partitions: usize, + output_partitioning: usize, + file_ext: Option<&str>, + ) -> Result<()> { + let ctx = SessionContext::new(); + register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()); - // no files => no groups - assert_list_files_for_scan_grouping( - &[], - "test:///bucket/key-prefix/", - 2, - 0, - Some(""), - ) - .await?; + let opt = ListingOptions::new(Arc::new(JsonFormat::default())) + .with_file_extension_opt(file_ext) + .with_target_partitions(target_partitions); - // files that don't match the prefix - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/other-prefix/roguefile", - ], - "test:///bucket/key-prefix/", - 10, - 2, - Some(""), - ) - .await?; - - // files that don't match the prefix or the default file extention - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0.json", - "bucket/key-prefix/file1.parquet", - "bucket/other-prefix/roguefile.json", - ], - "test:///bucket/key-prefix/", - 10, - 1, - None, - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_assert_list_files_for_multi_path() -> Result<()> { - // more expected partitions than files - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/", "test:///bucket/key2/"], - 12, - 5, - Some(""), - ) - .await?; - - // as many expected partitions as files - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/", "test:///bucket/key2/"], - 5, - 5, - Some(""), - ) - .await?; - - // more files as expected partitions - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/"], - 2, - 2, - Some(""), - ) - .await?; - - // no files => no groups - assert_list_files_for_multi_paths(&[], &["test:///bucket/key1/"], 2, 0, Some("")) - .await?; - - // files that don't match the prefix - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key3/"], - 2, - 1, - Some(""), - ) - .await?; - - // files that don't match the prefix or the default file ext - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0.json", - "bucket/key1/file1.csv", - "bucket/key1/file2.json", - "bucket/key2/file3.csv", - "bucket/key2/file4.json", - "bucket/key3/file5.csv", - ], - &["test:///bucket/key1/", "test:///bucket/key3/"], - 2, - 2, - None, - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_assert_list_files_for_exact_paths() -> Result<()> { - // more expected partitions than files - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 12, - 5, - Some(""), - ) - .await?; - - // more files than meta_fetch_concurrency (32) - let files: Vec = - (0..64).map(|i| format!("bucket/key1/file{}", i)).collect(); - // Collect references to each string - let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); - assert_list_files_for_exact_paths(file_refs.as_slice(), 5, 5, Some("")).await?; - - // as many expected partitions as files - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 5, - 5, - Some(""), - ) - .await?; - - // more files as expected partitions - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 2, - 2, - Some(""), - ) - .await?; - - // no files => no groups - assert_list_files_for_exact_paths(&[], 2, 0, Some("")).await?; - - // files that don't match the default file ext - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0.json", - "bucket/key1/file1.csv", - "bucket/key1/file2.json", - "bucket/key2/file3.csv", - "bucket/key2/file4.json", - "bucket/key3/file5.csv", - ], - 2, - 2, - None, - ) - .await?; - Ok(()) - } - - async fn load_table( - ctx: &SessionContext, - name: &str, - ) -> Result> { - let testdata = crate::test_util::parquet_test_data(); - let filename = format!("{testdata}/{name}"); - let table_path = ListingTableUrl::parse(filename).unwrap(); - - let config = ListingTableConfig::new(table_path) - .infer(&ctx.state()) - .await?; - let table = ListingTable::try_new(config)?; - Ok(Arc::new(table)) - } - - /// Check that the files listed by the table match the specified `output_partitioning` - /// when the object store contains `files`. - async fn assert_list_files_for_scan_grouping( - files: &[&str], - table_prefix: &str, - target_partitions: usize, - output_partitioning: usize, - file_ext: Option<&str>, - ) -> Result<()> { - let ctx = SessionContext::new(); - register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()); - - let opt = ListingOptions::new(Arc::new(JsonFormat::default())) - .with_file_extension_opt(file_ext) - .with_target_partitions(target_partitions); - - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); let table_path = ListingTableUrl::parse(table_prefix).unwrap(); let config = ListingTableConfig::new(table_path) @@ -1847,10 +1968,10 @@ mod tests { .execution .meta_fetch_concurrency; let expected_concurrency = files.len().min(meta_fetch_concurrency); - let head_blocking_store = ensure_head_concurrency(store, expected_concurrency); + let head_concurrency_store = ensure_head_concurrency(store, expected_concurrency); let url = Url::parse("test://").unwrap(); - ctx.register_object_store(&url, head_blocking_store.clone()); + ctx.register_object_store(&url, head_concurrency_store.clone()); let format = JsonFormat::default(); @@ -1862,7 +1983,7 @@ mod tests { let table_paths = files .iter() - .map(|t| ListingTableUrl::parse(format!("test:///{}", t)).unwrap()) + .map(|t| ListingTableUrl::parse(format!("test:///{t}")).unwrap()) .collect(); let config = ListingTableConfig::new_with_multi_paths(table_paths) .with_listing_options(opt) @@ -1877,80 +1998,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_new_json_files() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - JsonFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_insert_into_append_new_csv_files() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - CsvFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } - - #[cfg(feature = "parquet")] - #[tokio::test] - async fn test_insert_into_append_2_new_parquet_files_defaults() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - ParquetFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } - - #[cfg(feature = "parquet")] - #[tokio::test] - async fn test_insert_into_append_1_new_parquet_files_defaults() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "20".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "20".into(), - ); - helper_test_append_new_files_to_table( - ParquetFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 1, - ) - .await?; - Ok(()) - } - #[tokio::test] async fn test_insert_into_sql_csv_defaults() -> Result<()> { helper_test_insert_into_sql("csv", FileCompressionType::UNCOMPRESSED, "", None) @@ -2061,7 +2108,6 @@ mod tests { #[tokio::test] async fn test_insert_into_append_new_parquet_files_session_overrides() -> Result<()> { let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); config_map.insert( "datafusion.execution.soft_max_rows_per_output_file".into(), "10".into(), @@ -2126,7 +2172,7 @@ mod tests { "datafusion.execution.parquet.write_batch_size".into(), "5".into(), ); - config_map.insert("datafusion.execution.batch_size".into(), "1".into()); + config_map.insert("datafusion.execution.batch_size".into(), "10".into()); helper_test_append_new_files_to_table( ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, @@ -2183,7 +2229,7 @@ mod tests { let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( Box::new(Expr::Column("column1".into())), Operator::GtEq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(0)), None)), )); // Create a new batch of data to insert into the table @@ -2367,8 +2413,10 @@ mod tests { // create table let tmp_dir = TempDir::new()?; - let tmp_path = tmp_dir.into_path(); - let str_path = tmp_path.to_str().expect("Temp path should convert to &str"); + let str_path = tmp_dir + .path() + .to_str() + .expect("Temp path should convert to &str"); session_ctx .sql(&format!( "create external table foo(a varchar, b varchar, c int) \ @@ -2409,7 +2457,7 @@ mod tests { #[tokio::test] async fn test_infer_options_compressed_csv() -> Result<()> { let testdata = crate::test_util::arrow_test_data(); - let filename = format!("{}/csv/aggregate_test_100.csv.gz", testdata); + let filename = format!("{testdata}/csv/aggregate_test_100.csv.gz"); let table_path = ListingTableUrl::parse(filename).unwrap(); let ctx = SessionContext::new(); @@ -2424,4 +2472,640 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn infer_preserves_provided_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let testdata = datafusion_test_data(); + let filename = format!("{testdata}/aggregate_simple.csv"); + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let provided_schema = create_test_schema(); + + let config = + ListingTableConfig::new(table_path).with_schema(Arc::clone(&provided_schema)); + + let config = config.infer(&ctx.state()).await?; + + assert_eq!(*config.file_schema.unwrap(), *provided_schema); + + Ok(()) + } + + #[tokio::test] + async fn test_listing_table_config_with_multiple_files_comprehensive() -> Result<()> { + let ctx = SessionContext::new(); + + // Create test files with different schemas + let tmp_dir = TempDir::new()?; + let file_path1 = tmp_dir.path().join("file1.csv"); + let file_path2 = tmp_dir.path().join("file2.csv"); + + // File 1: c1,c2,c3 + let mut file1 = std::fs::File::create(&file_path1)?; + writeln!(file1, "c1,c2,c3")?; + writeln!(file1, "1,2,3")?; + writeln!(file1, "4,5,6")?; + + // File 2: c1,c2,c3,c4 + let mut file2 = std::fs::File::create(&file_path2)?; + writeln!(file2, "c1,c2,c3,c4")?; + writeln!(file2, "7,8,9,10")?; + writeln!(file2, "11,12,13,14")?; + + // Parse paths + let table_path1 = ListingTableUrl::parse(file_path1.to_str().unwrap())?; + let table_path2 = ListingTableUrl::parse(file_path2.to_str().unwrap())?; + + // Create format and options + let format = CsvFormat::default().with_has_header(true); + let options = ListingOptions::new(Arc::new(format)); + + // Test case 1: Infer schema using first file's schema + let config1 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_listing_options(options.clone()); + let config1 = config1.infer_schema(&ctx.state()).await?; + assert_eq!(config1.schema_source(), SchemaSource::Inferred); + + // Verify schema matches first file + let schema1 = config1.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema1.fields().len(), 3); + assert_eq!(schema1.field(0).name(), "c1"); + assert_eq!(schema1.field(1).name(), "c2"); + assert_eq!(schema1.field(2).name(), "c3"); + + // Test case 2: Use specified schema with 3 columns + let schema_3cols = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + ])); + + let config2 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_schema(schema_3cols) + .with_listing_options(options.clone()); + let config2 = config2.infer_schema(&ctx.state()).await?; + assert_eq!(config2.schema_source(), SchemaSource::Specified); + + // Verify that the schema is still the one we specified (3 columns) + let schema2 = config2.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema2.fields().len(), 3); + assert_eq!(schema2.field(0).name(), "c1"); + assert_eq!(schema2.field(1).name(), "c2"); + assert_eq!(schema2.field(2).name(), "c3"); + + // Test case 3: Use specified schema with 4 columns + let schema_4cols = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + Field::new("c4", DataType::Utf8, true), + ])); + + let config3 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_schema(schema_4cols) + .with_listing_options(options.clone()); + let config3 = config3.infer_schema(&ctx.state()).await?; + assert_eq!(config3.schema_source(), SchemaSource::Specified); + + // Verify that the schema is still the one we specified (4 columns) + let schema3 = config3.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema3.fields().len(), 4); + assert_eq!(schema3.field(0).name(), "c1"); + assert_eq!(schema3.field(1).name(), "c2"); + assert_eq!(schema3.field(2).name(), "c3"); + assert_eq!(schema3.field(3).name(), "c4"); + + // Test case 4: Verify order matters when inferring schema + let config4 = ListingTableConfig::new_with_multi_paths(vec![ + table_path2.clone(), + table_path1.clone(), + ]) + .with_listing_options(options); + let config4 = config4.infer_schema(&ctx.state()).await?; + + // Should use first file's schema, which now has 4 columns + let schema4 = config4.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema4.fields().len(), 4); + assert_eq!(schema4.field(0).name(), "c1"); + assert_eq!(schema4.field(1).name(), "c2"); + assert_eq!(schema4.field(2).name(), "c3"); + assert_eq!(schema4.field(3).name(), "c4"); + + Ok(()) + } + + #[tokio::test] + async fn test_list_files_configurations() -> Result<()> { + // Define common test cases as (description, files, paths, target_partitions, expected_partitions, file_ext) + let test_cases = vec![ + // Single path cases + ( + "Single path, more partitions than files", + generate_test_files("bucket/key-prefix", 5), + vec!["test:///bucket/key-prefix/"], + 12, + 5, + Some(""), + ), + ( + "Single path, equal partitions and files", + generate_test_files("bucket/key-prefix", 4), + vec!["test:///bucket/key-prefix/"], + 4, + 4, + Some(""), + ), + ( + "Single path, more files than partitions", + generate_test_files("bucket/key-prefix", 5), + vec!["test:///bucket/key-prefix/"], + 2, + 2, + Some(""), + ), + // Multi path cases + ( + "Multi path, more partitions than files", + { + let mut files = generate_test_files("bucket/key1", 3); + files.extend(generate_test_files_with_start("bucket/key2", 2, 3)); + files.extend(generate_test_files_with_start("bucket/key3", 1, 5)); + files + }, + vec!["test:///bucket/key1/", "test:///bucket/key2/"], + 12, + 5, + Some(""), + ), + // No files case + ( + "No files", + vec![], + vec!["test:///bucket/key-prefix/"], + 2, + 0, + Some(""), + ), + // Exact path cases + ( + "Exact paths test", + { + let mut files = generate_test_files("bucket/key1", 3); + files.extend(generate_test_files_with_start("bucket/key2", 2, 3)); + files + }, + vec![ + "test:///bucket/key1/file0", + "test:///bucket/key1/file1", + "test:///bucket/key1/file2", + "test:///bucket/key2/file3", + "test:///bucket/key2/file4", + ], + 12, + 5, + Some(""), + ), + ]; + + // Run each test case + for (test_name, files, paths, target_partitions, expected_partitions, file_ext) in + test_cases + { + println!("Running test: {test_name}"); + + if files.is_empty() { + // Test empty files case + assert_list_files_for_multi_paths( + &[], + &paths, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else if paths.len() == 1 { + // Test using single path API + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_scan_grouping( + &file_refs, + paths[0], + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else if paths[0].contains("test:///bucket/key") { + // Test using multi path API + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_multi_paths( + &file_refs, + &paths, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else { + // Test using exact path API for specific cases + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_exact_paths( + &file_refs, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } + } + + Ok(()) + } + + #[cfg(feature = "parquet")] + #[tokio::test] + async fn test_table_stats_behaviors() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + + let testdata = crate::test_util::parquet_test_data(); + let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let ctx = SessionContext::new(); + let state = ctx.state(); + + // Test 1: Default behavior - stats not collected + let opt_default = ListingOptions::new(Arc::new(ParquetFormat::default())); + let schema_default = opt_default.infer_schema(&state, &table_path).await?; + let config_default = ListingTableConfig::new(table_path.clone()) + .with_listing_options(opt_default) + .with_schema(schema_default); + let table_default = ListingTable::try_new(config_default)?; + + let exec_default = table_default.scan(&state, None, &[], None).await?; + assert_eq!( + exec_default.partition_statistics(None)?.num_rows, + Precision::Absent + ); + + // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 + assert_eq!( + exec_default.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); + + // Test 2: Explicitly disable stats + let opt_disabled = ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(false); + let schema_disabled = opt_disabled.infer_schema(&state, &table_path).await?; + let config_disabled = ListingTableConfig::new(table_path.clone()) + .with_listing_options(opt_disabled) + .with_schema(schema_disabled); + let table_disabled = ListingTable::try_new(config_disabled)?; + + let exec_disabled = table_disabled.scan(&state, None, &[], None).await?; + assert_eq!( + exec_disabled.partition_statistics(None)?.num_rows, + Precision::Absent + ); + assert_eq!( + exec_disabled.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); + + // Test 3: Explicitly enable stats + let opt_enabled = ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true); + let schema_enabled = opt_enabled.infer_schema(&state, &table_path).await?; + let config_enabled = ListingTableConfig::new(table_path) + .with_listing_options(opt_enabled) + .with_schema(schema_enabled); + let table_enabled = ListingTable::try_new(config_enabled)?; + + let exec_enabled = table_enabled.scan(&state, None, &[], None).await?; + assert_eq!( + exec_enabled.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); + // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 + assert_eq!( + exec_enabled.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_parameterized() -> Result<()> { + let test_cases = vec![ + // (file_format, batch_size, soft_max_rows, expected_files) + ("json", 10, 10, 2), + ("csv", 10, 10, 2), + #[cfg(feature = "parquet")] + ("parquet", 10, 10, 2), + #[cfg(feature = "parquet")] + ("parquet", 20, 20, 1), + ]; + + for (format, batch_size, soft_max_rows, expected_files) in test_cases { + println!("Testing insert with format: {format}, batch_size: {batch_size}, expected files: {expected_files}"); + + let mut config_map = HashMap::new(); + config_map.insert( + "datafusion.execution.batch_size".into(), + batch_size.to_string(), + ); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + soft_max_rows.to_string(), + ); + + let file_extension = match format { + "json" => JsonFormat::default().get_ext(), + "csv" => CsvFormat::default().get_ext(), + #[cfg(feature = "parquet")] + "parquet" => ParquetFormat::default().get_ext(), + _ => unreachable!("Unsupported format"), + }; + + helper_test_append_new_files_to_table( + file_extension, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + expected_files, + ) + .await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_mapping_with_custom_factory() -> Result<()> { + let ctx = SessionContext::new(); + let table = create_test_listing_table_with_json_and_adapter( + &ctx, + false, + // NullStatsAdapterFactory sets column_statistics null_count to DUMMY_NULL_COUNT + Arc::new(NullStatsAdapterFactory {}), + )?; + + let (groups, stats) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + + assert_eq!(stats.column_statistics[0].null_count, DUMMY_NULL_COUNT); + for g in groups { + if let Some(s) = g.file_statistics(None) { + assert_eq!(s.column_statistics[0].null_count, DUMMY_NULL_COUNT); + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_mapping_with_default_factory() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a table without providing a custom schema adapter factory + // This should fall back to using DefaultSchemaAdapterFactory + let path = "table/file.json"; + register_test_store(&ctx, &[(path, 10)]); + + let format = JsonFormat::default(); + let opt = ListingOptions::new(Arc::new(format)).with_collect_stat(false); + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let table_path = ListingTableUrl::parse("test:///table/").unwrap(); + + let config = ListingTableConfig::new(table_path) + .with_listing_options(opt) + .with_schema(Arc::new(schema)); + // Note: NOT calling .with_schema_adapter_factory() to test default behavior + + let table = ListingTable::try_new(config)?; + + // Verify that no custom schema adapter factory is set + assert!(table.schema_adapter_factory().is_none()); + + // The scan should work correctly with the default schema adapter + let scan_result = table.scan(&ctx.state(), None, &[], None).await; + assert!( + scan_result.is_ok(), + "Scan should succeed with default schema adapter" + ); + + // Verify that the default adapter handles basic schema compatibility + let (groups, _stats) = table.list_files_for_scan(&ctx.state(), &[], None).await?; + assert!( + !groups.is_empty(), + "Should list files successfully with default adapter" + ); + + Ok(()) + } + + #[rstest] + #[case(MapSchemaError::TypeIncompatible, "Cannot map incompatible types")] + #[case(MapSchemaError::GeneralFailure, "Schema adapter mapping failed")] + #[case( + MapSchemaError::InvalidProjection, + "Invalid projection in schema mapping" + )] + #[tokio::test] + async fn test_schema_adapter_map_schema_errors( + #[case] error_type: MapSchemaError, + #[case] expected_error_msg: &str, + ) -> Result<()> { + let ctx = SessionContext::new(); + let table = create_test_listing_table_with_json_and_adapter( + &ctx, + false, + Arc::new(FailingMapSchemaAdapterFactory { error_type }), + )?; + + // The error should bubble up from the scan operation when schema mapping fails + let scan_result = table.scan(&ctx.state(), None, &[], None).await; + + assert!(scan_result.is_err()); + let error_msg = scan_result.unwrap_err().to_string(); + assert!( + error_msg.contains(expected_error_msg), + "Expected error containing '{expected_error_msg}', got: {error_msg}" + ); + + Ok(()) + } + + // Test that errors during file listing also bubble up correctly + #[tokio::test] + async fn test_schema_adapter_error_during_file_listing() -> Result<()> { + let ctx = SessionContext::new(); + let table = create_test_listing_table_with_json_and_adapter( + &ctx, + true, + Arc::new(FailingMapSchemaAdapterFactory { + error_type: MapSchemaError::TypeIncompatible, + }), + )?; + + // The error should bubble up from list_files_for_scan when collecting statistics + let list_result = table.list_files_for_scan(&ctx.state(), &[], None).await; + + assert!(list_result.is_err()); + let error_msg = list_result.unwrap_err().to_string(); + assert!( + error_msg.contains("Cannot map incompatible types"), + "Expected type incompatibility error during file listing, got: {error_msg}" + ); + + Ok(()) + } + + #[derive(Debug, Copy, Clone)] + enum MapSchemaError { + TypeIncompatible, + GeneralFailure, + InvalidProjection, + } + + #[derive(Debug)] + struct FailingMapSchemaAdapterFactory { + error_type: MapSchemaError, + } + + impl SchemaAdapterFactory for FailingMapSchemaAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(FailingMapSchemaAdapter { + schema: projected_table_schema, + error_type: self.error_type, + }) + } + } + + #[derive(Debug)] + struct FailingMapSchemaAdapter { + schema: SchemaRef, + error_type: MapSchemaError, + } + + impl SchemaAdapter for FailingMapSchemaAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.schema.field(index); + file_schema.fields.find(field.name()).map(|(i, _)| i) + } + + fn map_schema( + &self, + _file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + // Always fail with different error types based on the configured error_type + match self.error_type { + MapSchemaError::TypeIncompatible => { + plan_err!( + "Cannot map incompatible types: Boolean cannot be cast to Utf8" + ) + } + MapSchemaError::GeneralFailure => { + plan_err!("Schema adapter mapping failed due to internal error") + } + MapSchemaError::InvalidProjection => { + plan_err!("Invalid projection in schema mapping: column index out of bounds") + } + } + } + } + + #[derive(Debug)] + struct NullStatsAdapterFactory; + + impl SchemaAdapterFactory for NullStatsAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(NullStatsAdapter { + schema: projected_table_schema, + }) + } + } + + #[derive(Debug)] + struct NullStatsAdapter { + schema: SchemaRef, + } + + impl SchemaAdapter for NullStatsAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.schema.field(index); + file_schema.fields.find(field.name()).map(|(i, _)| i) + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + let projection = (0..file_schema.fields().len()).collect(); + Ok((Arc::new(NullStatsMapper {}), projection)) + } + } + + #[derive(Debug)] + struct NullStatsMapper; + + impl SchemaMapper for NullStatsMapper { + fn map_batch(&self, batch: RecordBatch) -> Result { + Ok(batch) + } + + fn map_column_statistics( + &self, + stats: &[ColumnStatistics], + ) -> Result> { + Ok(stats + .iter() + .map(|s| { + let mut s = s.clone(); + s.null_count = DUMMY_NULL_COUNT; + s + }) + .collect()) + } + } + + /// Helper function to create a test ListingTable with JSON format and custom schema adapter factory + fn create_test_listing_table_with_json_and_adapter( + ctx: &SessionContext, + collect_stat: bool, + schema_adapter_factory: Arc, + ) -> Result { + let path = "table/file.json"; + register_test_store(ctx, &[(path, 10)]); + + let format = JsonFormat::default(); + let opt = ListingOptions::new(Arc::new(format)).with_collect_stat(collect_stat); + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let table_path = ListingTableUrl::parse("test:///table/").unwrap(); + + let config = ListingTableConfig::new(table_path) + .with_listing_options(opt) + .with_schema(Arc::new(schema)) + .with_schema_adapter_factory(schema_adapter_factory); + + ListingTable::try_new(config) + } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 636d1623c5e91..f98297d0e3f7f 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -27,7 +27,7 @@ use crate::datasource::listing::{ }; use crate::execution::context::SessionState; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::DataType; use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, ToDFSchema}; use datafusion_common::{config_datafusion_err, Result}; use datafusion_expr::CreateExternalTable; @@ -63,16 +63,39 @@ impl TableProviderFactory for ListingTableFactory { ))? .create(session_state, &cmd.options)?; - let file_extension = get_extension(cmd.location.as_str()); + let mut table_path = ListingTableUrl::parse(&cmd.location)?; + let file_extension = match table_path.is_collection() { + // Setting the extension to be empty instead of allowing the default extension seems + // odd, but was done to ensure existing behavior isn't modified. It seems like this + // could be refactored to either use the default extension or set the fully expected + // extension when compression is included (e.g. ".csv.gz") + true => "", + false => &get_extension(cmd.location.as_str()), + }; + let mut options = ListingOptions::new(file_format) + .with_session_config_options(session_state.config()) + .with_file_extension(file_extension); let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() { + let infer_parts = session_state + .config_options() + .execution + .listing_table_factory_infer_partitions; + let part_cols = if cmd.table_partition_cols.is_empty() && infer_parts { + options + .infer_partitions(session_state, &table_path) + .await? + .into_iter() + } else { + cmd.table_partition_cols.clone().into_iter() + }; + ( None, - cmd.table_partition_cols - .iter() - .map(|x| { + part_cols + .map(|p| { ( - x.clone(), + p, DataType::Dictionary( Box::new(DataType::UInt16), Box::new(DataType::Utf8), @@ -82,7 +105,7 @@ impl TableProviderFactory for ListingTableFactory { .collect::>(), ) } else { - let schema: SchemaRef = Arc::new(cmd.schema.as_ref().to_owned().into()); + let schema = Arc::clone(cmd.schema.inner()); let table_partition_cols = cmd .table_partition_cols .iter() @@ -108,13 +131,7 @@ impl TableProviderFactory for ListingTableFactory { (Some(schema), table_partition_cols) }; - let table_path = ListingTableUrl::parse(&cmd.location)?; - - let options = ListingOptions::new(file_format) - .with_collect_stat(state.config().collect_statistics()) - .with_file_extension(file_extension) - .with_target_partitions(state.config().target_partitions()) - .with_table_partition_cols(table_partition_cols); + options = options.with_table_partition_cols(table_partition_cols); options .validate_partitions(session_state, &table_path) @@ -126,6 +143,25 @@ impl TableProviderFactory for ListingTableFactory { // specifically for parquet file format. // See: https://github.com/apache/datafusion/issues/7317 None => { + // if the folder then rewrite a file path as 'path/*.parquet' + // to only read the files the reader can understand + if table_path.is_folder() && table_path.get_glob().is_none() { + // Since there are no files yet to infer an actual extension, + // derive the pattern based on compression type. + // So for gzipped CSV the pattern is `*.csv.gz` + let glob = match options.format.compression_type() { + Some(compression) => { + match options.format.get_ext_with_compression(&compression) { + // Use glob based on `FileFormat` extension + Ok(ext) => format!("*.{ext}"), + // Fallback to `file_type`, if not supported by `FileFormat` + Err(_) => format!("*.{}", cmd.file_type.to_lowercase()), + } + } + None => format!("*.{}", cmd.file_type.to_lowercase()), + }; + table_path = table_path.with_glob(glob.as_ref())?; + } let schema = options.infer_schema(session_state, &table_path).await?; let df_schema = Arc::clone(&schema).to_dfschema()?; let column_refs: HashSet<_> = cmd @@ -169,13 +205,18 @@ fn get_extension(path: &str) -> String { #[cfg(test)] mod tests { + use datafusion_execution::config::SessionConfig; + use glob::Pattern; use std::collections::HashMap; + use std::fs; + use std::path::PathBuf; use super::*; use crate::{ datasource::file_format::csv::CsvFormat, execution::context::SessionContext, }; + use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{Constraints, DFSchema, TableReference}; #[tokio::test] @@ -197,12 +238,13 @@ mod tests { schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, + or_replace: false, temporary: false, definition: None, order_exprs: vec![], unbounded: false, options: HashMap::from([("format.has_header".into(), "true".into())]), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); @@ -237,12 +279,13 @@ mod tests { schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, + or_replace: false, temporary: false, definition: None, order_exprs: vec![], unbounded: false, options, - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); @@ -258,4 +301,222 @@ mod tests { let listing_options = listing_table.options(); assert_eq!(".tbl", listing_options.file_extension); } + + /// Validates that CreateExternalTable with compression + /// searches for gzipped files in a directory location + #[tokio::test] + async fn test_create_using_folder_with_compression() { + let dir = tempfile::tempdir().unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = TableReference::bare("foo"); + + let mut options = HashMap::new(); + options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); + options.insert("format.has_header".into(), "true".into()); + options.insert("format.compression".into(), "gzip".into()); + let cmd = CreateExternalTable { + name, + location: dir.path().to_str().unwrap().to_string(), + file_type: "csv".to_string(), + schema: Arc::new(DFSchema::empty()), + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options, + constraints: Constraints::default(), + column_defaults: HashMap::new(), + }; + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify compression is used + let format = listing_table.options().format.clone(); + let csv_format = format.as_any().downcast_ref::().unwrap(); + let csv_options = csv_format.options().clone(); + assert_eq!(csv_options.compression, CompressionTypeVariant::GZIP); + + let listing_options = listing_table.options(); + assert_eq!("", listing_options.file_extension); + // Glob pattern is set to search for gzipped files + let table_path = listing_table.table_paths().first().unwrap(); + assert_eq!( + table_path.get_glob().clone().unwrap(), + Pattern::new("*.csv.gz").unwrap() + ); + } + + /// Validates that CreateExternalTable without compression + /// searches for normal files in a directory location + #[tokio::test] + async fn test_create_using_folder_without_compression() { + let dir = tempfile::tempdir().unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = TableReference::bare("foo"); + + let mut options = HashMap::new(); + options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); + options.insert("format.has_header".into(), "true".into()); + let cmd = CreateExternalTable { + name, + location: dir.path().to_str().unwrap().to_string(), + file_type: "csv".to_string(), + schema: Arc::new(DFSchema::empty()), + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options, + constraints: Constraints::default(), + column_defaults: HashMap::new(), + }; + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + let listing_options = listing_table.options(); + assert_eq!("", listing_options.file_extension); + // Glob pattern is set to search for gzipped files + let table_path = listing_table.table_paths().first().unwrap(); + assert_eq!( + table_path.get_glob().clone().unwrap(), + Pattern::new("*.csv").unwrap() + ); + } + + #[tokio::test] + async fn test_odd_directory_names() { + let dir = tempfile::tempdir().unwrap(); + let mut path = PathBuf::from(dir.path()); + path.extend(["odd.v1", "odd.v2"]); + fs::create_dir_all(&path).unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = TableReference::bare("foo"); + + let cmd = CreateExternalTable { + name, + location: String::from(path.to_str().unwrap()), + file_type: "parquet".to_string(), + schema: Arc::new(DFSchema::empty()), + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options: HashMap::new(), + constraints: Constraints::default(), + column_defaults: HashMap::new(), + }; + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + let listing_options = listing_table.options(); + assert_eq!("", listing_options.file_extension); + } + + #[tokio::test] + async fn test_create_with_hive_partitions() { + let dir = tempfile::tempdir().unwrap(); + let mut path = PathBuf::from(dir.path()); + path.extend(["key1=value1", "key2=value2"]); + fs::create_dir_all(&path).unwrap(); + path.push("data.parquet"); + fs::File::create_new(&path).unwrap(); + + let factory = ListingTableFactory::new(); + let context = SessionContext::new(); + let state = context.state(); + let name = TableReference::bare("foo"); + + let cmd = CreateExternalTable { + name, + location: dir.path().to_str().unwrap().to_string(), + file_type: "parquet".to_string(), + schema: Arc::new(DFSchema::empty()), + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options: HashMap::new(), + constraints: Constraints::default(), + column_defaults: HashMap::new(), + }; + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + let listing_options = listing_table.options(); + let dtype = + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)); + let expected_cols = vec![ + (String::from("key1"), dtype.clone()), + (String::from("key2"), dtype.clone()), + ]; + assert_eq!(expected_cols, listing_options.table_partition_cols); + + // Ensure partition detection can be disabled via config + let factory = ListingTableFactory::new(); + let mut cfg = SessionConfig::new(); + cfg.options_mut() + .execution + .listing_table_factory_infer_partitions = false; + let context = SessionContext::new_with_config(cfg); + let state = context.state(); + let name = TableReference::bare("foo"); + + let cmd = CreateExternalTable { + name, + location: dir.path().to_str().unwrap().to_string(), + file_type: "parquet".to_string(), + schema: Arc::new(DFSchema::empty()), + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options: HashMap::new(), + constraints: Constraints::default(), + column_defaults: HashMap::new(), + }; + let table_provider = factory.create(&state, &cmd).await.unwrap(); + let listing_table = table_provider + .as_any() + .downcast_ref::() + .unwrap(); + + let listing_options = listing_table.options(); + assert!(listing_options.table_partition_cols.is_empty()); + } } diff --git a/datafusion/core/src/datasource/memory_test.rs b/datafusion/core/src/datasource/memory_test.rs index 381000ab8ee1e..c16837c73b4f1 100644 --- a/datafusion/core/src/datasource/memory_test.rs +++ b/datafusion/core/src/datasource/memory_test.rs @@ -130,12 +130,15 @@ mod tests { .scan(&session_ctx.state(), Some(&projection), &[], None) .await { - Err(DataFusionError::ArrowError(ArrowError::SchemaError(e), _)) => { - assert_eq!( - "\"project index 4 out of bounds, max field 3\"", - format!("{e:?}") - ) - } + Err(DataFusionError::ArrowError(err, _)) => match err.as_ref() { + ArrowError::SchemaError(e) => { + assert_eq!( + "\"project index 4 out of bounds, max field 3\"", + format!("{e:?}") + ) + } + _ => panic!("unexpected error"), + }, res => panic!("Scan should failed on invalid projection, got {res:?}"), }; @@ -443,7 +446,7 @@ mod tests { .unwrap_err(); // Ensure that there is a descriptive error message assert_eq!( - "Error during planning: Cannot insert into MemTable with zero partitions", + "Error during planning: No partitions provided, expected at least one partition", experiment_result.strip_backtrace() ); Ok(()) diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index a15b2b6ffe137..94d651ddadd5c 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -52,27 +52,26 @@ pub use datafusion_physical_expr::create_ordering; mod tests { use crate::prelude::SessionContext; - - use std::fs; - use std::sync::Arc; - - use arrow::array::{Int32Array, StringArray}; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use arrow::record_batch::RecordBatch; - use datafusion_common::test_util::batches_to_sort_string; - use datafusion_datasource::file_scan_config::FileScanConfigBuilder; - use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, + use ::object_store::{path::Path, ObjectMeta}; + use arrow::{ + array::{Int32Array, StringArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use datafusion_common::{record_batch, test_util::batches_to_sort_string}; + use datafusion_datasource::{ + file::FileSource, + file_scan_config::FileScanConfigBuilder, + schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, + SchemaMapper, + }, + source::DataSourceExec, + PartitionedFile, }; - use datafusion_datasource::PartitionedFile; use datafusion_datasource_parquet::source::ParquetSource; - - use datafusion_common::record_batch; - - use ::object_store::path::Path; - use ::object_store::ObjectMeta; - use datafusion_datasource::source::DataSourceExec; use datafusion_physical_plan::collect; + use std::{fs, sync::Arc}; use tempfile::TempDir; #[tokio::test] @@ -106,7 +105,7 @@ mod tests { let meta = ObjectMeta { location, last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), - size: metadata.len() as usize, + size: metadata.len(), e_tag: None, version: None, }; @@ -124,10 +123,9 @@ mod tests { let f2 = Field::new("extra_column", DataType::Utf8, true); let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()])); - let source = Arc::new( - ParquetSource::default() - .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})), - ); + let source = ParquetSource::default() + .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})) + .unwrap(); let base_conf = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), schema, @@ -264,5 +262,12 @@ mod tests { Ok(RecordBatch::try_new(schema, new_columns).unwrap()) } + + fn map_column_statistics( + &self, + _file_col_statistics: &[datafusion_common::ColumnStatistics], + ) -> datafusion_common::Result> { + unimplemented!() + } } } diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 5dcf4df73f57a..b37dc499d4035 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -15,196 +15,40 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading Arrow files - use std::any::Any; use std::sync::Arc; -use crate::datasource::physical_plan::{FileMeta, FileOpenFuture, FileOpener}; +use crate::datasource::physical_plan::{FileOpenFuture, FileOpener}; use crate::error::Result; +use datafusion_datasource::as_file_source; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use arrow::buffer::Buffer; use arrow::datatypes::SchemaRef; use arrow_ipc::reader::FileDecoder; -use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::{exec_datafusion_err, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_datasource_json::source::JsonSource; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; +use datafusion_datasource::PartitionedFile; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use datafusion_datasource::file_groups::FileGroup; use futures::StreamExt; use itertools::Itertools; use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; -/// Execution plan for scanning Arrow data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct ArrowExec { - inner: DataSourceExec, - base_config: FileScanConfig, -} - -#[allow(unused, deprecated)] -impl ArrowExec { - /// Create a new Arrow reader execution plan provided base configurations - pub fn new(base_config: FileScanConfig) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - Arc::clone(&projected_schema), - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let arrow = ArrowSource::default(); - let base_config = base_config.with_source(Arc::new(arrow)); - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn json_source(&self) -> JsonSource { - self.file_scan_config() - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - output_ordering: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = - EquivalenceProperties::new_with_orderings(schema, output_ordering) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for ArrowExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for ArrowExec { - fn name(&self) -> &'static str { - "ArrowExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - fn metrics(&self) -> Option { - self.inner.metrics() - } - fn statistics(&self) -> Result { - self.inner.statistics() - } - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// Arrow configuration struct that is given to DataSourceExec /// Does not hold anything special, since [`FileScanConfig`] is sufficient for arrow #[derive(Clone, Default)] pub struct ArrowSource { metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, +} + +impl From for Arc { + fn from(source: ArrowSource) -> Self { + as_file_source(source) + } } impl FileSource for ArrowSource { @@ -255,6 +99,20 @@ impl FileSource for ArrowSource { fn file_type(&self) -> &str { "arrow" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } /// The struct arrow that implements `[FileOpener]` trait @@ -264,20 +122,25 @@ pub struct ArrowOpener { } impl FileOpener for ArrowOpener { - fn open(&self, file_meta: FileMeta) -> Result { + fn open(&self, partitioned_file: PartitionedFile) -> Result { let object_store = Arc::clone(&self.object_store); let projection = self.projection.clone(); Ok(Box::pin(async move { - let range = file_meta.range.clone(); + let range = partitioned_file.range.clone(); match range { None => { - let r = object_store.get(file_meta.location()).await?; + let r = object_store + .get(&partitioned_file.object_meta.location) + .await?; match r.payload { + #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(file, _) => { let arrow_reader = arrow::ipc::reader::FileReader::try_new( file, projection, )?; - Ok(futures::stream::iter(arrow_reader).boxed()) + Ok(futures::stream::iter(arrow_reader) + .map(|r| r.map_err(Into::into)) + .boxed()) } GetResultPayload::Stream(_) => { let bytes = r.bytes().await?; @@ -285,7 +148,9 @@ impl FileOpener for ArrowOpener { let arrow_reader = arrow::ipc::reader::FileReader::try_new( cursor, projection, )?; - Ok(futures::stream::iter(arrow_reader).boxed()) + Ok(futures::stream::iter(arrow_reader) + .map(|r| r.map_err(Into::into)) + .boxed()) } } } @@ -297,7 +162,7 @@ impl FileOpener for ArrowOpener { ..Default::default() }; let get_result = object_store - .get_opts(file_meta.location(), get_option) + .get_opts(&partitioned_file.object_meta.location, get_option) .await?; let footer_len_buf = get_result.bytes().await?; let footer_len = arrow_ipc::reader::read_footer_length( @@ -305,20 +170,18 @@ impl FileOpener for ArrowOpener { )?; // read footer according to footer_len let get_option = GetOptions { - range: Some(GetRange::Suffix(10 + footer_len)), + range: Some(GetRange::Suffix(10 + (footer_len as u64))), ..Default::default() }; let get_result = object_store - .get_opts(file_meta.location(), get_option) + .get_opts(&partitioned_file.object_meta.location, get_option) .await?; let footer_buf = get_result.bytes().await?; let footer = arrow_ipc::root_as_footer( footer_buf[..footer_len].try_into().unwrap(), ) .map_err(|err| { - arrow::error::ArrowError::ParseError(format!( - "Unable to get root as footer: {err:?}" - )) + exec_datafusion_err!("Unable to get root as footer: {err:?}") })?; // build decoder according to footer & projection let schema = @@ -332,14 +195,14 @@ impl FileOpener for ArrowOpener { .iter() .flatten() .map(|block| { - let block_len = block.bodyLength() as usize - + block.metaDataLength() as usize; - let block_offset = block.offset() as usize; + let block_len = + block.bodyLength() as u64 + block.metaDataLength() as u64; + let block_offset = block.offset() as u64; block_offset..block_offset + block_len }) .collect_vec(); let dict_results = object_store - .get_ranges(file_meta.location(), &dict_ranges) + .get_ranges(&partitioned_file.object_meta.location, &dict_ranges) .await?; for (dict_block, dict_result) in footer.dictionaries().iter().flatten().zip(dict_results) @@ -354,9 +217,9 @@ impl FileOpener for ArrowOpener { .iter() .flatten() .filter(|block| { - let block_offset = block.offset() as usize; - block_offset >= range.start as usize - && block_offset < range.end as usize + let block_offset = block.offset() as u64; + block_offset >= range.start as u64 + && block_offset < range.end as u64 }) .copied() .collect_vec(); @@ -364,15 +227,18 @@ impl FileOpener for ArrowOpener { let recordbatch_ranges = recordbatches .iter() .map(|block| { - let block_len = block.bodyLength() as usize - + block.metaDataLength() as usize; - let block_offset = block.offset() as usize; + let block_len = + block.bodyLength() as u64 + block.metaDataLength() as u64; + let block_offset = block.offset() as u64; block_offset..block_offset + block_len }) .collect_vec(); let recordbatch_results = object_store - .get_ranges(file_meta.location(), &recordbatch_ranges) + .get_ranges( + &partitioned_file.object_meta.location, + &recordbatch_ranges, + ) .await?; Ok(futures::stream::iter( @@ -385,6 +251,7 @@ impl FileOpener for ArrowOpener { .transpose() }), ) + .map(|r| r.map_err(Into::into)) .boxed()) } } diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 5914924797dce..e33761a0abb3a 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -369,7 +369,8 @@ mod tests { .build(); // Add partition columns - config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; + config.table_partition_cols = + vec![Arc::new(Field::new("date", DataType::Utf8, false))]; config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; // We should be able to project on the partition column @@ -658,7 +659,7 @@ mod tests { ) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value d for column 0 at line 4"); + assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 910c4316d9734..0d45711c76fb0 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -19,7 +19,6 @@ //! //! [`FileSource`]: datafusion_datasource::file::FileSource -#[allow(deprecated)] pub use datafusion_datasource_json::source::*; #[cfg(test)] @@ -495,7 +494,7 @@ mod tests { .write_json(out_dir_url, DataFrameWriteOptions::new(), None) .await .expect_err("should fail because input file does not match inferred schema"); - assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value d for column 0 at line 4"); + assert_eq!(e.strip_backtrace(), "Arrow error: Parser error: Error while parsing value 'd' as type 'Int64' for column 0 at line 4. Row data: '[d,4]'"); Ok(()) } diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index e3f237803b34a..3a9dedaa028f2 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -27,35 +27,22 @@ pub mod parquet; #[cfg(feature = "avro")] pub mod avro; -#[allow(deprecated)] #[cfg(feature = "avro")] -pub use avro::{AvroExec, AvroSource}; +pub use avro::AvroSource; #[cfg(feature = "parquet")] pub use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] -#[allow(deprecated)] -pub use datafusion_datasource_parquet::{ - ParquetExec, ParquetExecBuilder, ParquetFileMetrics, ParquetFileReaderFactory, -}; +pub use datafusion_datasource_parquet::{ParquetFileMetrics, ParquetFileReaderFactory}; -#[allow(deprecated)] -pub use arrow_file::ArrowExec; pub use arrow_file::ArrowSource; -#[allow(deprecated)] -pub use json::NdJsonExec; - pub use json::{JsonOpener, JsonSource}; -#[allow(deprecated)] -pub use csv::{CsvExec, CsvExecBuilder}; - pub use csv::{CsvOpener, CsvSource}; pub use datafusion_datasource::file::FileSource; pub use datafusion_datasource::file_groups::FileGroup; pub use datafusion_datasource::file_groups::FileGroupPartitioner; -pub use datafusion_datasource::file_meta::FileMeta; pub use datafusion_datasource::file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, FileScanConfigBuilder, diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 9e1b2822e8540..d0774e57174ee 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -38,21 +38,22 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; use arrow::array::{ - ArrayRef, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, - StructArray, + ArrayRef, AsArray, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, + StringViewArray, StructArray, TimestampNanosecondArray, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; - use arrow_schema::SchemaRef; + use arrow::util::pretty::pretty_format_batches; + use arrow_schema::{SchemaRef, TimeUnit}; use bytes::{BufMut, BytesMut}; use datafusion_common::config::TableParquetOptions; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{assert_contains, Result, ScalarValue}; use datafusion_datasource::file_format::FileFormat; - use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; + use datafusion_datasource::file::FileSource; use datafusion_datasource::{FileRange, PartitionedFile}; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_datasource_parquet::{ @@ -61,8 +62,9 @@ mod tests { use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::planner::logical2physical; + use datafusion_physical_plan::analyze::AnalyzeExec; + use datafusion_physical_plan::collect; use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; - use datafusion_physical_plan::{collect, displayable}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use chrono::{TimeZone, Utc}; @@ -81,10 +83,10 @@ mod tests { struct RoundTripResult { /// Data that was read back from ParquetFiles batches: Result>, + /// The EXPLAIN ANALYZE output + explain: Result, /// The physical plan that was created (that has statistics, etc) parquet_exec: Arc, - /// The ParquetSource that is used in plan - parquet_source: ParquetSource, } /// round-trip record batches by writing each individual RecordBatch to @@ -93,10 +95,15 @@ mod tests { #[derive(Debug, Default)] struct RoundTrip { projection: Option>, - schema: Option, + /// Optional logical table schema to use when reading the parquet files + /// + /// If None, the logical schema to use will be inferred from the + /// original data via [`Schema::try_merge`] + table_schema: Option, predicate: Option, pushdown_predicate: bool, page_index_predicate: bool, + bloom_filters: bool, } impl RoundTrip { @@ -109,8 +116,11 @@ mod tests { self } - fn with_schema(mut self, schema: SchemaRef) -> Self { - self.schema = Some(schema); + /// Specify table schema. + /// + ///See [`Self::table_schema`] for more details + fn with_table_schema(mut self, schema: SchemaRef) -> Self { + self.table_schema = Some(schema); self } @@ -129,6 +139,11 @@ mod tests { self } + fn with_bloom_filters(mut self) -> Self { + self.bloom_filters = true; + self + } + /// run the test, returning only the resulting RecordBatches async fn round_trip_to_batches( self, @@ -137,71 +152,124 @@ mod tests { self.round_trip(batches).await.batches } - /// run the test, returning the `RoundTripResult` - async fn round_trip(self, batches: Vec) -> RoundTripResult { - let Self { - projection, - schema, - predicate, - pushdown_predicate, - page_index_predicate, - } = self; - - let file_schema = match schema { - Some(schema) => schema, - None => Arc::new( - Schema::try_merge( - batches.iter().map(|b| b.schema().as_ref().clone()), - ) - .unwrap(), - ), - }; - // If testing with page_index_predicate, write parquet - // files with multiple pages - let multi_page = page_index_predicate; - let (meta, _files) = store_parquet(batches, multi_page).await.unwrap(); - let file_group = meta.into_iter().map(Into::into).collect(); - + fn build_file_source(&self, table_schema: SchemaRef) -> Arc { // set up predicate (this is normally done by a layer higher up) - let predicate = predicate.map(|p| logical2physical(&p, &file_schema)); + let predicate = self + .predicate + .as_ref() + .map(|p| logical2physical(p, &table_schema)); let mut source = ParquetSource::default(); if let Some(predicate) = predicate { - source = source.with_predicate(Arc::clone(&file_schema), predicate); + source = source.with_predicate(predicate); } - if pushdown_predicate { + if self.pushdown_predicate { source = source .with_pushdown_filters(true) .with_reorder_filters(true); + } else { + source = source.with_pushdown_filters(false); } - if page_index_predicate { + if self.page_index_predicate { source = source.with_enable_page_index(true); + } else { + source = source.with_enable_page_index(false); } + if self.bloom_filters { + source = source.with_bloom_filter_on_read(true); + } else { + source = source.with_bloom_filter_on_read(false); + } + + source.with_schema(Arc::clone(&table_schema)) + } + + fn build_parquet_exec( + &self, + file_schema: SchemaRef, + file_group: FileGroup, + source: Arc, + ) -> Arc { let base_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), file_schema, - Arc::new(source.clone()), + source, ) .with_file_group(file_group) - .with_projection(projection) + .with_projection(self.projection.clone()) .build(); + DataSourceExec::from_data_source(base_config) + } + + /// run the test, returning the `RoundTripResult` + /// + /// Each input batch is written into one or more parquet files (and thus + /// they could potentially have different schemas). The resulting + /// parquet files are then read back and filters are applied to the + async fn round_trip(&self, batches: Vec) -> RoundTripResult { + // If table_schema is not set, we need to merge the schema of the + // input batches to get a unified schema. + let table_schema = match &self.table_schema { + Some(schema) => schema, + None => &Arc::new( + Schema::try_merge( + batches.iter().map(|b| b.schema().as_ref().clone()), + ) + .unwrap(), + ), + }; + // If testing with page_index_predicate, write parquet + // files with multiple pages + let multi_page = self.page_index_predicate; + let (meta, _files) = store_parquet(batches, multi_page).await.unwrap(); + let file_group: FileGroup = meta.into_iter().map(Into::into).collect(); + + // build a ParquetExec to return the results + let parquet_source = self.build_file_source(Arc::clone(table_schema)); + let parquet_exec = self.build_parquet_exec( + Arc::clone(table_schema), + file_group.clone(), + Arc::clone(&parquet_source), + ); + + let analyze_exec = Arc::new(AnalyzeExec::new( + false, + false, + // use a new ParquetSource to avoid sharing execution metrics + self.build_parquet_exec( + Arc::clone(table_schema), + file_group.clone(), + self.build_file_source(Arc::clone(table_schema)), + ), + Arc::new(Schema::new(vec![ + Field::new("plan_type", DataType::Utf8, true), + Field::new("plan", DataType::Utf8, true), + ])), + )); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let parquet_exec = DataSourceExec::from_data_source(base_config.clone()); + let batches = collect( + Arc::clone(&parquet_exec) as Arc, + task_ctx.clone(), + ) + .await; + + let explain = collect(analyze_exec, task_ctx.clone()) + .await + .map(|batches| { + let batches = pretty_format_batches(&batches).unwrap(); + format!("{batches}") + }); + RoundTripResult { - batches: collect(parquet_exec.clone(), task_ctx).await, + batches, + explain, parquet_exec, - parquet_source: base_config - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone(), } } } @@ -247,7 +315,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit(1_i32)); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -263,10 +331,10 @@ mod tests { let metric = get_value(&metrics, "pushdown_rows_pruned"); assert_eq!(metric, 3, "Expected all rows to be pruned"); - // If we excplicitly allow nulls the rest of the predicate should work + // If we explicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -305,7 +373,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -321,10 +389,10 @@ mod tests { let metric = get_value(&metrics, "pushdown_rows_pruned"); assert_eq!(metric, 3, "Expected all rows to be pruned"); - // If we excplicitly allow nulls the rest of the predicate should work + // If we explicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -367,7 +435,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -383,10 +451,10 @@ mod tests { let metric = get_value(&metrics, "pushdown_rows_pruned"); assert_eq!(metric, 3, "Expected all rows to be pruned"); - // If we excplicitly allow nulls the rest of the predicate should work + // If we explicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -429,7 +497,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -445,10 +513,10 @@ mod tests { let metric = get_value(&metrics, "pushdown_rows_pruned"); assert_eq!(metric, 3, "Expected all rows to be pruned"); - // If we excplicitly allow nulls the rest of the predicate should work + // If we explicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c3").eq(lit(7_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -496,7 +564,7 @@ mod tests { .and(col("c3").eq(lit(10_i32)).or(col("c2").is_null())); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -526,7 +594,7 @@ mod tests { .or(col("c3").gt(lit(20_i32)).and(col("c2").is_null())); let rt = RoundTrip::new() - .with_schema(table_schema) + .with_table_schema(table_schema) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch]) @@ -776,7 +844,7 @@ mod tests { } #[tokio::test] - async fn evolved_schema_filter() { + async fn evolved_schema_column_order_filter() { let c1: ArrayRef = Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); @@ -807,6 +875,156 @@ mod tests { assert_eq!(read.len(), 0); } + #[tokio::test] + async fn evolved_schema_column_type_filter_strings() { + // The table and filter have a common data type, but the file schema differs + let c1: ArrayRef = + Arc::new(StringViewArray::from(vec![Some("foo"), Some("bar")])); + let batch = create_batch(vec![("c1", c1.clone())]); + + // Table schema is Utf8 but file schema is StringView + let table_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); + + // Predicate should prune all row groups + let filter = col("c1").eq(lit(ScalarValue::Utf8(Some("aaa".to_string())))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_table_schema(table_schema.clone()) + .round_trip(vec![batch.clone()]) + .await; + // There should be no predicate evaluation errors + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + assert_eq!(get_value(&metrics, "pushdown_rows_matched"), 0); + assert_eq!(rt.batches.unwrap().len(), 0); + + // Predicate should prune no row groups + let filter = col("c1").eq(lit(ScalarValue::Utf8(Some("foo".to_string())))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_table_schema(table_schema) + .round_trip(vec![batch]) + .await; + // There should be no predicate evaluation errors + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + assert_eq!(get_value(&metrics, "pushdown_rows_matched"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 2, "Expected 2 rows to match the predicate"); + } + + #[tokio::test] + async fn evolved_schema_column_type_filter_ints() { + // The table and filter have a common data type, but the file schema differs + let c1: ArrayRef = Arc::new(Int8Array::from(vec![Some(1), Some(2)])); + let batch = create_batch(vec![("c1", c1.clone())]); + + let table_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt64, false)])); + + // Predicate should prune all row groups + let filter = col("c1").eq(lit(ScalarValue::UInt64(Some(5)))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_table_schema(table_schema.clone()) + .round_trip(vec![batch.clone()]) + .await; + // There should be no predicate evaluation errors + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + assert_eq!(rt.batches.unwrap().len(), 0); + + // Predicate should prune no row groups + let filter = col("c1").eq(lit(ScalarValue::UInt64(Some(1)))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_table_schema(table_schema) + .round_trip(vec![batch]) + .await; + // There should be no predicate evaluation errors + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 2, "Expected 2 rows to match the predicate"); + } + + #[tokio::test] + async fn evolved_schema_column_type_filter_timestamp_units() { + // The table and filter have a common data type + // The table schema is in milliseconds, but the file schema is in nanoseconds + let c1: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![ + Some(1_000_000_000), // 1970-01-01T00:00:01Z + Some(2_000_000_000), // 1970-01-01T00:00:02Z + Some(3_000_000_000), // 1970-01-01T00:00:03Z + Some(4_000_000_000), // 1970-01-01T00:00:04Z + ])); + let batch = create_batch(vec![("c1", c1.clone())]); + let table_schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + false, + )])); + // One row should match, 2 pruned via page index, 1 pruned via filter pushdown + let filter = col("c1").eq(lit(ScalarValue::TimestampMillisecond( + Some(1_000), + Some("UTC".into()), + ))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .with_page_index_predicate() // produces pages with 2 rows each (2 pages total for our data) + .with_table_schema(table_schema.clone()) + .round_trip(vec![batch.clone()]) + .await; + // There should be no predicate evaluation errors and we keep 1 row + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 1, "Expected 1 rows to match the predicate"); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 0); + assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 2); + assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 1); + // If we filter with a value that is completely out of the range of the data + // we prune at the row group level. + let filter = col("c1").eq(lit(ScalarValue::TimestampMillisecond( + Some(5_000), + Some("UTC".into()), + ))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .with_table_schema(table_schema) + .round_trip(vec![batch]) + .await; + // There should be no predicate evaluation errors and we keep 0 rows + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 0, "Expected 0 rows to match the predicate"); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 1); + } + #[tokio::test] async fn evolved_schema_disjoint_schema_filter() { let c1: ArrayRef = @@ -1044,7 +1262,7 @@ mod tests { // batch2: c3(int8), c2(int64), c1(string), c4(string) let batch2 = create_batch(vec![("c3", c4), ("c2", c2), ("c1", c1)]); - let schema = Schema::new(vec![ + let table_schema = Schema::new(vec![ Field::new("c1", DataType::Utf8, true), Field::new("c2", DataType::Int64, true), Field::new("c3", DataType::Int8, true), @@ -1052,7 +1270,7 @@ mod tests { // read/write them files: let read = RoundTrip::new() - .with_schema(Arc::new(schema)) + .with_table_schema(Arc::new(table_schema)) .round_trip_to_batches(vec![batch1, batch2]) .await; assert_contains!(read.unwrap_err().to_string(), @@ -1069,6 +1287,7 @@ mod tests { let parquet_exec = scan_format( &state, &ParquetFormat::default(), + None, &testdata, filename, Some(vec![0, 1, 2]), @@ -1101,6 +1320,210 @@ mod tests { Ok(()) } + #[tokio::test] + async fn parquet_exec_with_int96_from_spark() -> Result<()> { + // arrow-rs relies on the chrono library to convert between timestamps and strings, so + // instead compare as Int64. The underlying type should be a PrimitiveArray of Int64 + // anyway, so this should be a zero-copy non-modifying cast at the SchemaAdapter. + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + let testdata = datafusion_common::test_util::parquet_test_data(); + let filename = "int96_from_spark.parquet"; + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + + let time_units_and_expected = vec![ + ( + None, // Same as "ns" time_unit + Arc::new(Int64Array::from(vec![ + Some(1704141296123456000), // Reads as nanosecond fine (note 3 extra 0s) + Some(1704070800000000000), // Reads as nanosecond fine (note 3 extra 0s) + Some(-4852191831933722624), // Cannot be represented with nanos timestamp (year 9999) + Some(1735599600000000000), // Reads as nanosecond fine (note 3 extra 0s) + None, + Some(-4864435138808946688), // Cannot be represented with nanos timestamp (year 290000) + ])), + ), + ( + Some("ns".to_string()), + Arc::new(Int64Array::from(vec![ + Some(1704141296123456000), + Some(1704070800000000000), + Some(-4852191831933722624), + Some(1735599600000000000), + None, + Some(-4864435138808946688), + ])), + ), + ( + Some("us".to_string()), + Arc::new(Int64Array::from(vec![ + Some(1704141296123456), + Some(1704070800000000), + Some(253402225200000000), + Some(1735599600000000), + None, + Some(9089380393200000000), + ])), + ), + ]; + + for (time_unit, expected) in time_units_and_expected { + let parquet_exec = scan_format( + &state, + &ParquetFormat::default().with_coerce_int96(time_unit.clone()), + Some(schema.clone()), + &testdata, + filename, + Some(vec![0]), + None, + ) + .await + .unwrap(); + assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); + + let mut results = parquet_exec.execute(0, task_ctx.clone())?; + let batch = results.next().await.unwrap()?; + + assert_eq!(6, batch.num_rows()); + assert_eq!(1, batch.num_columns()); + + assert_eq!(batch.num_columns(), 1); + let column = batch.column(0); + + assert_eq!(column.len(), expected.len()); + + column + .as_primitive::() + .iter() + .zip(expected.iter()) + .for_each(|(lhs, rhs)| { + assert_eq!(lhs, rhs); + }); + } + + Ok(()) + } + + #[tokio::test] + async fn parquet_exec_with_int96_nested() -> Result<()> { + // This test ensures that we maintain compatibility with coercing int96 to the desired + // resolution when they're within a nested type (e.g., struct, map, list). This file + // originates from a modified CometFuzzTestSuite ParquetGenerator to generate combinations + // of primitive and complex columns using int96. Other tests cover reading the data + // correctly with this coercion. Here we're only checking the coerced schema is correct. + let testdata = "../../datafusion/core/tests/data"; + let filename = "int96_nested.parquet"; + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + + let parquet_exec = scan_format( + &state, + &ParquetFormat::default().with_coerce_int96(Some("us".to_string())), + None, + testdata, + filename, + None, + None, + ) + .await + .unwrap(); + assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); + + let mut results = parquet_exec.execute(0, task_ctx.clone())?; + let batch = results.next().await.unwrap()?; + + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new_struct( + "c1", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + Field::new_struct( + "c2", + vec![Field::new_list( + "c0", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + )], + true, + ), + Field::new_map( + "c3", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + Field::new_list( + "c4", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + ), + Field::new_list( + "c5", + Field::new_struct( + "element", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + true, + ), + Field::new_list( + "c6", + Field::new_map( + "element", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + true, + ), + ])); + + assert_eq!(batch.schema(), expected_schema); + + Ok(()) + } + #[tokio::test] async fn parquet_exec_with_range() -> Result<()> { fn file_range(meta: &ObjectMeta, start: i64, end: i64) -> PartitionedFile { @@ -1375,26 +1798,6 @@ mod tests { create_batch(vec![("c1", c1.clone())]) } - /// Returns a int64 array with contents: - /// "[-1, 1, null, 2, 3, null, null]" - fn int64_batch() -> RecordBatch { - let contents: ArrayRef = Arc::new(Int64Array::from(vec![ - Some(-1), - Some(1), - None, - Some(2), - Some(3), - None, - None, - ])); - - create_batch(vec![ - ("a", contents.clone()), - ("b", contents.clone()), - ("c", contents.clone()), - ]) - } - #[tokio::test] async fn parquet_exec_metrics() { // batch1: c1(string) @@ -1454,110 +1857,17 @@ mod tests { .round_trip(vec![batch1]) .await; - // should have a pruning predicate - let pruning_predicate = rt.parquet_source.pruning_predicate(); - assert!(pruning_predicate.is_some()); - - // convert to explain plan form - let display = displayable(rt.parquet_exec.as_ref()) - .indent(true) - .to_string(); + let explain = rt.explain.unwrap(); - assert_contains!( - &display, - "pruning_predicate=c1_null_count@2 != row_count@3 AND (c1_min@0 != bar OR bar != c1_max@1)" - ); - - assert_contains!(&display, r#"predicate=c1@0 != bar"#); - - assert_contains!(&display, "projection=[c1]"); - } - - #[tokio::test] - async fn parquet_exec_display_deterministic() { - // batches: a(int64), b(int64), c(int64) - let batches = int64_batch(); - - fn extract_required_guarantees(s: &str) -> Option<&str> { - s.split("required_guarantees=").nth(1) - } - - // Ensuring that the required_guarantees remain consistent across every display plan of the filter conditions - for _ in 0..100 { - // c = 1 AND b = 1 AND a = 1 - let filter0 = col("c") - .eq(lit(1)) - .and(col("b").eq(lit(1))) - .and(col("a").eq(lit(1))); - - let rt0 = RoundTrip::new() - .with_predicate(filter0) - .with_pushdown_predicate() - .round_trip(vec![batches.clone()]) - .await; - - let pruning_predicate = rt0.parquet_source.pruning_predicate(); - assert!(pruning_predicate.is_some()); - - let display0 = displayable(rt0.parquet_exec.as_ref()) - .indent(true) - .to_string(); - - let guarantees0: &str = extract_required_guarantees(&display0) - .expect("Failed to extract required_guarantees"); - // Compare only the required_guarantees part (Because the file_groups part will not be the same) - assert_eq!( - guarantees0.trim(), - "[a in (1), b in (1), c in (1)]", - "required_guarantees don't match" - ); - } - - // c = 1 AND a = 1 AND b = 1 - let filter1 = col("c") - .eq(lit(1)) - .and(col("a").eq(lit(1))) - .and(col("b").eq(lit(1))); + // check that there was a pruning predicate -> row groups got pruned + assert_contains!(&explain, "predicate=c1@0 != bar"); - let rt1 = RoundTrip::new() - .with_predicate(filter1) - .with_pushdown_predicate() - .round_trip(vec![batches.clone()]) - .await; + // there's a single row group, but we can check that it matched + // if no pruning was done this would be 0 instead of 1 + assert_contains!(&explain, "row_groups_matched_statistics=1"); - // b = 1 AND a = 1 AND c = 1 - let filter2 = col("b") - .eq(lit(1)) - .and(col("a").eq(lit(1))) - .and(col("c").eq(lit(1))); - - let rt2 = RoundTrip::new() - .with_predicate(filter2) - .with_pushdown_predicate() - .round_trip(vec![batches]) - .await; - - // should have a pruning predicate - let pruning_predicate = rt1.parquet_source.pruning_predicate(); - assert!(pruning_predicate.is_some()); - let pruning_predicate = rt2.parquet_source.predicate(); - assert!(pruning_predicate.is_some()); - - // convert to explain plan form - let display1 = displayable(rt1.parquet_exec.as_ref()) - .indent(true) - .to_string(); - let display2 = displayable(rt2.parquet_exec.as_ref()) - .indent(true) - .to_string(); - - let guarantees1 = extract_required_guarantees(&display1) - .expect("Failed to extract required_guarantees"); - let guarantees2 = extract_required_guarantees(&display2) - .expect("Failed to extract required_guarantees"); - - // Compare only the required_guarantees part (Because the predicate part will not be the same) - assert_eq!(guarantees1, guarantees2, "required_guarantees don't match"); + // check the projection + assert_contains!(&explain, "projection=[c1]"); } #[tokio::test] @@ -1581,16 +1891,19 @@ mod tests { .await; // Should not contain a pruning predicate (since nothing can be pruned) - let pruning_predicate = rt.parquet_source.pruning_predicate(); - assert!( - pruning_predicate.is_none(), - "Still had pruning predicate: {pruning_predicate:?}" - ); + let explain = rt.explain.unwrap(); + + // When both matched and pruned are 0, it means that the pruning predicate + // was not used at all. + assert_contains!(&explain, "row_groups_matched_statistics=0"); + assert_contains!(&explain, "row_groups_pruned_statistics=0"); - // but does still has a pushdown down predicate - let predicate = rt.parquet_source.predicate(); - let filter_phys = logical2physical(&filter, rt.parquet_exec.schema().as_ref()); - assert_eq!(predicate.unwrap().to_string(), filter_phys.to_string()); + // But pushdown predicate should be present + assert_contains!( + &explain, + "predicate=CASE WHEN c1@0 != bar THEN true ELSE false END" + ); + assert_contains!(&explain, "pushdown_rows_pruned=5"); } #[tokio::test] @@ -1612,12 +1925,19 @@ mod tests { let rt = RoundTrip::new() .with_predicate(filter.clone()) .with_pushdown_predicate() + .with_bloom_filters() .round_trip(vec![batch1]) .await; // Should have a pruning predicate - let pruning_predicate = rt.parquet_source.pruning_predicate(); - assert!(pruning_predicate.is_some()); + let explain = rt.explain.unwrap(); + assert_contains!( + &explain, + "predicate=c1@0 = foo AND CASE WHEN c1@0 != bar THEN true ELSE false END" + ); + + // And bloom filters should have been evaluated + assert_contains!(&explain, "row_groups_pruned_bloom_filter=1"); } /// Returns the sum of all the metrics with the specified name @@ -1697,14 +2017,14 @@ mod tests { let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; fs::create_dir(&out_dir).unwrap(); let df = ctx.sql("SELECT c1, c2 FROM test").await?; - let schema: Schema = df.schema().into(); + let schema = Arc::clone(df.schema().inner()); // Register a listing table - this will use all files in the directory as data sources // for the query ctx.register_listing_table( "my_table", &out_dir, listing_options, - Some(Arc::new(schema)), + Some(schema), None, ) .await @@ -1850,13 +2170,13 @@ mod tests { path: &str, store: Arc, batch: RecordBatch, - ) -> usize { + ) -> u64 { let mut writer = ArrowWriter::try_new(BytesMut::new().writer(), batch.schema(), None).unwrap(); writer.write(&batch).unwrap(); writer.flush().unwrap(); let bytes = writer.into_inner().unwrap().into_inner().freeze(); - let total_size = bytes.len(); + let total_size = bytes.len() as u64; let path = Path::from(path); let payload = object_store::PutPayload::from_bytes(bytes); store @@ -1886,7 +2206,7 @@ mod tests { fn create_reader( &self, partition_index: usize, - file_meta: FileMeta, + partitioned_file: PartitionedFile, metadata_size_hint: Option, metrics: &ExecutionPlanMetricsSet, ) -> Result> @@ -1897,7 +2217,7 @@ mod tests { .push(metadata_size_hint); self.inner.create_reader( partition_index, - file_meta, + partitioned_file, metadata_size_hint, metrics, ) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index fc110a0699df2..a8148b80495e6 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -34,8 +34,12 @@ use crate::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }, datasource::{provider_as_source, MemTable, ViewTable}, - error::{DataFusionError, Result}, - execution::{options::ArrowReadOptions, runtime_env::RuntimeEnv, FunctionRegistry}, + error::Result, + execution::{ + options::ArrowReadOptions, + runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, + FunctionRegistry, + }, logical_expr::AggregateUDF, logical_expr::ScalarUDF, logical_expr::{ @@ -62,9 +66,10 @@ use datafusion_catalog::{ use datafusion_common::config::ConfigOptions; use datafusion_common::{ config::{ConfigExtension, TableOptions}, - exec_datafusion_err, exec_err, not_impl_err, plan_datafusion_err, plan_err, + exec_datafusion_err, exec_err, internal_datafusion_err, not_impl_err, + plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - DFSchema, ParamValues, ScalarValue, SchemaReference, TableReference, + DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaReference, TableReference, }; pub use datafusion_execution::config::SessionConfig; use datafusion_execution::registry::SerializerRegistry; @@ -222,7 +227,7 @@ where /// # use datafusion::execution::SessionStateBuilder; /// # use datafusion_execution::runtime_env::RuntimeEnvBuilder; /// // Configure a 4k batch size -/// let config = SessionConfig::new() .with_batch_size(4 * 1024); +/// let config = SessionConfig::new().with_batch_size(4 * 1024); /// /// // configure a memory limit of 1GB with 20% slop /// let runtime_env = RuntimeEnvBuilder::new() @@ -293,13 +298,13 @@ impl SessionContext { pub async fn refresh_catalogs(&self) -> Result<()> { let cat_names = self.catalog_names().clone(); for cat_name in cat_names.iter() { - let cat = self.catalog(cat_name.as_str()).ok_or_else(|| { - DataFusionError::Internal("Catalog not found!".to_string()) - })?; + let cat = self + .catalog(cat_name.as_str()) + .ok_or_else(|| internal_datafusion_err!("Catalog not found!"))?; for schema_name in cat.schema_names() { - let schema = cat.schema(schema_name.as_str()).ok_or_else(|| { - DataFusionError::Internal("Schema not found!".to_string()) - })?; + let schema = cat + .schema(schema_name.as_str()) + .ok_or_else(|| internal_datafusion_err!("Schema not found!"))?; let lister = schema.as_any().downcast_ref::(); if let Some(lister) = lister { lister.refresh(&self.state()).await?; @@ -581,6 +586,7 @@ impl SessionContext { /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub async fn sql(&self, sql: &str) -> Result { self.sql_with_options(sql, SQLOptions::new()).await } @@ -611,6 +617,7 @@ impl SessionContext { /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub async fn sql_with_options( &self, sql: &str, @@ -644,6 +651,7 @@ impl SessionContext { /// # Ok(()) /// # } /// ``` + #[cfg(feature = "sql")] pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result { self.state.read().create_logical_expr(sql, df_schema) } @@ -785,19 +793,44 @@ impl SessionContext { return not_impl_err!("Temporary tables not supported"); } - if exist { - match cmd.if_not_exists { - true => return self.return_empty_dataframe(), - false => { - return exec_err!("Table '{}' already exists", cmd.name); + match (cmd.if_not_exists, cmd.or_replace, exist) { + (true, false, true) => self.return_empty_dataframe(), + (false, true, true) => { + let result = self + .find_and_deregister(cmd.name.clone(), TableType::Base) + .await; + + match result { + Ok(true) => { + let table_provider: Arc = + self.create_custom_table(cmd).await?; + self.register_table(cmd.name.clone(), table_provider)?; + self.return_empty_dataframe() + } + Ok(false) => { + let table_provider: Arc = + self.create_custom_table(cmd).await?; + self.register_table(cmd.name.clone(), table_provider)?; + self.return_empty_dataframe() + } + Err(e) => { + exec_err!("Errored while deregistering external table: {}", e) + } } } + (true, true, true) => { + exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'") + } + (_, _, false) => { + let table_provider: Arc = + self.create_custom_table(cmd).await?; + self.register_table(cmd.name.clone(), table_provider)?; + self.return_empty_dataframe() + } + (false, false, true) => { + exec_err!("External table '{}' already exists", cmd.name) + } } - - let table_provider: Arc = - self.create_custom_table(cmd).await?; - self.register_table(cmd.name.clone(), table_provider)?; - self.return_empty_dataframe() } async fn create_memory_table(&self, cmd: CreateMemoryTable) -> Result { @@ -823,7 +856,7 @@ impl SessionContext { (true, false, Ok(_)) => self.return_empty_dataframe(), (false, true, Ok(_)) => { self.deregister_table(name.clone())?; - let schema = Arc::new(input.schema().as_ref().into()); + let schema = Arc::clone(input.schema().inner()); let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; @@ -841,8 +874,7 @@ impl SessionContext { exec_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'") } (_, _, Err(_)) => { - let df_schema = input.schema(); - let schema = Arc::new(df_schema.as_ref().into()); + let schema = Arc::clone(input.schema().inner()); let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; @@ -910,7 +942,7 @@ impl SessionContext { .. } = cmd; - // sqlparser doesnt accept database / catalog as parameter to CREATE SCHEMA + // sqlparser doesn't accept database / catalog as parameter to CREATE SCHEMA // so for now, we default to default catalog let tokens: Vec<&str> = schema_name.split('.').collect(); let (catalog, schema_name) = match tokens.len() { @@ -918,17 +950,15 @@ impl SessionContext { let state = self.state.read(); let name = &state.config().options().catalog.default_catalog; let catalog = state.catalog_list().catalog(name).ok_or_else(|| { - DataFusionError::Execution(format!( - "Missing default catalog '{name}'" - )) + exec_datafusion_err!("Missing default catalog '{name}'") })?; (catalog, tokens[0]) } 2 => { let name = &tokens[0]; - let catalog = self.catalog(name).ok_or_else(|| { - DataFusionError::Execution(format!("Missing catalog '{name}'")) - })?; + let catalog = self + .catalog(name) + .ok_or_else(|| exec_datafusion_err!("Missing catalog '{name}'"))?; (catalog, tokens[1]) } _ => return exec_err!("Unable to parse catalog from {schema_name}"), @@ -1036,13 +1066,72 @@ impl SessionContext { variable, value, .. } = stmt; - let mut state = self.state.write(); - state.config_mut().options_mut().set(&variable, &value)?; - drop(state); + // Check if this is a runtime configuration + if variable.starts_with("datafusion.runtime.") { + self.set_runtime_variable(&variable, &value)?; + } else { + let mut state = self.state.write(); + state.config_mut().options_mut().set(&variable, &value)?; + drop(state); + } self.return_empty_dataframe() } + fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> { + let key = variable.strip_prefix("datafusion.runtime.").unwrap(); + + let mut state = self.state.write(); + + let mut builder = RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); + builder = match key { + "memory_limit" => { + let memory_limit = Self::parse_memory_limit(value)?; + builder.with_memory_limit(memory_limit, 1.0) + } + "max_temp_directory_size" => { + let directory_size = Self::parse_memory_limit(value)?; + builder.with_max_temp_directory_size(directory_size as u64) + } + "temp_directory" => builder.with_temp_file_path(value), + "metadata_cache_limit" => { + let limit = Self::parse_memory_limit(value)?; + builder.with_metadata_cache_limit(limit) + } + _ => return plan_err!("Unknown runtime configuration: {variable}"), + }; + + *state = SessionStateBuilder::from(state.clone()) + .with_runtime_env(Arc::new(builder.build()?)) + .build(); + + Ok(()) + } + + /// Parse memory limit from string to number of bytes + /// Supports formats like '1.5G', '100M', '512K' + /// + /// # Examples + /// ``` + /// use datafusion::execution::context::SessionContext; + /// + /// assert_eq!(SessionContext::parse_memory_limit("1M").unwrap(), 1024 * 1024); + /// assert_eq!(SessionContext::parse_memory_limit("1.5G").unwrap(), (1.5 * 1024.0 * 1024.0 * 1024.0) as usize); + /// ``` + pub fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + plan_datafusion_err!("Failed to parse number from memory limit '{limit}'") + })?; + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => plan_err!("Unsupported unit '{unit}' in memory limit '{limit}'"), + } + } + async fn create_custom_table( &self, cmd: &CreateExternalTable, @@ -1054,10 +1143,7 @@ impl SessionContext { .table_factories() .get(file_type.as_str()) .ok_or_else(|| { - DataFusionError::Execution(format!( - "Unable to find factory for {}", - cmd.file_type - )) + exec_datafusion_err!("Unable to find factory for {}", cmd.file_type) })?; let table = (*factory).create(&state, cmd).await?; Ok(table) @@ -1098,9 +1184,11 @@ impl SessionContext { match function_factory { Some(f) => f.create(&state, stmt).await?, - _ => Err(DataFusionError::Configuration( - "Function factory has not been configured".into(), - ))?, + _ => { + return Err(DataFusionError::Configuration( + "Function factory has not been configured".to_string(), + )) + } } }; @@ -1153,7 +1241,7 @@ impl SessionContext { let mut params: Vec = parameters .into_iter() .map(|e| match e { - Expr::Literal(scalar) => Ok(scalar), + Expr::Literal(scalar, _) => Ok(scalar), _ => not_impl_err!("Unsupported parameter type: {}", e), }) .collect::>()?; @@ -1579,7 +1667,7 @@ impl SessionContext { /// [`ConfigOptions`]: crate::config::ConfigOptions pub fn state(&self) -> SessionState { let mut state = self.state.read().clone(); - state.execution_props_mut().start_execution(); + state.mark_start_execution(); state } @@ -1647,7 +1735,7 @@ impl FunctionRegistry for SessionContext { } fn expr_planners(&self) -> Vec> { - self.state.read().expr_planners() + self.state.read().expr_planners().to_vec() } fn register_expr_planner( @@ -1656,6 +1744,14 @@ impl FunctionRegistry for SessionContext { ) -> Result<()> { self.state.write().register_expr_planner(expr_planner) } + + fn udafs(&self) -> HashSet { + self.state.read().udafs() + } + + fn udwfs(&self) -> HashSet { + self.state.read().udwfs() + } } /// Create a new task context instance from SessionContext @@ -1680,7 +1776,7 @@ impl From for SessionStateBuilder { /// A planner used to add extensions to DataFusion logical and physical plans. #[async_trait] pub trait QueryPlanner: Debug { - /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution + /// Given a [`LogicalPlan`], create an [`ExecutionPlan`] suitable for execution async fn create_physical_plan( &self, logical_plan: &LogicalPlan, @@ -1688,12 +1784,46 @@ pub trait QueryPlanner: Debug { ) -> Result>; } -/// A pluggable interface to handle `CREATE FUNCTION` statements -/// and interact with [SessionState] to registers new udf, udaf or udwf. +/// Interface for handling `CREATE FUNCTION` statements and interacting with +/// [SessionState] to create and register functions ([`ScalarUDF`], +/// [`AggregateUDF`], [`WindowUDF`], and [`TableFunctionImpl`]) dynamically. +/// +/// Implement this trait to create user-defined functions in a custom way, such +/// as loading from external libraries or defining them programmatically. +/// DataFusion will parse `CREATE FUNCTION` statements into [`CreateFunction`] +/// structs and pass them to the [`create`](Self::create) method. +/// +/// Note there is no default implementation of this trait provided in DataFusion, +/// because the implementation and requirements vary widely. Please see +/// [function_factory example] for a reference implementation. +/// +/// [function_factory example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/function_factory.rs +/// +/// # Examples of syntax that can be supported +/// +/// ```sql +/// CREATE FUNCTION f1(BIGINT) +/// RETURNS BIGINT +/// RETURN $1 + 1; +/// ``` +/// or +/// ```sql +/// CREATE FUNCTION to_miles(DOUBLE) +/// RETURNS DOUBLE +/// LANGUAGE PYTHON +/// AS ' +/// import pyarrow.compute as pc +/// +/// conversation_rate_multiplier = 0.62137119 +/// +/// def to_miles(km_data): +/// return pc.multiply(km_data, conversation_rate_multiplier) +/// ' +/// ``` #[async_trait] pub trait FunctionFactory: Debug + Sync + Send { - /// Handles creation of user defined function specified in [CreateFunction] statement + /// Creates a new dynamic function from the SQL in the [CreateFunction] statement async fn create( &self, state: &SessionState, @@ -1701,7 +1831,8 @@ pub trait FunctionFactory: Debug + Sync + Send { ) -> Result; } -/// Type of function to create +/// The result of processing a [`CreateFunction`] statement with [`FunctionFactory`]. +#[derive(Debug, Clone)] pub enum RegisterFunction { /// Scalar user defined function Scalar(Arc), @@ -1833,7 +1964,7 @@ mod tests { use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use arrow::datatypes::{DataType, TimeUnit}; - use std::env; + use datafusion_common::DataFusionError; use std::error::Error; use std::path::PathBuf; diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 6ec9796fe90d5..731f7e59ecfaf 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -31,6 +31,21 @@ impl SessionContext { /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. /// /// For an example, see [`read_csv`](Self::read_csv) + /// + /// # Note: Statistics + /// + /// NOTE: by default, statistics are collected when reading the Parquet + /// files This can slow down the initial DataFrame creation while + /// greatly accelerating queries with certain filters. + /// + /// To disable statistics collection, set the [config option] + /// `datafusion.execution.collect_statistics` to `false`. See + /// [`ConfigOptions`] and [`ExecutionOptions::collect_statistics`] for more + /// details. + /// + /// [config option]: https://datafusion.apache.org/user-guide/configs.html + /// [`ConfigOptions`]: crate::config::ConfigOptions + /// [`ExecutionOptions::collect_statistics`]: crate::config::ExecutionOptions::collect_statistics pub async fn read_parquet( &self, table_paths: P, @@ -41,6 +56,13 @@ impl SessionContext { /// Registers a Parquet file as a table that can be referenced from SQL /// statements executed against this context. + /// + /// # Note: Statistics + /// + /// Statistics are not collected by default. See [`read_parquet`] for more + /// details and how to enable them. + /// + /// [`read_parquet`]: Self::read_parquet pub async fn register_parquet( &self, table_ref: impl Into, @@ -84,10 +106,14 @@ mod tests { use crate::parquet::basic::Compression; use crate::test_util::parquet_test_data; + use arrow::util::pretty::pretty_format_batches; use datafusion_common::config::TableParquetOptions; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, + }; use datafusion_execution::config::SessionConfig; - use tempfile::tempdir; + use tempfile::{tempdir, TempDir}; #[tokio::test] async fn read_with_glob_path() -> Result<()> { @@ -129,6 +155,49 @@ mod tests { Ok(()) } + async fn explain_query_all_with_config(config: SessionConfig) -> Result { + let ctx = SessionContext::new_with_config(config); + + ctx.register_parquet( + "test", + &format!("{}/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let df = ctx.sql("EXPLAIN SELECT * FROM test").await?; + let results = df.collect().await?; + let content = pretty_format_batches(&results).unwrap().to_string(); + Ok(content) + } + + #[tokio::test] + async fn register_parquet_respects_collect_statistics_config() -> Result<()> { + // The default is true + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Exact("); + + // Explicitly set to true + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + config.options_mut().execution.collect_statistics = true; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Exact("); + + // Explicitly set to false + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + config.options_mut().execution.collect_statistics = false; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Absent,"); + + Ok(()) + } + #[tokio::test] async fn read_from_registered_table_with_glob_path() -> Result<()> { let ctx = SessionContext::new(); @@ -286,7 +355,7 @@ mod tests { let expected_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expected_path) + format!("Execution error: File path '{expected_path}' does not match the expected extension '.parquet'") ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. @@ -333,4 +402,124 @@ mod tests { assert_eq!(total_rows, 5); Ok(()) } + + #[tokio::test] + async fn read_from_parquet_folder() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + let test_path = tmp_dir.path().to_str().unwrap().to_string(); + + ctx.sql("SELECT 1 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql("SELECT 2 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + // Adding CSV to check it is not read with Parquet reader + ctx.sql("SELECT 3 a") + .await? + .write_csv(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + let actual = ctx + .read_parquet(&test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + let actual = ctx + .read_parquet(test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + Ok(()) + } + + #[tokio::test] + async fn read_from_parquet_folder_table() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + let test_path = tmp_dir.path().to_str().unwrap().to_string(); + + ctx.sql("SELECT 1 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql("SELECT 2 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + // Adding CSV to check it is not read with Parquet reader + ctx.sql("SELECT 3 a") + .await? + .write_csv(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql(format!("CREATE EXTERNAL TABLE parquet_folder_t1 STORED AS PARQUET LOCATION '{test_path}'").as_ref()) + .await?; + + let actual = ctx + .sql("select * from parquet_folder_t1") + .await? + .collect() + .await?; + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + Ok(()) + } + + #[tokio::test] + async fn read_dummy_folder() -> Result<()> { + let ctx = SessionContext::new(); + let test_path = "/foo/"; + + let actual = ctx + .read_parquet(test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_eq!(&[ + "++", + "++", + ], &actual); + + Ok(()) + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 28f599304f8c8..b04004dd495c8 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -24,8 +24,8 @@ use std::fmt::Debug; use std::sync::Arc; use crate::catalog::{CatalogProviderList, SchemaProvider, TableProviderFactory}; -use crate::datasource::cte_worktable::CteWorkTable; -use crate::datasource::file_format::{format_as_file_type, FileFormatFactory}; +use crate::datasource::file_format::FileFormatFactory; +#[cfg(feature = "sql")] use crate::datasource::provider_as_source; use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::execution::SessionStateDefaults; @@ -34,16 +34,15 @@ use datafusion_catalog::information_schema::{ InformationSchemaProvider, INFORMATION_SCHEMA, }; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::DataType; use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::{TableFunction, TableFunctionImpl}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; -use datafusion_common::file_options::file_type::FileType; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - config_err, exec_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, + config_err, exec_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, }; use datafusion_execution::config::SessionConfig; @@ -51,13 +50,15 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::planner::{ExprPlanner, TypePlanner}; +use datafusion_expr::planner::ExprPlanner; +#[cfg(feature = "sql")] +use datafusion_expr::planner::TypePlanner; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; -use datafusion_expr::var_provider::{is_system_variables, VarType}; +#[cfg(feature = "sql")] +use datafusion_expr::TableSource; use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, TableSource, - WindowUDF, + AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, WindowUDF, }; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ @@ -69,16 +70,22 @@ use datafusion_physical_optimizer::optimizer::PhysicalOptimizer; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; -use datafusion_sql::parser::{DFParserBuilder, Statement}; -use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; +#[cfg(feature = "sql")] +use datafusion_sql::{ + parser::{DFParserBuilder, Statement}, + planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, +}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use itertools::Itertools; use log::{debug, info}; use object_store::ObjectStore; -use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias}; -use sqlparser::dialect::dialect_from_str; +#[cfg(feature = "sql")] +use sqlparser::{ + ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias}, + dialect::dialect_from_str, +}; use url::Url; use uuid::Uuid; @@ -132,6 +139,7 @@ pub struct SessionState { /// Provides support for customizing the SQL planner, e.g. to add support for custom operators like `->>` or `?` expr_planners: Vec>, /// Provides support for customizing the SQL type planning + #[cfg(feature = "sql")] type_planner: Option>, /// Responsible for optimizing a logical plan optimizer: Optimizer, @@ -185,7 +193,8 @@ impl Debug for SessionState { /// Prefer having short fields at the top and long vector fields near the end /// Group fields by fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SessionState") + let mut debug_struct = f.debug_struct("SessionState"); + let ret = debug_struct .field("session_id", &self.session_id) .field("config", &self.config) .field("runtime_env", &self.runtime_env) @@ -196,9 +205,12 @@ impl Debug for SessionState { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) - .field("expr_planners", &self.expr_planners) - .field("type_planner", &self.type_planner) - .field("query_planners", &self.query_planner) + .field("expr_planners", &self.expr_planners); + + #[cfg(feature = "sql")] + let ret = ret.field("type_planner", &self.type_planner); + + ret.field("query_planners", &self.query_planner) .field("analyzer", &self.analyzer) .field("optimizer", &self.optimizer) .field("physical_optimizers", &self.physical_optimizers) @@ -274,17 +286,6 @@ impl Session for SessionState { } impl SessionState { - /// Returns new [`SessionState`] using the provided - /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated(since = "41.0.0", note = "Use SessionStateBuilder")] - pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - SessionStateBuilder::new() - .with_config(config) - .with_runtime_env(runtime) - .with_default_features() - .build() - } - pub(crate) fn resolve_table_ref( &self, table_ref: impl Into, @@ -369,6 +370,7 @@ impl SessionState { /// [`Statement`]. See [`SessionContext::sql`] for running queries. /// /// [`SessionContext::sql`]: crate::execution::context::SessionContext::sql + #[cfg(feature = "sql")] pub fn sql_to_statement( &self, sql: &str, @@ -391,7 +393,7 @@ impl SessionState { .parse_statements()?; if statements.len() > 1 { - return not_impl_err!( + return datafusion_common::not_impl_err!( "The context currently only supports a single SQL statement" ); } @@ -405,6 +407,7 @@ impl SessionState { /// parse a sql string into a sqlparser-rs AST [`SQLExpr`]. /// /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + #[cfg(feature = "sql")] pub fn sql_to_expr( &self, sql: &str, @@ -416,6 +419,7 @@ impl SessionState { /// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`]. /// /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + #[cfg(feature = "sql")] pub fn sql_to_expr_with_alias( &self, sql: &str, @@ -434,7 +438,7 @@ impl SessionState { .with_dialect(dialect.as_ref()) .with_recursion_limit(recursion_limit) .build()? - .parse_expr()?; + .parse_into_expr()?; Ok(expr) } @@ -444,6 +448,7 @@ impl SessionState { /// See [`datafusion_sql::resolve::resolve_table_references`] for more information. /// /// [`datafusion_sql::resolve::resolve_table_references`]: datafusion_sql::resolve::resolve_table_references + #[cfg(feature = "sql")] pub fn resolve_table_references( &self, statement: &Statement, @@ -458,6 +463,7 @@ impl SessionState { } /// Convert an AST Statement into a LogicalPlan + #[cfg(feature = "sql")] pub async fn statement_to_plan( &self, statement: Statement, @@ -485,6 +491,7 @@ impl SessionState { query.statement_to_plan(statement) } + #[cfg(feature = "sql")] fn get_parser_options(&self) -> ParserOptions { let sql_parser_options = &self.config.options().sql_parser; @@ -494,8 +501,12 @@ impl SessionState { enable_options_value_normalization: sql_parser_options .enable_options_value_normalization, support_varchar_with_length: sql_parser_options.support_varchar_with_length, - map_varchar_to_utf8view: sql_parser_options.map_varchar_to_utf8view, + map_string_types_to_utf8view: sql_parser_options.map_string_types_to_utf8view, collect_spans: sql_parser_options.collect_spans, + default_null_ordering: sql_parser_options + .default_null_ordering + .as_str() + .into(), } } @@ -511,6 +522,7 @@ impl SessionState { /// /// [`SessionContext::sql`]: crate::execution::context::SessionContext::sql /// [`SessionContext::sql_with_options`]: crate::execution::context::SessionContext::sql_with_options + #[cfg(feature = "sql")] pub async fn create_logical_plan( &self, sql: &str, @@ -524,6 +536,7 @@ impl SessionState { /// Creates a datafusion style AST [`Expr`] from a SQL string. /// /// See example on [SessionContext::parse_sql_expr](crate::execution::context::SessionContext::parse_sql_expr) + #[cfg(feature = "sql")] pub fn create_logical_expr( &self, sql: &str, @@ -552,6 +565,11 @@ impl SessionState { &self.optimizer } + /// Returns the [`ExprPlanner`]s for this session + pub fn expr_planners(&self) -> &[Arc] { + &self.expr_planners + } + /// Returns the [`QueryPlanner`] for this session pub fn query_planner(&self) -> &Arc { &self.query_planner @@ -565,7 +583,7 @@ impl SessionState { // analyze & capture output of each rule let analyzer_result = self.analyzer.execute_and_check( e.plan.as_ref().clone(), - self.options(), + &self.options(), |analyzed_plan, analyzer| { let analyzer_name = analyzer.name().to_string(); let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; @@ -627,7 +645,7 @@ impl SessionState { } else { let analyzed_plan = self.analyzer.execute_and_check( plan.clone(), - self.options(), + &self.options(), |_, _| {}, )?; self.optimizer.optimize(analyzed_plan, self, |_, _| {}) @@ -729,10 +747,16 @@ impl SessionState { } /// return the configuration options - pub fn config_options(&self) -> &ConfigOptions { + pub fn config_options(&self) -> &Arc { self.config.options() } + /// Mark the start of the execution + pub fn mark_start_execution(&mut self) { + let config = Arc::clone(self.config.options()); + self.execution_props.mark_start_execution(config); + } + /// Return the table options pub fn table_options(&self) -> &TableOptions { &self.table_options @@ -888,6 +912,7 @@ pub struct SessionStateBuilder { session_id: Option, analyzer: Option, expr_planners: Option>>, + #[cfg(feature = "sql")] type_planner: Option>, optimizer: Option, physical_optimizers: Option, @@ -924,6 +949,7 @@ impl SessionStateBuilder { session_id: None, analyzer: None, expr_planners: None, + #[cfg(feature = "sql")] type_planner: None, optimizer: None, physical_optimizers: None, @@ -973,6 +999,7 @@ impl SessionStateBuilder { session_id: None, analyzer: Some(existing.analyzer), expr_planners: Some(existing.expr_planners), + #[cfg(feature = "sql")] type_planner: existing.type_planner, optimizer: Some(existing.optimizer), physical_optimizers: Some(existing.physical_optimizers), @@ -1114,6 +1141,7 @@ impl SessionStateBuilder { } /// Set the [`TypePlanner`] used to customize the behavior of the SQL planner. + #[cfg(feature = "sql")] pub fn with_type_planner(mut self, type_planner: Arc) -> Self { self.type_planner = Some(type_planner); self @@ -1325,6 +1353,7 @@ impl SessionStateBuilder { session_id, analyzer, expr_planners, + #[cfg(feature = "sql")] type_planner, optimizer, physical_optimizers, @@ -1348,28 +1377,31 @@ impl SessionStateBuilder { } = self; let config = config.unwrap_or_default(); - let runtime_env = runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); + let runtime_env = runtime_env.unwrap_or_else(|| Arc::new(RuntimeEnv::default())); let mut state = SessionState { - session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), + session_id: session_id.unwrap_or_else(|| Uuid::new_v4().to_string()), analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), + #[cfg(feature = "sql")] type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), - query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), - catalog_list: catalog_list - .unwrap_or(Arc::new(MemoryCatalogProviderList::new()) - as Arc), + query_planner: query_planner + .unwrap_or_else(|| Arc::new(DefaultQueryPlanner {})), + catalog_list: catalog_list.unwrap_or_else(|| { + Arc::new(MemoryCatalogProviderList::new()) as Arc + }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: serializer_registry - .unwrap_or(Arc::new(EmptySerializerRegistry)), + .unwrap_or_else(|| Arc::new(EmptySerializerRegistry)), file_formats: HashMap::new(), - table_options: table_options - .unwrap_or(TableOptions::default_from_session_config(config.options())), + table_options: table_options.unwrap_or_else(|| { + TableOptions::default_from_session_config(config.options()) + }), config, execution_props: execution_props.unwrap_or_default(), table_factories: table_factories.unwrap_or_default(), @@ -1470,6 +1502,7 @@ impl SessionStateBuilder { } /// Returns the current type_planner value + #[cfg(feature = "sql")] pub fn type_planner(&mut self) -> &mut Option> { &mut self.type_planner } @@ -1584,7 +1617,8 @@ impl Debug for SessionStateBuilder { /// Prefer having short fields at the top and long vector fields near the end /// Group fields by fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SessionStateBuilder") + let mut debug_struct = f.debug_struct("SessionStateBuilder"); + let ret = debug_struct .field("session_id", &self.session_id) .field("config", &self.config) .field("runtime_env", &self.runtime_env) @@ -1595,9 +1629,10 @@ impl Debug for SessionStateBuilder { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) - .field("expr_planners", &self.expr_planners) - .field("type_planner", &self.type_planner) - .field("query_planners", &self.query_planner) + .field("expr_planners", &self.expr_planners); + #[cfg(feature = "sql")] + let ret = ret.field("type_planner", &self.type_planner); + ret.field("query_planners", &self.query_planner) .field("analyzer_rules", &self.analyzer_rules) .field("analyzer", &self.analyzer) .field("optimizer_rules", &self.optimizer_rules) @@ -1628,14 +1663,16 @@ impl From for SessionStateBuilder { /// /// This is used so the SQL planner can access the state of the session without /// having a direct dependency on the [`SessionState`] struct (and core crate) +#[cfg(feature = "sql")] struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, } +#[cfg(feature = "sql")] impl ContextProvider for SessionContextProvider<'_> { fn get_expr_planners(&self) -> &[Arc] { - &self.state.expr_planners + self.state.expr_planners() } fn get_type_planner(&self) -> Option> { @@ -1668,6 +1705,13 @@ impl ContextProvider for SessionContextProvider<'_> { .get(name) .cloned() .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let dummy_schema = DFSchema::empty(); + let simplifier = + ExprSimplifier::new(SessionSimplifyProvider::new(self.state, &dummy_schema)); + let args = args + .into_iter() + .map(|arg| simplifier.simplify(arg)) + .collect::>>()?; let provider = tbl_func.create_table_provider(&args)?; Ok(provider_as_source(provider)) @@ -1679,9 +1723,11 @@ impl ContextProvider for SessionContextProvider<'_> { fn create_cte_work_table( &self, name: &str, - schema: SchemaRef, + schema: arrow::datatypes::SchemaRef, ) -> datafusion_common::Result> { - let table = Arc::new(CteWorkTable::new(name, schema)); + let table = Arc::new(crate::datasource::cte_worktable::CteWorkTable::new( + name, schema, + )); Ok(provider_as_source(table)) } @@ -1698,6 +1744,8 @@ impl ContextProvider for SessionContextProvider<'_> { } fn get_variable_type(&self, variable_names: &[String]) -> Option { + use datafusion_expr::var_provider::{is_system_variables, VarType}; + if variable_names.is_empty() { return None; } @@ -1731,14 +1779,21 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.window_functions().keys().cloned().collect() } - fn get_file_type(&self, ext: &str) -> datafusion_common::Result> { + fn get_file_type( + &self, + ext: &str, + ) -> datafusion_common::Result< + Arc, + > { self.state .file_formats .get(&ext.to_lowercase()) .ok_or(plan_datafusion_err!( "There is no registered file format with ext {ext}" )) - .map(|file_type| format_as_file_type(Arc::clone(file_type))) + .map(|file_type| { + crate::datasource::file_format::format_as_file_type(Arc::clone(file_type)) + }) } } @@ -1751,7 +1806,7 @@ impl FunctionRegistry for SessionState { let result = self.scalar_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDF named \"{name}\" in the registry. Use session context `register_udf` function to register a custom UDF") }) } @@ -1759,7 +1814,7 @@ impl FunctionRegistry for SessionState { let result = self.aggregate_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry. Use session context `register_udaf` function to register a custom UDAF") }) } @@ -1767,7 +1822,7 @@ impl FunctionRegistry for SessionState { let result = self.window_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry. Use session context `register_udwf` function to register a custom UDWF") }) } @@ -1862,6 +1917,14 @@ impl FunctionRegistry for SessionState { self.expr_planners.push(expr_planner); Ok(()) } + + fn udafs(&self) -> HashSet { + self.aggregate_functions.keys().cloned().collect() + } + + fn udwfs(&self) -> HashSet { + self.window_functions.keys().cloned().collect() + } } impl OptimizerConfig for SessionState { @@ -1873,8 +1936,8 @@ impl OptimizerConfig for SessionState { &self.execution_props.alias_generator } - fn options(&self) -> &ConfigOptions { - self.config_options() + fn options(&self) -> Arc { + Arc::clone(self.config.options()) } fn function_registry(&self) -> Option<&dyn FunctionRegistry> { @@ -1957,8 +2020,17 @@ pub(crate) struct PreparedPlan { #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; + use crate::common::assert_contains; + use crate::config::ConfigOptions; + use crate::datasource::empty::EmptyTable; + use crate::datasource::provider_as_source; use crate::datasource::MemTable; use crate::execution::context::SessionState; + use crate::logical_expr::planner::ExprPlanner; + use crate::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use crate::physical_plan::ExecutionPlan; + use crate::sql::planner::ContextProvider; + use crate::sql::{ResolvedTableReference, TableReference}; use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; @@ -1968,13 +2040,16 @@ mod tests { use datafusion_expr::Expr; use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; + use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; #[test] + #[cfg(feature = "sql")] fn test_session_state_with_default_features() { // test array planners with and without builtin planners + #[cfg(feature = "sql")] fn sql_to_expr(state: &SessionState) -> Result { let provider = SessionContextProvider { state, @@ -2125,4 +2200,148 @@ mod tests { Ok(()) } + + /// This test demonstrates why it's more convenient and somewhat necessary to provide + /// an `expr_planners` method for `SessionState`. + #[tokio::test] + async fn test_with_expr_planners() -> Result<()> { + // A helper method for planning count wildcard with or without expr planners. + async fn plan_count_wildcard( + with_expr_planners: bool, + ) -> Result> { + let mut context_provider = MyContextProvider::new().with_table( + "t", + provider_as_source(Arc::new(EmptyTable::new(Schema::empty().into()))), + ); + if with_expr_planners { + context_provider = context_provider.with_expr_planners(); + } + + let state = &context_provider.state; + let statement = state.sql_to_statement("select count(*) from t", "mysql")?; + let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?; + state.create_physical_plan(&plan).await + } + + // Planning count wildcard without expr planners should fail. + let got = plan_count_wildcard(false).await; + assert_contains!( + got.unwrap_err().to_string(), + "Physical plan does not support logical expression Wildcard" + ); + + // Planning count wildcard with expr planners should succeed. + let got = plan_count_wildcard(true).await?; + let displayable = DisplayableExecutionPlan::new(got.as_ref()); + assert_eq!( + displayable.indent(false).to_string(), + "ProjectionExec: expr=[0 as count(*)]\n PlaceholderRowExec\n" + ); + + Ok(()) + } + + /// A `ContextProvider` based on `SessionState`. + /// + /// Almost all planning context are retrieved from the `SessionState`. + struct MyContextProvider { + /// The session state. + state: SessionState, + /// Registered tables. + tables: HashMap>, + /// Controls whether to return expression planners when called `ContextProvider::expr_planners`. + return_expr_planners: bool, + } + + impl MyContextProvider { + /// Creates a new `SessionContextProvider`. + pub fn new() -> Self { + Self { + state: SessionStateBuilder::default() + .with_default_features() + .build(), + tables: HashMap::new(), + return_expr_planners: false, + } + } + + /// Registers a table. + /// + /// The catalog and schema are provided by default. + pub fn with_table(mut self, table: &str, source: Arc) -> Self { + self.tables.insert( + ResolvedTableReference { + catalog: "default".to_string().into(), + schema: "public".to_string().into(), + table: table.to_string().into(), + }, + source, + ); + self + } + + /// Sets the `return_expr_planners` flag to true. + pub fn with_expr_planners(self) -> Self { + Self { + return_expr_planners: true, + ..self + } + } + } + + impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + let resolved_table_ref = ResolvedTableReference { + catalog: "default".to_string().into(), + schema: "public".to_string().into(), + table: name.table().to_string().into(), + }; + let source = self.tables.get(&resolved_table_ref).cloned().unwrap(); + Ok(source) + } + + /// We use a `return_expr_planners` flag to demonstrate why it's necessary to + /// return the expression planners in the `SessionState`. + /// + /// Note, the default implementation returns an empty slice. + fn get_expr_planners(&self) -> &[Arc] { + if self.return_expr_planners { + self.state.expr_planners() + } else { + &[] + } + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn options(&self) -> &ConfigOptions { + self.state.config_options() + } + + fn udf_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } + } } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index a241738bd3a42..baf396f3f1c52 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -90,11 +90,10 @@ impl SessionStateDefaults { Arc::new(functions_nested::planner::NestedFunctionPlanner), #[cfg(feature = "nested_expressions")] Arc::new(functions_nested::planner::FieldAccessPlanner), - #[cfg(any( - feature = "datetime_expressions", - feature = "unicode_expressions" - ))] - Arc::new(functions::planner::UserDefinedFunctionPlanner), + #[cfg(feature = "datetime_expressions")] + Arc::new(functions::datetime::planner::DatetimeFunctionPlanner), + #[cfg(feature = "unicode_expressions")] + Arc::new(functions::unicode::planner::UnicodeFunctionPlanner), Arc::new(functions_aggregate::planner::AggregateFunctionPlanner), Arc::new(functions_window::planner::WindowFunctionPlanner), ]; diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index cc510bc81f1a8..e7ace544a11cf 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -19,16 +19,28 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 -#![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +// +// Eliminate unnecessary function calls(some may be not cheap) due to `xxx_or` +// for performance. Also avoid abusing `xxx_or_else` for readability: +// https://github.com/apache/datafusion/issues/15802 +#![cfg_attr( + not(test), + deny( + clippy::clone_on_ref_ptr, + clippy::or_fun_call, + clippy::unnecessary_lazy_evaluations + ) +)] #![warn(missing_docs, clippy::needless_borrow)] //! [DataFusion] is an extensible query engine written in Rust that //! uses [Apache Arrow] as its in-memory format. DataFusion's target users are //! developers building fast and feature rich database and analytic systems, -//! customized to particular workloads. See [use cases] for examples. +//! customized to particular workloads. Please see the [DataFusion website] for +//! additional documentation, [use cases] and examples. //! //! "Out of the box," DataFusion offers [SQL] and [`Dataframe`] APIs, //! excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, @@ -42,6 +54,7 @@ //! See the [Architecture] section below for more details. //! //! [DataFusion]: https://datafusion.apache.org/ +//! [DataFusion website]: https://datafusion.apache.org //! [Apache Arrow]: https://arrow.apache.org //! [use cases]: https://datafusion.apache.org/user-guide/introduction.html#use-cases //! [SQL]: https://datafusion.apache.org/user-guide/sql/index.html @@ -300,14 +313,17 @@ //! ``` //! //! A [`TableProvider`] provides information for planning and -//! an [`ExecutionPlan`]s for execution. DataFusion includes [`ListingTable`] -//! which supports reading several common file formats, and you can support any -//! new file format by implementing the [`TableProvider`] trait. See also: +//! an [`ExecutionPlan`] for execution. DataFusion includes two built-in +//! table providers that support common file formats and require no runtime services, +//! [`ListingTable`] and [`MemTable`]. You can add support for any other data +//! source and/or file formats by implementing the [`TableProvider`] trait. +//! +//! See also: //! -//! 1. [`ListingTable`]: Reads data from Parquet, JSON, CSV, or AVRO -//! files. Supports single files or multiple files with HIVE style +//! 1. [`ListingTable`]: Reads data from one or more Parquet, JSON, CSV, or AVRO +//! files in one or more local or remote directories. Supports HIVE style //! partitioning, optional compression, directly reading from remote -//! object store and more. +//! object store, file metadata caching, and more. //! //! 2. [`MemTable`]: Reads data from in memory [`RecordBatch`]es. //! @@ -326,11 +342,11 @@ //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! -//! `LogicalPlan`s can be rewritten with [`TreeNode`] API, see the +//! [`LogicalPlan`]s can be rewritten with [`TreeNode`] API, see the //! [`tree_node module`] for more details. //! //! [`Expr`]s can also be rewritten with [`TreeNode`] API and simplified using -//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be +//! [`ExprSimplifier`]. Examples of working with and executing [`Expr`]s can be //! found in the [`expr_api`.rs] example //! //! [`TreeNode`]: datafusion_common::tree_node::TreeNode @@ -415,17 +431,17 @@ //! //! ## Streaming Execution //! -//! DataFusion is a "streaming" query engine which means `ExecutionPlan`s incrementally +//! DataFusion is a "streaming" query engine which means [`ExecutionPlan`]s incrementally //! read from their input(s) and compute output one [`RecordBatch`] at a time //! by continually polling [`SendableRecordBatchStream`]s. Output and -//! intermediate `RecordBatch`s each have approximately `batch_size` rows, +//! intermediate [`RecordBatch`]s each have approximately `batch_size` rows, //! which amortizes per-batch overhead of execution. //! //! Note that certain operations, sometimes called "pipeline breakers", //! (for example full sorts or hash aggregations) are fundamentally non streaming and //! must read their input fully before producing **any** output. As much as possible, //! other operators read a single [`RecordBatch`] from their input to produce a -//! single `RecordBatch` as output. +//! single [`RecordBatch`] as output. //! //! For example, given this SQL query: //! @@ -434,9 +450,9 @@ //! ``` //! //! The diagram below shows the call sequence when a consumer calls [`next()`] to -//! get the next `RecordBatch` of output. While it is possible that some +//! get the next [`RecordBatch`] of output. While it is possible that some //! steps run on different threads, typically tokio will use the same thread -//! that called `next()` to read from the input, apply the filter, and +//! that called [`next()`] to read from the input, apply the filter, and //! return the results without interleaving any other operations. This results //! in excellent cache locality as the same CPU core that produces the data often //! consumes it immediately as well. @@ -474,39 +490,53 @@ //! DataFusion automatically runs each plan with multiple CPU cores using //! a [Tokio] [`Runtime`] as a thread pool. While tokio is most commonly used //! for asynchronous network I/O, the combination of an efficient, work-stealing -//! scheduler and first class compiler support for automatic continuation -//! generation (`async`), also makes it a compelling choice for CPU intensive +//! scheduler, and first class compiler support for automatic continuation +//! generation (`async`) also makes it a compelling choice for CPU intensive //! applications as explained in the [Using Rustlang’s Async Tokio //! Runtime for CPU-Bound Tasks] blog. //! //! The number of cores used is determined by the `target_partitions` //! configuration setting, which defaults to the number of CPU cores. //! While preparing for execution, DataFusion tries to create this many distinct -//! `async` [`Stream`]s for each `ExecutionPlan`. -//! The `Stream`s for certain `ExecutionPlans`, such as as [`RepartitionExec`] -//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that are run by -//! threads managed by the `Runtime`. -//! Many DataFusion `Stream`s perform CPU intensive processing. +//! `async` [`Stream`]s for each [`ExecutionPlan`]. +//! The [`Stream`]s for certain [`ExecutionPlan`]s, such as [`RepartitionExec`] +//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that run on +//! threads managed by the [`Runtime`]. +//! Many DataFusion [`Stream`]s perform CPU intensive processing. +//! +//! ### Cooperative Scheduling +//! +//! DataFusion uses cooperative scheduling, which means that each [`Stream`] +//! is responsible for yielding control back to the [`Runtime`] after +//! some amount of work is done. Please see the [`coop`] module documentation +//! for more details. +//! +//! [`coop`]: datafusion_physical_plan::coop +//! +//! ### Network I/O and CPU intensive tasks //! //! Using `async` for CPU intensive tasks makes it easy for [`TableProvider`]s //! to perform network I/O using standard Rust `async` during execution. //! However, this design also makes it very easy to mix CPU intensive and latency //! sensitive I/O work on the same thread pool ([`Runtime`]). -//! Using the same (default) `Runtime` is convenient, and often works well for +//! Using the same (default) [`Runtime`] is convenient, and often works well for //! initial development and processing local files, but it can lead to problems //! under load and/or when reading from network sources such as AWS S3. //! +//! ### Optimizing Latency: Throttled CPU / IO under Highly Concurrent Load +//! //! If your system does not fully utilize either the CPU or network bandwidth //! during execution, or you see significantly higher tail (e.g. p99) latencies //! responding to network requests, **it is likely you need to use a different -//! `Runtime` for CPU intensive DataFusion plans**. This effect can be especially -//! pronounced when running several queries concurrently. +//! [`Runtime`] for DataFusion plans**. The [thread_pools example] +//! has an example of how to do so. //! -//! As shown in the following figure, using the same `Runtime` for both CPU -//! intensive processing and network requests can introduce significant -//! delays in responding to those network requests. Delays in processing network -//! requests can and does lead network flow control to throttle the available -//! bandwidth in response. +//! As shown below, using the same [`Runtime`] for both CPU intensive processing +//! and network requests can introduce significant delays in responding to +//! those network requests. Delays in processing network requests can and does +//! lead network flow control to throttle the available bandwidth in response. +//! This effect can be especially pronounced when running multiple queries +//! concurrently. //! //! ```text //! Legend @@ -588,6 +618,7 @@ //! //! [Tokio]: https://tokio.rs //! [`Runtime`]: tokio::runtime::Runtime +//! [thread_pools example]: https://github.com/apache/datafusion/tree/main/datafusion-examples/examples/thread_pools.rs //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ //! [`RepartitionExec`]: physical_plan::repartition::RepartitionExec @@ -603,8 +634,8 @@ //! The state required to execute queries is managed by the following //! structures: //! -//! 1. [`SessionContext`]: State needed for create [`LogicalPlan`]s such -//! as the table definitions, and the function registries. +//! 1. [`SessionContext`]: State needed to create [`LogicalPlan`]s such +//! as the table definitions and the function registries. //! //! 2. [`TaskContext`]: State needed for execution such as the //! [`MemoryPool`], [`DiskManager`], and [`ObjectStoreRegistry`]. @@ -703,6 +734,8 @@ pub const DATAFUSION_VERSION: &str = env!("CARGO_PKG_VERSION"); extern crate core; + +#[cfg(feature = "sql")] extern crate sqlparser; pub mod dataframe; @@ -713,11 +746,16 @@ pub mod physical_planner; pub mod prelude; pub mod scalar; -// re-export dependencies from arrow-rs to minimize version maintenance for crate users +// Re-export dependencies that are part of DataFusion public API (e.g. via DataFusionError) pub use arrow; +pub use object_store; + #[cfg(feature = "parquet")] pub use parquet; +#[cfg(feature = "avro")] +pub use datafusion_datasource_avro::apache_avro; + // re-export DataFusion sub-crates at the top level. Use `pub use *` // so that the contents of the subcrates appears in rustdocs // for details, see https://github.com/apache/datafusion/issues/6648 @@ -772,6 +810,11 @@ pub mod physical_expr { pub use datafusion_physical_expr::*; } +/// re-export of [`datafusion_physical_expr_adapter`] crate +pub mod physical_expr_adapter { + pub use datafusion_physical_expr_adapter::*; +} + /// re-export of [`datafusion_physical_plan`] crate pub mod physical_plan { pub use datafusion_physical_plan::*; @@ -782,6 +825,7 @@ pub use datafusion_common::assert_batches_eq; pub use datafusion_common::assert_batches_sorted_eq; /// re-export of [`datafusion_sql`] crate +#[cfg(feature = "sql")] pub mod sql { pub use datafusion_sql::*; } @@ -797,13 +841,6 @@ pub mod functions_nested { pub use datafusion_functions_nested::*; } -/// re-export of [`datafusion_functions_nested`] crate as [`functions_array`] for backward compatibility, if "nested_expressions" feature is enabled -#[deprecated(since = "41.0.0", note = "use datafusion-functions-nested instead")] -pub mod functions_array { - #[cfg(feature = "nested_expressions")] - pub use datafusion_functions_nested::*; -} - /// re-export of [`datafusion_functions_aggregate`] crate pub mod functions_aggregate { pub use datafusion_functions_aggregate::*; @@ -1021,14 +1058,20 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/sql/write_options.md", - user_guide_sql_write_options + "../../../docs/source/user-guide/sql/format_options.md", + user_guide_sql_format_options +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/functions/adding-udfs.md", + library_user_guide_functions_adding_udfs ); #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/library-user-guide/adding-udfs.md", - library_user_guide_adding_udfs + "../../../docs/source/library-user-guide/functions/spark.md", + library_user_guide_functions_spark ); #[cfg(doctest)] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f1a99a7714ac4..c28e56790e660 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -45,7 +45,7 @@ use crate::physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, }; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; -use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::union::UnionExec; @@ -55,23 +55,25 @@ use crate::physical_plan::{ displayable, windows, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, }; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::recursive_query::RecursiveQueryExec; +use crate::schema_equivalence::schema_satisfied_by; use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::Schema; +use datafusion_catalog::ScanArgs; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::TableReference; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, }; +use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, - WindowFunction, WindowFunctionParams, + NullTreatment, WindowFunction, WindowFunctionParams, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -82,19 +84,21 @@ use datafusion_expr::{ }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; -use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr::{ + create_physical_sort_exprs, LexOrdering, PhysicalSortExpr, +}; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; -use crate::schema_equivalence::schema_satisfied_by; use async_trait::async_trait; -use datafusion_datasource::file_groups::FileGroup; +use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; -use log::{debug, trace}; -use sqlparser::ast::NullTreatment; +use log::debug; use tokio::sync::Mutex; /// Physical query planner that converts a `LogicalPlan` to an @@ -453,12 +457,15 @@ impl DefaultPhysicalPlanner { // doesn't know (nor should care) how the relation was // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); - source - .scan(session_state, projection.as_ref(), &filters, *fetch) - .await? + let filters_vec = filters.into_iter().collect::>(); + let opts = ScanArgs::default() + .with_projection(projection.as_deref()) + .with_filters(Some(&filters_vec)) + .with_limit(*fetch); + let res = source.scan_with_args(session_state, opts).await?; + Arc::clone(res.plan()) } LogicalPlan::Values(Values { values, schema }) => { - let exec_schema = schema.as_ref().to_owned().into(); let exprs = values .iter() .map(|row| { @@ -469,27 +476,23 @@ impl DefaultPhysicalPlanner { .collect::>>>() }) .collect::>>()?; - MemorySourceConfig::try_new_as_values(SchemaRef::new(exec_schema), exprs)? + MemorySourceConfig::try_new_as_values(Arc::clone(schema.inner()), exprs)? as _ } LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema, - }) => Arc::new(EmptyExec::new(SchemaRef::new( - schema.as_ref().to_owned().into(), - ))), + }) => Arc::new(EmptyExec::new(Arc::clone(schema.inner()))), LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: true, schema, - }) => Arc::new(PlaceholderRowExec::new(SchemaRef::new( - schema.as_ref().to_owned().into(), - ))), + }) => Arc::new(PlaceholderRowExec::new(Arc::clone(schema.inner()))), LogicalPlan::DescribeTable(DescribeTable { schema, output_schema, }) => { - let output_schema: Schema = output_schema.as_ref().into(); - self.plan_describe(Arc::clone(schema), Arc::new(output_schema))? + let output_schema = Arc::clone(output_schema.inner()); + self.plan_describe(Arc::clone(schema), output_schema)? } // 1 Child @@ -499,13 +502,14 @@ impl DefaultPhysicalPlanner { file_type, partition_by, options: source_option_tuples, + output_schema: _, }) => { let original_url = output_url.clone(); let input_exec = children.one()?; let parsed_url = ListingTableUrl::parse(output_url)?; let object_store_url = parsed_url.object_store(); - let schema: Schema = (**input.schema()).clone().into(); + let schema = Arc::clone(input.schema().inner()); // Note: the DataType passed here is ignored for the purposes of writing and inferred instead // from the schema of the RecordBatch being written. This allows COPY statements to specify only @@ -522,27 +526,42 @@ impl DefaultPhysicalPlanner { Some("true") => true, Some("false") => false, Some(value) => - return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{}\"", value))), + return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\""))), }; let sink_format = file_type_to_format(file_type)? .create(session_state, source_option_tuples)?; + // Determine extension based on format extension and compression + let file_extension = match sink_format.compression_type() { + Some(compression_type) => sink_format + .get_ext_with_compression(&compression_type) + .unwrap_or_else(|_| sink_format.get_ext()), + None => sink_format.get_ext(), + }; + // Set file sink related options let config = FileSinkConfig { original_url, object_store_url, table_paths: vec![parsed_url], file_group: FileGroup::default(), - output_schema: Arc::new(schema), + output_schema: schema, table_partition_cols, insert_op: InsertOp::Append, keep_partition_by_columns, - file_extension: sink_format.get_ext(), + file_extension, }; + let ordering = input_exec.properties().output_ordering().cloned(); + sink_format - .create_writer_physical_plan(input_exec, session_state, config, None) + .create_writer_physical_plan( + input_exec, + session_state, + config, + ordering.map(Into::into), + ) .await? } LogicalPlan::Dml(DmlStatement { @@ -572,27 +591,25 @@ impl DefaultPhysicalPlanner { let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } Expr::Alias(Alias { expr, .. }) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } _ => unreachable!(), } } @@ -694,7 +711,7 @@ impl DefaultPhysicalPlanner { } return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences .iter() - .map(|s| format!("\n\t- {}", s)) + .map(|s| format!("\n\t- {s}")) .join("")); } @@ -717,9 +734,54 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - let (aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) = + let (mut aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter); + let mut async_exprs = Vec::new(); + let num_input_columns = physical_input_schema.fields().len(); + + for agg_func in &mut aggregates { + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::Expr(agg_func.expressions()), + physical_input_schema.as_ref(), + )? { + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::Expr(physical_exprs), + ) => { + async_exprs.extend(async_map.async_exprs); + + if let Some(new_agg_func) = agg_func.with_new_expressions( + physical_exprs, + agg_func + .order_bys() + .iter() + .cloned() + .map(|x| x.expr) + .collect(), + ) { + *agg_func = Arc::new(new_agg_func); + } else { + return internal_err!("Failed to plan async expression"); + } + } + PlanAsyncExpr::Sync(PlannedExprResult::Expr(_)) => { + // Do nothing + } + _ => { + return internal_err!( + "Unexpected result from try_plan_async_exprs" + ) + } + } + } + let input_exec = if !async_exprs.is_empty() { + Arc::new(AsyncFuncExec::try_new(async_exprs, input_exec)?) + } else { + input_exec + }; + let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), @@ -775,12 +837,46 @@ impl DefaultPhysicalPlanner { let runtime_expr = self.create_physical_expr(predicate, input_dfschema, session_state)?; + + let input_schema = input.schema(); + let filter = match self.try_plan_async_exprs( + input_schema.fields().len(), + PlannedExprResult::Expr(vec![runtime_expr]), + input_schema.as_arrow(), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { + FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)? + } + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::Expr(runtime_expr), + ) => { + let async_exec = AsyncFuncExec::try_new( + async_map.async_exprs, + physical_input, + )?; + FilterExec::try_new( + Arc::clone(&runtime_expr[0]), + Arc::new(async_exec), + )? + // project the output columns excluding the async functions + // The async functions are always appended to the end of the schema. + .with_projection(Some( + (0..input.schema().fields().len()).collect(), + ))? + } + _ => { + return internal_err!( + "Unexpected result from try_plan_async_exprs" + ) + } + }; + let selectivity = session_state .config() .options() .optimizer .default_filter_selectivity; - let filter = FilterExec::try_new(runtime_expr, physical_input)?; Arc::new(filter.with_default_selectivity(selectivity)?) } LogicalPlan::Repartition(Repartition { @@ -822,13 +918,17 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); - let sort_expr = create_physical_sort_exprs( + let sort_exprs = create_physical_sort_exprs( expr, input_dfschema, session_state.execution_props(), )?; - let new_sort = - SortExec::new(sort_expr, physical_input).with_fetch(*fetch); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + return internal_err!( + "SortExec requires at least one sort expression" + ); + }; + let new_sort = SortExec::new(ordering, physical_input).with_fetch(*fetch); Arc::new(new_sort) } LogicalPlan::Subquery(_) => todo!(), @@ -871,7 +971,7 @@ impl DefaultPhysicalPlanner { .. }) => { let input = children.one()?; - let schema = SchemaRef::new(schema.as_ref().to_owned().into()); + let schema = Arc::clone(schema.inner()); let list_column_indices = list_type_columns .iter() .map(|(index, unnesting)| ListUnnest { @@ -890,17 +990,15 @@ impl DefaultPhysicalPlanner { // 2 Children LogicalPlan::Join(Join { - left, - right, + left: original_left, + right: original_right, on: keys, filter, join_type, - null_equals_null, + null_equality, schema: join_schema, .. }) => { - let null_equals_null = *null_equals_null; - let [physical_left, physical_right] = children.two()?; // If join has expression equijoin keys, add physical projection. @@ -916,23 +1014,25 @@ impl DefaultPhysicalPlanner { let (left, left_col_keys, left_projected) = wrap_projection_for_join_if_necessary( &left_keys, - left.as_ref().clone(), + original_left.as_ref().clone(), )?; let (right, right_col_keys, right_projected) = wrap_projection_for_join_if_necessary( &right_keys, - right.as_ref().clone(), + original_right.as_ref().clone(), )?; let column_on = (left_col_keys, right_col_keys); let left = Arc::new(left); let right = Arc::new(right); - let new_join = LogicalPlan::Join(Join::try_new_with_project_input( + let (new_join, requalified) = Join::try_new_with_project_input( node, Arc::clone(&left), Arc::clone(&right), column_on, - )?); + )?; + + let new_join = LogicalPlan::Join(new_join); // If inputs were projected then create ExecutionPlan for these new // LogicalPlan nodes. @@ -965,8 +1065,24 @@ impl DefaultPhysicalPlanner { // Remove temporary projected columns if left_projected || right_projected { - let final_join_result = - join_schema.iter().map(Expr::from).collect::>(); + // Re-qualify the join schema only if the inputs were previously requalified in + // `try_new_with_project_input`. This ensures that when building the Projection + // it can correctly resolve field nullability and data types + // by disambiguating fields from the left and right sides of the join. + let qualified_join_schema = if requalified { + Arc::new(qualify_join_schema_sides( + join_schema, + original_left, + original_right, + )?) + } else { + Arc::clone(join_schema) + }; + + let final_join_result = qualified_join_schema + .iter() + .map(Expr::from) + .collect::>(); let projection = LogicalPlan::Projection(Projection::try_new( final_join_result, Arc::new(new_join), @@ -1023,18 +1139,12 @@ impl DefaultPhysicalPlanner { // Collect left & right field indices, the field indices are sorted in ascending order let left_field_indices = cols .iter() - .filter_map(|c| match left_df_schema.index_of_column(c) { - Ok(idx) => Some(idx), - _ => None, - }) + .filter_map(|c| left_df_schema.index_of_column(c).ok()) .sorted() .collect::>(); let right_field_indices = cols .iter() - .filter_map(|c| match right_df_schema.index_of_column(c) { - Ok(idx) => Some(idx), - _ => None, - }) + .filter_map(|c| right_df_schema.index_of_column(c).ok()) .sorted() .collect::>(); @@ -1119,8 +1229,6 @@ impl DefaultPhysicalPlanner { && !prefer_hash_join { // Use SortMergeJoin if hash join is not preferred - // Sort-Merge join support currently is experimental - let join_on_len = join_on.len(); Arc::new(SortMergeJoinExec::try_new( physical_left, @@ -1129,7 +1237,7 @@ impl DefaultPhysicalPlanner { join_filter, *join_type, vec![SortOptions::default(); join_on_len], - null_equals_null, + *null_equality, )?) } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() @@ -1143,7 +1251,7 @@ impl DefaultPhysicalPlanner { join_type, None, PartitionMode::Auto, - null_equals_null, + *null_equality, )?) } else { Arc::new(HashJoinExec::try_new( @@ -1154,7 +1262,7 @@ impl DefaultPhysicalPlanner { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + *null_equality, )?) }; @@ -1179,7 +1287,7 @@ impl DefaultPhysicalPlanner { } // N Children - LogicalPlan::Union(_) => Arc::new(UnionExec::new(children.vec())), + LogicalPlan::Union(_) => UnionExec::try_new(children.vec())?, LogicalPlan::Extension(Extension { node }) => { let mut maybe_plan = None; let children = children.vec(); @@ -1292,6 +1400,9 @@ impl DefaultPhysicalPlanner { physical_name(expr), ))?])), } + } else if group_expr.is_empty() { + // No GROUP BY clause - create empty PhysicalGroupBy + Ok(PhysicalGroupBy::new(vec![], vec![], vec![])) } else { Ok(PhysicalGroupBy::new_single( group_expr @@ -1471,6 +1582,64 @@ fn get_null_physical_expr_pair( Ok((Arc::new(null_value), physical_name)) } +/// Qualifies the fields in a join schema with "left" and "right" qualifiers +/// without mutating the original schema. This function should only be used when +/// the join inputs have already been requalified earlier in `try_new_with_project_input`. +/// +/// The purpose is to avoid ambiguity errors later in planning (e.g., in nullability or data type resolution) +/// when converting expressions to fields. +fn qualify_join_schema_sides( + join_schema: &DFSchema, + left: &LogicalPlan, + right: &LogicalPlan, +) -> Result { + let left_fields = left.schema().fields(); + let right_fields = right.schema().fields(); + let join_fields = join_schema.fields(); + + // Validate lengths + if join_fields.len() != left_fields.len() + right_fields.len() { + return internal_err!( + "Join schema field count must match left and right field count." + ); + } + + // Validate field names match + for (i, (field, expected)) in join_fields + .iter() + .zip(left_fields.iter().chain(right_fields.iter())) + .enumerate() + { + if field.name() != expected.name() { + return internal_err!( + "Field name mismatch at index {}: expected '{}', found '{}'", + i, + expected.name(), + field.name() + ); + } + } + + // qualify sides + let qualifiers = join_fields + .iter() + .enumerate() + .map(|(i, _)| { + if i < left_fields.len() { + Some(TableReference::Bare { + table: Arc::from("left"), + }) + } else { + Some(TableReference::Bare { + table: Arc::from("right"), + }) + } + }) + .collect(); + + join_schema.with_field_specific_qualified_schema(qualifiers) +} + fn get_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, @@ -1510,19 +1679,22 @@ pub fn create_window_expr_with_name( execution_props: &ExecutionProps, ) -> Result> { let name = name.into(); - let physical_schema: &Schema = &logical_schema.into(); + let physical_schema = Arc::clone(logical_schema.inner()); match e { - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + distinct, + filter, + }, + } = window_fun.as_ref(); let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = @@ -1540,15 +1712,22 @@ pub fn create_window_expr_with_name( let window_frame = Arc::new(window_frame.clone()); let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; + let physical_filter = filter + .as_ref() + .map(|f| create_physical_expr(f, logical_schema, execution_props)) + .transpose()?; + windows::create_window_expr( fun, name, &physical_args, &partition_by, - order_by.as_ref(), + &order_by, window_frame, physical_schema, ignore_nulls, + *distinct, + physical_filter, ) } other => plan_err!("Invalid window expression '{other:?}'"), @@ -1573,8 +1752,8 @@ type AggregateExprWithOptionalArgs = ( Arc, // The filter clause, if any Option>, - // Ordering requirements, if any - Option, + // Expressions in the ORDER BY clause + Vec, ); /// Create an aggregate expression with a name from a logical expression @@ -1618,22 +1797,16 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; - let (agg_expr, filter, order_by) = { - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( - exprs, - logical_input_schema, - execution_props, - )?), - None => None, - }; - - let ordering_reqs: LexOrdering = - physical_sort_exprs.clone().unwrap_or_default(); + let (agg_expr, filter, order_bys) = { + let order_bys = create_physical_sort_exprs( + order_by, + logical_input_schema, + execution_props, + )?; let agg_expr = AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) - .order_by(ordering_reqs) + .order_by(order_bys.clone()) .schema(Arc::new(physical_input_schema.to_owned())) .alias(name) .human_display(human_displan) @@ -1642,10 +1815,10 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .build() .map(Arc::new)?; - (agg_expr, filter, physical_sort_exprs) + (agg_expr, filter, order_bys) }; - Ok((agg_expr, filter, order_by)) + Ok((agg_expr, filter, order_bys)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } @@ -1658,21 +1831,24 @@ pub fn create_aggregate_expr_and_maybe_filter( physical_input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result { - // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" + // Unpack (potentially nested) aliased logical expressions, e.g. "sum(col) as total" + // Some functions like `count_all()` create internal aliases, + // Unwrap all alias layers to get to the underlying aggregate function let (name, human_display, e) = match e { - Expr::Alias(Alias { expr, name, .. }) => { - (Some(name.clone()), String::default(), expr.as_ref()) + Expr::Alias(Alias { name, .. }) => { + let unaliased = e.clone().unalias_nested().data; + (Some(name.clone()), e.human_display().to_string(), unaliased) } Expr::AggregateFunction(_) => ( Some(e.schema_name().to_string()), e.human_display().to_string(), - e, + e.clone(), ), - _ => (None, String::default(), e), + _ => (None, String::default(), e.clone()), }; create_aggregate_expr_with_name_and_maybe_filter( - e, + &e, name, human_display, logical_input_schema, @@ -1681,14 +1857,6 @@ pub fn create_aggregate_expr_and_maybe_filter( ) } -#[deprecated( - since = "47.0.0", - note = "use datafusion::{create_physical_sort_expr, create_physical_sort_exprs}" -)] -pub use datafusion_physical_expr::{ - create_physical_sort_expr, create_physical_sort_exprs, -}; - impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// @@ -1720,6 +1888,14 @@ impl DefaultPhysicalPlanner { let config = &session_state.config_options().explain; let explain_format = &e.explain_format; + if !e.logical_optimization_succeeded { + return Ok(Arc::new(ExplainExec::new( + Arc::clone(e.schema.inner()), + e.stringified_plans.clone(), + true, + ))); + } + match explain_format { ExplainFormat::Indent => { /* fall through */ } ExplainFormat::Tree => { @@ -1737,6 +1913,7 @@ impl DefaultPhysicalPlanner { stringified_plans.push(StringifiedPlan::new( FinalPhysicalPlan, displayable(optimized_plan.as_ref()) + .set_tree_maximum_render_width(config.tree_maximum_render_width) .tree_render() .to_string(), )); @@ -1894,7 +2071,7 @@ impl DefaultPhysicalPlanner { session_state: &SessionState, ) -> Result> { let input = self.create_physical_plan(&a.input, session_state).await?; - let schema = SchemaRef::new((*a.schema).clone().into()); + let schema = Arc::clone(a.schema.inner()); let show_statistics = session_state.config_options().explain.show_statistics; Ok(Arc::new(AnalyzeExec::new( a.verbose, @@ -1920,7 +2097,7 @@ impl DefaultPhysicalPlanner { "Input physical plan:\n{}\n", displayable(plan.as_ref()).indent(false) ); - trace!( + debug!( "Detailed input physical plan:\n{}", displayable(plan.as_ref()).indent(true) ); @@ -1942,7 +2119,7 @@ impl DefaultPhysicalPlanner { OptimizationInvariantChecker::new(optimizer) .check(&new_plan, before_schema)?; - trace!( + debug!( "Optimized physical plan by {}:\n{}\n", optimizer.name(), displayable(new_plan.as_ref()).indent(false) @@ -1958,7 +2135,15 @@ impl DefaultPhysicalPlanner { "Optimized physical plan:\n{}\n", displayable(new_plan.as_ref()).indent(false) ); - trace!("Detailed optimized physical plan:\n{:?}", new_plan); + + // Don't print new_plan directly, as that may overflow the stack. + // For example: + // thread 'tokio-runtime-worker' has overflowed its stack + // fatal runtime error: stack overflow, aborting + debug!( + "Detailed optimized physical plan:\n{}\n", + displayable(new_plan.as_ref()).indent(true) + ); Ok(new_plan) } @@ -1976,7 +2161,7 @@ impl DefaultPhysicalPlanner { // "System supplied type" --> Use debug format of the datatype let data_type = field.data_type(); - data_types.append_value(format!("{data_type:?}")); + data_types.append_value(format!("{data_type}")); // "YES if the column is possibly nullable, NO if it is known not nullable. " let nullable_str = if field.is_nullable() { "YES" } else { "NO" }; @@ -2006,7 +2191,8 @@ impl DefaultPhysicalPlanner { input: &Arc, expr: &[Expr], ) -> Result> { - let input_schema = input.as_ref().schema(); + let input_logical_schema = input.as_ref().schema(); + let input_physical_schema = input_exec.schema(); let physical_exprs = expr .iter() .map(|e| { @@ -2025,7 +2211,7 @@ impl DefaultPhysicalPlanner { // This depends on the invariant that logical schema field index MUST match // with physical schema field index. let physical_name = if let Expr::Column(col) = e { - match input_schema.index_of_column(col) { + match input_logical_schema.index_of_column(col) { Ok(idx) => { // index physical field using logical field index Ok(input_exec.schema().field(idx).name().to_string()) @@ -2038,20 +2224,108 @@ impl DefaultPhysicalPlanner { physical_name(e) }; - tuple_err(( - self.create_physical_expr(e, input_schema, session_state), - physical_name, - )) + let physical_expr = + self.create_physical_expr(e, input_logical_schema, session_state); + + tuple_err((physical_expr, physical_name)) }) .collect::>>()?; - Ok(Arc::new(ProjectionExec::try_new( - physical_exprs, - input_exec, - )?)) + let num_input_columns = input_exec.schema().fields().len(); + + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::ExprWithName(physical_exprs), + input_physical_schema.as_ref(), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::ExprWithName(physical_exprs)) => { + let proj_exprs: Vec = physical_exprs + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input_exec)?)) + } + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::ExprWithName(physical_exprs), + ) => { + let async_exec = + AsyncFuncExec::try_new(async_map.async_exprs, input_exec)?; + let proj_exprs: Vec = physical_exprs + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + let new_proj_exec = + ProjectionExec::try_new(proj_exprs, Arc::new(async_exec))?; + Ok(Arc::new(new_proj_exec)) + } + _ => internal_err!("Unexpected PlanAsyncExpressions variant"), + } + } + + fn try_plan_async_exprs( + &self, + num_input_columns: usize, + physical_expr: PlannedExprResult, + schema: &Schema, + ) -> Result { + let mut async_map = AsyncMapper::new(num_input_columns); + match &physical_expr { + PlannedExprResult::ExprWithName(exprs) => { + exprs + .iter() + .try_for_each(|(expr, _)| async_map.find_references(expr, schema))?; + } + PlannedExprResult::Expr(exprs) => { + exprs + .iter() + .try_for_each(|expr| async_map.find_references(expr, schema))?; + } + } + + if async_map.is_empty() { + return Ok(PlanAsyncExpr::Sync(physical_expr)); + } + + let new_exprs = match physical_expr { + PlannedExprResult::ExprWithName(exprs) => PlannedExprResult::ExprWithName( + exprs + .iter() + .map(|(expr, column_name)| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok((new_expr.data, column_name.to_string())) + }) + .collect::>()?, + ), + PlannedExprResult::Expr(exprs) => PlannedExprResult::Expr( + exprs + .iter() + .map(|expr| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok(new_expr.data) + }) + .collect::>()?, + ), + }; + // rewrite the projection's expressions in terms of the columns with the result of async evaluation + Ok(PlanAsyncExpr::Async(async_map, new_exprs)) } } +#[derive(Debug)] +enum PlannedExprResult { + ExprWithName(Vec<(Arc, String)>), + Expr(Vec>), +} + +#[derive(Debug)] +enum PlanAsyncExpr { + Sync(PlannedExprResult), + Async(AsyncMapper, PlannedExprResult), +} + fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { match value { (Ok(e), Ok(e1)) => Ok((e, e1)), @@ -2103,7 +2377,7 @@ impl<'n> TreeNodeVisitor<'n> for OptimizationInvariantChecker<'_> { fn f_down(&mut self, node: &'n Self::Node) -> Result { // Checks for the more permissive `InvariantLevel::Always`. - // Plans are not guarenteed to be executable after each physical optimizer run. + // Plans are not guaranteed to be executable after each physical optimizer run. node.check_invariants(InvariantLevel::Always).map_err(|e| e.context(format!("Invariant for ExecutionPlan node '{}' failed for PhysicalOptimizer rule '{}'", node.name(), self.rule.name())) )?; @@ -2158,11 +2432,16 @@ mod tests { use crate::execution::session_state::SessionStateBuilder; use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; + use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; - use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; + use datafusion_common::{ + assert_contains, DFSchemaRef, TableReference, ToDFSchema as _, + }; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; + use datafusion_expr::builder::subquery_alias; use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_functions_aggregate::count::count_all; use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -2205,8 +2484,9 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; - assert!(format!("{exec_plan:?}").contains(expected)); + let expected = r#"BinaryExpr { left: Column { name: "c7", index: 2 }, op: Lt, right: Literal { value: Int64(5), field: Field { name: "lit", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#; + + assert_contains!(format!("{exec_plan:?}"), expected); Ok(()) } @@ -2230,9 +2510,121 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; - - assert_eq!(format!("{cube:?}"), expected); + insta::assert_debug_snapshot!(cube, @r#" + Ok( + PhysicalGroupBy { + expr: [ + ( + Column { + name: "c1", + index: 0, + }, + "c1", + ), + ( + Column { + name: "c2", + index: 1, + }, + "c2", + ), + ( + Column { + name: "c3", + index: 2, + }, + "c3", + ), + ], + null_expr: [ + ( + Literal { + value: Utf8(NULL), + field: Field { + name: "lit", + data_type: Utf8, + nullable: true, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + }, + "c1", + ), + ( + Literal { + value: Int64(NULL), + field: Field { + name: "lit", + data_type: Int64, + nullable: true, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + }, + "c2", + ), + ( + Literal { + value: Int64(NULL), + field: Field { + name: "lit", + data_type: Int64, + nullable: true, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + }, + "c3", + ), + ], + groups: [ + [ + false, + false, + false, + ], + [ + true, + false, + false, + ], + [ + false, + true, + false, + ], + [ + false, + false, + true, + ], + [ + true, + true, + false, + ], + [ + true, + false, + true, + ], + [ + false, + true, + true, + ], + [ + true, + true, + true, + ], + ], + }, + ) + "#); Ok(()) } @@ -2257,9 +2649,101 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; - - assert_eq!(format!("{rollup:?}"), expected); + insta::assert_debug_snapshot!(rollup, @r#" + Ok( + PhysicalGroupBy { + expr: [ + ( + Column { + name: "c1", + index: 0, + }, + "c1", + ), + ( + Column { + name: "c2", + index: 1, + }, + "c2", + ), + ( + Column { + name: "c3", + index: 2, + }, + "c3", + ), + ], + null_expr: [ + ( + Literal { + value: Utf8(NULL), + field: Field { + name: "lit", + data_type: Utf8, + nullable: true, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + }, + "c1", + ), + ( + Literal { + value: Int64(NULL), + field: Field { + name: "lit", + data_type: Int64, + nullable: true, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + }, + "c2", + ), + ( + Literal { + value: Int64(NULL), + field: Field { + name: "lit", + data_type: Int64, + nullable: true, + dict_id: 0, + dict_is_ordered: false, + metadata: {}, + }, + }, + "c3", + ), + ], + groups: [ + [ + true, + true, + true, + ], + [ + false, + true, + true, + ], + [ + false, + false, + true, + ], + [ + false, + false, + false, + ], + ], + }, + ) + "#); Ok(()) } @@ -2397,35 +2881,13 @@ mod tests { let logical_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); - let plan = planner + let e = planner .create_physical_plan(&logical_plan, &session_state) - .await; + .await + .expect_err("planning error") + .strip_backtrace(); - let expected_error: &str = "Error during planning: \ - Extension planner for NoOp created an ExecutionPlan with mismatched schema. \ - LogicalPlan schema: \ - DFSchema { inner: Schema { fields: \ - [Field { name: \"a\", \ - data_type: Int32, \ - nullable: false, \ - dict_id: 0, \ - dict_is_ordered: false, metadata: {} }], \ - metadata: {} }, field_qualifiers: [None], \ - functional_dependencies: FunctionalDependencies { deps: [] } }, \ - ExecutionPlan schema: Schema { fields: \ - [Field { name: \"b\", \ - data_type: Int32, \ - nullable: false, \ - dict_id: 0, \ - dict_is_ordered: false, metadata: {} }], \ - metadata: {} }"; - match plan { - Ok(_) => panic!("Expected planning failure"), - Err(e) => assert!( - e.to_string().contains(expected_error), - "Error '{e}' did not contain expected error '{expected_error}'" - ), - } + insta::assert_snapshot!(e, @r#"Error during planning: Extension planner for NoOp created an ExecutionPlan with mismatched schema. LogicalPlan schema: DFSchema { inner: Schema { fields: [Field { name: "a", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }, field_qualifiers: [None], functional_dependencies: FunctionalDependencies { deps: [] } }, ExecutionPlan schema: Schema { fields: [Field { name: "b", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }"#); } #[tokio::test] @@ -2441,10 +2903,9 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; + let expected = "expr: [ProjectionExpr { expr: BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, fail_on_overflow: false }"; - let actual = format!("{execution_plan:?}"); - assert!(actual.contains(expected), "{}", actual); + assert_contains!(format!("{execution_plan:?}"), expected); Ok(()) } @@ -2464,7 +2925,7 @@ mod tests { assert_contains!( &e, - r#"Error during planning: Can not find compatible types to compare Boolean with [Struct([Field { name: "foo", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), Utf8]"# + r#"Error during planning: Can not find compatible types to compare Boolean with [Struct(foo Boolean), Utf8]"# ); Ok(()) @@ -2621,6 +3082,25 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_count_all_with_alias() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + ])); + + let logical_plan = scan_empty(None, schema.as_ref(), None)? + .aggregate(Vec::::new(), vec![count_all().alias("total_rows")])? + .build()?; + + let physical_plan = plan(&logical_plan).await?; + assert_eq!( + "total_rows", + physical_plan.schema().field(0).name().as_str() + ); + Ok(()) + } + #[tokio::test] async fn test_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); @@ -2656,6 +3136,54 @@ mod tests { } } + #[tokio::test] + async fn test_explain_indent_err() { + let planner = DefaultPhysicalPlanner::default(); + let ctx = SessionContext::new(); + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let plan = Arc::new( + scan_empty(Some("employee"), &schema, None) + .unwrap() + .explain(true, false) + .unwrap() + .build() + .unwrap(), + ); + + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("plan_type", DataType::Utf8, false), + Field::new("plan", DataType::Utf8, false), + ])); + + // Create invalid indentation in the plan + let stringified_plans = + vec![StringifiedPlan::new(PlanType::FinalLogicalPlan, "Test Err")]; + + let explain = Explain { + verbose: false, + explain_format: ExplainFormat::Indent, + plan, + stringified_plans, + schema: schema.to_dfschema_ref().unwrap(), + logical_optimization_succeeded: false, + }; + let plan = planner + .handle_explain(&explain, &ctx.state()) + .await + .unwrap(); + if let Some(plan) = plan.as_any().downcast_ref::() { + let stringified_plans = plan.stringified_plans(); + assert_eq!(stringified_plans.len(), 1); + assert_eq!(stringified_plans[0].plan.as_str(), "Test Err"); + } else { + panic!( + "Plan was not an explain plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + } + struct ErrorExtensionPlanner {} #[async_trait] @@ -3152,4 +3680,61 @@ digraph { Ok(()) } + + // Reproducer for DataFusion issue #17405: + // + // The following SQL is semantically invalid. Notably, the `SELECT left_table.a, right_table.a` + // clause is missing from the explicit logical plan: + // + // SELECT a FROM ( + // -- SELECT left_table.a, right_table.a + // FROM left_table + // FULL JOIN right_table ON left_table.a = right_table.a + // ) AS alias + // GROUP BY a; + // + // As a result, the variables within `alias` subquery are not properly distinguished, which + // leads to a bug for logical and physical planning. + // + // The fix is to implicitly insert a Projection node to represent the missing SELECT clause to + // ensure each field is correctly aliased to a unique name when the SubqueryAlias node is added. + #[tokio::test] + async fn subquery_alias_confusing_the_optimizer() -> Result<()> { + let state = make_session_state(); + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema = Arc::new(schema); + + let table = MemTable::try_new(schema.clone(), vec![vec![]])?; + let table = Arc::new(table); + + let source = DefaultTableSource::new(table); + let source = Arc::new(source); + + let left = LogicalPlanBuilder::scan("left", source.clone(), None)?; + let right = LogicalPlanBuilder::scan("right", source, None)?.build()?; + + let join_keys = ( + vec![datafusion_common::Column::new(Some("left"), "a")], + vec![datafusion_common::Column::new(Some("right"), "a")], + ); + + let join = left.join(right, JoinType::Full, join_keys, None)?.build()?; + + let alias = subquery_alias(join, "alias")?; + + let planner = DefaultPhysicalPlanner::default(); + + let logical_plan = LogicalPlanBuilder::new(alias) + .aggregate(vec![col("a:1")], Vec::::new())? + .build()?; + let _physical_plan = planner.create_physical_plan(&logical_plan, &state).await?; + + let optimized_logical_plan = state.optimize(&logical_plan)?; + let _optimized_physical_plan = planner + .create_physical_plan(&optimized_logical_plan, &state) + .await?; + + Ok(()) + } } diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 9c9fcd04bf09a..d723620d32323 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -25,6 +25,7 @@ //! use datafusion::prelude::*; //! ``` +pub use crate::dataframe; pub use crate::dataframe::DataFrame; pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext}; pub use crate::execution::options::{ diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 8719a16f4919f..68f83e7f1f115 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -38,6 +38,7 @@ use crate::test_util::{aggr_test_schema, arrow_test_data}; use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; +#[cfg(feature = "compression")] use datafusion_common::DataFusionError; use datafusion_datasource::source::DataSourceExec; diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index e1328770cabdd..d31c2719973ec 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -17,18 +17,20 @@ //! Object store implementation used for testing -use crate::execution::context::SessionState; -use crate::execution::session_state::SessionStateBuilder; -use crate::prelude::SessionContext; -use futures::stream::BoxStream; -use futures::FutureExt; -use object_store::{ - memory::InMemory, path::Path, Error, GetOptions, GetResult, ListResult, - MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOpts, PutOptions, PutPayload, - PutResult, +use crate::{ + execution::{context::SessionState, session_state::SessionStateBuilder}, + object_store::{ + memory::InMemory, path::Path, Error, GetOptions, GetResult, ListResult, + MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, + PutPayload, PutResult, + }, + prelude::SessionContext, +}; +use futures::{stream::BoxStream, FutureExt}; +use std::{ + fmt::{Debug, Display, Formatter}, + sync::Arc, }; -use std::fmt::{Debug, Display, Formatter}; -use std::sync::Arc; use tokio::{ sync::Barrier, time::{timeout, Duration}, @@ -66,7 +68,7 @@ pub fn local_unpartitioned_file(path: impl AsRef) -> ObjectMeta ObjectMeta { location, last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), - size: metadata.len() as usize, + size: metadata.len(), e_tag: None, version: None, } @@ -118,7 +120,7 @@ impl ObjectStore for BlockingObjectStore { async fn put_multipart_opts( &self, location: &Path, - opts: PutMultipartOpts, + opts: PutMultipartOptions, ) -> object_store::Result> { self.inner.put_multipart_opts(location, opts).await } @@ -148,7 +150,7 @@ impl ObjectStore for BlockingObjectStore { "{} barrier wait timed out for {location}", BlockingObjectStore::NAME ); - log::error!("{}", error_message); + log::error!("{error_message}"); return Err(Error::Generic { store: BlockingObjectStore::NAME, source: error_message.into(), @@ -166,7 +168,7 @@ impl ObjectStore for BlockingObjectStore { fn list( &self, prefix: Option<&Path>, - ) -> BoxStream<'_, object_store::Result> { + ) -> BoxStream<'static, object_store::Result> { self.inner.list(prefix) } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index d6865ca3d532a..7149c5b0bd8ca 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -22,12 +22,14 @@ pub mod parquet; pub mod csv; +use futures::Stream; use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::task::{Context, Poll}; use crate::catalog::{TableProvider, TableProviderFactory}; use crate::dataframe::DataFrame; @@ -38,11 +40,13 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; +use crate::execution::SendableRecordBatchStream; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; use datafusion_common::TableReference; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use std::pin::Pin; use async_trait::async_trait; @@ -52,6 +56,8 @@ use tempfile::TempDir; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::execution::RecordBatchStream; + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, @@ -129,6 +135,7 @@ pub async fn test_table() -> Result { } /// Execute SQL and return results +#[cfg(feature = "sql")] pub async fn plan_and_collect( ctx: &SessionContext, sql: &str, @@ -178,7 +185,7 @@ impl TableProviderFactory for TestTableFactory { ) -> Result> { Ok(Arc::new(TestTableProvider { url: cmd.location.to_string(), - schema: Arc::new(cmd.schema.as_ref().into()), + schema: Arc::clone(cmd.schema.inner()), })) } } @@ -234,3 +241,44 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +/// Creates a bounded stream that emits the same record batch a specified number of times. +/// This is useful for testing purposes. +pub fn bounded_stream( + record_batch: RecordBatch, + limit: usize, +) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + record_batch, + count: 0, + limit, + }) +} + +struct BoundedStream { + record_batch: RecordBatch, + count: usize, + limit: usize, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + Poll::Ready(None) + } else { + self.count += 1; + Poll::Ready(Some(Ok(self.record_batch.clone()))) + } + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.record_batch.schema() + } +} diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 084554eecbdb0..eb4c61c025248 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -37,6 +37,7 @@ use crate::physical_plan::metrics::MetricsSet; use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig, SessionContext}; +use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use object_store::path::Path; @@ -82,27 +83,26 @@ impl TestParquetFile { props: WriterProperties, batches: impl IntoIterator, ) -> Result { - let file = File::create(&path).unwrap(); + let file = File::create(&path)?; let mut batches = batches.into_iter(); let first_batch = batches.next().expect("need at least one record batch"); let schema = first_batch.schema(); - let mut writer = - ArrowWriter::try_new(file, Arc::clone(&schema), Some(props)).unwrap(); + let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props))?; - writer.write(&first_batch).unwrap(); + writer.write(&first_batch)?; let mut num_rows = first_batch.num_rows(); for batch in batches { - writer.write(&batch).unwrap(); + writer.write(&batch)?; num_rows += batch.num_rows(); } - writer.close().unwrap(); + writer.close()?; println!("Generated test dataset with {num_rows} rows"); - let size = std::fs::metadata(&path)?.len() as usize; + let size = std::fs::metadata(&path)?.len(); let mut canonical_path = path.canonicalize()?; @@ -182,10 +182,11 @@ impl TestParquetFile { let physical_filter_expr = create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; - let source = Arc::new(ParquetSource::new(parquet_options).with_predicate( - Arc::clone(&self.schema), - Arc::clone(&physical_filter_expr), - )); + let source = Arc::new( + ParquetSource::new(parquet_options) + .with_predicate(Arc::clone(&physical_filter_expr)), + ) + .with_schema(Arc::clone(&self.schema)); let config = scan_config_builder.with_source(source).build(); let parquet_exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/tests/catalog/memory.rs b/datafusion/core/tests/catalog/memory.rs index b0753eb5c9494..ea9e71fc37467 100644 --- a/datafusion/core/tests/catalog/memory.rs +++ b/datafusion/core/tests/catalog/memory.rs @@ -47,6 +47,20 @@ fn memory_catalog_dereg_nonempty_schema() { assert!(cat.deregister_schema("foo", true).unwrap().is_some()); } +#[test] +fn memory_catalog_dereg_nonempty_schema_with_table_removal() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let test_table = + Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) as Arc; + schema.register_table("t".into(), test_table).unwrap(); + + cat.register_schema("foo", schema.clone()).unwrap(); + schema.deregister_table("t").unwrap(); + assert!(cat.deregister_schema("foo", false).unwrap().is_some()); +} + #[test] fn memory_catalog_dereg_empty_schema() { let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; @@ -109,7 +123,7 @@ async fn test_mem_provider() { assert!(provider.table_exist(table_name)); let other_table = EmptyTable::new(Arc::new(Schema::empty())); let result = provider.register_table(table_name.to_string(), Arc::new(other_table)); - assert!(result.is_err()); + assert!(result.is_err(), "The table test_table_exist already exists"); } #[tokio::test] diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index 9bcb9e41f86a9..e37a368f07719 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -45,12 +45,18 @@ mod optimizer; /// Run all tests that are found in the `physical_optimizer` directory mod physical_optimizer; +/// Run all tests that are found in the `schema_adapter` directory +mod schema_adapter; + /// Run all tests that are found in the `serde` directory mod serde; /// Run all tests that are found in the `catalog` directory mod catalog; +/// Run all tests that are found in the `tracing` directory +mod tracing; + #[cfg(test)] #[ctor::ctor] fn init() { diff --git a/datafusion/core/tests/csv_schema_fix_test.rs b/datafusion/core/tests/csv_schema_fix_test.rs new file mode 100644 index 0000000000000..2e1daa113b096 --- /dev/null +++ b/datafusion/core/tests/csv_schema_fix_test.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test for CSV schema inference with different column counts (GitHub issue #17516) + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::test_util::batches_to_sort_string; +use insta::assert_snapshot; +use std::fs; +use tempfile::TempDir; + +#[tokio::test] +async fn test_csv_schema_inference_different_column_counts() -> Result<()> { + // Create temporary directory for test files + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let temp_path = temp_dir.path(); + + // Create CSV file 1 with 3 columns (simulating older railway services format) + let csv1_content = r#"service_id,route_type,agency_id +1,bus,agency1 +2,rail,agency2 +3,bus,agency3 +"#; + fs::write(temp_path.join("services_2024.csv"), csv1_content)?; + + // Create CSV file 2 with 6 columns (simulating newer railway services format) + let csv2_content = r#"service_id,route_type,agency_id,stop_platform_change,stop_planned_platform,stop_actual_platform +4,rail,agency2,true,Platform A,Platform B +5,bus,agency1,false,Stop 1,Stop 1 +6,rail,agency3,true,Platform C,Platform D +"#; + fs::write(temp_path.join("services_2025.csv"), csv2_content)?; + + // Create DataFusion context + let ctx = SessionContext::new(); + + // This should now work (previously would have failed with column count mismatch) + // Enable truncated_rows to handle files with different column counts + let df = ctx + .read_csv( + temp_path.to_str().unwrap(), + CsvReadOptions::new().truncated_rows(true), + ) + .await + .expect("Should successfully read CSV directory with different column counts"); + + // Verify the schema contains all 6 columns (union of both files) + let df_clone = df.clone(); + let schema = df_clone.schema(); + assert_eq!( + schema.fields().len(), + 6, + "Schema should contain all 6 columns" + ); + + // Check that we have all expected columns + let field_names: Vec<&str> = + schema.fields().iter().map(|f| f.name().as_str()).collect(); + assert!(field_names.contains(&"service_id")); + assert!(field_names.contains(&"route_type")); + assert!(field_names.contains(&"agency_id")); + assert!(field_names.contains(&"stop_platform_change")); + assert!(field_names.contains(&"stop_planned_platform")); + assert!(field_names.contains(&"stop_actual_platform")); + + // All fields should be nullable since they don't appear in all files + for field in schema.fields() { + assert!( + field.is_nullable(), + "Field {} should be nullable", + field.name() + ); + } + + // Verify we can actually read the data + let results = df.collect().await?; + + // Calculate total rows across all batches + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(total_rows, 6, "Should have 6 total rows across all batches"); + + // All batches should have 6 columns (the union schema) + for batch in &results { + assert_eq!(batch.num_columns(), 6, "All batches should have 6 columns"); + assert_eq!( + batch.schema().fields().len(), + 6, + "Each batch should use the union schema with 6 fields" + ); + } + + // Verify the actual content of the data using snapshot testing + assert_snapshot!(batches_to_sort_string(&results), @r" + +------------+------------+-----------+----------------------+-----------------------+----------------------+ + | service_id | route_type | agency_id | stop_platform_change | stop_planned_platform | stop_actual_platform | + +------------+------------+-----------+----------------------+-----------------------+----------------------+ + | 1 | bus | agency1 | | | | + | 2 | rail | agency2 | | | | + | 3 | bus | agency3 | | | | + | 4 | rail | agency2 | true | Platform A | Platform B | + | 5 | bus | agency1 | false | Stop 1 | Stop 1 | + | 6 | rail | agency3 | true | Platform C | Platform D | + +------------+------------+-----------+----------------------+-----------------------+----------------------+ + "); + + Ok(()) +} diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index eb930b9a60bc9..cbdc4a448ea41 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -180,6 +180,13 @@ impl ExecutionPlan for CustomExecutionPlan { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); Ok(Statistics { num_rows: Precision::Exact(batch.num_rows()), diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index f68bcfaf15507..c80c0b4bf54ba 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -179,12 +179,12 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int64(Some(i))) => *i, + Expr::Literal(ScalarValue::Int8(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { - Expr::Literal(lit_value) => match lit_value { + Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 66c886510e96b..403c04f1737e1 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -184,6 +184,14 @@ impl ExecutionPlan for StatisticsValidation { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + Ok(Statistics::new_unknown(&self.schema)) + } else { + Ok(self.stats.clone()) + } + } } fn init_ctx(stats: Statistics, schema: Schema) -> Result { @@ -232,7 +240,7 @@ async fn sql_basic() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); // the statistics should be those of the source - assert_eq!(stats, physical_plan.statistics()?); + assert_eq!(stats, physical_plan.partition_statistics(None)?); Ok(()) } @@ -248,7 +256,7 @@ async fn sql_filter() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); - let stats = physical_plan.statistics()?; + let stats = physical_plan.partition_statistics(None)?; assert_eq!(stats.num_rows, Precision::Inexact(1)); Ok(()) @@ -257,20 +265,22 @@ async fn sql_filter() -> Result<()> { #[tokio::test] async fn sql_limit() -> Result<()> { let (stats, schema) = fully_defined(); - let col_stats = Statistics::unknown_column(&schema); let ctx = init_ctx(stats.clone(), schema)?; let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").await.unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); - // when the limit is smaller than the original number of lines - // we loose all statistics except the for number of rows which becomes the limit + // when the limit is smaller than the original number of lines we mark the statistics as inexact assert_eq!( Statistics { num_rows: Precision::Exact(5), - column_statistics: col_stats, + column_statistics: stats + .column_statistics + .iter() + .map(|c| c.clone().to_inexact()) + .collect(), total_byte_size: Precision::Absent }, - physical_plan.statistics()? + physical_plan.partition_statistics(None)? ); let df = ctx @@ -279,7 +289,7 @@ async fn sql_limit() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); // when the limit is larger than the original number of lines, statistics remain unchanged - assert_eq!(stats, physical_plan.statistics()?); + assert_eq!(stats, physical_plan.partition_statistics(None)?); Ok(()) } @@ -296,7 +306,7 @@ async fn sql_window() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); - let result = physical_plan.statistics()?; + let result = physical_plan.partition_statistics(None)?; assert_eq!(stats.num_rows, result.num_rows); let col_stats = result.column_statistics; diff --git a/datafusion/core/tests/data/empty_files/some_empty_with_header/a_empty.csv b/datafusion/core/tests/data/empty_files/some_empty_with_header/a_empty.csv new file mode 100644 index 0000000000000..f1968a0906d09 --- /dev/null +++ b/datafusion/core/tests/data/empty_files/some_empty_with_header/a_empty.csv @@ -0,0 +1 @@ +c1,c2,c3 diff --git a/datafusion/core/tests/data/empty_files/some_empty_with_header/b.csv b/datafusion/core/tests/data/empty_files/some_empty_with_header/b.csv new file mode 100644 index 0000000000000..ff596071444c3 --- /dev/null +++ b/datafusion/core/tests/data/empty_files/some_empty_with_header/b.csv @@ -0,0 +1,3 @@ +c1,c2,c3 +1,1,1 +2,2,2 diff --git a/datafusion/core/tests/data/empty_files/some_empty_with_header/c_nulls_column.csv b/datafusion/core/tests/data/empty_files/some_empty_with_header/c_nulls_column.csv new file mode 100644 index 0000000000000..bf86844cb0293 --- /dev/null +++ b/datafusion/core/tests/data/empty_files/some_empty_with_header/c_nulls_column.csv @@ -0,0 +1,2 @@ +c1,c2,c3 +3,3, diff --git a/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet b/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet new file mode 100644 index 0000000000000..ed700576a5afb Binary files /dev/null and b/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet differ diff --git a/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet b/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet new file mode 100644 index 0000000000000..29282cfbb6222 Binary files /dev/null and b/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet differ diff --git a/datafusion/core/tests/data/int96_nested.parquet b/datafusion/core/tests/data/int96_nested.parquet new file mode 100644 index 0000000000000..708823ded6faf Binary files /dev/null and b/datafusion/core/tests/data/int96_nested.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 0000000000000..ec164c6df7b5e Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 0000000000000..4b78cf963c111 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 0000000000000..09a01771d503c Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 0000000000000..6398cc43a2f5d Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/tpch_customer_small.parquet b/datafusion/core/tests/data/tpch_customer_small.parquet new file mode 100644 index 0000000000000..3d5f73ef3a066 Binary files /dev/null and b/datafusion/core/tests/data/tpch_customer_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_lineitem_small.parquet b/datafusion/core/tests/data/tpch_lineitem_small.parquet new file mode 100644 index 0000000000000..5e98706669d3b Binary files /dev/null and b/datafusion/core/tests/data/tpch_lineitem_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_nation_small.parquet b/datafusion/core/tests/data/tpch_nation_small.parquet new file mode 100644 index 0000000000000..99da99594cf89 Binary files /dev/null and b/datafusion/core/tests/data/tpch_nation_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_orders_small.parquet b/datafusion/core/tests/data/tpch_orders_small.parquet new file mode 100644 index 0000000000000..79e043137caf6 Binary files /dev/null and b/datafusion/core/tests/data/tpch_orders_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_part_small.parquet b/datafusion/core/tests/data/tpch_part_small.parquet new file mode 100644 index 0000000000000..d8e1d7d680aa2 Binary files /dev/null and b/datafusion/core/tests/data/tpch_part_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_partsupp_small.parquet b/datafusion/core/tests/data/tpch_partsupp_small.parquet new file mode 100644 index 0000000000000..711d58dda7493 Binary files /dev/null and b/datafusion/core/tests/data/tpch_partsupp_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_region_small.parquet b/datafusion/core/tests/data/tpch_region_small.parquet new file mode 100644 index 0000000000000..5e00a1f6da1d9 Binary files /dev/null and b/datafusion/core/tests/data/tpch_region_small.parquet differ diff --git a/datafusion/core/tests/data/tpch_supplier_small.parquet b/datafusion/core/tests/data/tpch_supplier_small.parquet new file mode 100644 index 0000000000000..18323395fcbed Binary files /dev/null and b/datafusion/core/tests/data/tpch_supplier_small.parquet differ diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index c763d4c8de2d6..b664fccdfa800 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -384,7 +384,7 @@ async fn test_fn_approx_median() -> Result<()> { #[tokio::test] async fn test_fn_approx_percentile_cont() -> Result<()> { - let expr = approx_percentile_cont(col("b"), lit(0.5), None); + let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), None); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; @@ -392,11 +392,26 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_snapshot!( batches_to_string(&batches), @r" - +---------------------------------------------+ - | approx_percentile_cont(test.b,Float64(0.5)) | - +---------------------------------------------+ - | 10 | - +---------------------------------------------+ + +---------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.5)) WITHIN GROUP [test.b ASC NULLS LAST] | + +---------------------------------------------------------------------------+ + | 10 | + +---------------------------------------------------------------------------+ + "); + + let expr = approx_percentile_cont(col("b").sort(false, false), lit(0.1), None); + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +----------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.1)) WITHIN GROUP [test.b DESC NULLS LAST] | + +----------------------------------------------------------------------------+ + | 100 | + +----------------------------------------------------------------------------+ "); // the arg2 parameter is a complex expr, but it can be evaluated to the literal value @@ -405,23 +420,59 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { None::<&str>, "arg_2".to_string(), )); - let expr = approx_percentile_cont(col("b"), alias_expr, None); + let expr = approx_percentile_cont(col("b").sort(true, false), alias_expr, None); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; assert_snapshot!( batches_to_string(&batches), @r" - +--------------------------------------+ - | approx_percentile_cont(test.b,arg_2) | - +--------------------------------------+ - | 10 | - +--------------------------------------+ + +--------------------------------------------------------------------+ + | approx_percentile_cont(arg_2) WITHIN GROUP [test.b ASC NULLS LAST] | + +--------------------------------------------------------------------+ + | 10 | + +--------------------------------------------------------------------+ + " + ); + + let alias_expr = Expr::Alias(Alias::new( + cast(lit(0.1), DataType::Float32), + None::<&str>, + "arg_2".to_string(), + )); + let expr = approx_percentile_cont(col("b").sort(false, false), alias_expr, None); + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +---------------------------------------------------------------------+ + | approx_percentile_cont(arg_2) WITHIN GROUP [test.b DESC NULLS LAST] | + +---------------------------------------------------------------------+ + | 100 | + +---------------------------------------------------------------------+ " ); // with number of centroids set - let expr = approx_percentile_cont(col("b"), lit(0.5), Some(lit(2))); + let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), Some(lit(2))); + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +------------------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.5),Int32(2)) WITHIN GROUP [test.b ASC NULLS LAST] | + +------------------------------------------------------------------------------------+ + | 30 | + +------------------------------------------------------------------------------------+ + "); + + let expr = + approx_percentile_cont(col("b").sort(false, false), lit(0.1), Some(lit(2))); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; @@ -429,11 +480,11 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_snapshot!( batches_to_string(&batches), @r" - +------------------------------------------------------+ - | approx_percentile_cont(test.b,Float64(0.5),Int32(2)) | - +------------------------------------------------------+ - | 30 | - +------------------------------------------------------+ + +-------------------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.1),Int32(2)) WITHIN GROUP [test.b DESC NULLS LAST] | + +-------------------------------------------------------------------------------------+ + | 69 | + +-------------------------------------------------------------------------------------+ "); Ok(()) @@ -1164,7 +1215,7 @@ async fn test_fn_decode() -> Result<()> { // Note that the decode function returns binary, and the default display of // binary is "hexadecimal" and therefore the output looks like decode did // nothing. So compare to a constant. - let df_schema = DFSchema::try_from(test_schema().as_ref().clone())?; + let df_schema = DFSchema::try_from(test_schema())?; let expr = decode(encode(col("a"), lit("hex")), lit("hex")) // need to cast to utf8 otherwise the default display of binary array is hex // so it looks like nothing is done @@ -1265,3 +1316,28 @@ async fn test_count_wildcard() -> Result<()> { Ok(()) } + +/// Call count wildcard with alias from dataframe API +#[tokio::test] +async fn test_count_wildcard_with_alias() -> Result<()> { + let df = create_test_table().await?; + let result_df = df.aggregate(vec![], vec![count_all().alias("total_count")])?; + + let schema = result_df.schema(); + assert_eq!(schema.fields().len(), 1); + assert_eq!(schema.field(0).name(), "total_count"); + assert_eq!(*schema.field(0).data_type(), DataType::Int64); + + let batches = result_df.collect().await?; + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 1); + + let count_array = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count_array.value(0), 4); + + Ok(()) +} diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1855a512048d6..aa538f6dee813 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -27,20 +27,21 @@ use arrow::array::{ }; use arrow::buffer::ScalarBuffer; use arrow::datatypes::{ - DataType, Field, Float32Type, Int32Type, Schema, SchemaRef, UInt64Type, UnionFields, - UnionMode, + DataType, Field, Float32Type, Int32Type, Schema, UInt64Type, UnionFields, UnionMode, }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{SortOptions, TimeUnit}; +use datafusion::{assert_batches_eq, dataframe}; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ - array_agg, avg, count, count_distinct, max, median, min, sum, + array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, + sum_distinct, }; use datafusion_functions_nested::make_array::make_array_udf; -use datafusion_functions_window::expr_fn::{first_value, row_number}; +use datafusion_functions_window::expr_fn::{first_value, lead, row_number}; use insta::assert_snapshot; use object_store::local::LocalFileSystem; -use sqlparser::ast::NullTreatment; use std::collections::HashMap; use std::fs; use std::sync::Arc; @@ -63,24 +64,36 @@ use datafusion::test_util::{ use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ - assert_contains, Constraint, Constraints, DataFusionError, ParamValues, ScalarValue, - TableReference, UnnestOptions, + assert_contains, internal_datafusion_err, Constraint, Constraints, DFSchema, + DataFusionError, ParamValues, ScalarValue, TableReference, UnnestOptions, }; use datafusion_common_runtime::SpawnedTask; +use datafusion_datasource::file_format::format_as_file_type; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::expr::{GroupingSet, Sort, WindowFunction}; +use datafusion_expr::expr::{ + FieldMetadata, GroupingSet, NullTreatment, Sort, WindowFunction, +}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, - ScalarFunctionImplementation, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + LogicalPlanBuilder, ScalarFunctionImplementation, SortExpr, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_plan::{displayable, ExecutionPlanProperties}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::empty::EmptyExec; +use datafusion_physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; + +use datafusion::error::Result as DataFusionResult; +use datafusion_functions_window::expr_fn::lag; // Get string representation of the plan async fn physical_plan_to_string(df: &DataFrame) -> String { @@ -90,8 +103,8 @@ async fn physical_plan_to_string(df: &DataFrame) -> String { .await .expect("Error creating physical plan"); - let formated = displayable(physical_plan.as_ref()).indent(true); - formated.to_string() + let formatted = displayable(physical_plan.as_ref()).indent(true); + formatted.to_string() } pub fn table_with_constraints() -> Arc { @@ -116,8 +129,7 @@ pub fn table_with_constraints() -> Arc { } async fn assert_logical_expr_schema_eq_physical_expr_schema(df: DataFrame) -> Result<()> { - let logical_expr_dfschema = df.schema(); - let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let logical_expr_schema = Arc::clone(df.schema().inner()); let batches = df.collect().await?; let physical_expr_schema = batches[0].schema(); assert_eq!(logical_expr_schema, physical_expr_schema); @@ -149,6 +161,46 @@ async fn test_array_agg_ord_schema() -> Result<()> { Ok(()) } +type WindowFnCase = (fn() -> Expr, &'static str); + +#[tokio::test] +async fn with_column_window_functions() -> DataFusionResult<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let ctx = SessionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + // Define test cases: (expr builder, alias name) + let test_cases: Vec = vec![ + (|| lag(col("a"), Some(1), None), "lag_val"), + (|| lead(col("a"), Some(1), None), "lead_val"), + (row_number, "row_num"), + ]; + + for (make_expr, alias) in test_cases { + let df = ctx.table("t").await?; + let expr = make_expr(); + let df_with = df.with_column(alias, expr)?; + let df_schema = df_with.schema().clone(); + + assert!( + df_schema.has_column_with_unqualified_name(alias), + "Schema does not contain expected column {alias}", + ); + + assert_eq!(2, df_schema.columns().len()); + } + + Ok(()) +} + #[tokio::test] async fn test_coalesce_schema() -> Result<()> { let ctx = SessionContext::new(); @@ -494,32 +546,35 @@ async fn drop_with_periods() -> Result<()> { #[tokio::test] async fn aggregate() -> Result<()> { // build plan using DataFrame API - let df = test_table().await?; + // union so some of the distincts have a clearly distinct result + let df = test_table().await?.union(test_table().await?)?; let group_expr = vec![col("c1")]; let aggr_expr = vec![ - min(col("c12")), - max(col("c12")), - avg(col("c12")), - sum(col("c12")), - count(col("c12")), - count_distinct(col("c12")), + min(col("c4")).alias("min(c4)"), + max(col("c4")).alias("max(c4)"), + avg(col("c4")).alias("avg(c4)"), + avg_distinct(col("c4")).alias("avg_distinct(c4)"), + sum(col("c4")).alias("sum(c4)"), + sum_distinct(col("c4")).alias("sum_distinct(c4)"), + count(col("c4")).alias("count(c4)"), + count_distinct(col("c4")).alias("count_distinct(c4)"), ]; let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; assert_snapshot!( batches_to_sort_string(&df), - @r###" - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - | c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - | a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 | - | b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 | - | c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 | - | d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 | - | e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - "### + @r" + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + | c1 | min(c4) | max(c4) | avg(c4) | avg_distinct(c4) | sum(c4) | sum_distinct(c4) | count(c4) | count_distinct(c4) | + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + | a | -28462 | 32064 | 306.04761904761904 | 306.04761904761904 | 12854 | 6427 | 42 | 21 | + | b | -28070 | 25286 | 7732.315789473684 | 7732.315789473684 | 293828 | 146914 | 38 | 19 | + | c | -30508 | 29106 | -1320.5238095238096 | -1320.5238095238096 | -55462 | -27731 | 42 | 21 | + | d | -24558 | 31106 | 10890.111111111111 | 10890.111111111111 | 392044 | 196022 | 36 | 18 | + | e | -31500 | 32514 | -4268.333333333333 | -4268.333333333333 | -179270 | -89635 | 42 | 21 | + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + " ); Ok(()) @@ -534,7 +589,9 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> { min(col("c12")), max(col("c12")), avg(col("c12")), + avg_distinct(col("c12")), sum(col("c12")), + sum_distinct(col("c12")), count(col("c12")), count_distinct(col("c12")), median(col("c12")), @@ -610,12 +667,12 @@ async fn test_aggregate_with_pk2() -> Result<()> { let df = df.filter(predicate)?; assert_snapshot!( physical_plan_to_string(&df).await, - @r###" - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 AND name@1 = a - AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[] + @r" + AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=Sorted + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: id@0 = 1 AND name@1 = a DataSourceExec: partitions=1, partition_sizes=[1] - "### + " ); // Since id and name are functionally dependant, we can use name among expression @@ -659,12 +716,12 @@ async fn test_aggregate_with_pk3() -> Result<()> { let df = df.select(vec![col("id"), col("name")])?; assert_snapshot!( physical_plan_to_string(&df).await, - @r###" - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 - AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[] + @r" + AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=PartiallySorted([0]) + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: id@0 = 1 DataSourceExec: partitions=1, partition_sizes=[1] - "### + " ); // Since id and name are functionally dependant, we can use name among expression @@ -710,12 +767,12 @@ async fn test_aggregate_with_pk4() -> Result<()> { // columns are not used. assert_snapshot!( physical_plan_to_string(&df).await, - @r###" - CoalesceBatchesExec: target_batch_size=8192 - FilterExec: id@0 = 1 - AggregateExec: mode=Single, gby=[id@0 as id], aggr=[] + @r" + AggregateExec: mode=Single, gby=[id@0 as id], aggr=[], ordering_mode=Sorted + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: id@0 = 1 DataSourceExec: partitions=1, partition_sizes=[1] - "### + " ); let df_results = df.collect().await?; @@ -906,7 +963,7 @@ async fn window_using_aggregates() -> Result<()> { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::from(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( @@ -957,6 +1014,83 @@ async fn window_using_aggregates() -> Result<()> { Ok(()) } +#[tokio::test] +async fn window_aggregates_with_filter() -> Result<()> { + // Define a small in-memory table to make expected values clear + let ts: Int32Array = [1, 2, 3, 4, 5].into_iter().collect(); + let val: Int32Array = [-3, -2, 1, 4, -1].into_iter().collect(); + let batch = RecordBatch::try_from_iter(vec![ + ("ts", Arc::new(ts) as _), + ("val", Arc::new(val) as _), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + + let df = ctx.table("t").await?; + + // Build filtered window aggregates over ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + let mut exprs = vec![ + (datafusion_functions_aggregate::sum::sum_udaf(), "sum_pos"), + ( + datafusion_functions_aggregate::average::avg_udaf(), + "avg_pos", + ), + ( + datafusion_functions_aggregate::min_max::min_udaf(), + "min_pos", + ), + ( + datafusion_functions_aggregate::min_max::max_udaf(), + "max_pos", + ), + ( + datafusion_functions_aggregate::count::count_udaf(), + "cnt_pos", + ), + ] + .into_iter() + .map(|(func, alias)| { + let w = WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(func), + vec![col("val")], + ); + + Expr::from(w) + .order_by(vec![col("ts").sort(true, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )) + .filter(col("val").gt(lit(0))) + .build() + .unwrap() + .alias(alias) + }) + .collect::>(); + exprs.extend_from_slice(&[col("ts"), col("val")]); + + let results = df.select(exprs)?.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r###" + +---------+---------+---------+---------+---------+----+-----+ + | sum_pos | avg_pos | min_pos | max_pos | cnt_pos | ts | val | + +---------+---------+---------+---------+---------+----+-----+ + | | | | | 0 | 1 | -3 | + | | | | | 0 | 2 | -2 | + | 1 | 1.0 | 1 | 1 | 1 | 3 | 1 | + | 5 | 2.5 | 1 | 4 | 2 | 4 | 4 | + | 5 | 2.5 | 1 | 4 | 2 | 5 | -1 | + +---------+---------+---------+---------+---------+----+-----+ + "### + ); + + Ok(()) +} + // Test issue: https://github.com/apache/datafusion/issues/10346 #[tokio::test] async fn test_select_over_aggregate_schema() -> Result<()> { @@ -1209,9 +1343,9 @@ async fn join_on_filter_datatype() -> Result<()> { let join = left.clone().join_on( right.clone(), JoinType::Inner, - Some(Expr::Literal(ScalarValue::Null)), + Some(Expr::Literal(ScalarValue::Null, None)), )?; - assert_snapshot!(join.into_optimized_plan().unwrap(), @"EmptyRelation"); + assert_snapshot!(join.into_optimized_plan().unwrap(), @"EmptyRelation: rows=0"); // JOIN ON expression must be boolean type let join = left.join_on(right, JoinType::Inner, Some(lit("TRUE")))?; @@ -1359,6 +1493,36 @@ async fn except() -> Result<()> { Ok(()) } +#[tokio::test] +async fn except_distinct() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c3"])?; + let d2 = df.clone(); + let plan = df.except_distinct(d2)?; + let result = plan.logical_plan().clone(); + let expected = create_plan( + "SELECT c1, c3 FROM aggregate_test_100 + EXCEPT DISTINCT SELECT c1, c3 FROM aggregate_test_100", + ) + .await?; + assert_same_plan(&result, &expected); + Ok(()) +} + +#[tokio::test] +async fn intersect_distinct() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c3"])?; + let d2 = df.clone(); + let plan = df.intersect_distinct(d2)?; + let result = plan.logical_plan().clone(); + let expected = create_plan( + "SELECT c1, c3 FROM aggregate_test_100 + INTERSECT DISTINCT SELECT c1, c3 FROM aggregate_test_100", + ) + .await?; + assert_same_plan(&result, &expected); + Ok(()) +} + #[tokio::test] async fn register_table() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c12"])?; @@ -1852,6 +2016,56 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { Ok(()) } +#[tokio::test] +async fn describe_lookup_via_quoted_identifier() -> Result<()> { + let ctx = SessionContext::new(); + let name = "aggregate_test_100"; + register_aggregate_csv(&ctx, name).await?; + let df = ctx.table(name); + + let df = df + .await? + .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? + .limit(0, Some(1))? + .sort(vec![ + // make the test deterministic + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), + ])? + .select_columns(&["c1"])?; + + let df_renamed = df.clone().with_column_renamed("c1", "CoLu.Mn[\"1\"]")?; + + let describe_result = df_renamed.describe().await?; + describe_result + .clone() + .sort(vec![ + col("describe").sort(true, true), + col("CoLu.Mn[\"1\"]").sort(true, true), + ])? + .show() + .await?; + assert_snapshot!( + batches_to_sort_string(&describe_result.clone().collect().await?), + @r###" + +------------+--------------+ + | describe | CoLu.Mn["1"] | + +------------+--------------+ + | count | 1 | + | max | a | + | mean | null | + | median | null | + | min | a | + | null_count | 0 | + | std | null | + +------------+--------------+ + "### + ); + + Ok(()) +} + #[tokio::test] async fn cast_expr_test() -> Result<()> { let df = test_table() @@ -2094,6 +2308,7 @@ async fn verify_join_output_partitioning() -> Result<()> { JoinType::LeftAnti, JoinType::RightAnti, JoinType::LeftMark, + JoinType::RightMark, ]; let default_partition_count = SessionConfig::new().target_partitions(); @@ -2127,7 +2342,8 @@ async fn verify_join_output_partitioning() -> Result<()> { JoinType::Inner | JoinType::Right | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { let right_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c2_c1", &join_schema)?), Arc::new(Column::new_with_schema("c2_c2", &join_schema)?), @@ -2454,6 +2670,11 @@ async fn write_table_with_order() -> Result<()> { write_df = write_df .with_column_renamed("column1", "tablecol1") .unwrap(); + + // Ensure the column type matches the target table + write_df = + write_df.with_column("tablecol1", cast(col("tablecol1"), DataType::Utf8View))?; + let sql_str = "create external table data(tablecol1 varchar) stored as parquet location '" .to_owned() @@ -2525,7 +2746,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { | | TableScan: t1 projection=[b] | | physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] | | | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] | - | | SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] | | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | | | CoalesceBatchesExec: target_batch_size=8192 | @@ -2655,23 +2876,20 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" - +---------------+---------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------+ - | logical_plan | LeftSemi Join: | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Projection: | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] | - | | TableScan: t2 projection=[] | - | physical_plan | NestedLoopJoinExec: join_type=RightSemi | - | | ProjectionExec: expr=[] | - | | PlaceholderRowExec | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------+ - "### + @r" + +---------------+-----------------------------------------------------+ + | plan_type | plan | + +---------------+-----------------------------------------------------+ + | logical_plan | LeftSemi Join: | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | EmptyRelation: rows=1 | + | physical_plan | NestedLoopJoinExec: join_type=RightSemi | + | | PlaceholderRowExec | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+-----------------------------------------------------+ + " ); let df_results = ctx @@ -2694,23 +2912,20 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" - +---------------+---------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------+ - | logical_plan | LeftSemi Join: | - | | TableScan: t1 projection=[a, b] | - | | SubqueryAlias: __correlated_sq_1 | - | | Projection: | - | | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] | - | | TableScan: t2 projection=[] | - | physical_plan | NestedLoopJoinExec: join_type=RightSemi | - | | ProjectionExec: expr=[] | - | | PlaceholderRowExec | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------+ - "### + @r" + +---------------+-----------------------------------------------------+ + | plan_type | plan | + +---------------+-----------------------------------------------------+ + | logical_plan | LeftSemi Join: | + | | TableScan: t1 projection=[a, b] | + | | SubqueryAlias: __correlated_sq_1 | + | | EmptyRelation: rows=1 | + | physical_plan | NestedLoopJoinExec: join_type=RightSemi | + | | PlaceholderRowExec | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+-----------------------------------------------------+ + " ); Ok(()) @@ -2729,20 +2944,20 @@ async fn test_count_wildcard_on_window() -> Result<()> { assert_snapshot!( pretty_format_batches(&sql_results).unwrap(), - @r###" - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | - | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | - | | TableScan: t1 projection=[a] | - | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | - | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: false }], mode=[Sorted] | - | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - "### + @r#" + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | + | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | + | | TableScan: t1 projection=[a] | + | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | + | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { name: "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING], mode=[Sorted] | + | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + "# ); let df_results = ctx @@ -2763,20 +2978,20 @@ async fn test_count_wildcard_on_window() -> Result<()> { assert_snapshot!( pretty_format_batches(&df_results).unwrap(), - @r###" - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | plan_type | plan | - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | - | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | - | | TableScan: t1 projection=[a] | - | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | - | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: false }], mode=[Sorted] | - | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | - | | DataSourceExec: partitions=1, partition_sizes=[1] | - | | | - +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ - "### + @r#" + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | plan_type | plan | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING | + | | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] | + | | TableScan: t1 projection=[a] | + | physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] | + | | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { name: "count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING], mode=[Sorted] | + | | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] | + | | DataSourceExec: partitions=1, partition_sizes=[1] | + | | | + +---------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + "# ); Ok(()) @@ -3570,16 +3785,15 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+------------------------------------------------+--------------------+ - | shape_id | points | tags | - +----------+------------------------------------------------+--------------------+ - | 1 | [{x: -3, y: -4}, {x: -3, y: 6}, {x: 2, y: -2}] | [tag1] | - | 2 | | [tag1, tag2] | - | 3 | [{x: -9, y: 2}, {x: -10, y: -4}] | | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | [tag1, tag2, tag3] | - +----------+------------------------------------------------+--------------------+ - "### - ); + +----------+---------------------------------+--------------------------+ + | shape_id | points | tags | + +----------+---------------------------------+--------------------------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+---------------------------------+--------------------------+ + "###); // Unnest tags let df = table_with_nested_types(NUM_ROWS).await?; @@ -3587,19 +3801,20 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+------------------------------------------------+------+ - | shape_id | points | tags | - +----------+------------------------------------------------+------+ - | 1 | [{x: -3, y: -4}, {x: -3, y: 6}, {x: 2, y: -2}] | tag1 | - | 2 | | tag1 | - | 2 | | tag2 | - | 3 | [{x: -9, y: 2}, {x: -10, y: -4}] | | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag1 | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag2 | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag3 | - +----------+------------------------------------------------+------+ - "### - ); + +----------+---------------------------------+------+ + | shape_id | points | tags | + +----------+---------------------------------+------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+---------------------------------+------+ + "###); // Test aggregate results for tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3612,20 +3827,18 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+-----------------+--------------------+ - | shape_id | points | tags | - +----------+-----------------+--------------------+ - | 1 | {x: -3, y: -4} | [tag1] | - | 1 | {x: -3, y: 6} | [tag1] | - | 1 | {x: 2, y: -2} | [tag1] | - | 2 | | [tag1, tag2] | - | 3 | {x: -10, y: -4} | | - | 3 | {x: -9, y: 2} | | - | 4 | {x: -3, y: 5} | [tag1, tag2, tag3] | - | 4 | {x: 2, y: -1} | [tag1, tag2, tag3] | - +----------+-----------------+--------------------+ - "### - ); + +----------+----------------+--------------------------+ + | shape_id | points | tags | + +----------+----------------+--------------------------+ + | 1 | {x: -3, y: -4} | [tag1] | + | 1 | {x: 5, y: -8} | [tag1] | + | 2 | {x: -2, y: -8} | [tag1] | + | 2 | {x: 6, y: 2} | [tag1] | + | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | + | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+----------------+--------------------------+ + "###); // Test aggregate results for points. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3642,25 +3855,26 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+-----------------+------+ - | shape_id | points | tags | - +----------+-----------------+------+ - | 1 | {x: -3, y: -4} | tag1 | - | 1 | {x: -3, y: 6} | tag1 | - | 1 | {x: 2, y: -2} | tag1 | - | 2 | | tag1 | - | 2 | | tag2 | - | 3 | {x: -10, y: -4} | | - | 3 | {x: -9, y: 2} | | - | 4 | {x: -3, y: 5} | tag1 | - | 4 | {x: -3, y: 5} | tag2 | - | 4 | {x: -3, y: 5} | tag3 | - | 4 | {x: 2, y: -1} | tag1 | - | 4 | {x: 2, y: -1} | tag2 | - | 4 | {x: 2, y: -1} | tag3 | - +----------+-----------------+------+ - "### - ); + +----------+----------------+------+ + | shape_id | points | tags | + +----------+----------------+------+ + | 1 | {x: -3, y: -4} | tag1 | + | 1 | {x: 5, y: -8} | tag1 | + | 2 | {x: -2, y: -8} | tag1 | + | 2 | {x: 6, y: 2} | tag1 | + | 3 | {x: -2, y: 5} | tag1 | + | 3 | {x: -2, y: 5} | tag2 | + | 3 | {x: -2, y: 5} | tag3 | + | 3 | {x: -2, y: 5} | tag4 | + | 3 | {x: -9, y: -7} | tag1 | + | 3 | {x: -9, y: -7} | tag2 | + | 3 | {x: -9, y: -7} | tag3 | + | 3 | {x: -9, y: -7} | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+----------------+------+ + "###); // Test aggregate results for points and tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3994,15 +4208,15 @@ async fn unnest_aggregate_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +--------------------+ - | tags | - +--------------------+ - | | - | [tag1, tag2, tag3] | - | [tag1, tag2, tag3] | - | [tag1, tag2] | - | [tag1] | - +--------------------+ + +--------------------------+ + | tags | + +--------------------------+ + | [tag1, tag2, tag3, tag4] | + | [tag1, tag2, tag3] | + | [tag1, tag2] | + | [tag1] | + | [tag1] | + +--------------------------+ "### ); @@ -4018,7 +4232,7 @@ async fn unnest_aggregate_columns() -> Result<()> { +-------------+ | count(tags) | +-------------+ - | 9 | + | 11 | +-------------+ "### ); @@ -4267,7 +4481,7 @@ async fn unnest_analyze_metrics() -> Result<()> { assert_contains!(&formatted, "elapsed_compute="); assert_contains!(&formatted, "input_batches=1"); assert_contains!(&formatted, "input_rows=5"); - assert_contains!(&formatted, "output_rows=10"); + assert_contains!(&formatted, "output_rows=11"); assert_contains!(&formatted, "output_batches=1"); Ok(()) @@ -4472,7 +4686,10 @@ async fn consecutive_projection_same_schema() -> Result<()> { // Add `t` column full of nulls let df = df - .with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32)) + .with_column( + "t", + cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32), + ) .unwrap(); df.clone().show().await.unwrap(); @@ -4614,7 +4831,7 @@ async fn table_with_nested_types(n: usize) -> Result { shape_id_builder.append_value(idx as u32 + 1); // Add a random number of points - let num_points: usize = rng.gen_range(0..4); + let num_points: usize = rng.random_range(0..4); if num_points > 0 { for _ in 0..num_points.max(2) { // Add x value @@ -4622,13 +4839,13 @@ async fn table_with_nested_types(n: usize) -> Result { .values() .field_builder::(0) .unwrap() - .append_value(rng.gen_range(-10..10)); + .append_value(rng.random_range(-10..10)); // Add y value points_builder .values() .field_builder::(1) .unwrap() - .append_value(rng.gen_range(-10..10)); + .append_value(rng.random_range(-10..10)); points_builder.values().append(true); } } @@ -4637,7 +4854,7 @@ async fn table_with_nested_types(n: usize) -> Result { points_builder.append(num_points > 0); // Append tags. - let num_tags: usize = rng.gen_range(0..5); + let num_tags: usize = rng.random_range(0..5); for id in 0..num_tags { tags_builder.values().append_value(format!("tag{}", id + 1)); } @@ -4791,7 +5008,7 @@ async fn use_var_provider() -> Result<()> { Field::new("bar", DataType::Int64, false), ])); - let mem_table = Arc::new(MemTable::try_new(schema, vec![])?); + let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![]])?); let config = SessionConfig::new() .with_target_partitions(4) @@ -4849,11 +5066,11 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Filter: a = $0 [a:Int32] Projection: Int32(1) AS a [a:Int32] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + " ); // Executing LogicalPlans with placeholders that don't have bound values @@ -4882,11 +5099,11 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Filter: a = Int32(3) [a:Int32] Projection: Int32(1) AS a [a:Int32] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + " ); // N.B., the test is basically `SELECT 1 as a WHERE a = 3;` which returns no results. @@ -4913,10 +5130,10 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Projection: $1 [$1:Null;N] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + " ); // Executing LogicalPlans with placeholders that don't have bound values @@ -4943,10 +5160,10 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> { assert_snapshot!( actual, - @r###" + @r" Projection: Int32(3) AS $1 [$1:Null;N] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + " ); assert_snapshot!( @@ -4982,11 +5199,11 @@ async fn test_dataframe_placeholder_like_expression() -> Result<()> { assert_snapshot!( actual, - @r###" + @r#" Filter: a LIKE $1 [a:Utf8] Projection: Utf8("foo") AS a [a:Utf8] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + "# ); // Executing LogicalPlans with placeholders that don't have bound values @@ -5015,11 +5232,11 @@ async fn test_dataframe_placeholder_like_expression() -> Result<()> { assert_snapshot!( actual, - @r###" + @r#" Filter: a LIKE Utf8("f%") [a:Utf8] Projection: Utf8("foo") AS a [a:Utf8] - EmptyRelation [] - "### + EmptyRelation: rows=1 [] + "# ); assert_snapshot!( @@ -5079,7 +5296,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { .await?; // Explicitly read the parquet file at c2=123 to verify the physical files are partitioned - let partitioned_file = format!("{out_dir}/c2=123", out_dir = out_dir); + let partitioned_file = format!("{out_dir}/c2=123"); let filter_df = ctx .read_parquet(&partitioned_file, ParquetReadOptions::default()) .await?; @@ -5229,11 +5446,11 @@ async fn union_literal_is_null_and_not_null() -> Result<()> { for batch in batches { // Verify schema is the same for all batches if !schema.contains(&batch.schema()) { - return Err(DataFusionError::Internal(format!( + return Err(internal_datafusion_err!( "Schema mismatch. Previously had\n{:#?}\n\nGot:\n{:#?}", &schema, batch.schema() - ))); + )); } } @@ -5575,7 +5792,7 @@ async fn test_alias() -> Result<()> { .await? .select(vec![col("a"), col("test.b"), lit(1).alias("one")])? .alias("table_alias")?; - // All ouput column qualifiers are changed to "table_alias" + // All output column qualifiers are changed to "table_alias" df.schema().columns().iter().for_each(|c| { assert_eq!(c.relation, Some("table_alias".into())); }); @@ -5616,6 +5833,7 @@ async fn test_alias() -> Result<()> { async fn test_alias_with_metadata() -> Result<()> { let mut metadata = HashMap::new(); metadata.insert(String::from("k"), String::from("v")); + let metadata = FieldMetadata::from(metadata); let df = create_test_table("test") .await? .select(vec![col("a").alias_with_metadata("b", Some(metadata))])? @@ -6017,3 +6235,247 @@ async fn test_insert_into_casting_support() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_dataframe_from_columns() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let b: ArrayRef = Arc::new(BooleanArray::from(vec![true, true, false])); + let c: ArrayRef = Arc::new(StringArray::from(vec![Some("foo"), Some("bar"), None])); + let df = DataFrame::from_columns(vec![("a", a), ("b", b), ("c", c)])?; + + assert_eq!(df.schema().fields().len(), 3); + assert_eq!(df.clone().count().await?, 3); + + let rows = df.sort(vec![col("a").sort(true, true)])?; + assert_batches_eq!( + &[ + "+---+-------+-----+", + "| a | b | c |", + "+---+-------+-----+", + "| 1 | true | foo |", + "| 2 | true | bar |", + "| 3 | false | |", + "+---+-------+-----+", + ], + &rows.collect().await? + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_macro() -> Result<()> { + let df = dataframe!( + "a" => [1, 2, 3], + "b" => [true, true, false], + "c" => [Some("foo"), Some("bar"), None] + )?; + + assert_eq!(df.schema().fields().len(), 3); + assert_eq!(df.clone().count().await?, 3); + + let rows = df.sort(vec![col("a").sort(true, true)])?; + assert_batches_eq!( + &[ + "+---+-------+-----+", + "| a | b | c |", + "+---+-------+-----+", + "| 1 | true | foo |", + "| 2 | true | bar |", + "| 3 | false | |", + "+---+-------+-----+", + ], + &rows.collect().await? + ); + + let df_empty = dataframe!()?; + assert_eq!(df_empty.schema().fields().len(), 0); + assert_eq!(df_empty.count().await?, 0); + + Ok(()) +} + +#[tokio::test] +async fn test_copy_schema() -> Result<()> { + let tmp_dir = TempDir::new()?; + + let session_state = SessionStateBuilder::new_with_default_features().build(); + + let session_ctx = SessionContext::new_with_state(session_state); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + + // Create and register the source table with the provided schema and data + let source_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + session_ctx.register_table("source_table", source_table.clone())?; + + let target_path = tmp_dir.path().join("target.csv"); + + let query = format!( + "COPY source_table TO '{:?}' STORED AS csv", + target_path.to_str().unwrap() + ); + + let result = session_ctx.sql(&query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) +} + +#[tokio::test] +async fn test_copy_to_preserves_order() -> Result<()> { + let tmp_dir = TempDir::new()?; + + let session_state = SessionStateBuilder::new_with_default_features().build(); + let session_ctx = SessionContext::new_with_state(session_state); + + let target_path = tmp_dir.path().join("target_ordered.csv"); + let csv_file_format = session_ctx + .state() + .get_file_format_factory("csv") + .map(format_as_file_type) + .unwrap(); + + let ordered_select_plan = LogicalPlanBuilder::values(vec![ + vec![lit(1u64)], + vec![lit(10u64)], + vec![lit(20u64)], + vec![lit(100u64)], + ])? + .sort(vec![SortExpr::new(col("column1"), false, true)])? + .build()?; + + let copy_to_plan = LogicalPlanBuilder::copy_to( + ordered_select_plan, + target_path.to_str().unwrap().to_string(), + csv_file_format, + HashMap::new(), + vec![], + )? + .build()?; + + let union_side_branch = LogicalPlanBuilder::values(vec![vec![lit(1u64)]])?.build()?; + let union_plan = LogicalPlanBuilder::from(copy_to_plan) + .union(union_side_branch)? + .build()?; + + let frame = session_ctx.execute_logical_plan(union_plan).await?; + let physical_plan = frame.create_physical_plan().await?; + + let physical_plan_format = + displayable(physical_plan.as_ref()).indent(true).to_string(); + + // Expect that input to the DataSinkExec is sorted correctly + assert_snapshot!( + physical_plan_format, + @r###" + UnionExec + DataSinkExec: sink=CsvSink(file_groups=[]) + SortExec: expr=[column1@0 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[1] + DataSourceExec: partitions=1, partition_sizes=[1] + "### + ); + Ok(()) +} + +#[tokio::test] +async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { + let ctx = SessionContext::new(); + + // Simple schema with just the fields we need + let file_schema = Arc::new(Schema::new(vec![ + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + true, + ), + Field::new("ticker", DataType::Utf8, true), + Field::new("value", DataType::Float64, true), + Field::new("date", DataType::Utf8, false), + ])); + + let df_schema = DFSchema::try_from(file_schema.clone())?; + + let timestamp = col("timestamp"); + let value = col("value"); + let ticker = col("ticker"); + let date = col("date"); + + let mock_exec = Arc::new(EmptyExec::new(file_schema.clone())); + + // Build first_value aggregate + let first_value = Arc::new( + AggregateExprBuilder::new( + datafusion_functions_aggregate::first_last::first_value_udaf(), + vec![ctx.create_physical_expr(value.clone(), &df_schema)?], + ) + .alias("first_value(value)") + .order_by(vec![PhysicalSortExpr::new( + ctx.create_physical_expr(timestamp.clone(), &df_schema)?, + SortOptions::new(false, false), + )]) + .schema(file_schema.clone()) + .build() + .expect("Failed to build first_value"), + ); + + // Build last_value aggregate + let last_value = Arc::new( + AggregateExprBuilder::new( + datafusion_functions_aggregate::first_last::last_value_udaf(), + vec![ctx.create_physical_expr(value.clone(), &df_schema)?], + ) + .alias("last_value(value)") + .order_by(vec![PhysicalSortExpr::new( + ctx.create_physical_expr(timestamp.clone(), &df_schema)?, + SortOptions::new(false, false), + )]) + .schema(file_schema.clone()) + .build() + .expect("Failed to build last_value"), + ); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![ + ( + ctx.create_physical_expr(date.clone(), &df_schema)?, + "date".to_string(), + ), + ( + ctx.create_physical_expr(ticker.clone(), &df_schema)?, + "ticker".to_string(), + ), + ]), + vec![first_value, last_value], + vec![None, None], + mock_exec, + file_schema, + ) + .expect("Failed to build partial agg"); + + // Assert that the schema field names match the expected names + let expected_field_names = vec![ + "date", + "ticker", + "first_value(value)[first_value]", + "timestamp@0", + "is_set", + "last_value(value)[last_value]", + "timestamp@0", + "is_set", + ]; + + let binding = partial_agg.schema(); + let actual_field_names: Vec<_> = binding.fields().iter().map(|f| f.name()).collect(); + assert_eq!(actual_field_names, expected_field_names); + + // Ensure that DFSchema::try_from does not fail + let partial_agg_exec_schema = DFSchema::try_from(partial_agg.schema()); + assert!( + partial_agg_exec_schema.is_ok(), + "Expected get AggregateExec schema to succeed with duplicate state fields" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/execution/coop.rs b/datafusion/core/tests/execution/coop.rs new file mode 100644 index 0000000000000..b6f406e967509 --- /dev/null +++ b/datafusion/core/tests/execution/coop.rs @@ -0,0 +1,821 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Int64Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::SortOptions; +use datafusion::common::NullEquality; +use datafusion::functions_aggregate::sum; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; +use datafusion::physical_plan; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion::physical_plan::execution_plan::Boundedness; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{exec_datafusion_err, DataFusionError, JoinType, ScalarValue}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr_common::operator::Operator; +use datafusion_expr_common::operator::Operator::{Divide, Eq, Gt, Modulo}; +use datafusion_functions_aggregate::min_max; +use datafusion_physical_expr::expressions::{ + binary, col, lit, BinaryExpr, Column, Literal, +}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::coop::make_cooperative; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::union::InterleaveExec; +use futures::StreamExt; +use parking_lot::RwLock; +use rstest::rstest; +use std::any::Any; +use std::error::Error; +use std::fmt::Formatter; +use std::ops::Range; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; +use tokio::runtime::{Handle, Runtime}; +use tokio::select; + +#[derive(Debug)] +struct RangeBatchGenerator { + schema: SchemaRef, + value_range: Range, + boundedness: Boundedness, + batch_size: usize, + poll_count: usize, +} + +impl std::fmt::Display for RangeBatchGenerator { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + // Display current counter + write!(f, "InfiniteGenerator(counter={})", self.poll_count) + } +} + +impl LazyBatchGenerator for RangeBatchGenerator { + fn as_any(&self) -> &dyn Any { + self + } + + fn boundedness(&self) -> Boundedness { + self.boundedness + } + + /// Generate the next RecordBatch. + fn generate_next_batch(&mut self) -> datafusion_common::Result> { + self.poll_count += 1; + + let mut builder = Int64Array::builder(self.batch_size); + for _ in 0..self.batch_size { + match self.value_range.next() { + None => break, + Some(v) => builder.append_value(v), + } + } + let array = builder.finish(); + + if array.is_empty() { + return Ok(None); + } + + let batch = + RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; + Ok(Some(batch)) + } +} + +fn make_lazy_exec(column_name: &str, pretend_infinite: bool) -> LazyMemoryExec { + make_lazy_exec_with_range(column_name, i64::MIN..i64::MAX, pretend_infinite) +} + +fn make_lazy_exec_with_range( + column_name: &str, + range: Range, + pretend_infinite: bool, +) -> LazyMemoryExec { + let schema = Arc::new(Schema::new(vec![Field::new( + column_name, + DataType::Int64, + false, + )])); + + let boundedness = if pretend_infinite { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + }; + + // Instantiate the generator with the batch and limit + let gen = RangeBatchGenerator { + schema: Arc::clone(&schema), + boundedness, + value_range: range, + batch_size: 8192, + poll_count: 0, + }; + + // Wrap the generator in a trait object behind Arc> + let generator: Arc> = Arc::new(RwLock::new(gen)); + + // Create a LazyMemoryExec with one partition using our generator + let mut exec = LazyMemoryExec::try_new(schema, vec![generator]).unwrap(); + + exec.add_ordering(vec![PhysicalSortExpr::new( + Arc::new(Column::new(column_name, 0)), + SortOptions::new(false, true), + )]); + + exec +} + +#[rstest] +#[tokio::test] +async fn agg_no_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation without grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new( + AggregateExprBuilder::new( + sum::sum_udaf(), + vec![col("value", &inf.schema())?], + ) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation with grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary(value_col.clone(), Divide, lit(1000000i64), &inf.schema())?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![]), + vec![Arc::new( + AggregateExprBuilder::new(sum::sum_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouped_topk_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up a top-k aggregation + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary(value_col.clone(), Divide, lit(1000000i64), &inf.schema())?; + + let aggr = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![(group, "group".to_string())], + vec![], + vec![vec![false]], + ), + vec![Arc::new( + AggregateExprBuilder::new(min_max::max_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("max") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )? + .with_limit(Some(100)), + ); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +// A test that mocks the behavior of `SpillManager::read_spill_as_stream` without file access +// to verify that a cooperative stream would properly yields in a spill file read scenario +async fn spill_reader_stream_yield() -> Result<(), Box> { + use datafusion_physical_plan::common::spawn_buffered; + + // A mock stream that always returns `Poll::Ready(Some(...))` immediately + let always_ready = + make_lazy_exec("value", false).execute(0, SessionContext::new().task_ctx())?; + + // this function makes a consumer stream that resembles how read_stream from spill file is constructed + let stream = make_cooperative(always_ready); + + // Set large buffer so that buffer always has free space for the producer/sender + let buffer_capacity = 100_000; + let mut mock_stream = spawn_buffered(stream, buffer_capacity); + let schema = mock_stream.schema(); + + let consumer_stream = futures::stream::poll_fn(move |cx| { + let mut collected = vec![]; + // To make sure that inner stream is polled multiple times, loop until the buffer is full + // Ideally, the stream will yield before the loop ends + for _ in 0..buffer_capacity { + match mock_stream.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(batch))) => { + collected.push(batch); + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))); + } + Poll::Ready(None) => { + break; + } + Poll::Pending => { + // polling inner stream may return Pending only when it reaches budget, since + // we intentionally made ProducerStream always return Ready + return Poll::Pending; + } + } + } + + // This should be unreachable since the stream is canceled + unreachable!("Expected the stream to be canceled, but it continued polling"); + }); + + let consumer_record_batch_stream = + Box::pin(RecordBatchStreamAdapter::new(schema, consumer_stream)); + + stream_yields(consumer_record_batch_stream).await +} + +#[rstest] +#[tokio::test] +async fn sort_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a SortExec that will not be able to finish in time because input is very large + let sort_expr = PhysicalSortExpr::new( + col("value", &inf.schema())?, + SortOptions { + descending: true, + nulls_first: true, + }, + ); + + let lex_ordering = LexOrdering::new(vec![sort_expr]).unwrap(); + let sort_exec = Arc::new(SortExec::new(lex_ordering, inf.clone())); + + query_yields(sort_exec, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn sort_merge_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec_with_range( + "value1", + i64::MIN..0, + pretend_infinite, + )); + let inf2 = Arc::new(make_lazy_exec_with_range( + "value2", + 0..i64::MAX, + pretend_infinite, + )); + + // set up a SortMergeJoinExec that will take a long time skipping left side content to find + // the first right side match + let join = Arc::new(SortMergeJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + JoinType::Inner, + vec![inf1.properties().eq_properties.output_ordering().unwrap()[0].options], + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a FilterExec that will filter out entire batches + let filter_expr = binary( + col("value", &inf.schema())?, + Operator::Lt, + lit(i64::MIN), + &inf.schema(), + )?; + let filter = Arc::new(FilterExec::try_new(filter_expr, inf.clone())?); + + query_yields(filter, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_reject_all_batches_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Create a Session, Schema, and an 8K-row RecordBatch + let session_ctx = SessionContext::new(); + + // Wrap this batch in an InfiniteExec + let infinite = make_lazy_exec_with_range("value", i64::MIN..0, pretend_infinite); + + // 2b) Construct a FilterExec that is always false: “value > 10000” (no rows pass) + let false_predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("value", 0)), + Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + let filtered = Arc::new(FilterExec::try_new(false_predicate, Arc::new(infinite))?); + + // Use CoalesceBatchesExec to guarantee each Filter pull always yields an 8192-row batch + let coalesced = Arc::new(CoalesceBatchesExec::new(filtered, 8_192)); + + query_yields(coalesced, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn interleave_then_filter_all_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Build a session and a schema with one i64 column. + let session_ctx = SessionContext::new(); + + // Create multiple infinite sources, each filtered by a different threshold. + // This ensures InterleaveExec has many children. + let mut infinite_children = vec![]; + + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + for threshold in 1..32 { + // One infinite exec: + let mut inf = make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “(value / 8192) % threshold == 0”. + let filter_expr = binary( + binary( + binary( + col("value", &inf.schema())?, + Divide, + lit(8192i64), + &inf.schema(), + )?, + Modulo, + lit(threshold as i64), + &inf.schema(), + )?, + Eq, + lit(0i64), + &inf.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all infinite children. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + + // Wrap the InterleaveExec in a FilterExec that always returns false, + // ensuring that no rows are ever emitted. + let always_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))); + let filtered_interleave = Arc::new(FilterExec::try_new(always_false, interleave)?); + + query_yields(filtered_interleave, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn interleave_then_aggregate_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Build session, schema, and a sample batch. + let session_ctx = SessionContext::new(); + + // Create N infinite sources, each filtered by a different predicate. + // That way, the InterleaveExec will have multiple children. + let mut infinite_children = vec![]; + + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + for threshold in 1..32 { + // One infinite exec: + let mut inf = make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “(value / 8192) % threshold == 0”. + let filter_expr = binary( + binary( + binary( + col("value", &inf.schema())?, + Divide, + lit(8192i64), + &inf.schema(), + )?, + Modulo, + lit(threshold as i64), + &inf.schema(), + )?, + Eq, + lit(0i64), + &inf.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all N children. + // Since each child now has Partitioning::Hash([col "value"], 1), InterleaveExec::try_new succeeds. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + let interleave_schema = interleave.schema(); + + // Build a global AggregateExec that sums “value” over all rows. + // Because we use `AggregateMode::Single` with no GROUP BY columns, this plan will + // only produce one “final” row once all inputs finish. But our inputs never finish, + // so we should never get any output. + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new("value", 0))], + ) + .schema(interleave_schema.clone()) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![], // no GROUP BY columns + vec![], // no GROUP BY expressions + vec![], // no GROUP BY physical expressions + ), + vec![Arc::new(aggregate_expr)], + vec![None], // no “distinct” flags + interleave, + interleave_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Create Join keys → join on “value” = “value” + let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + + // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition + let coalesced_left = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); + let coalesced_right = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); + let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_agg_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // 2b) Create Join keys → join on “value” = “value” + let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + + // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition + let coalesced_left = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); + let coalesced_right = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); + let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?); + + // Project only one column (“value” from the left side) because we just want to sum that + let input_schema = join.schema(); + + let proj_expr = vec![ProjectionExpr::new( + Arc::new(Column::new_with_schema("value", &input_schema)?) as _, + "value".to_string(), + )]; + + let projection = Arc::new(ProjectionExec::try_new(proj_expr, join)?); + let projection_schema = projection.schema(); + + let output_fields = vec![Field::new("total", DataType::Int64, true)]; + let output_schema = Arc::new(Schema::new(output_fields)); + + // 4) Global aggregate (Single) over “value” + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new_with_schema( + "value", + &projection.schema(), + )?)], + ) + .schema(output_schema) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new(aggregate_expr)], + vec![None], + projection, + projection_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec("value1", pretend_infinite)); + let inf2 = Arc::new(make_lazy_exec("value2", pretend_infinite)); + + // set up a HashJoinExec that will take a long time in the build phase + let join = Arc::new(HashJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + &JoinType::Left, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_without_repartition_and_no_agg( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Create Session, schema, and an 8K-row RecordBatch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Directly feed `infinite_left` and `infinite_right` into HashJoinExec. + // Do not use aggregation or repartition. + let join = Arc::new(HashJoinExec::try_new( + Arc::new(infinite_left), + Arc::new(infinite_right), + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + /* filter */ None, + &JoinType::Inner, + /* output64 */ None, + // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned channels. + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[derive(Debug)] +enum Yielded { + ReadyOrPending, + Err(#[allow(dead_code)] DataFusionError), + Timeout, +} + +async fn stream_yields( + mut stream: SendableRecordBatchStream, +) -> Result<(), Box> { + // Create an independent executor pool + let child_runtime = Runtime::new()?; + + // Spawn a task that tries to poll the stream + // The task returns Ready when the stream yielded with either Ready or Pending + let join_handle = child_runtime.spawn(std::future::poll_fn(move |cx| { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(_))) => Poll::Ready(Poll::Ready(Ok(()))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Poll::Ready(Err(e))), + Poll::Ready(None) => Poll::Ready(Poll::Ready(Ok(()))), + Poll::Pending => Poll::Ready(Poll::Pending), + } + })); + + let abort_handle = join_handle.abort_handle(); + + // Now select on the join handle of the task running in the child executor with a timeout + let yielded = select! { + result = join_handle => { + match result { + Ok(Pending) => Yielded::ReadyOrPending, + Ok(Ready(Ok(_))) => Yielded::ReadyOrPending, + Ok(Ready(Err(e))) => Yielded::Err(e), + Err(_) => Yielded::Err(exec_datafusion_err!("join error")), + } + }, + _ = tokio::time::sleep(Duration::from_secs(10)) => { + Yielded::Timeout + } + }; + + // Try to abort the poll task and shutdown the child runtime + abort_handle.abort(); + Handle::current().spawn_blocking(move || { + child_runtime.shutdown_timeout(Duration::from_secs(5)); + }); + + // Finally, check if poll_next yielded + assert!( + matches!(yielded, Yielded::ReadyOrPending), + "Result is not Ready or Pending: {yielded:?}" + ); + Ok(()) +} + +async fn query_yields( + plan: Arc, + task_ctx: Arc, +) -> Result<(), Box> { + // Run plan through EnsureCooperative + let optimized = + EnsureCooperative::new().optimize(plan, task_ctx.session_config().options())?; + + // Get the stream + let stream = physical_plan::execute_stream(optimized, task_ctx)?; + + // Spawn a task that tries to poll the stream and check whether given stream yields + stream_yields(stream).await +} diff --git a/datafusion/core/tests/execution/datasource_split.rs b/datafusion/core/tests/execution/datasource_split.rs new file mode 100644 index 0000000000000..0b90c6f326168 --- /dev/null +++ b/datafusion/core/tests/execution/datasource_split.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, Int32Array}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, +}; +use datafusion_datasource::memory::MemorySourceConfig; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::{common::collect, ExecutionPlan}; +use std::sync::Arc; + +/// Helper function to create a memory source with the given batch size and collect all batches +async fn create_and_collect_batches( + batch_size: usize, +) -> datafusion_common::Result> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let array = Int32Array::from_iter_values(0..batch_size as i32); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array) as ArrayRef])?; + let exec = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None)?; + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + collect(stream).await +} + +/// Helper function to create a memory source with multiple batches and collect all results +async fn create_and_collect_multiple_batches( + input_batches: Vec, +) -> datafusion_common::Result> { + let schema = input_batches[0].schema(); + let exec = MemorySourceConfig::try_new_exec(&[input_batches], schema, None)?; + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx)?; + collect(stream).await +} + +#[tokio::test] +async fn datasource_splits_large_batches() -> datafusion_common::Result<()> { + let batch_size = 20000; + let batches = create_and_collect_batches(batch_size).await?; + + assert!(batches.len() > 1); + let max = batches.iter().map(|b| b.num_rows()).max().unwrap(); + assert!( + max <= datafusion_execution::config::SessionConfig::new() + .options() + .execution + .batch_size + ); + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, batch_size); + Ok(()) +} + +#[tokio::test] +async fn datasource_exact_batch_size_no_split() -> datafusion_common::Result<()> { + let session_config = datafusion_execution::config::SessionConfig::new(); + let configured_batch_size = session_config.options().execution.batch_size; + + let batches = create_and_collect_batches(configured_batch_size).await?; + + // Should not split when exactly equal to batch_size + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), configured_batch_size); + Ok(()) +} + +#[tokio::test] +async fn datasource_small_batch_no_split() -> datafusion_common::Result<()> { + // Test with batch smaller than the batch size (8192) + let small_batch_size = 512; // Less than 8192 + + let batches = create_and_collect_batches(small_batch_size).await?; + + // Should not split small batches below the batch size + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), small_batch_size); + Ok(()) +} + +#[tokio::test] +async fn datasource_empty_batch_clean_termination() -> datafusion_common::Result<()> { + let batches = create_and_collect_batches(0).await?; + + // Empty batch should result in one empty batch + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 0); + Ok(()) +} + +#[tokio::test] +async fn datasource_multiple_empty_batches() -> datafusion_common::Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let empty_array = Int32Array::from_iter_values(std::iter::empty::()); + let empty_batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(empty_array) as ArrayRef])?; + + // Create multiple empty batches + let input_batches = vec![empty_batch.clone(), empty_batch.clone(), empty_batch]; + let batches = create_and_collect_multiple_batches(input_batches).await?; + + // Should preserve empty batches without issues + assert_eq!(batches.len(), 3); + for batch in &batches { + assert_eq!(batch.num_rows(), 0); + } + Ok(()) +} diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index b30636ddf6a81..ef2e263f2c467 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -19,15 +19,19 @@ //! create them and depend on them. Test executable semantics of logical plans. use arrow::array::Int64Array; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::datasource::{provider_as_source, ViewTable}; use datafusion::execution::session_state::SessionStateBuilder; -use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans}; +use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::logical_plan::{LogicalPlan, Values}; -use datafusion_expr::{Aggregate, AggregateUDF, Expr}; +use datafusion_expr::{ + Aggregate, AggregateUDF, EmptyRelation, Expr, LogicalPlanBuilder, UNNAMED_TABLE, +}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_plan::collect; +use insta::assert_snapshot; use std::collections::HashMap; use std::fmt::Debug; use std::ops::Deref; @@ -43,9 +47,9 @@ async fn count_only_nulls() -> Result<()> { let input = Arc::new(LogicalPlan::Values(Values { schema: input_schema, values: vec![ - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], ], })); let input_col_ref = Expr::Column(Column { @@ -64,7 +68,7 @@ async fn count_only_nulls() -> Result<()> { args: vec![input_col_ref], distinct: false, filter: None, - order_by: None, + order_by: vec![], null_treatment: None, }, })], @@ -92,7 +96,41 @@ where T: Debug, { let [element] = elements else { - panic!("Expected exactly one element, got {:?}", elements); + panic!("Expected exactly one element, got {elements:?}"); }; element } + +#[test] +fn inline_scan_projection_test() -> Result<()> { + let name = UNNAMED_TABLE; + let column = "a"; + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let projection = vec![schema.index_of(column)?]; + + let provider = ViewTable::new( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(DFSchema::try_from(schema)?), + }), + None, + ); + let source = provider_as_source(Arc::new(provider)); + + let plan = LogicalPlanBuilder::scan(name, source, Some(projection))?.build()?; + + assert_snapshot!( + format!("{plan}"), + @r" + SubqueryAlias: ?table? + Projection: a + EmptyRelation: rows=0 + " + ); + + Ok(()) +} diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs index 8169db1a4611e..8770b2a201051 100644 --- a/datafusion/core/tests/execution/mod.rs +++ b/datafusion/core/tests/execution/mod.rs @@ -15,4 +15,6 @@ // specific language governing permissions and limitations // under the License. +mod coop; +mod datasource_split; mod logical_plan; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index aef10379da074..4aee274de9083 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -24,6 +24,7 @@ use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; use datafusion::prelude::*; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::expr::NullTreatment; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; @@ -31,7 +32,6 @@ use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_functions_nested::expr_ext::{IndexAccessor, SliceAccessor}; use datafusion_optimizer::simplify_expressions::ExprSimplifier; -use sqlparser::ast::NullTreatment; /// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan use std::sync::{Arc, LazyLock}; @@ -358,8 +358,7 @@ async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { assert_eq!( expected_lines, actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); } @@ -379,8 +378,7 @@ fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { assert_eq!( expected_lines, actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); } diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 7bb21725ef401..89651726a69a4 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -17,6 +17,8 @@ //! This program demonstrates the DataFusion expression simplification API. +use insta::assert_snapshot; + use arrow::array::types::IntervalDayTime; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; @@ -237,11 +239,15 @@ fn to_timestamp_expr_folded() -> Result<()> { .project(proj)? .build()?; - let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ - \n TableScan: test" - .to_string(); - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8("2020-09-08T12:00:00+00:00")) + TableScan: test + "### + ); Ok(()) } @@ -262,11 +268,16 @@ fn now_less_than_timestamp() -> Result<()> { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - let actual = get_optimized_plan_formatted(plan, &time); - - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &time); + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r###" + Filter: Boolean(true) + TableScan: test + "### + ); Ok(()) } @@ -282,10 +293,13 @@ fn select_date_plus_interval() -> Result<()> { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema)? - + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 123, - milliseconds: 0, - }))); + + Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 123, + milliseconds: 0, + })), + None, + ); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![date_plus_interval_expr])? @@ -293,11 +307,16 @@ fn select_date_plus_interval() -> Result<()> { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = r#"Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") - TableScan: test"#; - let actual = get_optimized_plan_formatted(plan, &time); - - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &time); + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r###" + Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") + TableScan: test + "### + ); Ok(()) } @@ -311,10 +330,15 @@ fn simplify_project_scalar_fn() -> Result<()> { // before simplify: power(t.f, 1.0) // after simplify: t.f as "power(t.f, 1.0)" - let expected = "Projection: test.f AS power(test.f,Float64(1))\ - \n TableScan: test"; - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatter = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatter.trim(); + assert_snapshot!( + actual, + @r###" + Projection: test.f AS power(test.f,Float64(1)) + TableScan: test + "### + ); Ok(()) } @@ -334,9 +358,9 @@ fn simplify_scan_predicate() -> Result<()> { // before simplify: t.g = power(t.f, 1.0) // after simplify: t.g = t.f" - let expected = "TableScan: test, full_filters=[g = f]"; - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatted.trim(); + assert_snapshot!(actual, @"TableScan: test, full_filters=[g = f]"); Ok(()) } @@ -547,9 +571,9 @@ fn test_simplify_with_cycle_count( }; let simplifier = ExprSimplifier::new(info); let (simplified_expr, count) = simplifier - .simplify_with_cycle_count(input_expr.clone()) + .simplify_with_cycle_count_transformed(input_expr.clone()) .expect("successfully evaluated"); - + let simplified_expr = simplified_expr.data; assert_eq!( simplified_expr, expected_expr, "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index dcf477135a377..4e04da26f70b6 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,39 +17,44 @@ use std::sync::Arc; +use super::record_batch_generator::get_supported_types_columns; +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ - AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder, + AggregationFuzzerBuilder, DatasetGeneratorConfig, }; -use arrow::array::{types::Int64Type, Array, ArrayRef, AsArray, Int64Array, RecordBatch}; -use arrow::compute::{concat_batches, SortOptions}; -use arrow::datatypes::{ - DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, +use arrow::array::{ + types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, + StringArray, }; +use arrow::compute::concat_batches; +use arrow::datatypes::DataType; use arrow::util::pretty::pretty_format_batches; -use datafusion::common::Result; +use arrow_schema::{Field, Schema, SchemaRef}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::datasource::MemTable; -use datafusion::physical_expr::aggregate::AggregateExprBuilder; -use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, -}; -use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::HashMap; +use datafusion_common::{HashMap, Result}; use datafusion_common_runtime::JoinSet; use datafusion_functions_aggregate::sum::sum_udaf; -use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; +use datafusion_execution::memory_pool::FairSpillPool; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::metrics::MetricValue; +use datafusion_physical_plan::{collect, displayable, ExecutionPlan}; use rand::rngs::StdRng; -use rand::{thread_rng, Rng, SeedableRng}; +use rand::{random, rng, Rng, SeedableRng}; // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] @@ -78,6 +83,7 @@ async fn test_min() { .with_aggregate_function("min") // min works on all column types .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -104,6 +110,34 @@ async fn test_first_val() { .with_table_name("fuzz_table") .with_aggregate_function("first_value") .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_last_val() { + let mut data_gen_config = baseline_config(); + + for i in 0..data_gen_config.columns.len() { + if data_gen_config.columns[i].get_max_num_distinct().is_none() { + data_gen_config.columns[i] = data_gen_config.columns[i] + .clone() + // Minimize the chance of identical values in the order by columns to make the test more stable + .with_max_num_distinct(usize::MAX); + } + } + + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("last_value") + .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -123,6 +157,7 @@ async fn test_max() { .with_aggregate_function("max") // max works on all column types .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -143,6 +178,7 @@ async fn test_sum() { .with_distinct_aggregate_function("sum") // sum only works on numeric columns .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -163,6 +199,7 @@ async fn test_count() { .with_distinct_aggregate_function("count") // count work for all arguments .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -183,6 +220,7 @@ async fn test_median() { .with_distinct_aggregate_function("median") // median only works on numeric columns .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -200,82 +238,8 @@ async fn test_median() { /// 1. Floating point numbers /// 1. structured types fn baseline_config() -> DatasetGeneratorConfig { - let mut rng = thread_rng(); - let columns = vec![ - ColumnDescr::new("i8", DataType::Int8), - ColumnDescr::new("i16", DataType::Int16), - ColumnDescr::new("i32", DataType::Int32), - ColumnDescr::new("i64", DataType::Int64), - ColumnDescr::new("u8", DataType::UInt8), - ColumnDescr::new("u16", DataType::UInt16), - ColumnDescr::new("u32", DataType::UInt32), - ColumnDescr::new("u64", DataType::UInt64), - ColumnDescr::new("date32", DataType::Date32), - ColumnDescr::new("date64", DataType::Date64), - ColumnDescr::new("time32_s", DataType::Time32(TimeUnit::Second)), - ColumnDescr::new("time32_ms", DataType::Time32(TimeUnit::Millisecond)), - ColumnDescr::new("time64_us", DataType::Time64(TimeUnit::Microsecond)), - ColumnDescr::new("time64_ns", DataType::Time64(TimeUnit::Nanosecond)), - // `None` is passed in here however when generating the array, it will generate - // random timezones. - ColumnDescr::new("timestamp_s", DataType::Timestamp(TimeUnit::Second, None)), - ColumnDescr::new( - "timestamp_ms", - DataType::Timestamp(TimeUnit::Millisecond, None), - ), - ColumnDescr::new( - "timestamp_us", - DataType::Timestamp(TimeUnit::Microsecond, None), - ), - ColumnDescr::new( - "timestamp_ns", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ColumnDescr::new("float32", DataType::Float32), - ColumnDescr::new("float64", DataType::Float64), - ColumnDescr::new( - "interval_year_month", - DataType::Interval(IntervalUnit::YearMonth), - ), - ColumnDescr::new( - "interval_day_time", - DataType::Interval(IntervalUnit::DayTime), - ), - ColumnDescr::new( - "interval_month_day_nano", - DataType::Interval(IntervalUnit::MonthDayNano), - ), - // begin decimal columns - ColumnDescr::new("decimal128", { - // Generate valid precision and scale for Decimal128 randomly. - let precision: u8 = rng.gen_range(1..=DECIMAL128_MAX_PRECISION); - // It's safe to cast `precision` to i8 type directly. - let scale: i8 = rng.gen_range( - i8::MIN..=std::cmp::min(precision as i8, DECIMAL128_MAX_SCALE), - ); - DataType::Decimal128(precision, scale) - }), - ColumnDescr::new("decimal256", { - // Generate valid precision and scale for Decimal256 randomly. - let precision: u8 = rng.gen_range(1..=DECIMAL256_MAX_PRECISION); - // It's safe to cast `precision` to i8 type directly. - let scale: i8 = rng.gen_range( - i8::MIN..=std::cmp::min(precision as i8, DECIMAL256_MAX_SCALE), - ); - DataType::Decimal256(precision, scale) - }), - // begin string columns - ColumnDescr::new("utf8", DataType::Utf8), - ColumnDescr::new("largeutf8", DataType::LargeUtf8), - ColumnDescr::new("utf8view", DataType::Utf8View), - // low cardinality columns - ColumnDescr::new("u8_low", DataType::UInt8).with_max_num_distinct(10), - ColumnDescr::new("utf8_low", DataType::Utf8).with_max_num_distinct(10), - ColumnDescr::new("bool", DataType::Boolean), - ColumnDescr::new("binary", DataType::Binary), - ColumnDescr::new("large_binary", DataType::LargeBinary), - ColumnDescr::new("binaryview", DataType::BinaryView), - ]; + let mut rng = rng(); + let columns = get_supported_types_columns(rng.random()); let min_num_rows = 512; let max_num_rows = 1024; @@ -287,6 +251,12 @@ fn baseline_config() -> DatasetGeneratorConfig { // low cardinality to try and get many repeated runs vec![String::from("u8_low")], vec![String::from("utf8_low"), String::from("u8_low")], + vec![String::from("dictionary_utf8_low")], + vec![ + String::from("dictionary_utf8_low"), + String::from("utf8_low"), + String::from("u8_low"), + ], ], } } @@ -336,13 +306,9 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = LexOrdering::default(); - for ordering_col in ["a", "b", "c"] { - sort_keys.push(PhysicalSortExpr { - expr: col(ordering_col, &schema).unwrap(), - options: SortOptions::default(), - }) - } + let sort_keys = ["a", "b", "c"].map(|ordering_col| { + PhysicalSortExpr::new_default(col(ordering_col, &schema).unwrap()) + }); let concat_input_record = concat_batches(&schema, &input1).unwrap(); @@ -354,9 +320,9 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .unwrap(); let running_source = DataSourceExec::from_data_source( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None) + MemorySourceConfig::try_new(std::slice::from_ref(&input1), schema.clone(), None) .unwrap() - .try_with_sort_information(vec![sort_keys]) + .try_with_sort_information(vec![sort_keys.into()]) .unwrap(), ); @@ -371,7 +337,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str ]; let expr = group_by_columns .iter() - .map(|elem| (col(elem, &schema).unwrap(), elem.to_string())) + .map(|elem| (col(elem, &schema).unwrap(), (*elem).to_string())) .collect::>(); let group_by = PhysicalGroupBy::new_single(expr); @@ -437,7 +403,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str Left Plan:\n{}\n\ Right Plan:\n{}\n\ schema:\n{schema}\n\ - Left Ouptut:\n{}\n\ + Left Output:\n{}\n\ Right Output:\n{}\n\ input:\n{}\n\ ", @@ -464,13 +430,13 @@ pub(crate) fn make_staggered_batches( let mut input4: Vec = vec![0; len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, ) }); input4.iter_mut().for_each(|v| { - *v = rng.gen_range(0..n_distinct) as i64; + *v = rng.random_range(0..n_distinct) as i64; }); input123.sort(); let input1 = Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.0)); @@ -490,7 +456,7 @@ pub(crate) fn make_staggered_batches( let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { break; } @@ -499,7 +465,7 @@ pub(crate) fn make_staggered_batches( } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); } @@ -545,7 +511,9 @@ async fn group_by_string_test( let expected = compute_counts(&input, column_name); let schema = input[0].schema(); - let session_config = SessionConfig::new().with_batch_size(50); + let session_config = SessionConfig::new() + .with_batch_size(50) + .with_repartition_file_scans(false); let ctx = SessionContext::new_with_config(session_config); let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap(); @@ -663,3 +631,139 @@ fn extract_result_counts(results: Vec) -> HashMap, i } output } + +pub(crate) fn assert_spill_count_metric( + expect_spill: bool, + plan_that_spills: Arc, +) -> usize { + if let Some(metrics_set) = plan_that_spills.metrics() { + let mut spill_count = 0; + + // Inspect metrics for SpillCount + for metric in metrics_set.iter() { + if let MetricValue::SpillCount(count) = metric.value() { + spill_count = count.value(); + break; + } + } + + if expect_spill && spill_count == 0 { + panic!("Expected spill but SpillCount metric not found or SpillCount was 0."); + } else if !expect_spill && spill_count > 0 { + panic!("Expected no spill but found SpillCount metric with value greater than 0."); + } + + spill_count + } else { + panic!("No metrics returned from the operator; cannot verify spilling."); + } +} + +// Fix for https://github.com/apache/datafusion/issues/15530 +#[tokio::test] +async fn test_single_mode_aggregate_single_mode_aggregate_with_spill() -> Result<()> { + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::Int64, true), + Field::new("col_1", DataType::Utf8, true), + Field::new("col_2", DataType::Utf8, true), + Field::new("col_3", DataType::Utf8, true), + Field::new("col_4", DataType::Utf8, true), + Field::new("col_5", DataType::Int32, true), + Field::new("col_6", DataType::Utf8, true), + Field::new("col_7", DataType::Utf8, true), + Field::new("col_8", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![ + (Arc::new(Column::new("col_1", 1)), "col_1".to_string()), + (Arc::new(Column::new("col_7", 7)), "col_7".to_string()), + (Arc::new(Column::new("col_0", 0)), "col_0".to_string()), + (Arc::new(Column::new("col_8", 8)), "col_8".to_string()), + ]); + + fn generate_int64_array() -> ArrayRef { + Arc::new(Int64Array::from_iter_values( + (0..1024).map(|_| random::()), + )) + } + fn generate_int32_array() -> ArrayRef { + Arc::new(Int32Array::from_iter_values( + (0..1024).map(|_| random::()), + )) + } + + fn generate_string_array() -> ArrayRef { + Arc::new(StringArray::from( + (0..1024) + .map(|_| -> String { + rng() + .sample_iter::(rand::distr::StandardUniform) + .take(5) + .collect() + }) + .collect::>(), + )) + } + + fn generate_record_batch(schema: &SchemaRef) -> Result { + RecordBatch::try_new( + Arc::clone(schema), + vec![ + generate_int64_array(), + generate_string_array(), + generate_string_array(), + generate_string_array(), + generate_string_array(), + generate_int32_array(), + generate_string_array(), + generate_string_array(), + generate_string_array(), + ], + ) + .map_err(|err| err.into()) + } + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![lit(1i64)]) + .schema(Arc::clone(&scan_schema)) + .alias("SUM(1i64)") + .build()?, + )]; + + let batches = (0..5) + .map(|_| generate_record_batch(&scan_schema)) + .collect::>>()?; + + let plan: Arc = + MemorySourceConfig::try_new_exec(&[batches], Arc::clone(&scan_schema), None) + .unwrap(); + + let single_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + + let memory_pool = Arc::new(FairSpillPool::new(250000)); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(248)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )), + ); + + datafusion_physical_plan::common::collect( + single_aggregate.execute(0, Arc::clone(&task_ctx))?, + ) + .await?; + + assert_spill_count_metric(true, single_aggregate); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index 8a8aa180b3c44..2abfcd8417cbc 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -25,7 +25,7 @@ use datafusion_catalog::TableProvider; use datafusion_common::ScalarValue; use datafusion_common::{error::Result, utils::get_available_parallelism}; use datafusion_expr::col; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; @@ -43,7 +43,7 @@ use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; /// - `skip_partial parameters` /// - hint `sorted` or not /// - `spilling` or not (TODO, I think a special `MemoryPool` may be needed -/// to support this) +/// to support this) /// pub struct SessionContextGenerator { /// Current testing dataset @@ -112,7 +112,7 @@ impl SessionContextGenerator { /// Randomly generate session context pub fn generate(&self) -> Result { - let mut rng = thread_rng(); + let mut rng = rng(); let schema = self.dataset.batches[0].schema(); let batches = self.dataset.batches.clone(); let provider = MemTable::try_new(schema, vec![batches])?; @@ -123,17 +123,17 @@ impl SessionContextGenerator { // - `skip_partial`, trigger or not trigger currently for simplicity // - `sorted`, if found a sorted dataset, will or will not push down this information // - `spilling`(TODO) - let batch_size = rng.gen_range(1..=self.max_batch_size); + let batch_size = rng.random_range(1..=self.max_batch_size); - let target_partitions = rng.gen_range(1..=self.max_target_partitions); + let target_partitions = rng.random_range(1..=self.max_target_partitions); let skip_partial_params_idx = - rng.gen_range(0..self.candidate_skip_partial_params.len()); + rng.random_range(0..self.candidate_skip_partial_params.len()); let skip_partial_params = self.candidate_skip_partial_params[skip_partial_params_idx]; let (provider, sort_hint) = - if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() { + if rng.random_bool(0.5) && !self.dataset.sort_keys.is_empty() { // Sort keys exist and random to push down let sort_exprs = self .dataset diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index d61835a0804ed..753a74995d8ff 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -15,34 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use arrow::array::{ArrayRef, RecordBatch}; -use arrow::datatypes::{ - BinaryType, BinaryViewType, BooleanType, ByteArrayType, ByteViewType, DataType, - Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, LargeBinaryType, - LargeUtf8Type, Schema, StringViewType, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, Utf8Type, -}; -use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; +use arrow::array::RecordBatch; +use arrow::datatypes::DataType; +use datafusion_common::Result; use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::sorts::sort::sort_batch; -use rand::{ - rngs::{StdRng, ThreadRng}, - thread_rng, Rng, SeedableRng, -}; -use test_utils::{ - array_gen::{ - BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, - PrimitiveArrayGenerator, StringArrayGenerator, - }, - stagger_batch, -}; +use test_utils::stagger_batch; + +use crate::fuzz_cases::record_batch_generator::{ColumnDescr, RecordBatchGenerator}; /// Config for Dataset generator /// @@ -52,12 +33,12 @@ use test_utils::{ /// when you call `generate` function /// /// - `rows_num_range`, the number of rows in the datasets will be randomly generated -/// within this range +/// within this range /// /// - `sort_keys`, if `sort_keys` are defined, when you call the `generate` function, the generator -/// will generate one `base dataset` firstly. Then the `base dataset` will be sorted -/// based on each `sort_key` respectively. And finally `len(sort_keys) + 1` datasets -/// will be returned +/// will generate one `base dataset` firstly. Then the `base dataset` will be sorted +/// based on each `sort_key` respectively. And finally `len(sort_keys) + 1` datasets +/// will be returned /// #[derive(Debug, Clone)] pub struct DatasetGeneratorConfig { @@ -154,7 +135,7 @@ impl DatasetGenerator { } } - pub fn generate(&self) -> Result> { + pub fn generate(&mut self) -> Result> { let mut datasets = Vec::with_capacity(self.sort_keys_set.len() + 1); // Generate the base batch (unsorted) @@ -168,14 +149,14 @@ impl DatasetGenerator { for sort_keys in self.sort_keys_set.clone() { let sort_exprs = sort_keys .iter() - .map(|key| { - let col_expr = col(key, schema)?; - Ok(PhysicalSortExpr::new_default(col_expr)) - }) - .collect::>()?; - let sorted_batch = sort_batch(&base_batch, sort_exprs.as_ref(), None)?; - - let batches = stagger_batch(sorted_batch); + .map(|key| col(key, schema).map(PhysicalSortExpr::new_default)) + .collect::>>()?; + let batch = if let Some(ordering) = LexOrdering::new(sort_exprs) { + sort_batch(&base_batch, &ordering, None)? + } else { + base_batch.clone() + }; + let batches = stagger_batch(batch); let dataset = Dataset::new(batches, sort_keys); datasets.push(dataset); } @@ -204,553 +185,6 @@ impl Dataset { } } -#[derive(Debug, Clone)] -pub struct ColumnDescr { - /// Column name - name: String, - - /// Data type of this column - column_type: DataType, - - /// The maximum number of distinct values in this column. - /// - /// See [`ColumnDescr::with_max_num_distinct`] for more information - max_num_distinct: Option, -} - -impl ColumnDescr { - #[inline] - pub fn new(name: &str, column_type: DataType) -> Self { - Self { - name: name.to_string(), - column_type, - max_num_distinct: None, - } - } - - pub fn get_max_num_distinct(&self) -> Option { - self.max_num_distinct - } - - /// set the maximum number of distinct values in this column - /// - /// If `None`, the number of distinct values is randomly selected between 1 - /// and the number of rows. - pub fn with_max_num_distinct(mut self, num_distinct: usize) -> Self { - self.max_num_distinct = Some(num_distinct); - self - } -} - -/// Record batch generator -struct RecordBatchGenerator { - min_rows_nun: usize, - - max_rows_num: usize, - - columns: Vec, - - candidate_null_pcts: Vec, -} - -macro_rules! generate_string_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - let max_len = $BATCH_GEN_RNG.gen_range(1..50); - - let mut generator = StringArrayGenerator { - max_len, - num_strings: $NUM_ROWS, - num_distinct_strings: $MAX_NUM_DISTINCT, - null_pct, - rng: $ARRAY_GEN_RNG, - }; - - match $ARROW_TYPE::DATA_TYPE { - DataType::Utf8 => generator.gen_data::(), - DataType::LargeUtf8 => generator.gen_data::(), - DataType::Utf8View => generator.gen_string_view(), - _ => unreachable!(), - } - }}; -} - -macro_rules! generate_decimal_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - - let mut generator = DecimalArrayGenerator { - precision: $PRECISION, - scale: $SCALE, - num_decimals: $NUM_ROWS, - num_distinct_decimals: $MAX_NUM_DISTINCT, - null_pct, - rng: $ARRAY_GEN_RNG, - }; - - generator.gen_data::<$ARROW_TYPE>() - }}; -} - -// Generating `BooleanArray` due to it being a special type in Arrow (bit-packed) -macro_rules! generate_boolean_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ - // Select a null percentage from the candidate percentages - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - - let num_distinct_booleans = if $MAX_NUM_DISTINCT >= 2 { 2 } else { 1 }; - - let mut generator = BooleanArrayGenerator { - num_booleans: $NUM_ROWS, - num_distinct_booleans, - null_pct, - rng: $ARRAY_GEN_RNG, - }; - - generator.gen_data::<$ARROW_TYPE>() - }}; -} - -macro_rules! generate_primitive_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - - let mut generator = PrimitiveArrayGenerator { - num_primitives: $NUM_ROWS, - num_distinct_primitives: $MAX_NUM_DISTINCT, - null_pct, - rng: $ARRAY_GEN_RNG, - }; - - generator.gen_data::<$ARROW_TYPE>() - }}; -} - -macro_rules! generate_binary_array { - ( - $SELF:ident, - $NUM_ROWS:ident, - $MAX_NUM_DISTINCT:expr, - $BATCH_GEN_RNG:ident, - $ARRAY_GEN_RNG:ident, - $ARROW_TYPE:ident - ) => {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - - let max_len = $BATCH_GEN_RNG.gen_range(1..100); - - let mut generator = BinaryArrayGenerator { - max_len, - num_binaries: $NUM_ROWS, - num_distinct_binaries: $MAX_NUM_DISTINCT, - null_pct, - rng: $ARRAY_GEN_RNG, - }; - - match $ARROW_TYPE::DATA_TYPE { - DataType::Binary => generator.gen_data::(), - DataType::LargeBinary => generator.gen_data::(), - DataType::BinaryView => generator.gen_binary_view(), - _ => unreachable!(), - } - }}; -} - -impl RecordBatchGenerator { - fn new(min_rows_nun: usize, max_rows_num: usize, columns: Vec) -> Self { - let candidate_null_pcts = vec![0.0, 0.01, 0.1, 0.5]; - - Self { - min_rows_nun, - max_rows_num, - columns, - candidate_null_pcts, - } - } - - fn generate(&self) -> Result { - let mut rng = thread_rng(); - let num_rows = rng.gen_range(self.min_rows_nun..=self.max_rows_num); - let array_gen_rng = StdRng::from_seed(rng.gen()); - - // Build arrays - let mut arrays = Vec::with_capacity(self.columns.len()); - for col in self.columns.iter() { - let array = self.generate_array_of_type( - col, - num_rows, - &mut rng, - array_gen_rng.clone(), - ); - arrays.push(array); - } - - // Build schema - let fields = self - .columns - .iter() - .map(|col| Field::new(col.name.clone(), col.column_type.clone(), true)) - .collect::>(); - let schema = Arc::new(Schema::new(fields)); - - RecordBatch::try_new(schema, arrays).map_err(|e| arrow_datafusion_err!(e)) - } - - fn generate_array_of_type( - &self, - col: &ColumnDescr, - num_rows: usize, - batch_gen_rng: &mut ThreadRng, - array_gen_rng: StdRng, - ) -> ArrayRef { - let num_distinct = if num_rows > 1 { - batch_gen_rng.gen_range(1..num_rows) - } else { - num_rows - }; - // cap to at most the num_distinct values - let max_num_distinct = col - .max_num_distinct - .map(|max| num_distinct.min(max)) - .unwrap_or(num_distinct); - - match col.column_type { - DataType::Int8 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Int8Type - ) - } - DataType::Int16 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Int16Type - ) - } - DataType::Int32 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Int32Type - ) - } - DataType::Int64 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Int64Type - ) - } - DataType::UInt8 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - UInt8Type - ) - } - DataType::UInt16 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - UInt16Type - ) - } - DataType::UInt32 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - UInt32Type - ) - } - DataType::UInt64 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - UInt64Type - ) - } - DataType::Float32 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Float32Type - ) - } - DataType::Float64 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Float64Type - ) - } - DataType::Date32 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Date32Type - ) - } - DataType::Date64 => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Date64Type - ) - } - DataType::Time32(TimeUnit::Second) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Time32SecondType - ) - } - DataType::Time32(TimeUnit::Millisecond) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Time32MillisecondType - ) - } - DataType::Time64(TimeUnit::Microsecond) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Time64MicrosecondType - ) - } - DataType::Time64(TimeUnit::Nanosecond) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Time64NanosecondType - ) - } - DataType::Interval(IntervalUnit::YearMonth) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - IntervalYearMonthType - ) - } - DataType::Interval(IntervalUnit::DayTime) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - IntervalDayTimeType - ) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - IntervalMonthDayNanoType - ) - } - DataType::Timestamp(TimeUnit::Second, None) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - TimestampSecondType - ) - } - DataType::Timestamp(TimeUnit::Millisecond, None) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - TimestampMillisecondType - ) - } - DataType::Timestamp(TimeUnit::Microsecond, None) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - TimestampMicrosecondType - ) - } - DataType::Timestamp(TimeUnit::Nanosecond, None) => { - generate_primitive_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - TimestampNanosecondType - ) - } - DataType::Binary => { - generate_binary_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - BinaryType - ) - } - DataType::LargeBinary => { - generate_binary_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - LargeBinaryType - ) - } - DataType::BinaryView => { - generate_binary_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - BinaryViewType - ) - } - DataType::Decimal128(precision, scale) => { - generate_decimal_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - precision, - scale, - Decimal128Type - ) - } - DataType::Decimal256(precision, scale) => { - generate_decimal_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - precision, - scale, - Decimal256Type - ) - } - DataType::Utf8 => { - generate_string_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - Utf8Type - ) - } - DataType::LargeUtf8 => { - generate_string_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - LargeUtf8Type - ) - } - DataType::Utf8View => { - generate_string_array!( - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - StringViewType - ) - } - DataType::Boolean => { - generate_boolean_array! { - self, - num_rows, - max_num_distinct, - batch_gen_rng, - array_gen_rng, - BooleanType - } - } - _ => { - panic!("Unsupported data generator type: {}", col.column_type) - } - } - } -} - #[cfg(test)] mod test { use arrow::array::UInt32Array; @@ -777,7 +211,7 @@ mod test { sort_keys_set: vec![vec!["b".to_string()]], }; - let gen = DatasetGenerator::new(config); + let mut gen = DatasetGenerator::new(config); let datasets = gen.generate().unwrap(); // Should Generate 2 datasets diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index bb24fb554d65a..b90b3e5e32df7 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -16,15 +16,14 @@ // under the License. use std::sync::Arc; -use std::{collections::HashSet, str::FromStr}; use arrow::array::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_datafusion_err, Result}; use datafusion_common_runtime::JoinSet; -use rand::seq::SliceRandom; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ check_equality_of_batches, context_generator::{SessionContextGenerator, SessionContextWithParams}, @@ -69,30 +68,16 @@ impl AggregationFuzzerBuilder { /// - 3 random queries /// - 3 random queries for each group by selected from the sort keys /// - 1 random query with no grouping - pub fn add_query_builder(mut self, mut query_builder: QueryBuilder) -> Self { - const NUM_QUERIES: usize = 3; - for _ in 0..NUM_QUERIES { - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); - } - // also add several queries limited to grouping on the group by columns only, if any - // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b` - if let Some(data_gen_config) = &self.data_gen_config { - for sort_keys in &data_gen_config.sort_keys_set { - let group_by_columns = sort_keys.iter().map(|s| s.as_str()); - query_builder = query_builder.set_group_by_columns(group_by_columns); - for _ in 0..NUM_QUERIES { - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); - } - } - } - // also add a query with no grouping - query_builder = query_builder.set_group_by_columns(vec![]); - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); + pub fn add_query_builder(mut self, query_builder: QueryBuilder) -> Self { + self = self.table_name(query_builder.table_name()); + + let sqls = query_builder + .generate_queries() + .into_iter() + .map(|sql| Arc::from(sql.as_str())); + self.candidate_sqls.extend(sqls); - self.table_name(query_builder.table_name()) + self } pub fn table_name(mut self, table_name: &str) -> Self { @@ -164,7 +149,7 @@ struct QueryGroup { impl AggregationFuzzer { /// Run the fuzzer, printing an error and panicking if any of the tasks fail - pub async fn run(&self) { + pub async fn run(&mut self) { let res = self.run_inner().await; if let Err(e) = res { @@ -176,9 +161,9 @@ impl AggregationFuzzer { } } - async fn run_inner(&self) -> Result<()> { + async fn run_inner(&mut self) -> Result<()> { let mut join_set = JoinSet::new(); - let mut rng = thread_rng(); + let mut rng = rng(); // Loop to generate datasets and its query for _ in 0..self.data_gen_rounds { @@ -186,13 +171,13 @@ impl AggregationFuzzer { let datasets = self .dataset_generator .generate() - .expect("should success to generate dataset"); + .expect("should succeed to generate dataset"); // Then for each of them, we random select a test sql for it let query_groups = datasets .into_iter() .map(|dataset| { - let sql_idx = rng.gen_range(0..self.candidate_sqls.len()); + let sql_idx = rng.random_range(0..self.candidate_sqls.len()); let sql = self.candidate_sqls[sql_idx].clone(); QueryGroup { dataset, sql } @@ -212,10 +197,7 @@ impl AggregationFuzzer { while let Some(join_handle) = join_set.join_next().await { // propagate errors join_handle.map_err(|e| { - DataFusionError::Internal(format!( - "AggregationFuzzer task error: {:?}", - e - )) + internal_datafusion_err!("AggregationFuzzer task error: {e:?}") })??; } Ok(()) @@ -234,16 +216,16 @@ impl AggregationFuzzer { // Generate the baseline context, and get the baseline result firstly let baseline_ctx_with_params = ctx_generator .generate_baseline() - .expect("should success to generate baseline session context"); + .expect("should succeed to generate baseline session context"); let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx) .await - .expect("should success to run baseline sql"); + .expect("should succeed to run baseline sql"); let baseline_result = Arc::new(baseline_result); // Generate test tasks for _ in 0..CTX_GEN_ROUNDS { let ctx_with_params = ctx_generator .generate() - .expect("should success to generate session context"); + .expect("should succeed to generate session context"); let task = AggregationFuzzTestTask { dataset_ref: dataset_ref.clone(), expected_result: baseline_result.clone(), @@ -270,7 +252,7 @@ impl AggregationFuzzer { /// - `sql`, the selected test sql /// /// - `dataset_ref`, the input dataset, store it for error reported when found -/// the inconsistency between the one for `ctx` and `expected results`. +/// the inconsistency between the one for `ctx` and `expected results`. /// struct AggregationFuzzTestTask { /// Generated session context in current test case @@ -326,7 +308,7 @@ impl AggregationFuzzTestTask { format_batches_with_limit(expected_result), format_batches_with_limit(&self.dataset_ref.batches), ); - DataFusionError::Internal(message) + internal_datafusion_err!("{message}") }) } @@ -371,215 +353,3 @@ fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display pretty_format_batches(&to_print).unwrap() } - -/// Random aggregate query builder -/// -/// Creates queries like -/// ```sql -/// SELECT AGG(..) FROM table_name GROUP BY -///``` -#[derive(Debug, Default, Clone)] -pub struct QueryBuilder { - /// The name of the table to query - table_name: String, - /// Aggregate functions to be used in the query - /// (function_name, is_distinct) - aggregate_functions: Vec<(String, bool)>, - /// Columns to be used in group by - group_by_columns: Vec, - /// Possible columns for arguments in the aggregate functions - /// - /// Assumes each - arguments: Vec, -} -impl QueryBuilder { - pub fn new() -> Self { - Default::default() - } - - /// return the table name if any - pub fn table_name(&self) -> &str { - &self.table_name - } - - /// Set the table name for the query builder - pub fn with_table_name(mut self, table_name: impl Into) -> Self { - self.table_name = table_name.into(); - self - } - - /// Add a new possible aggregate function to the query builder - pub fn with_aggregate_function( - mut self, - aggregate_function: impl Into, - ) -> Self { - self.aggregate_functions - .push((aggregate_function.into(), false)); - self - } - - /// Add a new possible `DISTINCT` aggregate function to the query - /// - /// This is different than `with_aggregate_function` because only certain - /// aggregates support `DISTINCT` - pub fn with_distinct_aggregate_function( - mut self, - aggregate_function: impl Into, - ) -> Self { - self.aggregate_functions - .push((aggregate_function.into(), true)); - self - } - - /// Set the columns to be used in the group bys clauses - pub fn set_group_by_columns<'a>( - mut self, - group_by: impl IntoIterator, - ) -> Self { - self.group_by_columns = group_by.into_iter().map(String::from).collect(); - self - } - - /// Add one or more columns to be used as an argument in the aggregate functions - pub fn with_aggregate_arguments<'a>( - mut self, - arguments: impl IntoIterator, - ) -> Self { - let arguments = arguments.into_iter().map(String::from); - self.arguments.extend(arguments); - self - } - - pub fn generate_query(&self) -> String { - let group_by = self.random_group_by(); - let mut query = String::from("SELECT "); - query.push_str(&group_by.join(", ")); - if !group_by.is_empty() { - query.push_str(", "); - } - query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); - query.push_str(" FROM "); - query.push_str(&self.table_name); - if !group_by.is_empty() { - query.push_str(" GROUP BY "); - query.push_str(&group_by.join(", ")); - } - query - } - - /// Generate a some random aggregate function invocations (potentially repeating). - /// - /// Each aggregate function invocation is of the form - /// - /// ```sql - /// function_name( argument) as alias - /// ``` - /// - /// where - /// * `function_names` are randomly selected from [`Self::aggregate_functions`] - /// * ` argument` is randomly selected from [`Self::arguments`] - /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) - fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { - const MAX_NUM_FUNCTIONS: usize = 5; - let mut rng = thread_rng(); - let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); - - let mut alias_gen = 1; - - let mut aggregate_functions = vec![]; - - let mut order_by_black_list: HashSet = - group_by_cols.iter().cloned().collect(); - // remove one random col - if let Some(first) = order_by_black_list.iter().next().cloned() { - order_by_black_list.remove(&first); - } - - while aggregate_functions.len() < num_aggregate_functions { - let idx = rng.gen_range(0..self.aggregate_functions.len()); - let (function_name, is_distinct) = &self.aggregate_functions[idx]; - let argument = self.random_argument(); - let alias = format!("col{}", alias_gen); - let distinct = if *is_distinct { "DISTINCT " } else { "" }; - alias_gen += 1; - - let (order_by, null_opt) = if function_name.eq("first_value") { - ( - self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ - self.null_opt(), - ) - } else { - ("".to_string(), "".to_string()) - }; - - let function = format!( - "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" - ); - aggregate_functions.push(function); - } - aggregate_functions - } - - /// Pick a random aggregate function argument - fn random_argument(&self) -> String { - let mut rng = thread_rng(); - let idx = rng.gen_range(0..self.arguments.len()); - self.arguments[idx].clone() - } - - fn order_by(&self, black_list: &HashSet) -> String { - let mut available_columns: Vec = self - .arguments - .iter() - .filter(|col| !black_list.contains(*col)) - .cloned() - .collect(); - - available_columns.shuffle(&mut thread_rng()); - - let num_of_order_by_col = 12; - let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); - - let selected_columns = &available_columns[0..column_count]; - - let mut rng = thread_rng(); - let mut result = String::from_str(" order by ").unwrap(); - for col in selected_columns { - let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; - result.push_str(&format!("{} {},", col, order)); - } - - result.strip_suffix(",").unwrap().to_string() - } - - fn null_opt(&self) -> String { - if thread_rng().gen_bool(0.5) { - "RESPECT NULLS".to_string() - } else { - "IGNORE NULLS".to_string() - } - } - - /// Pick a random number of fields to group by (non-repeating) - /// - /// Limited to 3 group by columns to ensure coverage for large groups. With - /// larger numbers of columns, each group has many fewer values. - fn random_group_by(&self) -> Vec { - let mut rng = thread_rng(); - const MAX_GROUPS: usize = 3; - let max_groups = self.group_by_columns.len().max(MAX_GROUPS); - let num_group_by = rng.gen_range(1..max_groups); - - let mut already_used = HashSet::new(); - let mut group_by = vec![]; - while group_by.len() < num_group_by - && already_used.len() != self.group_by_columns.len() - { - let idx = rng.gen_range(0..self.group_by_columns.len()); - if already_used.insert(idx) { - group_by.push(self.group_by_columns[idx].clone()); - } - } - group_by - } -} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs index 1e42ac1f4b30b..e7ce557d2267d 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -43,8 +43,10 @@ use datafusion_common::error::Result; mod context_generator; mod data_generator; mod fuzzer; +pub mod query_builder; -pub use data_generator::{ColumnDescr, DatasetGeneratorConfig}; +pub use crate::fuzz_cases::record_batch_generator::ColumnDescr; +pub use data_generator::DatasetGeneratorConfig; pub use fuzzer::*; #[derive(Debug)] @@ -75,8 +77,8 @@ pub(crate) fn check_equality_of_batches( if lhs_row != rhs_row { return Err(InconsistentResult { row_idx, - lhs_row: lhs_row.to_string(), - rhs_row: rhs_row.to_string(), + lhs_row: (*lhs_row).to_string(), + rhs_row: (*rhs_row).to_string(), }); } } diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs new file mode 100644 index 0000000000000..209278385b7b5 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::HashSet, str::FromStr}; + +use rand::{rng, seq::SliceRandom, Rng}; + +/// Random aggregate query builder +/// +/// Creates queries like +/// ```sql +/// SELECT AGG(..) FROM table_name GROUP BY +///``` +#[derive(Debug, Default, Clone)] +pub struct QueryBuilder { + // =================================== + // Table settings + // =================================== + /// The name of the table to query + table_name: String, + + // =================================== + // Grouping settings + // =================================== + /// Columns to be used in randomly generate `groupings` + /// + /// # Example + /// + /// Columns: + /// + /// ```text + /// [a,b,c,d] + /// ``` + /// + /// And randomly generated `groupings` (at least 1 column) + /// can be: + /// + /// ```text + /// [a] + /// [a,b] + /// [a,b,d] + /// ... + /// ``` + /// + /// So the finally generated sqls will be: + /// + /// ```text + /// SELECT aggr FROM t GROUP BY a; + /// SELECT aggr FROM t GROUP BY a,b; + /// SELECT aggr FROM t GROUP BY a,b,d; + /// ... + /// ``` + group_by_columns: Vec, + + /// Max columns num in randomly generated `groupings` + max_group_by_columns: usize, + + /// Min columns num in randomly generated `groupings` + min_group_by_columns: usize, + + /// The sort keys of dataset + /// + /// Due to optimizations will be triggered when all or some + /// grouping columns are the sort keys of dataset. + /// So it is necessary to randomly generate some `groupings` basing on + /// dataset sort keys for test coverage. + /// + /// # Example + /// + /// Dataset including columns [a,b,c], and sorted by [a,b] + /// + /// And we may generate sqls to try covering the sort-optimization cases like: + /// + /// ```text + /// SELECT aggr FROM t GROUP BY b; // no permutation case + /// SELECT aggr FROM t GROUP BY a,c; // partial permutation case + /// SELECT aggr FROM t GROUP BY a,b,c; // full permutation case + /// ... + /// ``` + /// + /// More details can see [`GroupOrdering`]. + /// + /// [`GroupOrdering`]: datafusion_physical_plan::aggregates::order::GroupOrdering + /// + dataset_sort_keys: Vec>, + + /// If we will also test the no grouping case like: + /// + /// ```text + /// SELECT aggr FROM t; + /// ``` + /// + no_grouping: bool, + + // ==================================== + // Aggregation function settings + // ==================================== + /// Aggregate functions to be used in the query + /// (function_name, is_distinct) + aggregate_functions: Vec<(String, bool)>, + + /// Possible columns for arguments in the aggregate functions + /// + /// Assumes each + arguments: Vec, +} + +impl QueryBuilder { + pub fn new() -> Self { + Self { + no_grouping: true, + max_group_by_columns: 5, + min_group_by_columns: 1, + ..Default::default() + } + } + + /// return the table name if any + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Set the table name for the query builder + pub fn with_table_name(mut self, table_name: impl Into) -> Self { + self.table_name = table_name.into(); + self + } + + /// Add a new possible aggregate function to the query builder + pub fn with_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), false)); + self + } + + /// Add a new possible `DISTINCT` aggregate function to the query + /// + /// This is different than `with_aggregate_function` because only certain + /// aggregates support `DISTINCT` + pub fn with_distinct_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), true)); + self + } + + /// Set the columns to be used in the group bys clauses + pub fn set_group_by_columns<'a>( + mut self, + group_by: impl IntoIterator, + ) -> Self { + self.group_by_columns = group_by.into_iter().map(String::from).collect(); + self + } + + /// Add one or more columns to be used as an argument in the aggregate functions + pub fn with_aggregate_arguments<'a>( + mut self, + arguments: impl IntoIterator, + ) -> Self { + let arguments = arguments.into_iter().map(String::from); + self.arguments.extend(arguments); + self + } + + /// Add max columns num in group by(default: 3), for example if it is set to 1, + /// the generated sql will group by at most 1 column + #[allow(dead_code)] + pub fn with_max_group_by_columns(mut self, max_group_by_columns: usize) -> Self { + self.max_group_by_columns = max_group_by_columns; + self + } + + #[allow(dead_code)] + pub fn with_min_group_by_columns(mut self, min_group_by_columns: usize) -> Self { + self.min_group_by_columns = min_group_by_columns; + self + } + + /// Add sort keys of dataset if any, then the builder will generate queries basing on it + /// to cover the sort-optimization cases + pub fn with_dataset_sort_keys(mut self, dataset_sort_keys: Vec>) -> Self { + self.dataset_sort_keys = dataset_sort_keys; + self + } + + /// Add if also test the no grouping aggregation case(default: true) + #[allow(dead_code)] + pub fn with_no_grouping(mut self, no_grouping: bool) -> Self { + self.no_grouping = no_grouping; + self + } + + pub fn generate_queries(mut self) -> Vec { + const NUM_QUERIES: usize = 3; + let mut sqls = Vec::new(); + + // Add several queries group on randomly picked columns + for _ in 0..NUM_QUERIES { + let sql = self.generate_query(); + sqls.push(sql); + } + + // Also add several queries limited to grouping on the group by + // dataset sorted columns only, if any. + // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b`. + if !self.dataset_sort_keys.is_empty() { + let dataset_sort_keys = self.dataset_sort_keys.clone(); + for sort_keys in dataset_sort_keys { + let group_by_columns = sort_keys.iter().map(|s| s.as_str()); + self = self.set_group_by_columns(group_by_columns); + for _ in 0..NUM_QUERIES { + let sql = self.generate_query(); + sqls.push(sql); + } + } + } + + // Also add a query with no grouping + if self.no_grouping { + self = self.set_group_by_columns(vec![]); + let sql = self.generate_query(); + sqls.push(sql); + } + + sqls + } + + fn generate_query(&self) -> String { + let group_by = self.random_group_by(); + dbg!(&group_by); + let mut query = String::from("SELECT "); + query.push_str(&group_by.join(", ")); + if !group_by.is_empty() { + query.push_str(", "); + } + query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); + query.push_str(" FROM "); + query.push_str(&self.table_name); + if !group_by.is_empty() { + query.push_str(" GROUP BY "); + query.push_str(&group_by.join(", ")); + } + query + } + + /// Generate a some random aggregate function invocations (potentially repeating). + /// + /// Each aggregate function invocation is of the form + /// + /// ```sql + /// function_name( argument) as alias + /// ``` + /// + /// where + /// * `function_names` are randomly selected from [`Self::aggregate_functions`] + /// * ` argument` is randomly selected from [`Self::arguments`] + /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) + fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { + const MAX_NUM_FUNCTIONS: usize = 5; + let mut rng = rng(); + let num_aggregate_functions = rng.random_range(1..=MAX_NUM_FUNCTIONS); + + let mut alias_gen = 1; + + let mut aggregate_functions = vec![]; + + let mut order_by_black_list: HashSet = + group_by_cols.iter().cloned().collect(); + // remove one random col + if let Some(first) = order_by_black_list.iter().next().cloned() { + order_by_black_list.remove(&first); + } + + while aggregate_functions.len() < num_aggregate_functions { + let idx = rng.random_range(0..self.aggregate_functions.len()); + let (function_name, is_distinct) = &self.aggregate_functions[idx]; + let argument = self.random_argument(); + let alias = format!("col{alias_gen}"); + let distinct = if *is_distinct { "DISTINCT " } else { "" }; + alias_gen += 1; + + let (order_by, null_opt) = if function_name.eq("first_value") + || function_name.eq("last_value") + { + ( + self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ + self.null_opt(), + ) + } else { + ("".to_string(), "".to_string()) + }; + + let function = format!( + "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" + ); + aggregate_functions.push(function); + } + aggregate_functions + } + + /// Pick a random aggregate function argument + fn random_argument(&self) -> String { + let mut rng = rng(); + let idx = rng.random_range(0..self.arguments.len()); + self.arguments[idx].clone() + } + + fn order_by(&self, black_list: &HashSet) -> String { + let mut available_columns: Vec = self + .arguments + .iter() + .filter(|col| !black_list.contains(*col)) + .cloned() + .collect(); + + available_columns.shuffle(&mut rng()); + + let num_of_order_by_col = 12; + let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); + + let selected_columns = &available_columns[0..column_count]; + + let mut rng = rng(); + let mut result = String::from_str(" order by ").unwrap(); + for col in selected_columns { + let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" }; + result.push_str(&format!("{col} {order},")); + } + + result.strip_suffix(",").unwrap().to_string() + } + + fn null_opt(&self) -> String { + if rng().random_bool(0.5) { + "RESPECT NULLS".to_string() + } else { + "IGNORE NULLS".to_string() + } + } + + /// Pick a random number of fields to group by (non-repeating) + /// + /// Limited to `max_group_by_columns` group by columns to ensure coverage for large groups. + /// With larger numbers of columns, each group has many fewer values. + fn random_group_by(&self) -> Vec { + let mut rng = rng(); + let min_groups = self.min_group_by_columns; + let max_groups = self.max_group_by_columns; + assert!(min_groups <= max_groups); + let num_group_by = rng.random_range(min_groups..=max_groups); + + let mut already_used = HashSet::new(); + let mut group_by = vec![]; + while group_by.len() < num_group_by + && already_used.len() != self.group_by_columns.len() + { + let idx = rng.random_range(0..self.group_by_columns.len()); + if already_used.insert(idx) { + group_by.push(self.group_by_columns[idx].clone()); + } + } + group_by + } +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index 769deef1187d6..171839b390ffa 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -16,13 +16,17 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - convert_to_orderings, create_random_schema, create_test_params, create_test_schema_2, + create_random_schema, create_test_params, create_test_schema_2, generate_table_for_eq_properties, generate_table_for_orderings, is_table_same_after_sort, TestScalarUDF, }; use arrow::compute::SortOptions; +use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::equivalence::{ + convert_to_orderings, convert_to_sort_exprs, +}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -55,30 +59,27 @@ fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { col("f", &test_schema)?, ]; - for n_req in 0..=col_exprs.len() { + for n_req in 1..=col_exprs.len() { for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs + let sort_exprs = exprs .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); + .map(|expr| PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS)); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), + ordering.clone(), + &table_data_with_properties, )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties {}", - requirement, expected, eq_properties + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties {eq_properties}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - eq_properties.ordering_satisfy(requirement.as_ref()), + eq_properties.ordering_satisfy(ordering)?, expected, - "{}", - err_msg + "{err_msg}" ); } } @@ -110,6 +111,7 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, + Arc::new(ConfigOptions::default()), )?); let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, @@ -127,31 +129,28 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { a_plus_b, ]; - for n_req in 0..=exprs.len() { + for n_req in 1..=exprs.len() { for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs + let sort_exprs = exprs .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); + .map(|expr| PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS)); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), + ordering.clone(), + &table_data_with_properties, )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}", - requirement, expected, eq_properties, + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties: {eq_properties}", ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - eq_properties.ordering_satisfy(requirement.as_ref()), + eq_properties.ordering_satisfy(ordering)?, (expected | false), - "{}", - err_msg + "{err_msg}" ); } } @@ -304,25 +303,19 @@ fn test_ordering_satisfy_with_equivalence() -> Result<()> { ]; for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options, - }) - .collect::(); + let err_msg = format!("Error in test case: {cols:?}"); + let sort_exprs = convert_to_sort_exprs(&cols); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; // Check expected result with experimental result. assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, + is_table_same_after_sort(ordering.clone(), &table_data_with_properties)?, expected ); assert_eq!( - eq_properties.ordering_satisfy(required.as_ref()), + eq_properties.ordering_satisfy(ordering)?, expected, "{err_msg}" ); @@ -375,7 +368,7 @@ fn test_ordering_satisfy_on_data() -> Result<()> { (col_d, option_asc), ]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(!is_table_same_after_sort(ordering, batch.clone())?); + assert!(!is_table_same_after_sort(ordering, &batch)?); // [a ASC, b ASC, d ASC] cannot be deduced let ordering = vec![ @@ -384,12 +377,12 @@ fn test_ordering_satisfy_on_data() -> Result<()> { (col_d, option_asc), ]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(!is_table_same_after_sort(ordering, batch.clone())?); + assert!(!is_table_same_after_sort(ordering, &batch)?); // [a ASC, b ASC] can be deduced let ordering = vec![(col_a, option_asc), (col_b, option_asc)]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(is_table_same_after_sort(ordering, batch.clone())?); + assert!(is_table_same_after_sort(ordering, &batch)?); Ok(()) } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index a3fa1157b38f4..69639b3e09fdf 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -20,6 +20,7 @@ use crate::fuzz_cases::equivalence::utils::{ is_table_same_after_sort, TestScalarUDF, }; use arrow::compute::SortOptions; +use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::equivalence::ProjectionMapping; @@ -49,6 +50,7 @@ fn project_orderings_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, + Arc::new(ConfigOptions::default()), )?); // a + b let a_plus_b = Arc::new(BinaryExpr::new( @@ -71,7 +73,7 @@ fn project_orderings_random() -> Result<()> { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .map(|(expr, name)| (Arc::clone(expr), (*name).to_string())) .collect::>(); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), @@ -82,16 +84,12 @@ fn project_orderings_random() -> Result<()> { // Make sure each ordering after projection is valid. for ordering in projected_eq.oeq_class().iter() { let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties {}, proj_exprs: {:?}", - ordering, eq_properties, proj_exprs, + "Error in test case ordering:{ordering:?}, eq_properties {eq_properties}, proj_exprs: {proj_exprs:?}", ); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, + is_table_same_after_sort(ordering.clone(), &projected_batch)?, "{}", err_msg ); @@ -126,6 +124,7 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, + Arc::new(ConfigOptions::default()), )?) as PhysicalExprRef; // a + b let a_plus_b = Arc::new(BinaryExpr::new( @@ -148,8 +147,7 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); + .map(|(expr, name)| (Arc::clone(expr), (*name).to_string())); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), &table_data_with_properties, @@ -157,37 +155,36 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { )?; let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + ProjectionMapping::try_new(proj_exprs, &test_schema)?; let projected_exprs = projection_mapping .iter() - .map(|(_source, target)| Arc::clone(target)) + .flat_map(|(_, targets)| { + targets.iter().map(|(target, _)| Arc::clone(target)) + }) .collect::>(); - for n_req in 0..=projected_exprs.len() { + for n_req in 1..=projected_exprs.len() { for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; + let sort_exprs = exprs.into_iter().map(|expr| { + PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS) + }); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!( + "Test should always produce non-degenerate orderings" + ); + }; + let expected = + is_table_same_after_sort(ordering.clone(), &projected_batch)?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}, projected_eq: {}, projection_mapping: {:?}", - requirement, expected, eq_properties, projected_eq, projection_mapping + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties: {eq_properties}, projected_eq: {projected_eq}, projection_mapping: {projection_mapping:?}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - projected_eq.ordering_satisfy(requirement.as_ref()), + projected_eq.ordering_satisfy(ordering)?, expected, - "{}", - err_msg + "{err_msg}" ); } } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index 593e1c6c2dca0..382c4da943219 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::fuzz_cases::equivalence::utils::{ create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, TestScalarUDF, }; + use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; -use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{LexOrdering, ScalarFunctionExpr}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + +use datafusion_common::config::ConfigOptions; use itertools::Itertools; -use std::sync::Arc; #[test] fn test_find_longest_permutation_random() -> Result<()> { @@ -47,13 +50,14 @@ fn test_find_longest_permutation_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, - )?) as PhysicalExprRef; + Arc::new(ConfigOptions::default()), + )?) as _; let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, Operator::Plus, col("b", &test_schema)?, - )) as Arc; + )) as _; let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, @@ -68,33 +72,32 @@ fn test_find_longest_permutation_random() -> Result<()> { for n_req in 0..=exprs.len() { for exprs in exprs.iter().combinations(n_req) { let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = eq_properties.find_longest_permutation(&exprs); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs)?; // Make sure that find_longest_permutation return values are consistent let ordering2 = indices .iter() .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options: sort_expr.options, + .map(|(&idx, sort_expr)| { + PhysicalSortExpr::new(Arc::clone(&exprs[idx]), sort_expr.options) }) - .collect::(); + .collect::>(); assert_eq!( ordering, ordering2, "indices and lexicographical ordering do not match" ); let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties: {}", - ordering, eq_properties + "Error in test case ordering:{ordering:?}, eq_properties: {eq_properties}" ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + assert_eq!(ordering.len(), indices.len(), "{err_msg}"); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). + let Some(ordering) = LexOrdering::new(ordering) else { + continue; + }; assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, + is_table_same_after_sort(ordering, &table_data_with_properties)?, "{}", err_msg ); diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index d4b41b6866315..be35ddca8f02d 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -15,55 +15,50 @@ // specific language governing permissions and limitations // under the License. -use datafusion::physical_plan::expressions::col; -use datafusion::physical_plan::expressions::Column; -use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; use std::any::Any; use std::cmp::Ordering; use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; -use arrow::compute::SortOptions; -use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; +use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; +use datafusion_physical_expr::equivalence::{ + convert_to_orderings, EquivalenceClass, ProjectionMapping, +}; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::expressions::{col, Column}; use itertools::izip; use rand::prelude::*; +/// Projects the input schema based on the given projection mapping. pub fn output_schema( mapping: &ProjectionMapping, input_schema: &Arc, ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); + // Calculate output schema: + let mut fields = vec![]; + for (source, targets) in mapping.iter() { + let data_type = source.data_type(input_schema)?; + let nullable = source.nullable(input_schema)?; + for (target, _) in targets.iter() { + let Some(column) = target.as_any().downcast_ref::() else { + return plan_err!("Expects to have column"); + }; + fields.push(Field::new(column.name(), data_type.clone(), nullable)); + } + } let output_schema = Arc::new(Schema::new_with_metadata( - fields?, + fields, input_schema.metadata().clone(), )); @@ -100,9 +95,9 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_f))?; // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(col_e))])?; // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -114,18 +109,18 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(1..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: options_asc, - }) - .collect(); + let ordering = + remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: options_asc, + }); - eq_properties.add_new_orderings([ordering]); + eq_properties.add_ordering(ordering); } Ok((test_schema, eq_properties)) @@ -133,12 +128,12 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti // Apply projection to the input_data, return projected equivalence properties and record batch pub fn apply_projection( - proj_exprs: Vec<(Arc, String)>, + proj_exprs: impl IntoIterator, String)>, input_data: &RecordBatch, input_eq_properties: &EquivalenceProperties, ) -> Result<(RecordBatch, EquivalenceProperties)> { let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let output_schema = output_schema(&projection_mapping, &input_schema)?; let num_rows = input_data.num_rows(); @@ -168,49 +163,49 @@ fn add_equal_conditions_test() -> Result<()> { ])); let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; + let col_x = Arc::new(Column::new("x", 3)) as _; + let col_y = Arc::new(Column::new("y", 4)) as _; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); // b and c are aliases. Existing equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); + assert!(eq_groups.contains(&col_x)); + assert!(eq_groups.contains(&col_y)); Ok(()) } @@ -226,7 +221,7 @@ fn add_equal_conditions_test() -> Result<()> { /// already sorted according to `required_ordering` to begin with. pub fn is_table_same_after_sort( mut required_ordering: LexOrdering, - batch: RecordBatch, + batch: &RecordBatch, ) -> Result { // Clone the original schema and columns let original_schema = batch.schema(); @@ -327,7 +322,7 @@ pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { let col_f = &col("f", &test_schema)?; let col_g = &col("g", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_c))?; let option_asc = SortOptions { descending: false, @@ -350,7 +345,7 @@ pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { ], ]; let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); Ok((test_schema, eq_properties)) } @@ -369,14 +364,14 @@ pub fn generate_table_for_eq_properties( // Utility closure to generate random array let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .map(|_| rng.random_range(0..max_val) as f64 / 2.0) .collect(); Arc::new(Float64Array::from_iter_values(values)) }; // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr().as_any().downcast_ref::().unwrap(); + let col = constant.expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; @@ -461,7 +456,7 @@ pub fn generate_table_for_orderings( let batch = RecordBatch::try_from_iter(arrays)?; // Sort batch according to first ordering expression - let sort_columns = get_sort_columns(&batch, orderings[0].as_ref())?; + let sort_columns = get_sort_columns(&batch, &orderings[0])?; let sort_indices = lexsort_to_indices(&sort_columns, None)?; let mut batch = take_record_batch(&batch, &sort_indices)?; @@ -494,29 +489,6 @@ pub fn generate_table_for_orderings( Ok(batch) } -// Convert each tuple to PhysicalSortExpr -pub fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], -) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect() -} - -// Convert each inner tuple to PhysicalSortExpr -pub fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], -) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() -} - // Utility function to generate random f64 array fn generate_random_f64_array( n_elems: usize, @@ -524,7 +496,7 @@ fn generate_random_f64_array( rng: &mut StdRng, ) -> ArrayRef { let values: Vec = (0..n_elems) - .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) + .map(|_| rng.random_range(0..n_distinct) as f64 / 2.0) .collect(); Arc::new(Float64Array::from_iter_values(values)) } @@ -540,7 +512,7 @@ fn get_sort_columns( .collect::>>() } -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct TestScalarUDF { pub(crate) signature: Signature, } @@ -590,11 +562,11 @@ impl ScalarUDFImpl for TestScalarUDF { DataType::Float64 => Arc::new({ let arg = &args[0].as_any().downcast_ref::().ok_or_else( || { - DataFusionError::Internal(format!( + internal_datafusion_err!( "could not cast {} to {}", self.name(), std::any::type_name::() - )) + ) }, )?; @@ -605,11 +577,11 @@ impl ScalarUDFImpl for TestScalarUDF { DataType::Float32 => Arc::new({ let arg = &args[0].as_any().downcast_ref::().ok_or_else( || { - DataFusionError::Internal(format!( + internal_datafusion_err!( "could not cast {} to {}", self.name(), std::any::type_name::() - )) + ) }, )?; diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index da93dd5edf291..e8ff1ccf06704 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -20,7 +20,7 @@ use std::time::SystemTime; use crate::fuzz_cases::join_fuzz::JoinTestType::{HjSmj, NljHj}; -use arrow::array::{ArrayRef, Int32Array}; +use arrow::array::{ArrayRef, BinaryArray, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; @@ -37,7 +37,7 @@ use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::ScalarValue; +use datafusion_common::{NullEquality, ScalarValue}; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::PhysicalExprRef; @@ -92,8 +92,8 @@ fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { #[tokio::test] async fn test_inner_join_1k_filtered() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::Inner, Some(Box::new(col_lt_col_filter)), ) @@ -104,8 +104,8 @@ async fn test_inner_join_1k_filtered() { #[tokio::test] async fn test_inner_join_1k() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::Inner, None, ) @@ -116,8 +116,8 @@ async fn test_inner_join_1k() { #[tokio::test] async fn test_left_join_1k() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::Left, None, ) @@ -128,8 +128,8 @@ async fn test_left_join_1k() { #[tokio::test] async fn test_left_join_1k_filtered() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::Left, Some(Box::new(col_lt_col_filter)), ) @@ -140,8 +140,8 @@ async fn test_left_join_1k_filtered() { #[tokio::test] async fn test_right_join_1k() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::Right, None, ) @@ -152,8 +152,8 @@ async fn test_right_join_1k() { #[tokio::test] async fn test_right_join_1k_filtered() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::Right, Some(Box::new(col_lt_col_filter)), ) @@ -164,8 +164,8 @@ async fn test_right_join_1k_filtered() { #[tokio::test] async fn test_full_join_1k() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::Full, None, ) @@ -176,8 +176,8 @@ async fn test_full_join_1k() { #[tokio::test] async fn test_full_join_1k_filtered() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::Full, Some(Box::new(col_lt_col_filter)), ) @@ -186,10 +186,10 @@ async fn test_full_join_1k_filtered() { } #[tokio::test] -async fn test_semi_join_1k() { +async fn test_left_semi_join_1k() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::LeftSemi, None, ) @@ -198,10 +198,10 @@ async fn test_semi_join_1k() { } #[tokio::test] -async fn test_semi_join_1k_filtered() { +async fn test_left_semi_join_1k_filtered() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) @@ -209,11 +209,35 @@ async fn test_semi_join_1k_filtered() { .await } +#[tokio::test] +async fn test_right_semi_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_semi_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + #[tokio::test] async fn test_left_anti_join_1k() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::LeftAnti, None, ) @@ -224,8 +248,8 @@ async fn test_left_anti_join_1k() { #[tokio::test] async fn test_left_anti_join_1k_filtered() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) @@ -236,8 +260,8 @@ async fn test_left_anti_join_1k_filtered() { #[tokio::test] async fn test_right_anti_join_1k() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::RightAnti, None, ) @@ -248,8 +272,8 @@ async fn test_right_anti_join_1k() { #[tokio::test] async fn test_right_anti_join_1k_filtered() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::RightAnti, Some(Box::new(col_lt_col_filter)), ) @@ -260,8 +284,8 @@ async fn test_right_anti_join_1k_filtered() { #[tokio::test] async fn test_left_mark_join_1k() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), JoinType::LeftMark, None, ) @@ -272,8 +296,249 @@ async fn test_left_mark_join_1k() { #[tokio::test] async fn test_left_mark_join_1k_filtered() { JoinFuzzTestCase::new( - make_staggered_batches(1000), - make_staggered_batches(1000), + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +// todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support +#[tokio::test] +async fn test_right_mark_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), + JoinType::RightMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_mark_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000), + make_staggered_batches_i32(1000), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_inner_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::Inner, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_inner_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::Inner, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::Left, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::Left, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::Right, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::Right, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_full_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::Full, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_full_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::Full, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj, HjSmj], false) + .await +} + +#[tokio::test] +async fn test_left_semi_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::LeftSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_semi_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::LeftSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_semi_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_semi_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_anti_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::LeftAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_anti_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_anti_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_anti_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_mark_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::LeftMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_mark_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), JoinType::LeftMark, Some(Box::new(col_lt_col_filter)), ) @@ -281,6 +546,31 @@ async fn test_left_mark_join_1k_filtered() { .await } +// todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support +#[tokio::test] +async fn test_right_mark_join_1k_binary() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::RightMark, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_mark_join_1k_binary_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000), + make_staggered_batches_binary(1000), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + type JoinFilterBuilder = Box, Arc) -> JoinFilter>; struct JoinFuzzTestCase { @@ -428,12 +718,18 @@ impl JoinFuzzTestCase { fn left_right(&self) -> (Arc, Arc) { let schema1 = self.input1[0].schema(); let schema2 = self.input2[0].schema(); - let left = - MemorySourceConfig::try_new_exec(&[self.input1.clone()], schema1, None) - .unwrap(); - let right = - MemorySourceConfig::try_new_exec(&[self.input2.clone()], schema2, None) - .unwrap(); + let left = MemorySourceConfig::try_new_exec( + std::slice::from_ref(&self.input1), + schema1, + None, + ) + .unwrap(); + let right = MemorySourceConfig::try_new_exec( + std::slice::from_ref(&self.input2), + schema2, + None, + ) + .unwrap(); (left, right) } @@ -455,7 +751,7 @@ impl JoinFuzzTestCase { self.join_filter(), self.join_type, vec![SortOptions::default(); self.on_columns().len()], - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -472,7 +768,7 @@ impl JoinFuzzTestCase { &self.join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -545,7 +841,7 @@ impl JoinFuzzTestCase { std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); - println!("Test result data mismatch found. HJ rows {}, SMJ rows {}, NLJ rows {}", hj_rows, smj_rows, nlj_rows); + println!("Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}"); println!("The debug is ON. Input data will be saved to {out_dir_name}"); Self::save_partitioned_batches_as_parquet( @@ -561,10 +857,9 @@ impl JoinFuzzTestCase { if join_tests.contains(&NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); - hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + hj_formatted_sorted.iter().for_each(|s| println!("{s}")); println!("=============== NestedLoopJoinExec =================="); - nlj_formatted_sorted.iter().for_each(|s| println!("{}", s)); - + nlj_formatted_sorted.iter().for_each(|s| println!("{s}")); Self::save_partitioned_batches_as_parquet( &nlj_collected, out_dir_name, @@ -579,9 +874,9 @@ impl JoinFuzzTestCase { if join_tests.contains(&HjSmj) && smj_rows != hj_rows { println!("=============== HashJoinExec =================="); - hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + hj_formatted_sorted.iter().for_each(|s| println!("{s}")); println!("=============== SortMergeJoinExec =================="); - smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + smj_formatted_sorted.iter().for_each(|s| println!("{s}")); Self::save_partitioned_batches_as_parquet( &hj_collected, @@ -597,10 +892,10 @@ impl JoinFuzzTestCase { } if join_tests.contains(&NljHj) { - let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); + let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}"); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); - let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size); + let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}"); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows for (i, (nlj_line, hj_line)) in nlj_formatted_sorted @@ -671,7 +966,7 @@ impl JoinFuzzTestCase { std::fs::create_dir_all(out_path).unwrap(); input.iter().enumerate().for_each(|(idx, batch)| { - let file_path = format!("{out_path}/file_{}.parquet", idx); + let file_path = format!("{out_path}/file_{idx}.parquet"); let mut file = std::fs::File::create(&file_path).unwrap(); println!( "{}: Saving batch idx {} rows {} to parquet {}", @@ -722,11 +1017,9 @@ impl JoinFuzzTestCase { path.to_str().unwrap(), datafusion::prelude::ParquetReadOptions::default(), ) - .await - .unwrap() + .await? .collect() - .await - .unwrap(); + .await?; batches.append(&mut batch); } @@ -738,14 +1031,14 @@ impl JoinFuzzTestCase { /// Return randomly sized record batches with: /// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns /// two random int32 columns 'x', 'y' as other columns -fn make_staggered_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); +fn make_staggered_batches_i32(len: usize) -> Vec { + let mut rng = rand::rng(); let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; let mut input3: Vec = vec![0; len]; let mut input4: Vec = vec![0; len]; input12 .iter_mut() - .for_each(|v| *v = (rng.gen_range(0..100), rng.gen_range(0..100))); + .for_each(|v| *v = (rng.random_range(0..100), rng.random_range(0..100))); rng.fill(&mut input3[..]); rng.fill(&mut input4[..]); input12.sort_unstable(); @@ -766,3 +1059,43 @@ fn make_staggered_batches(len: usize) -> Vec { // use a random number generator to pick a random sized output stagger_batch_with_seed(batch, 42) } + +fn rand_bytes(rng: &mut R, min: usize, max: usize) -> Vec { + let n = rng.random_range(min..=max); + let mut v = vec![0u8; n]; + rng.fill(&mut v[..]); + v +} + +/// Return randomly sized record batches with: +/// two sorted binary columns 'a', 'b' (lexicographically) as join columns +/// two random binary columns 'x', 'y' as other columns +fn make_staggered_batches_binary(len: usize) -> Vec { + let mut rng = rand::rng(); + + // produce (a,b) pairs then sort lexicographically so SMJ has naturally sorted keys + let mut input12: Vec<(Vec, Vec)> = (0..len) + .map(|_| (rand_bytes(&mut rng, 4, 16), rand_bytes(&mut rng, 4, 16))) + .collect(); + input12.sort_unstable(); // lexicographic on Vec + + // payload cols (also binary so the existing x < x filter is well-typed) + let input3: Vec> = (0..len).map(|_| rand_bytes(&mut rng, 4, 24)).collect(); + let input4: Vec> = (0..len).map(|_| rand_bytes(&mut rng, 4, 24)).collect(); + + let a = BinaryArray::from_iter_values(input12.iter().map(|k| &k.0)); + let b = BinaryArray::from_iter_values(input12.iter().map(|k| &k.1)); + let x = BinaryArray::from_iter_values(input3.iter()); + let y = BinaryArray::from_iter_values(input4.iter()); + + let batch = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(a) as ArrayRef), + ("b", Arc::new(b) as ArrayRef), + ("x", Arc::new(x) as ArrayRef), + ("y", Arc::new(y) as ArrayRef), + ]) + .unwrap(); + + // preserve your existing randomized partitioning + stagger_batch_with_seed(batch, 42) +} diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 987a732eb294b..4c5ebf0402414 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -24,7 +24,7 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::assert_contains; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use std::sync::Arc; use test_utils::stagger_batch; @@ -54,11 +54,11 @@ async fn run_limit_fuzz_test(make_data: F) where F: Fn(usize) -> SortedData, { - let mut rng = thread_rng(); + let mut rng = rng(); for size in [10, 1_0000, 10_000, 100_000] { let data = make_data(size); // test various limits including some random ones - for limit in [1, 3, 7, 17, 10000, rng.gen_range(1..size * 2)] { + for limit in [1, 3, 7, 17, 10000, rng.random_range(1..size * 2)] { // limit can be larger than the number of rows in the input run_limit_test(limit, &data).await; } @@ -97,13 +97,13 @@ impl SortedData { /// Create an i32 column of random values, with the specified number of /// rows, sorted the default fn new_i32(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); // have some repeats (approximately 1/3 of the values are the same) let max = size as i32 / 3; let data: Vec> = (0..size) .map(|_| { // no nulls for now - Some(rng.gen_range(0..max)) + Some(rng.random_range(0..max)) }) .collect(); @@ -118,17 +118,17 @@ impl SortedData { /// Create an f64 column of random values, with the specified number of /// rows, sorted the default fn new_f64(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); let mut data: Vec> = (0..size / 3) .map(|_| { // no nulls for now - Some(rng.gen_range(0.0..1.0f64)) + Some(rng.random_range(0.0..1.0f64)) }) .collect(); // have some repeats (approximately 1/3 of the values are the same) while data.len() < size { - data.push(data[rng.gen_range(0..data.len())]); + data.push(data[rng.random_range(0..data.len())]); } let batches = stagger_batch(f64_batch(data.iter().cloned())); @@ -142,7 +142,7 @@ impl SortedData { /// Create an string column of random values, with the specified number of /// rows, sorted the default fn new_str(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); let mut data: Vec> = (0..size / 3) .map(|_| { // no nulls for now @@ -152,7 +152,7 @@ impl SortedData { // have some repeats (approximately 1/3 of the values are the same) while data.len() < size { - data.push(data[rng.gen_range(0..data.len())].clone()); + data.push(data[rng.random_range(0..data.len())].clone()); } let batches = stagger_batch(string_batch(data.iter())); @@ -166,7 +166,7 @@ impl SortedData { /// Create two columns of random values (int64, string), with the specified number of /// rows, sorted the default fn new_i64str(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); // 100 distinct values let strings: Vec> = (0..100) @@ -180,8 +180,8 @@ impl SortedData { let data = (0..size) .map(|_| { ( - Some(rng.gen_range(0..10)), - strings[rng.gen_range(0..strings.len())].clone(), + Some(rng.random_range(0..10)), + strings[rng.random_range(0..strings.len())].clone(), ) }) .collect::>(); @@ -340,8 +340,8 @@ async fn run_limit_test(fetch: usize, data: &SortedData) { /// Return random ASCII String with len fn get_random_string(len: usize) -> String { - thread_rng() - .sample_iter(rand::distributions::Alphanumeric) + rng() + .sample_iter(rand::distr::Alphanumeric) .take(len) .map(char::from) .collect() diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 92f3755250663..b92dec64e3f19 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -31,7 +31,6 @@ use datafusion::physical_plan::{ sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use test_utils::{batches_to_vec, partitions_to_sorted_vec, stagger_batch_with_seed}; @@ -109,13 +108,14 @@ async fn run_merge_test(input: Vec>) { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("x", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index d5511e2970f4d..9e2fd170f7f0c 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -20,6 +20,8 @@ mod distinct_count_string_fuzz; mod join_fuzz; mod merge_fuzz; mod sort_fuzz; +mod sort_query_fuzz; +mod topk_filter_pushdown; mod aggregation_fuzzer; mod equivalence; @@ -29,3 +31,8 @@ mod pruning; mod limit_fuzz; mod sort_preserving_repartition_fuzz; mod window_fuzz; + +// Utility modules +mod once_exec; +mod record_batch_generator; +mod spilling_fuzz_in_memory_constrained_env; diff --git a/datafusion/core/tests/fuzz_cases/once_exec.rs b/datafusion/core/tests/fuzz_cases/once_exec.rs new file mode 100644 index 0000000000000..49e2caaa7417c --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/once_exec.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SchemaRef; +use datafusion_common::internal_datafusion_err; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, Mutex}; + +/// Execution plan that return the stream on the call to `execute`. further calls to `execute` will +/// return an error +pub struct OnceExec { + /// the results to send back + stream: Mutex>, + cache: PlanProperties, +} + +impl Debug for OnceExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "OnceExec") + } +} + +impl OnceExec { + pub fn new(stream: SendableRecordBatchStream) -> Self { + let cache = Self::compute_properties(stream.schema()); + Self { + stream: Mutex::new(Some(stream)), + cache, + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } +} + +impl DisplayAs for OnceExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "OnceExec:") + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for OnceExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + /// Returns a stream which yields data + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> datafusion_common::Result { + assert_eq!(partition, 0); + + let stream = self.stream.lock().unwrap().take(); + + stream.ok_or_else(|| internal_datafusion_err!("Stream already consumed")) + } +} diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index 11dd961a54ee5..f8bd4dbc1a768 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -90,42 +90,42 @@ async fn test_utf8_not_like() { #[tokio::test] async fn test_utf8_like_prefix() { - Utf8Test::new(|value| col("a").like(lit(format!("%{}", value)))) + Utf8Test::new(|value| col("a").like(lit(format!("%{value}")))) .run() .await; } #[tokio::test] async fn test_utf8_like_suffix() { - Utf8Test::new(|value| col("a").like(lit(format!("{}%", value)))) + Utf8Test::new(|value| col("a").like(lit(format!("{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_prefix() { - Utf8Test::new(|value| col("a").not_like(lit(format!("%{}", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("%{value}")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_ecsape() { - Utf8Test::new(|value| col("a").not_like(lit(format!("\\%{}%", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("\\%{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_suffix() { - Utf8Test::new(|value| col("a").not_like(lit(format!("{}%", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_suffix_one() { - Utf8Test::new(|value| col("a").not_like(lit(format!("{}_", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("{value}_")))) .run() .await; } @@ -201,7 +201,7 @@ impl Utf8Test { } } - /// all combinations of interesting charactes with lengths ranging from 1 to 4 + /// all combinations of interesting characters with lengths ranging from 1 to 4 fn values() -> &'static [String] { &VALUES } @@ -226,7 +226,7 @@ impl Utf8Test { return (*files).clone(); } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = Self::values(); let mut row_groups = vec![]; @@ -276,7 +276,7 @@ async fn execute_with_predicate( ctx: &SessionContext, ) -> Vec { let parquet_source = if prune_stats { - ParquetSource::default().with_predicate(Arc::clone(&schema), predicate.clone()) + ParquetSource::default().with_predicate(predicate.clone()) } else { ParquetSource::default() }; @@ -319,14 +319,9 @@ async fn write_parquet_file( row_groups: Vec>, ) -> Bytes { let mut buf = BytesMut::new().writer(); - let mut props = WriterProperties::builder(); - if let Some(truncation_length) = truncation_length { - props = { - #[allow(deprecated)] - props.set_max_statistics_size(truncation_length) - } - } - props = props.set_statistics_enabled(EnabledStatistics::Chunk); // row group level + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::Chunk) // row group level + .set_statistics_truncate_length(truncation_length); let props = props.build(); { let mut writer = @@ -345,7 +340,7 @@ async fn write_parquet_file( /// The string values for [Utf8Test::values] static VALUES: LazyLock> = LazyLock::new(|| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let characters = [ "z", diff --git a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs new file mode 100644 index 0000000000000..45dba5f7864b1 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs @@ -0,0 +1,833 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch}; +use arrow::datatypes::{ + ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, + Decimal256Type, Decimal32Type, Decimal64Type, DurationMicrosecondType, + DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, + Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, +}; +use arrow_schema::{ + DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, + DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, +}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; +use rand::{rng, rngs::StdRng, Rng, SeedableRng}; +use test_utils::array_gen::{ + BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, + PrimitiveArrayGenerator, StringArrayGenerator, +}; + +/// Columns that are supported by the record batch generator +/// The RNG is used to generate the precision and scale for the decimal columns, thread +/// RNG is not used because this is used in fuzzing and deterministic results are preferred +pub fn get_supported_types_columns(rng_seed: u64) -> Vec { + let mut rng = StdRng::seed_from_u64(rng_seed); + vec![ + ColumnDescr::new("i8", DataType::Int8), + ColumnDescr::new("i16", DataType::Int16), + ColumnDescr::new("i32", DataType::Int32), + ColumnDescr::new("i64", DataType::Int64), + ColumnDescr::new("u8", DataType::UInt8), + ColumnDescr::new("u16", DataType::UInt16), + ColumnDescr::new("u32", DataType::UInt32), + ColumnDescr::new("u64", DataType::UInt64), + ColumnDescr::new("date32", DataType::Date32), + ColumnDescr::new("date64", DataType::Date64), + ColumnDescr::new("time32_s", DataType::Time32(TimeUnit::Second)), + ColumnDescr::new("time32_ms", DataType::Time32(TimeUnit::Millisecond)), + ColumnDescr::new("time64_us", DataType::Time64(TimeUnit::Microsecond)), + ColumnDescr::new("time64_ns", DataType::Time64(TimeUnit::Nanosecond)), + ColumnDescr::new("timestamp_s", DataType::Timestamp(TimeUnit::Second, None)), + ColumnDescr::new( + "timestamp_ms", + DataType::Timestamp(TimeUnit::Millisecond, None), + ), + ColumnDescr::new( + "timestamp_us", + DataType::Timestamp(TimeUnit::Microsecond, None), + ), + ColumnDescr::new( + "timestamp_ns", + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ColumnDescr::new("float32", DataType::Float32), + ColumnDescr::new("float64", DataType::Float64), + ColumnDescr::new( + "interval_year_month", + DataType::Interval(IntervalUnit::YearMonth), + ), + ColumnDescr::new( + "interval_day_time", + DataType::Interval(IntervalUnit::DayTime), + ), + ColumnDescr::new( + "interval_month_day_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + ), + // Internal error: AggregationFuzzer task error: JoinError::Panic(Id(29108), "called `Option::unwrap()` on a `None` value", ...). + // ColumnDescr::new( + // "duration_seconds", + // DataType::Duration(TimeUnit::Second), + // ), + ColumnDescr::new( + "duration_milliseconds", + DataType::Duration(TimeUnit::Millisecond), + ), + ColumnDescr::new( + "duration_microsecond", + DataType::Duration(TimeUnit::Microsecond), + ), + ColumnDescr::new( + "duration_nanosecond", + DataType::Duration(TimeUnit::Nanosecond), + ), + ColumnDescr::new("decimal32", { + let precision: u8 = rng.random_range(1..=DECIMAL32_MAX_PRECISION); + let scale: i8 = rng.random_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL32_MAX_SCALE), + ); + DataType::Decimal32(precision, scale) + }), + ColumnDescr::new("decimal64", { + let precision: u8 = rng.random_range(1..=DECIMAL64_MAX_PRECISION); + let scale: i8 = rng.random_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL64_MAX_SCALE), + ); + DataType::Decimal64(precision, scale) + }), + ColumnDescr::new("decimal128", { + let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION); + let scale: i8 = rng.random_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL128_MAX_SCALE), + ); + DataType::Decimal128(precision, scale) + }), + ColumnDescr::new("decimal256", { + let precision: u8 = rng.random_range(1..=DECIMAL256_MAX_PRECISION); + let scale: i8 = rng.random_range( + i8::MIN..=std::cmp::min(precision as i8, DECIMAL256_MAX_SCALE), + ); + DataType::Decimal256(precision, scale) + }), + ColumnDescr::new("utf8", DataType::Utf8), + ColumnDescr::new("largeutf8", DataType::LargeUtf8), + ColumnDescr::new("utf8view", DataType::Utf8View), + ColumnDescr::new("u8_low", DataType::UInt8).with_max_num_distinct(10), + ColumnDescr::new("utf8_low", DataType::Utf8).with_max_num_distinct(10), + ColumnDescr::new("bool", DataType::Boolean), + ColumnDescr::new("binary", DataType::Binary), + ColumnDescr::new("large_binary", DataType::LargeBinary), + ColumnDescr::new("binaryview", DataType::BinaryView), + ColumnDescr::new( + "dictionary_utf8_low", + DataType::Dictionary(Box::new(DataType::UInt64), Box::new(DataType::Utf8)), + ) + .with_max_num_distinct(10), + ] +} + +#[derive(Debug, Clone)] +pub struct ColumnDescr { + /// Column name + pub name: String, + + /// Data type of this column + pub column_type: DataType, + + /// The maximum number of distinct values in this column. + /// + /// See [`ColumnDescr::with_max_num_distinct`] for more information + max_num_distinct: Option, +} + +impl ColumnDescr { + #[inline] + pub fn new(name: &str, column_type: DataType) -> Self { + Self { + name: name.to_string(), + column_type, + max_num_distinct: None, + } + } + + pub fn get_max_num_distinct(&self) -> Option { + self.max_num_distinct + } + + /// set the maximum number of distinct values in this column + /// + /// If `None`, the number of distinct values is randomly selected between 1 + /// and the number of rows. + pub fn with_max_num_distinct(mut self, num_distinct: usize) -> Self { + self.max_num_distinct = Some(num_distinct); + self + } +} + +/// Record batch generator +pub struct RecordBatchGenerator { + pub min_rows_num: usize, + + pub max_rows_num: usize, + + pub columns: Vec, + + pub candidate_null_pcts: Vec, + + /// If a seed is provided when constructing the generator, it will be used to + /// create `rng` and the pseudo-randomly generated batches will be deterministic. + /// Otherwise, `rng` will be initialized using `rng()` and the batches + /// generated will be different each time. + rng: StdRng, +} + +macro_rules! generate_decimal_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ + let mut generator = DecimalArrayGenerator { + precision: $PRECISION, + scale: $SCALE, + num_decimals: $NUM_ROWS, + num_distinct_decimals: $MAX_NUM_DISTINCT, + null_pct: $NULL_PCT, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}; +} + +// Generating `BooleanArray` due to it being a special type in Arrow (bit-packed) +macro_rules! generate_boolean_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ + let num_distinct_booleans = if $MAX_NUM_DISTINCT >= 2 { 2 } else { 1 }; + + let mut generator = BooleanArrayGenerator { + num_booleans: $NUM_ROWS, + num_distinct_booleans, + null_pct: $NULL_PCT, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}; +} + +macro_rules! generate_primitive_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ + let mut generator = PrimitiveArrayGenerator { + num_primitives: $NUM_ROWS, + num_distinct_primitives: $MAX_NUM_DISTINCT, + null_pct: $NULL_PCT, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}; +} + +macro_rules! generate_dict { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident, $VALUES: ident) => {{ + debug_assert_eq!($VALUES.len(), $MAX_NUM_DISTINCT); + let keys: PrimitiveArray<$ARROW_TYPE> = (0..$NUM_ROWS) + .map(|_| { + if $BATCH_GEN_RNG.random::() < $NULL_PCT { + None + } else if $MAX_NUM_DISTINCT > 1 { + let range = 0..($MAX_NUM_DISTINCT + as <$ARROW_TYPE as ArrowPrimitiveType>::Native); + Some($ARRAY_GEN_RNG.random_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let dict = DictionaryArray::new(keys, $VALUES); + Arc::new(dict) as ArrayRef + }}; +} + +impl RecordBatchGenerator { + /// Create a new `RecordBatchGenerator` with a random seed. The generated + /// batches will be different each time. + pub fn new( + min_rows_nun: usize, + max_rows_num: usize, + columns: Vec, + ) -> Self { + let candidate_null_pcts = vec![0.0, 0.01, 0.1, 0.5]; + + Self { + min_rows_num: min_rows_nun, + max_rows_num, + columns, + candidate_null_pcts, + rng: StdRng::from_rng(&mut rng()), + } + } + + /// Set a seed for the generator. The pseudo-randomly generated batches will be + /// deterministic for the same seed. + pub fn with_seed(mut self, seed: u64) -> Self { + self.rng = StdRng::seed_from_u64(seed); + self + } + + pub fn generate(&mut self) -> Result { + let num_rows = self.rng.random_range(self.min_rows_num..=self.max_rows_num); + let array_gen_rng = StdRng::from_seed(self.rng.random()); + let mut batch_gen_rng = StdRng::from_seed(self.rng.random()); + let columns = self.columns.clone(); + + // Build arrays + let mut arrays = Vec::with_capacity(columns.len()); + for col in columns.iter() { + let array = self.generate_array_of_type( + col, + num_rows, + &mut batch_gen_rng, + array_gen_rng.clone(), + ); + arrays.push(array); + } + + // Build schema + let fields = self + .columns + .iter() + .map(|col| Field::new(col.name.clone(), col.column_type.clone(), true)) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + + RecordBatch::try_new(schema, arrays).map_err(|e| arrow_datafusion_err!(e)) + } + + fn generate_array_of_type( + &mut self, + col: &ColumnDescr, + num_rows: usize, + batch_gen_rng: &mut StdRng, + array_gen_rng: StdRng, + ) -> ArrayRef { + let null_pct_idx = batch_gen_rng.random_range(0..self.candidate_null_pcts.len()); + let null_pct = self.candidate_null_pcts[null_pct_idx]; + + Self::generate_array_of_type_inner( + col, + num_rows, + batch_gen_rng, + array_gen_rng, + null_pct, + ) + } + + fn generate_array_of_type_inner( + col: &ColumnDescr, + num_rows: usize, + batch_gen_rng: &mut StdRng, + array_gen_rng: StdRng, + null_pct: f64, + ) -> ArrayRef { + let num_distinct = if num_rows > 1 { + batch_gen_rng.random_range(1..num_rows) + } else { + num_rows + }; + // cap to at most the num_distinct values + let max_num_distinct = col + .max_num_distinct + .map(|max| num_distinct.min(max)) + .unwrap_or(num_distinct); + + match col.column_type { + DataType::Int8 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Int8Type + ) + } + DataType::Int16 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Int16Type + ) + } + DataType::Int32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Int32Type + ) + } + DataType::Int64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Int64Type + ) + } + DataType::UInt8 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + UInt8Type + ) + } + DataType::UInt16 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + UInt16Type + ) + } + DataType::UInt32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + UInt32Type + ) + } + DataType::UInt64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + UInt64Type + ) + } + DataType::Float32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Float32Type + ) + } + DataType::Float64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Float64Type + ) + } + DataType::Date32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Date32Type + ) + } + DataType::Date64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Date64Type + ) + } + DataType::Time32(TimeUnit::Second) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Time32SecondType + ) + } + DataType::Time32(TimeUnit::Millisecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Time32MillisecondType + ) + } + DataType::Time64(TimeUnit::Microsecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Time64MicrosecondType + ) + } + DataType::Time64(TimeUnit::Nanosecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + Time64NanosecondType + ) + } + DataType::Interval(IntervalUnit::YearMonth) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + IntervalYearMonthType + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + IntervalDayTimeType + ) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + IntervalMonthDayNanoType + ) + } + DataType::Duration(TimeUnit::Second) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationSecondType + ) + } + DataType::Duration(TimeUnit::Millisecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationMillisecondType + ) + } + DataType::Duration(TimeUnit::Microsecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationMicrosecondType + ) + } + DataType::Duration(TimeUnit::Nanosecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationNanosecondType + ) + } + DataType::Timestamp(TimeUnit::Second, None) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + TimestampSecondType + ) + } + DataType::Timestamp(TimeUnit::Millisecond, None) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + TimestampMillisecondType + ) + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + TimestampMicrosecondType + ) + } + DataType::Timestamp(TimeUnit::Nanosecond, None) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + TimestampNanosecondType + ) + } + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + let max_len = batch_gen_rng.random_range(1..50); + + let mut generator = StringArrayGenerator { + max_len, + num_strings: num_rows, + num_distinct_strings: max_num_distinct, + null_pct, + rng: array_gen_rng, + }; + + match col.column_type { + DataType::Utf8 => generator.gen_data::(), + DataType::LargeUtf8 => generator.gen_data::(), + DataType::Utf8View => generator.gen_string_view(), + _ => unreachable!(), + } + } + DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { + let max_len = batch_gen_rng.random_range(1..100); + + let mut generator = BinaryArrayGenerator { + max_len, + num_binaries: num_rows, + num_distinct_binaries: max_num_distinct, + null_pct, + rng: array_gen_rng, + }; + + match col.column_type { + DataType::Binary => generator.gen_data::(), + DataType::LargeBinary => generator.gen_data::(), + DataType::BinaryView => generator.gen_binary_view(), + _ => unreachable!(), + } + } + DataType::Decimal32(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal32Type + ) + } + DataType::Decimal64(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal64Type + ) + } + DataType::Decimal128(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal128Type + ) + } + DataType::Decimal256(precision, scale) => { + generate_decimal_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + precision, + scale, + Decimal256Type + ) + } + DataType::Boolean => { + generate_boolean_array! { + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + BooleanType + } + } + DataType::Dictionary(ref key_type, ref value_type) + if key_type.is_dictionary_key_type() => + { + // We generate just num_distinct values because they will be reused by different keys + let mut array_gen_rng = array_gen_rng; + debug_assert!((0.0..=1.0).contains(&null_pct)); + let values = Self::generate_array_of_type_inner( + &ColumnDescr::new("values", *value_type.clone()), + num_distinct, + batch_gen_rng, + array_gen_rng.clone(), + null_pct, // generate some null values + ); + + match key_type.as_ref() { + // new key types can be added here + DataType::UInt64 => generate_dict!( + self, + num_rows, + num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + UInt64Type, + values + ), + _ => panic!("Invalid dictionary keys type: {key_type}"), + } + } + _ => { + panic!("Unsupported data generator type: {}", col.column_type) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generator_with_fixed_seed_deterministic() { + let mut gen1 = RecordBatchGenerator::new( + 16, + 32, + vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::UInt32), + ], + ) + .with_seed(310104); + + let mut gen2 = RecordBatchGenerator::new( + 16, + 32, + vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::UInt32), + ], + ) + .with_seed(310104); + + let batch1 = gen1.generate().unwrap(); + let batch2 = gen2.generate().unwrap(); + + let batch1_formatted = format!("{batch1:?}"); + let batch2_formatted = format!("{batch2:?}"); + + assert_eq!(batch1_formatted, batch2_formatted); + } +} diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 0b0f0aa2f105a..28d28a6622a76 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -188,7 +188,7 @@ impl SortTest { } fn with_sort_columns(mut self, sort_columns: Vec<&str>) -> Self { - self.sort_columns = sort_columns.iter().map(|s| s.to_string()).collect(); + self.sort_columns = sort_columns.iter().map(|s| (*s).to_string()).collect(); self } @@ -232,23 +232,20 @@ impl SortTest { .expect("at least one batch"); let schema = first_batch.schema(); - let sort_ordering = LexOrdering::new( - self.sort_columns - .iter() - .map(|c| PhysicalSortExpr { - expr: col(c, &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }) - .collect(), - ); + let sort_ordering = + LexOrdering::new(self.sort_columns.iter().map(|c| PhysicalSortExpr { + expr: col(c, &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + })) + .unwrap(); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let sort = Arc::new(SortExec::new(sort_ordering, exec)); - let session_config = SessionConfig::new(); + let session_config = SessionConfig::new().with_repartition_file_scans(false); let session_ctx = if let Some(pool_size) = self.pool_size { // Make sure there is enough space for the initial spill // reservation @@ -298,20 +295,20 @@ impl SortTest { /// Return randomly sized record batches in a field named 'x' of type `Int32` /// with randomized i32 content fn make_staggered_i32_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( RecordBatch::try_from_iter(vec![( "x", Arc::new(Int32Array::from_iter_values( - (0..to_read).map(|_| rng.gen()), + (0..to_read).map(|_| rng.random()), )) as ArrayRef, )]) .unwrap(), @@ -323,20 +320,20 @@ fn make_staggered_i32_batches(len: usize) -> Vec { /// Return randomly sized record batches in a field named 'x' of type `Utf8` /// with randomized content fn make_staggered_utf8_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( RecordBatch::try_from_iter(vec![( "x", Arc::new(StringArray::from_iter_values( - (0..to_read).map(|_| format!("test_string_{}", rng.gen::())), + (0..to_read).map(|_| format!("test_string_{}", rng.random::())), )) as ArrayRef, )]) .unwrap(), @@ -349,13 +346,13 @@ fn make_staggered_utf8_batches(len: usize) -> Vec { /// with randomized i32 content and a field named 'y' of type `Utf8` /// with randomized content fn make_staggered_i32_utf8_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( @@ -363,13 +360,14 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { ( "x", Arc::new(Int32Array::from_iter_values( - (0..to_read).map(|_| rng.gen()), + (0..to_read).map(|_| rng.random()), )) as ArrayRef, ), ( "y", Arc::new(StringArray::from_iter_values( - (0..to_read).map(|_| format!("test_string_{}", rng.gen::())), + (0..to_read) + .map(|_| format!("test_string_{}", rng.random::())), )) as ArrayRef, ), ]) diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 06b93d41af362..99b20790fc46b 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -23,6 +23,8 @@ mod sp_repartition_fuzz_tests { use arrow::compute::{concat_batches, lexsort, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; use datafusion::physical_plan::{ collect, metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, @@ -34,19 +36,16 @@ mod sp_repartition_fuzz_tests { }; use datafusion::prelude::SessionContext; use datafusion_common::Result; - use datafusion_execution::{ - config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, - }; - use datafusion_physical_expr::{ - equivalence::{EquivalenceClass, EquivalenceProperties}, - expressions::{col, Column}, - ConstExpr, PhysicalExpr, PhysicalSortExpr, + use datafusion_execution::{config::SessionConfig, memory_pool::MemoryConsumer}; + use datafusion_physical_expr::equivalence::{ + EquivalenceClass, EquivalenceProperties, }; + use datafusion_physical_expr::expressions::{col, Column}; + use datafusion_physical_expr::ConstExpr; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use test_utils::add_empty_batches; - use datafusion::datasource::memory::MemorySourceConfig; - use datafusion::datasource::source::DataSourceExec; - use datafusion_physical_expr_common::sort_expr::LexOrdering; use itertools::izip; use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; @@ -80,9 +79,9 @@ mod sp_repartition_fuzz_tests { let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_f))?; // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(col_e))])?; // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -94,18 +93,18 @@ mod sp_repartition_fuzz_tests { }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(1..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: options_asc, - }) - .collect(); + let ordering = + remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }); - eq_properties.add_new_orderings([ordering]); + eq_properties.add_ordering(ordering); } Ok((test_schema, eq_properties)) @@ -144,14 +143,14 @@ mod sp_repartition_fuzz_tests { // Utility closure to generate random array let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as u64) + .map(|_| rng.random_range(0..max_val) as u64) .collect(); Arc::new(UInt64Array::from_iter_values(values)) }; // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr().as_any().downcast_ref::().unwrap(); + let col = constant.expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; @@ -227,21 +226,21 @@ mod sp_repartition_fuzz_tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEM, N_DISTINCT)?; let schema = table_data_with_properties.schema(); - let streams: Vec = (0..N_PARTITION) + let streams = (0..N_PARTITION) .map(|_idx| { let batch = table_data_with_properties.clone(); Box::pin(RecordBatchStreamAdapter::new( schema.clone(), futures::stream::once(async { Ok(batch) }), - )) as SendableRecordBatchStream + )) as _ }) .collect::>(); - // Returns concatenated version of the all available orderings - let exprs = eq_properties - .oeq_class() - .output_ordering() - .unwrap_or_default(); + // Returns concatenated version of the all available orderings: + let Some(exprs) = eq_properties.oeq_class().output_ordering() else { + // We always should have an ordering due to the way we generate the schema: + unreachable!("No ordering found in eq_properties: {:?}", eq_properties); + }; let context = SessionContext::new().task_ctx(); let mem_reservation = @@ -261,7 +260,7 @@ mod sp_repartition_fuzz_tests { let res = concat_batches(&res[0].schema(), &res)?; for ordering in eq_properties.oeq_class().iter() { - let err_msg = format!("error in eq properties: {:?}", eq_properties); + let err_msg = format!("error in eq properties: {eq_properties:?}"); let sort_columns = ordering .iter() .map(|sort_expr| sort_expr.evaluate_to_sort_column(&res)) @@ -273,7 +272,7 @@ mod sp_repartition_fuzz_tests { let sorted_columns = lexsort(&sort_columns, None)?; // Make sure after merging ordering is still valid. - assert_eq!(orig_columns.len(), sorted_columns.len(), "{}", err_msg); + assert_eq!(orig_columns.len(), sorted_columns.len(), "{err_msg}"); assert!( izip!(orig_columns.into_iter(), sorted_columns.into_iter()) .all(|(lhs, rhs)| { lhs == rhs }), @@ -347,20 +346,16 @@ mod sp_repartition_fuzz_tests { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = LexOrdering::default(); - for ordering_col in ["a", "b", "c"] { - sort_keys.push(PhysicalSortExpr { - expr: col(ordering_col, &schema).unwrap(), - options: SortOptions::default(), - }) - } + let sort_keys = ["a", "b", "c"].map(|ordering_col| { + PhysicalSortExpr::new_default(col(ordering_col, &schema).unwrap()) + }); let concat_input_record = concat_batches(&schema, &input1).unwrap(); let running_source = Arc::new( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None) + MemorySourceConfig::try_new(&[input1], schema.clone(), None) .unwrap() - .try_with_sort_information(vec![sort_keys.clone()]) + .try_with_sort_information(vec![sort_keys.clone().into()]) .unwrap(), ); let running_source = Arc::new(DataSourceExec::new(running_source)); @@ -381,7 +376,7 @@ mod sp_repartition_fuzz_tests { sort_preserving_repartition_exec_hash(intermediate, hash_exprs.clone()) }; - let final_plan = sort_preserving_merge_exec(sort_keys.clone(), intermediate); + let final_plan = sort_preserving_merge_exec(sort_keys.into(), intermediate); let task_ctx = ctx.task_ctx(); let collected_running = collect(final_plan, task_ctx.clone()).await.unwrap(); @@ -428,10 +423,9 @@ mod sp_repartition_fuzz_tests { } fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, + sort_exprs: LexOrdering, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) } @@ -447,9 +441,9 @@ mod sp_repartition_fuzz_tests { let mut input123: Vec<(i64, i64, i64)> = vec![(0, 0, 0); len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, ) }); input123.sort(); @@ -471,7 +465,7 @@ mod sp_repartition_fuzz_tests { let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { break; } @@ -481,7 +475,7 @@ mod sp_repartition_fuzz_tests { } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); diff --git a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs new file mode 100644 index 0000000000000..2ce7db3ea4bc7 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs @@ -0,0 +1,624 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill + +use std::cmp::min; +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow_schema::SchemaRef; +use datafusion::datasource::MemTable; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::{instant::Instant, Result}; +use datafusion_execution::disk_manager::DiskManagerBuilder; +use datafusion_execution::memory_pool::{ + human_readable_size, MemoryPool, UnboundedMemoryPool, +}; +use datafusion_expr::display_schema; +use datafusion_physical_plan::spill::get_record_batch_memory_size; +use std::time::Duration; + +use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; +use rand::prelude::IndexedRandom; +use rand::Rng; +use rand::{rngs::StdRng, SeedableRng}; + +use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + +use super::aggregation_fuzzer::ColumnDescr; +use super::record_batch_generator::{get_supported_types_columns, RecordBatchGenerator}; + +/// Entry point for executing the sort query fuzzer. +/// +/// Now memory limiting is disabled by default. See TODOs in `SortQueryFuzzer`. +#[tokio::test(flavor = "multi_thread")] +async fn sort_query_fuzzer_runner() { + let random_seed = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + let test_generator = SortFuzzerTestGenerator::new( + 2000, + 3, + "sort_fuzz_table".to_string(), + get_supported_types_columns(random_seed), + false, + random_seed, + ); + let mut fuzzer = SortQueryFuzzer::new(random_seed) + // Configs for how many random query to test + .with_max_rounds(Some(5)) + .with_queries_per_round(4) + .with_config_variations_per_query(5) + // Will stop early if the time limit is reached + .with_time_limit(Duration::from_secs(5)) + .with_test_generator(test_generator); + + fuzzer.run().await.unwrap(); +} + +/// SortQueryFuzzer holds the runner configuration for executing sort query fuzz tests. The fuzzing details are managed inside `SortFuzzerTestGenerator`. +/// +/// It defines: +/// - `max_rounds`: Maximum number of rounds to run (or None to run until `time_limit`). +/// - `queries_per_round`: Number of different queries to run in each round. +/// - `config_variations_per_query`: Number of different configurations to test per query. +/// - `time_limit`: Time limit for the entire fuzzer execution. +/// +/// TODO: The following improvements are blocked on https://github.com/apache/datafusion/issues/14748: +/// 1. Support generating queries with arbitrary number of ORDER BY clauses +/// Currently limited to be smaller than number of projected columns +/// 2. Enable special type columns like utf8_low to be used in ORDER BY clauses +/// 3. Enable memory limiting functionality in the fuzzer runner +pub struct SortQueryFuzzer { + test_gen: SortFuzzerTestGenerator, + /// Random number generator for the runner, used to generate seeds for inner components. + /// Seeds for each choice (query, config, etc.) are printed out for reproducibility. + runner_rng: StdRng, + + // ======================================================================== + // Runner configurations + // ======================================================================== + /// For each round, a new dataset is generated. If `None`, keep running until + /// the time limit is reached + max_rounds: Option, + /// How many different queries to run in each round + queries_per_round: usize, + /// For each query, how many different configurations to try and make sure their + /// results are consistent + config_variations_per_query: usize, + /// The time limit for the entire sort query fuzzer execution. + time_limit: Option, +} + +impl SortQueryFuzzer { + pub fn new(seed: u64) -> Self { + let max_rounds = Some(2); + let queries_per_round = 3; + let config_variations_per_query = 5; + let time_limit = None; + + // Filtered out one column due to a known bug https://github.com/apache/datafusion/issues/14748 + // TODO: Remove this once the bug is fixed + let candidate_columns = get_supported_types_columns(seed) + .into_iter() + .filter(|col| { + col.name != "utf8_low" + && col.name != "utf8view" + && col.name != "binaryview" + }) + .collect::>(); + + let test_gen = SortFuzzerTestGenerator::new( + 10000, + 4, + "sort_fuzz_table".to_string(), + candidate_columns, + false, + seed, + ); + + Self { + max_rounds, + queries_per_round, + config_variations_per_query, + time_limit, + test_gen, + runner_rng: StdRng::seed_from_u64(seed), + } + } + + pub fn with_test_generator(mut self, test_gen: SortFuzzerTestGenerator) -> Self { + self.test_gen = test_gen; + self + } + + pub fn with_max_rounds(mut self, max_rounds: Option) -> Self { + self.max_rounds = max_rounds; + self + } + + pub fn with_queries_per_round(mut self, queries_per_round: usize) -> Self { + self.queries_per_round = queries_per_round; + self + } + + pub fn with_config_variations_per_query( + mut self, + config_variations_per_query: usize, + ) -> Self { + self.config_variations_per_query = config_variations_per_query; + self + } + + pub fn with_time_limit(mut self, time_limit: Duration) -> Self { + self.time_limit = Some(time_limit); + self + } + + fn should_stop_due_to_time_limit( + &self, + start_time: Instant, + n_round: usize, + n_query: usize, + ) -> bool { + if let Some(time_limit) = self.time_limit { + if Instant::now().duration_since(start_time) > time_limit { + println!( + "[SortQueryFuzzer] Time limit reached: {} queries ({} random configs each) in {} rounds", + n_round * self.queries_per_round + n_query, + self.config_variations_per_query, + n_round + ); + return true; + } + } + false + } + + pub async fn run(&mut self) -> Result<()> { + let start_time = Instant::now(); + + // Execute until either`max_rounds` or `time_limit` is reached + let max_rounds = self.max_rounds.unwrap_or(usize::MAX); + for round in 0..max_rounds { + let init_seed = self.runner_rng.random(); + for query_i in 0..self.queries_per_round { + let query_seed = self.runner_rng.random(); + let mut expected_results: Option> = None; // use first config's result as the expected result + for config_i in 0..self.config_variations_per_query { + if self.should_stop_due_to_time_limit(start_time, round, query_i) { + return Ok(()); + } + + let config_seed = self.runner_rng.random(); + + println!( + "[SortQueryFuzzer] Round {round}, Query {query_i} (Config {config_i})" + ); + println!(" Seeds:"); + println!(" init_seed = {init_seed}"); + println!(" query_seed = {query_seed}"); + println!(" config_seed = {config_seed}"); + + let results = self + .test_gen + .fuzzer_run(init_seed, query_seed, config_seed) + .await?; + println!("\n"); // Separator between tested runs + + if expected_results.is_none() { + expected_results = Some(results); + } else if let Some(ref expected) = expected_results { + // `fuzzer_run` might append `LIMIT k` to either the + // expected or actual query. The number of results is + // checked inside `fuzzer_run()`. Here we only check + // that the first k rows of each result are consistent. + check_equality_of_batches(expected, &results).unwrap(); + } else { + unreachable!(); + } + } + } + } + Ok(()) + } +} + +/// Struct to generate and manage a random dataset for fuzz testing. +/// It is able to re-run the failed test cases by setting the same seed printed out. +/// See the unit tests for examples. +/// +/// To use this struct: +/// 1. Call `init_partitioned_staggered_batches` to generate a random dataset. +/// 2. Use `generate_random_query` to create a random SQL query. +/// 3. Use `generate_random_config` to create a random configuration. +/// 4. Run the fuzzer check with the generated query and configuration. +pub struct SortFuzzerTestGenerator { + /// The total number of rows for the registered table + num_rows: usize, + /// Max number of partitions for the registered table + max_partitions: usize, + /// The name of the registered table + table_name: String, + /// The selected columns from all available candidate columns to be used for + /// this dataset + selected_columns: Vec, + /// If true, will randomly generate a memory limit for the query. Otherwise + /// the query will run under the context with unlimited memory. + set_memory_limit: bool, + + /// States related to the randomly generated dataset. `None` if not initialized + /// by calling `init_partitioned_staggered_batches()` + dataset_state: Option, +} + +/// Struct to hold states related to the randomly generated dataset +pub struct DatasetState { + /// Dataset to construct the partitioned memory table. Outer vector is the + /// partitions, inner vector is staggered batches within the same partition. + partitioned_staggered_batches: Vec>, + /// Number of rows in the whole dataset + dataset_size: usize, + /// The approximate number of rows of a batch (staggered batches will be generated + /// with random number of rows between 1 and `approx_batch_size`) + approx_batch_num_rows: usize, + /// The schema of the dataset + schema: SchemaRef, + /// The memory size of the whole dataset + mem_size: usize, +} + +impl SortFuzzerTestGenerator { + /// Randomly pick a subset of `candidate_columns` to be used for this dataset + pub fn new( + num_rows: usize, + max_partitions: usize, + table_name: String, + candidate_columns: Vec, + set_memory_limit: bool, + rng_seed: u64, + ) -> Self { + let mut rng = StdRng::seed_from_u64(rng_seed); + let min_ncol = min(candidate_columns.len(), 5); + let max_ncol = min(candidate_columns.len(), 10); + let amount = rng.random_range(min_ncol..=max_ncol); + let selected_columns = candidate_columns + .choose_multiple(&mut rng, amount) + .cloned() + .collect(); + + Self { + num_rows, + max_partitions, + table_name, + selected_columns, + set_memory_limit, + dataset_state: None, + } + } + + /// The outer vector is the partitions, the inner vector is the chunked batches + /// within each partition. + /// The partition number is determined by `self.max_partitions`. + /// The chunked batch length is a random number between 1 and `self.num_rows` / + /// 100 (make sure a single batch won't exceed memory budget for external sort + /// executions) + /// + /// Hack: If we want the query to run under certain degree of parallelism, the + /// memory table should be generated with more partitions, due to https://github.com/apache/datafusion/issues/15088 + fn init_partitioned_staggered_batches(&mut self, rng_seed: u64) { + let mut rng = StdRng::seed_from_u64(rng_seed); + let num_partitions = rng.random_range(1..=self.max_partitions); + + let max_batch_size = self.num_rows / num_partitions / 50; + let target_partition_size = self.num_rows / num_partitions; + + let mut partitions = Vec::new(); + let mut schema = None; + for _ in 0..num_partitions { + let mut partition = Vec::new(); + let mut num_rows = 0; + + // For each partition, generate random batches until there is about enough + // rows for the specified total number of rows + while num_rows < target_partition_size { + // Generate a random batch of size between 1 and max_batch_size + + // Let edge case (1-row batch) more common + let (min_nrow, max_nrow) = if rng.random_bool(0.1) { + (1, 3) + } else { + (1, max_batch_size) + }; + + let mut record_batch_generator = RecordBatchGenerator::new( + min_nrow, + max_nrow, + self.selected_columns.clone(), + ) + .with_seed(rng.random()); + + let record_batch = record_batch_generator.generate().unwrap(); + num_rows += record_batch.num_rows(); + + if schema.is_none() { + schema = Some(record_batch.schema()); + println!(" Dataset schema:"); + println!(" {}", display_schema(schema.as_ref().unwrap())); + } + + partition.push(record_batch); + } + + partitions.push(partition); + } + + // After all partitions are created, optionally make one partition have 0/1 batch + if num_partitions > 2 && rng.random_bool(0.1) { + let partition_index = rng.random_range(0..num_partitions); + if rng.random_bool(0.5) { + // 0 batch + partitions[partition_index] = Vec::new(); + } else { + // 1 batch, keep the old first batch + let first_batch = partitions[partition_index].first().cloned(); + if let Some(batch) = first_batch { + partitions[partition_index] = vec![batch]; + } + } + } + + // Init self fields + let mem_size: usize = partitions + .iter() + .map(|partition| { + partition + .iter() + .map(get_record_batch_memory_size) + .sum::() + }) + .sum(); + + let dataset_size = partitions + .iter() + .map(|partition| { + partition + .iter() + .map(|batch| batch.num_rows()) + .sum::() + }) + .sum::(); + + let approx_batch_num_rows = max_batch_size; + + self.dataset_state = Some(DatasetState { + partitioned_staggered_batches: partitions, + dataset_size, + approx_batch_num_rows, + schema: schema.unwrap(), + mem_size, + }); + } + + /// Generates a random SQL query string and an optional limit value. + /// Returns a tuple containing the query string and an optional limit. + pub fn generate_random_query(&self, rng_seed: u64) -> (String, Option) { + let mut rng = StdRng::seed_from_u64(rng_seed); + + let num_columns = rng.random_range(1..=3).min(self.selected_columns.len()); + let selected_columns: Vec<_> = self + .selected_columns + .choose_multiple(&mut rng, num_columns) + .collect(); + + let mut order_by_clauses = Vec::new(); + for col in &selected_columns { + let mut clause = col.name.clone(); + if rng.random_bool(0.5) { + let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" }; + clause.push_str(&format!(" {order}")); + } + if rng.random_bool(0.5) { + let nulls = if rng.random_bool(0.5) { + "NULLS FIRST" + } else { + "NULLS LAST" + }; + clause.push_str(&format!(" {nulls}")); + } + order_by_clauses.push(clause); + } + + let dataset_size = self.dataset_state.as_ref().unwrap().dataset_size; + + let limit = if rng.random_bool(0.2) { + // Prefer edge cases for k like 1, dataset_size, etc. + Some(if rng.random_bool(0.5) { + let edge_cases = + [1, 2, 3, dataset_size - 1, dataset_size, dataset_size + 1]; + *edge_cases.choose(&mut rng).unwrap() + } else { + rng.random_range(1..=dataset_size) + }) + } else { + None + }; + + let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {l}")); + + let query = format!( + "SELECT {} FROM {} ORDER BY {}{}", + selected_columns + .iter() + .map(|col| col.name.clone()) + .collect::>() + .join(", "), + self.table_name, + order_by_clauses.join(", "), + limit_clause + ); + + (query, limit) + } + + pub fn generate_random_config( + &self, + rng_seed: u64, + with_memory_limit: bool, + ) -> Result { + let mut rng = StdRng::seed_from_u64(rng_seed); + let init_state = self.dataset_state.as_ref().unwrap(); + let dataset_size = init_state.mem_size; + let num_partitions = init_state.partitioned_staggered_batches.len(); + + // 30% to 200% of the dataset size (if `with_memory_limit` is false, config + // will use the default unbounded pool to override it later) + let memory_limit = rng.random_range( + (dataset_size as f64 * 0.5) as usize..=(dataset_size as f64 * 2.0) as usize, + ); + // 10% to 20% of the per-partition memory limit size + let per_partition_mem_limit = memory_limit / num_partitions; + let sort_spill_reservation_bytes = rng.random_range( + (per_partition_mem_limit as f64 * 0.2) as usize + ..=(per_partition_mem_limit as f64 * 0.3) as usize, + ); + + // 1 to 3 times of the approx batch size. Setting this to a very large nvalue + // will cause external sort to fail. + let sort_in_place_threshold_bytes = if with_memory_limit { + // For memory-limited query, setting `sort_in_place_threshold_bytes` too + // large will cause failure. + 0 + } else { + let dataset_size = self.dataset_state.as_ref().unwrap().dataset_size; + rng.random_range(0..=dataset_size * 2_usize) + }; + + // Set up strings for printing + let memory_limit_str = if with_memory_limit { + human_readable_size(memory_limit) + } else { + "Unbounded".to_string() + }; + let per_partition_limit_str = if with_memory_limit { + human_readable_size(per_partition_mem_limit) + } else { + "Unbounded".to_string() + }; + + println!(" Config: "); + println!(" Dataset size: {}", human_readable_size(dataset_size)); + println!(" Number of partitions: {num_partitions}"); + println!(" Batch size: {}", init_state.approx_batch_num_rows / 2); + println!(" Memory limit: {memory_limit_str}"); + println!(" Per partition memory limit: {per_partition_limit_str}"); + println!( + " Sort spill reservation bytes: {}", + human_readable_size(sort_spill_reservation_bytes) + ); + println!( + " Sort in place threshold bytes: {}", + human_readable_size(sort_in_place_threshold_bytes) + ); + + let config = SessionConfig::new() + .with_target_partitions(num_partitions) + .with_batch_size(init_state.approx_batch_num_rows / 2) + .with_sort_spill_reservation_bytes(sort_spill_reservation_bytes) + .with_sort_in_place_threshold_bytes(sort_in_place_threshold_bytes); + + let memory_pool: Arc = if with_memory_limit { + Arc::new(FairSpillPool::new(memory_limit)) + } else { + Arc::new(UnboundedMemoryPool::default()) + }; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .with_disk_manager_builder(DiskManagerBuilder::default()) + .build_arc()?; + + let ctx = SessionContext::new_with_config_rt(config, runtime); + + let dataset = &init_state.partitioned_staggered_batches; + let schema = &init_state.schema; + + let provider = MemTable::try_new(schema.clone(), dataset.clone())?; + ctx.register_table("sort_fuzz_table", Arc::new(provider))?; + + Ok(ctx) + } + + async fn fuzzer_run( + &mut self, + dataset_seed: u64, + query_seed: u64, + config_seed: u64, + ) -> Result> { + self.init_partitioned_staggered_batches(dataset_seed); + let (query_str, limit) = self.generate_random_query(query_seed); + println!(" Query:"); + println!(" {query_str}"); + + // ==== Execute the query ==== + + // Only enable memory limits if: + // 1. Query does not contain LIMIT (since topK does not support external execution) + // 2. Memory limiting is enabled in the test generator config + let with_mem_limit = !query_str.contains("LIMIT") && self.set_memory_limit; + + let ctx = self.generate_random_config(config_seed, with_mem_limit)?; + let df = ctx.sql(&query_str).await.unwrap(); + let results = df.collect().await.unwrap(); + + // ==== Check the result size is consistent with the limit ==== + let result_num_rows = results.iter().map(|batch| batch.num_rows()).sum::(); + let dataset_size = self.dataset_state.as_ref().unwrap().dataset_size; + + if let Some(limit) = limit { + let expected_num_rows = min(limit, dataset_size); + assert_eq!(result_num_rows, expected_num_rows); + } + + Ok(results) + } +} + +#[cfg(test)] +mod test { + use super::*; + + /// Given the same seed, the result should be the same + #[tokio::test] + async fn test_sort_query_fuzzer_deterministic() { + let gen_seed = 310104; + let mut test_generator = SortFuzzerTestGenerator::new( + 2000, + 3, + "sort_fuzz_table".to_string(), + get_supported_types_columns(gen_seed), + false, + gen_seed, + ); + + let res1 = test_generator.fuzzer_run(1, 2, 3).await.unwrap(); + let res2 = test_generator.fuzzer_run(1, 2, 3).await.unwrap(); + check_equality_of_batches(&res1, &res2).unwrap(); + } +} diff --git a/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs new file mode 100644 index 0000000000000..6c1bd316cdd39 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/spilling_fuzz_in_memory_constrained_env.rs @@ -0,0 +1,654 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fuzz Test for different operators in memory constrained environment + +use std::pin::Pin; +use std::sync::Arc; + +use crate::fuzz_cases::aggregate_fuzz::assert_spill_count_metric; +use crate::fuzz_cases::once_exec::OnceExec; +use arrow::array::UInt64Array; +use arrow::{array::StringArray, compute::SortOptions, record_batch::RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::Result; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::physical_plan::expressions::PhysicalSortExpr; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionConfig; +use datafusion_execution::memory_pool::units::{KB, MB}; +use datafusion_execution::memory_pool::{ + FairSpillPool, MemoryConsumer, MemoryReservation, +}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::{col, Column}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use futures::StreamExt; + +#[tokio::test] +async fn test_sort_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + 16 * KB as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_sort_test_with_limited_memory(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunTestWithLimitedMemoryArgs { + pool_size: usize, + task_ctx: Arc, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_sort_test_with_limited_memory( + mut args: RunTestWithLimitedMemoryArgs, +) -> Result { + let get_size_of_record_batch_to_generate = std::mem::replace( + &mut args.get_size_of_record_batch_to_generate, + Box::pin(move |_| unreachable!("should not be called after take")), + ); + + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..args.number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("col_0", &scan_schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]) + .unwrap(), + plan, + )); + + let result = sort_exec.execute(0, Arc::clone(&args.task_ctx))?; + + run_test(args, sort_exec, result).await +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {total_spill_files_size} should be greater than pool size {pool_size}", + ); + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 6 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_aggregate_with_high_cardinality_with_limited_memory_and_large_record_batch( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_test_aggregate_with_high_cardinality(RunTestWithLimitedMemoryArgs { + pool_size, + task_ctx: Arc::new(task_ctx), + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 6), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +async fn run_test_aggregate_with_high_cardinality( + mut args: RunTestWithLimitedMemoryArgs, +) -> Result { + let get_size_of_record_batch_to_generate = std::mem::replace( + &mut args.get_size_of_record_batch_to_generate, + Box::pin(move |_| unreachable!("should not be called after take")), + ); + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("col_0", 0)), + "col_0".to_string(), + )]); + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new( + array_agg_udaf(), + vec![col("col_1", &scan_schema).unwrap()], + ) + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, + )]; + + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(OnceExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..args.number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + aggregate_exec, + Arc::clone(&scan_schema), + )?); + + let result = aggregate_final.execute(0, Arc::clone(&args.task_ctx))?; + + run_test(args, aggregate_final, result).await +} + +async fn run_test( + args: RunTestWithLimitedMemoryArgs, + plan: Arc, + result_stream: SendableRecordBatchStream, +) -> Result { + let number_of_record_batches = args.number_of_record_batches; + + consume_stream_and_simulate_other_running_memory_consumers(args, result_stream) + .await?; + + let spill_count = assert_spill_count_metric(true, plan); + + assert!( + spill_count > 0, + "Expected spill, but did not, number of record batches: {number_of_record_batches}", + ); + + Ok(spill_count) +} + +/// Consume the stream and change the amount of memory used while consuming it based on the [`MemoryBehavior`] provided +async fn consume_stream_and_simulate_other_running_memory_consumers( + args: RunTestWithLimitedMemoryArgs, + mut result_stream: SendableRecordBatchStream, +) -> Result<()> { + let mut number_of_rows = 0; + let record_batch_size = args.task_ctx.session_config().batch_size() as u64; + + let memory_pool = args.task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result_stream.next().await { + match args.memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible( + args.pool_size, + &mut memory_reservation, + )?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_rows += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_rows, + args.number_of_record_batches * record_batch_size as usize + ); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs new file mode 100644 index 0000000000000..7f994daeaa58c --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs @@ -0,0 +1,387 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::{Arc, LazyLock}; + +use arrow::array::{Int32Array, StringArray, StringDictionaryBuilder}; +use arrow::datatypes::Int32Type; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::datasource::listing::{ListingOptions, ListingTable, ListingTableConfig}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource_parquet::ParquetFormat; +use datafusion_execution::object_store::ObjectStoreUrl; +use itertools::Itertools; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; +use parquet::arrow::ArrowWriter; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tokio::sync::Mutex; +use tokio::task::JoinSet; + +#[derive(Clone)] +struct TestDataSet { + store: Arc, + schema: Arc, +} + +/// List of in memory parquet files with UTF8 data +// Use a mutex rather than LazyLock to allow for async initialization +static TESTFILES: LazyLock>> = + LazyLock::new(|| Mutex::new(vec![])); + +async fn test_files() -> Vec { + let files_mutex = &TESTFILES; + let mut files = files_mutex.lock().await; + if !files.is_empty() { + return (*files).clone(); + } + + let mut rng = StdRng::seed_from_u64(0); + + for nulls_in_ids in [false, true] { + for nulls_in_names in [false, true] { + for nulls_in_departments in [false, true] { + let store = Arc::new(InMemory::new()); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, nulls_in_ids), + Field::new("name", DataType::Utf8, nulls_in_names), + Field::new( + "department", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + nulls_in_departments, + ), + ])); + + let name_choices = if nulls_in_names { + [Some("Alice"), Some("Bob"), None, Some("David"), None] + } else { + [ + Some("Alice"), + Some("Bob"), + Some("Charlie"), + Some("David"), + Some("Eve"), + ] + }; + + let department_choices = if nulls_in_departments { + [ + Some("Theater"), + Some("Engineering"), + None, + Some("Arts"), + None, + ] + } else { + [ + Some("Theater"), + Some("Engineering"), + Some("Healthcare"), + Some("Arts"), + Some("Music"), + ] + }; + + // Generate 5 files, some with overlapping or repeated ids some without + for i in 0..5 { + let num_batches = rng.random_range(1..3); + let mut batches = Vec::with_capacity(num_batches); + for _ in 0..num_batches { + let num_rows = 25; + let ids = Int32Array::from_iter((0..num_rows).map(|file| { + if nulls_in_ids { + if rng.random_bool(1.0 / 10.0) { + None + } else { + Some(rng.random_range(file..file + 5)) + } + } else { + Some(rng.random_range(file..file + 5)) + } + })); + let names = StringArray::from_iter((0..num_rows).map(|_| { + // randomly select a name + let idx = rng.random_range(0..name_choices.len()); + name_choices[idx].map(|s| s.to_string()) + })); + let mut departments = StringDictionaryBuilder::::new(); + for _ in 0..num_rows { + // randomly select a department + let idx = rng.random_range(0..department_choices.len()); + departments.append_option(department_choices[idx].as_ref()); + } + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(ids), + Arc::new(names), + Arc::new(departments.finish()), + ], + ) + .unwrap(); + batches.push(batch); + } + let mut buf = vec![]; + { + let mut writer = + ArrowWriter::try_new(&mut buf, schema.clone(), None).unwrap(); + for batch in batches { + writer.write(&batch).unwrap(); + writer.flush().unwrap(); + } + writer.flush().unwrap(); + writer.finish().unwrap(); + } + let payload = PutPayload::from(buf); + let path = Path::from(format!("file_{i}.parquet")); + store.put(&path, payload).await.unwrap(); + } + files.push(TestDataSet { store, schema }); + } + } + } + (*files).clone() +} + +struct RunResult { + results: Vec, + explain_plan: String, +} + +async fn run_query_with_config( + query: &str, + config: SessionConfig, + dataset: TestDataSet, +) -> RunResult { + let store = dataset.store; + let schema = dataset.schema; + let ctx = SessionContext::new_with_config(config); + let url = ObjectStoreUrl::parse("memory://").unwrap(); + ctx.register_object_store(url.as_ref(), store.clone()); + + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let options = ListingOptions::new(format); + let table_path = ListingTableUrl::parse("memory:///").unwrap(); + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + let table = Arc::new(ListingTable::try_new(config).unwrap()); + + ctx.register_table("test_table", table).unwrap(); + + let results = ctx.sql(query).await.unwrap().collect().await.unwrap(); + let explain_batches = ctx + .sql(&format!("EXPLAIN ANALYZE {query}")) + .await + .unwrap() + .collect() + .await + .unwrap(); + let explain_plan = pretty_format_batches(&explain_batches).unwrap().to_string(); + RunResult { + results, + explain_plan, + } +} + +#[derive(Debug)] +struct RunQueryResult { + query: String, + result: Vec, + expected: Vec, +} + +impl RunQueryResult { + fn expected_formatted(&self) -> String { + format!("{}", pretty_format_batches(&self.expected).unwrap()) + } + + fn result_formatted(&self) -> String { + format!("{}", pretty_format_batches(&self.result).unwrap()) + } + + fn is_ok(&self) -> bool { + self.expected_formatted() == self.result_formatted() + } +} + +/// Iterate over each line in the plan and check that one of them has `DataSourceExec` and `DynamicFilter` in the same line. +fn has_dynamic_filter_expr_pushdown(plan: &str) -> bool { + for line in plan.lines() { + if line.contains("DataSourceExec") && line.contains("DynamicFilter") { + return true; + } + } + false +} + +async fn run_query( + query: String, + cfg: SessionConfig, + dataset: TestDataSet, +) -> RunQueryResult { + let cfg_with_dynamic_filters = cfg + .clone() + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let cfg_without_dynamic_filters = cfg + .clone() + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", false); + + let expected_result = + run_query_with_config(&query, cfg_without_dynamic_filters, dataset.clone()).await; + let result = + run_query_with_config(&query, cfg_with_dynamic_filters, dataset.clone()).await; + // Check that dynamic filters were actually pushed down + if !has_dynamic_filter_expr_pushdown(&result.explain_plan) { + panic!( + "Dynamic filter was not pushed down in query: {query}\n\n{}", + result.explain_plan + ); + } + + RunQueryResult { + query: query.to_string(), + result: result.results, + expected: expected_result.results, + } +} + +struct TestCase { + query: String, + cfg: SessionConfig, + dataset: TestDataSet, +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_fuzz_topk_filter_pushdown() { + let order_columns = ["id", "name", "department"]; + let order_directions = ["ASC", "DESC"]; + let null_orders = ["NULLS FIRST", "NULLS LAST"]; + + let start = datafusion_common::instant::Instant::now(); + let mut orders: HashMap> = HashMap::new(); + for order_column in &order_columns { + for order_direction in &order_directions { + for null_order in &null_orders { + // if there is a vec for this column insert the order, otherwise create a new vec + let ordering = format!("{order_column} {order_direction} {null_order}"); + match orders.get_mut(*order_column) { + Some(order_vec) => { + order_vec.push(ordering); + } + None => { + orders.insert((*order_column).to_string(), vec![ordering]); + } + } + } + } + } + + let mut queries = vec![]; + + for limit in [1, 10] { + for num_order_by_columns in [1, 2, 3] { + for order_columns in ["id", "name", "department"] + .iter() + .combinations(num_order_by_columns) + { + for orderings in order_columns + .iter() + .map(|col| orders.get(**col).unwrap()) + .multi_cartesian_product() + { + let query = format!( + "SELECT * FROM test_table ORDER BY {} LIMIT {}", + orderings.into_iter().join(", "), + limit + ); + queries.push(query); + } + } + } + } + + queries.sort_unstable(); + println!( + "Generated {} queries in {:?}", + queries.len(), + start.elapsed() + ); + + let start = datafusion_common::instant::Instant::now(); + let datasets = test_files().await; + println!("Generated test files in {:?}", start.elapsed()); + + let mut test_cases = vec![]; + for enable_filter_pushdown in [true, false] { + for query in &queries { + for dataset in &datasets { + let mut cfg = SessionConfig::new(); + cfg = cfg.set_bool( + "datafusion.optimizer.enable_dynamic_filter_pushdown", + enable_filter_pushdown, + ); + test_cases.push(TestCase { + query: query.to_string(), + cfg, + dataset: dataset.clone(), + }); + } + } + } + + let start = datafusion_common::instant::Instant::now(); + let mut join_set = JoinSet::new(); + for tc in test_cases { + join_set.spawn(run_query(tc.query, tc.cfg, tc.dataset)); + } + let mut results = join_set.join_all().await; + results.sort_unstable_by(|a, b| a.query.cmp(&b.query)); + println!("Ran {} test cases in {:?}", results.len(), start.elapsed()); + + let failures = results + .iter() + .filter(|result| !result.is_ok()) + .collect::>(); + + for failure in &failures { + println!("Failure:"); + println!("Query:\n{}", failure.query); + println!("\nExpected:\n{}", failure.expected_formatted()); + println!("\nResult:\n{}", failure.result_formatted()); + println!("\n\n"); + } + + if !failures.is_empty() { + panic!("Some test cases failed"); + } else { + println!("All test cases passed"); + } +} diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 6b166dd32782f..65a41d39d3c54 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -35,7 +35,7 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::HashMap; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; +use datafusion_expr::type_coercion::functions::fields_with_aggregate_udf; use datafusion_expr::{ WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; @@ -51,7 +51,7 @@ use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use test_utils::add_empty_batches; @@ -252,7 +252,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> { ]; let partitionby_exprs = vec![]; - let orderby_exprs = LexOrdering::default(); // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -285,10 +284,12 @@ async fn bounded_window_causal_non_causal() -> Result<()> { fn_name.to_string(), &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &[], Arc::new(window_frame), - &extended_schema, + extended_schema, false, + false, + None, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![window_expr], @@ -398,8 +399,8 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(lead_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), - lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..1000)))), ], ), ); @@ -409,8 +410,8 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(lag_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), - lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..1000)))), ], ), ); @@ -435,12 +436,12 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(nth_value_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), ], ), ); - let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); + let rand_fn_idx = rng.random_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); @@ -448,13 +449,13 @@ fn get_random_function( if !args.is_empty() { // Do type coercion first argument let a = args[0].clone(); - let dt = a.data_type(schema.as_ref()).unwrap(); - let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap(); - args[0] = cast(a, schema, coerced[0].clone()).unwrap(); + let dt = a.return_field(schema.as_ref()).unwrap(); + let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap(); + args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); } } - (window_fn.clone(), args, fn_name.to_string()) + (window_fn.clone(), args, (*fn_name).to_string()) } fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { @@ -463,12 +464,12 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { is_preceding: bool, } let first_bound = Utils { - val: rng.gen_range(0..10), - is_preceding: rng.gen_range(0..2) == 0, + val: rng.random_range(0..10), + is_preceding: rng.random_range(0..2) == 0, }; let second_bound = Utils { - val: rng.gen_range(0..10), - is_preceding: rng.gen_range(0..2) == 0, + val: rng.random_range(0..10), + is_preceding: rng.random_range(0..2) == 0, }; let (start_bound, end_bound) = if first_bound.is_preceding == second_bound.is_preceding { @@ -485,7 +486,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { (second_bound, first_bound) }; // 0 means Range, 1 means Rows, 2 means GROUPS - let rand_num = rng.gen_range(0..3); + let rand_num = rng.random_range(0..3); let units = if rand_num < 1 { WindowFrameUnits::Range } else if rand_num < 2 { @@ -517,7 +518,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { }; let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound); // with 10% use unbounded preceding in tests - if rng.gen_range(0..10) == 0 { + if rng.random_range(0..10) == 0 { window_frame.start_bound = WindowFrameBound::Preceding(ScalarValue::Int32(None)); } @@ -545,7 +546,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { }; let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound); // with 10% use unbounded preceding in tests - if rng.gen_range(0..10) == 0 { + if rng.random_range(0..10) == 0 { window_frame.start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); } @@ -569,7 +570,7 @@ fn convert_bound_to_current_row_if_applicable( match bound { WindowFrameBound::Preceding(value) | WindowFrameBound::Following(value) => { if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) { - if value == &zero && rng.gen_range(0..2) == 0 { + if value == &zero && rng.random_range(0..2) == 0 { *bound = WindowFrameBound::CurrentRow; } } @@ -594,7 +595,7 @@ async fn run_window_test( let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); let window_frame = get_random_window_frame(&mut rng, is_linear); - let mut orderby_exprs = LexOrdering::default(); + let mut orderby_exprs = vec![]; for column in &orderby_columns { orderby_exprs.push(PhysicalSortExpr { expr: col(column, &schema)?, @@ -602,13 +603,13 @@ async fn run_window_test( }) } if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() { - orderby_exprs = LexOrdering::new(orderby_exprs[0..1].to_vec()); + orderby_exprs.truncate(1); } let mut partitionby_exprs = vec![]; for column in &partition_by_columns { partitionby_exprs.push(col(column, &schema)?); } - let mut sort_keys = LexOrdering::default(); + let mut sort_keys = vec![]; for partition_by_expr in &partitionby_exprs { sort_keys.push(PhysicalSortExpr { expr: partition_by_expr.clone(), @@ -622,7 +623,7 @@ async fn run_window_test( } let concat_input_record = concat_batches(&schema, &input1)?; - let source_sort_keys = LexOrdering::new(vec![ + let source_sort_keys: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: Default::default(), @@ -635,7 +636,8 @@ async fn run_window_test( expr: col("c", &schema)?, options: Default::default(), }, - ]); + ] + .into(); let mut exec1 = DataSourceExec::from_data_source( MemorySourceConfig::try_new(&[vec![concat_input_record]], schema.clone(), None)? .try_with_sort_information(vec![source_sort_keys.clone()])?, @@ -643,7 +645,9 @@ async fn run_window_test( // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. if is_linear { - exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; + if let Some(ordering) = LexOrdering::new(sort_keys) { + exec1 = Arc::new(SortExec::new(ordering, exec1)) as _; + } } let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; @@ -654,17 +658,19 @@ async fn run_window_test( fn_name.clone(), &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs.clone(), Arc::new(window_frame.clone()), - &extended_schema, + Arc::clone(&extended_schema), false, + false, + None, )?], exec1, false, )?) as _; let exec2 = DataSourceExec::from_data_source( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None)? - .try_with_sort_information(vec![source_sort_keys.clone()])?, + MemorySourceConfig::try_new(&[input1], schema, None)? + .try_with_sort_information(vec![source_sort_keys])?, ); let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![create_window_expr( @@ -672,10 +678,12 @@ async fn run_window_test( fn_name, &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs, Arc::new(window_frame.clone()), - &extended_schema, + extended_schema, + false, false, + None, )?], exec2, search_mode.clone(), @@ -728,7 +736,7 @@ async fn run_window_test( for (line1, line2) in usual_formatted_sorted.iter().zip(running_formatted_sorted) { - println!("{:?} --- {:?}", line1, line2); + println!("{line1:?} --- {line2:?}"); } unreachable!(); } @@ -758,9 +766,9 @@ pub(crate) fn make_staggered_batches( let mut input5: Vec = vec!["".to_string(); len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i32, - rng.gen_range(0..n_distinct) as i32, - rng.gen_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, ) }); input123.sort(); @@ -788,7 +796,7 @@ pub(crate) fn make_staggered_batches( let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { batches.push(remainder); break; @@ -798,7 +806,7 @@ pub(crate) fn make_staggered_batches( } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); } diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 9196efec972c1..c9f33f6fdf0f4 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -65,3 +65,42 @@ mod config_namespace { } } } + +mod config_field { + // NO other imports! + use datafusion_common::config_field; + + #[test] + fn test_macro() { + #[derive(Debug)] + #[allow(dead_code)] + struct E; + + impl std::fmt::Display for E { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + } + + impl std::error::Error for E {} + + #[allow(dead_code)] + struct S; + + impl std::str::FromStr for S { + type Err = E; + + fn from_str(_s: &str) -> Result { + unimplemented!() + } + } + + impl std::fmt::Display for S { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + } + + config_field!(S); + } +} diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs index 64ab1378340aa..e1d5f1b1ab198 100644 --- a/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/sort_mem_validation.rs @@ -31,7 +31,7 @@ static INIT: Once = Once::new(); // =========================================================================== // Test runners: -// Runners are splitted into multiple tests to run in parallel +// Runners are split into multiple tests to run in parallel // =========================================================================== #[test] @@ -98,11 +98,9 @@ fn init_once() { fn spawn_test_process(test: &str) { init_once(); - let test_path = format!( - "memory_limit::memory_limit_validation::sort_mem_validation::{}", - test - ); - info!("Running test: {}", test_path); + let test_path = + format!("memory_limit::memory_limit_validation::sort_mem_validation::{test}"); + info!("Running test: {test_path}"); // Run the test command let output = Command::new("cargo") @@ -125,7 +123,7 @@ fn spawn_test_process(test: &str) { let stdout = str::from_utf8(&output.stdout).unwrap_or(""); let stderr = str::from_utf8(&output.stderr).unwrap_or(""); - info!("{}", stdout); + info!("{stdout}"); assert!( output.status.success(), diff --git a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs index bdf30c140afff..7b157b707a6de 100644 --- a/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs +++ b/datafusion/core/tests/memory_limit/memory_limit_validation/utils.rs @@ -18,7 +18,7 @@ use datafusion_common_runtime::SpawnedTask; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use sysinfo::System; +use sysinfo::{ProcessRefreshKind, ProcessesToUpdate, System}; use tokio::time::{interval, Duration}; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -62,7 +62,11 @@ where loop { interval.tick().await; - sys.refresh_all(); + sys.refresh_processes_specifics( + ProcessesToUpdate::Some(&[pid]), + true, + ProcessRefreshKind::nothing().with_memory(), + ); if let Some(process) = sys.process(pid) { let rss_bytes = process.memory(); max_rss_clone @@ -116,8 +120,8 @@ where /// # Example /// /// utils::validate_query_with_memory_limits( -/// 40_000_000 * 2, -/// Some(40_000_000), +/// 40_000_000 * 2, +/// Some(40_000_000), /// "SELECT * FROM generate_series(1, 100000000) AS t(i) ORDER BY i", /// "SELECT * FROM generate_series(1, 10000000) AS t(i) ORDER BY i" /// ); diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 2702954e77830..89bc48b1e6348 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -28,10 +28,10 @@ use arrow::compute::SortOptions; use arrow::datatypes::{Int32Type, SchemaRef}; use arrow_schema::{DataType, Field, Schema}; use datafusion::assert_batches_eq; +use datafusion::config::SpillCompression; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::datasource::{MemTable, TableProvider}; -use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -41,14 +41,17 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_catalog::streaming::StreamingTable; use datafusion_catalog::Session; use datafusion_common::{assert_contains, Result}; +use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::memory_pool::{ FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, }; +use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::collect as collect_batches; use datafusion_physical_plan::common::collect; use datafusion_physical_plan::spill::get_record_batch_memory_size; use rand::Rng; @@ -82,7 +85,8 @@ async fn group_by_none() { TestCase::new() .with_query("select median(request_bytes) from t") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: AggregateStream" + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n AggregateStream", ]) .with_memory_limit(2_000) .run() @@ -94,7 +98,7 @@ async fn group_by_row_hash() { TestCase::new() .with_query("select count(*) from t GROUP BY response_bytes") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: GroupedHashAggregateStream" + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" ]) .with_memory_limit(2_000) .run() @@ -107,7 +111,7 @@ async fn group_by_hash() { // group by dict column .with_query("select count(*) from t GROUP BY service, host, pod, container") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: GroupedHashAggregateStream" + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" ]) .with_memory_limit(1_000) .run() @@ -120,7 +124,8 @@ async fn join_by_key_multiple_partitions() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -134,7 +139,8 @@ async fn join_by_key_single_partition() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -147,7 +153,7 @@ async fn join_by_expression() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]", + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]", ]) .with_memory_limit(1_000) .run() @@ -159,7 +165,8 @@ async fn cross_join() { TestCase::new() .with_query("select t1.*, t2.* from t t1 CROSS JOIN t t2") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n CrossJoinExec", ]) .with_memory_limit(1_000) .run() @@ -202,7 +209,7 @@ async fn sort_merge_join_spill() { ) .with_memory_limit(1_000) .with_config(config) - .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_disk_manager_builder(DiskManagerBuilder::default()) .with_scenario(Scenario::AccessLogStreaming) .run() .await @@ -215,7 +222,7 @@ async fn symmetric_hash_join() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: SymmetricHashJoinStream", + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n SymmetricHashJoinStream", ]) .with_memory_limit(1_000) .with_scenario(Scenario::AccessLogStreaming) @@ -233,7 +240,7 @@ async fn sort_preserving_merge() { // so only a merge is needed .with_query("select * from t ORDER BY a ASC NULLS LAST, b ASC NULLS LAST LIMIT 10") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: SortPreservingMergeExec", + "Resources exhausted: Additional allocation failed", "with top memory consumers (across reservations) as:\n SortPreservingMergeExec", ]) // provide insufficient memory to merge .with_memory_limit(partition_size / 2) @@ -286,7 +293,7 @@ async fn sort_spill_reservation() { .with_memory_limit(mem_limit) // use a single partition so only a sort is needed .with_scenario(scenario) - .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_disk_manager_builder(DiskManagerBuilder::default()) .with_expected_plan( // It is important that this plan only has a SortExec, not // also merge, so we can ensure the sort could finish @@ -312,8 +319,9 @@ async fn sort_spill_reservation() { test.clone() .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:", - "bytes for ExternalSorterMerge", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:", + "B for ExternalSorterMerge", ]) .with_config(config) .run() @@ -342,7 +350,8 @@ async fn oom_recursive_cte() { SELECT * FROM nodes;", ) .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: RecursiveQuery", + "Resources exhausted: Additional allocation failed", + "with top memory consumers (across reservations) as:\n RecursiveQuery", ]) .with_memory_limit(2_000) .run() @@ -352,7 +361,7 @@ async fn oom_recursive_cte() { #[tokio::test] async fn oom_parquet_sink() { let dir = tempfile::tempdir().unwrap(); - let path = dir.into_path().join("test.parquet"); + let path = dir.path().join("test.parquet"); let _ = File::create(path.clone()).await.unwrap(); TestCase::new() @@ -376,7 +385,7 @@ async fn oom_parquet_sink() { #[tokio::test] async fn oom_with_tracked_consumer_pool() { let dir = tempfile::tempdir().unwrap(); - let path = dir.into_path().join("test.parquet"); + let path = dir.path().join("test.parquet"); let _ = File::create(path.clone()).await.unwrap(); TestCase::new() @@ -394,7 +403,7 @@ async fn oom_with_tracked_consumer_pool() { .with_expected_errors(vec![ "Failed to allocate additional", "for ParquetSink(ArrowColumnWriter)", - "Additional allocation failed with top memory consumers (across reservations) as: ParquetSink(ArrowColumnWriter)" + "Additional allocation failed", "with top memory consumers (across reservations) as:\n ParquetSink(ArrowColumnWriter)" ]) .with_memory_pool(Arc::new( TrackConsumersPool::new( @@ -406,6 +415,19 @@ async fn oom_with_tracked_consumer_pool() { .await } +#[tokio::test] +async fn oom_grouped_hash_aggregate() { + TestCase::new() + .with_query("SELECT COUNT(*), SUM(request_bytes) FROM t GROUP BY host") + .with_expected_errors(vec![ + "Failed to allocate additional", + "GroupedHashAggregateStream[0] (count(1), sum(t.request_bytes))", + ]) + .with_memory_limit(1_000) + .run() + .await +} + /// For regression case: if spilled `StringViewArray`'s buffer will be referenced by /// other batches which are also need to be spilled, then the spill writer will /// repeatedly write out the same buffer, and after reading back, each batch's size @@ -415,7 +437,7 @@ async fn oom_with_tracked_consumer_pool() { /// If there is memory explosion for spilled record batch, this test will fail. #[tokio::test] async fn test_stringview_external_sort() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let array_length = 1000; let num_batches = 200; // Batches contain two columns: random 100-byte string, and random i32 @@ -425,7 +447,7 @@ async fn test_stringview_external_sort() { let strings: Vec = (0..array_length) .map(|_| { (0..100) - .map(|_| rng.gen_range(0..=u8::MAX) as char) + .map(|_| rng.random_range(0..=u8::MAX) as char) .collect() }) .collect(); @@ -433,8 +455,9 @@ async fn test_stringview_external_sort() { let string_array = StringViewArray::from(strings); let array_ref: ArrayRef = Arc::new(string_array); - let random_numbers: Vec = - (0..array_length).map(|_| rng.gen_range(0..=1000)).collect(); + let random_numbers: Vec = (0..array_length) + .map(|_| rng.random_range(0..=1000)) + .collect(); let int_array = Int32Array::from(random_numbers); let int_array_ref: ArrayRef = Arc::new(int_array); @@ -456,7 +479,9 @@ async fn test_stringview_external_sort() { .with_memory_pool(Arc::new(FairSpillPool::new(60 * 1024 * 1024))); let runtime = builder.build_arc().unwrap(); - let config = SessionConfig::new().with_sort_spill_reservation_bytes(40 * 1024 * 1024); + let config = SessionConfig::new() + .with_sort_spill_reservation_bytes(40 * 1024 * 1024) + .with_repartition_file_scans(false); let ctx = SessionContext::new_with_config_rt(config, runtime); ctx.register_table("t", Arc::new(table)).unwrap(); @@ -524,6 +549,173 @@ async fn test_external_sort_zero_merge_reservation() { assert!(spill_count > 0); } +// Tests for disk limit (`max_temp_directory_size` in `DiskManager`) +// ------------------------------------------------------------------ + +// Create a new `SessionContext` with specified disk limit, memory pool limit, and spill compression codec +async fn setup_context( + disk_limit: u64, + memory_pool_limit: usize, + spill_compression: SpillCompression, +) -> Result { + let disk_manager = DiskManagerBuilder::default() + .with_mode(DiskManagerMode::OsTmpDirectory) + .with_max_temp_directory_size(disk_limit) + .build()?; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(FairSpillPool::new(memory_pool_limit))) + .build_arc() + .unwrap(); + + let runtime = Arc::new(RuntimeEnv { + memory_pool: runtime.memory_pool.clone(), + disk_manager: Arc::new(disk_manager), + cache_manager: runtime.cache_manager.clone(), + object_store_registry: runtime.object_store_registry.clone(), + #[cfg(feature = "parquet_encryption")] + parquet_encryption_factory_registry: runtime + .parquet_encryption_factory_registry + .clone(), + }); + + let config = SessionConfig::new() + .with_sort_spill_reservation_bytes(64 * 1024) // 256KB + .with_sort_in_place_threshold_bytes(0) + .with_spill_compression(spill_compression) + .with_batch_size(64) // To reduce test memory usage + .with_target_partitions(1); + + Ok(SessionContext::new_with_config_rt(config, runtime)) +} + +/// If the spilled bytes exceed the disk limit, the query should fail +/// (specified by `max_temp_directory_size` in `DiskManager`) +#[tokio::test] +async fn test_disk_spill_limit_reached() -> Result<()> { + let spill_compression = SpillCompression::Uncompressed; + let ctx = setup_context(1024 * 1024, 1024 * 1024, spill_compression).await?; // 1MB disk limit, 1MB memory limit + + let df = ctx + .sql("select * from generate_series(1, 1000000000000) as t1(v1) order by v1") + .await + .unwrap(); + + let err = df.collect().await.unwrap_err(); + assert_contains!( + err.to_string(), + "The used disk space during the spilling process has exceeded the allowable limit" + ); + + Ok(()) +} + +/// External query should succeed, if the spilled bytes is less than the disk limit +/// Also verify that after the query is finished, all the disk usage accounted by +/// tempfiles are cleaned up. +#[tokio::test] +async fn test_disk_spill_limit_not_reached() -> Result<()> { + let disk_spill_limit = 1024 * 1024; // 1MB + let spill_compression = SpillCompression::Uncompressed; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit + + let df = ctx + .sql("select * from generate_series(1, 10000) as t1(v1) order by v1") + .await + .unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + + let task_ctx = ctx.task_ctx(); + let _ = collect_batches(Arc::clone(&plan), task_ctx) + .await + .expect("Query execution failed"); + + let spill_count = plan.metrics().unwrap().spill_count().unwrap(); + let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); + + println!("spill count {spill_count}, spill bytes {spilled_bytes}"); + assert!(spill_count > 0); + assert!((spilled_bytes as u64) < disk_spill_limit); + + // Verify that all temporary files have been properly cleaned up by checking + // that the total disk usage tracked by the disk manager is zero + let current_disk_usage = ctx.runtime_env().disk_manager.used_disk_space(); + assert_eq!(current_disk_usage, 0); + + Ok(()) +} + +/// External query should succeed using zstd as spill compression codec and +/// and all temporary spill files are properly cleaned up after execution. +/// Note: This test does not inspect file contents (e.g. magic number), +/// as spill files are automatically deleted on drop. +#[tokio::test] +async fn test_spill_file_compressed_with_zstd() -> Result<()> { + let disk_spill_limit = 1024 * 1024; // 1MB + let spill_compression = SpillCompression::Zstd; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit, zstd + + let df = ctx + .sql("select * from generate_series(1, 100000) as t1(v1) order by v1") + .await + .unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + + let task_ctx = ctx.task_ctx(); + let _ = collect_batches(Arc::clone(&plan), task_ctx) + .await + .expect("Query execution failed"); + + let spill_count = plan.metrics().unwrap().spill_count().unwrap(); + let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); + + println!("spill count {spill_count}"); + assert!(spill_count > 0); + assert!((spilled_bytes as u64) < disk_spill_limit); + + // Verify that all temporary files have been properly cleaned up by checking + // that the total disk usage tracked by the disk manager is zero + let current_disk_usage = ctx.runtime_env().disk_manager.used_disk_space(); + assert_eq!(current_disk_usage, 0); + + Ok(()) +} + +/// External query should succeed using lz4_frame as spill compression codec and +/// and all temporary spill files are properly cleaned up after execution. +/// Note: This test does not inspect file contents (e.g. magic number), +/// as spill files are automatically deleted on drop. +#[tokio::test] +async fn test_spill_file_compressed_with_lz4_frame() -> Result<()> { + let disk_spill_limit = 1024 * 1024; // 1MB + let spill_compression = SpillCompression::Lz4Frame; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit, lz4_frame + + let df = ctx + .sql("select * from generate_series(1, 100000) as t1(v1) order by v1") + .await + .unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + + let task_ctx = ctx.task_ctx(); + let _ = collect_batches(Arc::clone(&plan), task_ctx) + .await + .expect("Query execution failed"); + + let spill_count = plan.metrics().unwrap().spill_count().unwrap(); + let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); + + println!("spill count {spill_count}"); + assert!(spill_count > 0); + assert!((spilled_bytes as u64) < disk_spill_limit); + + // Verify that all temporary files have been properly cleaned up by checking + // that the total disk usage tracked by the disk manager is zero + let current_disk_usage = ctx.runtime_env().disk_manager.used_disk_space(); + assert_eq!(current_disk_usage, 0); + + Ok(()) +} /// Run the query with the specified memory limit, /// and verifies the expected errors are returned #[derive(Clone, Debug)] @@ -536,7 +728,7 @@ struct TestCase { scenario: Scenario, /// How should the disk manager (that allows spilling) be /// configured? Defaults to `Disabled` - disk_manager_config: DiskManagerConfig, + disk_manager_builder: DiskManagerBuilder, /// Expected explain plan, if non-empty expected_plan: Vec, /// Is the plan expected to pass? Defaults to false @@ -552,7 +744,8 @@ impl TestCase { config: SessionConfig::new(), memory_pool: None, scenario: Scenario::AccessLog, - disk_manager_config: DiskManagerConfig::Disabled, + disk_manager_builder: DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Disabled), expected_plan: vec![], expected_success: false, } @@ -609,17 +802,17 @@ impl TestCase { /// Specify if the disk manager should be enabled. If true, /// operators that support it can spill - pub fn with_disk_manager_config( + pub fn with_disk_manager_builder( mut self, - disk_manager_config: DiskManagerConfig, + disk_manager_builder: DiskManagerBuilder, ) -> Self { - self.disk_manager_config = disk_manager_config; + self.disk_manager_builder = disk_manager_builder; self } /// Specify an expected plan to review pub fn with_expected_plan(mut self, expected_plan: &[&str]) -> Self { - self.expected_plan = expected_plan.iter().map(|s| s.to_string()).collect(); + self.expected_plan = expected_plan.iter().map(|s| (*s).to_string()).collect(); self } @@ -632,7 +825,7 @@ impl TestCase { memory_pool, config, scenario, - disk_manager_config, + disk_manager_builder, expected_plan, expected_success, } = self; @@ -641,7 +834,7 @@ impl TestCase { let mut builder = RuntimeEnvBuilder::new() // disk manager setting controls the spilling - .with_disk_manager(disk_manager_config) + .with_disk_manager_builder(disk_manager_builder) .with_memory_limit(memory_limit, MEMORY_FRACTION); if let Some(pool) = memory_pool { @@ -772,11 +965,10 @@ impl Scenario { single_row_batches, } => { use datafusion::physical_expr::expressions::col; - let batches: Vec> = std::iter::repeat(maybe_split_batches( - dict_batches(), - *single_row_batches, - )) - .take(*partitions) + let batches: Vec> = std::iter::repeat_n( + maybe_split_batches(dict_batches(), *single_row_batches), + *partitions, + ) .collect(); let schema = batches[0][0].schema(); @@ -784,16 +976,11 @@ impl Scenario { descending: false, nulls_first: false, }; - let sort_information = vec![LexOrdering::new(vec![ - PhysicalSortExpr { - expr: col("a", &schema).unwrap(), - options, - }, - PhysicalSortExpr { - expr: col("b", &schema).unwrap(), - options, - }, - ])]; + let sort_information = vec![[ + PhysicalSortExpr::new(col("a", &schema).unwrap(), options), + PhysicalSortExpr::new(col("b", &schema).unwrap(), options), + ] + .into()]; let table = SortedTableProvider::new(batches, sort_information); Arc::new(table) @@ -907,9 +1094,9 @@ fn batches_byte_size(batches: &[RecordBatch]) -> usize { } #[derive(Debug)] -struct DummyStreamPartition { - schema: SchemaRef, - batches: Vec, +pub(crate) struct DummyStreamPartition { + pub(crate) schema: SchemaRef, + pub(crate) batches: Vec, } impl PartitionStream for DummyStreamPartition { diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 585540bd58754..9899a0158fb8a 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -18,6 +18,7 @@ //! Tests for the DataFusion SQL query planner that require functions from the //! datafusion-functions crate. +use insta::assert_snapshot; use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -56,9 +57,14 @@ fn init() { #[test] fn select_arrow_cast() { let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; - let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ - \n EmptyRelation"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: Float64(1234) AS f64, LargeUtf8("foo") AS large + EmptyRelation: rows=1 + "# + ); } #[test] fn timestamp_nano_ts_none_predicates() -> Result<()> { @@ -68,11 +74,15 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> { // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned - let expected = - "Projection: test.col_int32\ - \n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\ - \n TableScan: test projection=[col_int32, col_ts_nano_none]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: test.col_int32 + Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None) + TableScan: test projection=[col_int32, col_ts_nano_none] + " + ); Ok(()) } @@ -84,10 +94,15 @@ fn timestamp_nano_ts_utc_predicates() { // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned - let expected = - "Projection: test.col_int32\n Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some(\"+00:00\"))\ - \n TableScan: test projection=[col_int32, col_ts_nano_utc]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: test.col_int32 + Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some("+00:00")) + TableScan: test projection=[col_int32, col_ts_nano_utc] + "# + ); } #[test] @@ -95,10 +110,14 @@ fn concat_literals() -> Result<()> { let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \ AS col FROM test"; - let expected = - "Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ - \n TableScan: test projection=[col_int32, col_utf8]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: concat(Utf8("true"), CAST(test.col_int32 AS Utf8), Utf8("falsehello"), test.col_utf8, Utf8("123.4")) AS col + TableScan: test projection=[col_int32, col_utf8] + "# + ); Ok(()) } @@ -107,16 +126,15 @@ fn concat_ws_literals() -> Result<()> { let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \ AS col FROM test"; - let expected = - "Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ - \n TableScan: test projection=[col_int32, col_utf8]"; - quick_test(sql, expected); - Ok(()) -} - -fn quick_test(sql: &str, expected_plan: &str) { let plan = test_sql(sql).unwrap(); - assert_eq!(expected_plan, format!("{}", plan)); + assert_snapshot!( + plan, + @r#" + Projection: concat_ws(Utf8("-"), Utf8("true"), CAST(test.col_int32 AS Utf8), Utf8("false-hello"), test.col_utf8, Utf8("12--3.4")) AS col + TableScan: test projection=[col_int32, col_utf8] + "# + ); + Ok(()) } fn test_sql(sql: &str) -> Result { @@ -142,7 +160,7 @@ fn test_sql(sql: &str) -> Result { let analyzer = Analyzer::new(); let optimizer = Optimizer::new(); // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(plan, config.options(), |_, _| {})?; + let plan = analyzer.execute_and_check(plan, &config.options(), |_, _| {})?; optimizer.optimize(plan, &config, |_, _| {}) } @@ -342,8 +360,7 @@ where let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, - "{} simplified to {}, but expected {}", - expr, output, expected + "{expr} simplified to {output}, but expected {expected}" ); } } @@ -352,8 +369,7 @@ fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, - "{} was simplified to {}, but expected it to be unchanged", - expr, output + "{expr} was simplified to {output}, but expected it to be unchanged" ); } } diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index ce5c0d720174d..3a1f06656236c 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -23,11 +23,10 @@ use std::time::SystemTime; use arrow::array::{ArrayRef, Int64Array, Int8Array, StringArray}; use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; -use datafusion::datasource::file_format::parquet::fetch_parquet_metadata; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ - FileMeta, ParquetFileMetrics, ParquetFileReaderFactory, ParquetSource, + ParquetFileMetrics, ParquetFileReaderFactory, ParquetSource, }; use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; @@ -38,12 +37,14 @@ use datafusion_common::Result; use bytes::Bytes; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource_parquet::metadata::DFParquetMetadata; use futures::future::BoxFuture; use futures::{FutureExt, TryFutureExt}; use insta::assert_snapshot; use object_store::memory::InMemory; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; +use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::ArrowWriter; use parquet::errors::ParquetError; @@ -118,11 +119,11 @@ impl ParquetFileReaderFactory for InMemoryParquetFileReaderFactory { fn create_reader( &self, partition_index: usize, - file_meta: FileMeta, + partitioned_file: PartitionedFile, metadata_size_hint: Option, metrics: &ExecutionPlanMetricsSet, ) -> Result> { - let metadata = file_meta + let metadata = partitioned_file .extensions .as_ref() .expect("has user defined metadata"); @@ -134,13 +135,13 @@ impl ParquetFileReaderFactory for InMemoryParquetFileReaderFactory { let parquet_file_metrics = ParquetFileMetrics::new( partition_index, - file_meta.location().as_ref(), + partitioned_file.object_meta.location.as_ref(), metrics, ); Ok(Box::new(ParquetFileReader { store: Arc::clone(&self.0), - meta: file_meta.object_meta, + meta: partitioned_file.object_meta, metrics: parquet_file_metrics, metadata_size_hint, })) @@ -186,7 +187,7 @@ async fn store_parquet_in_memory( location: Path::parse(format!("file-{offset}.parquet")) .expect("creating path"), last_modified: chrono::DateTime::from(SystemTime::now()), - size: buf.len(), + size: buf.len() as u64, e_tag: None, version: None, }; @@ -218,9 +219,10 @@ struct ParquetFileReader { impl AsyncFileReader for ParquetFileReader { fn get_bytes( &mut self, - range: Range, + range: Range, ) -> BoxFuture<'_, parquet::errors::Result> { - self.metrics.bytes_scanned.add(range.end - range.start); + let bytes_scanned = range.end - range.start; + self.metrics.bytes_scanned.add(bytes_scanned as usize); self.store .get_range(&self.meta.location, range) @@ -232,20 +234,19 @@ impl AsyncFileReader for ParquetFileReader { fn get_metadata( &mut self, + _options: Option<&ArrowReaderOptions>, ) -> BoxFuture<'_, parquet::errors::Result>> { Box::pin(async move { - let metadata = fetch_parquet_metadata( - self.store.as_ref(), - &self.meta, - self.metadata_size_hint, - ) - .await - .map_err(|e| { - ParquetError::General(format!( - "AsyncChunkReader::get_metadata error: {e}" - )) - })?; - Ok(Arc::new(metadata)) + let metadata = DFParquetMetadata::new(self.store.as_ref(), &self.meta) + .with_metadata_size_hint(self.metadata_size_hint) + .fetch_metadata() + .await + .map_err(|e| { + ParquetError::General(format!( + "AsyncChunkReader::get_metadata error: {e}" + )) + })?; + Ok(metadata) }) } } diff --git a/datafusion/core/tests/parquet/encryption.rs b/datafusion/core/tests/parquet/encryption.rs new file mode 100644 index 0000000000000..819d8bf3a283d --- /dev/null +++ b/datafusion/core/tests/parquet/encryption.rs @@ -0,0 +1,369 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for reading and writing Parquet files that use Parquet modular encryption + +use arrow::array::{ArrayRef, Int32Array, StringArray}; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, SchemaRef}; +use async_trait::async_trait; +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion::datasource::listing::ListingOptions; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_common::config::{EncryptionFactoryOptions, TableParquetOptions}; +use datafusion_common::{assert_batches_sorted_eq, exec_datafusion_err, DataFusionError}; +use datafusion_datasource_parquet::ParquetFormat; +use datafusion_execution::parquet_encryption::EncryptionFactory; +use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; +use parquet::arrow::ArrowWriter; +use parquet::encryption::decrypt::FileDecryptionProperties; +use parquet::encryption::encrypt::FileEncryptionProperties; +use parquet::file::column_crypto_metadata::ColumnCryptoMetaData; +use parquet::file::properties::WriterProperties; +use std::collections::HashMap; +use std::fs::File; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::{Arc, Mutex}; +use tempfile::TempDir; + +async fn read_parquet_test_data<'a, T: Into>( + path: T, + ctx: &SessionContext, + options: ParquetReadOptions<'a>, +) -> Vec { + ctx.read_parquet(path.into(), options) + .await + .unwrap() + .collect() + .await + .unwrap() +} + +pub fn write_batches( + path: PathBuf, + props: WriterProperties, + batches: impl IntoIterator, +) -> datafusion_common::Result { + let mut batches = batches.into_iter(); + let first_batch = batches.next().expect("need at least one record batch"); + let schema = first_batch.schema(); + + let file = File::create(&path)?; + let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props))?; + + writer.write(&first_batch)?; + let mut num_rows = first_batch.num_rows(); + + for batch in batches { + writer.write(&batch)?; + num_rows += batch.num_rows(); + } + writer.close()?; + Ok(num_rows) +} + +#[tokio::test] +async fn round_trip_encryption() { + let ctx: SessionContext = SessionContext::new(); + + let options = ParquetReadOptions::default(); + let batches = read_parquet_test_data( + "tests/data/filter_pushdown/single_file.gz.parquet", + &ctx, + options, + ) + .await; + + let schema = batches[0].schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields.iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + let encrypt = encrypt.build().unwrap(); + let decrypt = decrypt.build().unwrap(); + + // Write encrypted parquet + let props = WriterProperties::builder() + .with_file_encryption_properties(encrypt) + .build(); + + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); + let tempfile = tempdir.path().join("data.parquet"); + let num_rows_written = write_batches(tempfile.clone(), props, batches).unwrap(); + + // Read encrypted parquet + let ctx: SessionContext = SessionContext::new(); + let options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + let encrypted_batches = read_parquet_test_data( + tempfile.into_os_string().into_string().unwrap(), + &ctx, + options, + ) + .await; + + let num_rows_read = encrypted_batches + .iter() + .fold(0, |acc, x| acc + x.num_rows()); + + assert_eq!(num_rows_written, num_rows_read); +} + +#[tokio::test] +async fn round_trip_parquet_with_encryption_factory() { + let ctx = SessionContext::new(); + let encryption_factory = Arc::new(MockEncryptionFactory::default()); + ctx.runtime_env().register_parquet_encryption_factory( + "test_encryption_factory", + Arc::clone(&encryption_factory) as Arc, + ); + + let tmpdir = TempDir::new().unwrap(); + + // Register some simple test data + let strings: ArrayRef = + Arc::new(StringArray::from(vec!["a", "b", "c", "a", "b", "c"])); + let x1: ArrayRef = Arc::new(Int32Array::from(vec![1, 10, 11, 100, 101, 111])); + let x2: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); + let batch = + RecordBatch::try_from_iter(vec![("string", strings), ("x1", x1), ("x2", x2)]) + .unwrap(); + let test_data_schema = batch.schema(); + ctx.register_batch("test_data", batch).unwrap(); + let df = ctx.table("test_data").await.unwrap(); + + // Write encrypted Parquet, partitioned by string column into separate files + let mut parquet_options = TableParquetOptions::new(); + parquet_options.crypto.factory_id = Some("test_encryption_factory".to_string()); + parquet_options + .crypto + .factory_options + .options + .insert("test_key".to_string(), "test value".to_string()); + + let df_write_options = + DataFrameWriteOptions::default().with_partition_by(vec!["string".to_string()]); + df.write_parquet( + tmpdir.path().to_str().unwrap(), + df_write_options, + Some(parquet_options.clone()), + ) + .await + .unwrap(); + + // Crypto factory should have generated one key per partition file + assert_eq!(encryption_factory.encryption_keys.lock().unwrap().len(), 3); + + verify_table_encrypted(tmpdir.path(), &encryption_factory) + .await + .unwrap(); + + // Registering table without decryption properties should fail + let table_path = format!("file://{}/", tmpdir.path().to_str().unwrap()); + let without_decryption_register = ctx + .register_listing_table( + "parquet_missing_decryption", + &table_path, + ListingOptions::new(Arc::new(ParquetFormat::default())), + None, + None, + ) + .await; + assert!(matches!( + without_decryption_register.unwrap_err(), + DataFusionError::ParquetError(_) + )); + + // Registering table succeeds if schema is provided + ctx.register_listing_table( + "parquet_missing_decryption", + &table_path, + ListingOptions::new(Arc::new(ParquetFormat::default())), + Some(test_data_schema), + None, + ) + .await + .unwrap(); + + // But trying to read from the table should fail + let without_decryption_read = ctx + .table("parquet_missing_decryption") + .await + .unwrap() + .collect() + .await; + assert!(matches!( + without_decryption_read.unwrap_err(), + DataFusionError::ParquetError(_) + )); + + // Register table with encryption factory specified + let listing_options = ListingOptions::new(Arc::new( + ParquetFormat::default().with_options(parquet_options), + )) + .with_table_partition_cols(vec![("string".to_string(), DataType::Utf8)]); + ctx.register_listing_table( + "parquet_with_decryption", + &table_path, + listing_options, + None, + None, + ) + .await + .unwrap(); + + // Can read correct data when encryption factory has been specified + let table = ctx + .table("parquet_with_decryption") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+-----+----+--------+", + "| x1 | x2 | string |", + "+-----+----+--------+", + "| 1 | 1 | a |", + "| 100 | 4 | a |", + "| 10 | 2 | b |", + "| 101 | 5 | b |", + "| 11 | 3 | c |", + "| 111 | 6 | c |", + "+-----+----+--------+", + ]; + assert_batches_sorted_eq!(expected, &table); +} + +async fn verify_table_encrypted( + table_path: &Path, + encryption_factory: &Arc, +) -> datafusion_common::Result<()> { + let mut directories = vec![table_path.to_path_buf()]; + let mut files_visited = 0; + while let Some(directory) = directories.pop() { + for entry in std::fs::read_dir(&directory)? { + let path = entry?.path(); + if path.is_dir() { + directories.push(path); + } else { + verify_file_encrypted(&path, encryption_factory).await?; + files_visited += 1; + } + } + } + assert!(files_visited > 0); + Ok(()) +} + +async fn verify_file_encrypted( + file_path: &Path, + encryption_factory: &Arc, +) -> datafusion_common::Result<()> { + let mut options = EncryptionFactoryOptions::default(); + options + .options + .insert("test_key".to_string(), "test value".to_string()); + + let file_path_str = if cfg!(target_os = "windows") { + // Windows backslashes are eventually converted to slashes when writing the Parquet files, + // through `ListingTableUrl::parse`, making `encryption_factory.encryption_keys` store them + // it that format. So we also replace backslashes here to ensure they match. + file_path.to_str().unwrap().replace("\\", "/") + } else { + file_path.to_str().unwrap().to_owned() + }; + + let object_path = object_store::path::Path::from(file_path_str); + let decryption_properties = encryption_factory + .get_file_decryption_properties(&options, &object_path) + .await? + .unwrap(); + + let reader_options = + ArrowReaderOptions::new().with_file_decryption_properties(decryption_properties); + let file = File::open(file_path)?; + let reader_metadata = ArrowReaderMetadata::load(&file, reader_options)?; + let metadata = reader_metadata.metadata(); + assert!(metadata.num_row_groups() > 0); + for row_group in metadata.row_groups() { + assert!(row_group.num_columns() > 0); + for col in row_group.columns() { + assert!(matches!( + col.crypto_metadata(), + Some(ColumnCryptoMetaData::EncryptionWithFooterKey) + )); + } + } + Ok(()) +} + +/// Encryption factory implementation for use in tests, +/// which generates encryption keys in a sequence +#[derive(Debug, Default)] +struct MockEncryptionFactory { + pub encryption_keys: Mutex>>, + pub counter: AtomicU8, +} + +#[async_trait] +impl EncryptionFactory for MockEncryptionFactory { + async fn get_file_encryption_properties( + &self, + config: &EncryptionFactoryOptions, + _schema: &SchemaRef, + file_path: &object_store::path::Path, + ) -> datafusion_common::Result> { + assert_eq!( + config.options.get("test_key"), + Some(&"test value".to_string()) + ); + let file_idx = self.counter.fetch_add(1, Ordering::Relaxed); + let key = vec![file_idx, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let mut keys = self.encryption_keys.lock().unwrap(); + keys.insert(file_path.clone(), key.clone()); + let encryption_properties = FileEncryptionProperties::builder(key).build()?; + Ok(Some(encryption_properties)) + } + + async fn get_file_decryption_properties( + &self, + config: &EncryptionFactoryOptions, + file_path: &object_store::path::Path, + ) -> datafusion_common::Result> { + assert_eq!( + config.options.get("test_key"), + Some(&"test value".to_string()) + ); + let keys = self.encryption_keys.lock().unwrap(); + let key = keys + .get(file_path) + .ok_or_else(|| exec_datafusion_err!("No key for file {file_path:?}"))?; + let decryption_properties = + FileDecryptionProperties::builder(key.clone()).build()?; + Ok(Some(decryption_properties)) + } +} diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index bbef073345b73..a5397c5a397ca 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -346,7 +346,7 @@ impl TestFull { let source = if let Some(predicate) = predicate { let df_schema = DFSchema::try_from(schema.clone())?; let predicate = ctx.create_physical_expr(predicate, &df_schema)?; - Arc::new(ParquetSource::default().with_predicate(schema.clone(), predicate)) + Arc::new(ParquetSource::default().with_predicate(predicate)) } else { Arc::new(ParquetSource::default()) }; diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 7e98ebed6c9a7..64ee92eda2545 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -28,6 +28,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; use datafusion_common::stats::Precision; +use datafusion_common::DFSchema; use datafusion_execution::cache::cache_manager::CacheManagerConfig; use datafusion_execution::cache::cache_unit::{ DefaultFileStatisticsCache, DefaultListFilesCache, @@ -37,6 +38,11 @@ use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::{col, lit, Expr}; use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion_common::config::ConfigOptions; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::ExecutionPlan; use tempfile::tempdir; #[tokio::test] @@ -45,18 +51,53 @@ async fn check_stats_precision_with_filter_pushdown() { let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + let opt = + ListingOptions::new(Arc::new(ParquetFormat::default())).with_collect_stat(true); let table = get_listing_table(&table_path, None, &opt).await; + let (_, _, state) = get_cache_runtime_state(); + let mut options: ConfigOptions = state.config().options().as_ref().clone(); + options.execution.parquet.pushdown_filters = true; + // Scan without filter, stats are exact let exec = table.scan(&state, None, &[], None).await.unwrap(); - assert_eq!(exec.statistics().unwrap().num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8), + "Stats without filter should be exact" + ); - // Scan with filter pushdown, stats are inexact - let filter = Expr::gt(col("id"), lit(1)); + // This is a filter that cannot be evaluated by the table provider scanning + // (it is not a partition filter). Therefore; it will be pushed down to the + // source operator after the appropriate optimizer pass. + let filter_expr = Expr::gt(col("id"), lit(1)); + let exec_with_filter = table + .scan(&state, None, std::slice::from_ref(&filter_expr), None) + .await + .unwrap(); + + let ctx = SessionContext::new(); + let df_schema = DFSchema::try_from(table.schema()).unwrap(); + let physical_filter = ctx.create_physical_expr(filter_expr, &df_schema).unwrap(); - let exec = table.scan(&state, None, &[filter], None).await.unwrap(); - assert_eq!(exec.statistics().unwrap().num_rows, Precision::Inexact(8)); + let filtered_exec = + Arc::new(FilterExec::try_new(physical_filter, exec_with_filter).unwrap()) + as Arc; + + let optimized_exec = FilterPushdown::new() + .optimize(filtered_exec, &options) + .unwrap(); + + assert!( + optimized_exec.as_any().is::(), + "Sanity check that the pushdown did what we expected" + ); + // Scan with filter pushdown, stats are inexact + assert_eq!( + optimized_exec.partition_statistics(None).unwrap().num_rows, + Precision::Inexact(8), + "Stats after filter pushdown should be inexact" + ); } #[tokio::test] @@ -70,7 +111,8 @@ async fn load_table_stats_with_session_level_cache() { // Create a separate DefaultFileStatisticsCache let (cache2, _, state2) = get_cache_runtime_state(); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + let opt = + ListingOptions::new(Arc::new(ParquetFormat::default())).with_collect_stat(true); let table1 = get_listing_table(&table_path, Some(cache1), &opt).await; let table2 = get_listing_table(&table_path, Some(cache2), &opt).await; @@ -79,9 +121,12 @@ async fn load_table_stats_with_session_level_cache() { assert_eq!(get_static_cache_size(&state1), 0); let exec1 = table1.scan(&state1, None, &[], None).await.unwrap(); - assert_eq!(exec1.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec1.statistics().unwrap().total_byte_size, + exec1.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec1.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -91,9 +136,12 @@ async fn load_table_stats_with_session_level_cache() { //check session 1 cache result not show in session 2 assert_eq!(get_static_cache_size(&state2), 0); let exec2 = table2.scan(&state2, None, &[], None).await.unwrap(); - assert_eq!(exec2.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec2.statistics().unwrap().total_byte_size, + exec2.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec2.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -103,9 +151,12 @@ async fn load_table_stats_with_session_level_cache() { //check session 1 cache result not show in session 2 assert_eq!(get_static_cache_size(&state1), 1); let exec3 = table1.scan(&state1, None, &[], None).await.unwrap(); - assert_eq!(exec3.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec3.statistics().unwrap().total_byte_size, + exec3.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec3.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -117,23 +168,15 @@ async fn load_table_stats_with_session_level_cache() { async fn list_files_with_session_level_cache() { let p_name = "alltypes_plain.parquet"; let testdata = datafusion::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, p_name); + let filename = format!("{testdata}/{p_name}"); - let temp_path1 = tempdir() - .unwrap() - .into_path() - .into_os_string() - .into_string() - .unwrap(); - let temp_filename1 = format!("{}/{}", temp_path1, p_name); + let temp_dir1 = tempdir().unwrap(); + let temp_path1 = temp_dir1.path().to_str().unwrap(); + let temp_filename1 = format!("{temp_path1}/{p_name}"); - let temp_path2 = tempdir() - .unwrap() - .into_path() - .into_os_string() - .into_string() - .unwrap(); - let temp_filename2 = format!("{}/{}", temp_path2, p_name); + let temp_dir2 = tempdir().unwrap(); + let temp_path2 = temp_dir2.path().to_str().unwrap(); + let temp_filename2 = format!("{temp_path2}/{p_name}"); fs::copy(filename.clone(), temp_filename1).expect("panic"); fs::copy(filename, temp_filename2).expect("panic"); diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 02fb59740493f..b769fec7d3728 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -26,56 +26,52 @@ //! select * from data limit 10; //! ``` -use std::path::Path; - use arrow::compute::concat_batches; use arrow::record_batch::RecordBatch; use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::MetricsSet; -use datafusion::prelude::{col, lit, lit_timestamp_nano, Expr, SessionContext}; +use datafusion::prelude::{ + col, lit, lit_timestamp_nano, Expr, ParquetReadOptions, SessionContext, +}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_common::instant::Instant; use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; +use std::path::Path; +use datafusion_common::test_util::parquet_test_data; +use datafusion_execution::config::SessionConfig; use itertools::Itertools; use parquet::file::properties::WriterProperties; use tempfile::TempDir; -use test_utils::AccessLogGenerator; /// how many rows of generated data to write to our parquet file (arbitrary) const NUM_ROWS: usize = 4096; -fn generate_file(tempdir: &TempDir, props: WriterProperties) -> TestParquetFile { - // Tune down the generator for smaller files - let generator = AccessLogGenerator::new() - .with_row_limit(NUM_ROWS) - .with_pods_per_host(1..4) - .with_containers_per_pod(1..2) - .with_entries_per_container(128..256); - - let file = tempdir.path().join("data.parquet"); - - let start = Instant::now(); - println!("Writing test data to {file:?}"); - let test_parquet_file = TestParquetFile::try_new(file, props, generator).unwrap(); - println!( - "Completed generating test data in {:?}", - Instant::now() - start - ); - test_parquet_file +async fn read_parquet_test_data>(path: T) -> Vec { + let ctx: SessionContext = SessionContext::new(); + ctx.read_parquet(path.into(), ParquetReadOptions::default()) + .await + .unwrap() + .collect() + .await + .unwrap() } #[tokio::test] async fn single_file() { - // Only create the parquet file once as it is fairly large + let batches = + read_parquet_test_data("tests/data/filter_pushdown/single_file.gz.parquet").await; - let tempdir = TempDir::new_in(Path::new(".")).unwrap(); - // Set row group size smaller so can test with fewer rows + // Set the row group size smaller so can test with fewer rows let props = WriterProperties::builder() .set_max_row_group_size(1024) .build(); - let test_parquet_file = generate_file(&tempdir, props); + // Only create the parquet file once as it is fairly large + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); + + let test_parquet_file = + TestParquetFile::try_new(tempdir.path().join("data.parquet"), props, batches) + .unwrap(); let case = TestCase::new(&test_parquet_file) .with_name("selective") // request_method = 'GET' @@ -224,16 +220,25 @@ async fn single_file() { } #[tokio::test] +#[allow(dead_code)] async fn single_file_small_data_pages() { + let batches = read_parquet_test_data( + "tests/data/filter_pushdown/single_file_small_pages.gz.parquet", + ) + .await; + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); - // Set low row count limit to improve page filtering + // Set a low row count limit to improve page filtering let props = WriterProperties::builder() .set_max_row_group_size(2048) .set_data_page_row_count_limit(512) .set_write_batch_size(512) .build(); - let test_parquet_file = generate_file(&tempdir, props); + + let test_parquet_file = + TestParquetFile::try_new(tempdir.path().join("data.parquet"), props, batches) + .unwrap(); // The statistics on the 'pod' column are as follows: // @@ -597,3 +602,99 @@ fn get_value(metrics: &MetricsSet, metric_name: &str) -> usize { } } } + +#[tokio::test] +async fn predicate_cache_default() -> datafusion_common::Result<()> { + let ctx = SessionContext::new(); + // The cache is on by default, but not used unless filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 0, + expected_records: 0, + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_pushdown_default() -> datafusion_common::Result<()> { + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(config); + // The cache is on by default, and used when filter pushdown is enabled + PredicateCacheTest { + expected_inner_records: 8, + expected_records: 4, + } + .run(&ctx) + .await +} + +#[tokio::test] +async fn predicate_cache_pushdown_disable() -> datafusion_common::Result<()> { + // Can disable the cache even with filter pushdown by setting the size to 0. In this case we + // expect the inner records are reported but no records are read from the cache + let mut config = SessionConfig::new(); + config.options_mut().execution.parquet.pushdown_filters = true; + config + .options_mut() + .execution + .parquet + .max_predicate_cache_size = Some(0); + let ctx = SessionContext::new_with_config(config); + PredicateCacheTest { + // file has 8 rows, which need to be read twice, one for filter, one for + // final output + expected_inner_records: 16, + // Expect this to 0 records read as the cache is disabled. However, it is + // non zero due to https://github.com/apache/arrow-rs/issues/8307 + expected_records: 3, + } + .run(&ctx) + .await +} + +/// Runs the query "SELECT * FROM alltypes_plain WHERE double_col != 0.0" +/// with a given SessionContext and asserts that the predicate cache metrics +/// are as expected +#[derive(Debug)] +struct PredicateCacheTest { + /// Expected records read from the underlying reader (to evaluate filters) + /// -- this is the total number of records in the file + expected_inner_records: usize, + /// Expected records to be read from the cache (after filtering) + expected_records: usize, +} + +impl PredicateCacheTest { + async fn run(self, ctx: &SessionContext) -> datafusion_common::Result<()> { + let Self { + expected_inner_records, + expected_records, + } = self; + // Create a dataframe that scans the "alltypes_plain.parquet" file with + // a filter on `double_col != 0.0` + let path = parquet_test_data() + "/alltypes_plain.parquet"; + let exec = ctx + .read_parquet(path, ParquetReadOptions::default()) + .await? + .filter(col("double_col").not_eq(lit(0.0)))? + .create_physical_plan() + .await?; + + // run the plan to completion + let _ = collect(exec.clone(), ctx.task_ctx()).await?; // run plan + let metrics = + TestParquetFile::parquet_metrics(&exec).expect("found parquet metrics"); + + // verify the predicate cache metrics + assert_eq!( + get_value(&metrics, "predicate_cache_inner_records"), + expected_inner_records + ); + assert_eq!( + get_value(&metrics, "predicate_cache_records"), + expected_records + ); + Ok(()) + } +} diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index f45eacce18df5..c44d14abd381a 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -43,12 +43,15 @@ use std::sync::Arc; use tempfile::NamedTempFile; mod custom_reader; +#[cfg(feature = "parquet_encryption")] +mod encryption; mod external_access_plan; mod file_statistics; mod filter_pushdown; mod page_pruning; mod row_group_pruning; mod schema; +mod schema_adapter; mod schema_coercion; mod utils; @@ -107,11 +110,11 @@ struct ContextWithParquet { /// The output of running one of the test cases struct TestOutput { - /// The input string + /// The input query SQL sql: String, /// Execution metrics for the Parquet Scan parquet_metrics: MetricsSet, - /// number of rows in results + /// number of actual rows in results result_rows: usize, /// the contents of the input, as a string pretty_input: String, @@ -152,6 +155,10 @@ impl TestOutput { self.metric_value("row_groups_pruned_statistics") } + fn files_ranges_pruned_statistics(&self) -> Option { + self.metric_value("files_ranges_pruned_statistics") + } + /// The number of row_groups matched by bloom filter or statistics fn row_groups_matched(&self) -> Option { self.row_groups_matched_bloom_filter() @@ -192,6 +199,8 @@ impl ContextWithParquet { unit: Unit, mut config: SessionConfig, ) -> Self { + // Use a single partition for deterministic results no matter how many CPUs the host has + config = config.with_target_partitions(1); let file = match unit { Unit::RowGroup(row_per_group) => { config = config.with_parquet_bloom_filter_pruning(true); @@ -611,7 +620,7 @@ fn make_bytearray_batch( large_binary_values: Vec<&[u8]>, ) -> RecordBatch { let num_rows = string_values.len(); - let name: StringArray = std::iter::repeat(Some(name)).take(num_rows).collect(); + let name: StringArray = std::iter::repeat_n(Some(name), num_rows).collect(); let service_string: StringArray = string_values.iter().map(Some).collect(); let service_binary: BinaryArray = binary_values.iter().map(Some).collect(); let service_fixedsize: FixedSizeBinaryArray = fixedsize_values @@ -659,7 +668,7 @@ fn make_bytearray_batch( /// name | service.name fn make_names_batch(name: &str, service_name_values: Vec<&str>) -> RecordBatch { let num_rows = service_name_values.len(); - let name: StringArray = std::iter::repeat(Some(name)).take(num_rows).collect(); + let name: StringArray = std::iter::repeat_n(Some(name), num_rows).collect(); let service_name: StringArray = service_name_values.iter().map(Some).collect(); let schema = Schema::new(vec![ @@ -698,7 +707,7 @@ fn make_int_batches_with_null( Int8Array::from_iter( v8.into_iter() .map(Some) - .chain(std::iter::repeat(None).take(null_values)), + .chain(std::iter::repeat_n(None, null_values)), ) .to_data(), ), @@ -706,7 +715,7 @@ fn make_int_batches_with_null( Int16Array::from_iter( v16.into_iter() .map(Some) - .chain(std::iter::repeat(None).take(null_values)), + .chain(std::iter::repeat_n(None, null_values)), ) .to_data(), ), @@ -714,7 +723,7 @@ fn make_int_batches_with_null( Int32Array::from_iter( v32.into_iter() .map(Some) - .chain(std::iter::repeat(None).take(null_values)), + .chain(std::iter::repeat_n(None, null_values)), ) .to_data(), ), @@ -722,7 +731,7 @@ fn make_int_batches_with_null( Int64Array::from_iter( v64.into_iter() .map(Some) - .chain(std::iter::repeat(None).take(null_values)), + .chain(std::iter::repeat_n(None, null_values)), ) .to_data(), ), diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 7006bf083eeed..27bee10234b57 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use crate::parquet::Unit::Page; use crate::parquet::{ContextWithParquet, Scenario}; +use arrow::array::RecordBatch; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::PartitionedFile; @@ -40,7 +41,11 @@ use futures::StreamExt; use object_store::path::Path; use object_store::ObjectMeta; -async fn get_parquet_exec(state: &SessionState, filter: Expr) -> DataSourceExec { +async fn get_parquet_exec( + state: &SessionState, + filter: Expr, + pushdown_filters: bool, +) -> DataSourceExec { let object_store_url = ObjectStoreUrl::local_filesystem(); let store = state.runtime_env().object_store(&object_store_url).unwrap(); @@ -52,7 +57,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> DataSourceExec let meta = ObjectMeta { location, last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), - size: metadata.len() as usize, + size: metadata.len(), e_tag: None, version: None, }; @@ -77,8 +82,9 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> DataSourceExec let source = Arc::new( ParquetSource::default() - .with_predicate(Arc::clone(&schema), predicate) - .with_enable_page_index(true), + .with_predicate(predicate) + .with_enable_page_index(true) + .with_pushdown_filters(pushdown_filters), ); let base_config = FileScanConfigBuilder::new(object_store_url, schema, source) .with_file(partitioned_file) @@ -87,38 +93,44 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> DataSourceExec DataSourceExec::new(Arc::new(base_config)) } +async fn get_filter_results( + state: &SessionState, + filter: Expr, + pushdown_filters: bool, +) -> Vec { + let parquet_exec = get_parquet_exec(state, filter, pushdown_filters).await; + let task_ctx = state.task_ctx(); + let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); + let mut batches = Vec::new(); + while let Some(Ok(batch)) = results.next().await { + batches.push(batch); + } + batches +} + #[tokio::test] async fn page_index_filter_one_col() { let session_ctx = SessionContext::new(); let state = session_ctx.state(); - let task_ctx = state.task_ctx(); // 1.create filter month == 1; let filter = col("month").eq(lit(1_i32)); - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - + let batches = get_filter_results(&state, filter.clone(), false).await; // `month = 1` from the page index should create below RowSelection // vec.push(RowSelector::select(312)); // vec.push(RowSelector::skip(3330)); // vec.push(RowSelector::select(339)); // vec.push(RowSelector::skip(3319)); // total 651 row - assert_eq!(batch.num_rows(), 651); + assert_eq!(batches[0].num_rows(), 651); + + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 620); // 2. create filter month == 1 or month == 2; let filter = col("month").eq(lit(1_i32)).or(col("month").eq(lit(2_i32))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - + let batches = get_filter_results(&state, filter.clone(), false).await; // `month = 1` or `month = 2` from the page index should create below RowSelection // vec.push(RowSelector::select(312)); // vec.push(RowSelector::skip(900)); @@ -128,95 +140,78 @@ async fn page_index_filter_one_col() { // vec.push(RowSelector::skip(873)); // vec.push(RowSelector::select(318)); // vec.push(RowSelector::skip(2128)); - assert_eq!(batch.num_rows(), 1281); + assert_eq!(batches[0].num_rows(), 1281); + + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 1180); // 3. create filter month == 1 and month == 12; let filter = col("month") .eq(lit(1_i32)) .and(col("month").eq(lit(12_i32))); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert!(batches.is_empty()); - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await; - - assert!(batch.is_none()); + let batches = get_filter_results(&state, filter, true).await; + assert!(batches.is_empty()); // 4.create filter 0 < month < 2 ; let filter = col("month").gt(lit(0_i32)).and(col("month").lt(lit(2_i32))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - + let batches = get_filter_results(&state, filter.clone(), false).await; // should same with `month = 1` - assert_eq!(batch.num_rows(), 651); - - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); + assert_eq!(batches[0].num_rows(), 651); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 620); // 5.create filter date_string_col == "01/01/09"`; // Note this test doesn't apply type coercion so the literal must match the actual view type let filter = col("date_string_col").eq(lit(ScalarValue::new_utf8view("01/01/09"))); - let parquet_exec = get_parquet_exec(&state, filter).await; - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - let batch = results.next().await.unwrap().unwrap(); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert_eq!(batches[0].num_rows(), 14); // there should only two pages match the filter // min max // page-20 0 01/01/09 01/02/09 // page-21 0 01/01/09 01/01/09 // each 7 rows - assert_eq!(batch.num_rows(), 14); + assert_eq!(batches[0].num_rows(), 14); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 10); } #[tokio::test] async fn page_index_filter_multi_col() { let session_ctx = SessionContext::new(); let state = session_ctx.state(); - let task_ctx = session_ctx.task_ctx(); // create filter month == 1 and year = 2009; let filter = col("month").eq(lit(1_i32)).and(col("year").eq(lit(2009))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - + let batches = get_filter_results(&state, filter.clone(), false).await; // `year = 2009` from the page index should create below RowSelection // vec.push(RowSelector::select(3663)); // vec.push(RowSelector::skip(3642)); // combine with `month = 1` total 333 row - assert_eq!(batch.num_rows(), 333); + assert_eq!(batches[0].num_rows(), 333); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 310); // create filter (year = 2009 or id = 1) and month = 1; // this should only use `month = 1` to evaluate the page index. let filter = col("month") .eq(lit(1_i32)) .and(col("year").eq(lit(2009)).or(col("id").eq(lit(1)))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - assert_eq!(batch.num_rows(), 651); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert_eq!(batches[0].num_rows(), 651); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 310); // create filter (year = 2009 or id = 1) // this filter use two columns will not push down let filter = col("year").eq(lit(2009)).or(col("id").eq(lit(1))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - assert_eq!(batch.num_rows(), 7300); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert_eq!(batches[0].num_rows(), 7300); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 3650); // create filter (year = 2009 and id = 1) or (year = 2010) // this filter use two columns will not push down @@ -226,13 +221,10 @@ async fn page_index_filter_multi_col() { .eq(lit(2009)) .and(col("id").eq(lit(1))) .or(col("year").eq(lit(2010))); - - let parquet_exec = get_parquet_exec(&state, filter).await; - - let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); - - let batch = results.next().await.unwrap().unwrap(); - assert_eq!(batch.num_rows(), 7300); + let batches = get_filter_results(&state, filter.clone(), false).await; + assert_eq!(batches[0].num_rows(), 7300); + let batches = get_filter_results(&state, filter, true).await; + assert_eq!(batches[0].num_rows(), 3651); } async fn test_prune( @@ -911,8 +903,8 @@ async fn without_pushdown_filter() { ) .unwrap(); - // Without filter will not read pageIndex. - assert!(bytes_scanned_with_filter > bytes_scanned_without_filter); + // Same amount of bytes are scanned when defaulting to cache parquet metadata + assert_eq!(bytes_scanned_with_filter, bytes_scanned_without_filter); } #[tokio::test] diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 5a85f47c015a9..44409166d3ce3 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -31,9 +31,10 @@ struct RowGroupPruningTest { expected_errors: Option, expected_row_group_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, + expected_files_pruned_by_statistics: Option, expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, - expected_results: usize, + expected_rows: usize, } impl RowGroupPruningTest { // Start building the test configuration @@ -44,9 +45,10 @@ impl RowGroupPruningTest { expected_errors: None, expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_files_pruned_by_statistics: None, expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, - expected_results: 0, + expected_rows: 0, } } @@ -80,6 +82,11 @@ impl RowGroupPruningTest { self } + fn with_pruned_files(mut self, pruned_files: Option) -> Self { + self.expected_files_pruned_by_statistics = pruned_files; + self + } + // Set the expected matched row groups by bloom filter fn with_matched_by_bloom_filter(mut self, matched_by_bf: Option) -> Self { self.expected_row_group_matched_by_bloom_filter = matched_by_bf; @@ -92,9 +99,9 @@ impl RowGroupPruningTest { self } - // Set the expected rows for the test + /// Set the number of expected rows from the output of this test fn with_expected_rows(mut self, rows: usize) -> Self { - self.expected_results = rows; + self.expected_rows = rows; self } @@ -121,6 +128,11 @@ impl RowGroupPruningTest { self.expected_row_group_pruned_by_statistics, "mismatched row_groups_pruned_statistics", ); + assert_eq!( + output.files_ranges_pruned_statistics(), + self.expected_files_pruned_by_statistics, + "mismatched files_ranges_pruned_statistics", + ); assert_eq!( output.row_groups_matched_bloom_filter(), self.expected_row_group_matched_by_bloom_filter, @@ -133,8 +145,10 @@ impl RowGroupPruningTest { ); assert_eq!( output.result_rows, - self.expected_results, - "mismatched expected rows: {}", + self.expected_rows, + "Expected {} rows, got {}: {}", + output.result_rows, + self.expected_rows, output.description(), ); } @@ -148,6 +162,7 @@ async fn prune_timestamps_nanos() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -165,6 +180,7 @@ async fn prune_timestamps_micros() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -182,6 +198,7 @@ async fn prune_timestamps_millis() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -199,6 +216,7 @@ async fn prune_timestamps_seconds() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -214,6 +232,7 @@ async fn prune_date32() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -256,6 +275,7 @@ async fn prune_disabled() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -301,6 +321,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -315,6 +336,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -330,6 +352,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -344,6 +367,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -359,6 +383,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) @@ -374,6 +399,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -389,6 +415,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) @@ -405,6 +432,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -421,7 +449,8 @@ macro_rules! int_tests { .with_query(&format!("SELECT * FROM t where i{} in (100)", $bits)) .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(1)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(0) @@ -438,6 +467,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(19) @@ -467,6 +497,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -482,6 +513,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -496,6 +528,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -511,6 +544,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -526,6 +560,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -542,6 +577,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -559,6 +595,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(4)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(0) @@ -575,6 +612,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(19) @@ -604,6 +642,7 @@ async fn prune_int32_eq_large_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -626,6 +665,7 @@ async fn prune_uint32_eq_large_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -641,6 +681,7 @@ async fn prune_f64_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -652,6 +693,7 @@ async fn prune_f64_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -669,6 +711,7 @@ async fn prune_f64_scalar_fun_and_gt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -685,6 +728,7 @@ async fn prune_f64_scalar_fun() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -701,6 +745,7 @@ async fn prune_f64_complex_expr() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) @@ -717,6 +762,7 @@ async fn prune_f64_complex_expr_subtract() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) @@ -735,6 +781,7 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) @@ -746,6 +793,7 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) @@ -757,6 +805,7 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) @@ -768,6 +817,7 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) @@ -786,6 +836,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -797,6 +848,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -809,6 +861,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -820,6 +873,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -839,6 +893,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) @@ -850,6 +905,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) @@ -861,6 +917,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) @@ -872,6 +929,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) @@ -885,6 +943,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -898,6 +957,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -911,6 +971,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -929,6 +990,7 @@ async fn prune_string_eq_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -947,6 +1009,7 @@ async fn prune_string_eq_no_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -963,6 +1026,7 @@ async fn prune_string_eq_no_match() { // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(0) @@ -980,6 +1044,7 @@ async fn prune_string_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -998,6 +1063,7 @@ async fn prune_string_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) @@ -1012,6 +1078,7 @@ async fn prune_string_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' @@ -1031,6 +1098,7 @@ async fn prune_binary_eq_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1049,6 +1117,7 @@ async fn prune_binary_eq_no_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -1065,6 +1134,7 @@ async fn prune_binary_eq_no_match() { // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(0) @@ -1082,6 +1152,7 @@ async fn prune_binary_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -1100,6 +1171,7 @@ async fn prune_binary_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) @@ -1114,6 +1186,7 @@ async fn prune_binary_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' @@ -1133,6 +1206,7 @@ async fn prune_fixedsizebinary_eq_match() { // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1148,6 +1222,7 @@ async fn prune_fixedsizebinary_eq_match() { // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1166,6 +1241,7 @@ async fn prune_fixedsizebinary_eq_no_match() { // false positive on 'mixed' batch: 'be1' < 'be9' < 'fe4' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -1183,6 +1259,7 @@ async fn prune_fixedsizebinary_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -1201,6 +1278,7 @@ async fn prune_fixedsizebinary_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -1215,6 +1293,7 @@ async fn prune_fixedsizebinary_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' @@ -1235,6 +1314,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(7) @@ -1246,6 +1326,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) @@ -1257,6 +1338,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -1277,6 +1359,7 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i8\" <= 5") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(2)) .with_expected_rows(5) .with_matched_by_bloom_filter(Some(0)) @@ -1290,6 +1373,7 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i8\" is Null") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(1)) .with_expected_rows(10) .with_matched_by_bloom_filter(Some(0)) @@ -1303,6 +1387,7 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i16\" is Not Null") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(2)) .with_expected_rows(5) .with_matched_by_bloom_filter(Some(0)) @@ -1316,7 +1401,8 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i32\" > 7") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(1)) .with_expected_rows(0) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) @@ -1332,6 +1418,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1344,6 +1431,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1356,6 +1444,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1368,6 +1457,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1383,6 +1473,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1395,6 +1486,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1407,6 +1499,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1419,6 +1512,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1434,6 +1528,7 @@ async fn test_bloom_filter_unsigned_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1446,6 +1541,7 @@ async fn test_bloom_filter_unsigned_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1461,6 +1557,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1473,6 +1570,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1485,6 +1583,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1499,6 +1598,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1514,6 +1614,7 @@ async fn test_bloom_filter_decimal_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1526,6 +1627,7 @@ async fn test_bloom_filter_decimal_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs new file mode 100644 index 0000000000000..4ae2fa9b4c399 --- /dev/null +++ b/datafusion/core/tests/parquet/schema_adapter.rs @@ -0,0 +1,551 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{record_batch, RecordBatch, RecordBatchOptions}; +use arrow::compute::{cast_with_options, CastOptions}; +use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef}; +use bytes::{BufMut, BytesMut}; +use datafusion::assert_batches_eq; +use datafusion::common::Result; +use datafusion::datasource::listing::{ListingTable, ListingTableConfig}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::DataFusionError; +use datafusion_common::{ColumnStatistics, ScalarValue}; +use datafusion_datasource::file::FileSource; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; +use datafusion_datasource::schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, +}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource_parquet::source::ParquetSource; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_expr::expressions::{self, Column}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, + PhysicalExprAdapterFactory, +}; +use itertools::Itertools; +use object_store::{memory::InMemory, path::Path, ObjectStore}; +use parquet::arrow::ArrowWriter; + +async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { + let mut out = BytesMut::new().writer(); + { + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + store.put(&Path::from(path), data.into()).await.unwrap(); +} + +#[derive(Debug)] +struct CustomSchemaAdapterFactory; + +impl SchemaAdapterFactory for CustomSchemaAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(CustomSchemaAdapter { + logical_file_schema: projected_table_schema, + }) + } +} + +#[derive(Debug)] +struct CustomSchemaAdapter { + logical_file_schema: SchemaRef, +} + +impl SchemaAdapter for CustomSchemaAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + for (idx, field) in file_schema.fields().iter().enumerate() { + if field.name() == self.logical_file_schema.field(index).name() { + return Some(idx); + } + } + None + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + let projection = (0..file_schema.fields().len()).collect_vec(); + Ok(( + Arc::new(CustomSchemaMapper { + logical_file_schema: Arc::clone(&self.logical_file_schema), + }), + projection, + )) + } +} + +#[derive(Debug)] +struct CustomSchemaMapper { + logical_file_schema: SchemaRef, +} + +impl SchemaMapper for CustomSchemaMapper { + fn map_batch(&self, batch: RecordBatch) -> Result { + let mut output_columns = + Vec::with_capacity(self.logical_file_schema.fields().len()); + for field in self.logical_file_schema.fields() { + if let Some(array) = batch.column_by_name(field.name()) { + output_columns.push(cast_with_options( + array, + field.data_type(), + &CastOptions::default(), + )?); + } else { + // Create a new array with the default value for the field type + let default_value = match field.data_type() { + DataType::Int64 => ScalarValue::Int64(Some(0)), + DataType::Utf8 => ScalarValue::Utf8(Some("a".to_string())), + _ => unimplemented!("Unsupported data type: {}", field.data_type()), + }; + output_columns + .push(default_value.to_array_of_size(batch.num_rows()).unwrap()); + } + } + let batch = RecordBatch::try_new_with_options( + Arc::clone(&self.logical_file_schema), + output_columns, + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), + ) + .unwrap(); + Ok(batch) + } + + fn map_column_statistics( + &self, + _file_col_statistics: &[ColumnStatistics], + ) -> Result> { + Ok(vec![ + ColumnStatistics::new_unknown(); + self.logical_file_schema.fields().len() + ]) + } +} + +// Implement a custom PhysicalExprAdapterFactory that fills in missing columns with the default value for the field type +#[derive(Debug)] +struct CustomPhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + Arc::new(CustomPhysicalExprAdapter { + logical_file_schema: Arc::clone(&logical_file_schema), + physical_file_schema: Arc::clone(&physical_file_schema), + inner: Arc::new(DefaultPhysicalExprAdapter::new( + logical_file_schema, + physical_file_schema, + )), + }) + } +} + +#[derive(Debug, Clone)] +struct CustomPhysicalExprAdapter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + inner: Arc, +} + +impl PhysicalExprAdapter for CustomPhysicalExprAdapter { + fn rewrite(&self, mut expr: Arc) -> Result> { + expr = expr + .transform(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + let field_name = column.name(); + if self + .physical_file_schema + .field_with_name(field_name) + .ok() + .is_none() + { + let field = self + .logical_file_schema + .field_with_name(field_name) + .map_err(|_| { + DataFusionError::Plan(format!( + "Field '{field_name}' not found in logical file schema", + )) + })?; + // If the field does not exist, create a default value expression + // Note that we use slightly different logic here to create a default value so that we can see different behavior in tests + let default_value = match field.data_type() { + DataType::Int64 => ScalarValue::Int64(Some(1)), + DataType::Utf8 => ScalarValue::Utf8(Some("b".to_string())), + _ => unimplemented!( + "Unsupported data type: {}", + field.data_type() + ), + }; + return Ok(Transformed::yes(Arc::new( + expressions::Literal::new(default_value), + ))); + } + } + + Ok(Transformed::no(expr)) + }) + .data()?; + self.inner.rewrite(expr) + } + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc { + assert!( + partition_values.is_empty(), + "Partition values are not supported in this test" + ); + Arc::new(self.clone()) + } +} + +#[tokio::test] +async fn test_custom_schema_adapter_and_custom_expression_adapter() { + let batch = + record_batch!(("extra", Int64, [1, 2, 3]), ("c1", Int32, [1, 2, 3])).unwrap(); + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + let path = "test.parquet"; + write_parquet(batch, store.clone(), path).await; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, false), + Field::new("c2", DataType::Utf8, true), + ])); + + let mut cfg = SessionConfig::new() + // Disable statistics collection for this test otherwise early pruning makes it hard to demonstrate data adaptation + .with_collect_statistics(false) + .with_parquet_pruning(false) + .with_parquet_page_index_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + assert!( + !ctx.state() + .config_mut() + .options_mut() + .execution + .collect_statistics + ); + assert!(!ctx.state().config().collect_statistics()); + + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)) + .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 IS NULL") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Test using a custom schema adapter and no explicit physical expr adapter + // This should use the custom schema adapter both for projections and predicate pushdown + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.deregister_table("t").unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'a'") + .await + .unwrap() + .collect() + .await + .unwrap(); + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| a | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // Do the same test but with a custom physical expr adapter + // Now the default schema adapter will be used for projections, but the custom physical expr adapter will be used for predicate pushdown + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.deregister_table("t").unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") + .await + .unwrap() + .collect() + .await + .unwrap(); + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); + + // If we use both then the custom physical expr adapter will be used for predicate pushdown and the custom schema adapter will be used for projections + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory)) + .with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory)); + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.deregister_table("t").unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + let batches = ctx + .sql("SELECT c2, c1 FROM t WHERE c1 = 2 AND c2 = 'b'") + .await + .unwrap() + .collect() + .await + .unwrap(); + let expected = [ + "+----+----+", + "| c2 | c1 |", + "+----+----+", + "| a | 2 |", + "+----+----+", + ]; + assert_batches_eq!(expected, &batches); +} + +/// A test schema adapter factory that adds prefix to column names +#[derive(Debug)] +struct PrefixAdapterFactory { + prefix: String, +} + +impl SchemaAdapterFactory for PrefixAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(PrefixAdapter { + input_schema: projected_table_schema, + prefix: self.prefix.clone(), + }) + } +} + +/// A test schema adapter that adds prefix to column names +#[derive(Debug)] +struct PrefixAdapter { + input_schema: SchemaRef, + prefix: String, +} + +impl SchemaAdapter for PrefixAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.input_schema.field(index); + file_schema.fields.find(field.name()).map(|(i, _)| i) + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + let mut projection = Vec::with_capacity(file_schema.fields().len()); + for (file_idx, file_field) in file_schema.fields().iter().enumerate() { + if self.input_schema.fields().find(file_field.name()).is_some() { + projection.push(file_idx); + } + } + + // Create a schema mapper that adds a prefix to column names + #[derive(Debug)] + struct PrefixSchemaMapping { + // Keep only the prefix field which is actually used in the implementation + prefix: String, + } + + impl SchemaMapper for PrefixSchemaMapping { + fn map_batch(&self, batch: RecordBatch) -> Result { + // Create a new schema with prefixed field names + let prefixed_fields: Vec = batch + .schema() + .fields() + .iter() + .map(|field| { + Field::new( + format!("{}{}", self.prefix, field.name()), + field.data_type().clone(), + field.is_nullable(), + ) + }) + .collect(); + let prefixed_schema = Arc::new(Schema::new(prefixed_fields)); + + // Create a new batch with the prefixed schema but the same data + let options = RecordBatchOptions::default(); + RecordBatch::try_new_with_options( + prefixed_schema, + batch.columns().to_vec(), + &options, + ) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + } + + fn map_column_statistics( + &self, + stats: &[ColumnStatistics], + ) -> Result> { + // For testing, just return the input statistics + Ok(stats.to_vec()) + } + } + + Ok(( + Arc::new(PrefixSchemaMapping { + prefix: self.prefix.clone(), + }), + projection, + )) + } +} + +#[test] +fn test_apply_schema_adapter_with_factory() { + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create a parquet source + let source = ParquetSource::default(); + + // Create a file scan config with source that has a schema adapter factory + let factory = Arc::new(PrefixAdapterFactory { + prefix: "test_".to_string(), + }); + + let file_source = source.clone().with_schema_adapter_factory(factory).unwrap(); + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::local_filesystem(), + schema.clone(), + file_source, + ) + .build(); + + // Apply schema adapter to a new source + let result_source = source.apply_schema_adapter(&config).unwrap(); + + // Verify the adapter was applied + assert!(result_source.schema_adapter_factory().is_some()); + + // Create adapter and test it produces expected schema + let adapter_factory = result_source.schema_adapter_factory().unwrap(); + let adapter = adapter_factory.create(schema.clone(), schema.clone()); + + // Create a dummy batch to test the schema mapping + let dummy_batch = RecordBatch::new_empty(schema.clone()); + + // Get the file schema (which is the same as the table schema in this test) + let (mapper, _) = adapter.map_schema(&schema).unwrap(); + + // Apply the mapping to get the output schema + let mapped_batch = mapper.map_batch(dummy_batch).unwrap(); + let output_schema = mapped_batch.schema(); + + // Check the column names have the prefix + assert_eq!(output_schema.field(0).name(), "test_id"); + assert_eq!(output_schema.field(1).name(), "test_name"); +} + +#[test] +fn test_apply_schema_adapter_without_factory() { + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create a parquet source + let source = ParquetSource::default(); + + // Convert to Arc + let file_source: Arc = Arc::new(source.clone()); + + // Create a file scan config without a schema adapter factory + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::local_filesystem(), + schema.clone(), + file_source, + ) + .build(); + + // Apply schema adapter function - should pass through the source unchanged + let result_source = source.apply_schema_adapter(&config).unwrap(); + + // Verify no adapter was applied + assert!(result_source.schema_adapter_factory().is_none()); +} diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 568be0d18f245..9c76f6ab6f58b 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -20,9 +20,10 @@ //! Note these tests are not in the same module as the optimizer pass because //! they rely on `DataSourceExec` which is in the core crate. +use insta::assert_snapshot; use std::sync::Arc; -use crate::physical_optimizer::test_utils::{parquet_exec, trim_plan_display}; +use crate::physical_optimizer::test_utils::parquet_exec; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; @@ -43,22 +44,16 @@ use datafusion_physical_plan::ExecutionPlan; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected macro_rules! assert_optimized { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); - + ($PLAN: expr, @ $EXPECTED_LINES: literal $(,)?) => { // run optimizer let optimizer = CombinePartialFinalAggregate {}; let config = ConfigOptions::new(); let optimized = optimizer.optimize($PLAN, &config)?; // Now format correctly let plan = displayable(optimized.as_ref()).indent(true).to_string(); - let actual_lines = trim_plan_display(&plan); + let actual_lines = plan.trim(); - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); + assert_snapshot!(actual_lines, @ $EXPECTED_LINES); }; } @@ -136,7 +131,7 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( repartition_exec(partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema.clone()), PhysicalGroupBy::default(), aggr_expr.clone(), )), @@ -144,20 +139,22 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { aggr_expr, ); // should not combine the Partial/Final AggregateExecs - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[COUNT(1)]", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Final, gby=[], aggr=[COUNT(1)] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); let aggr_expr1 = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; let aggr_expr2 = vec![count_expr(lit(1i8), "COUNT(2)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema), PhysicalGroupBy::default(), aggr_expr1, ), @@ -165,13 +162,14 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { aggr_expr2, ); // should not combine the Partial/Final AggregateExecs - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[COUNT(2)]", - "AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Final, gby=[], aggr=[COUNT(2)] + AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); Ok(()) } @@ -183,7 +181,7 @@ fn aggregations_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema), PhysicalGroupBy::default(), aggr_expr.clone(), ), @@ -191,12 +189,13 @@ fn aggregations_combined() -> datafusion_common::Result<()> { aggr_expr, ); // should combine the Partial/Final AggregateExecs to the Single AggregateExec - let expected = &[ - "AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ " + AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); Ok(()) } @@ -215,11 +214,8 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { vec![(col("c", &schema)?, "c".to_string())]; let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); + let partial_agg = + partial_aggregate_exec(parquet_exec(schema), partial_group_by, aggr_expr.clone()); let groups: Vec<(Arc, String)> = vec![(col("c", &partial_agg.schema())?, "c".to_string())]; @@ -227,12 +223,13 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec(partial_agg, final_group_by, aggr_expr); // should combine the Partial/Final AggregateExecs to the Single AggregateExec - let expected = &[ - "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[Sum(b)]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Single, gby=[c@2 as c], aggr=[Sum(b)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); Ok(()) } @@ -245,11 +242,8 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { vec![(col("c", &schema)?, "c".to_string())]; let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); + let partial_agg = + partial_aggregate_exec(parquet_exec(schema), partial_group_by, aggr_expr.clone()); let groups: Vec<(Arc, String)> = vec![(col("c", &partial_agg.schema())?, "c".to_string())]; @@ -271,11 +265,12 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { let plan: Arc = final_agg; // should combine the Partial/Final AggregateExecs to a Single AggregateExec // with the final limit preserved - let expected = &[ - "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5]", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet", - ]; - - assert_optimized!(expected, plan); + assert_optimized!( + plan, + @ r" + AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c], file_type=parquet + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 9898f6204e880..63111f43806b3 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -19,31 +19,35 @@ use std::fmt::Debug; use std::ops::Deref; use std::sync::Arc; -use crate::physical_optimizer::test_utils::parquet_exec_with_sort; use crate::physical_optimizer::test_utils::{ - check_integrity, coalesce_partitions_exec, repartition_exec, schema, - sort_merge_join_exec, sort_preserving_merge_exec, + check_integrity, coalesce_partitions_exec, parquet_exec_with_sort, + parquet_exec_with_stats, repartition_exec, schema, sort_exec, + sort_exec_with_preserve_partitioning, sort_merge_join_exec, + sort_preserving_merge_exec, union_exec, }; +use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; use arrow::compute::SortOptions; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::config::ConfigOptions; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; use datafusion::datasource::source::DataSourceExec; +use datafusion::datasource::MemTable; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::{ - expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, +use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, OrderingRequirements, PhysicalSortExpr, }; -use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; @@ -52,19 +56,18 @@ use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::ExecutionPlan; use datafusion_physical_plan::expressions::col; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::JoinOn; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_plan::PlanProperties; use datafusion_physical_plan::{ - get_plan_string, DisplayAs, DisplayFormatType, Statistics, + get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, + PlanProperties, Statistics, }; /// Models operators like BoundedWindowExec that require an input @@ -140,12 +143,8 @@ impl ExecutionPlan for SortRequiredExec { } // model that it requires the output ordering of its input - fn required_input_ordering(&self) -> Vec> { - if self.expr.is_empty() { - vec![None] - } else { - vec![Some(LexRequirement::from(self.expr.clone()))] - } + fn required_input_ordering(&self) -> Vec> { + vec![Some(OrderingRequirements::from(self.expr.clone()))] } fn with_new_children( @@ -169,12 +168,12 @@ impl ExecutionPlan for SortRequiredExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) } } fn parquet_exec() -> Arc { - parquet_exec_with_sort(vec![]) + parquet_exec_with_sort(schema(), vec![]) } fn parquet_exec_multiple() -> Arc { @@ -244,7 +243,10 @@ fn projection_exec_with_alias( ) -> Arc { let mut exprs = vec![]; for (column, alias) in alias_pairs.iter() { - exprs.push((col(column, &input.schema()).unwrap(), alias.to_string())); + exprs.push(ProjectionExpr { + expr: col(column, &input.schema()).unwrap(), + alias: alias.to_string(), + }); } Arc::new(ProjectionExec::try_new(exprs, input).unwrap()) } @@ -320,16 +322,6 @@ fn filter_exec(input: Arc) -> Arc { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } -fn sort_exec( - sort_exprs: LexOrdering, - input: Arc, - preserve_partitioning: bool, -) -> Arc { - let new_sort = SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning); - Arc::new(new_sort) -} - fn limit_exec(input: Arc) -> Arc { Arc::new(GlobalLimitExec::new( Arc::new(LocalLimitExec::new(input, 100)), @@ -338,10 +330,6 @@ fn limit_exec(input: Arc) -> Arc { )) } -fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) -} - fn sort_required_exec_with_req( input: Arc, sort_exprs: LexOrdering, @@ -456,7 +444,7 @@ impl TestConfig { /// Perform a series of runs using the current [`TestConfig`], /// assert the expected plan result, - /// and return the result plan (for potentional subsequent runs). + /// and return the result plan (for potential subsequent runs). fn run( &self, expected_lines: &[&str], @@ -524,8 +512,7 @@ impl TestConfig { assert_eq!( &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); Ok(optimized) @@ -643,7 +630,7 @@ fn multi_hash_joins() -> Result<()> { test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; test_config.run(&expected, top_join, &SORT_DISTRIB_DISTRIB)?; } - JoinType::RightSemi | JoinType::RightAnti => {} + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {} } match join_type { @@ -652,7 +639,8 @@ fn multi_hash_joins() -> Result<()> { | JoinType::Right | JoinType::Full | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { // This time we use (b1 == c) for top join // Join on (b1 == c) let top_join_on = vec![( @@ -1736,10 +1724,11 @@ fn smj_join_key_ordering() -> Result<()> { fn merge_does_not_need_sort() -> Result<()> { // see https://github.com/apache/datafusion/issues/4331 let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // Scan some sorted parquet files let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); @@ -1794,7 +1783,7 @@ fn union_to_interleave() -> Result<()> { ); // Union - let plan = Arc::new(UnionExec::new(vec![left, right])); + let plan = UnionExec::try_new(vec![left, right])?; // final agg let plan = @@ -1838,7 +1827,7 @@ fn union_not_to_interleave() -> Result<()> { ); // Union - let plan = Arc::new(UnionExec::new(vec![left, right])); + let plan = UnionExec::try_new(vec![left, right])?; // final agg let plan = @@ -1936,11 +1925,12 @@ fn repartition_unsorted_limit() -> Result<()> { #[test] fn repartition_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan = limit_exec(sort_exec(sort_key, parquet_exec(), false)); + }] + .into(); + let plan = limit_exec(sort_exec(sort_key, parquet_exec())); let expected = &[ "GlobalLimitExec: skip=0, fetch=100", @@ -1960,12 +1950,13 @@ fn repartition_sorted_limit() -> Result<()> { #[test] fn repartition_sorted_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_required_exec_with_req( - filter_exec(sort_exec(sort_key.clone(), parquet_exec(), false)), + filter_exec(sort_exec(sort_key.clone(), parquet_exec())), sort_key, ); @@ -2043,10 +2034,11 @@ fn repartition_ignores_union() -> Result<()> { fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec(sort_key, parquet_exec()); // need resort as the data was not sorted correctly @@ -2066,10 +2058,11 @@ fn repartition_through_sort_preserving_merge() -> Result<()> { fn repartition_ignores_sort_preserving_merge() -> Result<()> { // sort preserving merge already sorted input, let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec( sort_key.clone(), parquet_exec_multiple_sorted(vec![sort_key]), @@ -2101,11 +2094,15 @@ fn repartition_ignores_sort_preserving_merge() -> Result<()> { fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let input = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + }] + .into(); + let input = union_exec(vec![ + parquet_exec_with_sort(schema, vec![sort_key.clone()]); + 2 + ]); let plan = sort_preserving_merge_exec(sort_key, input); // Test: run EnforceDistribution, then EnforceSort. @@ -2139,12 +2136,13 @@ fn repartition_does_not_destroy_sort() -> Result<()> { // SortRequired // Parquet(sorted) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("d", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("d", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_required_exec_with_req( - filter_exec(parquet_exec_with_sort(vec![sort_key.clone()])), + filter_exec(parquet_exec_with_sort(schema, vec![sort_key.clone()])), sort_key, ); @@ -2177,12 +2175,13 @@ fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { // Parquet(unsorted) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input1 = sort_required_exec_with_req( - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), sort_key, ); let input2 = filter_exec(parquet_exec()); @@ -2211,20 +2210,21 @@ fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { #[test] fn repartition_transitively_with_projection() -> Result<()> { let schema = schema(); - let proj_exprs = vec![( - Arc::new(BinaryExpr::new( - col("a", &schema).unwrap(), + let proj_exprs = vec![ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + col("a", &schema)?, Operator::Plus, - col("b", &schema).unwrap(), - )) as Arc, - "sum".to_string(), - )]; + col("b", &schema)?, + )) as _, + alias: "sum".to_string(), + }]; // non sorted input let proj = Arc::new(ProjectionExec::try_new(proj_exprs, parquet_exec())?); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("sum", &proj.schema()).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("sum", &proj.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec(sort_key, proj); // Test: run EnforceDistribution, then EnforceSort. @@ -2256,10 +2256,11 @@ fn repartition_transitively_with_projection() -> Result<()> { #[test] fn repartition_ignores_transitively_with_projection() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -2291,10 +2292,11 @@ fn repartition_ignores_transitively_with_projection() -> Result<()> { #[test] fn repartition_transitively_past_sort_with_projection() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -2302,10 +2304,9 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { ]; let plan = sort_preserving_merge_exec( sort_key.clone(), - sort_exec( + sort_exec_with_preserve_partitioning( sort_key, projection_exec_with_alias(parquet_exec(), alias), - true, ), ); @@ -2326,11 +2327,12 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { #[test] fn repartition_transitively_past_sort_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); - let plan = sort_exec(sort_key, filter_exec(parquet_exec()), false); + }] + .into(); + let plan = sort_exec(sort_key, filter_exec(parquet_exec())); // Test: run EnforceDistribution, then EnforceSort. let expected = &[ @@ -2362,10 +2364,11 @@ fn repartition_transitively_past_sort_with_filter() -> Result<()> { #[cfg(feature = "parquet")] fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_exec( sort_key, projection_exec_with_alias( @@ -2376,7 +2379,6 @@ fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> ("c".to_string(), "c".to_string()), ], ), - false, ); // Test: run EnforceDistribution, then EnforceSort. @@ -2447,10 +2449,11 @@ fn parallelization_single_partition() -> Result<()> { #[test] fn parallelization_multiple_files() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key.clone()])); let plan = sort_required_exec_with_req(plan, sort_key); @@ -2610,7 +2613,7 @@ fn parallelization_two_partitions_into_four() -> Result<()> { "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", " RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4", " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - // Multiple source files splitted across partitions + // Multiple source files split across partitions " DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=parquet", ]; test_config.run( @@ -2625,7 +2628,7 @@ fn parallelization_two_partitions_into_four() -> Result<()> { "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", " RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4", " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[]", - // Multiple source files splitted across partitions + // Multiple source files split across partitions " DataSourceExec: file_groups={4 groups: [[x:0..50], [x:50..100], [y:0..50], [y:50..100]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", ]; test_config.run(&expected_csv, plan_csv.clone(), &DISTRIB_DISTRIB_SORT)?; @@ -2637,12 +2640,13 @@ fn parallelization_two_partitions_into_four() -> Result<()> { #[test] fn parallelization_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec(), false)); - let plan_csv = limit_exec(sort_exec(sort_key, csv_exec(), false)); + }] + .into(); + let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec())); + let plan_csv = limit_exec(sort_exec(sort_key, csv_exec())); let test_config = TestConfig::default(); @@ -2680,16 +2684,14 @@ fn parallelization_sorted_limit() -> Result<()> { #[test] fn parallelization_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan_parquet = limit_exec(filter_exec(sort_exec( - sort_key.clone(), - parquet_exec(), - false, - ))); - let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec(), false))); + }] + .into(); + let plan_parquet = + limit_exec(filter_exec(sort_exec(sort_key.clone(), parquet_exec()))); + let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec()))); let test_config = TestConfig::default(); @@ -2834,14 +2836,15 @@ fn parallelization_union_inputs() -> Result<()> { #[test] fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // sort preserving merge already sorted input, let plan_parquet = sort_preserving_merge_exec( sort_key.clone(), - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), ); let plan_csv = sort_preserving_merge_exec(sort_key.clone(), csv_exec_with_sort(vec![sort_key])); @@ -2875,13 +2878,17 @@ fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { #[test] fn parallelization_sort_preserving_merge_with_union() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let input_parquet = - union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + union_exec(vec![ + parquet_exec_with_sort(schema, vec![sort_key.clone()]); + 2 + ]); let input_csv = union_exec(vec![csv_exec_with_sort(vec![sort_key.clone()]); 2]); let plan_parquet = sort_preserving_merge_exec(sort_key.clone(), input_parquet); let plan_csv = sort_preserving_merge_exec(sort_key, input_csv); @@ -2948,14 +2955,15 @@ fn parallelization_sort_preserving_merge_with_union() -> Result<()> { #[test] fn parallelization_does_not_benefit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // SortRequired // Parquet(sorted) let plan_parquet = sort_required_exec_with_req( - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), sort_key.clone(), ); let plan_csv = @@ -2993,22 +3001,26 @@ fn parallelization_does_not_benefit() -> Result<()> { fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> { // sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ ("a".to_string(), "a2".to_string()), ("c".to_string(), "c2".to_string()), ]; - let proj_parquet = - projection_exec_with_alias(parquet_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c2", &proj_parquet.schema()).unwrap(), + let proj_parquet = projection_exec_with_alias( + parquet_exec_with_sort(schema, vec![sort_key]), + alias_pairs, + ); + let sort_key_after_projection = [PhysicalSortExpr { + expr: col("c2", &proj_parquet.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan_parquet = sort_preserving_merge_exec(sort_key_after_projection, proj_parquet); let expected = &[ @@ -3039,10 +3051,11 @@ fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { // sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ @@ -3052,10 +3065,11 @@ fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { let proj_csv = projection_exec_with_alias(csv_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c2", &proj_csv.schema()).unwrap(), + let sort_key_after_projection = [PhysicalSortExpr { + expr: col("c2", &proj_csv.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); let expected = &[ "SortPreservingMergeExec: [c2@1 ASC]", @@ -3108,10 +3122,11 @@ fn remove_redundant_roundrobins() -> Result<()> { #[test] fn remove_unnecessary_spm_after_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3138,10 +3153,11 @@ fn remove_unnecessary_spm_after_filter() -> Result<()> { #[test] fn preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("d", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("d", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3163,10 +3179,11 @@ fn preserve_ordering_through_repartition() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3202,10 +3219,11 @@ fn do_not_preserve_ordering_through_repartition() -> Result<()> { #[test] fn no_need_for_sort_after_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3227,16 +3245,18 @@ fn no_need_for_sort_after_filter() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition2() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key]); - let sort_req = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_req = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let physical_plan = sort_preserving_merge_exec(sort_req, filter_exec(input)); let test_config = TestConfig::default(); @@ -3272,10 +3292,11 @@ fn do_not_preserve_ordering_through_repartition2() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition3() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key]); let physical_plan = filter_exec(input); @@ -3294,10 +3315,11 @@ fn do_not_preserve_ordering_through_repartition3() -> Result<()> { #[test] fn do_not_put_sort_when_input_is_invalid() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec(); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); let expected = &[ @@ -3331,10 +3353,11 @@ fn do_not_put_sort_when_input_is_invalid() -> Result<()> { #[test] fn put_sort_when_input_is_valid() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); @@ -3368,12 +3391,13 @@ fn put_sort_when_input_is_valid() -> Result<()> { #[test] fn do_not_add_unnecessary_hash() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![("a".to_string(), "a".to_string())]; - let input = parquet_exec_with_sort(vec![sort_key]); + let input = parquet_exec_with_sort(schema, vec![sort_key]); let physical_plan = aggregate_exec_with_alias(input, alias); // TestConfig: @@ -3394,10 +3418,11 @@ fn do_not_add_unnecessary_hash() -> Result<()> { #[test] fn do_not_add_unnecessary_hash2() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![("a".to_string(), "a".to_string())]; let input = parquet_exec_multiple_sorted(vec![sort_key]); let aggregate = aggregate_exec_with_alias(input, alias.clone()); @@ -3471,3 +3496,140 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> { Ok(()) } + +/// Ensures that `DataSourceExec` has been repartitioned into `target_partitions` file groups +#[tokio::test] +async fn test_distribute_sort_parquet() -> Result<()> { + let test_config: TestConfig = + TestConfig::default().with_prefer_repartition_file_scans(1000); + assert!( + test_config.config.optimizer.repartition_file_scans, + "should enable scans to be repartitioned" + ); + + let schema = schema(); + let sort_key = [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); + let physical_plan = sort_exec(sort_key, parquet_exec_with_stats(10000 * 8192)); + + // prior to optimization, this is the starting plan + let starting = &[ + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + plans_matches_expected!(starting, physical_plan.clone()); + + // what the enforce distribution run does. + let expected = &[ + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected, physical_plan.clone(), &[Run::Distribution])?; + + // what the sort parallelization (in enforce sorting), does after the enforce distribution changes + let expected = &[ + "SortPreservingMergeExec: [c@2 ASC]", + " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", + " DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected, physical_plan, &[Run::Distribution, Run::Sorting])?; + Ok(()) +} + +/// Ensures that `DataSourceExec` has been repartitioned into `target_partitions` memtable groups +#[tokio::test] +async fn test_distribute_sort_memtable() -> Result<()> { + let test_config: TestConfig = + TestConfig::default().with_prefer_repartition_file_scans(1000); + assert!( + test_config.config.optimizer.repartition_file_scans, + "should enable scans to be repartitioned" + ); + + let mem_table = create_memtable()?; + let session_config = SessionConfig::new() + .with_repartition_file_min_size(1000) + .with_target_partitions(3); + let ctx = SessionContext::new_with_config(session_config); + ctx.register_table("users", Arc::new(mem_table))?; + + let dataframe = ctx.sql("SELECT * FROM users order by id;").await?; + let physical_plan = dataframe.create_physical_plan().await?; + + // this is the final, optimized plan + let expected = &[ + "SortPreservingMergeExec: [id@0 ASC NULLS LAST]", + " SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true]", + " DataSourceExec: partitions=3, partition_sizes=[34, 33, 33]", + ]; + plans_matches_expected!(expected, physical_plan); + + Ok(()) +} + +/// Create a [`MemTable`] with 100 batches of 8192 rows each, in 1 partition +fn create_memtable() -> Result { + let mut batches = Vec::with_capacity(100); + for _ in 0..100 { + batches.push(create_record_batch()?); + } + let partitions = vec![batches]; + MemTable::try_new(get_schema(), partitions) +} + +fn create_record_batch() -> Result { + let id_array = UInt8Array::from(vec![1; 8192]); + let account_array = UInt64Array::from(vec![9000; 8192]); + + Ok(RecordBatch::try_new( + get_schema(), + vec![Arc::new(id_array), Arc::new(account_array)], + ) + .unwrap()) +} + +fn get_schema() -> SchemaRef { + SchemaRef::new(Schema::new(vec![ + Field::new("id", DataType::UInt8, false), + Field::new("bank_account", DataType::UInt64, true), + ])) +} +#[test] +fn test_replace_order_preserving_variants_with_fetch() -> Result<()> { + // Create a base plan + let parquet_exec = parquet_exec(); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("id", 0))); + + // Create a SortPreservingMergeExec with fetch=5 + let spm_exec = Arc::new( + SortPreservingMergeExec::new([sort_expr].into(), parquet_exec.clone()) + .with_fetch(Some(5)), + ); + + // Create distribution context + let dist_context = DistributionContext::new( + spm_exec, + true, + vec![DistributionContext::new(parquet_exec, false, vec![])], + ); + + // Apply the function + let result = replace_order_preserving_variants(dist_context)?; + + // Verify the plan was transformed to CoalescePartitionsExec + result + .plan + .as_any() + .downcast_ref::() + .expect("Expected CoalescePartitionsExec"); + + // Verify fetch was preserved + assert_eq!( + result.plan.fetch(), + Some(5), + "Fetch value was not preserved after transformation" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 4d2c875d3f1d4..a2c604a84e76f 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -17,130 +17,119 @@ use std::sync::Arc; +use crate::memory_limit::DummyStreamPartition; use crate::physical_optimizer::test_utils::{ - aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec, - coalesce_partitions_exec, create_test_schema, create_test_schema2, - create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, limit_exec, - local_limit_exec, memory_exec, parquet_exec, repartition_exec, sort_exec, - sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, - sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, - spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, + aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, + check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, + create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, + hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, + projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, + sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, + sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, + union_exec, RequirementsTestExec, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TreeNode, TransformedResult}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{JoinType, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; +use datafusion_common::{Result, ScalarValue, TableReference}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; +use datafusion_datasource::source::DataSourceExec; +use datafusion_expr_common::operator::Operator; +use datafusion_expr::{JoinType, SortExpr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr::expressions::{col, Column, NotExpr}; -use datafusion_physical_expr::Partitioning; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, OrderingRequirements +}; +use datafusion_physical_expr::{Distribution, Partitioning}; +use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, NotExpr}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{displayable, get_plan_string, ExecutionPlan, InputOrderMode}; -use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; +use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::listing::PartitionedFile; use datafusion_physical_optimizer::enforce_sorting::{EnforceSorting, PlanWithCorrespondingCoalescePartitions, PlanWithCorrespondingSort, parallelize_sorts, ensure_sorting}; use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{replace_with_order_preserving_variants, OrderPreservationContext}; use datafusion_physical_optimizer::enforce_sorting::sort_pushdown::{SortPushDown, assign_initial_requirements, pushdown_sorts}; use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; +use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_functions_aggregate::average::avg_udaf; -use datafusion_functions_aggregate::count::count_udaf; -use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; - -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::source::DataSourceExec; -use rstest::rstest; - -/// Create a csv exec for tests -fn csv_exec_ordered( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - Arc::new(CsvSource::new(true, 0, b'"')), - ) - .with_file(PartitionedFile::new("file_path".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); - - DataSourceExec::from_data_source(config) -} - -/// Created a sorted parquet exec -pub fn parquet_exec_sorted( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let source = Arc::new(ParquetSource::default()); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - source, - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); +use datafusion::prelude::*; +use arrow::array::{Int32Array, RecordBatch}; +use arrow::datatypes::{Field}; +use arrow_schema::Schema; +use datafusion_execution::TaskContext; +use datafusion_catalog::streaming::StreamingTable; - DataSourceExec::from_data_source(config) -} +use futures::StreamExt; +use insta::{assert_snapshot, Settings}; /// Create a sorted Csv exec fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let config = FileScanConfigBuilder::new( + let mut builder = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), schema.clone(), Arc::new(CsvSource::new(false, 0, 0)), ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); + .with_file(PartitionedFile::new("x".to_string(), 100)); + if let Some(ordering) = LexOrdering::new(sort_exprs) { + builder = builder.with_output_ordering(vec![ordering]); + } + let config = builder.build(); DataSourceExec::from_data_source(config) } /// Runs the sort enforcement optimizer and asserts the plan /// against the original and expected plans -/// -/// `$EXPECTED_PLAN_LINES`: input plan -/// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan -/// `$PLAN`: the plan to optimized -/// `REPARTITION_SORTS`: Flag to set `config.options.optimizer.repartition_sorts` option. -/// -macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $REPARTITION_SORTS: expr) => { +struct EnforceSortingTest { + plan: Arc, + repartition_sorts: bool, +} + +impl EnforceSortingTest { + fn new(plan: Arc) -> Self { + Self { + plan, + repartition_sorts: false, + } + } + + /// Set whether to repartition sorts + fn with_repartition_sorts(mut self, repartition_sorts: bool) -> Self { + self.repartition_sorts = repartition_sorts; + self + } + + /// Runs the enforce sorting test and returns a string with the input and + /// optimized plan as strings for snapshot comparison using insta + fn run(&self) -> String { let mut config = ConfigOptions::new(); - config.optimizer.repartition_sorts = $REPARTITION_SORTS; + config.optimizer.repartition_sorts = self.repartition_sorts; // This file has 4 rules that use tree node, apply these rules as in the // EnforceSorting::optimize implementation // After these operations tree nodes should be in a consistent state. // This code block makes sure that these rules doesn't violate tree node integrity. { - let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); + let plan_requirements = + PlanWithCorrespondingSort::new_default(Arc::clone(&self.plan)); let adjusted = plan_requirements .transform_up(ensure_sorting) .data() - .and_then(check_integrity)?; + .and_then(check_integrity) + .expect("check_integrity failed after ensure_sorting"); // TODO: End state payloads will be checked here. let new_plan = if config.optimizer.repartition_sorts { @@ -149,60 +138,60 @@ macro_rules! assert_optimized { let parallel = plan_with_coalesce_partitions .transform_up(parallelize_sorts) .data() - .and_then(check_integrity)?; + .and_then(check_integrity) + .expect("check_integrity failed after parallelize_sorts"); // TODO: End state payloads will be checked here. parallel.plan } else { adjusted.plan }; - let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan); + let plan_with_pipeline_fixer = + OrderPreservationContext::new_default(new_plan); let updated_plan = plan_with_pipeline_fixer .transform_up(|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, true, - &config, + &config, ) }) .data() - .and_then(check_integrity)?; + .and_then(check_integrity) + .expect( + "check_integrity failed after replace_with_order_preserving_variants", + ); // TODO: End state payloads will be checked here. let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); - check_integrity(pushdown_sorts(sort_pushdown)?)?; + check_integrity( + pushdown_sorts(sort_pushdown).expect("pushdown_sorts failed"), + ) + .expect("check_integrity failed after pushdown_sorts"); // TODO: End state payloads will be checked here. } - - let physical_plan = $PLAN; - let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES - .iter().map(|s| *s).collect(); - - assert_eq!( - expected_plan_lines, actual, - "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES - .iter().map(|s| *s).collect(); + let input_plan_string = displayable(self.plan.as_ref()).indent(true).to_string(); // Run the actual optimizer - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan,&config)?; + let optimized_physical_plan = EnforceSorting::new() + .optimize(Arc::clone(&self.plan), &config) + .expect("enforce_sorting failed"); // Get string representation of the plan - let actual = get_plan_string(&optimized_physical_plan); - assert_eq!( - expected_optimized_lines, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_optimized_lines:#?}\nactual:\n\n{actual:#?}\n\n" - ); + let optimized_plan_string = displayable(optimized_physical_plan.as_ref()) + .indent(true) + .to_string(); - }; + if input_plan_string == optimized_plan_string { + format!("Input / Optimized Plan:\n{input_plan_string}",) + } else { + format!( + "Input Plan:\n{input_plan_string}\nOptimized Plan:\n{optimized_plan_string}", + ) + } + } } #[tokio::test] @@ -210,96 +199,97 @@ async fn test_remove_unnecessary_sort5() -> Result<()> { let left_schema = create_test_schema2()?; let right_schema = create_test_schema3()?; let left_input = memory_exec(&left_schema); - let parquet_sort_exprs = vec![sort_expr("a", &right_schema)]; - let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs); - + let parquet_ordering = [sort_expr("a", &right_schema)].into(); + let right_input = + parquet_exec_with_sort(right_schema.clone(), vec![parquet_ordering]); let on = vec![( Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _, Arc::new(Column::new_with_schema("c", &right_schema)?) as _, )]; let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?; - let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); - - let expected_input = ["SortExec: expr=[a@2 ASC], preserve_partitioning=[false]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet"]; - - let expected_optimized = ["HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - + let physical_plan = sort_exec([sort_expr("a", &join.schema())].into(), join); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@2 ASC], preserve_partitioning=[false] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)] + DataSourceExec: partitions=1, partition_sizes=[0] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + + Optimized Plan: + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)] + DataSourceExec: partitions=1, partition_sizes=[0] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + "); Ok(()) } #[tokio::test] async fn test_do_not_remove_sort_with_limit() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort = sort_exec(sort_exprs.clone(), source1); - let limit = limit_exec(sort); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + ] + .into(); + let sort = sort_exec(ordering.clone(), source1); + let limit = local_limit_exec(sort, 100); + let parquet_ordering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, limit]); let repartition = repartition_exec(union); - let physical_plan = sort_preserving_merge_exec(sort_exprs, repartition); - - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - + let physical_plan = sort_preserving_merge_exec(ordering, repartition); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + LocalLimitExec: fetch=100 + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2 + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + LocalLimitExec: fetch=100 + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // We should keep the bottom `SortExec`. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - Ok(()) } #[tokio::test] async fn test_union_inputs_sorted() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source1); - - let source2 = parquet_exec_sorted(&schema, sort_exprs.clone()); - + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source1); + let source2 = parquet_exec_with_sort(schema, vec![ordering.clone()]); let union = union_exec(vec![source2, sort]); - let physical_plan = sort_preserving_merge_exec(sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(ordering, union); // one input to the union is already sorted, one is not. - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - ]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // should not add a sort at the output of the union, input plan should not be changed - let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -307,31 +297,30 @@ async fn test_union_inputs_sorted() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source1); - - let parquet_sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source1); + let parquet_ordering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + ] + .into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, sort]); - let physical_plan = sort_preserving_merge_exec(sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(ordering, union); // one input to the union is already sorted, one is not. - let expected_input = vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - ]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // should not add a sort at the output of the union, input plan should not be changed - let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -339,36 +328,38 @@ async fn test_union_inputs_different_sorted() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted2() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let sort_exprs: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ] + .into(); let sort = sort_exec(sort_exprs.clone(), source1); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + let parquet_ordering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, sort]); let physical_plan = sort_preserving_merge_exec(sort_exprs, union); // Input is an invalid plan. In this case rule should add required sorting in appropriate places. // First DataSourceExec has output ordering(nullable_col@0 ASC). However, it doesn't satisfy the // required ordering of SortPreservingMergeExec. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } @@ -376,83 +367,89 @@ async fn test_union_inputs_different_sorted2() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted3() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let sort2 = sort_exec(sort_exprs2, source1); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + let sort2 = sort_exec(ordering2, source1); + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering.clone()]); let union = union_exec(vec![sort1, source2, sort2]); - let physical_plan = sort_preserving_merge_exec(parquet_sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(parquet_ordering, union); // First input to the union is not Sorted (SortExec is finer than required ordering by the SortPreservingMergeExec above). // Second input to the union is already Sorted (matches with the required ordering by the SortPreservingMergeExec above). // Third input to the union is not Sorted (SortExec is matches required ordering by the SortPreservingMergeExec above). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // should adjust sorting in the first input of the union such that it is not unnecessarily fine - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - Ok(()) } #[tokio::test] async fn test_union_inputs_different_sorted4() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs2.clone(), source1.clone()); - let sort2 = sort_exec(sort_exprs2.clone(), source1); - - let source2 = parquet_exec_sorted(&schema, sort_exprs2); - + ] + .into(); + let ordering2: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering2.clone(), source1.clone()); + let sort2 = sort_exec(ordering2.clone(), source1); + let source2 = parquet_exec_with_sort(schema, vec![ordering2]); let union = union_exec(vec![sort1, source2, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs1, union); + let physical_plan = sort_preserving_merge_exec(ordering1, union); // Ordering requirement of the `SortPreservingMergeExec` is not met. // Should modify the plan to ensure that all three inputs to the // `UnionExec` satisfy the ordering, OR add a single sort after // the `UnionExec` (both of which are equally good for this example). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } @@ -460,13 +457,13 @@ async fn test_union_inputs_different_sorted4() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted5() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr_options( "non_nullable_col", @@ -476,30 +473,35 @@ async fn test_union_inputs_different_sorted5() -> Result<()> { nulls_first: false, }, ), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort2 = sort_exec(sort_exprs2, source1); - + ] + .into(); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); let union = union_exec(vec![sort1, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let physical_plan = sort_preserving_merge_exec(ordering3, union); // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. However, we should be able to change the unnecessarily // fine `SortExec`s below with required `SortExec`s that are absolutely necessary. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } @@ -507,22 +509,20 @@ async fn test_union_inputs_different_sorted5() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted6() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort_exprs2 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ] + .into(); let repartition = repartition_exec(source1); - let spm = sort_preserving_merge_exec(sort_exprs2, repartition); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - + let spm = sort_preserving_merge_exec(ordering2, repartition); + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering.clone()]); let union = union_exec(vec![sort1, source2, spm]); - let physical_plan = sort_preserving_merge_exec(parquet_sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(parquet_ordering, union); // The plan is not valid as it is -- the input ordering requirement // of the `SortPreservingMergeExec` under the third child of the @@ -530,25 +530,30 @@ async fn test_union_inputs_different_sorted6() -> Result<()> { // At the same time, this ordering requirement is unnecessarily fine. // The final plan should be valid AND the ordering of the third child // shouldn't be finer than necessary. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // Should adjust the requirement in the third input of the union so // that it is not unnecessarily fine. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } @@ -556,34 +561,38 @@ async fn test_union_inputs_different_sorted6() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted7() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1.clone(), source1.clone()); - let sort2 = sort_exec(sort_exprs1, source1); - + ] + .into(); + let sort1 = sort_exec(ordering1.clone(), source1.clone()); + let sort2 = sort_exec(ordering1, source1); let union = union_exec(vec![sort1, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering2, union); // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - // Union preserves the inputs ordering and we should not change any of the SortExecs under UnionExec - let expected_output = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_output, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); + // Union preserves the inputs ordering, and we should not change any of the SortExecs under UnionExec Ok(()) } @@ -591,13 +600,13 @@ async fn test_union_inputs_different_sorted7() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted8() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr_options( "nullable_col", &schema, @@ -614,75 +623,484 @@ async fn test_union_inputs_different_sorted8() -> Result<()> { nulls_first: false, }, ), - ]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort2 = sort_exec(sort_exprs2, source1); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); let physical_plan = union_exec(vec![sort1, sort2]); // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. - let expected_input = ["UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[nullable_col@0 DESC NULLS LAST, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); // Since `UnionExec` doesn't preserve ordering in the plan above. // We shouldn't keep SortExecs in the plan. - let expected_optimized = ["UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } #[tokio::test] -async fn test_window_multi_path_sort() -> Result<()> { +async fn test_soft_hard_requirements_remove_soft_requirement() -> Result<()> { let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let sort_exprs = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(sort_exprs, source); + let partition_bys = &[col("nullable_col", &schema)?]; + let physical_plan = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys, sort); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} - let sort_exprs1 = vec![ - sort_expr("nullable_col", &schema), - sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - // reverse sorting of sort_exprs2 - let sort_exprs3 = vec![sort_expr_options( +#[tokio::test] +async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns( +) -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source.clone()); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "count".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let bounded_window = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys, sort); + let physical_plan = projection_exec(proj_exprs, bounded_window)?; + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let physical_plan = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source.clone()); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let physical_plan = bounded_window_exec_with_partition( + "count", + vec![], + partition_bys, + bounded_window, + ); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + + let ordering2: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort2 = sort_exec(ordering2.clone(), bounded_window); + let sort3 = sort_exec(ordering2, sort2); + let physical_plan = + bounded_window_exec_with_partition("count", vec![], partition_bys, sort3); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_multiple_sorts() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( "nullable_col", &schema, SortOptions { descending: true, nulls_first: false, }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), )]; - let source1 = parquet_exec_sorted(&schema, sort_exprs1); - let source2 = parquet_exec_sorted(&schema, sort_exprs2); - let sort1 = sort_exec(sort_exprs3.clone(), source1); - let sort2 = sort_exec(sort_exprs3.clone(), source2); + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let ordering2: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort2 = sort_exec(ordering2.clone(), bounded_window); + let physical_plan = sort_exec(ordering2, sort2); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_requirement( +) -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let partition_bys1 = &[col("nullable_col", &schema)?]; + let bounded_window = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys1, sort); + let partition_bys2 = &[col("non_nullable_col", &schema)?]; + let bounded_window2 = bounded_window_exec_with_partition( + "non_nullable_col", + vec![], + partition_bys2, + bounded_window, + ); + let requirement = [PhysicalSortRequirement::new( + col("non_nullable_col", &schema)?, + Some(SortOptions::new(false, true)), + )] + .into(); + let physical_plan = Arc::new(OutputRequirementExec::new( + bounded_window2, + Some(OrderingRequirements::new(requirement)), + Distribution::SinglePartition, + None, + )); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + OutputRequirementExec: order_by=[(non_nullable_col@1, asc)], dist_by=SinglePartition + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + OutputRequirementExec: order_by=[(non_nullable_col@1, asc)], dist_by=SinglePartition + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "#); + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "OutputRequirementExec", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + // " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + Ok(()) +} +#[tokio::test] +async fn test_window_multi_path_sort() -> Result<()> { + let schema = create_test_schema()?; + let ordering1 = [ + sort_expr("nullable_col", &schema), + sort_expr("non_nullable_col", &schema), + ] + .into(); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + // Reverse of the above + let ordering3: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let source1 = parquet_exec_with_sort(schema.clone(), vec![ordering1]); + let source2 = parquet_exec_with_sort(schema, vec![ordering2]); + let sort1 = sort_exec(ordering3.clone(), source1); + let sort2 = sort_exec(ordering3.clone(), source2); let union = union_exec(vec![sort1, sort2]); - let spm = sort_preserving_merge_exec(sort_exprs3.clone(), union); - let physical_plan = bounded_window_exec("nullable_col", sort_exprs3, spm); + let spm = sort_preserving_merge_exec(ordering3.clone(), union); + let physical_plan = bounded_window_exec("nullable_col", ordering3, spm); // The `WindowAggExec` gets its sorting from multiple children jointly. // During the removal of `SortExec`s, it should be able to remove the // corresponding SortExecs together. Also, the inputs of these `SortExec`s // are not necessarily the same to be able to remove them. - let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST]", - " UnionExec", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - let expected_optimized = [ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST] + UnionExec + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + "#); Ok(()) } @@ -690,36 +1108,40 @@ async fn test_window_multi_path_sort() -> Result<()> { #[tokio::test] async fn test_window_multi_path_sort2() -> Result<()> { let schema = create_test_schema()?; - - let sort_exprs1 = LexOrdering::new(vec![ + let ordering1: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]); - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let source1 = parquet_exec_sorted(&schema, sort_exprs2.clone()); - let source2 = parquet_exec_sorted(&schema, sort_exprs2.clone()); - let sort1 = sort_exec(sort_exprs1.clone(), source1); - let sort2 = sort_exec(sort_exprs1.clone(), source2); - + ] + .into(); + let ordering2: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source1 = parquet_exec_with_sort(schema.clone(), vec![ordering2.clone()]); + let source2 = parquet_exec_with_sort(schema, vec![ordering2.clone()]); + let sort1 = sort_exec(ordering1.clone(), source1); + let sort2 = sort_exec(ordering1.clone(), source2); let union = union_exec(vec![sort1, sort2]); - let spm = Arc::new(SortPreservingMergeExec::new(sort_exprs1, union)) as _; - let physical_plan = bounded_window_exec("nullable_col", sort_exprs2, spm); + let spm = Arc::new(SortPreservingMergeExec::new(ordering1, union)) as _; + let physical_plan = bounded_window_exec("nullable_col", ordering2, spm); // The `WindowAggExec` can get its required sorting from the leaf nodes directly. // The unnecessary SortExecs should be removed - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet + "#); Ok(()) } @@ -727,13 +1149,13 @@ async fn test_window_multi_path_sort2() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr_options( "non_nullable_col", @@ -743,35 +1165,39 @@ async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { nulls_first: false, }, ), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - - let sort2 = sort_exec(sort_exprs2, source1); - let limit = local_limit_exec(sort2); - let limit = global_limit_exec(limit); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); + let limit = local_limit_exec(sort2, 100); + let limit = global_limit_exec(limit, 0, Some(100)); let union = union_exec(vec![sort1, limit]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering3, union); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); // Should not change the unnecessarily fine `SortExec`s because there is `LimitExec` - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " UnionExec", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + UnionExec + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + GlobalLimitExec: skip=0, fetch=100 + LocalLimitExec: fetch=100 + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } @@ -781,15 +1207,17 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; + let settings = Settings::clone_current(); + let join_types = vec![ JoinType::Inner, JoinType::Left, @@ -801,49 +1229,69 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let sort_exprs = vec![ + let ordering = [ sort_expr("nullable_col", &join.schema()), sort_expr("non_nullable_col", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs.clone(), join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join); - let join_plan = format!( - "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" - ); - let join_plan2 = format!( - " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + let mut settings = settings.clone(); + + settings.add_filter( + // join_type={} replace with join_type=... to avoid snapshot name issue + format!("join_type={join_type}").as_str(), + "join_type=...", ); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - join_plan2.as_str(), - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - let expected_optimized = match join_type { + + insta::allow_duplicates! { + settings.bind( || { + + + match join_type { JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { // can push down the sort requirements and save 1 SortExec - vec![ - join_plan.as_str(), - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", - ] + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); } _ => { // can not push down the sort requirements - vec![ - "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - join_plan2.as_str(), - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", - ] + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); } }; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + }) + } } Ok(()) } @@ -853,15 +1301,17 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; + let settings = Settings::clone_current(); + let join_types = vec![ JoinType::Inner, JoinType::Left, @@ -872,50 +1322,83 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let sort_exprs = vec![ + let ordering = [ sort_expr("col_a", &join.schema()), sort_expr("col_b", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs, join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join); - let join_plan = format!( - "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" - ); - let spm_plan = match join_type { - JoinType::RightAnti => "SortPreservingMergeExec: [col_a@0 ASC, col_b@1 ASC]", - _ => "SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC]", - }; - let join_plan2 = format!( - " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + let mut settings = settings.clone(); + + settings.add_filter( + // join_type={} replace with join_type=... to avoid snapshot name issue + format!("join_type={join_type}").as_str(), + "join_type=...", ); - let expected_input = [spm_plan, - join_plan2.as_str(), - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - let expected_optimized = match join_type { - JoinType::Inner | JoinType::Right | JoinType::RightAnti => { + + insta::allow_duplicates! { + settings.bind( || { + + + match join_type { + JoinType::Inner | JoinType::Right => { + // can push down the sort requirements and save 1 SortExec + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC] + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); + } + JoinType::RightAnti => { // can push down the sort requirements and save 1 SortExec - vec![ - join_plan.as_str(), - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", - ] + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [col_a@0 ASC, col_b@1 ASC] + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); } _ => { // can not push down the sort requirements for Left and Full join. - vec![ - "SortExec: expr=[col_a@2 ASC, col_b@3 ASC], preserve_partitioning=[false]", - join_plan2.as_str(), - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", - ] + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC] + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortExec: expr=[col_a@2 ASC, col_b@3 ASC], preserve_partitioning=[false] + SortMergeJoin: join_type=..., on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); } }; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + }) + } } Ok(()) } @@ -925,59 +1408,69 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; let join = sort_merge_join_exec(left, right, &join_on, &JoinType::Inner); // order by (col_b, col_a) - let sort_exprs1 = vec![ + let ordering = [ sort_expr("col_b", &join.schema()), sort_expr("col_a", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs1, join.clone()); - - let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join.clone()); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC] + SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortExec: expr=[col_b@3 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); // order by (nullable_col, col_b, col_a) - let sort_exprs2 = vec![ + let ordering2 = [ sort_expr("nullable_col", &join.schema()), sort_expr("col_b", &join.schema()), sort_expr("col_a", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs2, join); - - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - - // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering2, join); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC] + SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + + Optimized Plan: + SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet + "); + // Can push down the sort requirements since col_a = nullable_col Ok(()) } @@ -985,152 +1478,136 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { #[tokio::test] async fn test_multilayer_coalesce_partitions() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); + let source1 = parquet_exec(schema.clone()); let repartition = repartition_exec(source1); - let coalesce = Arc::new(CoalescePartitionsExec::new(repartition)) as _; + let coalesce = coalesce_partitions_exec(repartition) as _; // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), coalesce, ); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let physical_plan = sort_exec(sort_exprs, filter); + let ordering = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_exec(ordering, filter); // CoalescePartitionsExec and SortExec are not directly consecutive. In this case // we should be able to parallelize Sorting also (given that executors in between don't require) // single partition. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " FilterExec: NOT non_nullable_col@1", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", - " FilterExec: NOT non_nullable_col@1", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + FilterExec: NOT non_nullable_col@1 + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true] + FilterExec: NOT non_nullable_col@1 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet + "); Ok(()) } -#[tokio::test] -async fn test_with_lost_ordering_bounded() -> Result<()> { +fn create_lost_ordering_plan(source_unbounded: bool) -> Result> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = csv_exec_sorted(&schema, sort_exprs); - let repartition_rr = repartition_exec(source); - let repartition_hash = Arc::new(RepartitionExec::try_new( - repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), - )?) as _; - let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - - let expected_input = ["SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); - - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn test_with_lost_ordering_unbounded_bounded( - #[values(false, true)] source_unbounded: bool, -) -> Result<()> { - let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let sort_exprs = [sort_expr("a", &schema)]; // create either bounded or unbounded source let source = if source_unbounded { - stream_exec_ordered(&schema, sort_exprs) + stream_exec_ordered(&schema, sort_exprs.clone().into()) } else { - csv_exec_ordered(&schema, sort_exprs) + csv_exec_sorted(&schema, sort_exprs.clone()) }; let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + Partitioning::Hash(vec![col("c", &schema)?], 10), )?) as _; let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = vec![ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", - ]; - let expected_input_bounded = vec![ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", - ]; + let physical_plan = sort_exec(sort_exprs.into(), coalesce_partitions); + Ok(physical_plan) +} - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = vec![ - "SortPreservingMergeExec: [a@0 ASC]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", - ]; +#[tokio::test] +async fn test_with_lost_ordering_unbounded() -> Result<()> { + let physical_plan = create_lost_ordering_plan(true)?; + + let test_no_repartition_sorts = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(false); + + assert_snapshot!(test_no_repartition_sorts.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + + Optimized Plan: + SortPreservingMergeExec: [a@0 ASC] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + "); + + let test_with_repartition_sorts = + EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test_with_repartition_sorts.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + + Optimized Plan: + SortPreservingMergeExec: [a@0 ASC] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + "); - // Expected bounded results with and without flag - let expected_optimized_bounded = vec![ - "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", - ]; - let expected_optimized_bounded_parallelize_sort = vec![ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", - ]; - let (expected_input, expected_optimized, expected_optimized_sort_parallelize) = - if source_unbounded { - ( - expected_input_unbounded, - expected_optimized_unbounded.clone(), - expected_optimized_unbounded, - ) - } else { - ( - expected_input_bounded, - expected_optimized_bounded, - expected_optimized_bounded_parallelize_sort, - ) - }; - assert_optimized!( - expected_input, - expected_optimized, - physical_plan.clone(), - false - ); - assert_optimized!( - expected_input, - expected_optimized_sort_parallelize, - physical_plan, - true - ); + Ok(()) +} + +#[tokio::test] +async fn test_with_lost_ordering_bounded() -> Result<()> { + let physical_plan = create_lost_ordering_plan(false)?; + + let test_no_repartition_sorts = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(false); + + assert_snapshot!(test_no_repartition_sorts.run(), @r" + Input / Optimized Plan: + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false + "); + + let test_with_repartition_sorts = + EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test_with_repartition_sorts.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false + + Optimized Plan: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false + "); Ok(()) } @@ -1138,21 +1615,21 @@ async fn test_with_lost_ordering_unbounded_bounded( #[tokio::test] async fn test_do_not_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); - let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); - let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); - - let expected_input = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - let expected_optimized = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); + let spm = sort_preserving_merge_exec(sort_exprs.into(), repartition_rr); + let physical_plan = sort_exec([sort_expr("b", &schema)].into(), spm); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false + "); Ok(()) } @@ -1160,192 +1637,115 @@ async fn test_do_not_pushdown_through_spm() -> Result<()> { #[tokio::test] async fn test_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); - let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); + let spm = sort_preserving_merge_exec(sort_exprs.into(), repartition_rr); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), spm, ); - - let expected_input = ["SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); - + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false + + Optimized Plan: + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false + "); Ok(()) } #[tokio::test] async fn test_window_multi_layer_requirement() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, vec![]); - let sort = sort_exec(sort_exprs.clone(), source); + let sort = sort_exec(sort_exprs.clone().into(), source); let repartition = repartition_exec(sort); let repartition = spr_repartition_exec(repartition); - let spm = sort_preserving_merge_exec(sort_exprs.clone(), repartition); - + let spm = sort_preserving_merge_exec(sort_exprs.clone().into(), repartition); let physical_plan = bounded_window_exec("a", sort_exprs, spm); - let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - let expected_optimized = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, false); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortPreservingMergeExec: [a@0 ASC, b@1 ASC] + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + "#); Ok(()) } #[tokio::test] async fn test_not_replaced_with_partial_sort_for_bounded_input() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let parquet_input = parquet_exec_sorted(&schema, input_sort_exprs); - + let parquet_ordering = [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let parquet_input = parquet_exec_with_sort(schema.clone(), vec![parquet_ordering]); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), parquet_input, ); - let expected_input = [ - "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[b@1 ASC, c@2 ASC], file_type=parquet" - ]; - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, physical_plan, false); - Ok(()) -} - -/// Runs the sort enforcement optimizer and asserts the plan -/// against the original and expected plans -/// -/// `$EXPECTED_PLAN_LINES`: input plan -/// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan -/// `$PLAN`: the plan to optimized -/// `REPARTITION_SORTS`: Flag to set `config.options.optimizer.repartition_sorts` option. -/// `$CASE_NUMBER` (optional): The test case number to print on failure. -macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $REPARTITION_SORTS: expr $(, $CASE_NUMBER: expr)?) => { - let mut config = ConfigOptions::new(); - config.optimizer.repartition_sorts = $REPARTITION_SORTS; - - // This file has 4 rules that use tree node, apply these rules as in the - // EnforceSorting::optimize implementation - // After these operations tree nodes should be in a consistent state. - // This code block makes sure that these rules doesn't violate tree node integrity. - { - let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); - let adjusted = plan_requirements - .transform_up(ensure_sorting) - .data() - .and_then(check_integrity)?; - // TODO: End state payloads will be checked here. - - let new_plan = if config.optimizer.repartition_sorts { - let plan_with_coalesce_partitions = - PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); - let parallel = plan_with_coalesce_partitions - .transform_up(parallelize_sorts) - .data() - .and_then(check_integrity)?; - // TODO: End state payloads will be checked here. - parallel.plan - } else { - adjusted.plan - }; + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(false); - let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan); - let updated_plan = plan_with_pipeline_fixer - .transform_up(|plan_with_pipeline_fixer| { - replace_with_order_preserving_variants( - plan_with_pipeline_fixer, - false, - true, - &config, - ) - }) - .data() - .and_then(check_integrity)?; - // TODO: End state payloads will be checked here. - - let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); - assign_initial_requirements(&mut sort_pushdown); - check_integrity(pushdown_sorts(sort_pushdown)?)?; - // TODO: End state payloads will be checked here. - } - - let physical_plan = $PLAN; - let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES - .iter().map(|s| *s).collect(); - - if expected_plan_lines != actual { - $(println!("\n**Original Plan Mismatch in case {}**", $CASE_NUMBER);)? - println!("\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", expected_plan_lines, actual); - assert_eq!(expected_plan_lines, actual); - } - - let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES - .iter().map(|s| *s).collect(); - - // Run the actual optimizer - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan, &config)?; + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[b@1 ASC, c@2 ASC], file_type=parquet + "); - // Get string representation of the plan - let actual = get_plan_string(&optimized_physical_plan); - if expected_optimized_lines != actual { - $(println!("\n**Optimized Plan Mismatch in case {}**", $CASE_NUMBER);)? - println!("\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", expected_optimized_lines, actual); - assert_eq!(expected_optimized_lines, actual); - } - }; + Ok(()) } #[tokio::test] async fn test_remove_unnecessary_sort() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input); - - let expected_input = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); + let physical_plan = sort_exec([sort_expr("nullable_col", &schema)].into(), input); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1354,58 +1754,56 @@ async fn test_remove_unnecessary_sort() -> Result<()> { async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs = vec![sort_expr_options( + let ordering: LexOrdering = [sort_expr_options( "non_nullable_col", &source.schema(), SortOptions { descending: true, nulls_first: true, }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); + )] + .into(); + let sort = sort_exec(ordering.clone(), source); // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let coalesce_batches = coalesce_batches_exec(sort); - - let window_agg = - bounded_window_exec("non_nullable_col", sort_exprs, coalesce_batches); - - let sort_exprs = vec![sort_expr_options( + let coalesce_batches = coalesce_batches_exec(sort, 128); + let window_agg = bounded_window_exec("non_nullable_col", ordering, coalesce_batches); + let ordering2: LexOrdering = [sort_expr_options( "non_nullable_col", &window_agg.schema(), SortOptions { descending: false, nulls_first: false, }, - )]; - - let sort = sort_exec(sort_exprs.clone(), window_agg); - + )] + .into(); + let sort = sort_exec(ordering2.clone(), window_agg); // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), sort, ); - - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs, filter); - - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " FilterExec: NOT non_nullable_col@1", - " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " CoalesceBatchesExec: target_batch_size=128", - " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " FilterExec: NOT non_nullable_col@1", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " CoalesceBatchesExec: target_batch_size=128", - " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let physical_plan = bounded_window_exec("non_nullable_col", ordering2, filter); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + FilterExec: NOT non_nullable_col@1 + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + CoalesceBatchesExec: target_batch_size=128 + SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + FilterExec: NOT non_nullable_col@1 + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + CoalesceBatchesExec: target_batch_size=128 + SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "#); Ok(()) } @@ -1414,20 +1812,20 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { async fn test_add_required_sort() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); + let ordering = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering, source); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + DataSourceExec: partitions=1, partition_sizes=[0] - let physical_plan = sort_preserving_merge_exec(sort_exprs, source); - - let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1436,25 +1834,26 @@ async fn test_add_required_sort() -> Result<()> { async fn test_remove_unnecessary_sort1() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), spm); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); - let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering.clone(), sort); + let sort = sort_exec(ordering.clone(), spm); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1463,38 +1862,38 @@ async fn test_remove_unnecessary_sort1() -> Result<()> { async fn test_remove_unnecessary_sort2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = vec![ + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering, sort); + let ordering2: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort2 = sort_exec(sort_exprs.clone(), spm); - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort3 = sort_exec(sort_exprs, spm2); + ] + .into(); + let sort2 = sort_exec(ordering2.clone(), spm); + let spm2 = sort_preserving_merge_exec(ordering2, sort2); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let sort3 = sort_exec(ordering3, spm2); let physical_plan = repartition_exec(repartition_exec(sort3)); - let expected_input = [ - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - - let expected_optimized = [ - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1503,43 +1902,43 @@ async fn test_remove_unnecessary_sort2() -> Result<()> { async fn test_remove_unnecessary_sort3() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = LexOrdering::new(vec![ + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering, sort); + let repartition_exec = repartition_exec(spm); + let ordering2: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]); - let repartition_exec = repartition_exec(spm); + ] + .into(); let sort2 = Arc::new( - SortExec::new(sort_exprs.clone(), repartition_exec) + SortExec::new(ordering2.clone(), repartition_exec) .with_preserve_partitioning(true), ) as _; - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); - + let spm2 = sort_preserving_merge_exec(ordering2, sort2); let physical_plan = aggregate_exec(spm2); // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = [ - "AggregateExec: mode=Final, gby=[], aggr=[]", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - - let expected_optimized = [ - "AggregateExec: mode=Final, gby=[], aggr=[]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + AggregateExec: mode=Final, gby=[], aggr=[] + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + AggregateExec: mode=Final, gby=[], aggr=[] + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1548,52 +1947,51 @@ async fn test_remove_unnecessary_sort3() -> Result<()> { async fn test_remove_unnecessary_sort4() -> Result<()> { let schema = create_test_schema()?; let source1 = repartition_exec(memory_exec(&schema)); - let source2 = repartition_exec(memory_exec(&schema)); let union = union_exec(vec![source1, source2]); - - let sort_exprs = LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]); - // let sort = sort_exec(sort_exprs.clone(), union); - let sort = Arc::new( - SortExec::new(sort_exprs.clone(), union).with_preserve_partitioning(true), - ) as _; - let spm = sort_preserving_merge_exec(sort_exprs, sort); - + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = + Arc::new(SortExec::new(ordering.clone(), union).with_preserve_partitioning(true)) + as _; + let spm = sort_preserving_merge_exec(ordering, sort); let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), spm, ); - - let sort_exprs = vec![ + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let physical_plan = sort_exec(sort_exprs, filter); + ] + .into(); + let physical_plan = sort_exec(ordering2, filter); // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " FilterExec: NOT non_nullable_col@1", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[true]", - " UnionExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", - " FilterExec: NOT non_nullable_col@1", - " UnionExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + FilterExec: NOT non_nullable_col@1 + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[true] + UnionExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true] + FilterExec: NOT non_nullable_col@1 + UnionExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1602,31 +2000,31 @@ async fn test_remove_unnecessary_sort4() -> Result<()> { async fn test_remove_unnecessary_sort6() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new( - SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - ) - .with_fetch(Some(2)), + let input = sort_exec_with_fetch( + [sort_expr("non_nullable_col", &schema)].into(), + Some(2), + source, ); let physical_plan = sort_exec( - vec![ + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ] + .into(), input, ); - - let expected_input = [ - "SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1635,33 +2033,33 @@ async fn test_remove_unnecessary_sort6() -> Result<()> { async fn test_remove_unnecessary_sort7() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![ + let input = sort_exec( + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ]), + ] + .into(), source, - )); + ); + let physical_plan = sort_exec_with_fetch( + [sort_expr("non_nullable_col", &schema)].into(), + Some(2), + input, + ); - let physical_plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - input, - ) - .with_fetch(Some(2)), - ) as Arc; + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false], sort_prefix=[non_nullable_col@1 ASC] + SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] - let expected_input = [ - "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "GlobalLimitExec: skip=0, fetch=2", - " SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Optimized Plan: + GlobalLimitExec: skip=0, fetch=2 + SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1670,31 +2068,31 @@ async fn test_remove_unnecessary_sort7() -> Result<()> { async fn test_remove_unnecessary_sort8() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - )); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); let limit = Arc::new(LocalLimitExec::new(input, 2)); let physical_plan = sort_exec( - vec![ + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ] + .into(), limit, ); - let expected_input = [ - "SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " LocalLimitExec: fetch=2", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "LocalLimitExec: fetch=2", - " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + LocalLimitExec: fetch=2 + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + LocalLimitExec: fetch=2 + SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1703,27 +2101,19 @@ async fn test_remove_unnecessary_sort8() -> Result<()> { async fn test_do_not_pushdown_through_limit() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - // let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - )); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); let limit = Arc::new(GlobalLimitExec::new(input, 0, Some(5))) as _; - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], limit); - - let expected_input = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " GlobalLimitExec: skip=0, fetch=5", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " GlobalLimitExec: skip=0, fetch=5", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let physical_plan = sort_exec([sort_expr("nullable_col", &schema)].into(), limit); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + GlobalLimitExec: skip=0, fetch=5 + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1732,24 +2122,25 @@ async fn test_do_not_pushdown_through_limit() -> Result<()> { async fn test_remove_unnecessary_spm1() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = - sort_preserving_merge_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let input2 = - sort_preserving_merge_exec(vec![sort_expr("non_nullable_col", &schema)], input); + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let input = sort_preserving_merge_exec(ordering.clone(), source); + let input2 = sort_preserving_merge_exec(ordering, input); let physical_plan = - sort_preserving_merge_exec(vec![sort_expr("nullable_col", &schema)], input2); - - let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + sort_preserving_merge_exec([sort_expr("nullable_col", &schema)].into(), input2); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortPreservingMergeExec: [non_nullable_col@1 ASC] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1759,21 +2150,22 @@ async fn test_remove_unnecessary_spm2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); let input = sort_preserving_merge_exec_with_fetch( - vec![sort_expr("non_nullable_col", &schema)], + [sort_expr("non_nullable_col", &schema)].into(), source, 100, ); - let expected_input = [ - "SortPreservingMergeExec: [non_nullable_col@1 ASC], fetch=100", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "LocalLimitExec: fetch=100", - " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, input, true); + let test = EnforceSortingTest::new(input.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [non_nullable_col@1 ASC], fetch=100 + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + LocalLimitExec: fetch=100 + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1782,22 +2174,25 @@ async fn test_remove_unnecessary_spm2() -> Result<()> { async fn test_change_wrong_sorting() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![ + let sort_exprs = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let sort = sort_exec(vec![sort_exprs[0].clone()], source); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); - let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let sort = sort_exec([sort_exprs[0].clone()].into(), source); + let physical_plan = sort_preserving_merge_exec(sort_exprs.into(), sort); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1806,25 +2201,26 @@ async fn test_change_wrong_sorting() -> Result<()> { async fn test_change_wrong_sorting2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![ + let sort_exprs = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let spm1 = sort_preserving_merge_exec(sort_exprs.clone(), source); - let sort2 = sort_exec(vec![sort_exprs[0].clone()], spm1); - let physical_plan = sort_preserving_merge_exec(vec![sort_exprs[1].clone()], sort2); - - let expected_input = [ - "SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let spm1 = sort_preserving_merge_exec(sort_exprs.clone().into(), source); + let sort2 = sort_exec([sort_exprs[0].clone()].into(), spm1); + let physical_plan = sort_preserving_merge_exec([sort_exprs[1].clone()].into(), sort2); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortPreservingMergeExec: [non_nullable_col@1 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1833,32 +2229,34 @@ async fn test_change_wrong_sorting2() -> Result<()> { async fn test_multiple_sort_window_exec() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs1 = vec![sort_expr("nullable_col", &schema)]; - let sort_exprs2 = vec![ + let ordering1 = [sort_expr("nullable_col", &schema)]; + let sort1 = sort_exec(ordering1.clone().into(), source); + let window_agg1 = bounded_window_exec("non_nullable_col", ordering1.clone(), sort1); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - - let sort1 = sort_exec(sort_exprs1.clone(), source); - let window_agg1 = bounded_window_exec("non_nullable_col", sort_exprs1.clone(), sort1); - let window_agg2 = bounded_window_exec("non_nullable_col", sort_exprs2, window_agg1); - // let filter_exec = sort_exec; - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs1, window_agg2); - - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let window_agg2 = bounded_window_exec("non_nullable_col", ordering2, window_agg1); + let physical_plan = bounded_window_exec("non_nullable_col", ordering1, window_agg2); + + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r#" + Input Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "#); Ok(()) } @@ -1871,47 +2269,38 @@ async fn test_multiple_sort_window_exec() -> Result<()> { // EnforceDistribution may invalidate ordering invariant. async fn test_commutativity() -> Result<()> { let schema = create_test_schema()?; - let config = ConfigOptions::new(); - let memory_exec = memory_exec(&schema); - let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); + let sort_exprs = [sort_expr("nullable_col", &schema)]; let window = bounded_window_exec("nullable_col", sort_exprs.clone(), memory_exec); let repartition = repartition_exec(window); + let orig_plan = sort_exec(sort_exprs.into(), repartition); - let orig_plan = - Arc::new(SortExec::new(sort_exprs, repartition)) as Arc; - let actual = get_plan_string(&orig_plan); - let expected_input = vec![ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_eq!( - expected_input, actual, - "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_input:#?}\nactual:\n\n{actual:#?}\n\n" - ); + assert_snapshot!(displayable(orig_plan.as_ref()).indent(true), @r#" + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: partitions=1, partition_sizes=[0] + "#); - let mut plan = orig_plan.clone(); + let config = ConfigOptions::new(); let rules = vec![ Arc::new(EnforceDistribution::new()) as Arc, Arc::new(EnforceSorting::new()) as Arc, ]; + let mut first_plan = orig_plan.clone(); for rule in rules { - plan = rule.optimize(plan, &config)?; + first_plan = rule.optimize(first_plan, &config)?; } - let first_plan = plan.clone(); - let mut plan = orig_plan.clone(); let rules = vec![ Arc::new(EnforceSorting::new()) as Arc, Arc::new(EnforceDistribution::new()) as Arc, Arc::new(EnforceSorting::new()) as Arc, ]; + let mut second_plan = orig_plan.clone(); for rule in rules { - plan = rule.optimize(plan, &config)?; + second_plan = rule.optimize(second_plan, &config)?; } - let second_plan = plan.clone(); assert_eq!(get_plan_string(&first_plan), get_plan_string(&second_plan)); Ok(()) @@ -1922,35 +2311,37 @@ async fn test_coalesce_propagate() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); let repartition = repartition_exec(source); - let coalesce_partitions = Arc::new(CoalescePartitionsExec::new(repartition)); + let coalesce_partitions = coalesce_partitions_exec(repartition); let repartition = repartition_exec(coalesce_partitions); - let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); // Add local sort let sort = Arc::new( - SortExec::new(sort_exprs.clone(), repartition).with_preserve_partitioning(true), + SortExec::new(ordering.clone(), repartition).with_preserve_partitioning(true), ) as _; - let spm = sort_preserving_merge_exec(sort_exprs.clone(), sort); - let sort = sort_exec(sort_exprs, spm); + let spm = sort_preserving_merge_exec(ordering.clone(), sort); + let sort = sort_exec(ordering, spm); let physical_plan = sort.clone(); // Sort Parallelize rule should end Coalesce + Sort linkage when Sort is Global Sort // Also input plan is not valid as it is. We need to add SortExec before SortPreservingMergeExec. - let expected_input = [ - "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - let expected_optimized = [ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false] + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + CoalescePartitionsExec + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + SortPreservingMergeExec: [nullable_col@0 ASC] + SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } @@ -1958,153 +2349,226 @@ async fn test_coalesce_propagate() -> Result<()> { #[tokio::test] async fn test_replace_with_partial_sort2() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("a", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("a", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("c", &schema), sort_expr("d", &schema), - ], + ] + .into(), unbounded_input, ); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC] + + Optimized Plan: + PartialSortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], common_prefix_length=[2] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC] + "); - let expected_input = [ - "SortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC]" - ]; - // let optimized - let expected_optimized = [ - "PartialSortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], common_prefix_length=[2]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } #[tokio::test] async fn test_push_with_required_input_ordering_prohibited() -> Result<()> { - // SortExec: expr=[b] <-- can't push this down - // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order - // SortExec: expr=[a] - // DataSourceExec let schema = create_test_schema3()?; - let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); - let sort_exprs_b = LexOrdering::new(vec![sort_expr("b", &schema)]); + let ordering_a: LexOrdering = [sort_expr("a", &schema)].into(); + let ordering_b: LexOrdering = [sort_expr("b", &schema)].into(); let plan = memory_exec(&schema); - let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = sort_exec(ordering_a.clone(), plan); let plan = RequirementsTestExec::new(plan) - .with_required_input_ordering(sort_exprs_a) + .with_required_input_ordering(Some(ordering_a)) .with_maintains_input_order(true) .into_arc(); - let plan = sort_exec(sort_exprs_b, plan); - - let expected_input = [ - "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " RequiredInputOrderingExec", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; + let plan = sort_exec(ordering_b, plan); + let test = EnforceSortingTest::new(plan.clone()).with_repartition_sorts(true); // should not be able to push shorts - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, plan, true); + + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[b@1 ASC], preserve_partitioning=[false] + RequiredInputOrderingExec + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); Ok(()) } // test when the required input ordering is satisfied so could push through #[tokio::test] async fn test_push_with_required_input_ordering_allowed() -> Result<()> { - // SortExec: expr=[a,b] <-- can push this down (as it is compatible with the required input ordering) - // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order - // SortExec: expr=[a] - // DataSourceExec let schema = create_test_schema3()?; - let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); - let sort_exprs_ab = - LexOrdering::new(vec![sort_expr("a", &schema), sort_expr("b", &schema)]); + let ordering_a: LexOrdering = [sort_expr("a", &schema)].into(); + let ordering_ab = [sort_expr("a", &schema), sort_expr("b", &schema)].into(); let plan = memory_exec(&schema); - let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = sort_exec(ordering_a.clone(), plan); let plan = RequirementsTestExec::new(plan) - .with_required_input_ordering(sort_exprs_a) + .with_required_input_ordering(Some(ordering_a)) .with_maintains_input_order(true) .into_arc(); - let plan = sort_exec(sort_exprs_ab, plan); + let plan = sort_exec(ordering_ab, plan); + /* let expected_input = [ - "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " RequiredInputOrderingExec", + "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", // <-- can push this down (as it is compatible with the required input ordering) + " RequiredInputOrderingExec", // <-- this requires input sorted by a, and preserves the input order " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - // should able to push shorts - let expected = [ - "RequiredInputOrderingExec", - " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ]; - assert_optimized!(expected_input, expected, plan, true); + */ + let test = EnforceSortingTest::new(plan.clone()).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false] + RequiredInputOrderingExec + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + RequiredInputOrderingExec + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "); + // Should be able to push down Ok(()) } #[tokio::test] async fn test_replace_with_partial_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("a", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("a", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![sort_expr("a", &schema), sort_expr("c", &schema)], + [sort_expr("a", &schema), sort_expr("c", &schema)].into(), unbounded_input, ); - let expected_input = [ - "SortExec: expr=[a@0 ASC, c@2 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]" - ]; - let expected_optimized = [ - "PartialSortExec: expr=[a@0 ASC, c@2 ASC], common_prefix_length=[1]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", - ]; - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[a@0 ASC, c@2 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + + Optimized Plan: + PartialSortExec: expr=[a@0 ASC, c@2 ASC], common_prefix_length=[1] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC] + "); Ok(()) } #[tokio::test] async fn test_not_replaced_with_partial_sort_for_unbounded_input() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), unbounded_input, ); - let expected_input = [ - "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" - ]; - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, physical_plan, true); + let test = + EnforceSortingTest::new(physical_plan.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + "); Ok(()) } +// aal here #[tokio::test] async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { let input_schema = create_test_schema()?; - let sort_exprs = vec![sort_expr_options( + let ordering = [sort_expr_options( "nullable_col", &input_schema, SortOptions { descending: false, nulls_first: false, }, - )]; - let source = parquet_exec_sorted(&input_schema, sort_exprs); + )] + .into(); + let source = parquet_exec_with_sort(input_schema.clone(), vec![ordering]) as _; + + // Macro for testing window function optimization with snapshots + macro_rules! test_window_case { + ( + partition_by: $partition_by:expr, + window_frame: $window_frame:expr, + func: ($func_def:expr, $func_name:expr, $func_args:expr), + required_sort: [$($col:expr, $asc:expr, $nulls_first:expr),*], + @ $expected:literal + ) => {{ + let partition_by_exprs = if $partition_by { + vec![col("nullable_col", &input_schema)?] + } else { + vec![] + }; + + let window_expr = create_window_expr( + &$func_def, + $func_name, + &$func_args, + &partition_by_exprs, + &[], + $window_frame, + Arc::clone(&input_schema), + false, + false, + None, + )?; + + let window_exec = if window_expr.uses_bounded_memory() { + Arc::new(BoundedWindowAggExec::try_new( + vec![window_expr], + Arc::clone(&source), + InputOrderMode::Sorted, + $partition_by, + )?) as Arc + } else { + Arc::new(WindowAggExec::try_new( + vec![window_expr], + Arc::clone(&source), + $partition_by, + )?) as Arc + }; + + let output_schema = window_exec.schema(); + let sort_expr = vec![ + $( + sort_expr_options( + $col, + &output_schema, + SortOptions { + descending: !$asc, + nulls_first: $nulls_first, + }, + ) + ),* + ]; + let ordering = LexOrdering::new(sort_expr).unwrap(); + let physical_plan = sort_exec(ordering, window_exec); + + let test = EnforceSortingTest::new(physical_plan).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @ $expected); + + Result::<(), datafusion_common::DataFusionError>::Ok(()) + }}; + } // Function definition - Alias of the resulting column - Arguments of the function #[derive(Clone)] @@ -2151,1232 +2615,1097 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { "avg".to_string(), function_arg_unordered, ); - struct TestCase<'a> { - // Whether window expression has a partition_by expression or not. - // If it does, it will be on the ordered column -- `nullable_col`. - partition_by: bool, - // Whether the frame is unbounded in both directions, or unbounded in - // only one direction (when set-monotonicity has a meaning), or it is - // a sliding window. - window_frame: Arc, - // Function definition - Alias of the resulting column - Arguments of the function - func: WindowFuncParam, - // Global sort requirement at the root and its direction, - // which is required to be removed or preserved -- (asc, nulls_first) - required_sort_columns: Vec<(&'a str, bool, bool)>, - initial_plan: Vec<&'a str>, - expected_plan: Vec<&'a str>, - } - let test_cases = vec![ - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column - // Case 0: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 1: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 2: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 3: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column - // Case 4: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("non_nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 5: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("non_nullable_col", false, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[non_nullable_col@1 DESC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 6: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", true, false), ("non_nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 7: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("avg", false, false), ("nullable_col", false, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column - // Case 8: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 9: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 10: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 11: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column - // Case 12: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("non_nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 13: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("non_nullable_col", true, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 14: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", false, false), ("non_nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 15: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(None)), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("avg", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on ordered column - // Case 16: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 17: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("max", false, true), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 18: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", true, true), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 ASC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 19: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on unordered column - // Case 20: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 21: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, true)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 22: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 23: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("avg", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + partition_by + on ordered column - // Case 24: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 25: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 26: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 27: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", false, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // WindowAggExec + Sliding(current row, unbounded following) + partition_by + on unordered column - // Case 28: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("count", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // Case 29: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, true)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 30: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 31: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column - // Case 32: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 33: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("max", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // Case 34: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 35: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column - // Case 36: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, true)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 37: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("max", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 38: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", false, true), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 39: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column - // Case 40: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 41: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("max", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - expected_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // Case 42: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 43: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column - // Case 44: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![ ("count", true, true)], - initial_plan: vec![ - "SortExec: expr=[count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], - }, - // Case 45: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 46: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 47: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new(Some(true))), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on ordered column - // Case 48: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("count", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 49: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("max", true, false)], - initial_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 50: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 51: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on unordered column - // Case 52: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("count", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" - ], - }, - // Case 53: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 54: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("min", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 55: - TestCase { - partition_by: false, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on ordered column - // Case 56: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_count_on_ordered.clone(), - required_sort_columns: vec![("count", true, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 57: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), - func: fn_max_on_ordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: Following(UInt32(1)), is_causal: false }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 58: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_min_on_ordered.clone(), - required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 59: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_avg_on_ordered.clone(), - required_sort_columns: vec![("avg", true, false)], - initial_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = - // ============================================REGION STARTS============================================ - // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on unordered column - // Case 60: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_count_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 61: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_max_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("max", true, true)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 62: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_min_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false), ("min", false, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // Case 63: - TestCase { - partition_by: true, - window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), - func: fn_avg_on_unordered.clone(), - required_sort_columns: vec![("nullable_col", true, false)], - initial_plan: vec![ - "SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - expected_plan: vec![ - "BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt32(1)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", - ], - }, - // =============================================REGION ENDS============================================= - ]; - for (case_idx, case) in test_cases.into_iter().enumerate() { - let partition_by = if case.partition_by { - vec![col("nullable_col", &input_schema)?] - } else { - vec![] - }; - let window_expr = create_window_expr( - &case.func.0, - case.func.1, - &case.func.2, - &partition_by, - &LexOrdering::default(), - case.window_frame, - input_schema.as_ref(), - false, - )?; - let window_exec = if window_expr.uses_bounded_memory() { - Arc::new(BoundedWindowAggExec::try_new( - vec![window_expr], - Arc::clone(&source), - InputOrderMode::Sorted, - case.partition_by, - )?) as Arc - } else { - Arc::new(WindowAggExec::try_new( - vec![window_expr], - Arc::clone(&source), - case.partition_by, - )?) as _ - }; - let output_schema = window_exec.schema(); - let sort_expr = case - .required_sort_columns - .iter() - .map(|(col_name, asc, nf)| { - sort_expr_options( - col_name, - &output_schema, - SortOptions { - descending: !asc, - nulls_first: *nf, - }, - ) - }) - .collect::>(); - let physical_plan = sort_exec(sort_expr, window_exec); - - assert_optimized!( - case.initial_plan, - case.expected_plan, - physical_plan, - true, - case_idx - ); - } + // ============================================REGION STARTS============================================ + // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column + // Case 0: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", true, false], + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 1: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", false, false], + @ r#" + Input Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 2: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), + required_sort: ["min", false, false, "nullable_col", true, false], + @ r#" + Input Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + + Optimized Plan: + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 3: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), + required_sort: ["avg", true, false, "nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // WindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column + // Case 4: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), + required_sort: ["non_nullable_col", true, false, "count", true, false], + @ r#" +Input Plan: +SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 5: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), + required_sort: ["non_nullable_col", false, false, "max", false, false], + @ r#" +Input Plan: +SortExec: expr=[non_nullable_col@1 DESC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +SortExec: expr=[non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 6: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), + required_sort: ["min", true, false, "non_nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[min@2 ASC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 7: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), + required_sort: ["avg", false, false, "nullable_col", false, false], + @ r#" +Input Plan: +SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column + // Case 8: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", true, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 9: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", false, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 10: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), + required_sort: ["min", false, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 11: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), + required_sort: ["avg", true, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // WindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column + // Case 12: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), + required_sort: ["non_nullable_col", true, false, "count", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 13: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), + required_sort: ["non_nullable_col", true, false, "max", false, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[non_nullable_col@1 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 14: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), + required_sort: ["min", false, false, "non_nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 15: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(None)), + func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), + required_sort: ["avg", true, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on ordered column + // Case 16: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", false, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 17: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), + required_sort: ["max", false, true, "nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[max@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 18: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), + required_sort: ["min", true, true, "nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[min@2 ASC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 19: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), + required_sort: ["avg", false, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // WindowAggExec + Sliding(current row, unbounded following) + no partition_by + on unordered column + // Case 20: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 21: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", false, true], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 22: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), + required_sort: ["min", true, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 23: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), + required_sort: ["avg", false, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // WindowAggExec + Sliding(current row, unbounded following) + partition_by + on ordered column + // Case 24: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", false, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 25: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 26: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), + required_sort: ["min", false, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 27: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), + required_sort: ["avg", false, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // WindowAggExec + Sliding(current row, unbounded following) + partition_by + on unordered column + // Case 28: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), + required_sort: ["count", false, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[count: Ok(Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 29: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", false, true], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false] + WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +WindowAggExec: wdw=[max: Ok(Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 30: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), + required_sort: ["min", false, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[min: Ok(Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 31: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true)).reverse()), + func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "avg", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + WindowAggExec: wdw=[avg: Ok(Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on ordered column + // Case 32: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", true, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 33: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), + required_sort: ["max", false, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 34: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), + required_sort: ["min", false, false, "nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 35: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "avg", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + no partition_by + on unordered column + // Case 36: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", true, true], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 37: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), + required_sort: ["max", true, false, "nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 38: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), + required_sort: ["min", false, true, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 39: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), + required_sort: ["avg", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on ordered column + // Case 40: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", true, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 41: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), + required_sort: ["max", true, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 42: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), + required_sort: ["min", false, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 43: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "avg", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // BoundedWindowAggExec + Plain(unbounded preceding, unbounded following) + partition_by + on unordered column + // Case 44: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), + required_sort: ["count", true, true], + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 45: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", false, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 46: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "min", false, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 47: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new(Some(true))), + func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on ordered column + // Case 48: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), + required_sort: ["count", true, false, "nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 49: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), + func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), + required_sort: ["max", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 50: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "min", false, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 51: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), + required_sort: ["avg", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + no partition_by + on unordered column + // Case 52: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), + func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), + required_sort: ["count", true, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 53: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 54: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), + required_sort: ["min", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 55: + test_window_case!( + partition_by: false, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), + func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on ordered column + // Case 56: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_count_on_ordered.0.clone(), fn_count_on_ordered.1.clone(), fn_count_on_ordered.2.clone()), + required_sort: ["count", true, false, "nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[count@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + + // Case 57: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::Following(ScalarValue::new_one(&DataType::UInt32)?))), + func: (fn_max_on_ordered.0.clone(), fn_max_on_ordered.1.clone(), fn_max_on_ordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 58: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_min_on_ordered.0.clone(), fn_min_on_ordered.1.clone(), fn_min_on_ordered.2.clone()), + required_sort: ["min", false, false, "nullable_col", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 59: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_avg_on_ordered.0.clone(), fn_avg_on_ordered.1.clone(), fn_avg_on_ordered.2.clone()), + required_sort: ["avg", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + // =============================================REGION ENDS============================================= + // = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = + // ============================================REGION STARTS============================================ + // BoundedWindowAggExec + Sliding(bounded preceding, bounded following) + partition_by + on unordered column + // Case 60: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_count_on_unordered.0.clone(), fn_count_on_unordered.1.clone(), fn_count_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "count", true, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 61: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_max_on_unordered.0.clone(), fn_max_on_unordered.1.clone(), fn_max_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "max", true, true], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[max: Field { name: "max", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 62: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_min_on_unordered.0.clone(), fn_min_on_unordered.1.clone(), fn_min_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false, "min", false, false], + @ r#" + Input / Optimized Plan: + SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[min: Field { name: "min", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + "# + )?; + + // Case 63: + test_window_case!( + partition_by: true, + window_frame: Arc::new(WindowFrame::new_bounds(WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::new_one(&DataType::UInt32)?), WindowFrameBound::CurrentRow)), + func: (fn_avg_on_unordered.0.clone(), fn_avg_on_unordered.1.clone(), fn_avg_on_unordered.2.clone()), + required_sort: ["nullable_col", true, false], + @ r#" +Input Plan: +SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false] + BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet + +Optimized Plan: +BoundedWindowAggExec: wdw=[avg: Field { name: "avg", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet +"# + )?; + // =============================================REGION ENDS============================================= Ok(()) } - #[test] fn test_removes_unused_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - - let orthogonal_sort = sort_exec(vec![sort_expr("a", &schema)], unbounded_input); - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); // same sort as data source + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); + let orthogonal_sort = sort_exec([sort_expr("a", &schema)].into(), unbounded_input); + let output_sort = sort_exec(input_ordering, orthogonal_sort); // same sort as data source // Test scenario/input has an orthogonal sort: - let expected_input = [ - "SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" - ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + let test = EnforceSortingTest::new(output_sort).with_repartition_sorts(true); + + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + Optimized Plan: + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + "); // Test: should remove orthogonal sort, and the uppermost (unneeded) sort: - let expected_optimized = [ - "StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" - ]; - assert_optimized!(expected_input, expected_optimized, output_sort, true); Ok(()) } @@ -3384,24 +3713,23 @@ fn test_removes_unused_orthogonal_sort() -> Result<()> { #[test] fn test_keeps_used_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); let orthogonal_sort = - sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); + sort_exec_with_fetch([sort_expr("a", &schema)].into(), Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output + let output_sort = sort_exec(input_ordering, orthogonal_sort); // Test scenario/input has an orthogonal sort: - let expected_input = [ - "SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" - ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + let test = EnforceSortingTest::new(output_sort).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input / Optimized Plan: + SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + "); // Test: should keep the orthogonal sort, since it modifies the output: - let expected_optimized = expected_input; - assert_optimized!(expected_input, expected_optimized, output_sort, true); Ok(()) } @@ -3409,34 +3737,191 @@ fn test_keeps_used_orthogonal_sort() -> Result<()> { #[test] fn test_handles_multiple_orthogonal_sorts() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - - let orthogonal_sort_0 = sort_exec(vec![sort_expr("c", &schema)], unbounded_input); // has no fetch, so can be removed + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); + let ordering0: LexOrdering = [sort_expr("c", &schema)].into(); + let orthogonal_sort_0 = sort_exec(ordering0.clone(), unbounded_input); // has no fetch, so can be removed + let ordering1: LexOrdering = [sort_expr("a", &schema)].into(); let orthogonal_sort_1 = - sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output - let orthogonal_sort_2 = sort_exec(vec![sort_expr("c", &schema)], orthogonal_sort_1); // has no fetch, so can be removed - let orthogonal_sort_3 = sort_exec(vec![sort_expr("a", &schema)], orthogonal_sort_2); // has no fetch, so can be removed - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort_3); // final sort + sort_exec_with_fetch(ordering1.clone(), Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output + let orthogonal_sort_2 = sort_exec(ordering0, orthogonal_sort_1); // has no fetch, so can be removed + let orthogonal_sort_3 = sort_exec(ordering1, orthogonal_sort_2); // has no fetch, so can be removed + let output_sort = sort_exec(input_ordering, orthogonal_sort_3); // final sort // Test scenario/input has an orthogonal sort: - let expected_input = [ - "SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]", - " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]", - ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + let test = EnforceSortingTest::new(output_sort.clone()).with_repartition_sorts(true); + assert_snapshot!(test.run(), @r" + Input Plan: + SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false] + SortExec: expr=[c@2 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + + Optimized Plan: + SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false] + SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false] + StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC] + "); // Test: should keep only the needed orthogonal sort, and remove the unneeded ones: - let expected_optimized = [ - "SortExec: expr=[b@1 ASC, c@2 ASC], preserve_partitioning=[false]", - " SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]", - " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]", + Ok(()) +} + +#[test] +fn test_parallelize_sort_preserves_fetch() -> Result<()> { + // Create a schema + let schema = create_test_schema3()?; + let parquet_exec = parquet_exec(schema); + let coalesced = coalesce_partitions_exec(parquet_exec.clone()); + let top_coalesced = coalesce_partitions_exec(coalesced.clone()) + .with_fetch(Some(10)) + .unwrap(); + + let requirements = PlanWithCorrespondingCoalescePartitions::new( + top_coalesced, + true, + vec![PlanWithCorrespondingCoalescePartitions::new( + coalesced, + true, + vec![PlanWithCorrespondingCoalescePartitions::new( + parquet_exec, + false, + vec![], + )], + )], + ); + + let res = parallelize_sorts(requirements)?; + + // Verify fetch was preserved + assert_eq!( + res.data.plan.fetch(), + Some(10), + "Fetch value was not preserved after transformation" + ); + Ok(()) +} + +#[tokio::test] +async fn test_partial_sort_with_homogeneous_batches() -> Result<()> { + // Create schema for the table + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + // Create homogeneous batches - each batch has the same values for columns a and b + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![3, 2, 1])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![2, 2, 2])), + Arc::new(Int32Array::from(vec![2, 2, 2])), + Arc::new(Int32Array::from(vec![4, 6, 5])), + ], + )?; + let batch3 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![3, 3, 3])), + Arc::new(Int32Array::from(vec![3, 3, 3])), + Arc::new(Int32Array::from(vec![9, 7, 8])), + ], + )?; + + // Create session with batch size of 3 to match our homogeneous batch pattern + let session_config = SessionConfig::new() + .with_batch_size(3) + .with_target_partitions(1); + let ctx = SessionContext::new_with_config(session_config); + + let sort_order = vec![ + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "a", + )), + true, + false, + ), + SortExpr::new( + Expr::Column(datafusion_common::Column::new( + Option::::None, + "b", + )), + true, + false, + ), ]; - assert_optimized!(expected_input, expected_optimized, output_sort, true); + let batches = Arc::new(DummyStreamPartition { + schema: schema.clone(), + batches: vec![batch1, batch2, batch3], + }) as _; + let provider = StreamingTable::try_new(schema.clone(), vec![batches])? + .with_sort_order(sort_order) + .with_infinite_table(true); + ctx.register_table("test_table", Arc::new(provider))?; + + let sql = "SELECT * FROM test_table ORDER BY a ASC, c ASC"; + let df = ctx.sql(sql).await?; + + let physical_plan = df.create_physical_plan().await?; + + // Verify that PartialSortExec is used + let plan_str = displayable(physical_plan.as_ref()).indent(true).to_string(); + assert!( + plan_str.contains("PartialSortExec"), + "Expected PartialSortExec in plan:\n{plan_str}", + ); + + let task_ctx = Arc::new(TaskContext::default()); + let mut stream = physical_plan.execute(0, task_ctx.clone())?; + + let mut collected_batches = Vec::new(); + while let Some(batch) = stream.next().await { + let batch = batch?; + if batch.num_rows() > 0 { + collected_batches.push(batch); + } + } + + // Assert we got 3 separate batches (not concatenated into fewer) + assert_eq!( + collected_batches.len(), + 3, + "Expected 3 separate batches, got {}", + collected_batches.len() + ); + + // Verify each batch has been sorted within itself + let expected_values = [vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]; + + for (i, batch) in collected_batches.iter().enumerate() { + let c_array = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let actual = c_array.values().iter().copied().collect::>(); + assert_eq!(actual, expected_values[i], "Batch {i} not sorted correctly",); + } + + assert_eq!( + task_ctx.runtime_env().memory_pool.reserved(), + 0, + "Memory should be released after execution" + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs new file mode 100644 index 0000000000000..b91c1732260cf --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -0,0 +1,1894 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow::{ + array::record_batch, + datatypes::{DataType, Field, Schema, SchemaRef}, + util::pretty::pretty_format_batches, +}; +use arrow_schema::SortOptions; +use datafusion::{ + assert_batches_eq, + logical_expr::Operator, + physical_plan::{ + expressions::{BinaryExpr, Column, Literal}, + PhysicalExpr, + }, + prelude::{ParquetReadOptions, SessionConfig, SessionContext}, + scalar::ScalarValue, +}; +use datafusion_catalog::memory::DataSourceExec; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::{ + file_groups::FileGroup, file_scan_config::FileScanConfigBuilder, PartitionedFile, +}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_expr::ScalarUDF; +use datafusion_functions::math::random::RandomFunc; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::{ + aggregate::AggregateExprBuilder, Partitioning, ScalarFunctionExpr, +}; +use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::{ + filter_pushdown::FilterPushdown, PhysicalOptimizerRule, +}; +use datafusion_physical_plan::{ + aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, + coalesce_batches::CoalesceBatchesExec, + coalesce_partitions::CoalescePartitionsExec, + collect, + filter::FilterExec, + repartition::RepartitionExec, + sorts::sort::SortExec, + ExecutionPlan, +}; + +use datafusion_physical_plan::union::UnionExec; +use futures::StreamExt; +use object_store::{memory::InMemory, ObjectStore}; +use util::{format_plan_for_test, OptimizationTest, TestNode, TestScanBuilder}; + +use crate::physical_optimizer::filter_pushdown::util::TestSource; + +mod util; + +#[test] +fn test_pushdown_into_scan() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_pushdown_volatile_functions_not_allowed() { + // Test that we do not push down filters with volatile functions + // Use random() as an example of a volatile function + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let cfg = Arc::new(ConfigOptions::default()); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("a", &schema()).unwrap()), + Operator::Eq, + Arc::new( + ScalarFunctionExpr::try_new( + Arc::new(ScalarUDF::from(RandomFunc::new())), + vec![], + &schema(), + cfg, + ) + .unwrap(), + ), + )) as Arc; + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + // expect the filter to not be pushed down + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = random() + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = random() + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + ", + ); +} + +/// Show that we can use config options to determine how to do pushdown. +#[test] +fn test_pushdown_into_scan_with_config_options() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()) as _; + + let mut cfg = ConfigOptions::default(); + insta::assert_snapshot!( + OptimizationTest::new( + Arc::clone(&plan), + FilterPushdown::new(), + false + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); + + cfg.execution.parquet.pushdown_filters = true; + insta::assert_snapshot!( + OptimizationTest::new( + plan, + FilterPushdown::new(), + true + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[tokio::test] +async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8View, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap()]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8View, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8View, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap()]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("e", DataType::Utf8View, false), + Field::new("f", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let on = vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + + // Finally let's add a SortExec on the outside to test pushdown of dynamic filters + let sort_expr = + PhysicalSortExpr::new(col("e", &join_schema).unwrap(), SortOptions::default()); + let plan = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]).unwrap(), join) + .with_fetch(Some(2)), + ) as Arc; + + let mut config = ConfigOptions::default(); + config.optimizer.enable_dynamic_filter_pushdown = true; + config.execution.parquet.pushdown_filters = true; + + // Apply the FilterPushdown optimizer rule + let plan = FilterPushdown::new_post_optimization() + .optimize(Arc::clone(&plan), &config) + .unwrap(); + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] AND DynamicFilter [ empty ] + " + ); + + // Put some data through the plan to check that the filter is updated to reflect the TopK state + let session_ctx = SessionContext::new_with_config(SessionConfig::new()); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + stream.next().await.unwrap().unwrap(); + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + format_plan_for_test(&plan), + @r" + - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false], filter=[e@4 IS NULL OR e@4 < bb] + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= aa AND d@0 <= ab ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] + " + ); +} + +// Test both static and dynamic filter pushdown in HashJoinExec. +// Note that static filter pushdown is rare: it should have already happened in the logical optimizer phase. +// However users may manually construct plans that could result in a FilterExec -> HashJoinExec -> Scan setup. +// Dynamic filters arise in cases such as nested inner joins or TopK -> HashJoinExec -> Scan setups. +#[tokio::test] +async fn test_static_filter_pushdown_through_hash_join() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8View, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap()]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8View, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8View, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap()]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("e", DataType::Utf8View, false), + Field::new("f", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let on = vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + // Create filters that can be pushed down to different sides + // We need to create filters in the context of the join output schema + let join_schema = join.schema(); + + // Filter on build side column: a = 'aa' + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column: e = 'ba' + let right_filter = col_lit_predicate("e", "ba", &join_schema); + // Filter that references both sides: a = d (should not be pushed down) + let cross_filter = Arc::new(BinaryExpr::new( + col("a", &join_schema).unwrap(), + Operator::Eq, + col("d", &join_schema).unwrap(), + )) as Arc; + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let filter = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()); + let plan = Arc::new(FilterExec::try_new(cross_filter, filter).unwrap()) + as Arc; + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = d@3 + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = d@3 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=e@1 = ba + " + ); + + // Test left join - filters should NOT be pushed down + let join = Arc::new( + HashJoinExec::try_new( + TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .build(), + TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .build(), + vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )], + None, + &JoinType::Left, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let join_schema = join.schema(); + let filter = col_lit_predicate("a", "aa", &join_schema); + let plan = + Arc::new(FilterExec::try_new(filter, join).unwrap()) as Arc; + + // Test that filters are NOT pushed down for left join + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + " + ); +} + +#[test] +fn test_filter_collapse() { + // filter should be pushed down into the parquet scan with two filters + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate1 = col_lit_predicate("a", "foo", &schema()); + let filter1 = Arc::new(FilterExec::try_new(predicate1, scan).unwrap()); + let predicate2 = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate2, filter1).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_filter_with_projection() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let projection = vec![1, 0]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExec::try_new(predicate, Arc::clone(&scan)) + .unwrap() + .with_projection(Some(projection)) + .unwrap(), + ); + + // expect the predicate to be pushed down into the DataSource but the FilterExec to be converted to ProjectionExec + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1, a@0] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // add a test where the filter is on a column that isn't included in the output + let projection = vec![1]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExec::try_new(predicate, scan) + .unwrap() + .with_projection(Some(projection)) + .unwrap(), + ); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(),true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_push_down_through_transparent_nodes() { + // expect the predicate to be pushed down into the DataSource + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 1)); + let predicate = col_lit_predicate("a", "foo", &schema()); + let filter = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); + let repartition = Arc::new( + RepartitionExec::try_new(filter, Partitioning::RoundRobinBatch(1)).unwrap(), + ); + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, repartition).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(),true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - FilterExec: a@0 = foo + - CoalesceBatchesExec: target_batch_size=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - CoalesceBatchesExec: target_batch_size=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_no_pushdown_through_aggregates() { + // There are 2 important points here: + // 1. The outer filter **is not** pushed down at all because we haven't implemented pushdown support + // yet for AggregateExec. + // 2. The inner filter **is** pushed down into the DataSource. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 10)); + + let filter = Arc::new( + FilterExec::try_new(col_lit_predicate("a", "foo", &schema()), coalesce).unwrap(), + ); + + let aggregate_expr = + vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + let group_by = PhysicalGroupBy::new_single(vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ]); + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + filter, + schema(), + ) + .unwrap(), + ); + + let coalesce = Arc::new(CoalesceBatchesExec::new(aggregate, 100)); + + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - CoalesceBatchesExec: target_batch_size=100 + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) + - FilterExec: a@0 = foo + - CoalesceBatchesExec: target_batch_size=10 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: b@1 = bar + - CoalesceBatchesExec: target_batch_size=100 + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) + - CoalesceBatchesExec: target_batch_size=10 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +/// Test various combinations of handling of child pushdown results +/// in an ExecutionPlan in combination with support/not support in a DataSource. +#[test] +fn test_node_handles_child_pushdown_result() { + // If we set `with_support(true)` + `inject_filter = true` then the filter is pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // If we set `with_support(false)` + `inject_filter = true` then the filter is not pushed down to the DataSource + // and a FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); + + // If we set `with_support(false)` + `inject_filter = false` then the filter is not pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(false, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); +} + +#[tokio::test] +async fn test_topk_dynamic_filter_pushdown() { + let batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bd", "bc"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + ]; + let scan = TestScanBuilder::new(schema()) + .with_support(true) + .with_batches(batches) + .build(); + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + Arc::clone(&scan), + ) + .with_fetch(Some(1)), + ) as Arc; + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and put some data through it to check that the filter is updated to reflect the TopK state + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(2); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + stream.next().await.unwrap().unwrap(); + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false], filter=[b@1 > bd] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@1 > bd ] + " + ); +} + +#[tokio::test] +async fn test_topk_dynamic_filter_pushdown_multi_column_sort() { + let batches = vec![ + // We are going to do ORDER BY b ASC NULLS LAST, a DESC + // And we put the values in such a way that the first batch will fill the TopK + // and we skip the second batch. + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bc", "bd"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + ]; + let scan = TestScanBuilder::new(schema()) + .with_support(true) + .with_batches(batches) + .build(); + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![ + PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::default().asc().nulls_last(), + ), + PhysicalSortExpr::new( + col("a", &schema()).unwrap(), + SortOptions::default().desc().nulls_first(), + ), + ]) + .unwrap(), + Arc::clone(&scan), + ) + .with_fetch(Some(2)), + ) as Arc; + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and put some data through it to check that the filter is updated to reflect the TopK state + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(2); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + let res = stream.next().await.unwrap().unwrap(); + #[rustfmt::skip] + let expected = [ + "+----+----+-----+", + "| a | b | c |", + "+----+----+-----+", + "| ad | ba | 1.0 |", + "| ac | bb | 2.0 |", + "+----+----+-----+", + ]; + assert_batches_eq!(expected, &[res]); + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: TopK(fetch=2), expr=[b@1 ASC NULLS LAST, a@0 DESC], preserve_partitioning=[false], filter=[b@1 < bb OR b@1 = bb AND (a@0 IS NULL OR a@0 > ac)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@1 < bb OR b@1 = bb AND (a@0 IS NULL OR a@0 > ac) ] + " + ); + // There should be no more batches + assert!(stream.next().await.is_none()); +} + +#[tokio::test] +async fn test_topk_filter_passes_through_coalesce_partitions() { + // Create multiple batches for different partitions + let batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bd", "bc"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + ]; + + // Create a source that supports all batches + let source = Arc::new(TestSource::new(true, batches)); + + let base_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test://").unwrap(), + Arc::clone(&schema()), + source, + ) + .with_file_groups(vec![ + // Partition 0 + FileGroup::new(vec![PartitionedFile::new("test1.parquet", 123)]), + // Partition 1 + FileGroup::new(vec![PartitionedFile::new("test2.parquet", 123)]), + ]) + .build(); + + let scan = DataSourceExec::from_data_source(base_config); + + // Add CoalescePartitionsExec to merge the two partitions + let coalesce = Arc::new(CoalescePartitionsExec::new(scan)) as Arc; + + // Add SortExec with TopK + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::new(true, false), + )]) + .unwrap(), + coalesce, + ) + .with_fetch(Some(1)), + ) as Arc; + + // Test optimization - the filter SHOULD pass through CoalescePartitionsExec + // if it properly implements from_children (not all_unsupported) + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - DataSourceExec: file_groups={2 groups: [[test1.parquet], [test2.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - DataSourceExec: file_groups={2 groups: [[test1.parquet], [test2.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); +} + +#[tokio::test] +async fn test_topk_filter_passes_through_coalesce_batches() { + let batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bd", "bc"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + ]; + + let scan = TestScanBuilder::new(schema()) + .with_support(true) + .with_batches(batches) + .build(); + + let coalesce_batches = + Arc::new(CoalesceBatchesExec::new(scan, 1024)) as Arc; + + // Add SortExec with TopK + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::new(true, false), + )]) + .unwrap(), + coalesce_batches, + ) + .with_fetch(Some(1)), + ) as Arc; + + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - CoalesceBatchesExec: target_batch_size=1024 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap()]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap()]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec with dynamic filter + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let plan = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ) as Arc; + + // expect the predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + ", + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + + // Test for https://github.com/apache/datafusion/pull/17371: dynamic filter linking survives `with_new_children` + let children = plan.children().into_iter().map(Arc::clone).collect(); + let plan = plan.with_new_children(children).unwrap(); + + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + stream.next().await.unwrap().unwrap(); + + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + " + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Rough sketch of the MRE we're trying to recreate: + // COPY (select i as k from generate_series(1, 10000000) as t(i)) + // TO 'test_files/scratch/push_down_filter/t1.parquet' + // STORED AS PARQUET; + // COPY (select i as k, i as v from generate_series(1, 10000000) as t(i)) + // TO 'test_files/scratch/push_down_filter/t2.parquet' + // STORED AS PARQUET; + // create external table t1 stored as parquet location 'test_files/scratch/push_down_filter/t1.parquet'; + // create external table t2 stored as parquet location 'test_files/scratch/push_down_filter/t2.parquet'; + // explain + // select * + // from t1 + // join t2 on t1.k = t2.k; + // +---------------+------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+------------------------------------------------------------+ + // | physical_plan | ┌───────────────────────────┐ | + // | | │ CoalesceBatchesExec │ | + // | | │ -------------------- │ | + // | | │ target_batch_size: │ | + // | | │ 8192 │ | + // | | └─────────────┬─────────────┘ | + // | | ┌─────────────┴─────────────┐ | + // | | │ HashJoinExec │ | + // | | │ -------------------- ├──────────────┐ | + // | | │ on: (k = k) │ │ | + // | | └─────────────┬─────────────┘ │ | + // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | + // | | │ CoalesceBatchesExec ││ CoalesceBatchesExec │ | + // | | │ -------------------- ││ -------------------- │ | + // | | │ target_batch_size: ││ target_batch_size: │ | + // | | │ 8192 ││ 8192 │ | + // | | └─────────────┬─────────────┘└─────────────┬─────────────┘ | + // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | + // | | │ RepartitionExec ││ RepartitionExec │ | + // | | │ -------------------- ││ -------------------- │ | + // | | │ partition_count(in->out): ││ partition_count(in->out): │ | + // | | │ 12 -> 12 ││ 12 -> 12 │ | + // | | │ ││ │ | + // | | │ partitioning_scheme: ││ partitioning_scheme: │ | + // | | │ Hash([k@0], 12) ││ Hash([k@0], 12) │ | + // | | └─────────────┬─────────────┘└─────────────┬─────────────┘ | + // | | ┌─────────────┴─────────────┐┌─────────────┴─────────────┐ | + // | | │ DataSourceExec ││ DataSourceExec │ | + // | | │ -------------------- ││ -------------------- │ | + // | | │ files: 12 ││ files: 12 │ | + // | | │ format: parquet ││ format: parquet │ | + // | | │ ││ predicate: true │ | + // | | └───────────────────────────┘└───────────────────────────┘ | + // | | | + // +---------------+------------------------------------------------------------+ + + // Create build side with limited values + let build_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap()]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap()]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Build side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + let build_hash_exprs = vec![ + col("a", &build_side_schema).unwrap(), + col("b", &build_side_schema).unwrap(), + ]; + let build_repartition = Arc::new( + RepartitionExec::try_new( + build_scan, + Partitioning::Hash(build_hash_exprs, partition_count), + ) + .unwrap(), + ); + let build_coalesce = Arc::new(CoalesceBatchesExec::new(build_repartition, 8192)); + + // Probe side: DataSource -> RepartitionExec (Hash) -> CoalesceBatchesExec + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), + ) + .unwrap(), + ); + let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); + + // Create HashJoinExec with partitioned inputs + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_coalesce, + probe_coalesce, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + // Top-level CoalesceBatchesExec + let cb = + Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // expect the predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Now check what our filter looks like + #[cfg(not(feature = "force_hash_collisions"))] + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb OR a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba ] + " + ); + + #[cfg(feature = "force_hash_collisions")] + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + " + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + // The number of output rows from the probe side scan should stay consistent across executions. + // Issue: https://github.com/apache/datafusion/issues/17451 + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +#[tokio::test] +async fn test_hashjoin_dynamic_filter_pushdown_collect_left() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + let build_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) // Extra column not used in join + ) + .unwrap()]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab", "ac", "ad"]), + ("b", Utf8, ["ba", "bb", "bc", "bd"]), + ("e", Float64, [1.0, 2.0, 3.0, 4.0]) // Extra column not used in join + ) + .unwrap()]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("e", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create RepartitionExec nodes for both sides with hash partitioning on join keys + let partition_count = 12; + + // Probe side: DataSource -> RepartitionExec(Hash) -> CoalesceBatchesExec + let probe_hash_exprs = vec![ + col("a", &probe_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ]; + let probe_repartition = Arc::new( + RepartitionExec::try_new( + Arc::clone(&probe_scan), + Partitioning::Hash(probe_hash_exprs, partition_count), // create multi partitions on probSide + ) + .unwrap(), + ); + let probe_coalesce = Arc::new(CoalesceBatchesExec::new(probe_repartition, 8192)); + + let on = vec![ + ( + col("a", &build_side_schema).unwrap(), + col("a", &probe_side_schema).unwrap(), + ), + ( + col("b", &build_side_schema).unwrap(), + col("b", &probe_side_schema).unwrap(), + ), + ]; + let hash_join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_coalesce, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + // Top-level CoalesceBatchesExec + let cb = + Arc::new(CoalesceBatchesExec::new(hash_join, 8192)) as Arc; + // Top-level CoalescePartitionsExec + let cp = Arc::new(CoalescePartitionsExec::new(cb)) as Arc; + // Add a sort for deterministic output + let plan = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("a", &probe_side_schema).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + cp, + )) as Arc; + + // expect the predicate to be pushed down into the probe side DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + " + ); + + // Actually apply the optimization to the plan and execute to see the filter in action + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let batches = collect(Arc::clone(&plan), Arc::clone(&task_ctx)) + .await + .unwrap(); + + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] + - CoalescePartitionsExec + - CoalesceBatchesExec: target_batch_size=8192 + - HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0), (b@1, b@1)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - CoalesceBatchesExec: target_batch_size=8192 + - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + " + ); + + let result = format!("{}", pretty_format_batches(&batches).unwrap()); + + let probe_scan_metrics = probe_scan.metrics().unwrap(); + + // The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain. + // The number of output rows from the probe side scan should stay consistent across executions. + // Issue: https://github.com/apache/datafusion/issues/17451 + assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2); + + insta::assert_snapshot!( + result, + @r" + +----+----+-----+----+----+-----+ + | a | b | c | a | b | e | + +----+----+-----+----+----+-----+ + | ab | bb | 2.0 | ab | bb | 2.0 | + | aa | ba | 1.0 | aa | ba | 1.0 | + +----+----+-----+----+----+-----+ + ", + ); +} + +#[tokio::test] +async fn test_nested_hashjoin_dynamic_filter_pushdown() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create test data for three tables: t1, t2, t3 + // t1: small table with limited values (will be build side of outer join) + let t1_batches = + vec![ + record_batch!(("a", Utf8, ["aa", "ab"]), ("x", Float64, [1.0, 2.0])).unwrap(), + ]; + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("x", DataType::Float64, false), + ])); + let t1_scan = TestScanBuilder::new(Arc::clone(&t1_schema)) + .with_support(true) + .with_batches(t1_batches) + .build(); + + // t2: larger table (will be probe side of inner join, build side of outer join) + let t2_batches = vec![record_batch!( + ("b", Utf8, ["aa", "ab", "ac", "ad", "ae"]), + ("c", Utf8, ["ca", "cb", "cc", "cd", "ce"]), + ("y", Float64, [1.0, 2.0, 3.0, 4.0, 5.0]) + ) + .unwrap()]; + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + Field::new("y", DataType::Float64, false), + ])); + let t2_scan = TestScanBuilder::new(Arc::clone(&t2_schema)) + .with_support(true) + .with_batches(t2_batches) + .build(); + + // t3: largest table (will be probe side of inner join) + let t3_batches = vec![record_batch!( + ("d", Utf8, ["ca", "cb", "cc", "cd", "ce", "cf", "cg", "ch"]), + ("z", Float64, [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + ) + .unwrap()]; + let t3_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("z", DataType::Float64, false), + ])); + let t3_scan = TestScanBuilder::new(Arc::clone(&t3_schema)) + .with_support(true) + .with_batches(t3_batches) + .build(); + + // Create nested join structure: + // Join (t1.a = t2.b) + // / \ + // t1 Join(t2.c = t3.d) + // / \ + // t2 t3 + + // First create inner join: t2.c = t3.d + let inner_join_on = + vec![(col("c", &t2_schema).unwrap(), col("d", &t3_schema).unwrap())]; + let inner_join = Arc::new( + HashJoinExec::try_new( + t2_scan, + t3_scan, + inner_join_on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + // Then create outer join: t1.a = t2.b (from inner join result) + let outer_join_on = vec![( + col("a", &t1_schema).unwrap(), + col("b", &inner_join.schema()).unwrap(), + )]; + let outer_join = Arc::new( + HashJoinExec::try_new( + t1_scan, + inner_join as Arc, + outer_join_on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ) as Arc; + + // Test that dynamic filters are pushed down correctly through nested joins + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&outer_join), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true + output: + Ok: + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ empty ] + ", + ); + + // Execute the plan to verify the dynamic filters are properly updated + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_dynamic_filter_pushdown = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(outer_join, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(10); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Execute to populate the dynamic filters + stream.next().await.unwrap().unwrap(); + + // Verify that both the inner and outer join have updated dynamic filters + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@0 >= aa AND b@0 <= ab ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= ca AND d@0 <= cb ] + " + ); +} + +#[tokio::test] +async fn test_hashjoin_parent_filter_pushdown() { + use datafusion_common::JoinType; + use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; + + // Create build side with limited values + let build_batches = vec![record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["ba", "bb"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap()]; + let build_side_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) + .with_support(true) + .with_batches(build_batches) + .build(); + + // Create probe side with more values + let probe_batches = vec![record_batch!( + ("d", Utf8, ["aa", "ab", "ac", "ad"]), + ("e", Utf8, ["ba", "bb", "bc", "bd"]), + ("f", Float64, [1.0, 2.0, 3.0, 4.0]) + ) + .unwrap()]; + let probe_side_schema = Arc::new(Schema::new(vec![ + Field::new("d", DataType::Utf8, false), + Field::new("e", DataType::Utf8, false), + Field::new("f", DataType::Float64, false), + ])); + let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema)) + .with_support(true) + .with_batches(probe_batches) + .build(); + + // Create HashJoinExec + let on = vec![( + col("a", &build_side_schema).unwrap(), + col("d", &probe_side_schema).unwrap(), + )]; + let join = Arc::new( + HashJoinExec::try_new( + build_scan, + probe_scan, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + datafusion_common::NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + // Create filters that can be pushed down to different sides + // We need to create filters in the context of the join output schema + let join_schema = join.schema(); + + // Filter on build side column: a = 'aa' + let left_filter = col_lit_predicate("a", "aa", &join_schema); + // Filter on probe side column: e = 'ba' + let right_filter = col_lit_predicate("e", "ba", &join_schema); + // Filter that references both sides: a = d (should not be pushed down) + let cross_filter = Arc::new(BinaryExpr::new( + col("a", &join_schema).unwrap(), + Operator::Eq, + col("d", &join_schema).unwrap(), + )) as Arc; + + let filter = + Arc::new(FilterExec::try_new(left_filter, Arc::clone(&join) as _).unwrap()); + let filter = Arc::new(FilterExec::try_new(right_filter, filter).unwrap()); + let plan = Arc::new(FilterExec::try_new(cross_filter, filter).unwrap()) + as Arc; + + // Test that filters are pushed down correctly to each side of the join + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = d@3 + - FilterExec: e@4 = ba + - FilterExec: a@0 = aa + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = d@3 + - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=e@1 = ba + " + ); +} + +/// Integration test for dynamic filter pushdown with TopK. +/// We use an integration test because there are complex interactions in the optimizer rules +/// that the unit tests applying a single optimizer rule do not cover. +#[tokio::test] +async fn test_topk_dynamic_filter_pushdown_integration() { + let store = Arc::new(InMemory::new()) as Arc; + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + cfg.options_mut().execution.parquet.max_row_group_size = 128; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store( + ObjectStoreUrl::parse("memory://").unwrap().as_ref(), + Arc::clone(&store), + ); + ctx.sql( + r" +COPY ( + SELECT 1372708800 + value AS t + FROM generate_series(0, 99999) + ORDER BY t + ) TO 'memory:///1.parquet' +STORED AS PARQUET; + ", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + // Register the file with the context + ctx.register_parquet( + "topk_pushdown", + "memory:///1.parquet", + ParquetReadOptions::default(), + ) + .await + .unwrap(); + + // Create a TopK query that will use dynamic filter pushdown + let df = ctx + .sql(r"EXPLAIN ANALYZE SELECT t FROM topk_pushdown ORDER BY t LIMIT 10;") + .await + .unwrap(); + let batches = df.collect().await.unwrap(); + let explain = format!("{}", pretty_format_batches(&batches).unwrap()); + + assert!(explain.contains("output_rows=128")); // Read 1 row group + assert!(explain.contains("t@0 < 1372708809")); // Dynamic filter was applied + assert!( + explain.contains("pushdown_rows_matched=128, pushdown_rows_pruned=99872"), + "{explain}" + ); + // Pushdown pruned most rows +} + +#[test] +fn test_filter_pushdown_through_union() { + let scan1 = TestScanBuilder::new(schema()).with_support(true).build(); + let scan2 = TestScanBuilder::new(schema()).with_support(true).build(); + + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, union).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - UnionExec + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +/// Schema: +/// a: String +/// b: String +/// c: f64 +static TEST_SCHEMA: LazyLock = LazyLock::new(|| { + let fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ]; + Arc::new(Schema::new(fields)) +}); + +fn schema() -> SchemaRef { + Arc::clone(&TEST_SCHEMA) +} + +/// Returns a predicate that is a binary expression col = lit +fn col_lit_predicate( + column_name: &str, + scalar_value: impl Into, + schema: &Schema, +) -> Arc { + let scalar_value = scalar_value.into(); + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema(column_name, schema).unwrap()), + Operator::Eq, + Arc::new(Literal::new(scalar_value)), + )) +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs new file mode 100644 index 0000000000000..f05f3f00281d6 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -0,0 +1,574 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::SchemaRef; +use arrow::{array::RecordBatch, compute::concat_batches}; +use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; +use datafusion_common::{config::ConfigOptions, internal_err, Result, Statistics}; +use datafusion_datasource::{ + file::FileSource, file_scan_config::FileScanConfig, + file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, + file_stream::FileOpener, schema_adapter::DefaultSchemaAdapterFactory, + schema_adapter::SchemaAdapterFactory, source::DataSourceExec, PartitionedFile, +}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::filter::batch_filter; +use datafusion_physical_plan::filter_pushdown::{FilterPushdownPhase, PushedDown}; +use datafusion_physical_plan::{ + displayable, + filter::FilterExec, + filter_pushdown::{ + ChildFilterDescription, ChildPushdownResult, FilterDescription, + FilterPushdownPropagation, + }, + metrics::ExecutionPlanMetricsSet, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use futures::StreamExt; +use futures::{FutureExt, Stream}; +use object_store::ObjectStore; +use std::{ + any::Any, + fmt::{Display, Formatter}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +pub struct TestOpener { + batches: Vec, + batch_size: Option, + schema: Option, + projection: Option>, + predicate: Option>, +} + +impl FileOpener for TestOpener { + fn open(&self, _partitioned_file: PartitionedFile) -> Result { + let mut batches = self.batches.clone(); + if let Some(batch_size) = self.batch_size { + let batch = concat_batches(&batches[0].schema(), &batches)?; + let mut new_batches = Vec::new(); + for i in (0..batch.num_rows()).step_by(batch_size) { + let end = std::cmp::min(i + batch_size, batch.num_rows()); + let batch = batch.slice(i, end - i); + new_batches.push(batch); + } + batches = new_batches.into_iter().collect(); + } + if let Some(schema) = &self.schema { + let factory = DefaultSchemaAdapterFactory::from_schema(Arc::clone(schema)); + let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap(); + let mut new_batches = Vec::new(); + for batch in batches { + let batch = if let Some(predicate) = &self.predicate { + batch_filter(&batch, predicate)? + } else { + batch + }; + + let batch = batch.project(&projection).unwrap(); + let batch = mapper.map_batch(batch).unwrap(); + new_batches.push(batch); + } + batches = new_batches; + } + if let Some(projection) = &self.projection { + batches = batches + .into_iter() + .map(|batch| batch.project(projection).unwrap()) + .collect(); + } + + let stream = TestStream::new(batches); + + Ok((async { Ok(stream.boxed()) }).boxed()) + } +} + +/// A placeholder data source that accepts filter pushdown +#[derive(Clone, Default)] +pub struct TestSource { + support: bool, + predicate: Option>, + statistics: Option, + batch_size: Option, + batches: Vec, + schema: Option, + metrics: ExecutionPlanMetricsSet, + projection: Option>, + schema_adapter_factory: Option>, +} + +impl TestSource { + pub fn new(support: bool, batches: Vec) -> Self { + Self { + support, + metrics: ExecutionPlanMetricsSet::new(), + batches, + ..Default::default() + } + } +} + +impl FileSource for TestSource { + fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &FileScanConfig, + _partition: usize, + ) -> Arc { + Arc::new(TestOpener { + batches: self.batches.clone(), + batch_size: self.batch_size, + schema: self.schema.clone(), + projection: self.projection.clone(), + predicate: self.predicate.clone(), + }) + } + + fn filter(&self) -> Option> { + self.predicate.clone() + } + + fn as_any(&self) -> &dyn Any { + todo!("should not be called") + } + + fn with_batch_size(&self, batch_size: usize) -> Arc { + Arc::new(TestSource { + batch_size: Some(batch_size), + ..self.clone() + }) + } + + fn with_schema(&self, schema: SchemaRef) -> Arc { + Arc::new(TestSource { + schema: Some(schema), + ..self.clone() + }) + } + + fn with_projection(&self, config: &FileScanConfig) -> Arc { + Arc::new(TestSource { + projection: config.projection.clone(), + ..self.clone() + }) + } + + fn with_statistics(&self, statistics: Statistics) -> Arc { + Arc::new(TestSource { + statistics: Some(statistics), + ..self.clone() + }) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + &self.metrics + } + + fn statistics(&self) -> Result { + Ok(self + .statistics + .as_ref() + .expect("statistics not set") + .clone()) + } + + fn file_type(&self) -> &str { + "test" + } + + fn fmt_extra(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let support = format!(", pushdown_supported={}", self.support); + + let predicate_string = self + .predicate + .as_ref() + .map(|p| format!(", predicate={p}")) + .unwrap_or_default(); + + write!(f, "{support}{predicate_string}") + } + DisplayFormatType::TreeRender => { + if let Some(predicate) = &self.predicate { + writeln!(f, "pushdown_supported={}", fmt_sql(predicate.as_ref()))?; + writeln!(f, "predicate={}", fmt_sql(predicate.as_ref()))?; + } + Ok(()) + } + } + } + + fn try_pushdown_filters( + &self, + mut filters: Vec>, + config: &ConfigOptions, + ) -> Result>> { + if self.support && config.execution.parquet.pushdown_filters { + if let Some(internal) = self.predicate.as_ref() { + filters.push(Arc::clone(internal)); + } + let new_node = Arc::new(TestSource { + predicate: datafusion_physical_expr::utils::conjunction_opt( + filters.clone(), + ), + ..self.clone() + }); + Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::Yes; filters.len()], + ) + .with_updated_node(new_node)) + } else { + Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::No; filters.len()], + )) + } + } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +#[derive(Debug, Clone)] +pub struct TestScanBuilder { + support: bool, + batches: Vec, + schema: SchemaRef, +} + +impl TestScanBuilder { + pub fn new(schema: SchemaRef) -> Self { + Self { + support: false, + batches: vec![], + schema, + } + } + + pub fn with_support(mut self, support: bool) -> Self { + self.support = support; + self + } + + pub fn with_batches(mut self, batches: Vec) -> Self { + self.batches = batches; + self + } + + pub fn build(self) -> Arc { + let source = Arc::new(TestSource::new(self.support, self.batches)); + let base_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test://").unwrap(), + Arc::clone(&self.schema), + source, + ) + .with_file(PartitionedFile::new("test.parquet", 123)) + .build(); + DataSourceExec::from_data_source(base_config) + } +} + +/// Index into the data that has been returned so far +#[derive(Debug, Default, Clone)] +pub struct BatchIndex { + inner: Arc>, +} + +impl BatchIndex { + /// Return the current index + pub fn value(&self) -> usize { + let inner = self.inner.lock().unwrap(); + *inner + } + + // increment the current index by one + pub fn incr(&self) { + let mut inner = self.inner.lock().unwrap(); + *inner += 1; + } +} + +/// Iterator over batches +#[derive(Debug, Default)] +pub struct TestStream { + /// Vector of record batches + data: Vec, + /// Index into the data that has been returned so far + index: BatchIndex, +} + +impl TestStream { + /// Create an iterator for a vector of record batches. Assumes at + /// least one entry in data (for the schema) + pub fn new(data: Vec) -> Self { + // check that there is at least one entry in data and that all batches have the same schema + assert!(!data.is_empty(), "data must not be empty"); + assert!( + data.iter().all(|batch| batch.schema() == data[0].schema()), + "all batches must have the same schema" + ); + Self { + data, + ..Default::default() + } + } +} + +impl Stream for TestStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let next_batch = self.index.value(); + + Poll::Ready(if next_batch < self.data.len() { + let next_batch = self.index.value(); + self.index.incr(); + Some(Ok(self.data[next_batch].clone())) + } else { + None + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.data.len(), Some(self.data.len())) + } +} + +/// A harness for testing physical optimizers. +/// +/// You can use this to test the output of a physical optimizer rule using insta snapshots +#[derive(Debug)] +pub struct OptimizationTest { + input: Vec, + output: Result, String>, +} + +impl OptimizationTest { + pub fn new( + input_plan: Arc, + opt: O, + allow_pushdown_filters: bool, + ) -> Self + where + O: PhysicalOptimizerRule, + { + let mut parquet_pushdown_config = ConfigOptions::default(); + parquet_pushdown_config.execution.parquet.pushdown_filters = + allow_pushdown_filters; + + let input = format_execution_plan(&input_plan); + let input_schema = input_plan.schema(); + + let output_result = opt.optimize(input_plan, &parquet_pushdown_config); + let output = output_result + .and_then(|plan| { + if opt.schema_check() && (plan.schema() != input_schema) { + internal_err!( + "Schema mismatch:\n\nBefore:\n{:?}\n\nAfter:\n{:?}", + input_schema, + plan.schema() + ) + } else { + Ok(plan) + } + }) + .map(|plan| format_execution_plan(&plan)) + .map_err(|e| e.to_string()); + + Self { input, output } + } +} + +impl Display for OptimizationTest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "OptimizationTest:")?; + writeln!(f, " input:")?; + for line in &self.input { + writeln!(f, " - {line}")?; + } + writeln!(f, " output:")?; + match &self.output { + Ok(output) => { + writeln!(f, " Ok:")?; + for line in output { + writeln!(f, " - {line}")?; + } + } + Err(err) => { + writeln!(f, " Err: {err}")?; + } + } + Ok(()) + } +} + +pub fn format_execution_plan(plan: &Arc) -> Vec { + format_lines(&displayable(plan.as_ref()).indent(false).to_string()) +} + +fn format_lines(s: &str) -> Vec { + s.trim().split('\n').map(|s| s.to_string()).collect() +} + +pub fn format_plan_for_test(plan: &Arc) -> String { + let mut out = String::new(); + for line in format_execution_plan(plan) { + out.push_str(&format!(" - {line}\n")); + } + out.push('\n'); + out +} + +#[derive(Debug)] +pub(crate) struct TestNode { + inject_filter: bool, + input: Arc, + predicate: Arc, +} + +impl TestNode { + pub fn new( + inject_filter: bool, + input: Arc, + predicate: Arc, + ) -> Self { + Self { + inject_filter, + input, + predicate, + } + } +} + +impl DisplayAs for TestNode { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "TestInsertExec {{ inject_filter: {} }}", + self.inject_filter + ) + } +} + +impl ExecutionPlan for TestNode { + fn name(&self) -> &str { + "TestInsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.input.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.len() == 1); + Ok(Arc::new(TestNode::new( + self.inject_filter, + children[0].clone(), + self.predicate.clone(), + ))) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("TestInsertExec is a stub for testing.") + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + // Since TestNode marks all parent filters as supported and adds its own filter, + // we use from_child to create a description with all parent filters supported + let child = &self.input; + let child_desc = ChildFilterDescription::from_child(&parent_filters, child)? + .with_self_filter(Arc::clone(&self.predicate)); + Ok(FilterDescription::new().with_child(child_desc)) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + if self.inject_filter { + // Add a FilterExec if our own filter was not handled by the child + + // We have 1 child + assert_eq!(child_pushdown_result.self_filters.len(), 1); + let self_pushdown_result = child_pushdown_result.self_filters[0].clone(); + // And pushed down 1 filter + assert_eq!(self_pushdown_result.len(), 1); + let self_pushdown_result: Vec<_> = self_pushdown_result.into_iter().collect(); + + let first_pushdown_result = self_pushdown_result[0].clone(); + + match &first_pushdown_result.discriminant { + PushedDown::No => { + // We have a filter to push down + let new_child = FilterExec::try_new( + Arc::clone(&first_pushdown_result.predicate), + Arc::clone(&self.input), + )?; + let new_self = + TestNode::new(false, Arc::new(new_child), self.predicate.clone()); + let mut res = + FilterPushdownPropagation::if_all(child_pushdown_result); + res.updated_node = Some(Arc::new(new_self) as Arc); + Ok(res) + } + PushedDown::Yes => { + let res = FilterPushdownPropagation::if_all(child_pushdown_result); + Ok(res) + } + } + } else { + let res = FilterPushdownPropagation::if_all(child_pushdown_result); + Ok(res) + } + } +} diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index d3b6ec700beec..f9d3a045469e1 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use insta::assert_snapshot; use std::sync::Arc; use std::{ any::Any, @@ -25,8 +26,8 @@ use std::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::JoinSide; use datafusion_common::{stats::Precision, ColumnStatistics, JoinType, ScalarValue}; +use datafusion_common::{JoinSide, NullEquality}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; @@ -35,9 +36,7 @@ use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::PhysicalExprRef; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; -use datafusion_physical_optimizer::join_selection::{ - hash_join_swap_subrule, JoinSelection, -}; +use datafusion_physical_optimizer::join_selection::JoinSelection; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::displayable; use datafusion_physical_plan::joins::utils::ColumnIndex; @@ -222,7 +221,7 @@ async fn test_join_with_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -237,12 +236,12 @@ async fn test_join_with_swap() { .expect("A proj is required to swap columns back to their original order"); assert_eq!(swapping_projection.expr().len(), 2); - let (col, name) = &swapping_projection.expr()[0]; - assert_eq!(name, "big_col"); - assert_col_expr(col, "big_col", 1); - let (col, name) = &swapping_projection.expr()[1]; - assert_eq!(name, "small_col"); - assert_col_expr(col, "small_col", 0); + let proj_expr = &swapping_projection.expr()[0]; + assert_eq!(proj_expr.alias, "big_col"); + assert_col_expr(&proj_expr.expr, "big_col", 1); + let proj_expr = &swapping_projection.expr()[1]; + assert_eq!(proj_expr.alias, "small_col"); + assert_col_expr(&proj_expr.expr, "small_col", 0); let swapped_join = swapping_projection .input() @@ -251,11 +250,19 @@ async fn test_join_with_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -276,7 +283,7 @@ async fn test_left_join_no_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -291,11 +298,19 @@ async fn test_left_join_no_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -317,7 +332,7 @@ async fn test_join_with_swap_semi() { &join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -336,11 +351,74 @@ async fn test_join_with_swap_semi() { assert_eq!(swapped_join.schema().fields().len(), 1); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, + Precision::Inexact(8192) + ); + assert_eq!( + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, + Precision::Inexact(2097152) + ); + assert_eq!(original_schema, swapped_join.schema()); + } +} + +#[tokio::test] +async fn test_join_with_swap_mark() { + let join_types = [JoinType::LeftMark, JoinType::RightMark]; + for join_type in join_types { + let (big, small) = create_big_and_small(); + + let join = HashJoinExec::try_new( + Arc::clone(&big), + Arc::clone(&small), + vec![( + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), + )], + None, + &join_type, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + let original_schema = join.schema(); + + let optimized_join = JoinSelection::new() + .optimize(Arc::new(join), &ConfigOptions::new()) + .unwrap(); + + let swapped_join = optimized_join + .as_any() + .downcast_ref::() + .expect( + "A proj is not required to swap columns back to their original order", + ); + + assert_eq!(swapped_join.schema().fields().len(), 2); + assert_eq!( + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); assert_eq!(original_schema, swapped_join.schema()); @@ -349,8 +427,7 @@ async fn test_join_with_swap_semi() { /// Compare the input plan with the plan after running the probe order optimizer. macro_rules! assert_optimized { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let expected_lines = $EXPECTED_LINES.iter().map(|s| *s).collect::>(); + ($PLAN: expr, @$EXPECTED_LINES: literal $(,)?) => { let plan = Arc::new($PLAN); let optimized = JoinSelection::new() @@ -358,12 +435,11 @@ macro_rules! assert_optimized { .unwrap(); let plan_string = displayable(optimized.as_ref()).indent(true).to_string(); - let actual_lines = plan_string.split("\n").collect::>(); + let actual = plan_string.trim(); - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + assert_snapshot!( + actual, + @$EXPECTED_LINES ); }; } @@ -384,7 +460,7 @@ async fn test_nested_join_swap() { &JoinType::Inner, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); let child_schema = child_join.schema(); @@ -401,7 +477,7 @@ async fn test_nested_join_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -412,17 +488,18 @@ async fn test_nested_join_swap() { // The first hash join's left is 'small' table (with 1000 rows), and the second hash join's // left is the F(small IJ big) which has an estimated cardinality of 2000 rows (vs medium which // has an exact cardinality of 10_000 rows). - let expected = [ - "ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col]", - " HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)]", - " ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col]", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)]", - " StatisticsExec: col_count=1, row_count=Inexact(1000)", - " StatisticsExec: col_count=1, row_count=Inexact(100000)", - " StatisticsExec: col_count=1, row_count=Inexact(10000)", - "", - ]; - assert_optimized!(expected, join); + assert_optimized!( + join, + @r" + ProjectionExec: expr=[medium_col@2 as medium_col, big_col@0 as big_col, small_col@1 as small_col] + HashJoinExec: mode=CollectLeft, join_type=Right, on=[(small_col@1, medium_col@0)] + ProjectionExec: expr=[big_col@1 as big_col, small_col@0 as small_col] + HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(small_col@0, big_col@0)] + StatisticsExec: col_count=1, row_count=Inexact(1000) + StatisticsExec: col_count=1, row_count=Inexact(100000) + StatisticsExec: col_count=1, row_count=Inexact(10000) + " + ); } #[tokio::test] @@ -440,7 +517,7 @@ async fn test_join_no_swap() { &JoinType::Inner, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -455,11 +532,19 @@ async fn test_join_no_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -496,12 +581,12 @@ async fn test_nl_join_with_swap(join_type: JoinType) { .expect("A proj is required to swap columns back to their original order"); assert_eq!(swapping_projection.expr().len(), 2); - let (col, name) = &swapping_projection.expr()[0]; - assert_eq!(name, "big_col"); - assert_col_expr(col, "big_col", 1); - let (col, name) = &swapping_projection.expr()[1]; - assert_eq!(name, "small_col"); - assert_col_expr(col, "small_col", 0); + let proj_expr = &swapping_projection.expr()[0]; + assert_eq!(proj_expr.alias, "big_col"); + assert_col_expr(&proj_expr.expr, "big_col", 1); + let proj_expr = &swapping_projection.expr()[1]; + assert_eq!(proj_expr.alias, "small_col"); + assert_col_expr(&proj_expr.expr, "small_col", 0); let swapped_join = swapping_projection .input() @@ -524,11 +609,19 @@ async fn test_nl_join_with_swap(join_type: JoinType) { ); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -538,7 +631,8 @@ async fn test_nl_join_with_swap(join_type: JoinType) { case::left_semi(JoinType::LeftSemi), case::left_anti(JoinType::LeftAnti), case::right_semi(JoinType::RightSemi), - case::right_anti(JoinType::RightAnti) + case::right_anti(JoinType::RightAnti), + case::right_mark(JoinType::RightMark) )] #[tokio::test] async fn test_nl_join_with_swap_no_proj(join_type: JoinType) { @@ -589,11 +683,19 @@ async fn test_nl_join_with_swap_no_proj(join_type: JoinType) { ); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -642,7 +744,7 @@ async fn test_hash_join_swap_on_joins_with_projections( &join_type, Some(projection), PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, )?); let swapped = join @@ -803,7 +905,7 @@ fn check_join_partition_mode( &JoinType::Inner, None, PartitionMode::Auto, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -1067,6 +1169,14 @@ impl ExecutionPlan for StatisticsExec { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + Ok(if partition.is_some() { + Statistics::new_unknown(&self.schema) + } else { + self.stats.clone() + }) + } } #[test] @@ -1442,10 +1552,11 @@ async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { &t.initial_join_type, None, t.initial_mode, - false, + NullEquality::NullEqualsNothing, )?) as _; - let optimized_join_plan = hash_join_swap_subrule(join, &ConfigOptions::new())?; + let optimized_join_plan = + JoinSelection::new().optimize(Arc::clone(&join), &ConfigOptions::new())?; // If swap did happen let projection_added = optimized_join_plan.as_any().is::(); diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index dd2c1960a6580..56d48901f284d 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -17,28 +17,26 @@ use std::sync::Arc; +use crate::physical_optimizer::test_utils::{ + coalesce_batches_exec, coalesce_partitions_exec, global_limit_exec, local_limit_exec, + sort_exec, sort_preserving_merge_exec, stream_exec, +}; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::BinaryExpr; -use datafusion_physical_expr::expressions::{col, lit}; -use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; +use datafusion_physical_expr::expressions::{col, lit, BinaryExpr}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan, ExecutionPlanProperties}; +use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; fn create_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -48,48 +46,6 @@ fn create_schema() -> SchemaRef { ])) } -fn streaming_table_exec(schema: SchemaRef) -> Result> { - Ok(Arc::new(StreamingTableExec::try_new( - Arc::clone(&schema), - vec![Arc::new(DummyStreamPartition { schema }) as _], - None, - None, - true, - None, - )?)) -} - -fn global_limit_exec( - input: Arc, - skip: usize, - fetch: Option, -) -> Arc { - Arc::new(GlobalLimitExec::new(input, skip, fetch)) -} - -fn local_limit_exec( - input: Arc, - fetch: usize, -) -> Arc { - Arc::new(LocalLimitExec::new(input, fetch)) -} - -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input)) -} - -fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) -} - fn projection_exec( schema: SchemaRef, input: Arc, @@ -118,16 +74,6 @@ fn filter_exec( )?)) } -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec( - local_limit: Arc, -) -> Arc { - Arc::new(CoalescePartitionsExec::new(local_limit)) -} - fn repartition_exec( streaming_table: Arc, ) -> Result> { @@ -141,24 +87,11 @@ fn empty_exec(schema: SchemaRef) -> Arc { Arc::new(EmptyExec::new(schema)) } -#[derive(Debug)] -struct DummyStreamPartition { - schema: SchemaRef, -} -impl PartitionStream for DummyStreamPartition { - fn schema(&self) -> &SchemaRef { - &self.schema - } - fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - unreachable!() - } -} - #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; + let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 0, Some(5)); let initial = get_plan_string(&global_limit); @@ -183,7 +116,7 @@ fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; + let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 2, Some(5)); let initial = get_plan_string(&global_limit); @@ -209,10 +142,10 @@ fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_li fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let repartition = repartition_exec(streaming_table)?; let filter = filter_exec(schema, repartition)?; - let coalesce_batches = coalesce_batches_exec(filter); + let coalesce_batches = coalesce_batches_exec(filter, 8192); let local_limit = local_limit_exec(coalesce_batches, 5); let coalesce_partitions = coalesce_partitions_exec(local_limit); let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); @@ -247,7 +180,7 @@ fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limi #[test] fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let filter = filter_exec(Arc::clone(&schema), streaming_table)?; let projection = projection_exec(schema, filter)?; let global_limit = global_limit_exec(projection, 0, Some(5)); @@ -279,8 +212,8 @@ fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); + let streaming_table = stream_exec(&schema); + let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); let projection = projection_exec(schema, coalesce_batches)?; let global_limit = global_limit_exec(projection, 0, Some(5)); @@ -310,18 +243,17 @@ fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batc #[test] fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); + let streaming_table = stream_exec(&schema); + let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); let projection = projection_exec(Arc::clone(&schema), coalesce_batches)?; let repartition = repartition_exec(projection)?; - let sort = sort_exec( - vec![PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }], - repartition, - ); - let spm = sort_preserving_merge_exec(sort.output_ordering().unwrap().to_vec(), sort); + let ordering: LexOrdering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }] + .into(); + let sort = sort_exec(ordering.clone(), repartition); + let spm = sort_preserving_merge_exec(ordering, sort); let global_limit = global_limit_exec(spm, 0, Some(5)); let initial = get_plan_string(&global_limit); @@ -357,7 +289,7 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let repartition = repartition_exec(streaming_table)?; let filter = filter_exec(schema, repartition)?; let coalesce_partitions = coalesce_partitions_exec(filter); diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index f9810eab8f594..ad15d6803413b 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -17,11 +17,12 @@ //! Integration tests for [`LimitedDistinctAggregation`] physical optimizer rule +use insta::assert_snapshot; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - assert_plan_matches_expected, build_group_by, mock_data, parquet_exec_with_sort, - schema, TestAggregate, + build_group_by, get_optimized_plan, mock_data, parquet_exec_with_sort, schema, + TestAggregate, }; use arrow::datatypes::DataType; @@ -30,9 +31,8 @@ use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::cast; -use datafusion_physical_expr::{expressions, expressions::col, PhysicalSortExpr}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr::expressions::{self, cast, col}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::{ aggregates::{AggregateExec, AggregateMode}, collect, @@ -40,16 +40,12 @@ use datafusion_physical_plan::{ ExecutionPlan, }; -async fn assert_results_match_expected( - plan: Arc, - expected: &str, -) -> Result<()> { +async fn run_plan_and_format(plan: Arc) -> Result { let cfg = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(cfg); let batches = collect(plan, ctx.task_ctx()).await?; let actual = format!("{}", pretty_format_batches(&batches)?); - assert_eq!(actual, expected); - Ok(()) + Ok(actual) } #[tokio::test] @@ -78,27 +74,33 @@ async fn test_partial_final() -> Result<()> { Arc::new(final_agg), 4, // fetch ); - // expected to push the limit to the Partial and Final AggregateExecs - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4]", - "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=4 + AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4] + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let expected = run_plan_and_format(plan).await?; + assert_snapshot!( + expected, + @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | | + | 4 | + +---+ + " + ); + Ok(()) } @@ -121,25 +123,31 @@ async fn test_single_local() -> Result<()> { 4, // fetch ); // expected to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=4 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let expected = run_plan_and_format(plan).await?; + assert_snapshot!( + expected, + @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | | + | 4 | + +---+ + " + ); Ok(()) } @@ -163,24 +171,30 @@ async fn test_single_global() -> Result<()> { Some(3), // fetch ); // expected to push the skip+fetch limit to the AggregateExec - let expected = [ - "GlobalLimitExec: skip=1, fetch=3", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + GlobalLimitExec: skip=1, fetch=3 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let expected = run_plan_and_format(plan).await?; + assert_snapshot!( + expected, + @r" + +---+ + | a | + +---+ + | 2 | + | | + | 4 | + +---+ + " + ); Ok(()) } @@ -211,37 +225,44 @@ async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { 4, // fetch ); // expected to push the limit to the outer AggregateExec only - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=4 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4] + AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let expected = run_plan_and_format(plan).await?; + assert_snapshot!( + expected, + @r" + +---+ + | a | + +---+ + | 1 | + | 2 | + | | + | 4 | + +---+ + " + ); Ok(()) } #[test] fn test_has_order_by() -> Result<()> { - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema()).unwrap(), + let schema = schema(); + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); - let source = parquet_exec_with_sort(vec![sort_key]); - let schema = source.schema(); + }] + .into(); + let source = parquet_exec_with_sort(schema.clone(), vec![sort_key]); // `SELECT a FROM DataSourceExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec // the `a > 1` filter is applied in the AggregateExec @@ -258,13 +279,17 @@ fn test_has_order_by() -> Result<()> { 10, // fetch ); // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", - "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=10 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet + " + ); Ok(()) } @@ -287,13 +312,17 @@ fn test_no_group_by() -> Result<()> { 10, // fetch ); // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[], aggr=[]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=10 + AggregateExec: mode=Single, gby=[], aggr=[] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); Ok(()) } @@ -317,13 +346,17 @@ fn test_has_aggregate_expression() -> Result<()> { 10, // fetch ); // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=10 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); Ok(()) } @@ -355,12 +388,16 @@ fn test_has_filter() -> Result<()> { ); // expected not to push the limit to the AggregateExec // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", - "DataSourceExec: partitions=1, partition_sizes=[1]", - ]; let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=10 + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 7d5d07715eebc..777c26e80e902 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -21,10 +21,13 @@ mod aggregate_statistics; mod combine_partial_final_agg; mod enforce_distribution; mod enforce_sorting; +mod filter_pushdown; mod join_selection; mod limit_pushdown; mod limited_distinct_aggregation; +mod partition_statistics; mod projection_pushdown; mod replace_with_order_preserving_variants; mod sanity_checker; mod test_utils; +mod window_optimize; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs new file mode 100644 index 0000000000000..62ab5cbc422be --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -0,0 +1,979 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema, SortOptions}; + use datafusion::datasource::listing::ListingTable; + use datafusion::prelude::SessionContext; + use datafusion_catalog::TableProvider; + use datafusion_common::stats::Precision; + use datafusion_common::Result; + use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::TaskContext; + use datafusion_expr_common::operator::Operator; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{binary, col, lit, Column}; + use datafusion_physical_expr::Partitioning; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::common::compute_record_batch_statistics; + use datafusion_physical_plan::empty::EmptyExec; + use datafusion_physical_plan::filter::FilterExec; + use datafusion_physical_plan::joins::CrossJoinExec; + use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; + use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; + use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; + use datafusion_physical_plan::repartition::RepartitionExec; + use datafusion_physical_plan::sorts::sort::SortExec; + use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; + use datafusion_physical_plan::{ + execute_stream_partitioned, get_plan_string, ExecutionPlan, + ExecutionPlanProperties, + }; + + use futures::TryStreamExt; + + /// Creates a test table with statistics from the test data directory. + /// + /// This function: + /// - Creates an external table from './tests/data/test_statistics_per_partition' + /// - If we set the `target_partition` to 2, the data contains 2 partitions, each with 2 rows + /// - Each partition has an "id" column (INT) with the following values: + /// - First partition: [3, 4] + /// - Second partition: [1, 2] + /// - Each row is 110 bytes in size + /// + /// @param create_table_sql Optional parameter to set the create table SQL + /// @param target_partition Optional parameter to set the target partitions + /// @return ExecutionPlan representing the scan of the table with statistics + async fn create_scan_exec_with_statistics( + create_table_sql: Option<&str>, + target_partition: Option, + ) -> Arc { + let mut session_config = SessionConfig::new().with_collect_statistics(true); + if let Some(partition) = target_partition { + session_config = session_config.with_target_partitions(partition); + } + let ctx = SessionContext::new_with_config(session_config); + // Create table with partition + let create_table_sql = create_table_sql.unwrap_or( + "CREATE EXTERNAL TABLE t1 (id INT NOT NULL, date DATE) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + PARTITIONED BY (date) \ + WITH ORDER (id ASC);", + ); + // Get table name from `create_table_sql` + let table_name = create_table_sql + .split_whitespace() + .nth(3) + .unwrap_or("t1") + .to_string(); + ctx.sql(create_table_sql) + .await + .unwrap() + .collect() + .await + .unwrap(); + let table = ctx.table_provider(table_name.as_str()).await.unwrap(); + let listing_table = table + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + listing_table + .scan(&ctx.state(), None, &[], None) + .await + .unwrap() + } + + /// Helper function to create expected statistics for a partition with Int32 column + fn create_partition_statistics( + num_rows: usize, + total_byte_size: usize, + min_value: i32, + max_value: i32, + include_date_column: bool, + ) -> Statistics { + let mut column_stats = vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(max_value))), + min_value: Precision::Exact(ScalarValue::Int32(Some(min_value))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }]; + + if include_date_column { + column_stats.push(ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + } + + Statistics { + num_rows: Precision::Exact(num_rows), + total_byte_size: Precision::Exact(total_byte_size), + column_statistics: column_stats, + } + } + + #[derive(PartialEq, Eq, Debug)] + enum ExpectedStatistics { + Empty, // row_count == 0 + NonEmpty(i32, i32, usize), // (min_id, max_id, row_count) + } + + /// Helper function to validate that statistics from statistics_by_partition match the actual data + async fn validate_statistics_with_data( + plan: Arc, + expected_stats: Vec, + id_column_index: usize, + ) -> Result<()> { + let ctx = TaskContext::default(); + let partitions = execute_stream_partitioned(plan, Arc::new(ctx))?; + + let mut actual_stats = Vec::new(); + for partition_stream in partitions.into_iter() { + let result: Vec = partition_stream.try_collect().await?; + + let mut min_id = i32::MAX; + let mut max_id = i32::MIN; + let mut row_count = 0; + + for batch in result { + if batch.num_columns() > id_column_index { + let id_array = batch + .column(id_column_index) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + let id_value = id_array.value(i); + min_id = min_id.min(id_value); + max_id = max_id.max(id_value); + row_count += 1; + } + } + } + + if row_count == 0 { + actual_stats.push(ExpectedStatistics::Empty); + } else { + actual_stats + .push(ExpectedStatistics::NonEmpty(min_id, max_id, row_count)); + } + } + + // Compare actual data with expected statistics + assert_eq!( + actual_stats.len(), + expected_stats.len(), + "Number of partitions with data doesn't match expected" + ); + for i in 0..actual_stats.len() { + assert_eq!( + actual_stats[i], expected_stats[i], + "Partition {i} data doesn't match statistics" + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_data_source() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let statistics = (0..scan.output_partitioning().partition_count()) + .map(|idx| scan.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + // Check the statistics of each partition + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), // (min_id, max_id, row_count) for first partition + ExpectedStatistics::NonEmpty(1, 2, 2), // (min_id, max_id, row_count) for second partition + ]; + validate_statistics_with_data(scan, expected_stats, 0).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_projection() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + // Add projection execution plan + let exprs = vec![ProjectionExpr { + expr: Arc::new(Column::new("id", 0)) as Arc, + alias: "id".to_string(), + }]; + let projection: Arc = + Arc::new(ProjectionExec::try_new(exprs, scan)?); + let statistics = (0..projection.output_partitioning().partition_count()) + .map(|idx| projection.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition_1 = + create_partition_statistics(2, 8, 3, 4, false); + let expected_statistic_partition_2 = + create_partition_statistics(2, 8, 1, 2, false); + // Check the statistics of each partition + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(projection, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_sort() -> Result<()> { + let scan_1 = create_scan_exec_with_statistics(None, Some(1)).await; + // Add sort execution plan + let ordering = [PhysicalSortExpr::new( + Arc::new(Column::new("id", 0)), + SortOptions::new(false, false), + )]; + let sort = SortExec::new(ordering.clone().into(), scan_1); + let sort_exec: Arc = Arc::new(sort); + let statistics = (0..sort_exec.output_partitioning().partition_count()) + .map(|idx| sort_exec.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition = + create_partition_statistics(4, 220, 1, 4, true); + assert_eq!(statistics.len(), 1); + assert_eq!(statistics[0], expected_statistic_partition); + // Check the statistics_by_partition with real results + let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; + validate_statistics_with_data(sort_exec.clone(), expected_stats, 0).await?; + + // Sort with preserve_partitioning + let scan_2 = create_scan_exec_with_statistics(None, Some(2)).await; + // Add sort execution plan + let sort_exec: Arc = Arc::new( + SortExec::new(ordering.into(), scan_2).with_preserve_partitioning(true), + ); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + let statistics = (0..sort_exec.output_partitioning().partition_count()) + .map(|idx| sort_exec.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(sort_exec, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_filter() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let predicate = binary( + Arc::new(Column::new("id", 0)), + Operator::Lt, + lit(1i32), + &schema, + )?; + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, scan)?); + let full_statistics = filter.partition_statistics(None)?; + let expected_full_statistic = Statistics { + num_rows: Precision::Inexact(0), + total_byte_size: Precision::Inexact(0), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Null), + sum_value: Precision::Exact(ScalarValue::Null), + distinct_count: Precision::Exact(0), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Null), + sum_value: Precision::Exact(ScalarValue::Null), + distinct_count: Precision::Exact(0), + }, + ], + }; + assert_eq!(full_statistics, expected_full_statistic); + + let statistics = (0..filter.output_partitioning().partition_count()) + .map(|idx| filter.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_full_statistic); + assert_eq!(statistics[1], expected_full_statistic); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_union() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let union_exec: Arc = + UnionExec::try_new(vec![scan.clone(), scan])?; + let statistics = (0..union_exec.output_partitioning().partition_count()) + .map(|idx| union_exec.partition_statistics(Some(idx))) + .collect::>>()?; + // Check that we have 4 partitions (2 from each scan) + assert_eq!(statistics.len(), 4); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + // Verify first partition (from first scan) + assert_eq!(statistics[0], expected_statistic_partition_1); + // Verify second partition (from first scan) + assert_eq!(statistics[1], expected_statistic_partition_2); + // Verify third partition (from second scan - same as first partition) + assert_eq!(statistics[2], expected_statistic_partition_1); + // Verify fourth partition (from second scan - same as second partition) + assert_eq!(statistics[3], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(union_exec, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_interleave() -> Result<()> { + let scan1 = create_scan_exec_with_statistics(None, Some(1)).await; + let scan2 = create_scan_exec_with_statistics(None, Some(1)).await; + + // Create same hash partitioning on the 'id' column as InterleaveExec + // requires all children have a consistent hash partitioning + let hash_expr1 = vec![col("id", &scan1.schema())?]; + let repartition1 = Arc::new(RepartitionExec::try_new( + scan1, + Partitioning::Hash(hash_expr1, 2), + )?); + let hash_expr2 = vec![col("id", &scan2.schema())?]; + let repartition2 = Arc::new(RepartitionExec::try_new( + scan2, + Partitioning::Hash(hash_expr2, 2), + )?); + + let interleave: Arc = + Arc::new(InterleaveExec::try_new(vec![repartition1, repartition2])?); + + // Verify the result of partition statistics + let stats = (0..interleave.output_partitioning().partition_count()) + .map(|idx| interleave.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(stats.len(), 2); + + let expected_stats = Statistics { + num_rows: Precision::Inexact(4), + total_byte_size: Precision::Inexact(220), + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + assert_eq!(stats[0], expected_stats); + assert_eq!(stats[1], expected_stats); + + // Verify the execution results + let partitions = execute_stream_partitioned( + interleave.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(partitions.len(), 2); + + let mut partition_row_counts = Vec::new(); + for partition_stream in partitions.into_iter() { + let results: Vec = partition_stream.try_collect().await?; + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + partition_row_counts.push(total_rows); + } + assert_eq!(partition_row_counts.len(), 2); + assert_eq!(partition_row_counts.iter().sum::(), 8); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_cross_join() -> Result<()> { + let left_scan = create_scan_exec_with_statistics(None, Some(1)).await; + let right_create_table_sql = "CREATE EXTERNAL TABLE t2 (id INT NOT NULL) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + WITH ORDER (id ASC);"; + let right_scan = + create_scan_exec_with_statistics(Some(right_create_table_sql), Some(2)).await; + let cross_join: Arc = + Arc::new(CrossJoinExec::new(left_scan, right_scan)); + let statistics = (0..cross_join.output_partitioning().partition_count()) + .map(|idx| cross_join.partition_statistics(Some(idx))) + .collect::>>()?; + // Check that we have 2 partitions + assert_eq!(statistics.len(), 2); + let mut expected_statistic_partition_1 = + create_partition_statistics(8, 48400, 1, 4, true); + expected_statistic_partition_1 + .column_statistics + .push(ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + let mut expected_statistic_partition_2 = + create_partition_statistics(8, 48400, 1, 4, true); + expected_statistic_partition_2 + .column_statistics + .push(ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(1, 4, 8), + ExpectedStatistics::NonEmpty(1, 4, 8), + ]; + validate_statistics_with_data(cross_join, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_coalesce_batches() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let coalesce_batches: Arc = + Arc::new(CoalesceBatchesExec::new(scan, 2)); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + let statistics = (0..coalesce_batches.output_partitioning().partition_count()) + .map(|idx| coalesce_batches.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(coalesce_batches, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_coalesce_partitions() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let coalesce_partitions: Arc = + Arc::new(CoalescePartitionsExec::new(scan)); + let expected_statistic_partition = + create_partition_statistics(4, 220, 1, 4, true); + let statistics = (0..coalesce_partitions.output_partitioning().partition_count()) + .map(|idx| coalesce_partitions.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 1); + assert_eq!(statistics[0], expected_statistic_partition); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; + validate_statistics_with_data(coalesce_partitions, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_local_limit() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let local_limit: Arc = + Arc::new(LocalLimitExec::new(scan.clone(), 1)); + let statistics = (0..local_limit.output_partitioning().partition_count()) + .map(|idx| local_limit.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + let mut expected_0 = statistics[0].clone(); + expected_0.column_statistics = expected_0 + .column_statistics + .into_iter() + .map(|c| c.to_inexact()) + .collect(); + let mut expected_1 = statistics[1].clone(); + expected_1.column_statistics = expected_1 + .column_statistics + .into_iter() + .map(|c| c.to_inexact()) + .collect(); + assert_eq!(statistics[0], expected_0); + assert_eq!(statistics[1], expected_1); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_global_limit_partitions() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + // Skip 2 rows + let global_limit: Arc = + Arc::new(GlobalLimitExec::new(scan.clone(), 0, Some(2))); + let statistics = (0..global_limit.output_partitioning().partition_count()) + .map(|idx| global_limit.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 1); + let expected_statistic_partition = + create_partition_statistics(2, 110, 3, 4, true); + assert_eq!(statistics[0], expected_statistic_partition); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_agg() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let scan_schema = scan.schema(); + + // select id, 1+id, count(*) from t group by id, 1+id + let group_by = PhysicalGroupBy::new_single(vec![ + (col("id", &scan_schema)?, "id".to_string()), + ( + binary( + lit(1), + Operator::Plus, + col("id", &scan_schema)?, + &scan_schema, + )?, + "expr".to_string(), + ), + ]); + + let aggr_expr = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&scan_schema)) + .alias(String::from("COUNT(c)")) + .build() + .map(Arc::new)?]; + + let aggregate_exec_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + Arc::clone(&scan), + scan_schema.clone(), + )?) as _; + + let mut plan_string = get_plan_string(&aggregate_exec_partial); + let _ = plan_string.swap_remove(1); + let expected_plan = vec![ + "AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]", + ]; + assert_eq!(plan_string, expected_plan); + + let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; + + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + assert_eq!(&p0_statistics, &expected_p0_statistics); + + let expected_p1_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + let p1_statistics = aggregate_exec_partial.partition_statistics(Some(1))?; + assert_eq!(&p1_statistics, &expected_p1_statistics); + + validate_statistics_with_data( + aggregate_exec_partial.clone(), + vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ], + 0, + ) + .await?; + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + group_by.clone(), + aggr_expr.clone(), + vec![None], + aggregate_exec_partial.clone(), + aggregate_exec_partial.schema(), + )?); + + let p0_statistics = agg_final.partition_statistics(Some(0))?; + assert_eq!(&p0_statistics, &expected_p0_statistics); + + let p1_statistics = agg_final.partition_statistics(Some(1))?; + assert_eq!(&p1_statistics, &expected_p1_statistics); + + validate_statistics_with_data( + agg_final.clone(), + vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ], + 0, + ) + .await?; + + // select id, 1+id, count(*) from empty_table group by id, 1+id + let empty_table = + Arc::new(EmptyExec::new(scan_schema.clone()).with_partitions(2)); + + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?) as _; + + let agg_plan = get_plan_string(&agg_partial).remove(0); + assert_eq!("AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]",agg_plan); + + let empty_stat = Statistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(0))?); + assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(1))?); + validate_statistics_with_data( + agg_partial.clone(), + vec![ExpectedStatistics::Empty, ExpectedStatistics::Empty], + 0, + ) + .await?; + + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?); + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + group_by.clone(), + aggr_expr.clone(), + vec![None], + agg_partial.clone(), + agg_partial.schema(), + )?); + + assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(0))?); + assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(1))?); + + validate_statistics_with_data( + agg_final, + vec![ExpectedStatistics::Empty, ExpectedStatistics::Empty], + 0, + ) + .await?; + + // select count(*) from empty_table + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?); + + let coalesce = Arc::new(CoalescePartitionsExec::new(agg_partial.clone())); + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + aggr_expr.clone(), + vec![None], + coalesce.clone(), + coalesce.schema(), + )?); + + let expect_stat = Statistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::new_unknown()], + }; + + assert_eq!(&expect_stat, &agg_final.partition_statistics(Some(0))?); + + // Verify that the aggregate final result has exactly one partition with one row + let mut partitions = execute_stream_partitioned( + agg_final.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(1, partitions.len()); + let result: Vec = partitions.remove(0).try_collect().await?; + assert_eq!(1, result[0].num_rows()); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_placeholder_rows() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let plan = Arc::new(PlaceholderRowExec::new(schema).with_partitions(2)) + as Arc; + let schema = plan.schema(); + + let ctx = TaskContext::default(); + let partitions = execute_stream_partitioned(Arc::clone(&plan), Arc::new(ctx))?; + + let mut all_batches = vec![]; + for (i, partition_stream) in partitions.into_iter().enumerate() { + let batches: Vec = partition_stream.try_collect().await?; + let actual = plan.partition_statistics(Some(i))?; + let expected = compute_record_batch_statistics( + std::slice::from_ref(&batches), + &schema, + None, + ); + assert_eq!(actual, expected); + all_batches.push(batches); + } + + let actual = plan.partition_statistics(None)?; + let expected = compute_record_batch_statistics(&all_batches, &schema, None); + assert_eq!(actual, expected); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_repartition() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let repartition = Arc::new(RepartitionExec::try_new( + scan.clone(), + Partitioning::RoundRobinBatch(3), + )?); + + let statistics = (0..repartition.partitioning().partition_count()) + .map(|idx| repartition.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 3); + + let expected_stats = Statistics { + num_rows: Precision::Inexact(1), + total_byte_size: Precision::Inexact(73), + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + // All partitions should have the same statistics + for stat in statistics.iter() { + assert_eq!(stat, &expected_stats); + } + + // Verify that the result has exactly 3 partitions + let partitions = execute_stream_partitioned( + repartition.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(partitions.len(), 3); + + // Collect row counts from each partition + let mut partition_row_counts = Vec::new(); + for partition_stream in partitions.into_iter() { + let results: Vec = partition_stream.try_collect().await?; + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + partition_row_counts.push(total_rows); + } + assert_eq!(partition_row_counts.len(), 3); + assert_eq!(partition_row_counts[0], 2); + assert_eq!(partition_row_counts[1], 2); + assert_eq!(partition_row_counts[2], 0); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_repartition_invalid_partition() -> Result<()> + { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let repartition = Arc::new(RepartitionExec::try_new( + scan.clone(), + Partitioning::RoundRobinBatch(2), + )?); + + let result = repartition.partition_statistics(Some(2)); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error + .to_string() + .contains("RepartitionExec invalid partition 2 (expected less than 2)")); + + let partitions = execute_stream_partitioned( + repartition.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(partitions.len(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_repartition_zero_partitions() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let scan_schema = scan.schema(); + + // Create a repartition with 0 partitions + let repartition = Arc::new(RepartitionExec::try_new( + Arc::new(EmptyExec::new(scan_schema.clone())), + Partitioning::RoundRobinBatch(0), + )?); + + let result = repartition.partition_statistics(Some(0))?; + assert_eq!(result, Statistics::new_unknown(&scan_schema)); + + // Verify that the result has exactly 0 partitions + let partitions = execute_stream_partitioned( + repartition.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(partitions.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_repartition_hash_partitioning() -> Result<()> + { + let scan = create_scan_exec_with_statistics(None, Some(1)).await; + + // Create hash partitioning on the 'id' column + let hash_expr = vec![col("id", &scan.schema())?]; + let repartition = Arc::new(RepartitionExec::try_new( + scan, + Partitioning::Hash(hash_expr, 2), + )?); + + // Verify the result of partition statistics of repartition + let stats = (0..repartition.partitioning().partition_count()) + .map(|idx| repartition.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(stats.len(), 2); + + let expected_stats = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Inexact(110), + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + assert_eq!(stats[0], expected_stats); + assert_eq!(stats[1], expected_stats); + + // Verify the repartition execution results + let partitions = + execute_stream_partitioned(repartition, Arc::new(TaskContext::default()))?; + assert_eq!(partitions.len(), 2); + + let mut partition_row_counts = Vec::new(); + for partition_stream in partitions.into_iter() { + let results: Vec = partition_stream.try_collect().await?; + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + partition_row_counts.push(total_rows); + } + assert_eq!(partition_row_counts.len(), 2); + assert_eq!(partition_row_counts.iter().sum::(), 4); + + Ok(()) + } +} diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 911d2c0cee05f..c51a5e02c9c33 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -25,21 +25,22 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; -use datafusion_common::{JoinSide, JoinType, ScalarValue}; +use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr::expressions::{ binary, cast, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; -use datafusion_physical_expr::ScalarFunctionExpr; -use datafusion_physical_expr::{ - Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; @@ -50,21 +51,19 @@ use datafusion_physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use datafusion_physical_plan::projection::{update_expr, ProjectionExec}; +use datafusion_physical_plan::projection::{update_expr, ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::PartitionStream; -use datafusion_physical_plan::streaming::StreamingTableExec; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; +use datafusion_physical_plan::{displayable, ExecutionPlan}; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_expr_common::columnar_value::ColumnarValue; +use insta::assert_snapshot; use itertools::Itertools; /// Mocked UDF -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct DummyUDF { signature: Signature, } @@ -128,7 +127,8 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), + Arc::new(ConfigOptions::default()), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -193,7 +193,8 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), + Arc::new(ConfigOptions::default()), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -223,8 +224,12 @@ fn test_update_matching_exprs() -> Result<()> { )?), ]; + let child_exprs: Vec = child + .iter() + .map(|(expr, alias)| ProjectionExpr::new(expr.clone(), alias.clone())) + .collect(); for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { - assert!(update_expr(&expr, &child, true)? + assert!(update_expr(&expr, &child_exprs, true)? .unwrap() .eq(&expected_expr)); } @@ -261,7 +266,8 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), + Arc::new(ConfigOptions::default()), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -326,7 +332,8 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b_new", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), + Arc::new(ConfigOptions::default()), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), @@ -356,8 +363,12 @@ fn test_update_projected_exprs() -> Result<()> { )?), ]; + let proj_exprs: Vec = projected_exprs + .iter() + .map(|(expr, alias)| ProjectionExpr::new(expr.clone(), alias.clone())) + .collect(); for (expr, expected_expr) in exprs.into_iter().zip(expected_exprs.into_iter()) { - assert!(update_expr(&expr, &projected_exprs, false)? + assert!(update_expr(&expr, &proj_exprs, false)? .unwrap() .eq(&expected_expr)); } @@ -421,24 +432,34 @@ fn test_csv_after_projection() -> Result<()> { let csv = create_projecting_csv_exec(); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("b", 2)), "b".to_string()), - (Arc::new(Column::new("d", 0)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 2)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 0)), "d".to_string()), ], csv.clone(), )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[b@2 as b, d@0 as d]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[d, c, b], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[b@2 as b, d@0 as d] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[d, c, b], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = - ["DataSourceExec: file_groups={1 group: [[x]]}, projection=[b, d], file_type=csv, has_header=false"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @"DataSourceExec: file_groups={1 group: [[x]]}, projection=[b, d], file_type=csv, has_header=false" + ); Ok(()) } @@ -448,24 +469,36 @@ fn test_memory_after_projection() -> Result<()> { let memory = create_projecting_memory_exec(); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("d", 2)), "d".to_string()), - (Arc::new(Column::new("e", 3)), "e".to_string()), - (Arc::new(Column::new("a", 1)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 2)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("e", 3)), "e".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 1)), "a".to_string()), ], memory.clone(), )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[d@2 as d, e@3 as e, a@1 as a]", - " DataSourceExec: partitions=0, partition_sizes=[]", - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[d@2 as d, e@3 as e, a@1 as a] + DataSourceExec: partitions=0, partition_sizes=[] + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = ["DataSourceExec: partitions=0, partition_sizes=[]"]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @"DataSourceExec: partitions=0, partition_sizes=[]" + ); + assert_eq!( after_optimize .clone() @@ -519,7 +552,7 @@ fn test_streaming_table_after_projection() -> Result<()> { }) as _], Some(&vec![0_usize, 2, 4, 3]), vec![ - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: Arc::new(Column::new("e", 2)), options: SortOptions::default(), @@ -528,11 +561,13 @@ fn test_streaming_table_after_projection() -> Result<()> { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(), + [PhysicalSortExpr { expr: Arc::new(Column::new("d", 3)), options: SortOptions::default(), - }]), + }] + .into(), ] .into_iter(), true, @@ -540,9 +575,9 @@ fn test_streaming_table_after_projection() -> Result<()> { )?; let projection = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("d", 3)), "d".to_string()), - (Arc::new(Column::new("e", 2)), "e".to_string()), - (Arc::new(Column::new("a", 0)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("e", 2)), "e".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), ], Arc::new(streaming_table) as _, )?) as _; @@ -579,7 +614,7 @@ fn test_streaming_table_after_projection() -> Result<()> { assert_eq!( result.projected_output_ordering().into_iter().collect_vec(), vec![ - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: Arc::new(Column::new("e", 1)), options: SortOptions::default(), @@ -588,11 +623,13 @@ fn test_streaming_table_after_projection() -> Result<()> { expr: Arc::new(Column::new("a", 2)), options: SortOptions::default(), }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(), + [PhysicalSortExpr { expr: Arc::new(Column::new("d", 0)), options: SortOptions::default(), - }]), + }] + .into(), ] ); assert!(result.is_infinite()); @@ -605,17 +642,17 @@ fn test_projection_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let child_projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("e", 4)), "new_e".to_string()), - (Arc::new(Column::new("a", 0)), "a".to_string()), - (Arc::new(Column::new("b", 1)), "new_b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("e", 4)), "new_e".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "new_b".to_string()), ], csv.clone(), )?); let top_projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("new_b", 3)), "new_b".to_string()), - ( + ProjectionExpr::new(Arc::new(Column::new("new_b", 3)), "new_b".to_string()), + ProjectionExpr::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("c", 0)), Operator::Plus, @@ -623,27 +660,43 @@ fn test_projection_after_projection() -> Result<()> { )), "binary".to_string(), ), - (Arc::new(Column::new("new_b", 3)), "newest_b".to_string()), + ProjectionExpr::new( + Arc::new(Column::new("new_b", 3)), + "newest_b".to_string(), + ), ], child_projection.clone(), )?); - let initial = get_plan_string(&top_projection); - let expected_initial = [ - "ProjectionExec: expr=[new_b@3 as new_b, c@0 + new_e@1 as binary, new_b@3 as newest_b]", - " ProjectionExec: expr=[c@2 as c, e@4 as new_e, a@0 as a, b@1 as new_b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(top_projection.as_ref()) + .indent(true) + .to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[new_b@3 as new_b, c@0 + new_e@1 as binary, new_b@3 as newest_b] + ProjectionExec: expr=[c@2 as c, e@4 as new_e, a@0 as a, b@1 as new_b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(top_projection, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[b@1 as new_b, c@2 + e@4 as binary, b@1 as newest_b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -652,67 +705,85 @@ fn test_projection_after_projection() -> Result<()> { fn test_output_req_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(OutputRequirementExec::new( - csv.clone(), - Some(LexRequirement::new(vec![ - PhysicalSortRequirement { - expr: Arc::new(Column::new("b", 1)), - options: Some(SortOptions::default()), - }, - PhysicalSortRequirement { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: Some(SortOptions::default()), - }, - ])), + csv, + Some(OrderingRequirements::new( + [ + PhysicalSortRequirement::new( + Arc::new(Column::new("b", 1)), + Some(SortOptions::default()), + ), + PhysicalSortRequirement::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + Some(SortOptions::default()), + ), + ] + .into(), + )), Distribution::HashPartitioned(vec![ Arc::new(Column::new("a", 0)), Arc::new(Column::new("b", 1)), ]), + None, )); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("a", 0)), "new_a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), ], sort_req.clone(), )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " OutputRequirementExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + OutputRequirementExec: order_by=[(b@1, asc), (c@2 + a@0, asc)], dist_by=HashPartitioned[[a@0, b@1]]) + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected: [&str; 3] = [ - "OutputRequirementExec", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - - assert_eq!(get_plan_string(&after_optimize), expected); - let expected_reqs = LexRequirement::new(vec![ - PhysicalSortRequirement { - expr: Arc::new(Column::new("b", 2)), - options: Some(SortOptions::default()), - }, - PhysicalSortRequirement { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 0)), - Operator::Plus, - Arc::new(Column::new("new_a", 1)), - )), - options: Some(SortOptions::default()), - }, - ]); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + OutputRequirementExec: order_by=[(b@2, asc), (c@0 + new_a@1, asc)], dist_by=HashPartitioned[[new_a@1, b@2]]) + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); + + let expected_reqs = OrderingRequirements::new( + [ + PhysicalSortRequirement::new( + Arc::new(Column::new("b", 2)), + Some(SortOptions::default()), + ), + PhysicalSortRequirement::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_a", 1)), + )), + Some(SortOptions::default()), + ), + ] + .into(), + ); assert_eq!( after_optimize .as_any() @@ -752,29 +823,40 @@ fn test_coalesce_partitions_after_projection() -> Result<()> { Arc::new(CoalescePartitionsExec::new(csv)); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("b", 1)), "b".to_string()), - (Arc::new(Column::new("a", 0)), "a_new".to_string()), - (Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), ], coalesce_partitions, )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", - " CoalescePartitionsExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d] + CoalescePartitionsExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "CoalescePartitionsExec", - " ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + CoalescePartitionsExec + ProjectionExec: expr=[b@1 as b, a@0 as a_new, d@3 as d] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -795,33 +877,44 @@ fn test_filter_after_projection() -> Result<()> { Arc::new(Column::new("a", 0)), )), )); - let filter: Arc = Arc::new(FilterExec::try_new(predicate, csv)?); + let filter = Arc::new(FilterExec::try_new(predicate, csv)?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("a", 0)), "a_new".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), - (Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_new".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), ], filter.clone(), - )?); + )?) as _; + + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", - " FilterExec: b@1 - a@0 > d@3 - a@0", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d] + FilterExec: b@1 - a@0 > d@3 - a@0 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "FilterExec: b@1 - a_new@0 > d@2 - a_new@0", - " ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + FilterExec: b@1 - a_new@0 > d@2 - a_new@0 + ProjectionExec: expr=[a@0 as a_new, b@1 as b, d@3 as d] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -875,41 +968,58 @@ fn test_join_after_projection() -> Result<()> { ])), )), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, None, None, StreamJoinPartitionMode::SinglePartition, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), - (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), - (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), - (Arc::new(Column::new("a", 5)), "a_from_right".to_string()), - (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left".to_string()), + ProjectionExpr::new( + Arc::new(Column::new("a", 5)), + "a_from_right".to_string(), + ), + ProjectionExpr::new( + Arc::new(Column::new("c", 7)), + "c_from_right".to_string(), + ), ], join, - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right] + SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let expected_filter_col_ind = vec![ ColumnIndex { @@ -945,7 +1055,7 @@ fn test_join_after_required_projection() -> Result<()> { let left_csv = create_simple_csv_exec(); let right_csv = create_simple_csv_exec(); - let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( + let join = Arc::new(SymmetricHashJoinExec::try_new( left_csv, right_csv, vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], @@ -989,45 +1099,56 @@ fn test_join_after_required_projection() -> Result<()> { ])), )), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, None, None, StreamJoinPartitionMode::SinglePartition, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("a", 5)), "a".to_string()), - (Arc::new(Column::new("b", 6)), "b".to_string()), - (Arc::new(Column::new("c", 7)), "c".to_string()), - (Arc::new(Column::new("d", 8)), "d".to_string()), - (Arc::new(Column::new("e", 9)), "e".to_string()), - (Arc::new(Column::new("a", 0)), "a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("d", 3)), "d".to_string()), - (Arc::new(Column::new("e", 4)), "e".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 5)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 6)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 8)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("e", 9)), "e".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d".to_string()), + ProjectionExpr::new(Arc::new(Column::new("e", 4)), "e".to_string()), ], join, - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e] + SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e] + SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1061,7 +1182,7 @@ fn test_nested_loop_join_after_projection() -> Result<()> { Field::new("c", DataType::Int32, true), ]); - let join: Arc = Arc::new(NestedLoopJoinExec::try_new( + let join = Arc::new(NestedLoopJoinExec::try_new( left_csv, right_csv, Some(JoinFilter::new( @@ -1071,29 +1192,39 @@ fn test_nested_loop_join_after_projection() -> Result<()> { )), &JoinType::Inner, None, - )?); + )?) as _; let projection: Arc = Arc::new(ProjectionExec::try_new( - vec![(col_left_c, "c".to_string())], + vec![ProjectionExpr::new(col_left_c, "c".to_string())], Arc::clone(&join), - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c]", - " NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c] + NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); - let after_optimize = + let after_optimize_string = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1, projection=[c@2]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize_string.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + NestedLoopJoinExec: join_type=Inner, filter=a@0 < b@1, projection=[c@2] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + + ); Ok(()) } @@ -1104,7 +1235,7 @@ fn test_hash_join_after_projection() -> Result<()> { let left_csv = create_simple_csv_exec(); let right_csv = create_simple_csv_exec(); - let join: Arc = Arc::new(HashJoinExec::try_new( + let join = Arc::new(HashJoinExec::try_new( left_csv, right_csv, vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], @@ -1150,46 +1281,76 @@ fn test_hash_join_after_projection() -> Result<()> { &JoinType::Inner, None, PartitionMode::Auto, - true, + NullEquality::NullEqualsNothing, )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), - (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), - (Arc::new(Column::new("a", 0)), "a_from_left".to_string()), - (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c_from_left".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_from_left".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a_from_left".to_string()), + ProjectionExpr::new( + Arc::new(Column::new("c", 7)), + "c_from_right".to_string(), + ), ], join.clone(), - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@7 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@7 as c_from_right] + HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); // HashJoinExec only returns result after projection. Because there are some alias columns in the projection, the ProjectionExec is not removed. - let expected = ["ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false"]; - assert_eq!(get_plan_string(&after_optimize), expected); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right] + HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("a", 0)), "a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("c", 7)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 7)), "c".to_string()), ], join.clone(), )?); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); // Comparing to the previous result, this projection don't have alias columns either change the order of output fields. So the ProjectionExec is removed. - let expected = ["HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false"]; - assert_eq!(get_plan_string(&after_optimize), expected); + assert_snapshot!( + actual, + @r" + HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1197,7 +1358,7 @@ fn test_hash_join_after_projection() -> Result<()> { #[test] fn test_repartition_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let repartition: Arc = Arc::new(RepartitionExec::try_new( + let repartition = Arc::new(RepartitionExec::try_new( csv, Partitioning::Hash( vec![ @@ -1210,29 +1371,38 @@ fn test_repartition_after_projection() -> Result<()> { )?); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("b", 1)), "b_new".to_string()), - (Arc::new(Column::new("a", 0)), "a".to_string()), - (Arc::new(Column::new("d", 3)), "d_new".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b_new".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("d", 3)), "d_new".to_string()), ], repartition, - )?); - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", - " RepartitionExec: partitioning=Hash([a@0, b@1, d@3], 6), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(initial, expected_initial); + )?) as _; + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new] + RepartitionExec: partitioning=Hash([a@0, b@1, d@3], 6), input_partitions=1 + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1", - " ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1 + ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); assert_eq!( after_optimize @@ -1257,49 +1427,53 @@ fn test_repartition_after_projection() -> Result<()> { #[test] fn test_sort_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let sort_req: Arc = Arc::new(SortExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: SortOptions::default(), - }, - ]), - csv.clone(), - )); + let sort_exec = SortExec::new( + [ + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + PhysicalSortExpr::new_default(Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + ))), + ] + .into(), + csv, + ); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("a", 0)), "new_a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), ], - sort_req.clone(), - )?); + Arc::new(sort_exec), + )?) as _; - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " SortExec: expr=[b@1 ASC, c@2 + a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + SortExec: expr=[b@1 ASC, c@2 + a@0 ASC], preserve_partitioning=[false] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "SortExec: expr=[b@2 ASC, c@0 + new_a@1 ASC], preserve_partitioning=[false]", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + SortExec: expr=[b@2 ASC, c@0 + new_a@1 ASC], preserve_partitioning=[false] + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1307,49 +1481,53 @@ fn test_sort_after_projection() -> Result<()> { #[test] fn test_sort_preserving_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: SortOptions::default(), - }, - ]), - csv.clone(), - )); + let sort_exec = SortPreservingMergeExec::new( + [ + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + PhysicalSortExpr::new_default(Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + ))), + ] + .into(), + csv, + ); let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("a", 0)), "new_a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), ], - sort_req.clone(), - )?); + Arc::new(sort_exec), + )?) as _; - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " SortPreservingMergeExec: [b@1 ASC, c@2 + a@0 ASC]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + SortPreservingMergeExec: [b@1 ASC, c@2 + a@0 ASC] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "SortPreservingMergeExec: [b@2 ASC, c@0 + new_a@1 ASC]", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + SortPreservingMergeExec: [b@2 ASC, c@0 + new_a@1 ASC] + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1357,40 +1535,48 @@ fn test_sort_preserving_after_projection() -> Result<()> { #[test] fn test_union_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let union: Arc = - Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); + let union = UnionExec::try_new(vec![csv.clone(), csv.clone(), csv])?; let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - (Arc::new(Column::new("c", 2)), "c".to_string()), - (Arc::new(Column::new("a", 0)), "new_a".to_string()), - (Arc::new(Column::new("b", 1)), "b".to_string()), + ProjectionExpr::new(Arc::new(Column::new("c", 2)), "c".to_string()), + ProjectionExpr::new(Arc::new(Column::new("a", 0)), "new_a".to_string()), + ProjectionExpr::new(Arc::new(Column::new("b", 1)), "b".to_string()), ], union.clone(), - )?); + )?) as _; - let initial = get_plan_string(&projection); - let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " UnionExec", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(initial, expected_initial); + let initial = displayable(projection.as_ref()).indent(true).to_string(); + let actual = initial.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + UnionExec + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "UnionExec", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + UnionExec + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1421,17 +1607,17 @@ fn test_partition_col_projection_pushdown() -> Result<()> { let source = partitioned_data_source(); let partitioned_schema = source.schema(); - let projection = Arc::new(ProjectionExec::try_new( + let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ( + ProjectionExpr::new( col("string_col", partitioned_schema.as_ref())?, "string_col".to_string(), ), - ( + ProjectionExpr::new( col("partition_col", partitioned_schema.as_ref())?, "partition_col".to_string(), ), - ( + ProjectionExpr::new( col("int_col", partitioned_schema.as_ref())?, "int_col".to_string(), ), @@ -1442,11 +1628,17 @@ fn test_partition_col_projection_pushdown() -> Result<()> { let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[string_col@1 as string_col, partition_col@2 as partition_col, int_col@0 as int_col]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[string_col@1 as string_col, partition_col@2 as partition_col, int_col@0 as int_col] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false + " + ); Ok(()) } @@ -1456,13 +1648,13 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { let source = partitioned_data_source(); let partitioned_schema = source.schema(); - let projection = Arc::new(ProjectionExec::try_new( + let projection: Arc = Arc::new(ProjectionExec::try_new( vec![ - ( + ProjectionExpr::new( col("string_col", partitioned_schema.as_ref())?, "string_col".to_string(), ), - ( + ProjectionExpr::new( // CAST(partition_col, Utf8View) cast( col("partition_col", partitioned_schema.as_ref())?, @@ -1471,7 +1663,7 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { )?, "partition_col".to_string(), ), - ( + ProjectionExpr::new( col("int_col", partitioned_schema.as_ref())?, "int_col".to_string(), ), @@ -1482,11 +1674,17 @@ fn test_partition_col_projection_pushdown_expr() -> Result<()> { let after_optimize = ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; - let expected = [ - "ProjectionExec: expr=[string_col@1 as string_col, CAST(partition_col@2 AS Utf8View) as partition_col, int_col@0 as int_col]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false" - ]; - assert_eq!(get_plan_string(&after_optimize), expected); + let after_optimize_string = displayable(after_optimize.as_ref()) + .indent(true) + .to_string(); + let actual = after_optimize_string.trim(); + assert_snapshot!( + actual, + @r" + ProjectionExec: expr=[string_col@1 as string_col, CAST(partition_col@2 AS Utf8View) as partition_col, int_col@0 as int_col] + DataSourceExec: file_groups={1 group: [[x]]}, projection=[int_col, string_col, partition_col], file_type=csv, has_header=false + " + ); Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index 58eb866c590cc..066e52614a12e 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -18,7 +18,10 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, sort_preserving_merge_exec, stream_exec_ordered_with_projection, + check_integrity, coalesce_batches_exec, coalesce_partitions_exec, + create_test_schema3, parquet_exec_with_sort, sort_exec, + sort_exec_with_preserve_partitioning, sort_preserving_merge_exec, + sort_preserving_merge_exec_with_fetch, stream_exec_ordered_with_projection, }; use datafusion::prelude::SessionContext; @@ -26,1011 +29,1026 @@ use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use insta::{allow_duplicates, assert_snapshot}; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{assert_contains, NullEquality, Result}; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::source::DataSourceExec; use datafusion_execution::TaskContext; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::expressions::{self, col, Column}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{ + plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext +}; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::collect; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::{ - displayable, get_plan_string, ExecutionPlan, Partitioning, + collect, displayable, ExecutionPlan, Partitioning, }; -use datafusion::datasource::source::DataSourceExec; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::Result; -use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{self, col, Column}; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{replace_with_order_preserving_variants, OrderPreservationContext}; -use datafusion_common::config::ConfigOptions; use object_store::memory::InMemory; use object_store::ObjectStore; use rstest::rstest; use url::Url; -/// Runs the `replace_with_order_preserving_variants` sub-rule and asserts -/// the plan against the original and expected plans. -/// -/// # Parameters -/// -/// * `$EXPECTED_PLAN_LINES`: Expected input plan. -/// * `EXPECTED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag -/// `prefer_existing_sort` is `false`. -/// * `EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan when -/// the flag `prefer_existing_sort` is `true`. -/// * `$PLAN`: The plan to optimize. -macro_rules! assert_optimized_prefer_sort_on_off { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr, $SOURCE_UNBOUNDED: expr) => { - if $PREFER_EXISTING_SORT { - assert_optimized!( - $EXPECTED_PLAN_LINES, - $EXPECTED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, - $PLAN, - $PREFER_EXISTING_SORT, - $SOURCE_UNBOUNDED - ); - } else { - assert_optimized!( - $EXPECTED_PLAN_LINES, - $EXPECTED_OPTIMIZED_PLAN_LINES, - $PLAN, - $PREFER_EXISTING_SORT, - $SOURCE_UNBOUNDED - ); - } - }; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Boundedness { + Unbounded, + Bounded, } -/// Runs the `replace_with_order_preserving_variants` sub-rule and asserts -/// the plan against the original and expected plans for both bounded and -/// unbounded cases. -/// -/// # Parameters -/// -/// * `EXPECTED_UNBOUNDED_PLAN_LINES`: Expected input unbounded plan. -/// * `EXPECTED_BOUNDED_PLAN_LINES`: Expected input bounded plan. -/// * `EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan, which is -/// the same regardless of the value of the `prefer_existing_sort` flag. -/// * `EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES`: Optimized plan when the flag -/// `prefer_existing_sort` is `false` for bounded cases. -/// * `EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES`: Optimized plan -/// when the flag `prefer_existing_sort` is `true` for bounded cases. -/// * `$PLAN`: The plan to optimize. -/// * `$SOURCE_UNBOUNDED`: Whether the given plan contains an unbounded source. -macro_rules! assert_optimized_in_all_boundedness_situations { - ($EXPECTED_UNBOUNDED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PLAN_LINES: expr, $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES: expr, $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $SOURCE_UNBOUNDED: expr, $PREFER_EXISTING_SORT: expr) => { - if $SOURCE_UNBOUNDED { - assert_optimized_prefer_sort_on_off!( - $EXPECTED_UNBOUNDED_PLAN_LINES, - $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, - $EXPECTED_UNBOUNDED_OPTIMIZED_PLAN_LINES, - $PLAN, - $PREFER_EXISTING_SORT, - $SOURCE_UNBOUNDED - ); - } else { - assert_optimized_prefer_sort_on_off!( - $EXPECTED_BOUNDED_PLAN_LINES, - $EXPECTED_BOUNDED_OPTIMIZED_PLAN_LINES, - $EXPECTED_BOUNDED_PREFER_SORT_ON_OPTIMIZED_PLAN_LINES, - $PLAN, - $PREFER_EXISTING_SORT, - $SOURCE_UNBOUNDED - ); - } - }; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SortPreference { + PreserveOrder, + MaximizeParallelism, } -/// Runs the `replace_with_order_preserving_variants` sub-rule and asserts -/// the plan against the original and expected plans. -/// -/// # Parameters -/// -/// * `$EXPECTED_PLAN_LINES`: Expected input plan. -/// * `$EXPECTED_OPTIMIZED_PLAN_LINES`: Expected optimized plan. -/// * `$PLAN`: The plan to optimize. -/// * `$PREFER_EXISTING_SORT`: Value of the `prefer_existing_sort` flag. -#[macro_export] -macro_rules! assert_optimized { - ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr, $PREFER_EXISTING_SORT: expr, $SOURCE_UNBOUNDED: expr) => { - let physical_plan = $PLAN; - let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES - .iter().map(|s| *s).collect(); - - assert_eq!( - expected_plan_lines, actual, - "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" - ); +struct ReplaceTest { + plan: Arc, + boundedness: Boundedness, + sort_preference: SortPreference, +} - let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES.iter().map(|s| *s).collect(); +impl ReplaceTest { + fn new(plan: Arc) -> Self { + Self { + plan, + boundedness: Boundedness::Bounded, + sort_preference: SortPreference::MaximizeParallelism, + } + } - // Run the rule top-down - let mut config = ConfigOptions::new(); - config.optimizer.prefer_existing_sort=$PREFER_EXISTING_SORT; - let plan_with_pipeline_fixer = OrderPreservationContext::new_default(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, &config)).data().and_then(check_integrity)?; - let optimized_physical_plan = parallel.plan; + fn with_boundedness(mut self, boundedness: Boundedness) -> Self { + self.boundedness = boundedness; + self + } + + fn with_sort_preference(mut self, sort_preference: SortPreference) -> Self { + self.sort_preference = sort_preference; + self + } - // Get string representation of the plan - let actual = get_plan_string(&optimized_physical_plan); - assert_eq!( - expected_optimized_lines, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_optimized_lines:#?}\nactual:\n\n{actual:#?}\n\n" + async fn execute_plan(&self) -> String { + let mut config = ConfigOptions::new(); + config.optimizer.prefer_existing_sort = + self.sort_preference == SortPreference::PreserveOrder; + + let plan_with_pipeline_fixer = OrderPreservationContext::new_default( + self.plan.clone().reset_state().unwrap(), + ); + + let parallel = plan_with_pipeline_fixer + .transform_up(|plan_with_pipeline_fixer| { + replace_with_order_preserving_variants( + plan_with_pipeline_fixer, + false, + false, + &config, + ) + }) + .data() + .and_then(check_integrity) + .unwrap(); + + let optimized_physical_plan = parallel.plan; + let optimized_plan_string = displayable(optimized_physical_plan.as_ref()) + .indent(true) + .to_string(); + + if self.boundedness == Boundedness::Bounded { + let ctx = SessionContext::new(); + let object_store = InMemory::new(); + object_store + .put( + &object_store::path::Path::from("file_path"), + bytes::Bytes::from("").into(), + ) + .await + .expect("could not create object store"); + ctx.register_object_store( + &Url::parse("test://").unwrap(), + Arc::new(object_store), + ); + let task_ctx = Arc::new(TaskContext::from(&ctx)); + let res = collect(optimized_physical_plan, task_ctx).await; + assert!( + res.is_ok(), + "Some errors occurred while executing the optimized physical plan: {:?}\nPlan: {}", + res.unwrap_err(), optimized_plan_string ); + } + + optimized_plan_string + } + + async fn run(&self) -> String { + let input_plan_string = displayable(self.plan.as_ref()).indent(true).to_string(); - if !$SOURCE_UNBOUNDED { - let ctx = SessionContext::new(); - let object_store = InMemory::new(); - object_store.put(&object_store::path::Path::from("file_path"), bytes::Bytes::from("").into()).await?; - ctx.register_object_store(&Url::parse("test://").unwrap(), Arc::new(object_store)); - let task_ctx = Arc::new(TaskContext::from(&ctx)); - let res = collect(optimized_physical_plan, task_ctx).await; - assert!( - res.is_ok(), - "Some errors occurred while executing the optimized physical plan: {:?}", res.unwrap_err() - ); - } - }; + let optimized = self.execute_plan().await; + + if input_plan_string == optimized { + format!("Input / Optimized:\n{input_plan_string}") + } else { + format!("Input:\n{input_plan_string}\nOptimized:\n{optimized}") + } } +} #[rstest] #[tokio::test] // Searches for a simple sort and a repartition just after it, the second repartition with 1 input partition should not be affected async fn test_replace_multiple_input_repartition_1( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let sort_exprs: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, sort_exprs.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, sort_exprs.clone()), }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort = sort_exec_with_preserve_partitioning(sort_exprs.clone(), repartition); + let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_with_inter_children_change_only( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr_default("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr_default("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let sort = sort_exec( - vec![sort_expr_default("a", &coalesce_partitions.schema())], - coalesce_partitions, - false, - ); + let sort = sort_exec(ordering.clone(), coalesce_partitions); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); let filter = filter_exec(repartition_hash2); - let sort2 = sort_exec(vec![sort_expr_default("a", &filter.schema())], filter, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &sort2.schema())], sort2); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortPreservingMergeExec: [a@0 ASC]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortPreservingMergeExec: [a@0 ASC]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort2 = sort_exec_with_preserve_partitioning(ordering.clone(), filter); + + let physical_plan = sort_preserving_merge_exec(ordering, sort2); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC] + + Optimized: + SortPreservingMergeExec: [a@0 ASC] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortPreservingMergeExec: [a@0 ASC] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC + + Optimized: + SortPreservingMergeExec: [a@0 ASC] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortPreservingMergeExec: [a@0 ASC] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_2( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + FilterExec: c@1 > 3 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); + let sort = + sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_replace_multiple_input_repartition_with_extra_steps_2( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); - let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); + let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr, 8192); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec_2 = coalesce_batches_exec(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec_2, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let coalesce_batches_exec_2 = coalesce_batches_exec(filter, 8192); + let sort = + sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec_2); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_not_replacing_when_no_need_to_preserve_sorting( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => stream_exec_ordered_with_projection(&schema, ordering), + Boundedness::Bounded => memory_exec_sorted(&schema, ordering), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); - - let physical_plan: Arc = - coalesce_partitions_exec(coalesce_batches_exec); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results same with and without flag, because there is no executor with ordering requirement - let expected_optimized_bounded = [ - "CoalescePartitionsExec", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; - - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); + let physical_plan = coalesce_partitions_exec(coalesce_batches_exec); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + CoalescePartitionsExec + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + CoalescePartitionsExec + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + // Expected bounded results same with and without flag, because there is no executor with ordering requirement + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + CoalescePartitionsExec + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] -async fn test_with_multiple_replacable_repartitions( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, +async fn test_with_multiple_replaceable_repartitions( + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches = coalesce_batches_exec(filter); + let coalesce_batches = coalesce_batches_exec(filter, 8192); let repartition_hash_2 = repartition_exec_hash(coalesce_batches); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash_2, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=8192", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash_2); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_not_replace_with_different_orderings( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { + use datafusion_physical_expr::LexOrdering; + let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering_a = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering_a) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering_a), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let sort = sort_exec( - vec![sort_expr_default("c", &repartition_hash.schema())], - repartition_hash, - true, - ); + let ordering_c: LexOrdering = + [sort_expr_default("c", &repartition_hash.schema())].into(); + let sort = sort_exec_with_preserve_partitioning(ordering_c.clone(), repartition_hash); + let physical_plan = sort_preserving_merge_exec(ordering_c, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + // Expected bounded results same with and without flag, because ordering requirement of the executor is + // different from the existing ordering. + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &sort.schema())], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results same with and without flag, because ordering requirement of the executor is different than the existing ordering. - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; - - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); Ok(()) } #[rstest] #[tokio::test] async fn test_with_lost_ordering( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering.clone()) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering.clone()), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = - sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let physical_plan = sort_exec(ordering, coalesce_partitions); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [a@0 ASC NULLS LAST] + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_with_lost_and_kept_ordering( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { + use datafusion_physical_expr::LexOrdering; + let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; - let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) - } else { - memory_exec_sorted(&schema, sort_exprs) + let ordering_a = [sort_expr("a", &schema)].into(); + let source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, ordering_a) + } + Boundedness::Bounded => memory_exec_sorted(&schema, ordering_a), }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let sort = sort_exec( - vec![sort_expr_default("c", &coalesce_partitions.schema())], - coalesce_partitions, - false, - ); + let ordering_c: LexOrdering = + [sort_expr_default("c", &coalesce_partitions.schema())].into(); + let sort = sort_exec(ordering_c.clone(), coalesce_partitions); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); let filter = filter_exec(repartition_hash2); - let sort2 = sort_exec(vec![sort_expr_default("c", &filter.schema())], filter, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &sort2.schema())], sort2); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results with and without flag - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = [ - "SortPreservingMergeExec: [c@1 ASC]", - " FilterExec: c@1 > 3", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", - " CoalescePartitionsExec", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); + let sort2 = sort_exec_with_preserve_partitioning(ordering_c.clone(), filter); + let physical_plan = sort_preserving_merge_exec(ordering_c, sort2); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + + Optimized: + SortPreservingMergeExec: [c@1 ASC] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, SortPreference::MaximizeParallelism) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + }, + (Boundedness::Bounded, SortPreference::PreserveOrder) => { + assert_snapshot!(physical_plan, @r" + Input: + SortPreservingMergeExec: [c@1 ASC] + SortExec: expr=[c@1 ASC], preserve_partitioning=[true] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + + Optimized: + SortPreservingMergeExec: [c@1 ASC] + FilterExec: c@1 > 3 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + SortExec: expr=[c@1 ASC], preserve_partitioning=[false] + CoalescePartitionsExec + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + } + } + } + Ok(()) } #[rstest] #[tokio::test] async fn test_with_multiple_child_trees( - #[values(false, true)] source_unbounded: bool, - #[values(false, true)] prefer_existing_sort: bool, + #[values(Boundedness::Unbounded, Boundedness::Bounded)] boundedness: Boundedness, + #[values(SortPreference::PreserveOrder, SortPreference::MaximizeParallelism)] + sort_pref: SortPreference, ) -> Result<()> { let schema = create_test_schema()?; - let left_sort_exprs = vec![sort_expr("a", &schema)]; - let left_source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, left_sort_exprs) - } else { - memory_exec_sorted(&schema, left_sort_exprs) + let left_ordering = [sort_expr("a", &schema)].into(); + let left_source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, left_ordering) + } + Boundedness::Bounded => memory_exec_sorted(&schema, left_ordering), }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); - let right_sort_exprs = vec![sort_expr("a", &schema)]; - let right_source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, right_sort_exprs) - } else { - memory_exec_sorted(&schema, right_sort_exprs) + let right_ordering = [sort_expr("a", &schema)].into(); + let right_source = match boundedness { + Boundedness::Unbounded => { + stream_exec_ordered_with_projection(&schema, right_ordering) + } + Boundedness::Bounded => memory_exec_sorted(&schema, right_ordering), }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); @@ -1039,85 +1057,55 @@ async fn test_with_multiple_child_trees( let hash_join_exec = hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); - let sort = sort_exec( - vec![sort_expr_default("a", &hash_join_exec.schema())], - hash_join_exec, - true, - ); + let ordering: LexOrdering = [sort_expr_default("a", &hash_join_exec.schema())].into(); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), hash_join_exec); + let physical_plan = sort_preserving_merge_exec(ordering, sort); + + let run = ReplaceTest::new(physical_plan) + .with_boundedness(boundedness) + .with_sort_preference(sort_pref); + + let physical_plan = run.run().await; + + allow_duplicates! { + match (boundedness, sort_pref) { + (Boundedness::Unbounded, _) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] + "); + }, + (Boundedness::Bounded, _) => { + assert_snapshot!(physical_plan, @r" + Input / Optimized: + SortPreservingMergeExec: [a@0 ASC] + SortExec: expr=[a@0 ASC], preserve_partitioning=[true] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)] + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST + "); + // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. + // Hence, no need to preserve existing ordering. + } + } + } - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &sort.schema())], sort); - - // Expected inputs unbounded and bounded - let expected_input_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - let expected_input_bounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - - // Expected unbounded result (same for with and without flag) - let expected_optimized_unbounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", - ]; - - // Expected bounded results same with and without flag, because ordering get lost during intermediate executor anyway. Hence no need to preserve - // existing ordering. - let expected_optimized_bounded = [ - "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[1], output_ordering=a@0 ASC NULLS LAST", - ]; - let expected_optimized_bounded_sort_preserve = expected_optimized_bounded; - - assert_optimized_in_all_boundedness_situations!( - expected_input_unbounded, - expected_input_bounded, - expected_optimized_unbounded, - expected_optimized_bounded, - expected_optimized_bounded_sort_preserve, - physical_plan, - source_unbounded, - prefer_existing_sort - ); Ok(()) } @@ -1145,18 +1133,6 @@ fn sort_expr_options( } } -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, - preserve_partitioning: bool, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new( - SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning), - ) -} - fn repartition_exec_round_robin(input: Arc) -> Arc { Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(8)).unwrap()) } @@ -1184,14 +1160,6 @@ fn filter_exec(input: Arc) -> Arc { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec(input: Arc) -> Arc { - Arc::new(CoalescePartitionsExec::new(input)) -} - fn hash_join_exec( left: Arc, right: Arc, @@ -1209,7 +1177,7 @@ fn hash_join_exec( &JoinType::Inner, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -1229,7 +1197,7 @@ fn create_test_schema() -> Result { // projection parameter is given static due to testing needs fn memory_exec_sorted( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { pub fn make_partition(schema: &SchemaRef, sz: i32) -> RecordBatch { let values = (0..sz).collect::>(); @@ -1245,7 +1213,6 @@ fn memory_exec_sorted( let rows = 5; let partitions = 1; - let sort_exprs = sort_exprs.into_iter().collect(); Arc::new({ let data: Vec> = (0..partitions) .map(|_| vec![make_partition(schema, rows)]) @@ -1254,8 +1221,79 @@ fn memory_exec_sorted( DataSourceExec::new(Arc::new( MemorySourceConfig::try_new(&data, schema.clone(), Some(projection)) .unwrap() - .try_with_sort_information(vec![sort_exprs]) + .try_with_sort_information(vec![ordering]) .unwrap(), )) }) } + +#[test] +fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { + // Create a schema + let schema = create_test_schema3()?; + let parquet_sort_exprs = vec![[sort_expr("a", &schema)].into()]; + let parquet_exec = parquet_exec_with_sort(schema, parquet_sort_exprs); + let coalesced = coalesce_partitions_exec(parquet_exec.clone()) + .with_fetch(Some(10)) + .unwrap(); + + // Test sort's fetch is greater than coalesce fetch, return error because it's not reasonable + let requirements = OrderPreservationContext::new( + coalesced.clone(), + false, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + false, + vec![], + )], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, Some(15)); + assert_contains!(res.unwrap_err().to_string(), "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]"); + + // Test sort is without fetch, expected to get the fetch value from the coalesced + let requirements = OrderPreservationContext::new( + coalesced.clone(), + false, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + false, + vec![], + )], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, None)?; + assert_eq!(res.plan.fetch(), Some(10),); + + // Test sort's fetch is less than coalesces fetch, expected to get the fetch value from the sort + let requirements = OrderPreservationContext::new( + coalesced, + false, + vec![OrderPreservationContext::new(parquet_exec, false, vec![])], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, Some(5))?; + assert_eq!(res.plan.fetch(), Some(5),); + Ok(()) +} + +#[test] +fn test_plan_with_order_breaking_variants_preserves_fetch() -> Result<()> { + let schema = create_test_schema3()?; + let parquet_sort_exprs: LexOrdering = [sort_expr("a", &schema)].into(); + let parquet_exec = parquet_exec_with_sort(schema, vec![parquet_sort_exprs.clone()]); + let spm = sort_preserving_merge_exec_with_fetch( + parquet_sort_exprs, + parquet_exec.clone(), + 10, + ); + let requirements = OrderPreservationContext::new( + spm, + true, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + true, + vec![], + )], + ); + let res = plan_with_order_breaking_variants(requirements)?; + assert_eq!(res.plan.fetch(), Some(10)); + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index a73d084a081f3..ce6eb13c86c44 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +use insta::assert_snapshot; use std::sync::Arc; use crate::physical_optimizer::test_utils::{ bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, - repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, + projection_exec, repartition_exec, sort_exec, sort_expr, sort_expr_options, + sort_merge_join_exec, sort_preserving_merge_exec, union_exec, }; use arrow::compute::SortOptions; @@ -27,9 +29,10 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{JoinType, Result}; -use datafusion_physical_expr::expressions::col; +use datafusion_common::{JoinType, Result, ScalarValue}; +use datafusion_physical_expr::expressions::{col, Literal}; use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::repartition::RepartitionExec; @@ -397,34 +400,32 @@ fn assert_sanity_check(plan: &Arc, is_sane: bool) { ); } -/// Check if the plan we created is as expected by comparing the plan -/// formatted as a string. -fn assert_plan(plan: &dyn ExecutionPlan, expected_lines: Vec<&str>) { - let plan_str = displayable(plan).indent(true).to_string(); - let actual_lines: Vec<&str> = plan_str.trim().lines().collect(); - assert_eq!(actual_lines, expected_lines); -} - #[tokio::test] /// Tests that plan is valid when the sort requirements are satisfied. async fn test_bounded_window_agg_sort_requirement() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr_options( + let ordering: LexOrdering = [sort_expr_options( "c9", &source.schema(), SortOptions { descending: false, nulls_first: false, }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); - let bw = bounded_window_exec("c9", sort_exprs, sort); - assert_plan(bw.as_ref(), vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]" - ]); + )] + .into(); + let sort = sort_exec(ordering.clone(), source); + let bw = bounded_window_exec("c9", ordering, sort); + let plan_str = displayable(bw.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r#" + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + "# + ); assert_sanity_check(&bw, true); Ok(()) } @@ -443,10 +444,15 @@ async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { }, )]; let bw = bounded_window_exec("c9", sort_exprs, source); - assert_plan(bw.as_ref(), vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " DataSourceExec: partitions=1, partition_sizes=[0]" - ]); + let plan_str = displayable(bw.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r#" + BoundedWindowAggExec: wdw=[count: Field { name: "count", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: partitions=1, partition_sizes=[0] + "# + ); // Order requirement of the `BoundedWindowAggExec` is not satisfied. We expect to receive error during sanity check. assert_sanity_check(&bw, false); Ok(()) @@ -458,14 +464,16 @@ async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { async fn test_global_limit_single_partition() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = global_limit_exec(source); - - assert_plan( - limit.as_ref(), - vec![ - "GlobalLimitExec: skip=0, fetch=100", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let limit = global_limit_exec(source, 0, Some(100)); + + let plan_str = displayable(limit.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + GlobalLimitExec: skip=0, fetch=100 + DataSourceExec: partitions=1, partition_sizes=[0] + " ); assert_sanity_check(&limit, true); Ok(()) @@ -477,15 +485,17 @@ async fn test_global_limit_single_partition() -> Result<()> { async fn test_global_limit_multi_partition() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = global_limit_exec(repartition_exec(source)); - - assert_plan( - limit.as_ref(), - vec![ - "GlobalLimitExec: skip=0, fetch=100", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let limit = global_limit_exec(repartition_exec(source), 0, Some(100)); + + let plan_str = displayable(limit.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + GlobalLimitExec: skip=0, fetch=100 + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + " ); // Distribution requirement of the `GlobalLimitExec` is not satisfied. We expect to receive error during sanity check. assert_sanity_check(&limit, false); @@ -497,14 +507,16 @@ async fn test_global_limit_multi_partition() -> Result<()> { async fn test_local_limit() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = local_limit_exec(source); - - assert_plan( - limit.as_ref(), - vec![ - "LocalLimitExec: fetch=100", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let limit = local_limit_exec(source, 100); + + let plan_str = displayable(limit.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=100 + DataSourceExec: partitions=1, partition_sizes=[0] + " ); assert_sanity_check(&limit, true); Ok(()) @@ -518,12 +530,12 @@ async fn test_sort_merge_join_satisfied() -> Result<()> { let source1 = memory_exec(&schema1); let source2 = memory_exec(&schema2); let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let ordering1 = [sort_expr_options("c9", &source1.schema(), sort_opts)].into(); + let ordering2 = [sort_expr_options("a", &source2.schema(), sort_opts)].into(); + let left = sort_exec(ordering1, source1); + let right = sort_exec(ordering2, source2); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -538,17 +550,19 @@ async fn test_sort_merge_join_satisfied() -> Result<()> { let join_ty = JoinType::Inner; let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let plan_str = displayable(smj.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + " ); assert_sanity_check(&smj, true); Ok(()) @@ -562,15 +576,16 @@ async fn test_sort_merge_join_order_missing() -> Result<()> { let schema2 = create_test_schema2(); let source1 = memory_exec(&schema1); let right = memory_exec(&schema2); - let sort_exprs1 = vec![sort_expr_options( + let ordering1 = [sort_expr_options( "c9", &source1.schema(), SortOptions::default(), - )]; - let left = sort_exec(sort_exprs1, source1); + )] + .into(); + let left = sort_exec(ordering1, source1); // Missing sort of the right child here.. - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -585,16 +600,18 @@ async fn test_sort_merge_join_order_missing() -> Result<()> { let join_ty = JoinType::Inner; let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let plan_str = displayable(smj.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[0] + " ); // Order requirement for the `SortMergeJoin` is not satisfied for right child. We expect to receive error during sanity check. assert_sanity_check(&smj, false); @@ -610,16 +627,16 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { let source1 = memory_exec(&schema1); let source2 = memory_exec(&schema2); let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); + let ordering1 = [sort_expr_options("c9", &source1.schema(), sort_opts)].into(); + let ordering2 = [sort_expr_options("a", &source2.schema(), sort_opts)].into(); + let left = sort_exec(ordering1, source1); + let right = sort_exec(ordering2, source2); let right = Arc::new(RepartitionExec::try_new( right, Partitioning::RoundRobinBatch(10), )?); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -631,19 +648,95 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { let join_ty = JoinType::Inner; let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); - assert_plan( - smj.as_ref(), - vec![ - "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", - " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", - " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]", - ], + let plan_str = displayable(smj.as_ref()).indent(true).to_string(); + let actual = plan_str.trim(); + assert_snapshot!( + actual, + @r" + SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)] + RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1 + SortExec: expr=[c9@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1 + SortExec: expr=[a@0 ASC], preserve_partitioning=[false] + DataSourceExec: partitions=1, partition_sizes=[0] + " ); // Distribution requirement for the `SortMergeJoin` is not satisfied for right child (has round-robin partitioning). We expect to receive error during sanity check. assert_sanity_check(&smj, false); Ok(()) } + +/// A particular edge case. +/// +/// See . +#[tokio::test] +async fn test_union_with_sorts_and_constants() -> Result<()> { + let schema_in = create_test_schema2(); + + let proj_exprs_1 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + let proj_exprs_2 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("bar".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + + let source_1 = memory_exec(&schema_in); + let source_1 = projection_exec(proj_exprs_1.clone(), source_1).unwrap(); + let schema_sources = source_1.schema(); + let ordering_sources: LexOrdering = + [sort_expr("a", &schema_sources).nulls_last()].into(); + let source_1 = sort_exec(ordering_sources.clone(), source_1); + + let source_2 = memory_exec(&schema_in); + let source_2 = projection_exec(proj_exprs_2, source_2).unwrap(); + let source_2 = sort_exec(ordering_sources.clone(), source_2); + + let plan = union_exec(vec![source_1, source_2]); + + let schema_out = plan.schema(); + let ordering_out: LexOrdering = [ + sort_expr("const_1", &schema_out).nulls_last(), + sort_expr("const_2", &schema_out).nulls_last(), + sort_expr("a", &schema_out).nulls_last(), + ] + .into(); + + let plan = sort_preserving_merge_exec(ordering_out, plan); + + let plan_str = displayable(plan.as_ref()).indent(true).to_string(); + let plan_str = plan_str.trim(); + assert_snapshot!( + plan_str, + @r" + SortPreservingMergeExec: [const_1@0 ASC NULLS LAST, const_2@1 ASC NULLS LAST, a@2 ASC NULLS LAST] + UnionExec + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, foo as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, bar as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + " + ); + + assert_sanity_check(&plan, true); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 4587f99989d34..8ca33f3d4abb9 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::fmt::Formatter; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use arrow::array::Int32Array; use arrow::compute::SortOptions; @@ -30,19 +30,20 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; +use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; -use datafusion_common::{JoinType, Result}; +use datafusion_common::{ColumnStatistics, JoinType, NullEquality, Result, Statistics}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; -use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr::{expressions, PhysicalExpr}; +use datafusion_physical_expr::expressions::{self, col}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, LexRequirement, PhysicalSortExpr, + LexOrdering, OrderingRequirements, PhysicalSortExpr, }; use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_optimizer::PhysicalOptimizerRule; @@ -55,6 +56,7 @@ use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{JoinFilter, JoinOn}; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -68,10 +70,10 @@ use datafusion_physical_plan::{ }; /// Create a non sorted parquet exec -pub fn parquet_exec(schema: &SchemaRef) -> Arc { +pub fn parquet_exec(schema: SchemaRef) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), + schema, Arc::new(ParquetSource::default()), ) .with_file(PartitionedFile::new("x".to_string(), 100)) @@ -82,11 +84,12 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { /// Create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( + schema: SchemaRef, output_ordering: Vec, ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), + schema, Arc::new(ParquetSource::default()), ) .with_file(PartitionedFile::new("x".to_string(), 100)) @@ -96,38 +99,92 @@ pub(crate) fn parquet_exec_with_sort( DataSourceExec::from_data_source(config) } +fn int64_stats() -> ColumnStatistics { + ColumnStatistics { + null_count: Precision::Absent, + sum_value: Precision::Absent, + max_value: Precision::Exact(1_000_000.into()), + min_value: Precision::Exact(0.into()), + distinct_count: Precision::Absent, + } +} + +fn column_stats() -> Vec { + vec![ + int64_stats(), // a + int64_stats(), // b + int64_stats(), // c + ColumnStatistics::default(), + ColumnStatistics::default(), + ] +} + +/// Create parquet datasource exec using schema from [`schema`]. +pub(crate) fn parquet_exec_with_stats(file_size: u64) -> Arc { + let mut statistics = Statistics::new_unknown(&schema()); + statistics.num_rows = Precision::Inexact(10000); + statistics.column_statistics = column_stats(); + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + schema(), + Arc::new(ParquetSource::new(Default::default())), + ) + .with_file(PartitionedFile::new("x".to_string(), file_size)) + .with_statistics(statistics) + .build(); + + assert_eq!( + config.file_source.statistics().unwrap().num_rows, + Precision::Inexact(10000) + ); + DataSourceExec::from_data_source(config) +} + pub fn schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - Field::new("d", DataType::Int32, true), - Field::new("e", DataType::Boolean, true), - ])) + static SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])) + }); + Arc::clone(&SCHEMA) } pub fn create_test_schema() -> Result { - let nullable_column = Field::new("nullable_col", DataType::Int32, true); - let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); - let schema = Arc::new(Schema::new(vec![nullable_column, non_nullable_column])); + static SCHEMA: LazyLock = LazyLock::new(|| { + let nullable_column = Field::new("nullable_col", DataType::Int32, true); + let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); + Arc::new(Schema::new(vec![nullable_column, non_nullable_column])) + }); + let schema = Arc::clone(&SCHEMA); Ok(schema) } pub fn create_test_schema2() -> Result { - let col_a = Field::new("col_a", DataType::Int32, true); - let col_b = Field::new("col_b", DataType::Int32, true); - let schema = Arc::new(Schema::new(vec![col_a, col_b])); + static SCHEMA: LazyLock = LazyLock::new(|| { + let col_a = Field::new("col_a", DataType::Int32, true); + let col_b = Field::new("col_b", DataType::Int32, true); + Arc::new(Schema::new(vec![col_a, col_b])) + }); + let schema = Arc::clone(&SCHEMA); Ok(schema) } // Generate a schema which consists of 5 columns (a, b, c, d, e) pub fn create_test_schema3() -> Result { - let a = Field::new("a", DataType::Int32, true); - let b = Field::new("b", DataType::Int32, false); - let c = Field::new("c", DataType::Int32, true); - let d = Field::new("d", DataType::Int32, false); - let e = Field::new("e", DataType::Int32, false); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e])); + static SCHEMA: LazyLock = LazyLock::new(|| { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, false); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, false); + let e = Field::new("e", DataType::Int32, false); + Arc::new(Schema::new(vec![a, b, c, d, e])) + }); + let schema = Arc::clone(&SCHEMA); Ok(schema) } @@ -145,7 +202,7 @@ pub fn sort_merge_join_exec( None, *join_type, vec![SortOptions::default(); join_on.len()], - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -191,7 +248,7 @@ pub fn hash_join_exec( join_type, None, PartitionMode::Partitioned, - true, + NullEquality::NullEqualsNothing, )?)) } @@ -200,17 +257,28 @@ pub fn bounded_window_exec( sort_exprs: impl IntoIterator, input: Arc, ) -> Arc { - let sort_exprs: LexOrdering = sort_exprs.into_iter().collect(); + bounded_window_exec_with_partition(col_name, sort_exprs, &[], input) +} + +pub fn bounded_window_exec_with_partition( + col_name: &str, + sort_exprs: impl IntoIterator, + partition_by: &[Arc], + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect::>(); let schema = input.schema(); let window_expr = create_window_expr( &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], - &[], - sort_exprs.as_ref(), + partition_by, + &sort_exprs, Arc::new(WindowFrame::new(Some(false))), - schema.as_ref(), + schema, false, + false, + None, ) .unwrap(); @@ -233,36 +301,37 @@ pub fn filter_exec( } pub fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + Arc::new(SortPreservingMergeExec::new(ordering, input)) } pub fn sort_preserving_merge_exec_with_fetch( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, fetch: usize, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input).with_fetch(Some(fetch))) + Arc::new(SortPreservingMergeExec::new(ordering, input).with_fetch(Some(fetch))) } pub fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) + UnionExec::try_new(input).unwrap() } -pub fn limit_exec(input: Arc) -> Arc { - global_limit_exec(local_limit_exec(input)) -} - -pub fn local_limit_exec(input: Arc) -> Arc { - Arc::new(LocalLimitExec::new(input, 100)) +pub fn local_limit_exec( + input: Arc, + fetch: usize, +) -> Arc { + Arc::new(LocalLimitExec::new(input, fetch)) } -pub fn global_limit_exec(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new(input, 0, Some(100))) +pub fn global_limit_exec( + input: Arc, + skip: usize, + fetch: Option, +) -> Arc { + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } pub fn repartition_exec(input: Arc) -> Arc { @@ -292,30 +361,50 @@ pub fn aggregate_exec(input: Arc) -> Arc { ) } -pub fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 128)) +pub fn coalesce_batches_exec( + input: Arc, + batch_size: usize, +) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, batch_size)) } pub fn sort_exec( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, ) -> Arc { - sort_exec_with_fetch(sort_exprs, None, input) + sort_exec_with_fetch(ordering, None, input) +} + +pub fn sort_exec_with_preserve_partitioning( + ordering: LexOrdering, + input: Arc, +) -> Arc { + Arc::new(SortExec::new(ordering, input).with_preserve_partitioning(true)) } pub fn sort_exec_with_fetch( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, fetch: Option, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input).with_fetch(fetch)) + Arc::new(SortExec::new(ordering, input).with_fetch(fetch)) +} + +pub fn projection_exec( + expr: Vec<(Arc, String)>, + input: Arc, +) -> Result> { + let proj_exprs: Vec = expr + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?)) } /// A test [`ExecutionPlan`] whose requirements can be configured. #[derive(Debug)] pub struct RequirementsTestExec { - required_input_ordering: LexOrdering, + required_input_ordering: Option, maintains_input_order: bool, input: Arc, } @@ -323,7 +412,7 @@ pub struct RequirementsTestExec { impl RequirementsTestExec { pub fn new(input: Arc) -> Self { Self { - required_input_ordering: LexOrdering::default(), + required_input_ordering: None, maintains_input_order: true, input, } @@ -332,7 +421,7 @@ impl RequirementsTestExec { /// sets the required input ordering pub fn with_required_input_ordering( mut self, - required_input_ordering: LexOrdering, + required_input_ordering: Option, ) -> Self { self.required_input_ordering = required_input_ordering; self @@ -377,9 +466,11 @@ impl ExecutionPlan for RequirementsTestExec { self.input.properties() } - fn required_input_ordering(&self) -> Vec> { - let requirement = LexRequirement::from(self.required_input_ordering.clone()); - vec![Some(requirement)] + fn required_input_ordering(&self) -> Vec> { + vec![self + .required_input_ordering + .as_ref() + .map(|ordering| OrderingRequirements::from(ordering.clone()))] } fn maintains_input_order(&self) -> Vec { @@ -436,13 +527,6 @@ pub fn check_integrity(context: PlanContext) -> Result Vec<&str> { - plan.split('\n') - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .collect() -} - // construct a stream partition for test purposes #[derive(Debug)] pub struct TestStreamPartition { @@ -458,13 +542,28 @@ impl PartitionStream for TestStreamPartition { } } -/// Create an unbounded stream exec +/// Create an unbounded stream table without data ordering. +pub fn stream_exec(schema: &SchemaRef) -> Arc { + Arc::new( + StreamingTableExec::try_new( + Arc::clone(schema), + vec![Arc::new(TestStreamPartition { + schema: Arc::clone(schema), + }) as _], + None, + vec![], + true, + None, + ) + .unwrap(), + ) +} + +/// Create an unbounded stream table with data ordering. pub fn stream_exec_ordered( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new( StreamingTableExec::try_new( Arc::clone(schema), @@ -472,7 +571,7 @@ pub fn stream_exec_ordered( schema: Arc::clone(schema), }) as _], None, - vec![sort_exprs], + vec![ordering], true, None, ) @@ -480,12 +579,11 @@ pub fn stream_exec_ordered( ) } -// Creates a stream exec source for the test purposes +/// Create an unbounded stream table with data ordering and built-in projection. pub fn stream_exec_ordered_with_projection( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; Arc::new( @@ -495,7 +593,7 @@ pub fn stream_exec_ordered_with_projection( schema: Arc::clone(schema), }) as _], Some(&projection), - vec![sort_exprs], + vec![ordering], true, None, ) @@ -542,26 +640,15 @@ pub fn build_group_by(input_schema: &SchemaRef, columns: Vec) -> Physica PhysicalGroupBy::new_single(group_by_expr.clone()) } -pub fn assert_plan_matches_expected( - plan: &Arc, - expected: &[&str], -) -> Result<()> { - let expected_lines: Vec<&str> = expected.to_vec(); +pub fn get_optimized_plan(plan: &Arc) -> Result { let config = ConfigOptions::new(); let optimized = LimitedDistinctAggregation::new().optimize(Arc::clone(plan), &config)?; let optimized_result = displayable(optimized.as_ref()).indent(true).to_string(); - let actual_lines = trim_plan_display(&optimized_result); - - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - Ok(()) + Ok(optimized_result) } /// Describe the type of aggregate being tested diff --git a/datafusion/core/tests/physical_optimizer/window_optimize.rs b/datafusion/core/tests/physical_optimizer/window_optimize.rs new file mode 100644 index 0000000000000..fc1e6444d756e --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/window_optimize.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod test { + use arrow::array::{Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_datasource::memory::MemorySourceConfig; + use datafusion_datasource::source::DataSourceExec; + use datafusion_execution::TaskContext; + use datafusion_expr::WindowFrame; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{col, Column}; + use datafusion_physical_expr::window::PlainAggregateWindowExpr; + use datafusion_physical_plan::windows::BoundedWindowAggExec; + use datafusion_physical_plan::{common, ExecutionPlan, InputOrderMode}; + use std::sync::Arc; + + /// Test case for + #[tokio::test] + async fn test_window_constant_aggregate() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let c = Arc::new(Column::new("b", 1)); + let cnt = AggregateExprBuilder::new(count_udaf(), vec![c]) + .schema(schema.clone()) + .alias("t") + .build()?; + let partition = [col("a", &schema)?]; + let frame = WindowFrame::new(None); + let plain = PlainAggregateWindowExpr::new( + Arc::new(cnt), + &partition, + &[], + Arc::new(frame), + None, + ); + + let bounded_agg_exec = BoundedWindowAggExec::try_new( + vec![Arc::new(plain)], + source, + InputOrderMode::Linear, + true, + )?; + let task_ctx = Arc::new(TaskContext::default()); + common::collect(bounded_agg_exec.execute(0, task_ctx)?).await?; + + Ok(()) + } + + pub fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + Some(1), + Some(3), + Some(2), + Some(1), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + Some(6), + Some(2), + Some(8), + Some(9), + ])), + ], + )?; + + MemorySourceConfig::try_new_exec(&[vec![batch]], Arc::clone(&schema), None) + } +} diff --git a/datafusion/core/tests/schema_adapter/mod.rs b/datafusion/core/tests/schema_adapter/mod.rs new file mode 100644 index 0000000000000..2f81a43f4736e --- /dev/null +++ b/datafusion/core/tests/schema_adapter/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod schema_adapter_integration_tests; diff --git a/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs b/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs new file mode 100644 index 0000000000000..c3c92a9028d67 --- /dev/null +++ b/datafusion/core/tests/schema_adapter/schema_adapter_integration_tests.rs @@ -0,0 +1,363 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use bytes::{BufMut, BytesMut}; +use datafusion::common::Result; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{ + ArrowSource, CsvSource, FileSource, JsonSource, ParquetSource, +}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::ColumnStatistics; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; +use datafusion_datasource::schema_adapter::{ + SchemaAdapter, SchemaAdapterFactory, SchemaMapper, +}; +use datafusion_datasource::source::DataSourceExec; +use datafusion_execution::object_store::ObjectStoreUrl; +use object_store::{memory::InMemory, path::Path, ObjectStore}; +use parquet::arrow::ArrowWriter; + +async fn write_parquet(batch: RecordBatch, store: Arc, path: &str) { + let mut out = BytesMut::new().writer(); + { + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + store.put(&Path::from(path), data.into()).await.unwrap(); +} + +/// A schema adapter factory that transforms column names to uppercase +#[derive(Debug, PartialEq)] +struct UppercaseAdapterFactory {} + +impl SchemaAdapterFactory for UppercaseAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(UppercaseAdapter { + table_schema: projected_table_schema, + }) + } +} + +/// Schema adapter that transforms column names to uppercase +#[derive(Debug)] +struct UppercaseAdapter { + table_schema: SchemaRef, +} + +impl SchemaAdapter for UppercaseAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.table_schema.field(index); + let uppercase_name = field.name().to_uppercase(); + file_schema + .fields() + .iter() + .position(|f| f.name().to_uppercase() == uppercase_name) + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + let mut projection = Vec::new(); + + // Map each field in the table schema to the corresponding field in the file schema + for table_field in self.table_schema.fields() { + let uppercase_name = table_field.name().to_uppercase(); + if let Some(pos) = file_schema + .fields() + .iter() + .position(|f| f.name().to_uppercase() == uppercase_name) + { + projection.push(pos); + } + } + + let mapper = UppercaseSchemaMapper { + output_schema: self.output_schema(), + projection: projection.clone(), + }; + + Ok((Arc::new(mapper), projection)) + } +} + +impl UppercaseAdapter { + fn output_schema(&self) -> SchemaRef { + let fields: Vec = self + .table_schema + .fields() + .iter() + .map(|f| { + Field::new( + f.name().to_uppercase().as_str(), + f.data_type().clone(), + f.is_nullable(), + ) + }) + .collect(); + + Arc::new(Schema::new(fields)) + } +} + +#[derive(Debug)] +struct UppercaseSchemaMapper { + output_schema: SchemaRef, + projection: Vec, +} + +impl SchemaMapper for UppercaseSchemaMapper { + fn map_batch(&self, batch: RecordBatch) -> Result { + let columns = self + .projection + .iter() + .map(|&i| batch.column(i).clone()) + .collect::>(); + Ok(RecordBatch::try_new(self.output_schema.clone(), columns)?) + } + + fn map_column_statistics( + &self, + stats: &[ColumnStatistics], + ) -> Result> { + Ok(self + .projection + .iter() + .map(|&i| stats.get(i).cloned().unwrap_or_default()) + .collect()) + } +} + +#[cfg(feature = "parquet")] +#[tokio::test] +async fn test_parquet_integration_with_schema_adapter() -> Result<()> { + // Create test data + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), + ], + )?; + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + let path = "test.parquet"; + write_parquet(batch.clone(), store.clone(), path).await; + + // Get the actual file size from the object store + let object_meta = store.head(&Path::from(path)).await?; + let file_size = object_meta.size; + + // Create a session context and register the object store + let ctx = SessionContext::new(); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + // Create a ParquetSource with the adapter factory + let file_source = ParquetSource::default() + .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {}))?; + + // Create a table schema with uppercase column names + let table_schema = Arc::new(Schema::new(vec![ + Field::new("ID", DataType::Int32, false), + Field::new("NAME", DataType::Utf8, true), + ])); + + let config = FileScanConfigBuilder::new(store_url, table_schema.clone(), file_source) + .with_file(PartitionedFile::new(path, file_size)) + .build(); + + // Create a data source executor + let exec = DataSourceExec::from_data_source(config); + + // Collect results + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(stream).await?; + + // There should be one batch + assert_eq!(batches.len(), 1); + + // Verify the schema has the uppercase column names + let result_schema = batches[0].schema(); + assert_eq!(result_schema.field(0).name(), "ID"); + assert_eq!(result_schema.field(1).name(), "NAME"); + + Ok(()) +} + +#[cfg(feature = "parquet")] +#[tokio::test] +async fn test_parquet_integration_with_schema_adapter_and_expression_rewriter( +) -> Result<()> { + // Create test data + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), + ], + )?; + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + let path = "test.parquet"; + write_parquet(batch.clone(), store.clone(), path).await; + + // Get the actual file size from the object store + let object_meta = store.head(&Path::from(path)).await?; + let file_size = object_meta.size; + + // Create a session context and register the object store + let ctx = SessionContext::new(); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + // Create a ParquetSource with the adapter factory + let file_source = ParquetSource::default() + .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {}))?; + + let config = FileScanConfigBuilder::new(store_url, batch.schema(), file_source) + .with_file(PartitionedFile::new(path, file_size)) + .build(); + + // Create a data source executor + let exec = DataSourceExec::from_data_source(config); + + // Collect results + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(stream).await?; + + // There should be one batch + assert_eq!(batches.len(), 1); + + // Verify the schema has the original column names (schema adapter not applied in DataSourceExec) + let result_schema = batches[0].schema(); + assert_eq!(result_schema.field(0).name(), "id"); + assert_eq!(result_schema.field(1).name(), "name"); + + Ok(()) +} + +#[tokio::test] +async fn test_multi_source_schema_adapter_reuse() -> Result<()> { + // This test verifies that the same schema adapter factory can be reused + // across different file source types. This is important for ensuring that: + // 1. The schema adapter factory interface works uniformly across all source types + // 2. The factory can be shared and cloned efficiently using Arc + // 3. Various data source implementations correctly implement the schema adapter factory pattern + + // Create a test factory + let factory = Arc::new(UppercaseAdapterFactory {}); + + // Test ArrowSource + { + let source = ArrowSource::default(); + let source_with_adapter = source + .clone() + .with_schema_adapter_factory(factory.clone()) + .unwrap(); + + let base_source: Arc = source.into(); + assert!(base_source.schema_adapter_factory().is_none()); + assert!(source_with_adapter.schema_adapter_factory().is_some()); + + let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); + assert_eq!( + format!("{:?}", retrieved_factory.as_ref()), + format!("{:?}", factory.as_ref()) + ); + } + + // Test ParquetSource + #[cfg(feature = "parquet")] + { + let source = ParquetSource::default(); + let source_with_adapter = source + .clone() + .with_schema_adapter_factory(factory.clone()) + .unwrap(); + + let base_source: Arc = source.into(); + assert!(base_source.schema_adapter_factory().is_none()); + assert!(source_with_adapter.schema_adapter_factory().is_some()); + + let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); + assert_eq!( + format!("{:?}", retrieved_factory.as_ref()), + format!("{:?}", factory.as_ref()) + ); + } + + // Test CsvSource + { + let source = CsvSource::default(); + let source_with_adapter = source + .clone() + .with_schema_adapter_factory(factory.clone()) + .unwrap(); + + let base_source: Arc = source.into(); + assert!(base_source.schema_adapter_factory().is_none()); + assert!(source_with_adapter.schema_adapter_factory().is_some()); + + let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); + assert_eq!( + format!("{:?}", retrieved_factory.as_ref()), + format!("{:?}", factory.as_ref()) + ); + } + + // Test JsonSource + { + let source = JsonSource::default(); + let source_with_adapter = source + .clone() + .with_schema_adapter_factory(factory.clone()) + .unwrap(); + + let base_source: Arc = source.into(); + assert!(base_source.schema_adapter_factory().is_none()); + assert!(source_with_adapter.schema_adapter_factory().is_some()); + + let retrieved_factory = source_with_adapter.schema_adapter_factory().unwrap(); + assert_eq!( + format!("{:?}", retrieved_factory.as_ref()), + format!("{:?}", factory.as_ref()) + ); + } + + Ok(()) +} diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates/basic.rs similarity index 78% rename from datafusion/core/tests/sql/aggregates.rs rename to datafusion/core/tests/sql/aggregates/basic.rs index 52372e01d41ac..4b421b5294e01 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -16,7 +16,10 @@ // under the License. use super::*; -use datafusion::scalar::ScalarValue; +use datafusion::common::test_util::batches_to_string; +use datafusion_catalog::MemTable; +use datafusion_common::ScalarValue; +use insta::assert_snapshot; #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { @@ -45,11 +48,11 @@ async fn csv_query_array_agg_distinct() -> Result<()> { let column = actual[0].column(0); assert_eq!(column.len(), 1); let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?; - let mut scalars = scalar_vec[0].clone(); + let mut scalars = scalar_vec[0].as_ref().unwrap().clone(); // workaround lack of Ord of ScalarValue let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") + a.try_cmp(b).expect("Can compare ScalarValues") }; scalars.sort_by(cmp); assert_eq!( @@ -321,3 +324,120 @@ async fn test_accumulator_row_accumulator() -> Result<()> { Ok(()) } + +/// Test that COUNT(DISTINCT) correctly handles dictionary arrays with all null values. +/// Verifies behavior across both single and multiple partitions. +#[tokio::test] +async fn count_distinct_dictionary_all_null_values() -> Result<()> { + let n: usize = 5; + let num = Arc::new(Int32Array::from_iter(0..n as i32)) as ArrayRef; + + // Create dictionary where all indices point to a null value (index 0) + let dict_values = StringArray::from(vec![None, Some("abc")]); + let dict_indices = Int32Array::from(vec![0; n]); + let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values)); + + let schema = Arc::new(Schema::new(vec![ + Field::new("num1", DataType::Int32, false), + Field::new("num2", DataType::Int32, false), + Field::new( + "dict", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![num.clone(), num.clone(), Arc::new(dict)], + )?; + + // Test with single partition + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(1)); + let provider = MemTable::try_new(schema.clone(), vec![vec![batch.clone()]])?; + ctx.register_table("t", Arc::new(provider))?; + + let df = ctx + .sql("SELECT count(distinct dict) as cnt, count(num2) FROM t GROUP BY num1") + .await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r###" + +-----+---------------+ + | cnt | count(t.num2) | + +-----+---------------+ + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + +-----+---------------+ + "### + ); + + // Test with multiple partitions + let ctx_multi = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2)); + let provider_multi = MemTable::try_new(schema, vec![vec![batch]])?; + ctx_multi.register_table("t", Arc::new(provider_multi))?; + + let df_multi = ctx_multi + .sql("SELECT count(distinct dict) as cnt, count(num2) FROM t GROUP BY num1") + .await?; + let results_multi = df_multi.collect().await?; + + // Results should be identical across partition configurations + assert_eq!( + batches_to_string(&results), + batches_to_string(&results_multi) + ); + + Ok(()) +} + +/// Test COUNT(DISTINCT) with mixed null and non-null dictionary values +#[tokio::test] +async fn count_distinct_dictionary_mixed_values() -> Result<()> { + let n: usize = 6; + let num = Arc::new(Int32Array::from_iter(0..n as i32)) as ArrayRef; + + // Dictionary values array with nulls and non-nulls + let dict_values = StringArray::from(vec![None, Some("abc"), Some("def"), None]); + // Create indices that point to both null and non-null values + let dict_indices = Int32Array::from(vec![0, 1, 2, 0, 1, 3]); + let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values)); + + let schema = Arc::new(Schema::new(vec![ + Field::new("num1", DataType::Int32, false), + Field::new( + "dict", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ])); + + let batch = RecordBatch::try_new(schema.clone(), vec![num, Arc::new(dict)])?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::new(provider))?; + + // COUNT(DISTINCT) should only count non-null values "abc" and "def" + let df = ctx.sql("SELECT count(distinct dict) FROM t").await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r###" + +------------------------+ + | count(DISTINCT t.dict) | + +------------------------+ + | 2 | + +------------------------+ + "### + ); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/aggregates/dict_nulls.rs b/datafusion/core/tests/sql/aggregates/dict_nulls.rs new file mode 100644 index 0000000000000..da4b2c8d25c9d --- /dev/null +++ b/datafusion/core/tests/sql/aggregates/dict_nulls.rs @@ -0,0 +1,454 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; +use datafusion::common::test_util::batches_to_string; +use insta::assert_snapshot; + +/// Comprehensive test for aggregate functions with null values and dictionary columns +/// Tests COUNT, SUM, MIN, and MEDIAN null handling in single comprehensive test +#[tokio::test] +async fn test_aggregates_null_handling_comprehensive() -> Result<()> { + let test_data_basic = TestData::new(); + let test_data_extended = TestData::new_extended(); + let test_data_min_max = TestData::new_for_min_max(); + let test_data_median = TestData::new_for_median(); + + // Test COUNT null exclusion with basic data + let sql_count = "SELECT dict_null_keys, COUNT(value) as cnt FROM t GROUP BY dict_null_keys ORDER BY dict_null_keys NULLS FIRST"; + let results_count = run_snapshot_test(&test_data_basic, sql_count).await?; + + assert_snapshot!( + batches_to_string(&results_count), + @r###" + +----------------+-----+ + | dict_null_keys | cnt | + +----------------+-----+ + | | 0 | + | group_a | 2 | + | group_b | 1 | + +----------------+-----+ + "### + ); + + // Test SUM null handling with extended data + let sql_sum = "SELECT dict_null_vals, SUM(value) as total FROM t GROUP BY dict_null_vals ORDER BY dict_null_vals NULLS FIRST"; + let results_sum = run_snapshot_test(&test_data_extended, sql_sum).await?; + + assert_snapshot!( + batches_to_string(&results_sum), + @r" + +----------------+-------+ + | dict_null_vals | total | + +----------------+-------+ + | | 4 | + | group_x | 4 | + | group_y | 2 | + | group_z | 5 | + +----------------+-------+ + " + ); + + // Test MIN null handling with min/max data + let sql_min = "SELECT dict_null_keys, MIN(value) as minimum FROM t GROUP BY dict_null_keys ORDER BY dict_null_keys NULLS FIRST"; + let results_min = run_snapshot_test(&test_data_min_max, sql_min).await?; + + assert_snapshot!( + batches_to_string(&results_min), + @r###" + +----------------+---------+ + | dict_null_keys | minimum | + +----------------+---------+ + | | 2 | + | group_a | 3 | + | group_b | 1 | + | group_c | 7 | + +----------------+---------+ + "### + ); + + // Test MEDIAN null handling with median data + let sql_median = "SELECT dict_null_vals, MEDIAN(value) as median_value FROM t GROUP BY dict_null_vals ORDER BY dict_null_vals NULLS FIRST"; + let results_median = run_snapshot_test(&test_data_median, sql_median).await?; + + assert_snapshot!( + batches_to_string(&results_median), + @r" + +----------------+--------------+ + | dict_null_vals | median_value | + +----------------+--------------+ + | | 3 | + | group_x | 1 | + | group_y | 5 | + | group_z | 7 | + +----------------+--------------+ + "); + + Ok(()) +} + +/// Test FIRST_VAL and LAST_VAL with null values and GROUP BY dict with null keys and null values - may return null if first/last value is null (single and multiple partitions) +#[tokio::test] +async fn test_first_last_val_null_handling() -> Result<()> { + let test_data = TestData::new_for_first_last(); + + // Test FIRST_VALUE and LAST_VALUE with window functions over groups + let sql = "SELECT dict_null_keys, value, FIRST_VALUE(value) OVER (PARTITION BY dict_null_keys ORDER BY value NULLS FIRST) as first_val, LAST_VALUE(value) OVER (PARTITION BY dict_null_keys ORDER BY value NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as last_val FROM t ORDER BY dict_null_keys NULLS FIRST, value NULLS FIRST"; + + let results_single = run_snapshot_test(&test_data, sql).await?; + + assert_snapshot!(batches_to_string(&results_single), @r" + +----------------+-------+-----------+----------+ + | dict_null_keys | value | first_val | last_val | + +----------------+-------+-----------+----------+ + | | 1 | 1 | 3 | + | | 3 | 1 | 3 | + | group_a | | | | + | group_a | | | | + | group_b | 2 | 2 | 2 | + +----------------+-------+-----------+----------+ + "); + + Ok(()) +} + +/// Test FIRST_VALUE and LAST_VALUE with ORDER BY - comprehensive null handling +#[tokio::test] +async fn test_first_last_value_order_by_null_handling() -> Result<()> { + let ctx = SessionContext::new(); + + // Create test data with nulls mixed in + let dict_keys = create_test_dict( + &[Some("group_a"), Some("group_b"), Some("group_c")], + &[Some(0), Some(1), Some(2), Some(0), Some(1)], + ); + + let values = Int32Array::from(vec![None, Some(10), Some(20), Some(5), None]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_group", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(dict_keys), Arc::new(values)], + )?; + + let table = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("test_data", Arc::new(table))?; + + // Test all combinations of FIRST_VALUE and LAST_VALUE with null handling + let sql = "SELECT + dict_group, + value, + FIRST_VALUE(value IGNORE NULLS) OVER (ORDER BY value NULLS LAST) as first_ignore_nulls, + FIRST_VALUE(value RESPECT NULLS) OVER (ORDER BY value NULLS FIRST) as first_respect_nulls, + LAST_VALUE(value IGNORE NULLS) OVER (ORDER BY value NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as last_ignore_nulls, + LAST_VALUE(value RESPECT NULLS) OVER (ORDER BY value NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as last_respect_nulls + FROM test_data + ORDER BY value NULLS LAST"; + + let df = ctx.sql(sql).await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r###" + +------------+-------+--------------------+---------------------+-------------------+--------------------+ + | dict_group | value | first_ignore_nulls | first_respect_nulls | last_ignore_nulls | last_respect_nulls | + +------------+-------+--------------------+---------------------+-------------------+--------------------+ + | group_a | 5 | 5 | | 20 | | + | group_b | 10 | 5 | | 20 | | + | group_c | 20 | 5 | | 20 | | + | group_a | | 5 | | 20 | | + | group_b | | 5 | | 20 | | + +------------+-------+--------------------+---------------------+-------------------+--------------------+ + "### + ); + + Ok(()) +} + +/// Test GROUP BY with dictionary columns containing null keys and values for FIRST_VALUE/LAST_VALUE +#[tokio::test] +async fn test_first_last_value_group_by_dict_nulls() -> Result<()> { + let ctx = SessionContext::new(); + + // Create dictionary with null keys + let dict_null_keys = create_test_dict( + &[Some("group_a"), Some("group_b")], + &[ + Some(0), // group_a + None, // null key + Some(1), // group_b + None, // null key + Some(0), // group_a + ], + ); + + // Create dictionary with null values + let dict_null_vals = create_test_dict( + &[Some("val_x"), None, Some("val_y")], + &[ + Some(0), // val_x + Some(1), // null value + Some(2), // val_y + Some(1), // null value + Some(0), // val_x + ], + ); + + // Create test values + let values = Int32Array::from(vec![Some(10), Some(20), Some(30), Some(40), Some(50)]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(dict_null_keys), + Arc::new(dict_null_vals), + Arc::new(values), + ], + )?; + + let table = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("test_data", Arc::new(table))?; + + // Test GROUP BY with null keys + let sql = "SELECT + dict_null_keys, + FIRST_VALUE(value) as first_val, + LAST_VALUE(value) as last_val, + COUNT(*) as cnt + FROM test_data + GROUP BY dict_null_keys + ORDER BY dict_null_keys NULLS FIRST"; + + let df = ctx.sql(sql).await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r###" + +----------------+-----------+----------+-----+ + | dict_null_keys | first_val | last_val | cnt | + +----------------+-----------+----------+-----+ + | | 20 | 40 | 2 | + | group_a | 10 | 50 | 2 | + | group_b | 30 | 30 | 1 | + +----------------+-----------+----------+-----+ + "### + ); + + // Test GROUP BY with null values in dictionary + let sql2 = "SELECT + dict_null_vals, + FIRST_VALUE(value) as first_val, + LAST_VALUE(value) as last_val, + COUNT(*) as cnt + FROM test_data + GROUP BY dict_null_vals + ORDER BY dict_null_vals NULLS FIRST"; + + let df2 = ctx.sql(sql2).await?; + let results2 = df2.collect().await?; + + assert_snapshot!( + batches_to_string(&results2), + @r###" + +----------------+-----------+----------+-----+ + | dict_null_vals | first_val | last_val | cnt | + +----------------+-----------+----------+-----+ + | | 20 | 40 | 2 | + | val_x | 10 | 50 | 2 | + | val_y | 30 | 30 | 1 | + +----------------+-----------+----------+-----+ + "### + ); + + Ok(()) +} + +/// Test MAX with dictionary columns containing null keys and values as specified in the SQL query +#[tokio::test] +async fn test_max_with_fuzz_table_dict_nulls() -> Result<()> { + let (ctx_single, ctx_multi) = setup_fuzz_test_contexts().await?; + + // Execute the SQL query with MAX aggregations + let sql = "SELECT + u8_low, + dictionary_utf8_low, + utf8_low, + max(utf8_low) as col1, + max(utf8) as col2 + FROM + fuzz_table + GROUP BY + u8_low, + dictionary_utf8_low, + utf8_low + ORDER BY u8_low, dictionary_utf8_low NULLS FIRST, utf8_low"; + + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +--------+---------------------+----------+-------+---------+ + | u8_low | dictionary_utf8_low | utf8_low | col1 | col2 | + +--------+---------------------+----------+-------+---------+ + | 1 | | str_b | str_b | value_2 | + | 1 | dict_a | str_a | str_a | value_5 | + | 2 | | str_c | str_c | value_7 | + | 2 | | str_d | str_d | value_4 | + | 2 | dict_b | str_c | str_c | value_3 | + | 3 | | str_e | str_e | | + | 3 | dict_c | str_f | str_f | value_6 | + +--------+---------------------+----------+-------+---------+ + "); + + Ok(()) +} + +/// Test MIN with fuzz table containing dictionary columns with null keys and values and timestamp data (single and multiple partitions) +#[tokio::test] +async fn test_min_timestamp_with_fuzz_table_dict_nulls() -> Result<()> { + let (ctx_single, ctx_multi) = setup_fuzz_timestamp_test_contexts().await?; + + // Execute the SQL query with MIN aggregation on timestamp + let sql = "SELECT + utf8_low, + u8_low, + dictionary_utf8_low, + min(timestamp_us) as col1 + FROM + fuzz_table + GROUP BY + utf8_low, + u8_low, + dictionary_utf8_low + ORDER BY utf8_low, u8_low, dictionary_utf8_low NULLS FIRST"; + + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +----------+--------+---------------------+-------------------------+ + | utf8_low | u8_low | dictionary_utf8_low | col1 | + +----------+--------+---------------------+-------------------------+ + | alpha | 10 | dict_x | 1970-01-01T00:00:01 | + | beta | 20 | | 1970-01-01T00:00:02 | + | delta | 20 | | 1970-01-01T00:00:03.500 | + | epsilon | 40 | | 1970-01-01T00:00:04 | + | gamma | 30 | dict_y | 1970-01-01T00:00:02.800 | + | zeta | 30 | dict_z | 1970-01-01T00:00:02.500 | + +----------+--------+---------------------+-------------------------+ + " + ); + + Ok(()) +} + +/// Test COUNT and COUNT DISTINCT with fuzz table containing dictionary columns with null keys and values (single and multiple partitions) +#[tokio::test] +async fn test_count_distinct_with_fuzz_table_dict_nulls() -> Result<()> { + let (ctx_single, ctx_multi) = setup_fuzz_count_test_contexts().await?; + + // Execute the SQL query with COUNT and COUNT DISTINCT aggregations + let sql = "SELECT + u8_low, + utf8_low, + dictionary_utf8_low, + count(duration_nanosecond) as col1, + count(DISTINCT large_binary) as col2 + FROM + fuzz_table + GROUP BY + u8_low, + utf8_low, + dictionary_utf8_low + ORDER BY u8_low, utf8_low, dictionary_utf8_low NULLS FIRST"; + + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + + assert_snapshot!( + batches_to_string(&results), + @r###" + +--------+----------+---------------------+------+------+ + | u8_low | utf8_low | dictionary_utf8_low | col1 | col2 | + +--------+----------+---------------------+------+------+ + | 5 | text_a | group_alpha | 3 | 1 | + | 10 | text_b | | 1 | 1 | + | 10 | text_d | | 2 | 0 | + | 15 | text_c | group_beta | 1 | 1 | + | 20 | text_e | | 0 | 1 | + | 25 | text_f | group_gamma | 1 | 1 | + +--------+----------+---------------------+------+------+ + "### + ); + + Ok(()) +} + +/// Test MEDIAN and MEDIAN DISTINCT with fuzz table containing various numeric types and dictionary columns with null keys and values (single and multiple partitions) +#[tokio::test] +async fn test_median_distinct_with_fuzz_table_dict_nulls() -> Result<()> { + let (ctx_single, ctx_multi) = setup_fuzz_median_test_contexts().await?; + + // Execute the SQL query with MEDIAN and MEDIAN DISTINCT aggregations + let sql = "SELECT + u8_low, + dictionary_utf8_low, + median(DISTINCT u64) as col1, + median(DISTINCT u16) as col2, + median(u64) as col3, + median(decimal128) as col4, + median(DISTINCT u32) as col5 + FROM + fuzz_table + GROUP BY + u8_low, + dictionary_utf8_low + ORDER BY u8_low, dictionary_utf8_low NULLS FIRST"; + + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + + assert_snapshot!( + batches_to_string(&results), + @r" + +--------+---------------------+------+------+------+--------+--------+ + | u8_low | dictionary_utf8_low | col1 | col2 | col3 | col4 | col5 | + +--------+---------------------+------+------+------+--------+--------+ + | 50 | | | 30 | | 987.65 | 400000 | + | 50 | group_three | 5000 | 50 | 5000 | 555.55 | 500000 | + | 75 | | 4000 | | 4000 | | 450000 | + | 100 | group_one | 1100 | 11 | 1000 | 123.45 | 110000 | + | 100 | group_two | 1500 | 15 | 1500 | 111.11 | 150000 | + | 200 | | 2500 | 22 | 2500 | 506.11 | 250000 | + +--------+---------------------+------+------+------+--------+--------+ + " + ); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/aggregates/mod.rs b/datafusion/core/tests/sql/aggregates/mod.rs new file mode 100644 index 0000000000000..321c158628e43 --- /dev/null +++ b/datafusion/core/tests/sql/aggregates/mod.rs @@ -0,0 +1,1026 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Aggregate function tests + +use super::*; +use arrow::{ + array::{ + types::UInt32Type, Decimal128Array, DictionaryArray, DurationNanosecondArray, + Int32Array, LargeBinaryArray, StringArray, TimestampMicrosecondArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::{DataType, Field, Schema, TimeUnit}, + record_batch::RecordBatch, +}; +use datafusion::{ + common::{test_util::batches_to_string, Result}, + execution::{config::SessionConfig, context::SessionContext}, +}; +use datafusion_catalog::MemTable; +use std::{cmp::min, sync::Arc}; +/// Helper function to create the commonly used UInt32 indexed UTF-8 dictionary data type +pub fn string_dict_type() -> DataType { + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)) +} + +/// Helper functions for aggregate tests with dictionary columns and nulls +/// Creates a dictionary array with null values in the dictionary +pub fn create_test_dict( + values: &[Option<&str>], + indices: &[Option], +) -> DictionaryArray { + let dict_values = StringArray::from(values.to_vec()); + let dict_indices = UInt32Array::from(indices.to_vec()); + DictionaryArray::new(dict_indices, Arc::new(dict_values)) +} + +/// Creates test data with both dictionary columns and value column +pub struct TestData { + pub dict_null_keys: DictionaryArray, + pub dict_null_vals: DictionaryArray, + pub values: Int32Array, + pub schema: Arc, +} + +impl TestData { + pub fn new() -> Self { + // Create dictionary with null keys + let dict_null_keys = create_test_dict( + &[Some("group_a"), Some("group_b")], + &[ + Some(0), // group_a + None, // null key + Some(1), // group_b + None, // null key + Some(0), // group_a + ], + ); + + // Create dictionary with null values + let dict_null_vals = create_test_dict( + &[Some("group_x"), None, Some("group_y")], + &[ + Some(0), // group_x + Some(1), // null value + Some(2), // group_y + Some(1), // null value + Some(0), // group_x + ], + ); + + // Create test data with nulls + let values = Int32Array::from(vec![Some(1), None, Some(2), None, Some(3)]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } + + /// Creates extended test data for more comprehensive testing + pub fn new_extended() -> Self { + // Create dictionary with null values in the dictionary array + let dict_null_vals = create_test_dict( + &[Some("group_a"), None, Some("group_b")], + &[ + Some(0), // group_a + Some(1), // null value + Some(2), // group_b + Some(1), // null value + Some(0), // group_a + Some(1), // null value + Some(2), // group_b + Some(1), // null value + ], + ); + + // Create dictionary with null keys + let dict_null_keys = create_test_dict( + &[Some("group_x"), Some("group_y"), Some("group_z")], + &[ + Some(0), // group_x + None, // null key + Some(1), // group_y + None, // null key + Some(0), // group_x + None, // null key + Some(2), // group_z + None, // null key + ], + ); + + // Create test data with nulls + let values = Int32Array::from(vec![ + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + Some(5), + None, + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } + + /// Creates test data for MIN/MAX testing with varied values + pub fn new_for_min_max() -> Self { + let dict_null_keys = create_test_dict( + &[Some("group_a"), Some("group_b"), Some("group_c")], + &[ + Some(0), + Some(1), + Some(0), + Some(2), + None, + None, // group_a, group_b, group_a, group_c, null, null + ], + ); + + let dict_null_vals = create_test_dict( + &[Some("group_x"), None, Some("group_y")], + &[ + Some(0), + Some(1), + Some(0), + Some(2), + Some(1), + Some(1), // group_x, null, group_x, group_y, null, null + ], + ); + + let values = + Int32Array::from(vec![Some(5), Some(1), Some(3), Some(7), Some(2), None]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } + + /// Creates test data for MEDIAN testing with varied values + pub fn new_for_median() -> Self { + let dict_null_vals = create_test_dict( + &[Some("group_a"), None, Some("group_b")], + &[Some(0), Some(1), Some(2), Some(1), Some(0)], + ); + + let dict_null_keys = create_test_dict( + &[Some("group_x"), Some("group_y"), Some("group_z")], + &[Some(0), None, Some(1), None, Some(2)], + ); + + let values = Int32Array::from(vec![Some(1), None, Some(5), Some(3), Some(7)]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } + + /// Creates test data for FIRST_VALUE/LAST_VALUE testing + pub fn new_for_first_last() -> Self { + let dict_null_keys = create_test_dict( + &[Some("group_a"), Some("group_b")], + &[Some(0), None, Some(1), None, Some(0)], + ); + + let dict_null_vals = create_test_dict( + &[Some("group_x"), None, Some("group_y")], + &[Some(0), Some(1), Some(2), Some(1), Some(0)], + ); + + let values = Int32Array::from(vec![None, Some(1), Some(2), Some(3), None]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict_null_keys", string_dict_type(), true), + Field::new("dict_null_vals", string_dict_type(), true), + Field::new("value", DataType::Int32, true), + ])); + + Self { + dict_null_keys, + dict_null_vals, + values, + schema, + } + } +} + +/// Sets up test contexts for TestData with both single and multiple partitions +pub async fn setup_test_contexts( + test_data: &TestData, +) -> Result<(SessionContext, SessionContext)> { + // Single partition context + let ctx_single = create_context_with_partitions(test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_context_with_partitions(test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with the specified number of partitions and registers test data +pub async fn create_context_with_partitions( + test_data: &TestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_test_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("t", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits test data into multiple batches for partitioning +pub fn split_test_data_into_batches( + test_data: &TestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.values.len(); + let chunk_size = total_len.div_ceil(num_partitions); // Ensure we cover all data + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.dict_null_keys.slice(start, len)), + Arc::new(test_data.dict_null_vals.slice(start, len)), + Arc::new(test_data.values.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +/// Executes a query on both single and multi-partition contexts and verifies consistency +pub async fn test_query_consistency( + ctx_single: &SessionContext, + ctx_multi: &SessionContext, + sql: &str, +) -> Result> { + let df_single = ctx_single.sql(sql).await?; + let results_single = df_single.collect().await?; + + let df_multi = ctx_multi.sql(sql).await?; + let results_multi = df_multi.collect().await?; + + // Verify results are consistent between single and multiple partitions + assert_eq!( + batches_to_string(&results_single), + batches_to_string(&results_multi), + "Results should be identical between single and multiple partitions" + ); + + Ok(results_single) +} + +/// Helper function to run snapshot tests with consistent setup, execution, and assertion +/// This reduces the repetitive pattern of "setup data → SQL → assert_snapshot!" +pub async fn run_snapshot_test( + test_data: &TestData, + sql: &str, +) -> Result> { + let (ctx_single, ctx_multi) = setup_test_contexts(test_data).await?; + let results = test_query_consistency(&ctx_single, &ctx_multi, sql).await?; + Ok(results) +} + +/// Test data structure for fuzz table with dictionary columns containing nulls +pub struct FuzzTestData { + pub schema: Arc, + pub u8_low: UInt8Array, + pub dictionary_utf8_low: DictionaryArray, + pub utf8_low: StringArray, + pub utf8: StringArray, +} + +impl FuzzTestData { + pub fn new() -> Self { + // Create dictionary columns with null keys and values + let dictionary_utf8_low = create_test_dict( + &[Some("dict_a"), None, Some("dict_b"), Some("dict_c")], + &[ + Some(0), // dict_a + Some(1), // null value + Some(2), // dict_b + None, // null key + Some(0), // dict_a + Some(1), // null value + Some(3), // dict_c + None, // null key + ], + ); + + let u8_low = UInt8Array::from(vec![ + Some(1), + Some(1), + Some(2), + Some(2), + Some(1), + Some(3), + Some(3), + Some(2), + ]); + + let utf8_low = StringArray::from(vec![ + Some("str_a"), + Some("str_b"), + Some("str_c"), + Some("str_d"), + Some("str_a"), + Some("str_e"), + Some("str_f"), + Some("str_c"), + ]); + + let utf8 = StringArray::from(vec![ + Some("value_1"), + Some("value_2"), + Some("value_3"), + Some("value_4"), + Some("value_5"), + None, + Some("value_6"), + Some("value_7"), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("u8_low", DataType::UInt8, true), + Field::new("dictionary_utf8_low", string_dict_type(), true), + Field::new("utf8_low", DataType::Utf8, true), + Field::new("utf8", DataType::Utf8, true), + ])); + + Self { + schema, + u8_low, + dictionary_utf8_low, + utf8_low, + utf8, + } + } +} + +/// Sets up test contexts for fuzz table with both single and multiple partitions +pub async fn setup_fuzz_test_contexts() -> Result<(SessionContext, SessionContext)> { + let test_data = FuzzTestData::new(); + + // Single partition context + let ctx_single = create_fuzz_context_with_partitions(&test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_fuzz_context_with_partitions(&test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with fuzz table partitioned into specified number of partitions +pub async fn create_fuzz_context_with_partitions( + test_data: &FuzzTestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_fuzz_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("fuzz_table", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits fuzz test data into multiple batches for partitioning +pub fn split_fuzz_data_into_batches( + test_data: &FuzzTestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.u8_low.len(); + let chunk_size = total_len.div_ceil(num_partitions); + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.u8_low.slice(start, len)), + Arc::new(test_data.dictionary_utf8_low.slice(start, len)), + Arc::new(test_data.utf8_low.slice(start, len)), + Arc::new(test_data.utf8.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +/// Test data structure for fuzz table with duration, large_binary and dictionary columns containing nulls +pub struct FuzzCountTestData { + pub schema: Arc, + pub u8_low: UInt8Array, + pub utf8_low: StringArray, + pub dictionary_utf8_low: DictionaryArray, + pub duration_nanosecond: DurationNanosecondArray, + pub large_binary: LargeBinaryArray, +} + +impl FuzzCountTestData { + pub fn new() -> Self { + // Create dictionary columns with null keys and values + let dictionary_utf8_low = create_test_dict( + &[ + Some("group_alpha"), + None, + Some("group_beta"), + Some("group_gamma"), + ], + &[ + Some(0), // group_alpha + Some(1), // null value + Some(2), // group_beta + None, // null key + Some(0), // group_alpha + Some(1), // null value + Some(3), // group_gamma + None, // null key + Some(2), // group_beta + Some(0), // group_alpha + ], + ); + + let u8_low = UInt8Array::from(vec![ + Some(5), + Some(10), + Some(15), + Some(10), + Some(5), + Some(20), + Some(25), + Some(10), + Some(15), + Some(5), + ]); + + let utf8_low = StringArray::from(vec![ + Some("text_a"), + Some("text_b"), + Some("text_c"), + Some("text_d"), + Some("text_a"), + Some("text_e"), + Some("text_f"), + Some("text_d"), + Some("text_c"), + Some("text_a"), + ]); + + // Create duration data with some nulls (nanoseconds) + let duration_nanosecond = DurationNanosecondArray::from(vec![ + Some(1000000000), // 1 second + Some(2000000000), // 2 seconds + None, // null duration + Some(3000000000), // 3 seconds + Some(1500000000), // 1.5 seconds + None, // null duration + Some(4000000000), // 4 seconds + Some(2500000000), // 2.5 seconds + Some(3500000000), // 3.5 seconds + Some(1200000000), // 1.2 seconds + ]); + + // Create large binary data with some nulls and duplicates + let large_binary = LargeBinaryArray::from(vec![ + Some(b"binary_data_1".as_slice()), + Some(b"binary_data_2".as_slice()), + Some(b"binary_data_3".as_slice()), + None, // null binary + Some(b"binary_data_1".as_slice()), // duplicate + Some(b"binary_data_4".as_slice()), + Some(b"binary_data_5".as_slice()), + None, // null binary + Some(b"binary_data_3".as_slice()), // duplicate + Some(b"binary_data_1".as_slice()), // duplicate + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("u8_low", DataType::UInt8, true), + Field::new("utf8_low", DataType::Utf8, true), + Field::new("dictionary_utf8_low", string_dict_type(), true), + Field::new( + "duration_nanosecond", + DataType::Duration(TimeUnit::Nanosecond), + true, + ), + Field::new("large_binary", DataType::LargeBinary, true), + ])); + + Self { + schema, + u8_low, + utf8_low, + dictionary_utf8_low, + duration_nanosecond, + large_binary, + } + } +} + +/// Sets up test contexts for fuzz table with duration/binary columns and both single and multiple partitions +pub async fn setup_fuzz_count_test_contexts() -> Result<(SessionContext, SessionContext)> +{ + let test_data = FuzzCountTestData::new(); + + // Single partition context + let ctx_single = create_fuzz_count_context_with_partitions(&test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_fuzz_count_context_with_partitions(&test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with fuzz count table partitioned into specified number of partitions +pub async fn create_fuzz_count_context_with_partitions( + test_data: &FuzzCountTestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_fuzz_count_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("fuzz_table", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits fuzz count test data into multiple batches for partitioning +pub fn split_fuzz_count_data_into_batches( + test_data: &FuzzCountTestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.u8_low.len(); + let chunk_size = total_len.div_ceil(num_partitions); + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.u8_low.slice(start, len)), + Arc::new(test_data.utf8_low.slice(start, len)), + Arc::new(test_data.dictionary_utf8_low.slice(start, len)), + Arc::new(test_data.duration_nanosecond.slice(start, len)), + Arc::new(test_data.large_binary.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +/// Test data structure for fuzz table with numeric types for median testing and dictionary columns containing nulls +pub struct FuzzMedianTestData { + pub schema: Arc, + pub u8_low: UInt8Array, + pub dictionary_utf8_low: DictionaryArray, + pub u64: UInt64Array, + pub u16: UInt16Array, + pub u32: UInt32Array, + pub decimal128: Decimal128Array, +} + +impl FuzzMedianTestData { + pub fn new() -> Self { + // Create dictionary columns with null keys and values + let dictionary_utf8_low = create_test_dict( + &[ + Some("group_one"), + None, + Some("group_two"), + Some("group_three"), + ], + &[ + Some(0), // group_one + Some(1), // null value + Some(2), // group_two + None, // null key + Some(0), // group_one + Some(1), // null value + Some(3), // group_three + None, // null key + Some(2), // group_two + Some(0), // group_one + Some(1), // null value + Some(3), // group_three + ], + ); + + let u8_low = UInt8Array::from(vec![ + Some(100), + Some(200), + Some(100), + Some(200), + Some(100), + Some(50), + Some(50), + Some(200), + Some(100), + Some(100), + Some(75), + Some(50), + ]); + + // Create u64 data with some nulls and duplicates for DISTINCT testing + let u64 = UInt64Array::from(vec![ + Some(1000), + Some(2000), + Some(1500), + Some(3000), + Some(1000), // duplicate + None, // null + Some(5000), + Some(2500), + Some(1500), // duplicate + Some(1200), + Some(4000), + Some(5000), // duplicate + ]); + + // Create u16 data with some nulls and duplicates + let u16 = UInt16Array::from(vec![ + Some(10), + Some(20), + Some(15), + None, // null + Some(10), // duplicate + Some(30), + Some(50), + Some(25), + Some(15), // duplicate + Some(12), + None, // null + Some(50), // duplicate + ]); + + // Create u32 data with some nulls and duplicates + let u32 = UInt32Array::from(vec![ + Some(100000), + Some(200000), + Some(150000), + Some(300000), + Some(100000), // duplicate + Some(400000), + Some(500000), + None, // null + Some(150000), // duplicate + Some(120000), + Some(450000), + None, // null + ]); + + // Create decimal128 data with precision 10, scale 2 + let decimal128 = Decimal128Array::from(vec![ + Some(12345), // 123.45 + Some(67890), // 678.90 + Some(11111), // 111.11 + None, // null + Some(12345), // 123.45 duplicate + Some(98765), // 987.65 + Some(55555), // 555.55 + Some(33333), // 333.33 + Some(11111), // 111.11 duplicate + Some(12500), // 125.00 + None, // null + Some(55555), // 555.55 duplicate + ]) + .with_precision_and_scale(10, 2) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("u8_low", DataType::UInt8, true), + Field::new("dictionary_utf8_low", string_dict_type(), true), + Field::new("u64", DataType::UInt64, true), + Field::new("u16", DataType::UInt16, true), + Field::new("u32", DataType::UInt32, true), + Field::new("decimal128", DataType::Decimal128(10, 2), true), + ])); + + Self { + schema, + u8_low, + dictionary_utf8_low, + u64, + u16, + u32, + decimal128, + } + } +} + +/// Sets up test contexts for fuzz table with numeric types for median testing and both single and multiple partitions +pub async fn setup_fuzz_median_test_contexts() -> Result<(SessionContext, SessionContext)> +{ + let test_data = FuzzMedianTestData::new(); + + // Single partition context + let ctx_single = create_fuzz_median_context_with_partitions(&test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_fuzz_median_context_with_partitions(&test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with fuzz median table partitioned into specified number of partitions +pub async fn create_fuzz_median_context_with_partitions( + test_data: &FuzzMedianTestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_fuzz_median_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("fuzz_table", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits fuzz median test data into multiple batches for partitioning +pub fn split_fuzz_median_data_into_batches( + test_data: &FuzzMedianTestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.u8_low.len(); + let chunk_size = total_len.div_ceil(num_partitions); + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.u8_low.slice(start, len)), + Arc::new(test_data.dictionary_utf8_low.slice(start, len)), + Arc::new(test_data.u64.slice(start, len)), + Arc::new(test_data.u16.slice(start, len)), + Arc::new(test_data.u32.slice(start, len)), + Arc::new(test_data.decimal128.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +/// Test data structure for fuzz table with timestamp and dictionary columns containing nulls +pub struct FuzzTimestampTestData { + pub schema: Arc, + pub utf8_low: StringArray, + pub u8_low: UInt8Array, + pub dictionary_utf8_low: DictionaryArray, + pub timestamp_us: TimestampMicrosecondArray, +} + +impl FuzzTimestampTestData { + pub fn new() -> Self { + // Create dictionary columns with null keys and values + let dictionary_utf8_low = create_test_dict( + &[Some("dict_x"), None, Some("dict_y"), Some("dict_z")], + &[ + Some(0), // dict_x + Some(1), // null value + Some(2), // dict_y + None, // null key + Some(0), // dict_x + Some(1), // null value + Some(3), // dict_z + None, // null key + Some(2), // dict_y + ], + ); + + let utf8_low = StringArray::from(vec![ + Some("alpha"), + Some("beta"), + Some("gamma"), + Some("delta"), + Some("alpha"), + Some("epsilon"), + Some("zeta"), + Some("delta"), + Some("gamma"), + ]); + + let u8_low = UInt8Array::from(vec![ + Some(10), + Some(20), + Some(30), + Some(20), + Some(10), + Some(40), + Some(30), + Some(20), + Some(30), + ]); + + // Create timestamp data with some nulls + let timestamp_us = TimestampMicrosecondArray::from(vec![ + Some(1000000), // 1970-01-01 00:00:01 + Some(2000000), // 1970-01-01 00:00:02 + Some(3000000), // 1970-01-01 00:00:03 + None, // null timestamp + Some(1500000), // 1970-01-01 00:00:01.5 + Some(4000000), // 1970-01-01 00:00:04 + Some(2500000), // 1970-01-01 00:00:02.5 + Some(3500000), // 1970-01-01 00:00:03.5 + Some(2800000), // 1970-01-01 00:00:02.8 + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("utf8_low", DataType::Utf8, true), + Field::new("u8_low", DataType::UInt8, true), + Field::new("dictionary_utf8_low", string_dict_type(), true), + Field::new( + "timestamp_us", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ])); + + Self { + schema, + utf8_low, + u8_low, + dictionary_utf8_low, + timestamp_us, + } + } +} + +/// Sets up test contexts for fuzz table with timestamps and both single and multiple partitions +pub async fn setup_fuzz_timestamp_test_contexts( +) -> Result<(SessionContext, SessionContext)> { + let test_data = FuzzTimestampTestData::new(); + + // Single partition context + let ctx_single = create_fuzz_timestamp_context_with_partitions(&test_data, 1).await?; + + // Multiple partition context + let ctx_multi = create_fuzz_timestamp_context_with_partitions(&test_data, 3).await?; + + Ok((ctx_single, ctx_multi)) +} + +/// Creates a session context with fuzz timestamp table partitioned into specified number of partitions +pub async fn create_fuzz_timestamp_context_with_partitions( + test_data: &FuzzTimestampTestData, + num_partitions: usize, +) -> Result { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_target_partitions(num_partitions), + ); + + let batches = split_fuzz_timestamp_data_into_batches(test_data, num_partitions)?; + let provider = MemTable::try_new(test_data.schema.clone(), batches)?; + ctx.register_table("fuzz_table", Arc::new(provider))?; + + Ok(ctx) +} + +/// Splits fuzz timestamp test data into multiple batches for partitioning +pub fn split_fuzz_timestamp_data_into_batches( + test_data: &FuzzTimestampTestData, + num_partitions: usize, +) -> Result>> { + debug_assert!(num_partitions > 0, "num_partitions must be greater than 0"); + let total_len = test_data.utf8_low.len(); + let chunk_size = total_len.div_ceil(num_partitions); + + let mut batches = Vec::new(); + let mut start = 0; + + while start < total_len { + let end = min(start + chunk_size, total_len); + let len = end - start; + + if len > 0 { + let batch = RecordBatch::try_new( + test_data.schema.clone(), + vec![ + Arc::new(test_data.utf8_low.slice(start, len)), + Arc::new(test_data.u8_low.slice(start, len)), + Arc::new(test_data.dictionary_utf8_low.slice(start, len)), + Arc::new(test_data.timestamp_us.slice(start, len)), + ], + )?; + batches.push(vec![batch]); + } + start = end; + } + + Ok(batches) +} + +pub mod basic; +pub mod dict_nulls; diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 83712053b9542..4a60a79ff5de3 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -61,8 +61,31 @@ async fn create_external_table_with_ddl() -> Result<()> { assert_eq!(3, table_schema.fields().len()); assert_eq!(&DataType::Int32, table_schema.field(0).data_type()); - assert_eq!(&DataType::Utf8, table_schema.field(1).data_type()); + assert_eq!(&DataType::Utf8View, table_schema.field(1).data_type()); assert_eq!(&DataType::Boolean, table_schema.field(2).data_type()); Ok(()) } + +#[tokio::test] +async fn create_drop_table() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "CREATE TABLE dt (a_id integer, a_str string, a_bool boolean);"; + ctx.sql(sql).await.unwrap(); + + let cat = ctx.catalog("datafusion").unwrap(); + let schema = cat.schema("public").unwrap(); + + let exists = schema.table_exist("dt"); + assert!(exists, "Table should have been created!"); + + // Drop the table + let sql = "DROP TABLE dt;"; + ctx.sql(sql).await.unwrap(); + + let exists = schema.table_exist("dt"); + assert!(!exists, "Table should have been dropped!"); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index e8ef34c2afe70..e082cabaadaff 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -16,11 +16,13 @@ // under the License. use super::*; +use insta::assert_snapshot; use rstest::rstest; use datafusion::config::ConfigOptions; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::metrics::Timestamp; +use object_store::path::Path; #[tokio::test] async fn explain_analyze_baseline_metrics() { @@ -52,6 +54,7 @@ async fn explain_analyze_baseline_metrics() { let formatted = arrow::util::pretty::pretty_format_batches(&results) .unwrap() .to_string(); + println!("Query Output:\n\n{formatted}"); assert_metrics!( @@ -174,69 +177,66 @@ async fn csv_explain_plans() { println!("SQL: {sql}"); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int64(10) + TableScan: aggregate_test_100 + "### ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Optimized logical plan @@ -248,69 +248,67 @@ async fn csv_explain_plans() { assert_eq!(logical_schema, optimized_logical_schema.as_ref()); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8] + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int8(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int8(10) + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] + + "### ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)\nSchema: [c1:Utf8View, c2:Int8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\nSchema: [c1:Utf8View, c2:Int8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Physical plan @@ -396,69 +394,66 @@ async fn csv_explain_verbose_plans() { // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - ]; let formatted = dataframe.logical_plan().display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100", - ]; let formatted = dataframe.logical_plan().display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int64(10) + TableScan: aggregate_test_100 + "### ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = dataframe.logical_plan().display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Optimized logical plan @@ -470,69 +465,66 @@ async fn csv_explain_verbose_plans() { assert_eq!(&logical_schema, optimized_logical_schema.as_ref()); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8] + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int8(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int8(10) + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] + "### ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)\nSchema: [c1:Utf8View, c2:Int8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\nSchema: [c1:Utf8View, c2:Int8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Physical plan @@ -561,7 +553,9 @@ async fn csv_explain_verbose_plans() { async fn explain_analyze_runs_optimizers(#[values("*", "1")] count_expr: &str) { // repro for https://github.com/apache/datafusion/issues/917 // where EXPLAIN ANALYZE was not correctly running optimizer - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_collect_statistics(true), + ); register_alltypes_parquet(&ctx).await; // This happens as an optimization pass where count(*)/count(1) can be @@ -600,19 +594,6 @@ async fn test_physical_plan_display_indent() { LIMIT 10"; let dataframe = ctx.sql(sql).await.unwrap(); let physical_plan = dataframe.create_physical_plan().await.unwrap(); - let expected = vec![ - "SortPreservingMergeExec: [the_min@2 DESC], fetch=10", - " SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true]", - " ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min]", - " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", - " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: c12@1 < 10", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true", - ]; let normalizer = ExplainNormalizer::new(); let actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) @@ -620,10 +601,24 @@ async fn test_physical_plan_display_indent() { .lines() // normalize paths .map(|s| normalizer.normalize(s)) - .collect::>(); - assert_eq!( - expected, actual, - "expected:\n{expected:#?}\nactual:\n\n{actual:#?}\n" + .collect::>() + .join("\n"); + + assert_snapshot!( + actual, + @r###" + SortPreservingMergeExec: [the_min@2 DESC], fetch=10 + SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true] + ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min] + AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 + AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] + CoalesceBatchesExec: target_batch_size=4096 + FilterExec: c12@1 < 10 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true + "### ); } @@ -645,19 +640,6 @@ async fn test_physical_plan_display_indent_multi_children() { let dataframe = ctx.sql(sql).await.unwrap(); let physical_plan = dataframe.create_physical_plan().await.unwrap(); - let expected = vec![ - "CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=9000", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c2]", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true", - ]; let normalizer = ExplainNormalizer::new(); let actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) @@ -665,11 +647,24 @@ async fn test_physical_plan_display_indent_multi_children() { .lines() // normalize paths .map(|s| normalizer.normalize(s)) - .collect::>(); - - assert_eq!( - expected, actual, - "expected:\n{expected:#?}\nactual:\n\n{actual:#?}\n" + .collect::>() + .join("\n"); + + assert_snapshot!( + actual, + @r###" + CoalesceBatchesExec: target_batch_size=4096 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0] + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=9000 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + ProjectionExec: expr=[c1@0 as c2] + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true + "### ); } @@ -733,6 +728,130 @@ async fn parquet_explain_analyze() { assert_contains!(&formatted, "row_groups_pruned_statistics=0"); } +// This test reproduces the behavior described in +// https://github.com/apache/datafusion/issues/16684 where projection +// pushdown with recursive CTEs could fail to remove unused columns +// (e.g. nested/recursive expansion causing full schema to be scanned). +// Keeping this test ensures we don't regress that behavior. +#[tokio::test] +#[cfg_attr(tarpaulin, ignore)] +async fn parquet_recursive_projection_pushdown() -> Result<()> { + use parquet::arrow::arrow_writer::ArrowWriter; + use parquet::file::properties::WriterProperties; + + let temp_dir = TempDir::new().unwrap(); + let parquet_path = temp_dir.path().join("hierarchy.parquet"); + + let ids = Int64Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let parent_ids = Int64Array::from(vec![0, 1, 1, 2, 2, 3, 4, 5, 6, 7]); + let values = Int64Array::from(vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("parent_id", DataType::Int64, true), + Field::new("value", DataType::Int64, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(ids), Arc::new(parent_ids), Arc::new(values)], + ) + .unwrap(); + + let file = File::create(&parquet_path).unwrap(); + let props = WriterProperties::builder().build(); + let mut writer = ArrowWriter::try_new(file, schema, Some(props)).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let ctx = SessionContext::new(); + ctx.register_parquet( + "hierarchy", + parquet_path.to_str().unwrap(), + ParquetReadOptions::default(), + ) + .await?; + + let sql = r#" + WITH RECURSIVE number_series AS ( + SELECT id, 1 as level + FROM hierarchy + WHERE id = 1 + + UNION ALL + + SELECT ns.id + 1, ns.level + 1 + FROM number_series ns + WHERE ns.id < 10 + ) + SELECT * FROM number_series ORDER BY id + "#; + + let dataframe = ctx.sql(sql).await?; + let physical_plan = dataframe.create_physical_plan().await?; + + let normalizer = ExplainNormalizer::new(); + let mut actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) + .trim() + .lines() + .map(|line| normalizer.normalize(line)) + .collect::>() + .join("\n"); + + fn replace_path_variants(actual: &mut String, path: &str) { + let mut candidates = vec![path.to_string()]; + + let trimmed = path.trim_start_matches(std::path::MAIN_SEPARATOR); + if trimmed != path { + candidates.push(trimmed.to_string()); + } + + let forward_slash = path.replace('\\', "/"); + if forward_slash != path { + candidates.push(forward_slash.clone()); + + let trimmed_forward = forward_slash.trim_start_matches('/'); + if trimmed_forward != forward_slash { + candidates.push(trimmed_forward.to_string()); + } + } + + for candidate in candidates { + *actual = actual.replace(&candidate, "TMP_DIR"); + } + } + + let temp_dir_path = temp_dir.path(); + let fs_path = temp_dir_path.to_string_lossy().to_string(); + replace_path_variants(&mut actual, &fs_path); + + if let Ok(url_path) = Path::from_filesystem_path(temp_dir_path) { + replace_path_variants(&mut actual, url_path.as_ref()); + } + + assert_snapshot!( + actual, + @r" + SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] + RecursiveQueryExec: name=number_series, is_distinct=false + CoalescePartitionsExec + ProjectionExec: expr=[id@0 as id, 1 as level] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: id@0 = 1 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] + CoalescePartitionsExec + ProjectionExec: expr=[id@0 + 1 as ns.id + Int64(1), level@1 + 1 as ns.level + Int64(1)] + CoalesceBatchesExec: target_batch_size=8192 + FilterExec: id@0 < 10 + RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 + WorkTableExec: name=number_series + " + ); + + Ok(()) +} + #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn parquet_explain_analyze_verbose() { @@ -777,14 +896,19 @@ async fn explain_logical_plan_only() { let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); - - let expected = vec![ - vec!["logical_plan", "Projection: count(Int64(1)) AS count(*)\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n SubqueryAlias: t\ - \n Projection: \ - \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"]]; - assert_eq!(expected, actual); + let actual = actual.into_iter().map(|r| r.join("\n")).collect::(); + + assert_snapshot!( + actual, + @r#" + logical_plan + Projection: count(Int64(1)) AS count(*) + Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] + SubqueryAlias: t + Projection: + Values: (Utf8("a"), Int64(1), Int64(100)), (Utf8("a"), Int64(2), Int64(150)) + "# + ); } #[tokio::test] @@ -795,14 +919,16 @@ async fn explain_physical_plan_only() { let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); - - let expected = vec![vec![ - "physical_plan", - "ProjectionExec: expr=[2 as count(*)]\ - \n PlaceholderRowExec\ - \n", - ]]; - assert_eq!(expected, actual); + let actual = actual.into_iter().map(|r| r.join("\n")).collect::(); + + assert_snapshot!( + actual, + @r###" + physical_plan + ProjectionExec: expr=[2 as count(*)] + PlaceholderRowExec + "### + ); } #[tokio::test] diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 77eec20eac006..7a59834475920 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,8 +15,13 @@ // specific language governing permissions and limitations // under the License. +use insta::assert_snapshot; + +use datafusion::assert_batches_eq; +use datafusion::catalog::MemTable; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::test_util::register_unbounded_file_with_ordering; +use datafusion_sql::unparser::plan_to_sql; use super::*; @@ -61,28 +66,21 @@ async fn join_change_in_planner() -> Result<()> { let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let expected = { - [ - "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], file_type=csv, has_header=false", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], file_type=csv, has_header=false" - ] - }; - let mut actual: Vec<&str> = formatted.trim().lines().collect(); - // Remove CSV lines - actual.remove(4); - actual.remove(7); - - assert_eq!( - expected, - actual[..], - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r" + SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10 + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8, preserve_order=true, sort_exprs=a1@0 ASC NULLS LAST + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + " ); Ok(()) } @@ -129,28 +127,21 @@ async fn join_no_order_on_filter() -> Result<()> { let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let expected = { - [ - "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], file_type=csv, has_header=false", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], file_type=csv, has_header=false" - ] - }; - let mut actual: Vec<&str> = formatted.trim().lines().collect(); - // Remove CSV lines - actual.remove(4); - actual.remove(7); - - assert_eq!( - expected, - actual[..], - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r" + SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a3@0 AS Int64) > CAST(a3@1 AS Int64) + 3 AND CAST(a3@0 AS Int64) < CAST(a3@1 AS Int64) + 10 + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2, a3], infinite_source=true, output_ordering=[a1@0 ASC NULLS LAST] + " ); Ok(()) } @@ -179,28 +170,21 @@ async fn join_change_in_planner_without_sort() -> Result<()> { let dataframe = ctx.sql(sql).await?; let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); - let expected = { - [ - "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], file_type=csv, has_header=false", - " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", - " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - // " DataSourceExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], file_type=csv, has_header=false" - ] - }; - let mut actual: Vec<&str> = formatted.trim().lines().collect(); - // Remove CSV lines - actual.remove(4); - actual.remove(7); - - assert_eq!( - expected, - actual[..], - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r" + SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10 + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8 + RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 + StreamingTableExec: partition_sizes=1, projection=[a1, a2], infinite_source=true + " ); Ok(()) } @@ -235,3 +219,92 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn join_using_uppercase_column() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "UPPER", + DataType::UInt32, + false, + )])); + let tmp_dir = TempDir::new()?; + let file_path = tmp_dir.path().join("uppercase-column.csv"); + let mut file = File::create(file_path.clone())?; + file.write_all("0".as_bytes())?; + drop(file); + + let ctx = SessionContext::new(); + ctx.register_csv( + "test", + file_path.to_str().unwrap(), + CsvReadOptions::new().schema(&schema).has_header(false), + ) + .await?; + + let dataframe = ctx + .sql( + r#" + SELECT test."UPPER" FROM "test" + INNER JOIN ( + SELECT test."UPPER" FROM "test" + ) AS selection USING ("UPPER") + ; + "#, + ) + .await?; + + assert_batches_eq!( + [ + "+-------+", + "| UPPER |", + "+-------+", + "| 0 |", + "+-------+", + ], + &dataframe.collect().await? + ); + + Ok(()) +} + +// Issue #17359: https://github.com/apache/datafusion/issues/17359 +#[tokio::test] +async fn unparse_cross_join() -> Result<()> { + let ctx = SessionContext::new(); + + let j1_schema = Arc::new(Schema::new(vec![ + Field::new("j1_id", DataType::Int32, true), + Field::new("j1_string", DataType::Utf8, true), + ])); + let j2_schema = Arc::new(Schema::new(vec![ + Field::new("j2_id", DataType::Int32, true), + Field::new("j2_string", DataType::Utf8, true), + ])); + + ctx.register_table("j1", Arc::new(MemTable::try_new(j1_schema, vec![vec![]])?))?; + ctx.register_table("j2", Arc::new(MemTable::try_new(j2_schema, vec![vec![]])?))?; + + let df = ctx + .sql( + r#" + select j1.j1_id, j2.j2_string + from j1, j2 + where j2.j2_id = 0 + "#, + ) + .await?; + + let unopt_sql = plan_to_sql(df.logical_plan())?; + assert_snapshot!(unopt_sql, @r#" + SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0) + "#); + + let optimized_plan = df.into_optimized_plan()?; + + let opt_sql = plan_to_sql(&optimized_plan)?; + assert_snapshot!(opt_sql, @r#" + SELECT j1.j1_id, j2.j2_string FROM j1 CROSS JOIN j2 WHERE (j2.j2_id = 0) + "#); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 579049692e7dc..e212ee269b151 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -34,7 +34,6 @@ use datafusion::{execution::context::SessionContext, physical_plan::displayable} use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{assert_contains, assert_not_contains}; -use insta::assert_snapshot; use object_store::path::Path; use std::fs::File; use std::io::Write; @@ -63,6 +62,7 @@ pub mod create_drop; pub mod explain_analyze; pub mod joins; mod path_partition; +mod runtime_config; pub mod select; mod sql_api; diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index bf8466d849f25..05cc723ef05fb 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -25,8 +25,6 @@ use std::sync::Arc; use arrow::datatypes::DataType; use datafusion::datasource::listing::ListingTableUrl; -use datafusion::datasource::physical_plan::ParquetSource; -use datafusion::datasource::source::DataSourceExec; use datafusion::{ datasource::{ file_format::{csv::CsvFormat, parquet::ParquetFormat}, @@ -42,8 +40,6 @@ use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::ScalarValue; use datafusion_execution::config::SessionConfig; -use datafusion_expr::{col, lit, Expr, Operator}; -use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; use async_trait::async_trait; use bytes::Bytes; @@ -54,58 +50,9 @@ use object_store::{ path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutOptions, PutResult, }; -use object_store::{Attributes, MultipartUpload, PutMultipartOpts, PutPayload}; +use object_store::{Attributes, MultipartUpload, PutMultipartOptions, PutPayload}; use url::Url; -#[tokio::test] -async fn parquet_partition_pruning_filter() -> Result<()> { - let ctx = SessionContext::new(); - - let table = create_partitioned_alltypes_parquet_table( - &ctx, - &[ - "year=2021/month=09/day=09/file.parquet", - "year=2021/month=10/day=09/file.parquet", - "year=2021/month=10/day=28/file.parquet", - ], - &[ - ("year", DataType::Int32), - ("month", DataType::Int32), - ("day", DataType::Int32), - ], - "mirror:///", - "alltypes_plain.parquet", - ) - .await; - - // The first three filters can be resolved using only the partition columns. - let filters = [ - Expr::eq(col("year"), lit(2021)), - Expr::eq(col("month"), lit(10)), - Expr::eq(col("day"), lit(28)), - Expr::gt(col("id"), lit(1)), - ]; - let exec = table.scan(&ctx.state(), None, &filters, None).await?; - let data_source_exec = exec.as_any().downcast_ref::().unwrap(); - if let Some((_, parquet_config)) = - data_source_exec.downcast_to_file_source::() - { - let pred = parquet_config.predicate().unwrap(); - // Only the last filter should be pushdown to TableScan - let expected = Arc::new(BinaryExpr::new( - Arc::new(Column::new_with_schema("id", &exec.schema()).unwrap()), - Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), - )); - - assert!(pred.as_any().is::()); - let pred = pred.as_any().downcast_ref::().unwrap(); - - assert_eq!(pred, expected.as_ref()); - } - Ok(()) -} - #[tokio::test] async fn parquet_distinct_partition_col() -> Result<()> { let ctx = SessionContext::new(); @@ -484,7 +431,9 @@ async fn parquet_multiple_nonstring_partitions() -> Result<()> { #[tokio::test] async fn parquet_statistics() -> Result<()> { - let ctx = SessionContext::new(); + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = true; + let ctx = SessionContext::new_with_config(config); register_partitioned_alltypes_parquet( &ctx, @@ -511,7 +460,7 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 4); - let stat_cols = physical_plan.statistics()?.column_statistics; + let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; assert_eq!(stat_cols.len(), 4); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(3)); @@ -526,7 +475,7 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 2); - let stat_cols = physical_plan.statistics()?.column_statistics; + let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; assert_eq!(stat_cols.len(), 2); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); @@ -636,7 +585,8 @@ async fn create_partitioned_alltypes_parquet_table( .iter() .map(|x| (x.0.to_owned(), x.1.clone())) .collect::>(), - ); + ) + .with_session_config_options(&ctx.copied_config()); let table_path = ListingTableUrl::parse(table_path).unwrap(); let store_path = @@ -695,7 +645,7 @@ impl ObjectStore for MirroringObjectStore { async fn put_multipart_opts( &self, _location: &Path, - _opts: PutMultipartOpts, + _opts: PutMultipartOptions, ) -> object_store::Result> { unimplemented!() } @@ -712,7 +662,7 @@ impl ObjectStore for MirroringObjectStore { let meta = ObjectMeta { location: location.clone(), last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), - size: metadata.len() as usize, + size: metadata.len(), e_tag: None, version: None, }; @@ -728,14 +678,15 @@ impl ObjectStore for MirroringObjectStore { async fn get_range( &self, location: &Path, - range: Range, + range: Range, ) -> object_store::Result { self.files.iter().find(|x| *x == location).unwrap(); let path = std::path::PathBuf::from(&self.mirrored_file); let mut file = File::open(path).unwrap(); - file.seek(SeekFrom::Start(range.start as u64)).unwrap(); + file.seek(SeekFrom::Start(range.start)).unwrap(); let to_read = range.end - range.start; + let to_read: usize = to_read.try_into().unwrap(); let mut data = Vec::with_capacity(to_read); let read = file.take(to_read as u64).read_to_end(&mut data).unwrap(); assert_eq!(read, to_read); @@ -750,9 +701,10 @@ impl ObjectStore for MirroringObjectStore { fn list( &self, prefix: Option<&Path>, - ) -> BoxStream<'_, object_store::Result> { + ) -> BoxStream<'static, object_store::Result> { let prefix = prefix.cloned().unwrap_or_default(); - Box::pin(stream::iter(self.files.iter().filter_map( + let size = self.file_size; + Box::pin(stream::iter(self.files.clone().into_iter().filter_map( move |location| { // Don't return for exact prefix match let filter = location @@ -762,9 +714,9 @@ impl ObjectStore for MirroringObjectStore { filter.then(|| { Ok(ObjectMeta { - location: location.clone(), + location, last_modified: Utc.timestamp_nanos(0), - size: self.file_size as usize, + size, e_tag: None, version: None, }) @@ -802,7 +754,7 @@ impl ObjectStore for MirroringObjectStore { let object = ObjectMeta { location: k.clone(), last_modified: Utc.timestamp_nanos(0), - size: self.file_size as usize, + size: self.file_size, e_tag: None, version: None, }; diff --git a/datafusion/core/tests/sql/runtime_config.rs b/datafusion/core/tests/sql/runtime_config.rs new file mode 100644 index 0000000000000..9627d7bccdb04 --- /dev/null +++ b/datafusion/core/tests/sql/runtime_config.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for runtime configuration SQL interface + +use std::sync::Arc; + +use datafusion::execution::context::SessionContext; +use datafusion::execution::context::TaskContext; +use datafusion_physical_plan::common::collect; + +#[tokio::test] +async fn test_memory_limit_with_spill() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + ctx.sql("SET datafusion.execution.sort_spill_reservation_bytes = 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,10000000) as t1(v1) order by v1;"; + let df = ctx.sql(query).await.unwrap(); + + let plan = df.create_physical_plan().await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx).unwrap(); + + let _results = collect(stream).await; + let metrics = plan.metrics().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + assert!(spill_count > 0, "Expected spills but none occurred"); +} + +#[tokio::test] +async fn test_no_spill_with_adequate_memory() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '10M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + ctx.sql("SET datafusion.execution.sort_spill_reservation_bytes = 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let df = ctx.sql(query).await.unwrap(); + + let plan = df.create_physical_plan().await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx).unwrap(); + + let _results = collect(stream).await; + let metrics = plan.metrics().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + assert_eq!(spill_count, 0, "Expected no spills but some occurred"); +} + +#[tokio::test] +async fn test_multiple_configs() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '100M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + ctx.sql("SET datafusion.execution.batch_size = '2048'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_ok(), "Should not fail due to memory limit"); + + let state = ctx.state(); + let batch_size = state.config().options().execution.batch_size; + assert_eq!(batch_size, 2048); +} + +#[tokio::test] +async fn test_memory_limit_enforcement() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_err(), "Should fail due to memory limit"); + + ctx.sql("SET datafusion.runtime.memory_limit = '100M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_ok(), "Should not fail due to memory limit"); +} + +#[tokio::test] +async fn test_invalid_memory_limit() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.memory_limit = '100X'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains("Unsupported unit 'X'")); +} + +#[tokio::test] +async fn test_max_temp_directory_size_enforcement() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + ctx.sql("SET datafusion.execution.sort_spill_reservation_bytes = 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + + ctx.sql("SET datafusion.runtime.max_temp_directory_size = '0K'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!( + result.is_err(), + "Should fail due to max temp directory size limit" + ); + + ctx.sql("SET datafusion.runtime.max_temp_directory_size = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!( + result.is_ok(), + "Should not fail due to max temp directory size limit" + ); +} + +#[tokio::test] +async fn test_test_metadata_cache_limit() { + let ctx = SessionContext::new(); + + let update_limit = async |ctx: &SessionContext, limit: &str| { + ctx.sql( + format!("SET datafusion.runtime.metadata_cache_limit = '{limit}'").as_str(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + }; + + let get_limit = |ctx: &SessionContext| -> usize { + ctx.task_ctx() + .runtime_env() + .cache_manager + .get_file_metadata_cache() + .cache_limit() + }; + + update_limit(&ctx, "100M").await; + assert_eq!(get_limit(&ctx), 100 * 1024 * 1024); + + update_limit(&ctx, "2G").await; + assert_eq!(get_limit(&ctx), 2 * 1024 * 1024 * 1024); + + update_limit(&ctx, "123K").await; + assert_eq!(get_limit(&ctx), 123 * 1024); +} + +#[tokio::test] +async fn test_unknown_runtime_config() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.unknown_config = 'value'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains("Unknown runtime configuration")); +} diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index f874dd7c08428..98c3e3ccee8a1 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -17,6 +17,7 @@ use super::*; use datafusion_common::ScalarValue; +use insta::assert_snapshot; #[tokio::test] async fn test_list_query_parameters() -> Result<()> { @@ -217,10 +218,12 @@ async fn test_parameter_invalid_types() -> Result<()> { .with_param_values(vec![ScalarValue::from(4_i32)])? .collect() .await; - assert_eq!( - results.unwrap_err().strip_backtrace(), - "type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32" -); + assert_snapshot!(results.unwrap_err().strip_backtrace(), + @r#" + type_coercion + caused by + Error during planning: Cannot infer common argument type for comparison operation List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32 + "#); Ok(()) } @@ -343,3 +346,28 @@ async fn test_version_function() { assert_eq!(version.value(0), expected_version); } + +/// Regression test for https://github.com/apache/datafusion/issues/17513 +/// See https://github.com/apache/datafusion/pull/17520 +#[tokio::test] +async fn test_select_no_projection() -> Result<()> { + let tmp_dir = TempDir::new()?; + // `create_ctx_with_partition` creates 10 rows per partition and we chose 1 partition + let ctx = create_ctx_with_partition(&tmp_dir, 1).await?; + + let results = ctx.sql("SELECT FROM test").await?.collect().await?; + // We should get all of the rows, just without any columns + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 10); + // Check that none of the batches have any columns + for batch in &results { + assert_eq!(batch.num_columns(), 0); + } + // Sanity check the output, should be just empty columns + assert_snapshot!(batches_to_sort_string(&results), @r" + ++ + ++ + ++ + "); + Ok(()) +} diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index 034d6fa23d9cb..b87afd27ddea7 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -19,6 +19,23 @@ use datafusion::prelude::*; use tempfile::TempDir; +#[tokio::test] +async fn test_window_function() { + let ctx = SessionContext::new(); + let df = ctx + .sql( + r#"SELECT + t1.v1, + SUM(t1.v1) OVER w + 1 + FROM + generate_series(1, 10000) AS t1(v1) + WINDOW + w AS (ORDER BY t1.v1 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW);"#, + ) + .await; + assert!(df.is_ok()); +} + #[tokio::test] async fn unsupported_ddl_returns_error() { // Verify SessionContext::with_sql_options errors appropriately @@ -67,8 +84,8 @@ async fn dml_output_schema() { ctx.sql("CREATE TABLE test (x int)").await.unwrap(); let sql = "INSERT INTO test VALUES (1)"; let df = ctx.sql(sql).await.unwrap(); - let count_schema = Schema::new(vec![Field::new("count", DataType::UInt64, false)]); - assert_eq!(Schema::from(df.schema()), count_schema); + let count_schema = &Schema::new(vec![Field::new("count", DataType::UInt64, false)]); + assert_eq!(df.schema().as_arrow(), count_schema); } #[tokio::test] diff --git a/datafusion/core/tests/tpc-ds/49.sql b/datafusion/core/tests/tpc-ds/49.sql index 090e9746c0d81..219877719f227 100644 --- a/datafusion/core/tests/tpc-ds/49.sql +++ b/datafusion/core/tests/tpc-ds/49.sql @@ -110,7 +110,7 @@ select channel, item, return_ratio, return_rank, currency_rank from where sr.sr_return_amt > 10000 and sts.ss_net_profit > 1 - and sts.ss_net_paid > 0 + and sts.ss_net_paid > 0 and sts.ss_quantity > 0 and ss_sold_date_sk = d_date_sk and d_year = 2000 diff --git a/datafusion/core/tests/tracing/asserting_tracer.rs b/datafusion/core/tests/tracing/asserting_tracer.rs new file mode 100644 index 0000000000000..292e066e5f121 --- /dev/null +++ b/datafusion/core/tests/tracing/asserting_tracer.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::VecDeque; +use std::ops::Deref; +use std::sync::{Arc, LazyLock}; + +use datafusion_common::{HashMap, HashSet}; +use datafusion_common_runtime::{set_join_set_tracer, JoinSetTracer}; +use futures::future::BoxFuture; +use tokio::sync::{Mutex, MutexGuard}; + +/// Initializes the global join set tracer with the asserting tracer. +/// Call this function before spawning any tasks that should be traced. +pub fn init_asserting_tracer() { + set_join_set_tracer(ASSERTING_TRACER.deref()) + .expect("Failed to initialize asserting tracer"); +} + +/// Verifies that the current task has a traceable ancestry back to "root". +/// +/// The function performs a breadth-first search (BFS) in the global spawn graph: +/// - It starts at the current task and follows parent links. +/// - If it reaches the "root" task, the ancestry is valid. +/// - If a task is missing from the graph, it panics. +/// +/// Note: Tokio task IDs are unique only while a task is active. +/// Once a task completes, its ID may be reused. +pub async fn assert_traceability() { + // Acquire the spawn graph lock. + let spawn_graph = acquire_spawn_graph().await; + + // Start BFS with the current task. + let mut tasks_to_check = VecDeque::from(vec![current_task()]); + + while let Some(task_id) = tasks_to_check.pop_front() { + if task_id == "root" { + // Ancestry reached the root. + continue; + } + // Obtain parent tasks, panicking if the task is not present. + let parents = spawn_graph + .get(&task_id) + .expect("Task ID not found in spawn graph"); + // Queue each parent for checking. + for parent in parents { + tasks_to_check.push_back(parent.clone()); + } + } +} + +/// Tracer that maintains a graph of task ancestry for tracing purposes. +/// +/// For each task, it records a set of parent task IDs to ensure that every +/// asynchronous task can be traced back to "root". +struct AssertingTracer { + /// An asynchronous map from task IDs to their parent task IDs. + spawn_graph: Arc>>>, +} + +/// Lazily initialized global instance of `AssertingTracer`. +static ASSERTING_TRACER: LazyLock = LazyLock::new(AssertingTracer::new); + +impl AssertingTracer { + /// Creates a new `AssertingTracer` with an empty spawn graph. + fn new() -> Self { + Self { + spawn_graph: Arc::default(), + } + } +} + +/// Returns the current task's ID as a string, or "root" if unavailable. +/// +/// Tokio guarantees task IDs are unique only among active tasks, +/// so completed tasks may have their IDs reused. +fn current_task() -> String { + tokio::task::try_id() + .map(|id| format!("{id}")) + .unwrap_or_else(|| "root".to_string()) +} + +/// Asynchronously locks and returns the spawn graph. +/// +/// The returned guard allows inspection or modification of task ancestry. +async fn acquire_spawn_graph<'a>() -> MutexGuard<'a, HashMap>> { + ASSERTING_TRACER.spawn_graph.lock().await +} + +/// Registers the current task as a child of `parent_id` in the spawn graph. +async fn register_task(parent_id: String) { + acquire_spawn_graph() + .await + .entry(current_task()) + .or_insert_with(HashSet::new) + .insert(parent_id); +} + +impl JoinSetTracer for AssertingTracer { + /// Wraps an asynchronous future to record its parent task before execution. + fn trace_future( + &self, + fut: BoxFuture<'static, Box>, + ) -> BoxFuture<'static, Box> { + // Capture the parent task ID. + let parent_id = current_task(); + Box::pin(async move { + // Register the parent-child relationship. + register_task(parent_id).await; + // Execute the wrapped future. + fut.await + }) + } + + /// Wraps a blocking closure to record its parent task before execution. + fn trace_block( + &self, + f: Box Box + Send>, + ) -> Box Box + Send> { + let parent_id = current_task(); + Box::new(move || { + // Synchronously record the task relationship. + futures::executor::block_on(register_task(parent_id)); + f() + }) + } +} diff --git a/datafusion/core/tests/tracing/mod.rs b/datafusion/core/tests/tracing/mod.rs new file mode 100644 index 0000000000000..0b66a49eea9f4 --- /dev/null +++ b/datafusion/core/tests/tracing/mod.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # JoinSetTracer Integration Tests +//! +//! These are smoke tests that verify `JoinSetTracer` can be correctly injected into DataFusion. +//! +//! They run a SQL query that reads Parquet data and performs an aggregation, +//! which causes DataFusion to spawn multiple tasks. +//! The object store is wrapped to assert that every task can be traced back to the root. +//! +//! These tests don't cover all edge cases, but they should fail if changes to +//! DataFusion's task spawning break tracing. + +mod asserting_tracer; +mod traceable_object_store; + +use asserting_tracer::init_asserting_tracer; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::ListingOptions; +use datafusion::prelude::*; +use datafusion::test_util::parquet_test_data; +use datafusion_common::assert_contains; +use datafusion_common_runtime::SpawnedTask; +use log::info; +use object_store::local::LocalFileSystem; +use std::sync::Arc; +use traceable_object_store::traceable_object_store; +use url::Url; + +/// Combined test that first verifies the query panics when no tracer is registered, +/// then initializes the tracer and confirms the query runs successfully. +/// +/// Using a single test function prevents global tracer leakage between tests. +#[tokio::test(flavor = "multi_thread", worker_threads = 8)] +async fn test_tracer_injection() { + // Without initializing the tracer, run the query. + // Spawn the query in a separate task so we can catch its panic. + info!("Running query without tracer"); + // The absence of the tracer should cause the task to panic inside the `TraceableObjectStore`. + let untraced_result = SpawnedTask::spawn(run_query()).join().await; + if let Err(e) = untraced_result { + // Check if the error message contains the expected error. + assert!(e.is_panic(), "Expected a panic, but got: {e:?}"); + assert_contains!(e.to_string(), "Task ID not found in spawn graph"); + info!("Caught expected panic: {e}"); + } else { + panic!("Expected the task to panic, but it completed successfully"); + }; + + // Initialize the asserting tracer and run the query. + info!("Initializing tracer and re-running query"); + init_asserting_tracer(); + SpawnedTask::spawn(run_query()).join().await.unwrap(); // Should complete without panics or errors. +} + +/// Executes a sample task-spawning SQL query using a traceable object store. +async fn run_query() { + info!("Starting query execution"); + + // Create a new session context + let ctx = SessionContext::new(); + + // Get the test data directory + let test_data = if cfg!(target_os = "windows") { + // Prefix Windows paths with "/", since they start with :/ but the URI should be + // test:///C:/... (https://datatracker.ietf.org/doc/html/rfc8089#appendix-E.2) + format!("/{}", parquet_test_data()) + } else { + parquet_test_data() + }; + + // Define a Parquet file format with pruning enabled + let file_format = ParquetFormat::default().with_enable_pruning(true); + + // Set listing options for the parquet file with a specific extension + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_file_extension("alltypes_tiny_pages_plain.parquet"); + + // Wrap the local file system in a traceable object store to verify task traceability. + let local_fs = Arc::new(LocalFileSystem::new()); + let traceable_store = traceable_object_store(local_fs); + + // Register the traceable object store with a test URL. + let url = Url::parse("test://").unwrap(); + ctx.register_object_store(&url, traceable_store.clone()); + + // Register a listing table from the test data directory. + let table_path = format!("test://{test_data}/"); + ctx.register_listing_table("alltypes", &table_path, listing_options, None, None) + .await + .expect("Failed to register table"); + + // Define and execute an SQL query against the registered table, which should + // spawn multiple tasks due to the aggregation and parquet file read. + let sql = "SELECT COUNT(*), string_col FROM alltypes GROUP BY string_col"; + let result_batches = ctx.sql(sql).await.unwrap().collect().await.unwrap(); + + info!("Query complete: {} batches returned", result_batches.len()); +} diff --git a/datafusion/core/tests/tracing/traceable_object_store.rs b/datafusion/core/tests/tracing/traceable_object_store.rs new file mode 100644 index 0000000000000..60ef1cc5d6b6a --- /dev/null +++ b/datafusion/core/tests/tracing/traceable_object_store.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Object store implementation used for testing + +use crate::tracing::asserting_tracer::assert_traceability; +use futures::stream::BoxStream; +use object_store::{ + path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, + ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, +}; +use std::fmt::{Debug, Display, Formatter}; +use std::sync::Arc; + +/// Returns an `ObjectStore` that asserts it can trace its calls back to the root tokio task. +pub fn traceable_object_store( + object_store: Arc, +) -> Arc { + Arc::new(TraceableObjectStore::new(object_store)) +} + +/// An object store that asserts it can trace all its calls back to the root tokio task. +#[derive(Debug)] +struct TraceableObjectStore { + inner: Arc, +} + +impl TraceableObjectStore { + fn new(inner: Arc) -> Self { + Self { inner } + } +} + +impl Display for TraceableObjectStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.inner, f) + } +} + +/// All trait methods are forwarded to the inner object store, +/// after asserting they can trace their calls back to the root tokio task. +#[async_trait::async_trait] +impl ObjectStore for TraceableObjectStore { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> object_store::Result { + assert_traceability().await; + self.inner.put_opts(location, payload, opts).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> object_store::Result> { + assert_traceability().await; + self.inner.put_multipart_opts(location, opts).await + } + + async fn get_opts( + &self, + location: &Path, + options: GetOptions, + ) -> object_store::Result { + assert_traceability().await; + self.inner.get_opts(location, options).await + } + + async fn head(&self, location: &Path) -> object_store::Result { + assert_traceability().await; + self.inner.head(location).await + } + + async fn delete(&self, location: &Path) -> object_store::Result<()> { + assert_traceability().await; + self.inner.delete(location).await + } + + fn list( + &self, + prefix: Option<&Path>, + ) -> BoxStream<'static, object_store::Result> { + futures::executor::block_on(assert_traceability()); + self.inner.list(prefix) + } + + async fn list_with_delimiter( + &self, + prefix: Option<&Path>, + ) -> object_store::Result { + assert_traceability().await; + self.inner.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { + assert_traceability().await; + self.inner.copy(from, to).await + } + + async fn copy_if_not_exists( + &self, + from: &Path, + to: &Path, + ) -> object_store::Result<()> { + assert_traceability().await; + self.inner.copy_if_not_exists(from, to).await + } +} diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index 1fc6d14c5b229..07d289cab06c2 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -56,7 +56,7 @@ impl ExprPlanner for MyCustomPlanner { } BinaryOperator::Question => { Ok(PlannerResult::Planned(Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), + Expr::Literal(ScalarValue::Boolean(Some(true)), None), None::<&str>, format!("{} ? {}", expr.left, expr.right), )))) diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index 12f700ce572ba..c8a4279a42110 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -26,6 +26,7 @@ use datafusion::{ use datafusion_catalog::{Session, TableProvider}; use datafusion_expr::{dml::InsertOp, Expr, TableType}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::SchedulingType; use datafusion_physical_plan::{ execution_plan::{Boundedness, EmissionType}, DisplayAs, ExecutionPlan, PlanProperties, @@ -132,7 +133,8 @@ impl TestInsertExec { Partitioning::UnknownPartitioning(1), EmissionType::Incremental, Boundedness::Bounded, - ), + ) + .with_scheduling_type(SchedulingType::Cooperative), } } } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5cbb05f290a70..982b4804597e6 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,7 +18,9 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::any::Any; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -26,10 +28,11 @@ use std::sync::{ }; use arrow::array::{ - types::UInt64Type, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray, + record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, + StringArray, StructArray, UInt64Array, }; use arrow::datatypes::{Fields, Schema}; - +use arrow_schema::FieldRef; use datafusion::common::test_util::batches_to_string; use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; @@ -48,11 +51,13 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::assert_contains; +use datafusion_common::{assert_contains, exec_datafusion_err}; use datafusion_common::{cast::as_primitive_array, exec_err}; + +use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - LogicalPlanBuilder, SimpleAggregateUDF, + col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, + GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -293,10 +298,12 @@ async fn deregister_udaf() -> Result<()> { ctx.register_udaf(my_avg); assert!(ctx.state().aggregate_functions().contains_key("my_avg")); + assert!(datafusion_execution::FunctionRegistry::udafs(&ctx).contains("my_avg")); ctx.deregister_udaf("my_avg"); assert!(!ctx.state().aggregate_functions().contains_key("my_avg")); + assert!(!datafusion_execution::FunctionRegistry::udafs(&ctx).contains("my_avg")); Ok(()) } @@ -375,13 +382,13 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&alias_result), @r###" - +------------+ - | dummy(t.i) | - +------------+ - | 1.0 | - +------------+ - "###); + insta::assert_snapshot!(batches_to_string(&alias_result), @r" + +------------------+ + | dummy_alias(t.i) | + +------------------+ + | 1.0 | + +------------------+ + "); Ok(()) } @@ -569,7 +576,7 @@ impl TimeSum { // Returns the same type as its input let return_type = timestamp_type.clone(); - let state_fields = vec![Field::new("sum", timestamp_type, true)]; + let state_fields = vec![Field::new("sum", timestamp_type, true).into()]; let volatility = Volatility::Immutable; @@ -669,7 +676,7 @@ impl FirstSelector { let state_fields = state_type .into_iter() .enumerate() - .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .map(|(i, t)| Field::new(format!("{i}"), t, true).into()) .collect::>(); // Possible input signatures @@ -774,14 +781,14 @@ impl Accumulator for FirstSelector { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct TestGroupsAccumulator { signature: Signature, result: u64, } impl AggregateUDFImpl for TestGroupsAccumulator { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -812,21 +819,6 @@ impl AggregateUDFImpl for TestGroupsAccumulator { ) -> Result> { Ok(Box::new(self.clone())) } - - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.result == other.result && self.signature == other.signature - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.signature.hash(hasher); - self.result.hash(hasher); - hasher.finish() - } } impl Accumulator for TestGroupsAccumulator { @@ -890,3 +882,290 @@ impl GroupsAccumulator for TestGroupsAccumulator { size_of::() } } + +#[derive(Debug)] +struct MetadataBasedAggregateUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl PartialEq for MetadataBasedAggregateUdf { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + signature, + metadata, + } = self; + name == &other.name + && signature == &other.signature + && metadata == &other.metadata + } +} +impl Eq for MetadataBasedAggregateUdf {} +impl Hash for MetadataBasedAggregateUdf { + fn hash(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + std::any::type_name::().hash(state); + name.hash(state); + signature.hash(state); + } +} + +impl MetadataBasedAggregateUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl AggregateUDFImpl for MetadataBasedAggregateUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("this should never be called since return_field is implemented"); + } + + fn return_field(&self, _arg_fields: &[FieldRef]) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let input_expr = acc_args + .exprs + .first() + .ok_or(exec_datafusion_err!("Expected one argument"))?; + let input_field = input_expr.return_field(acc_args.schema)?; + + let double_output = input_field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedAccumulator { + double_output, + curr_sum: 0, + })) + } +} + +#[derive(Debug)] +struct MetadataBasedAccumulator { + double_output: bool, + curr_sum: u64, +} + +impl Accumulator for MetadataBasedAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = values[0] + .as_any() + .downcast_ref::() + .ok_or(exec_datafusion_err!("Expected UInt64Array"))?; + + self.curr_sum = arr.iter().fold(self.curr_sum, |a, b| a + b.unwrap_or(0)); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let v = match self.double_output { + true => self.curr_sum * 2, + false => self.curr_sum, + }; + + Ok(ScalarValue::from(v)) + } + + fn size(&self) -> usize { + 9 + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::from(self.curr_sum)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +#[tokio::test] +async fn test_metadata_based_aggregate() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = + AggregateUDF::from(MetadataBasedAggregateUdf::new(HashMap::new())); + let with_output_meta_udf = AggregateUDF::from(MetadataBasedAggregateUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let df = df.aggregate( + vec![], + vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ], + )?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50]), + ("meta_with_in_no_out", UInt64, [100]), + ("meta_no_in_with_out", UInt64, [50]), + ("meta_with_in_with_out", UInt64, [100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +#[tokio::test] +async fn test_metadata_based_aggregate_as_window() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = Arc::new(AggregateUDF::from( + MetadataBasedAggregateUdf::new(HashMap::new()), + )); + let with_output_meta_udf = + Arc::new(AggregateUDF::from(MetadataBasedAggregateUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + ))); + + let df = df.select(vec![ + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)), + vec![col("no_metadata")], + )) + .alias("meta_no_in_no_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(no_output_meta_udf), + vec![col("with_metadata")], + )) + .alias("meta_with_in_no_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)), + vec![col("no_metadata")], + )) + .alias("meta_no_in_with_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(with_output_meta_udf), + vec![col("with_metadata")], + )) + .alias("meta_with_in_with_out"), + ])?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]), + ("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index e46940e631542..f0bf15d3483ba 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -63,15 +63,14 @@ use std::hash::Hash; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; +use arrow::array::{Array, ArrayRef, StringViewArray}; use arrow::{ - array::{Int64Array, StringArray}, - datatypes::SchemaRef, - record_batch::RecordBatch, + array::Int64Array, datatypes::SchemaRef, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ - common::cast::{as_int64_array, as_string_array}, + common::cast::as_int64_array, common::{arrow_datafusion_err, internal_err, DFSchemaRef}, error::{DataFusionError, Result}, execution::{ @@ -100,6 +99,7 @@ use datafusion_optimizer::AnalyzerRule; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use async_trait::async_trait; +use datafusion_common::cast::as_string_view_array; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -580,7 +580,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { self.input.schema() } - fn check_invariants(&self, check: InvariantLevel, _plan: &LogicalPlan) -> Result<()> { + fn check_invariants(&self, check: InvariantLevel) -> Result<()> { if let Some(InvariantMock { should_fail_invariant, kind, @@ -796,22 +796,26 @@ fn accumulate_batch( k: &usize, ) -> BTreeMap { let num_rows = input_batch.num_rows(); + // Assuming the input columns are - // column[0]: customer_id / UTF8 + // column[0]: customer_id UTF8View // column[1]: revenue: Int64 - let customer_id = - as_string_array(input_batch.column(0)).expect("Column 0 is not customer_id"); + let customer_id_column = input_batch.column(0); let revenue = as_int64_array(input_batch.column(1)).unwrap(); for row in 0..num_rows { - add_row( - &mut top_values, - customer_id.value(row), - revenue.value(row), - k, - ); + let customer_id = match customer_id_column.data_type() { + arrow::datatypes::DataType::Utf8View => { + let array = as_string_view_array(customer_id_column).unwrap(); + array.value(row) + } + _ => panic!("Unsupported customer_id type"), + }; + + add_row(&mut top_values, customer_id, revenue.value(row), k); } + top_values } @@ -843,11 +847,19 @@ impl Stream for TopKReader { self.state.iter().rev().unzip(); let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect(); + + let customer_array: ArrayRef = match schema.field(0).data_type() { + arrow::datatypes::DataType::Utf8View => { + Arc::new(StringViewArray::from(customer)) + } + other => panic!("Unsupported customer_id output type: {other:?}"), + }; + Poll::Ready(Some( RecordBatch::try_new( schema, vec![ - Arc::new(StringArray::from(customer)), + Arc::new(customer_array), Arc::new(Int64Array::from(revenue)), ], ) @@ -900,11 +912,12 @@ impl MyAnalyzerRule { .map(|e| { e.transform(|e| { Ok(match e { - Expr::Literal(ScalarValue::Int64(i)) => { + Expr::Literal(ScalarValue::Int64(i), _) => { // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) + Transformed::yes(Expr::Literal( + ScalarValue::UInt64(i.map(|i| i as u64)), + None, + )) } _ => Transformed::no(e), }) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 264bd6b66a600..f1af66de9b592 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,16 +16,19 @@ // under the License. use std::any::Any; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::array::as_string_array; +use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; +use arrow_schema::{ArrowError, FieldRef, SchemaRef}; use datafusion::common::test_util::batches_to_string; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; @@ -34,14 +37,15 @@ use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err, - plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_datafusion_err, + exec_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::FieldMetadata; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, - OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, + LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -57,7 +61,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let ctx = create_udf_context(); register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -76,7 +80,7 @@ async fn csv_query_avg_sqrt() -> Result<()> { register_aggregate_csv(&ctx).await?; // Note it is a different column (c12) than above (c11) let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -177,6 +181,7 @@ async fn scalar_udf() -> Result<()> { Ok(()) } +#[derive(PartialEq, Eq, Hash)] struct Simple0ArgsScalarUDF { name: String, signature: Signature, @@ -389,7 +394,7 @@ async fn udaf_as_window_func() -> Result<()> { WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; - let dataframe = context.sql(sql).await.unwrap(); + let dataframe = context.sql(sql).await?; assert_eq!(format!("{}", dataframe.logical_plan()), expected); Ok(()) } @@ -399,7 +404,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -443,7 +448,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -473,19 +478,19 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { "###); let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?; - insta::assert_snapshot!(batches_to_string(&alias_result), @r###" - +------------+ - | dummy(t.i) | - +------------+ - | 1 | - +------------+ - "###); + insta::assert_snapshot!(batches_to_string(&alias_result), @r" + +------------------+ + | dummy_alias(t.i) | + +------------------+ + | 1 | + +------------------+ + "); Ok(()) } /// Volatile UDF that should append a different value to each row -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AddIndexToStringVolatileScalarUDF { name: String, signature: Signature, @@ -656,7 +661,7 @@ async fn volatile_scalar_udf_with_params() -> Result<()> { Ok(()) } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct CastToI64UDF { signature: Signature, } @@ -770,15 +775,17 @@ async fn deregister_udf() -> Result<()> { ctx.register_udf(cast2i64); assert!(ctx.udfs().contains("cast_to_i64")); + assert!(FunctionRegistry::udfs(&ctx).contains("cast_to_i64")); ctx.deregister_udf("cast_to_i64"); assert!(!ctx.udfs().contains("cast_to_i64")); + assert!(!FunctionRegistry::udfs(&ctx).contains("cast_to_i64")); Ok(()) } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct TakeUDF { signature: Signature, } @@ -803,7 +810,7 @@ impl ScalarUDFImpl for TakeUDF { &self.signature } fn return_type(&self, _args: &[DataType]) -> Result { - not_impl_err!("Not called because the return_type_from_args is implemented") + not_impl_err!("Not called because the return_field_from_args is implemented") } /// This function returns the type of the first or second argument based on @@ -811,9 +818,9 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 3 { - return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", args.arg_fields.len()); } let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) { @@ -838,9 +845,12 @@ impl ScalarUDFImpl for TakeUDF { ); }; - Ok(ReturnInfo::new_nullable( - args.arg_types[take_idx].to_owned(), - )) + Ok(Field::new( + self.name(), + args.arg_fields[take_idx].data_type().to_owned(), + true, + ) + .into()) } // The actual implementation @@ -929,7 +939,7 @@ impl FunctionFactory for CustomFunctionFactory { // // it also defines custom [ScalarUDFImpl::simplify()] // to replace ScalarUDF expression with one instance contains. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct ScalarFunctionWrapper { name: String, expr: Expr, @@ -967,10 +977,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(ExprSimplifyResult::Simplified(replacement)) } - - fn aliases(&self) -> &[String] { - &[] - } } impl ScalarFunctionWrapper { @@ -1003,10 +1009,7 @@ impl ScalarFunctionWrapper { fn parse_placeholder_identifier(placeholder: &str) -> Result { if let Some(value) = placeholder.strip_prefix('$') { Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { - DataFusionError::Execution(format!( - "Placeholder `{}` parsing error: {}!", - placeholder, e - )) + exec_datafusion_err!("Placeholder `{placeholder}` parsing error: {e}!") })?) } else { exec_err!("Placeholder should start with `$`!") @@ -1160,7 +1163,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( match ctx.sql(sql).await { Ok(_) => {} Err(e) => { - panic!("Error creating function: {}", e); + panic!("Error creating function: {e}"); } } @@ -1179,7 +1182,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( quote_style: None, span: Span::empty(), }), - data_type: DataType::Utf8, + data_type: DataType::Utf8View, default_expr: None, }]), return_type: Some(DataType::Int32), @@ -1206,6 +1209,22 @@ struct MyRegexUdf { regex: Regex, } +impl PartialEq for MyRegexUdf { + fn eq(&self, other: &Self) -> bool { + let Self { signature, regex } = self; + signature == &other.signature && regex.as_str() == other.regex.as_str() + } +} +impl Eq for MyRegexUdf {} + +impl Hash for MyRegexUdf { + fn hash(&self, state: &mut H) { + let Self { signature, regex } = self; + signature.hash(state); + regex.as_str().hash(state); + } +} + impl MyRegexUdf { fn new(pattern: &str) -> Self { Self { @@ -1257,20 +1276,6 @@ impl ScalarUDFImpl for MyRegexUdf { _ => exec_err!("regex_udf only accepts a Utf8 arguments"), } } - - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.regex.as_str() == other.regex.as_str() - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.regex.as_str().hash(hasher); - hasher.finish() - } } #[tokio::test] @@ -1367,3 +1372,642 @@ async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } + +#[derive(Debug, PartialEq, Eq)] +struct MetadataBasedUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl Hash for MetadataBasedUdf { + fn hash(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + name.hash(state); + signature.hash(state); + } +} + +impl MetadataBasedUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl ScalarUDFImpl for MetadataBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!( + "this should never be called since return_field_from_args is implemented" + ); + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let should_double = args.arg_fields[0] + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + let multiplier = if should_double { 2 } else { 1 }; + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.map(|x| x * multiplier)) + .collect(); + let array_ref = Arc::new(UInt64Array::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::UInt64(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::UInt64( + value.map(|v| v * multiplier), + ))) + } + } + } +} + +#[tokio::test] +async fn test_metadata_based_udf() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let no_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(HashMap::new())); + let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]), + ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} + +#[tokio::test] +async fn test_metadata_based_udf_with_literal() -> Result<()> { + let ctx = SessionContext::new(); + let input_metadata: HashMap = + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(); + let input_metadata = FieldMetadata::from(input_metadata); + let df = ctx.sql("select 0;").await?.select(vec![ + lit(5u64).alias_with_metadata("lit_with_doubling", Some(input_metadata.clone())), + lit(5u64).alias("lit_no_doubling"), + lit_with_metadata(5u64, Some(input_metadata)) + .alias("lit_with_double_no_alias_metadata"), + ])?; + + let output_metadata: HashMap = + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(); + let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(output_metadata.clone())); + + let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?) + .project(vec![ + custom_udf + .call(vec![col("lit_with_doubling")]) + .alias("doubled_output"), + custom_udf + .call(vec![col("lit_no_doubling")]) + .alias("not_doubled_output"), + custom_udf + .call(vec![col("lit_with_double_no_alias_metadata")]) + .alias("double_without_alias_metadata"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + let schema = Arc::new(Schema::new(vec![ + Field::new("doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("not_doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("double_without_alias_metadata", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + ])); + + let expected = RecordBatch::try_new( + schema, + vec![ + create_array!(UInt64, [10]), + create_array!(UInt64, [5]), + create_array!(UInt64, [10]), + ], + )?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +/// This UDF is to test extension handling, both on the input and output +/// sides. For the input, we will handle the data differently if there is +/// the canonical extension type Bool8. For the output we will add a +/// user defined extension type. +#[derive(Debug, PartialEq, Eq, Hash)] +struct ExtensionBasedUdf { + name: String, + signature: Signature, +} + +impl Default for ExtensionBasedUdf { + fn default() -> Self { + Self { + name: "canonical_extension_udf".to_string(), + signature: Signature::exact(vec![DataType::Int8], Volatility::Immutable), + } + } +} +impl ScalarUDFImpl for ExtensionBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new("canonical_extension_udf", DataType::Utf8, true) + .with_extension_type(MyUserExtensionType {}) + .into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let input_field = args.arg_fields[0].as_ref(); + + let output_as_bool = matches!( + CanonicalExtensionType::try_from(input_field), + Ok(CanonicalExtensionType::Bool8(_)) + ); + + // If we have the extension type set, we are outputting a boolean value. + // Otherwise we output a string representation of the numeric value. + fn print_value(v: Option, as_bool: bool) -> Option { + v.map(|x| match as_bool { + true => format!("{}", x != 0), + false => format!("{x}"), + }) + } + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| print_value(v, output_as_bool)) + .collect(); + let array_ref = Arc::new(StringArray::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::Int8(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(print_value( + *value, + output_as_bool, + )))) + } + } + } +} + +struct MyUserExtensionType {} + +impl ExtensionType for MyUserExtensionType { + const NAME: &'static str = "my_user_Extension_type"; + type Metadata = (); + + fn metadata(&self) -> &Self::Metadata { + &() + } + + fn serialize_metadata(&self) -> Option { + None + } + + fn deserialize_metadata( + _metadata: Option<&str>, + ) -> std::result::Result { + Ok(()) + } + + fn supports_data_type( + &self, + data_type: &DataType, + ) -> std::result::Result<(), ArrowError> { + if let DataType::Utf8 = data_type { + Ok(()) + } else { + Err(ArrowError::InvalidArgumentError( + "only utf8 supported".to_string(), + )) + } + } + + fn try_new( + _data_type: &DataType, + _metadata: Self::Metadata, + ) -> std::result::Result { + Ok(Self {}) + } +} + +#[tokio::test] +async fn test_extension_based_udf() -> Result<()> { + let data_array = Arc::new(Int8Array::from(vec![0, 0, 10, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_extension", DataType::Int8, true), + Field::new("with_extension", DataType::Int8, true).with_extension_type(Bool8), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let extension_based_udf = ScalarUDF::from(ExtensionBasedUdf::default()); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + extension_based_udf + .call(vec![col("no_extension")]) + .alias("without_bool8_extension"), + extension_based_udf + .call(vec![col("with_extension")]) + .alias("with_bool8_extension"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output extension handling, we set the expected values on the result + // To test for input extensions handling, we check the strings returned + let expected_schema = Schema::new(vec![ + Field::new("without_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtensionType {}), + Field::new("with_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtensionType {}), + ]); + + let expected = record_batch!( + ("without_bool8_extension", Utf8, ["0", "0", "10", "20"]), + ( + "with_bool8_extension", + Utf8, + ["false", "false", "true", "true"] + ) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} + +#[tokio::test] +async fn test_config_options_work_for_scalar_func() -> Result<()> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestScalarUDF { + signature: Signature, + } + + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "TestScalarUDF" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let tz = args.config_options.execution.time_zone.clone(); + Ok(ColumnarValue::Scalar(ScalarValue::from(tz))) + } + } + + let udf = ScalarUDF::from(TestScalarUDF { + signature: Signature::uniform(1, vec![DataType::Utf8], Volatility::Stable), + }); + + let mut config = SessionConfig::new(); + config.options_mut().execution.time_zone = "AEST".into(); + + let ctx = SessionContext::new_with_config(config); + + ctx.register_udf(udf.clone()); + + let df = ctx.read_empty()?; + let df = df.select(vec![udf.call(vec![lit("a")]).alias("a")])?; + let actual = df.collect().await?; + + let expected_schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let expected = RecordBatch::try_new( + SchemaRef::from(expected_schema), + vec![create_array!(Utf8, ["AEST"])], + )?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +/// https://github.com/apache/datafusion/issues/17425 +#[tokio::test] +async fn test_extension_metadata_preserve_in_sql_values() -> Result<()> { + #[derive(Debug, Hash, PartialEq, Eq)] + struct MakeExtension { + signature: Signature, + } + + impl Default for MakeExtension { + fn default() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for MakeExtension { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "make_extension" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + Ok(arg_types.to_vec()) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unreachable!("This shouldn't have been called") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + Ok(args.arg_fields[0] + .as_ref() + .clone() + .with_metadata(HashMap::from([( + "ARROW:extension:metadata".to_string(), + "foofy.foofy".to_string(), + )])) + .into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + Ok(args.args[0].clone()) + } + } + + let ctx = SessionContext::new(); + ctx.register_udf(MakeExtension::default().into()); + + let batches = ctx + .sql( + " +SELECT extension FROM (VALUES + ('one', make_extension('foofy one')), + ('two', make_extension('foofy two')), + ('three', make_extension('foofy three'))) +AS t(string, extension) + ", + ) + .await? + .collect() + .await?; + + assert_eq!( + batches[0] + .schema() + .field(0) + .metadata() + .get("ARROW:extension:metadata"), + Some(&"foofy.foofy".into()) + ); + Ok(()) +} + +/// https://github.com/apache/datafusion/issues/17422 +#[tokio::test] +async fn test_extension_metadata_preserve_in_subquery() -> Result<()> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct ExtensionScalarPredicate { + signature: Signature, + } + + impl Default for ExtensionScalarPredicate { + fn default() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for ExtensionScalarPredicate { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "extension_predicate" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + Ok(arg_types.to_vec()) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unreachable!("This shouldn't have been called") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + for arg in args.arg_fields { + assert!(arg.metadata().contains_key("ARROW:extension:name")); + } + + Ok(Field::new("", DataType::Boolean, true).into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + for arg in args.arg_fields { + assert!(arg.metadata().contains_key("ARROW:extension:name")); + } + + let array = + ScalarValue::Boolean(Some(true)).to_array_of_size(args.number_rows)?; + Ok(ColumnarValue::Array(array)) + } + } + + let schema = Schema::new(vec![ + Field::new("id", DataType::Int64, true), + Field::new("geometry", DataType::Utf8, true).with_metadata(HashMap::from([( + "ARROW:extension:name".to_string(), + "foofy.foofy".to_string(), + )])), + ]); + + let batch_lhs = RecordBatch::try_new( + schema.clone().into(), + vec![ + create_array!(Int64, [1, 2]), + create_array!(Utf8, [Some("item1"), Some("item2")]), + ], + )?; + + let batch_rhs = RecordBatch::try_new( + schema.clone().into(), + vec![ + create_array!(Int64, [2, 3]), + create_array!(Utf8, [Some("item2"), Some("item3")]), + ], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("l", batch_lhs)?; + ctx.register_batch("r", batch_rhs)?; + ctx.register_udf(ExtensionScalarPredicate::default().into()); + + let df = ctx + .sql( + " + SELECT L.id l_id FROM L + WHERE EXISTS (SELECT 1 FROM R WHERE extension_predicate(L.geometry, R.geometry)) + ORDER BY l_id + ", + ) + .await?; + assert!(!df.collect().await?.is_empty()); + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index e4aff0b00705d..2c6611f382cea 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -205,7 +205,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc { let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 28394f0b9dfaf..33607ebc0d2cc 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -18,13 +18,20 @@ //! This module contains end to end tests of creating //! user defined window functions -use arrow::array::{ArrayRef, AsArray, Int64Array, RecordBatch, StringArray}; +use arrow::array::{ + record_batch, Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, + UInt64Array, +}; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::FieldRef; use datafusion::common::test_util::batches_to_string; use datafusion::common::{Result, ScalarValue}; use datafusion::prelude::SessionContext; +use datafusion_common::exec_datafusion_err; +use datafusion_expr::ptr_eq::PtrEq; use datafusion_expr::{ - PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, + LimitEffect, PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, + WindowUDFImpl, }; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_functions_window_common::{ @@ -34,6 +41,8 @@ use datafusion_physical_expr::{ expressions::{col, lit}, PhysicalExpr, }; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; use std::{ any::Any, ops::Range, @@ -120,10 +129,12 @@ async fn test_deregister_udwf() -> Result<()> { OddCounter::register(&mut ctx, Arc::clone(&test_state)); assert!(ctx.state().window_functions().contains_key("odd_counter")); + assert!(datafusion_execution::FunctionRegistry::udwfs(&ctx).contains("odd_counter")); ctx.deregister_udwf("odd_counter"); assert!(!ctx.state().window_functions().contains_key("odd_counter")); + assert!(!datafusion_execution::FunctionRegistry::udwfs(&ctx).contains("odd_counter")); Ok(()) } @@ -137,22 +148,22 @@ async fn test_udwf_with_alias() { .await .unwrap(); - insta::assert_snapshot!(batches_to_string(&actual), @r###" - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | x | y | val | odd_counter(t.val) PARTITION BY [t.x] ORDER BY [t.y ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - | 1 | a | 0 | 1 | - | 1 | b | 1 | 1 | - | 1 | c | 2 | 1 | - | 2 | d | 3 | 2 | - | 2 | e | 4 | 2 | - | 2 | f | 5 | 2 | - | 2 | g | 6 | 2 | - | 2 | h | 6 | 2 | - | 2 | i | 6 | 2 | - | 2 | j | 6 | 2 | - +---+---+-----+-----------------------------------------------------------------------------------------------------------------------+ - "###); + insta::assert_snapshot!(batches_to_string(&actual), @r" + +---+---+-----+--------------------------+ + | x | y | val | odd_counter_alias(t.val) | + +---+---+-----+--------------------------+ + | 1 | a | 0 | 1 | + | 1 | b | 1 | 1 | + | 1 | c | 2 | 1 | + | 2 | d | 3 | 2 | + | 2 | e | 4 | 2 | + | 2 | f | 5 | 2 | + | 2 | g | 6 | 2 | + | 2 | h | 6 | 2 | + | 2 | i | 6 | 2 | + | 2 | j | 6 | 2 | + +---+---+-----+--------------------------+ + "); } /// Basic user defined window function with bounded window @@ -516,10 +527,10 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc) { - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimpleWindowUDF { signature: Signature, - test_state: Arc, + test_state: PtrEq>, aliases: Vec, } @@ -529,7 +540,7 @@ impl OddCounter { Signature::exact(vec![DataType::Float64], Volatility::Immutable); Self { signature, - test_state, + test_state: test_state.into(), aliases: vec!["odd_counter_alias".to_string()], } } @@ -559,8 +570,12 @@ impl OddCounter { &self.aliases } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Int64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Int64, true).into()) + } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown } } @@ -633,11 +648,11 @@ fn odd_count(arr: &Int64Array) -> i64 { /// returns an array of num_rows that has the number of odd values in `arr` fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef { - let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); + let array: Int64Array = std::iter::repeat_n(odd_count(arr), num_rows).collect(); Arc::new(array) } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct VariadicWindowUDF { signature: Signature, } @@ -678,9 +693,13 @@ impl WindowUDFImpl for VariadicWindowUDF { unimplemented!("unnecessary for testing"); } - fn field(&self, _: WindowUDFFieldArgs) -> Result { + fn field(&self, _: WindowUDFFieldArgs) -> Result { unimplemented!("unnecessary for testing"); } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } #[test] @@ -723,11 +742,11 @@ fn test_default_expressions() -> Result<()> { ]; for input_exprs in &test_cases { - let input_types = input_exprs + let input_fields = input_exprs .iter() - .map(|expr: &Arc| expr.data_type(&schema).unwrap()) + .map(|expr: &Arc| expr.return_field(&schema).unwrap()) .collect::>(); - let expr_args = ExpressionArgs::new(input_exprs, &input_types); + let expr_args = ExpressionArgs::new(input_exprs, &input_fields); let ret_exprs = udwf.expressions(expr_args); @@ -735,9 +754,7 @@ fn test_default_expressions() -> Result<()> { assert_eq!( input_exprs.len(), ret_exprs.len(), - "\nInput expressions: {:?}\nReturned expressions: {:?}", - input_exprs, - ret_exprs + "\nInput expressions: {input_exprs:?}\nReturned expressions: {ret_exprs:?}" ); // Compares each returned expression with original input expressions @@ -753,3 +770,178 @@ fn test_default_expressions() -> Result<()> { } Ok(()) } + +#[derive(Debug)] +struct MetadataBasedWindowUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl PartialEq for MetadataBasedWindowUdf { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + signature, + metadata, + } = self; + name == &other.name + && signature == &other.signature + && metadata == &other.metadata + } +} +impl Eq for MetadataBasedWindowUdf {} +impl Hash for MetadataBasedWindowUdf { + fn hash(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + name.hash(state); + signature.hash(state); + } +} + +impl MetadataBasedWindowUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl WindowUDFImpl for MetadataBasedWindowUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let input_field = partition_evaluator_args + .input_fields() + .first() + .ok_or(exec_datafusion_err!("Expected one argument"))?; + + let double_output = input_field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedPartitionEvaluator { double_output })) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } +} + +#[derive(Debug)] +struct MetadataBasedPartitionEvaluator { + double_output: bool, +} + +impl PartitionEvaluator for MetadataBasedPartitionEvaluator { + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { + let values = values[0].as_any().downcast_ref::().unwrap(); + let sum = values.iter().fold(0_u64, |acc, v| acc + v.unwrap_or(0)); + + let result = if self.double_output { sum * 2 } else { sum }; + + Ok(Arc::new(UInt64Array::from_value(result, num_rows))) + } +} + +#[tokio::test] +async fn test_metadata_based_window_fn() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new(HashMap::new())); + let with_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let df = df.select(vec![ + no_output_meta_udf + .call(vec![datafusion_expr::col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![datafusion_expr::col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![datafusion_expr::col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![datafusion_expr::col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]), + ("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} diff --git a/datafusion/datasource-avro/Cargo.toml b/datafusion/datasource-avro/Cargo.toml index 064f9f87ee9fe..e013e8a3d0934 100644 --- a/datafusion/datasource-avro/Cargo.toml +++ b/datafusion/datasource-avro/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-datasource-avro" description = "datafusion-datasource-avro" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -35,22 +35,16 @@ apache-avro = { workspace = true } arrow = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } -chrono = { workspace = true } -datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, features = ["object_store", "avro"] } datafusion-datasource = { workspace = true } -datafusion-execution = { workspace = true } -datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } futures = { workspace = true } num-traits = { version = "0.2" } object_store = { workspace = true } -tokio = { workspace = true } [dev-dependencies] -rstest = { workspace = true } serde_json = { workspace = true } [lints] diff --git a/datafusion/datasource-avro/README.md b/datafusion/datasource-avro/README.md index f8d7aebdcad18..e9b8affe60e36 100644 --- a/datafusion/datasource-avro/README.md +++ b/datafusion/datasource-avro/README.md @@ -17,10 +17,17 @@ under the License. --> -# DataFusion datasource +# Apache DataFusion Avro DataSource -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. -This crate is a submodule of DataFusion that defines a Avro based file source. +This crate is a submodule of DataFusion that defines an [Apache Avro] based file source. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[apache avro]: https://avro.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs index 9a1b54b872ad7..a80f18cf818fe 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs @@ -19,9 +19,10 @@ use apache_avro::schema::RecordSchema; use apache_avro::{ + error::Details as AvroErrorDetails, schema::{Schema as AvroSchema, SchemaKind}, types::Value, - AvroResult, Error as AvroError, Reader as AvroReader, + Error as AvroError, Reader as AvroReader, }; use arrow::array::{ make_array, Array, ArrayBuilder, ArrayData, ArrayDataBuilder, ArrayRef, @@ -33,7 +34,7 @@ use arrow::buffer::{Buffer, MutableBuffer}; use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type, Date64Type, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, @@ -56,23 +57,17 @@ type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>]; pub struct AvroArrowArrayReader<'a, R: Read> { reader: AvroReader<'a, R>, schema: SchemaRef, - projection: Option>, schema_lookup: BTreeMap, } impl AvroArrowArrayReader<'_, R> { - pub fn try_new( - reader: R, - schema: SchemaRef, - projection: Option>, - ) -> Result { + pub fn try_new(reader: R, schema: SchemaRef) -> Result { let reader = AvroReader::new(reader)?; let writer_schema = reader.writer_schema().clone(); let schema_lookup = Self::schema_lookup(writer_schema)?; Ok(Self { reader, schema, - projection, schema_lookup, }) } @@ -123,7 +118,7 @@ impl AvroArrowArrayReader<'_, R> { AvroSchema::Record(RecordSchema { fields, lookup, .. }) => { lookup.iter().for_each(|(field_name, pos)| { schema_lookup - .insert(format!("{}.{}", parent_field_name, field_name), *pos); + .insert(format!("{parent_field_name}.{field_name}"), *pos); }); for field in fields { @@ -137,7 +132,7 @@ impl AvroArrowArrayReader<'_, R> { } } AvroSchema::Array(schema) => { - let sub_parent_field_name = format!("{}.element", parent_field_name); + let sub_parent_field_name = format!("{parent_field_name}.element"); Self::child_schema_lookup( &sub_parent_field_name, &schema.items, @@ -158,7 +153,7 @@ impl AvroArrowArrayReader<'_, R> { .map(|value| match value { Ok(Value::Record(v)) => Ok(v), Err(e) => Err(ArrowError::ParseError(format!( - "Failed to parse avro value: {e:?}" + "Failed to parse avro value: {e}" ))), other => Err(ArrowError::ParseError(format!( "Row needs to be of type object, got: {other:?}" @@ -175,20 +170,9 @@ impl AvroArrowArrayReader<'_, R> { }; let rows = rows.iter().collect::>>(); - let projection = self.projection.clone().unwrap_or_default(); - let arrays = - self.build_struct_array(&rows, "", self.schema.fields(), &projection); - let projected_fields = if projection.is_empty() { - self.schema.fields().clone() - } else { - projection - .iter() - .filter_map(|name| self.schema.column_with_name(name)) - .map(|(_, field)| field.clone()) - .collect() - }; - let projected_schema = Arc::new(Schema::new(projected_fields)); - Some(arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr))) + let arrays = self.build_struct_array(&rows, "", self.schema.fields()); + + Some(arrays.and_then(|arr| RecordBatch::try_new(Arc::clone(&self.schema), arr))) } fn build_boolean_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef { @@ -297,7 +281,7 @@ impl AvroArrowArrayReader<'_, R> { self.list_array_string_array_builder::(&dtype, col_name, rows) } ref e => Err(SchemaError(format!( - "Data type is currently not supported for dictionaries in list : {e:?}" + "Data type is currently not supported for dictionaries in list : {e}" ))), } } @@ -324,7 +308,7 @@ impl AvroArrowArrayReader<'_, R> { } e => { return Err(SchemaError(format!( - "Nested list data builder type is not supported: {e:?}" + "Nested list data builder type is not supported: {e}" ))) } }; @@ -389,7 +373,7 @@ impl AvroArrowArrayReader<'_, R> { } e => { return Err(SchemaError(format!( - "Nested list data builder type is not supported: {e:?}" + "Nested list data builder type is not supported: {e}" ))) } } @@ -615,7 +599,7 @@ impl AvroArrowArrayReader<'_, R> { let sub_parent_field_name = format!("{}.{}", parent_field_name, list_field.name()); let arrays = - self.build_struct_array(&rows, &sub_parent_field_name, fields, &[])?; + self.build_struct_array(&rows, &sub_parent_field_name, fields)?; let data_type = DataType::Struct(fields.clone()); ArrayDataBuilder::new(data_type) .len(rows.len()) @@ -626,7 +610,7 @@ impl AvroArrowArrayReader<'_, R> { } datatype => { return Err(SchemaError(format!( - "Nested list of {datatype:?} not supported" + "Nested list of {datatype} not supported" ))); } }; @@ -645,20 +629,14 @@ impl AvroArrowArrayReader<'_, R> { /// The function does not construct the StructArray as some callers would want the child arrays. /// /// *Note*: The function is recursive, and will read nested structs. - /// - /// If `projection` is not empty, then all values are returned. The first level of projection - /// occurs at the `RecordBatch` level. No further projection currently occurs, but would be - /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`. fn build_struct_array( &self, rows: RecordSlice, parent_field_name: &str, struct_fields: &Fields, - projection: &[String], ) -> ArrowResult> { let arrays: ArrowResult> = struct_fields .iter() - .filter(|field| projection.is_empty() || projection.contains(field.name())) .map(|field| { let field_path = if parent_field_name.is_empty() { field.name().to_string() @@ -840,12 +818,8 @@ impl AvroArrowArrayReader<'_, R> { } }) .collect::>>(); - let arrays = self.build_struct_array( - &struct_rows, - &field_path, - fields, - &[], - )?; + let arrays = + self.build_struct_array(&struct_rows, &field_path, fields)?; // construct a struct array's data in order to set null buffer let data_type = DataType::Struct(fields.clone()); let data = ArrayDataBuilder::new(data_type) @@ -857,7 +831,7 @@ impl AvroArrowArrayReader<'_, R> { } _ => { return Err(SchemaError(format!( - "type {:?} not supported", + "type {} not supported", field.data_type() ))) } @@ -956,49 +930,40 @@ fn resolve_string(v: &Value) -> ArrowResult> { match v { Value::String(s) => Ok(Some(s.clone())), Value::Bytes(bytes) => String::from_utf8(bytes.to_vec()) - .map_err(AvroError::ConvertToUtf8) + .map_err(|e| AvroError::new(AvroErrorDetails::ConvertToUtf8(e))) .map(Some), Value::Enum(_, s) => Ok(Some(s.clone())), Value::Null => Ok(None), - other => Err(AvroError::GetString(other.into())), + other => Err(AvroError::new(AvroErrorDetails::GetString(other.clone()))), } - .map_err(|e| SchemaError(format!("expected resolvable string : {e:?}"))) + .map_err(|e| SchemaError(format!("expected resolvable string : {e}"))) } -fn resolve_u8(v: &Value) -> AvroResult { - let int = match v { - Value::Int(n) => Ok(Value::Int(*n)), - Value::Long(n) => Ok(Value::Int(*n as i32)), - other => Err(AvroError::GetU8(other.into())), - }?; - if let Value::Int(n) = int { - if n >= 0 && n <= From::from(u8::MAX) { - return Ok(n as u8); - } - } +fn resolve_u8(v: &Value) -> Option { + let v = match v { + Value::Union(_, inner) => inner.as_ref(), + _ => v, + }; - Err(AvroError::GetU8(int.into())) + match v { + Value::Int(n) => u8::try_from(*n).ok(), + Value::Long(n) => u8::try_from(*n).ok(), + _ => None, + } } fn resolve_bytes(v: &Value) -> Option> { - let v = if let Value::Union(_, b) = v { b } else { v }; + let v = match v { + Value::Union(_, inner) => inner.as_ref(), + _ => v, + }; + match v { - Value::Bytes(_) => Ok(v.clone()), - Value::String(s) => Ok(Value::Bytes(s.clone().into_bytes())), - Value::Array(items) => Ok(Value::Bytes( - items - .iter() - .map(resolve_u8) - .collect::, _>>() - .ok()?, - )), - other => Err(AvroError::GetBytes(other.into())), - } - .ok() - .and_then(|v| match v { - Value::Bytes(s) => Some(s), + Value::Bytes(bytes) => Some(bytes.clone()), + Value::String(s) => Some(s.as_bytes().to_vec()), + Value::Array(items) => items.iter().map(resolve_u8).collect::>>(), _ => None, - }) + } } fn resolve_fixed(v: &Value, size: usize) -> Option> { @@ -1082,7 +1047,7 @@ mod test { use std::fs::File; use std::sync::Arc; - fn build_reader(name: &str, batch_size: usize) -> Reader { + fn build_reader(name: &'_ str, batch_size: usize) -> Reader<'_, File> { let testdata = datafusion_common::test_util::arrow_test_data(); let filename = format!("{testdata}/avro/{name}"); let builder = ReaderBuilder::new() diff --git a/datafusion/datasource-avro/src/avro_to_arrow/reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/reader.rs index bc7b50a9cdc31..9a4d13fc191da 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/reader.rs @@ -16,7 +16,7 @@ // under the License. use super::arrow_array_reader::AvroArrowArrayReader; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use datafusion_common::Result; @@ -133,19 +133,35 @@ impl Reader<'_, R> { /// /// If reading a `File`, you can customise the Reader, such as to enable schema /// inference, use `ReaderBuilder`. + /// + /// If projection is provided, it uses a schema with only the fields in the projection, respecting their order. + /// Only the first level of projection is handled. No further projection currently occurs, but would be + /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`. pub fn try_new( reader: R, schema: SchemaRef, batch_size: usize, projection: Option>, ) -> Result { + let projected_schema = projection.as_ref().filter(|p| !p.is_empty()).map_or_else( + || Arc::clone(&schema), + |proj| { + Arc::new(arrow::datatypes::Schema::new( + proj.iter() + .filter_map(|name| { + schema.column_with_name(name).map(|(_, f)| f.clone()) + }) + .collect::(), + )) + }, + ); + Ok(Self { array_reader: AvroArrowArrayReader::try_new( reader, - Arc::clone(&schema), - projection, + Arc::clone(&projected_schema), )?, - schema, + schema: projected_schema, batch_size, }) } @@ -179,10 +195,13 @@ mod tests { use arrow::datatypes::{DataType, Field}; use std::fs::File; - fn build_reader(name: &str) -> Reader { + fn build_reader(name: &'_ str, projection: Option>) -> Reader<'_, File> { let testdata = datafusion_common::test_util::arrow_test_data(); let filename = format!("{testdata}/avro/{name}"); - let builder = ReaderBuilder::new().read_schema().with_batch_size(64); + let mut builder = ReaderBuilder::new().read_schema().with_batch_size(64); + if let Some(projection) = projection { + builder = builder.with_projection(projection); + } builder.build(File::open(filename).unwrap()).unwrap() } @@ -195,7 +214,7 @@ mod tests { #[test] fn test_avro_basic() { - let mut reader = build_reader("alltypes_dictionary.avro"); + let mut reader = build_reader("alltypes_dictionary.avro", None); let batch = reader.next().unwrap().unwrap(); assert_eq!(11, batch.num_columns()); @@ -281,4 +300,58 @@ mod tests { assert_eq!(1230768000000000, col.value(0)); assert_eq!(1230768060000000, col.value(1)); } + + #[test] + fn test_avro_with_projection() { + // Test projection to filter and reorder columns + let projection = Some(vec![ + "string_col".to_string(), + "double_col".to_string(), + "bool_col".to_string(), + ]); + let mut reader = build_reader("alltypes_dictionary.avro", projection); + let batch = reader.next().unwrap().unwrap(); + + // Only 3 columns should be present (not all 11) + assert_eq!(3, batch.num_columns()); + assert_eq!(2, batch.num_rows()); + + let schema = reader.schema(); + let batch_schema = batch.schema(); + assert_eq!(schema, batch_schema); + + // Verify columns are in the order specified in projection + // First column should be string_col (was at index 9 in original) + assert_eq!("string_col", schema.field(0).name()); + assert_eq!(&DataType::Binary, schema.field(0).data_type()); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!("0".as_bytes(), col.value(0)); + assert_eq!("1".as_bytes(), col.value(1)); + + // Second column should be double_col (was at index 7 in original) + assert_eq!("double_col", schema.field(1).name()); + assert_eq!(&DataType::Float64, schema.field(1).data_type()); + let col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(0.0, col.value(0)); + assert_eq!(10.1, col.value(1)); + + // Third column should be bool_col (was at index 1 in original) + assert_eq!("bool_col", schema.field(2).name()); + assert_eq!(&DataType::Boolean, schema.field(2).data_type()); + let col = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col.value(0)); + assert!(!col.value(1)); + } } diff --git a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs index 276056c24c01c..3fce0d4826a22 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs @@ -22,7 +22,7 @@ use apache_avro::types::Value; use apache_avro::Schema as AvroSchema; use arrow::datatypes::{DataType, IntervalUnit, Schema, TimeUnit, UnionMode}; use arrow::datatypes::{Field, UnionFields}; -use datafusion_common::error::{DataFusionError, Result}; +use datafusion_common::error::Result; use std::collections::HashMap; use std::sync::Arc; @@ -107,9 +107,10 @@ fn schema_to_field_with_props( .data_type() .clone() } else { - return Err(DataFusionError::AvroError( - apache_avro::Error::GetUnionDuplicate, - )); + return Err(apache_avro::Error::new( + apache_avro::error::Details::GetUnionDuplicate, + ) + .into()); } } else { let fields = sub_schemas @@ -237,6 +238,8 @@ fn default_field_name(dt: &DataType) -> &str { | DataType::LargeListView(_) => { unimplemented!("View support not implemented") } + DataType::Decimal32(_, _) => "decimal", + DataType::Decimal64(_, _) => "decimal", DataType::Decimal128(_, _) => "decimal", DataType::Decimal256(_, _) => "decimal", } diff --git a/datafusion/datasource-avro/src/file_format.rs b/datafusion/datasource-avro/src/file_format.rs index 4b50fee1d326b..60c361b42e771 100644 --- a/datafusion/datasource-avro/src/file_format.rs +++ b/datafusion/datasource-avro/src/file_format.rs @@ -37,7 +37,6 @@ use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::source::DataSourceExec; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; @@ -111,6 +110,10 @@ impl FileFormat for AvroFormat { } } + fn compression_type(&self) -> Option { + None + } + async fn infer_schema( &self, _state: &dyn Session, @@ -150,7 +153,6 @@ impl FileFormat for AvroFormat { &self, _state: &dyn Session, conf: FileScanConfig, - _filters: Option<&Arc>, ) -> Result> { let config = FileScanConfigBuilder::from(conf) .with_source(self.file_source()) diff --git a/datafusion/datasource-avro/src/mod.rs b/datafusion/datasource-avro/src/mod.rs index 71996f3f0eaa2..ad8ebe11446f5 100644 --- a/datafusion/datasource-avro/src/mod.rs +++ b/datafusion/datasource-avro/src/mod.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] @@ -30,4 +30,5 @@ pub mod avro_to_arrow; pub mod file_format; pub mod source; +pub use apache_avro; pub use file_format::*; diff --git a/datafusion/datasource-avro/src/source.rs b/datafusion/datasource-avro/src/source.rs index ce3722e7b11ee..0916222337b80 100644 --- a/datafusion/datasource-avro/src/source.rs +++ b/datafusion/datasource-avro/src/source.rs @@ -18,142 +18,22 @@ //! Execution plan for reading line-delimited Avro files use std::any::Any; -use std::fmt::Formatter; use std::sync::Arc; use crate::avro_to_arrow::Reader as AvroReader; -use datafusion_common::error::Result; - use arrow::datatypes::SchemaRef; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::error::Result; +use datafusion_common::Statistics; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_datasource::file_stream::FileOpener; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use object_store::ObjectStore; -/// Execution plan for scanning Avro data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct AvroExec { - inner: DataSourceExec, - base_config: FileScanConfig, -} - -#[allow(unused, deprecated)] -impl AvroExec { - /// Create a new Avro reader execution plan provided base configurations - pub fn new(base_config: FileScanConfig) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - Arc::clone(&projected_schema), - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let base_config = base_config.with_source(Arc::new(AvroSource::default())); - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - let n_partitions = file_scan_config.file_groups.len(); - - PlanProperties::new( - eq_properties, - Partitioning::UnknownPartitioning(n_partitions), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for AvroExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for AvroExec { - fn name(&self) -> &'static str { - "AvroExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// AvroSource holds the extra configuration that is necessary for opening avro files #[derive(Clone, Default)] pub struct AvroSource { @@ -162,6 +42,7 @@ pub struct AvroSource { projection: Option>, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl AvroSource { @@ -244,13 +125,27 @@ impl FileSource for AvroSource { ) -> Result> { Ok(None) } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } mod private { use super::*; use bytes::Buf; - use datafusion_datasource::{file_meta::FileMeta, file_stream::FileOpenFuture}; + use datafusion_datasource::{file_stream::FileOpenFuture, PartitionedFile}; use futures::StreamExt; use object_store::{GetResultPayload, ObjectStore}; @@ -260,20 +155,26 @@ mod private { } impl FileOpener for AvroOpener { - fn open(&self, file_meta: FileMeta) -> Result { + fn open(&self, partitioned_file: PartitionedFile) -> Result { let config = Arc::clone(&self.config); let object_store = Arc::clone(&self.object_store); Ok(Box::pin(async move { - let r = object_store.get(file_meta.location()).await?; + let r = object_store + .get(&partitioned_file.object_meta.location) + .await?; match r.payload { GetResultPayload::File(file, _) => { let reader = config.open(file)?; - Ok(futures::stream::iter(reader).boxed()) + Ok(futures::stream::iter(reader) + .map(|r| r.map_err(Into::into)) + .boxed()) } GetResultPayload::Stream(_) => { let bytes = r.bytes().await?; let reader = config.open(bytes.reader())?; - Ok(futures::stream::iter(reader).boxed()) + Ok(futures::stream::iter(reader) + .map(|r| r.map_err(Into::into)) + .boxed()) } } })) diff --git a/datafusion/datasource-csv/Cargo.toml b/datafusion/datasource-csv/Cargo.toml index c9e4649bdc25d..209cea403896b 100644 --- a/datafusion/datasource-csv/Cargo.toml +++ b/datafusion/datasource-csv/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-datasource-csv" description = "datafusion-datasource-csv" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -34,13 +34,11 @@ all-features = true arrow = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } -datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } diff --git a/datafusion/datasource-csv/README.md b/datafusion/datasource-csv/README.md index c5944f9e438fa..8bdadd0fe2c13 100644 --- a/datafusion/datasource-csv/README.md +++ b/datafusion/datasource-csv/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion datasource +# Apache DataFusion CSV DataSource -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that defines a CSV based file source. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource-csv/src/file_format.rs b/datafusion/datasource-csv/src/file_format.rs index 76f3c50a70a7c..1c39893b23c85 100644 --- a/datafusion/datasource-csv/src/file_format.rs +++ b/datafusion/datasource-csv/src/file_format.rs @@ -50,7 +50,6 @@ use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join; use datafusion_datasource::write::BatchSerializer; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; @@ -152,13 +151,13 @@ impl CsvFormat { let stream = store .get(&object.location) .await - .map_err(DataFusionError::ObjectStore); + .map_err(|e| DataFusionError::ObjectStore(Box::new(e))); let stream = match stream { Ok(stream) => self .read_to_delimited_chunks_from_stream( stream .into_stream() - .map_err(DataFusionError::ObjectStore) + .map_err(|e| DataFusionError::ObjectStore(Box::new(e))) .boxed(), ) .await @@ -182,7 +181,7 @@ impl CsvFormat { let stream = match decoder { Ok(decoded_stream) => { newline_delimited_stream(decoded_stream.map_err(|e| match e { - DataFusionError::ObjectStore(e) => e, + DataFusionError::ObjectStore(e) => *e, err => object_store::Error::Generic { store: "read to delimited chunks failed", source: Box::new(err), @@ -223,6 +222,11 @@ impl CsvFormat { self } + pub fn with_truncated_rows(mut self, truncated_rows: bool) -> Self { + self.options.truncated_rows = Some(truncated_rows); + self + } + /// Set the regex to use for null values in the CSV reader. /// - default to treat empty values as null. pub fn with_null_regex(mut self, null_regex: Option) -> Self { @@ -292,6 +296,13 @@ impl CsvFormat { self } + /// Set whether rows should be truncated to the column width + /// - defaults to false + pub fn with_truncate_rows(mut self, truncate_rows: bool) -> Self { + self.options.truncated_rows = Some(truncate_rows); + self + } + /// The delimiter character. pub fn delimiter(&self) -> u8 { self.options.delimiter @@ -359,6 +370,10 @@ impl FileFormat for CsvFormat { Ok(format!("{}{}", ext, file_compression_type.get_ext())) } + fn compression_type(&self) -> Option { + Some(self.options.compression.into()) + } + async fn infer_schema( &self, state: &dyn Session, @@ -408,27 +423,28 @@ impl FileFormat for CsvFormat { &self, state: &dyn Session, conf: FileScanConfig, - _filters: Option<&Arc>, ) -> Result> { // Consult configuration options for default values let has_header = self .options .has_header - .unwrap_or(state.config_options().catalog.has_header); + .unwrap_or_else(|| state.config_options().catalog.has_header); let newlines_in_values = self .options .newlines_in_values - .unwrap_or(state.config_options().catalog.newlines_in_values); + .unwrap_or_else(|| state.config_options().catalog.newlines_in_values); let conf_builder = FileScanConfigBuilder::from(conf) .with_file_compression_type(self.options.compression.into()) .with_newlines_in_values(newlines_in_values); + let truncated_rows = self.options.truncated_rows.unwrap_or(false); let source = Arc::new( CsvSource::new(has_header, self.options.delimiter, self.options.quote) .with_escape(self.options.escape) .with_terminator(self.options.terminator) - .with_comment(self.options.comment), + .with_comment(self.options.comment) + .with_truncate_rows(truncated_rows), ); let config = conf_builder.with_source(source).build(); @@ -454,11 +470,11 @@ impl FileFormat for CsvFormat { let has_header = self .options() .has_header - .unwrap_or(state.config_options().catalog.has_header); + .unwrap_or_else(|| state.config_options().catalog.has_header); let newlines_in_values = self .options() .newlines_in_values - .unwrap_or(state.config_options().catalog.newlines_in_values); + .unwrap_or_else(|| state.config_options().catalog.newlines_in_values); let options = self .options() @@ -481,7 +497,20 @@ impl FileFormat for CsvFormat { impl CsvFormat { /// Return the inferred schema reading up to records_to_read from a /// stream of delimited chunks returning the inferred schema and the - /// number of lines that were read + /// number of lines that were read. + /// + /// This method can handle CSV files with different numbers of columns. + /// The inferred schema will be the union of all columns found across all files. + /// Files with fewer columns will have missing columns filled with null values. + /// + /// # Example + /// + /// If you have two CSV files: + /// - `file1.csv`: `col1,col2,col3` + /// - `file2.csv`: `col1,col2,col3,col4,col5` + /// + /// The inferred schema will contain all 5 columns, with files that don't + /// have columns 4 and 5 having null values for those columns. pub async fn infer_schema_from_stream( &self, state: &dyn Session, @@ -504,10 +533,11 @@ impl CsvFormat { && self .options .has_header - .unwrap_or(state.config_options().catalog.has_header), + .unwrap_or_else(|| state.config_options().catalog.has_header), ) .with_delimiter(self.options.delimiter) - .with_quote(self.options.quote); + .with_quote(self.options.quote) + .with_truncated_rows(self.options.truncated_rows.unwrap_or(false)); if let Some(null_regex) = &self.options.null_regex { let regex = Regex::new(null_regex.as_str()) @@ -543,21 +573,37 @@ impl CsvFormat { }) .unzip(); } else { - if fields.len() != column_type_possibilities.len() { + if fields.len() != column_type_possibilities.len() + && !self.options.truncated_rows.unwrap_or(false) + { return exec_err!( - "Encountered unequal lengths between records on CSV file whilst inferring schema. \ - Expected {} fields, found {} fields at record {}", - column_type_possibilities.len(), - fields.len(), - record_number + 1 - ); + "Encountered unequal lengths between records on CSV file whilst inferring schema. \ + Expected {} fields, found {} fields at record {}", + column_type_possibilities.len(), + fields.len(), + record_number + 1 + ); } + // First update type possibilities for existing columns using zip column_type_possibilities.iter_mut().zip(&fields).for_each( |(possibilities, field)| { possibilities.insert(field.data_type().clone()); }, ); + + // Handle files with different numbers of columns by extending the schema + if fields.len() > column_type_possibilities.len() { + // New columns found - extend our tracking structures + for field in fields.iter().skip(column_type_possibilities.len()) { + column_names.push(field.name().clone()); + let mut possibilities = HashSet::new(); + if records_read > 0 { + possibilities.insert(field.data_type().clone()); + } + column_type_possibilities.push(possibilities); + } + } } if records_to_read == 0 { @@ -565,20 +611,28 @@ impl CsvFormat { } } - let schema = build_schema_helper(column_names, &column_type_possibilities); + let schema = build_schema_helper(column_names, column_type_possibilities); Ok((schema, total_records_read)) } } -fn build_schema_helper(names: Vec, types: &[HashSet]) -> Schema { +fn build_schema_helper(names: Vec, types: Vec>) -> Schema { let fields = names .into_iter() .zip(types) - .map(|(field_name, data_type_possibilities)| { + .map(|(field_name, mut data_type_possibilities)| { // ripped from arrow::csv::reader::infer_reader_schema_with_csv_options // determine data type based on possible types // if there are incompatible types, use DataType::Utf8 + + // ignore nulls, to avoid conflicting datatypes (e.g. [nulls, int]) being inferred as Utf8. + data_type_possibilities.remove(&DataType::Null); + match data_type_possibilities.len() { + // Return Null for columns with only nulls / empty files + // This allows schema merging to work when reading folders + // such files along with normal files. + 0 => Field::new(field_name, DataType::Null, true), 1 => Field::new( field_name, data_type_possibilities.iter().next().unwrap().clone(), @@ -744,3 +798,82 @@ impl DataSink for CsvSink { FileSink::write_all(self, data, context).await } } + +#[cfg(test)] +mod tests { + use super::build_schema_helper; + use arrow::datatypes::DataType; + use std::collections::HashSet; + + #[test] + fn test_build_schema_helper_different_column_counts() { + // Test the core schema building logic with different column counts + let mut column_names = + vec!["col1".to_string(), "col2".to_string(), "col3".to_string()]; + + // Simulate adding two more columns from another file + column_names.push("col4".to_string()); + column_names.push("col5".to_string()); + + let column_type_possibilities = vec![ + HashSet::from([DataType::Int64]), + HashSet::from([DataType::Utf8]), + HashSet::from([DataType::Float64]), + HashSet::from([DataType::Utf8]), // col4 + HashSet::from([DataType::Utf8]), // col5 + ]; + + let schema = build_schema_helper(column_names, column_type_possibilities); + + // Verify schema has 5 columns + assert_eq!(schema.fields().len(), 5); + assert_eq!(schema.field(0).name(), "col1"); + assert_eq!(schema.field(1).name(), "col2"); + assert_eq!(schema.field(2).name(), "col3"); + assert_eq!(schema.field(3).name(), "col4"); + assert_eq!(schema.field(4).name(), "col5"); + + // All fields should be nullable + for field in schema.fields() { + assert!( + field.is_nullable(), + "Field {} should be nullable", + field.name() + ); + } + } + + #[test] + fn test_build_schema_helper_type_merging() { + // Test type merging logic + let column_names = vec!["col1".to_string(), "col2".to_string()]; + + let column_type_possibilities = vec![ + HashSet::from([DataType::Int64, DataType::Float64]), // Should resolve to Float64 + HashSet::from([DataType::Utf8]), // Should remain Utf8 + ]; + + let schema = build_schema_helper(column_names, column_type_possibilities); + + // col1 should be Float64 due to Int64 + Float64 = Float64 + assert_eq!(*schema.field(0).data_type(), DataType::Float64); + + // col2 should remain Utf8 + assert_eq!(*schema.field(1).data_type(), DataType::Utf8); + } + + #[test] + fn test_build_schema_helper_conflicting_types() { + // Test when we have incompatible types - should default to Utf8 + let column_names = vec!["col1".to_string()]; + + let column_type_possibilities = vec![ + HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), // Should resolve to Utf8 due to conflicts + ]; + + let schema = build_schema_helper(column_names, column_type_possibilities); + + // Should default to Utf8 for conflicting types + assert_eq!(*schema.field(0).data_type(), DataType::Utf8); + } +} diff --git a/datafusion/datasource-csv/src/source.rs b/datafusion/datasource-csv/src/source.rs index 6db4d18703204..0445329d06530 100644 --- a/datafusion/datasource-csv/src/source.rs +++ b/datafusion/datasource-csv/src/source.rs @@ -17,6 +17,7 @@ //! Execution plan for reading CSV files +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use std::any::Any; use std::fmt; use std::io::{Read, Seek, SeekFrom}; @@ -25,382 +26,30 @@ use std::task::Poll; use datafusion_datasource::decoder::{deserialize_stream, DecoderDeserializer}; use datafusion_datasource::file_compression_type::FileCompressionType; -use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; use datafusion_datasource::{ - calculate_range, FileRange, ListingTableUrl, RangeCalculation, + as_file_source, calculate_range, FileRange, ListingTableUrl, PartitionedFile, + RangeCalculation, }; use arrow::csv; use arrow::datatypes::SchemaRef; -use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, DataFusionError, Result, Statistics}; +use datafusion_common::{DataFusionError, Result, Statistics}; use datafusion_common_runtime::JoinSet; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, }; use crate::file_format::CsvDecoder; -use datafusion_datasource::file_groups::FileGroup; use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; -/// Old Csv source, deprecated with DataSourceExec implementation and CsvSource -/// -/// See examples on `CsvSource` -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct CsvExec { - base_config: FileScanConfig, - inner: DataSourceExec, -} - -/// Builder for [`CsvExec`]. -/// -/// See example on [`CsvExec`]. -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use FileScanConfig instead")] -pub struct CsvExecBuilder { - file_scan_config: FileScanConfig, - file_compression_type: FileCompressionType, - // TODO: it seems like these format options could be reused across all the various CSV config - has_header: bool, - delimiter: u8, - quote: u8, - terminator: Option, - escape: Option, - comment: Option, - newlines_in_values: bool, -} - -#[allow(unused, deprecated)] -impl CsvExecBuilder { - /// Create a new builder to read the provided file scan configuration. - pub fn new(file_scan_config: FileScanConfig) -> Self { - Self { - file_scan_config, - // TODO: these defaults are duplicated from `CsvOptions` - should they be computed? - has_header: false, - delimiter: b',', - quote: b'"', - terminator: None, - escape: None, - comment: None, - newlines_in_values: false, - file_compression_type: FileCompressionType::UNCOMPRESSED, - } - } - - /// Set whether the first row defines the column names. - /// - /// The default value is `false`. - pub fn with_has_header(mut self, has_header: bool) -> Self { - self.has_header = has_header; - self - } - - /// Set the column delimeter. - /// - /// The default is `,`. - pub fn with_delimeter(mut self, delimiter: u8) -> Self { - self.delimiter = delimiter; - self - } - - /// Set the quote character. - /// - /// The default is `"`. - pub fn with_quote(mut self, quote: u8) -> Self { - self.quote = quote; - self - } - - /// Set the line terminator. If not set, the default is CRLF. - /// - /// The default is None. - pub fn with_terminator(mut self, terminator: Option) -> Self { - self.terminator = terminator; - self - } - - /// Set the escape character. - /// - /// The default is `None` (i.e. quotes cannot be escaped). - pub fn with_escape(mut self, escape: Option) -> Self { - self.escape = escape; - self - } - - /// Set the comment character. - /// - /// The default is `None` (i.e. comments are not supported). - pub fn with_comment(mut self, comment: Option) -> Self { - self.comment = comment; - self - } - - /// Set whether newlines in (quoted) values are supported. - /// - /// Parsing newlines in quoted values may be affected by execution behaviour such as - /// parallel file scanning. Setting this to `true` ensures that newlines in values are - /// parsed successfully, which may reduce performance. - /// - /// The default value is `false`. - pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { - self.newlines_in_values = newlines_in_values; - self - } - - /// Set the file compression type. - /// - /// The default is [`FileCompressionType::UNCOMPRESSED`]. - pub fn with_file_compression_type( - mut self, - file_compression_type: FileCompressionType, - ) -> Self { - self.file_compression_type = file_compression_type; - self - } - - /// Build a [`CsvExec`]. - #[must_use] - pub fn build(self) -> CsvExec { - let Self { - file_scan_config: base_config, - file_compression_type, - has_header, - delimiter, - quote, - terminator, - escape, - comment, - newlines_in_values, - } = self; - - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = CsvExec::compute_properties( - projected_schema, - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let csv = CsvSource::new(has_header, delimiter, quote) - .with_comment(comment) - .with_escape(escape) - .with_terminator(terminator); - let base_config = base_config - .with_newlines_in_values(newlines_in_values) - .with_file_compression_type(file_compression_type) - .with_source(Arc::new(csv)); - - CsvExec { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } -} - -#[allow(unused, deprecated)] -impl CsvExec { - /// Create a new CSV reader execution plan provided base and specific configurations - #[allow(clippy::too_many_arguments)] - pub fn new( - base_config: FileScanConfig, - has_header: bool, - delimiter: u8, - quote: u8, - terminator: Option, - escape: Option, - comment: Option, - newlines_in_values: bool, - file_compression_type: FileCompressionType, - ) -> Self { - CsvExecBuilder::new(base_config) - .with_has_header(has_header) - .with_delimeter(delimiter) - .with_quote(quote) - .with_terminator(terminator) - .with_escape(escape) - .with_comment(comment) - .with_newlines_in_values(newlines_in_values) - .with_file_compression_type(file_compression_type) - .build() - } - - /// Return a [`CsvExecBuilder`]. - /// - /// See example on [`CsvExec`] and [`CsvExecBuilder`] for specifying CSV table options. - pub fn builder(file_scan_config: FileScanConfig) -> CsvExecBuilder { - CsvExecBuilder::new(file_scan_config) - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn csv_source(&self) -> CsvSource { - let source = self.file_scan_config(); - source - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - /// true if the first line of each file is a header - pub fn has_header(&self) -> bool { - self.csv_source().has_header() - } - - /// Specifies whether newlines in (quoted) values are supported. - /// - /// Parsing newlines in quoted values may be affected by execution behaviour such as - /// parallel file scanning. Setting this to `true` ensures that newlines in values are - /// parsed successfully, which may reduce performance. - /// - /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. - pub fn newlines_in_values(&self) -> bool { - let source = self.file_scan_config(); - source.newlines_in_values() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for CsvExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for CsvExec { - fn name(&self) -> &'static str { - "CsvExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - /// - /// Return `None` if can't get repartitioned (empty, compressed file, or `newlines_in_values` set). - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } - - fn try_swapping_with_projection( - &self, - projection: &ProjectionExec, - ) -> Result>> { - self.inner.try_swapping_with_projection(projection) - } -} - /// A Config for [`CsvOpener`] /// /// # Example: create a `DataSourceExec` for CSV @@ -443,6 +92,8 @@ pub struct CsvSource { comment: Option, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, + truncate_rows: bool, } impl CsvSource { @@ -460,6 +111,11 @@ impl CsvSource { pub fn has_header(&self) -> bool { self.has_header } + + // true if rows length support truncate + pub fn truncate_rows(&self) -> bool { + self.truncate_rows + } /// A column delimiter pub fn delimiter(&self) -> u8 { self.delimiter @@ -505,6 +161,13 @@ impl CsvSource { conf.comment = comment; conf } + + /// Whether to support truncate rows when read csv file + pub fn with_truncate_rows(&self, truncate_rows: bool) -> Self { + let mut conf = self.clone(); + conf.truncate_rows = truncate_rows; + conf + } } impl CsvSource { @@ -524,7 +187,8 @@ impl CsvSource { .expect("Batch size must be set before initializing builder"), ) .with_header(self.has_header) - .with_quote(self.quote); + .with_quote(self.quote) + .with_truncated_rows(self.truncate_rows); if let Some(terminator) = self.terminator { builder = builder.with_terminator(terminator); } @@ -564,6 +228,12 @@ impl CsvOpener { } } +impl From for Arc { + fn from(source: CsvSource) -> Self { + as_file_source(source) + } +} + impl FileSource for CsvSource { fn create_file_opener( &self, @@ -626,6 +296,20 @@ impl FileSource for CsvSource { DisplayFormatType::TreeRender => Ok(()), } } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } impl FileOpener for CsvOpener { @@ -652,12 +336,12 @@ impl FileOpener for CsvOpener { /// A,1,2,3,4,5,6,7,8,9\n /// A},1,2,3,4,5,6,7,8,9\n /// The lines read would be: [1, 2] - fn open(&self, file_meta: FileMeta) -> Result { + fn open(&self, partitioned_file: PartitionedFile) -> Result { // `self.config.has_header` controls whether to skip reading the 1st line header // If the .csv file is read in parallel and this `CsvOpener` is only reading some middle // partition, then don't skip first line let mut csv_has_header = self.config.has_header; - if let Some(FileRange { start, .. }) = file_meta.range { + if let Some(FileRange { start, .. }) = partitioned_file.range { if start != 0 { csv_has_header = false; } @@ -665,12 +349,13 @@ impl FileOpener for CsvOpener { let config = CsvSource { has_header: csv_has_header, + truncate_rows: self.config.truncate_rows, ..(*self.config).clone() }; let file_compression_type = self.file_compression_type.to_owned(); - if file_meta.range.is_some() { + if partitioned_file.range.is_some() { assert!( !file_compression_type.is_compressed(), "Reading compressed .csv in parallel is not supported" @@ -684,7 +369,7 @@ impl FileOpener for CsvOpener { // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) let calculated_range = - calculate_range(&file_meta, &store, terminator).await?; + calculate_range(&partitioned_file, &store, terminator).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, @@ -701,11 +386,14 @@ impl FileOpener for CsvOpener { ..Default::default() }; - let result = store.get_opts(file_meta.location(), options).await?; + let result = store + .get_opts(&partitioned_file.object_meta.location, options) + .await?; match result.payload { + #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(mut file, _) => { - let is_whole_file_scanned = file_meta.range.is_none(); + let is_whole_file_scanned = partitioned_file.range.is_none(); let decoder = if is_whole_file_scanned { // Don't seek if no range as breaks FIFO files file_compression_type.convert_read(file)? @@ -716,17 +404,20 @@ impl FileOpener for CsvOpener { )? }; - Ok(futures::stream::iter(config.open(decoder)?).boxed()) + Ok(futures::stream::iter(config.open(decoder)?) + .map(|r| r.map_err(Into::into)) + .boxed()) } GetResultPayload::Stream(s) => { let decoder = config.builder().build_decoder(); let s = s.map_err(DataFusionError::from); let input = file_compression_type.convert_stream(s.boxed())?.fuse(); - Ok(deserialize_stream( + let stream = deserialize_stream( input, DecoderDeserializer::new(CsvDecoder::new(decoder)), - )) + ); + Ok(stream.map_err(Into::into).boxed()) } } })) @@ -742,6 +433,11 @@ pub async fn plan_to_csv( let parsed = ListingTableUrl::parse(path)?; let object_store_url = parsed.object_store(); let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let writer_buffer_size = task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let storeref = Arc::clone(&store); @@ -751,7 +447,8 @@ pub async fn plan_to_csv( let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { - let mut buf_writer = BufWriter::new(storeref, file.clone()); + let mut buf_writer = + BufWriter::with_capacity(storeref, file.clone(), writer_buffer_size); let mut buffer = Vec::with_capacity(1024); //only write headers on first iteration let mut write_headers = true; diff --git a/datafusion/datasource-json/Cargo.toml b/datafusion/datasource-json/Cargo.toml index 6c74923ff79e9..987ab60c70b7c 100644 --- a/datafusion/datasource-json/Cargo.toml +++ b/datafusion/datasource-json/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-datasource-json" description = "datafusion-datasource-json" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -34,19 +34,16 @@ all-features = true arrow = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } -datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } futures = { workspace = true } object_store = { workspace = true } -serde_json = { workspace = true } tokio = { workspace = true } [lints] diff --git a/datafusion/datasource-json/README.md b/datafusion/datasource-json/README.md index 64181814736df..ca2771b9d67e4 100644 --- a/datafusion/datasource-json/README.md +++ b/datafusion/datasource-json/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion datasource +# Apache DataFusion JSON DataSource -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that defines a JSON based file source. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource-json/src/file_format.rs b/datafusion/datasource-json/src/file_format.rs index a6c52312e4127..51f4bd7e963e0 100644 --- a/datafusion/datasource-json/src/file_format.rs +++ b/datafusion/datasource-json/src/file_format.rs @@ -52,7 +52,6 @@ use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join; use datafusion_datasource::write::BatchSerializer; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; @@ -186,6 +185,10 @@ impl FileFormat for JsonFormat { Ok(format!("{}{}", ext, file_compression_type.get_ext())) } + fn compression_type(&self) -> Option { + Some(self.options.compression.into()) + } + async fn infer_schema( &self, _state: &dyn Session, @@ -209,6 +212,7 @@ impl FileFormat for JsonFormat { let r = store.as_ref().get(&object.location).await?; let schema = match r.payload { + #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(file, _) => { let decoder = file_compression_type.convert_read(file)?; let mut reader = BufReader::new(decoder); @@ -248,7 +252,6 @@ impl FileFormat for JsonFormat { &self, _state: &dyn Session, conf: FileScanConfig, - _filters: Option<&Arc>, ) -> Result> { let source = Arc::new(JsonSource::new()); let conf = FileScanConfigBuilder::from(conf) diff --git a/datafusion/datasource-json/src/source.rs b/datafusion/datasource-json/src/source.rs index f1adccf9ded7d..0b1eee1dac588 100644 --- a/datafusion/datasource-json/src/source.rs +++ b/datafusion/datasource-json/src/source.rs @@ -28,200 +28,26 @@ use datafusion_common::error::{DataFusionError, Result}; use datafusion_common_runtime::JoinSet; use datafusion_datasource::decoder::{deserialize_stream, DecoderDeserializer}; use datafusion_datasource::file_compression_type::FileCompressionType; -use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; -use datafusion_datasource::{calculate_range, ListingTableUrl, RangeCalculation}; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; +use datafusion_datasource::{ + as_file_source, calculate_range, ListingTableUrl, PartitionedFile, RangeCalculation, +}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use arrow::json::ReaderBuilder; use arrow::{datatypes::SchemaRef, json}; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::Statistics; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{DisplayAs, DisplayFormatType, PlanProperties}; - -use datafusion_datasource::file_groups::FileGroup; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; -/// Execution plan for scanning NdJson data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct NdJsonExec { - inner: DataSourceExec, - base_config: FileScanConfig, - file_compression_type: FileCompressionType, -} - -#[allow(unused, deprecated)] -impl NdJsonExec { - /// Create a new JSON reader execution plan provided base configurations - pub fn new( - base_config: FileScanConfig, - file_compression_type: FileCompressionType, - ) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - projected_schema, - &projected_output_ordering, - projected_constraints, - &base_config, - ); - - let json = JsonSource::default(); - let base_config = base_config - .with_file_compression_type(file_compression_type) - .with_source(Arc::new(json)); - - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - file_compression_type: base_config.file_compression_type, - base_config, - } - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - /// Ref to file compression type - pub fn file_compression_type(&self) -> &FileCompressionType { - &self.file_compression_type - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn json_source(&self) -> JsonSource { - let source = self.file_scan_config(); - source - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for NdJsonExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for NdJsonExec { - fn name(&self) -> &'static str { - "NdJsonExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn repartitioned( - &self, - target_partitions: usize, - config: &datafusion_common::config::ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// A [`FileOpener`] that opens a JSON file and yields a [`FileOpenFuture`] pub struct JsonOpener { batch_size: usize, @@ -253,6 +79,7 @@ pub struct JsonSource { batch_size: Option, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl JsonSource { @@ -262,6 +89,12 @@ impl JsonSource { } } +impl From for Arc { + fn from(source: JsonSource) -> Self { + as_file_source(source) + } +} + impl FileSource for JsonSource { fn create_file_opener( &self, @@ -316,6 +149,20 @@ impl FileSource for JsonSource { fn file_type(&self) -> &str { "json" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } impl FileOpener for JsonOpener { @@ -328,14 +175,15 @@ impl FileOpener for JsonOpener { /// are applied to determine which lines to read: /// 1. The first line of the partition is the line in which the index of the first character >= `start`. /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. - fn open(&self, file_meta: FileMeta) -> Result { + fn open(&self, partitioned_file: PartitionedFile) -> Result { let store = Arc::clone(&self.object_store); let schema = Arc::clone(&self.projected_schema); let batch_size = self.batch_size; let file_compression_type = self.file_compression_type.to_owned(); Ok(Box::pin(async move { - let calculated_range = calculate_range(&file_meta, &store, None).await?; + let calculated_range = + calculate_range(&partitioned_file, &store, None).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, @@ -352,11 +200,14 @@ impl FileOpener for JsonOpener { ..Default::default() }; - let result = store.get_opts(file_meta.location(), options).await?; + let result = store + .get_opts(&partitioned_file.object_meta.location, options) + .await?; match result.payload { + #[cfg(not(target_arch = "wasm32"))] GetResultPayload::File(mut file, _) => { - let bytes = match file_meta.range { + let bytes = match partitioned_file.range { None => file_compression_type.convert_read(file)?, Some(_) => { file.seek(SeekFrom::Start(result.range.start as _))?; @@ -369,7 +220,9 @@ impl FileOpener for JsonOpener { .with_batch_size(batch_size) .build(BufReader::new(bytes))?; - Ok(futures::stream::iter(reader).boxed()) + Ok(futures::stream::iter(reader) + .map(|r| r.map_err(Into::into)) + .boxed()) } GetResultPayload::Stream(s) => { let s = s.map_err(DataFusionError::from); @@ -379,10 +232,11 @@ impl FileOpener for JsonOpener { .build_decoder()?; let input = file_compression_type.convert_stream(s.boxed())?.fuse(); - Ok(deserialize_stream( + let stream = deserialize_stream( input, DecoderDeserializer::new(JsonDecoder::new(decoder)), - )) + ); + Ok(stream.map_err(Into::into).boxed()) } } })) @@ -398,6 +252,11 @@ pub async fn plan_to_json( let parsed = ListingTableUrl::parse(path)?; let object_store_url = parsed.object_store(); let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let writer_buffer_size = task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let storeref = Arc::clone(&store); @@ -407,7 +266,8 @@ pub async fn plan_to_json( let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { - let mut buf_writer = BufWriter::new(storeref, file.clone()); + let mut buf_writer = + BufWriter::with_capacity(storeref, file.clone(), writer_buffer_size); let mut buffer = Vec::with_capacity(1024); while let Some(batch) = stream.next().await.transpose()? { diff --git a/datafusion/datasource-parquet/Cargo.toml b/datafusion/datasource-parquet/Cargo.toml index b6a548c998dc2..1f866ffd6cc2f 100644 --- a/datafusion/datasource-parquet/Cargo.toml +++ b/datafusion/datasource-parquet/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-datasource-parquet" description = "datafusion-datasource-parquet" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -34,17 +34,17 @@ all-features = true arrow = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } -datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, features = ["object_store", "parquet"] } datafusion-common-runtime = { workspace = true } -datafusion-datasource = { workspace = true, features = ["parquet"] } +datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } -datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } +datafusion-pruning = { workspace = true } datafusion-session = { workspace = true } futures = { workspace = true } itertools = { workspace = true } @@ -52,7 +52,6 @@ log = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true } -rand = { workspace = true } tokio = { workspace = true } [dev-dependencies] @@ -64,3 +63,10 @@ workspace = true [lib] name = "datafusion_datasource_parquet" path = "src/mod.rs" + +[features] +parquet_encryption = [ + "parquet/encryption", + "datafusion-common/parquet_encryption", + "datafusion-execution/parquet_encryption", +] diff --git a/datafusion/datasource-parquet/README.md b/datafusion/datasource-parquet/README.md index abcdd5ab13402..833fc74a258b3 100644 --- a/datafusion/datasource-parquet/README.md +++ b/datafusion/datasource-parquet/README.md @@ -17,10 +17,17 @@ under the License. --> -# DataFusion datasource +# Apache DataFusion Parquet DataSource -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. -This crate is a submodule of DataFusion that defines a Parquet based file source. +This crate is a submodule of DataFusion that defines an [Apache Parquet] based file source. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[apache parquet]: https://parquet.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index 1d9a67fd2eb6d..963c1d77950c6 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -18,69 +18,74 @@ //! [`ParquetFormat`]: Parquet [`FileFormat`] abstractions use std::any::Any; -use std::fmt; +use std::cell::RefCell; use std::fmt::Debug; use std::ops::Range; +use std::rc::Rc; use std::sync::Arc; +use std::{fmt, vec}; use arrow::array::RecordBatch; -use arrow::compute::sum; +use arrow::datatypes::{Fields, Schema, SchemaRef, TimeUnit}; +use datafusion_datasource::file_compression_type::FileCompressionType; +use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; +use datafusion_datasource::write::{ + get_writer_schema, ObjectWriterBuilder, SharedBuffer, +}; + +use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; +use datafusion_datasource::write::demux::DemuxedStreamReceiver; + use arrow::datatypes::{DataType, Field, FieldRef}; -use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::config::{ConfigField, ConfigFileType, TableParquetOptions}; +#[cfg(feature = "parquet_encryption")] +use datafusion_common::encryption::map_config_decryption_to_decryption; +use datafusion_common::encryption::FileDecryptionProperties; use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::stats::Precision; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, ColumnStatistics, - DataFusionError, GetExt, Result, DEFAULT_PARQUET_EXTENSION, + internal_datafusion_err, internal_err, not_impl_err, DataFusionError, GetExt, + HashSet, Result, DEFAULT_PARQUET_EXTENSION, }; use datafusion_common::{HashMap, Statistics}; use datafusion_common_runtime::{JoinSet, SpawnedTask}; use datafusion_datasource::display::FileGroupDisplay; use datafusion_datasource::file::FileSource; -use datafusion_datasource::file_compression_type::FileCompressionType; -use datafusion_datasource::file_format::{ - FileFormat, FileFormatFactory, FilePushdownSupport, -}; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; use datafusion_datasource::sink::{DataSink, DataSinkExec}; -use datafusion_datasource::write::demux::DemuxedStreamReceiver; -use datafusion_datasource::write::{create_writer, get_writer_schema, SharedBuffer}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_expr::Expr; -use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; -use datafusion_physical_plan::Accumulator; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; -use crate::can_expr_be_pushed_down_with_schemas; -use crate::source::ParquetSource; +use crate::reader::CachedParquetFileReaderFactory; +use crate::source::{parse_coerce_int96_string, ParquetSource}; use async_trait::async_trait; use bytes::Bytes; use datafusion_datasource::source::DataSourceExec; +use datafusion_execution::runtime_env::RuntimeEnv; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; -use log::debug; use object_store::buffered::BufWriter; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; -use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::arrow_writer::{ - compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, - ArrowLeafColumn, ArrowWriterOptions, + compute_leaves, ArrowColumnChunk, ArrowColumnWriter, ArrowLeafColumn, + ArrowRowGroupWriterFactory, ArrowWriterOptions, }; use parquet::arrow::async_reader::MetadataFetch; -use parquet::arrow::{parquet_to_arrow_schema, ArrowSchemaConverter, AsyncArrowWriter}; +use parquet::arrow::{ArrowWriter, AsyncArrowWriter}; +use parquet::basic::Type; + +use crate::metadata::DFParquetMetadata; +use datafusion_execution::cache::cache_manager::FileMetadataCache; use parquet::errors::ParquetError; -use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData}; +use parquet::file::metadata::ParquetMetaData; use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; use parquet::file::writer::SerializedFileWriter; use parquet::format::FileMetaData; +use parquet::schema::types::SchemaDescriptor; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -268,6 +273,15 @@ impl ParquetFormat { self.options.global.binary_as_string = binary_as_string; self } + + pub fn coerce_int96(&self) -> Option { + self.options.global.coerce_int96.clone() + } + + pub fn with_coerce_int96(mut self, time_unit: Option) -> Self { + self.options.global.coerce_int96 = time_unit; + self + } } /// Clears all metadata (Schema level and field level) on an iterator @@ -287,14 +301,39 @@ fn clear_metadata( }) } -async fn fetch_schema_with_location( - store: &dyn ObjectStore, - file: &ObjectMeta, - metadata_size_hint: Option, -) -> Result<(Path, Schema)> { - let loc_path = file.location.clone(); - let schema = fetch_schema(store, file, metadata_size_hint).await?; - Ok((loc_path, schema)) +#[cfg(feature = "parquet_encryption")] +async fn get_file_decryption_properties( + state: &dyn Session, + options: &TableParquetOptions, + file_path: &Path, +) -> Result> { + let file_decryption_properties: Option = + match &options.crypto.file_decryption { + Some(cfd) => Some(map_config_decryption_to_decryption(cfd)), + None => match &options.crypto.factory_id { + Some(factory_id) => { + let factory = + state.runtime_env().parquet_encryption_factory(factory_id)?; + factory + .get_file_decryption_properties( + &options.crypto.factory_options, + file_path, + ) + .await? + } + None => None, + }, + }; + Ok(file_decryption_properties) +} + +#[cfg(not(feature = "parquet_encryption"))] +async fn get_file_decryption_properties( + _state: &dyn Session, + _options: &TableParquetOptions, + _file_path: &Path, +) -> Result> { + Ok(None) } #[async_trait] @@ -318,21 +357,43 @@ impl FileFormat for ParquetFormat { } } + fn compression_type(&self) -> Option { + None + } + async fn infer_schema( &self, state: &dyn Session, store: &Arc, objects: &[ObjectMeta], ) -> Result { + let coerce_int96 = match self.coerce_int96() { + Some(time_unit) => Some(parse_coerce_int96_string(time_unit.as_str())?), + None => None, + }; + + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let mut schemas: Vec<_> = futures::stream::iter(objects) - .map(|object| { - fetch_schema_with_location( - store.as_ref(), - object, - self.metadata_size_hint(), + .map(|object| async { + let file_decryption_properties = get_file_decryption_properties( + state, + &self.options, + &object.location, ) + .await?; + let result = DFParquetMetadata::new(store.as_ref(), object) + .with_metadata_size_hint(self.metadata_size_hint()) + .with_decryption_properties(file_decryption_properties.as_ref()) + .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) + .with_coerce_int96(coerce_int96) + .fetch_schema_with_location() + .await?; + Ok::<_, DataFusionError>(result) }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 + // fetch schemas concurrently, if requested .buffered(state.config_options().execution.meta_fetch_concurrency) .try_collect() .await?; @@ -373,53 +434,57 @@ impl FileFormat for ParquetFormat { async fn infer_stats( &self, - _state: &dyn Session, + state: &dyn Session, store: &Arc, table_schema: SchemaRef, object: &ObjectMeta, ) -> Result { - let stats = fetch_statistics( - store.as_ref(), - table_schema, - object, - self.metadata_size_hint(), - ) - .await?; - Ok(stats) + let file_decryption_properties = + get_file_decryption_properties(state, &self.options, &object.location) + .await?; + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + DFParquetMetadata::new(store, object) + .with_metadata_size_hint(self.metadata_size_hint()) + .with_decryption_properties(file_decryption_properties.as_ref()) + .with_file_metadata_cache(Some(file_metadata_cache)) + .fetch_statistics(&table_schema) + .await } async fn create_physical_plan( &self, - _state: &dyn Session, + state: &dyn Session, conf: FileScanConfig, - filters: Option<&Arc>, ) -> Result> { - let mut predicate = None; let mut metadata_size_hint = None; - // If enable pruning then combine the filters to build the predicate. - // If disable pruning then set the predicate to None, thus readers - // will not prune data based on the statistics. - if self.enable_pruning() { - if let Some(pred) = filters.cloned() { - predicate = Some(pred); - } - } if let Some(metadata) = self.metadata_size_hint() { metadata_size_hint = Some(metadata); } let mut source = ParquetSource::new(self.options.clone()); - if let Some(predicate) = predicate { - source = source.with_predicate(Arc::clone(&conf.file_schema), predicate); - } + // Use the CachedParquetFileReaderFactory + let metadata_cache = state.runtime_env().cache_manager.get_file_metadata_cache(); + let store = state + .runtime_env() + .object_store(conf.object_store_url.clone())?; + let cached_parquet_read_factory = + Arc::new(CachedParquetFileReaderFactory::new(store, metadata_cache)); + source = source.with_parquet_file_reader_factory(cached_parquet_read_factory); + if let Some(metadata_size_hint) = metadata_size_hint { source = source.with_metadata_size_hint(metadata_size_hint) } + source = self.set_source_encryption_factory(source, state)?; + + // Apply schema adapter factory before building the new config + let file_source = source.apply_schema_adapter(&conf)?; + let conf = FileScanConfigBuilder::from(conf) - .with_source(Arc::new(source)) + .with_source(file_source) .build(); Ok(DataSourceExec::from_data_source(conf)) } @@ -440,29 +505,43 @@ impl FileFormat for ParquetFormat { Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } - fn supports_filters_pushdown( - &self, - file_schema: &Schema, - table_schema: &Schema, - filters: &[&Expr], - ) -> Result { - if !self.options().global.pushdown_filters { - return Ok(FilePushdownSupport::NoSupport); - } - - let all_supported = filters.iter().all(|filter| { - can_expr_be_pushed_down_with_schemas(filter, file_schema, table_schema) - }); + fn file_source(&self) -> Arc { + Arc::new(ParquetSource::default()) + } +} - Ok(if all_supported { - FilePushdownSupport::Supported +#[cfg(feature = "parquet_encryption")] +impl ParquetFormat { + fn set_source_encryption_factory( + &self, + source: ParquetSource, + state: &dyn Session, + ) -> Result { + if let Some(encryption_factory_id) = &self.options.crypto.factory_id { + Ok(source.with_encryption_factory( + state + .runtime_env() + .parquet_encryption_factory(encryption_factory_id)?, + )) } else { - FilePushdownSupport::NotSupportedForFilter - }) + Ok(source) + } } +} - fn file_source(&self) -> Arc { - Arc::new(ParquetSource::default()) +#[cfg(not(feature = "parquet_encryption"))] +impl ParquetFormat { + fn set_source_encryption_factory( + &self, + source: ParquetSource, + _state: &dyn Session, + ) -> Result { + if let Some(encryption_factory_id) = &self.options.crypto.factory_id { + Err(DataFusionError::Configuration( + format!("Parquet encryption factory id is set to '{encryption_factory_id}' but the parquet_encryption feature is disabled"))) + } else { + Ok(source) + } } } @@ -569,6 +648,194 @@ pub fn apply_file_schema_type_coercions( )) } +/// Coerces the file schema's Timestamps to the provided TimeUnit if Parquet schema contains INT96. +pub fn coerce_int96_to_resolution( + parquet_schema: &SchemaDescriptor, + file_schema: &Schema, + time_unit: &TimeUnit, +) -> Option { + // Traverse the parquet_schema columns looking for int96 physical types. If encountered, insert + // the field's full path into a set. + let int96_fields: HashSet<_> = parquet_schema + .columns() + .iter() + .filter(|f| f.physical_type() == Type::INT96) + .map(|f| f.path().string()) + .collect(); + + if int96_fields.is_empty() { + // The schema doesn't contain any int96 fields, so skip the remaining logic. + return None; + } + + // Do a DFS into the schema using a stack, looking for timestamp(nanos) fields that originated + // as int96 to coerce to the provided time_unit. + + type NestedFields = Rc>>; + type StackContext<'a> = ( + Vec<&'a str>, // The Parquet column path (e.g., "c0.list.element.c1") for the current field. + &'a FieldRef, // The current field to be processed. + NestedFields, // The parent's fields that this field will be (possibly) type-coerced and + // inserted into. All fields have a parent, so this is not an Option type. + Option, // Nested types need to create their own vector of fields for their + // children. For primitive types this will remain None. For nested + // types it is None the first time they are processed. Then, we + // instantiate a vector for its children, push the field back onto the + // stack to be processed again, and DFS into its children. The next + // time we process the field, we know we have DFS'd into the children + // because this field is Some. + ); + + // This is our top-level fields from which we will construct our schema. We pass this into our + // initial stack context as the parent fields, and the DFS populates it. + let fields = Rc::new(RefCell::new(Vec::with_capacity(file_schema.fields.len()))); + + // TODO: It might be possible to only DFS into nested fields that we know contain an int96 if we + // use some sort of LPM data structure to check if we're currently DFS'ing nested types that are + // in a column path that contains an int96. That can be a future optimization for large schemas. + let transformed_schema = { + // Populate the stack with our top-level fields. + let mut stack: Vec = file_schema + .fields() + .iter() + .rev() + .map(|f| (vec![f.name().as_str()], f, Rc::clone(&fields), None)) + .collect(); + + // Pop fields to DFS into until we have exhausted the stack. + while let Some((parquet_path, current_field, parent_fields, child_fields)) = + stack.pop() + { + match (current_field.data_type(), child_fields) { + (DataType::Struct(unprocessed_children), None) => { + // This is the first time popping off this struct. We don't yet know the + // correct types of its children (i.e., if they need coercing) so we create + // a vector for child_fields, push the struct node back onto the stack to be + // processed again (see below) after processing all its children. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity( + unprocessed_children.len(), + ))); + // Note that here we push the struct back onto the stack with its + // parent_fields in the same position, now with Some(child_fields). + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + // Push all the children in reverse to maintain original schema order due to + // stack processing. + for child in unprocessed_children.into_iter().rev() { + let mut child_path = parquet_path.clone(); + // Build up a normalized path that we'll use as a key into the original + // int96_fields set above to test if this originated as int96. + child_path.push("."); + child_path.push(child.name()); + // Note that here we push the field onto the stack using the struct's + // new child_fields vector as the field's parent_fields. + stack.push((child_path, child, Rc::clone(&child_fields), None)); + } + } + (DataType::Struct(unprocessed_children), Some(processed_children)) => { + // This is the second time popping off this struct. The child_fields vector + // now contains each field that has been DFS'd into, and we can construct + // the resulting struct with correct child types. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), unprocessed_children.len()); + let processed_struct = Field::new_struct( + current_field.name(), + processed_children.as_slice(), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_struct)); + } + (DataType::List(unprocessed_child), None) => { + // This is the first time popping off this list. See struct docs above. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + let mut child_path = parquet_path.clone(); + // Spark uses a definition for arrays/lists that results in a group + // named "list" that is not maintained when parsing to Arrow. We just push + // this name into the path. + child_path.push(".list."); + child_path.push(unprocessed_child.name()); + stack.push(( + child_path.clone(), + unprocessed_child, + Rc::clone(&child_fields), + None, + )); + } + (DataType::List(_), Some(processed_children)) => { + // This is the second time popping off this list. See struct docs above. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), 1); + let processed_list = Field::new_list( + current_field.name(), + Arc::clone(&processed_children[0]), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_list)); + } + (DataType::Map(unprocessed_child, _), None) => { + // This is the first time popping off this map. See struct docs above. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + let mut child_path = parquet_path.clone(); + child_path.push("."); + child_path.push(unprocessed_child.name()); + stack.push(( + child_path.clone(), + unprocessed_child, + Rc::clone(&child_fields), + None, + )); + } + (DataType::Map(_, sorted), Some(processed_children)) => { + // This is the second time popping off this map. See struct docs above. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), 1); + let processed_map = Field::new( + current_field.name(), + DataType::Map(Arc::clone(&processed_children[0]), *sorted), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_map)); + } + (DataType::Timestamp(TimeUnit::Nanosecond, None), None) + if int96_fields.contains(parquet_path.concat().as_str()) => + // We found a timestamp(nanos) and it originated as int96. Coerce it to the correct + // time_unit. + { + parent_fields.borrow_mut().push(field_with_new_type( + current_field, + DataType::Timestamp(*time_unit, None), + )); + } + // Other types can be cloned as they are. + _ => parent_fields.borrow_mut().push(Arc::clone(current_field)), + } + } + assert_eq!(fields.borrow().len(), file_schema.fields.len()); + Schema::new_with_metadata( + fields.borrow_mut().clone(), + file_schema.metadata.clone(), + ) + }; + + Some(transformed_schema) +} + /// Coerces the file schema if the table schema uses a view type. #[deprecated( since = "47.0.0", @@ -723,22 +990,19 @@ pub fn transform_binary_to_string(schema: &Schema) -> Schema { } /// [`MetadataFetch`] adapter for reading bytes from an [`ObjectStore`] -struct ObjectStoreFetch<'a> { +pub struct ObjectStoreFetch<'a> { store: &'a dyn ObjectStore, meta: &'a ObjectMeta, } impl<'a> ObjectStoreFetch<'a> { - fn new(store: &'a dyn ObjectStore, meta: &'a ObjectMeta) -> Self { + pub fn new(store: &'a dyn ObjectStore, meta: &'a ObjectMeta) -> Self { Self { store, meta } } } impl MetadataFetch for ObjectStoreFetch<'_> { - fn fetch( - &mut self, - range: Range, - ) -> BoxFuture<'_, Result> { + fn fetch(&mut self, range: Range) -> BoxFuture<'_, Result> { async { self.store .get_range(&self.meta.location, range) @@ -755,212 +1019,57 @@ impl MetadataFetch for ObjectStoreFetch<'_> { /// through [`ParquetFileReaderFactory`]. /// /// [`ParquetFileReaderFactory`]: crate::ParquetFileReaderFactory +#[deprecated( + since = "50.0.0", + note = "Use `DFParquetMetadata::fetch_metadata` instead" +)] pub async fn fetch_parquet_metadata( store: &dyn ObjectStore, - meta: &ObjectMeta, + object_meta: &ObjectMeta, size_hint: Option, -) -> Result { - let file_size = meta.size; - let fetch = ObjectStoreFetch::new(store, meta); - - ParquetMetaDataReader::new() - .with_prefetch_hint(size_hint) - .load_and_finish(fetch, file_size) + #[allow(unused)] decryption_properties: Option<&FileDecryptionProperties>, + file_metadata_cache: Option>, +) -> Result> { + DFParquetMetadata::new(store, object_meta) + .with_metadata_size_hint(size_hint) + .with_decryption_properties(decryption_properties) + .with_file_metadata_cache(file_metadata_cache) + .fetch_metadata() .await - .map_err(DataFusionError::from) -} - -/// Read and parse the schema of the Parquet file at location `path` -async fn fetch_schema( - store: &dyn ObjectStore, - file: &ObjectMeta, - metadata_size_hint: Option, -) -> Result { - let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; - let file_metadata = metadata.file_metadata(); - let schema = parquet_to_arrow_schema( - file_metadata.schema_descr(), - file_metadata.key_value_metadata(), - )?; - Ok(schema) } /// Read and parse the statistics of the Parquet file at location `path` /// /// See [`statistics_from_parquet_meta_calc`] for more details +#[deprecated( + since = "50.0.0", + note = "Use `DFParquetMetadata::fetch_statistics` instead" +)] pub async fn fetch_statistics( store: &dyn ObjectStore, table_schema: SchemaRef, file: &ObjectMeta, metadata_size_hint: Option, + decryption_properties: Option<&FileDecryptionProperties>, + file_metadata_cache: Option>, ) -> Result { - let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; - statistics_from_parquet_meta_calc(&metadata, table_schema) + DFParquetMetadata::new(store, file) + .with_metadata_size_hint(metadata_size_hint) + .with_decryption_properties(decryption_properties) + .with_file_metadata_cache(file_metadata_cache) + .fetch_statistics(&table_schema) + .await } -/// Convert statistics in [`ParquetMetaData`] into [`Statistics`] using [`StatisticsConverter`] -/// -/// The statistics are calculated for each column in the table schema -/// using the row group statistics in the parquet metadata. -/// -/// # Key behaviors: -/// -/// 1. Extracts row counts and byte sizes from all row groups -/// 2. Applies schema type coercions to align file schema with table schema -/// 3. Collects and aggregates statistics across row groups when available -/// -/// # When there are no statistics: -/// -/// If the Parquet file doesn't contain any statistics (has_statistics is false), the function returns a Statistics object with: -/// - Exact row count -/// - Exact byte size -/// - All column statistics marked as unknown via Statistics::unknown_column(&table_schema) -/// # When only some columns have statistics: -/// -/// For columns with statistics: -/// - Min/max values are properly extracted and represented as Precision::Exact -/// - Null counts are calculated by summing across row groups -/// -/// For columns without statistics, -/// - For min/max, there are two situations: -/// 1. The column isn't in arrow schema, then min/max values are set to Precision::Absent -/// 2. The column is in arrow schema, but not in parquet schema due to schema revolution, min/max values are set to Precision::Exact(null) -/// - Null counts are set to Precision::Exact(num_rows) (conservatively assuming all values could be null) +#[deprecated( + since = "50.0.0", + note = "Use `DFParquetMetadata::statistics_from_parquet_metadata` instead" +)] pub fn statistics_from_parquet_meta_calc( metadata: &ParquetMetaData, table_schema: SchemaRef, ) -> Result { - let row_groups_metadata = metadata.row_groups(); - - let mut statistics = Statistics::new_unknown(&table_schema); - let mut has_statistics = false; - let mut num_rows = 0_usize; - let mut total_byte_size = 0_usize; - for row_group_meta in row_groups_metadata { - num_rows += row_group_meta.num_rows() as usize; - total_byte_size += row_group_meta.total_byte_size() as usize; - - if !has_statistics { - has_statistics = row_group_meta - .columns() - .iter() - .any(|column| column.statistics().is_some()); - } - } - statistics.num_rows = Precision::Exact(num_rows); - statistics.total_byte_size = Precision::Exact(total_byte_size); - - let file_metadata = metadata.file_metadata(); - let mut file_schema = parquet_to_arrow_schema( - file_metadata.schema_descr(), - file_metadata.key_value_metadata(), - )?; - - if let Some(merged) = apply_file_schema_type_coercions(&table_schema, &file_schema) { - file_schema = merged; - } - - statistics.column_statistics = if has_statistics { - let (mut max_accs, mut min_accs) = create_max_min_accs(&table_schema); - let mut null_counts_array = - vec![Precision::Exact(0); table_schema.fields().len()]; - - table_schema - .fields() - .iter() - .enumerate() - .for_each(|(idx, field)| { - match StatisticsConverter::try_new( - field.name(), - &file_schema, - file_metadata.schema_descr(), - ) { - Ok(stats_converter) => { - summarize_min_max_null_counts( - &mut min_accs, - &mut max_accs, - &mut null_counts_array, - idx, - num_rows, - &stats_converter, - row_groups_metadata, - ) - .ok(); - } - Err(e) => { - debug!("Failed to create statistics converter: {}", e); - null_counts_array[idx] = Precision::Exact(num_rows); - } - } - }); - - get_col_stats( - &table_schema, - null_counts_array, - &mut max_accs, - &mut min_accs, - ) - } else { - Statistics::unknown_column(&table_schema) - }; - - Ok(statistics) -} - -fn get_col_stats( - schema: &Schema, - null_counts: Vec>, - max_values: &mut [Option], - min_values: &mut [Option], -) -> Vec { - (0..schema.fields().len()) - .map(|i| { - let max_value = match max_values.get_mut(i).unwrap() { - Some(max_value) => max_value.evaluate().ok(), - None => None, - }; - let min_value = match min_values.get_mut(i).unwrap() { - Some(min_value) => min_value.evaluate().ok(), - None => None, - }; - ColumnStatistics { - null_count: null_counts[i], - max_value: max_value.map(Precision::Exact).unwrap_or(Precision::Absent), - min_value: min_value.map(Precision::Exact).unwrap_or(Precision::Absent), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - } - }) - .collect() -} - -fn summarize_min_max_null_counts( - min_accs: &mut [Option], - max_accs: &mut [Option], - null_counts_array: &mut [Precision], - arrow_schema_index: usize, - num_rows: usize, - stats_converter: &StatisticsConverter, - row_groups_metadata: &[RowGroupMetaData], -) -> Result<()> { - let max_values = stats_converter.row_group_maxes(row_groups_metadata)?; - let min_values = stats_converter.row_group_mins(row_groups_metadata)?; - let null_counts = stats_converter.row_group_null_counts(row_groups_metadata)?; - - if let Some(max_acc) = &mut max_accs[arrow_schema_index] { - max_acc.update_batch(&[max_values])?; - } - - if let Some(min_acc) = &mut min_accs[arrow_schema_index] { - min_acc.update_batch(&[min_values])?; - } - - null_counts_array[arrow_schema_index] = Precision::Exact(match sum(&null_counts) { - Some(null_count) => null_count as usize, - None => num_rows, - }); - - Ok(()) + DFParquetMetadata::statistics_from_parquet_metadata(metadata, &table_schema) } /// Implements [`DataSink`] for writing to a parquet file. @@ -1014,15 +1123,12 @@ impl ParquetSink { /// Create writer properties based upon configuration settings, /// including partitioning and the inclusion of arrow schema metadata. - fn create_writer_props(&self) -> Result { - let schema = if self.parquet_options.global.allow_single_file_parallelism { - // If parallelizing writes, we may be also be doing hive style partitioning - // into multiple files which impacts the schema per file. - // Refer to `get_writer_schema()` - &get_writer_schema(&self.config) - } else { - self.config.output_schema() - }; + async fn create_writer_props( + &self, + runtime: &Arc, + path: &Path, + ) -> Result { + let schema = self.config.output_schema(); // TODO: avoid this clone in follow up PR, where the writer properties & schema // are calculated once on `ParquetSink::new` @@ -1031,7 +1137,16 @@ impl ParquetSink { parquet_opts.arrow_schema(schema); } - Ok(WriterPropertiesBuilder::try_from(&parquet_opts)?.build()) + let mut builder = WriterPropertiesBuilder::try_from(&parquet_opts)?; + builder = set_writer_encryption_properties( + builder, + runtime, + &parquet_opts, + schema, + path, + ) + .await?; + Ok(builder.build()) } /// Creates an AsyncArrowWriter which serializes a parquet file to an ObjectStore @@ -1040,9 +1155,18 @@ impl ParquetSink { &self, location: &Path, object_store: Arc, + context: &Arc, parquet_props: WriterProperties, ) -> Result> { - let buf_writer = BufWriter::new(object_store, location.clone()); + let buf_writer = BufWriter::with_capacity( + object_store, + location.clone(), + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + ); let options = ArrowWriterOptions::new() .with_properties(parquet_props) .with_skip_arrow_metadata(self.parquet_options.global.skip_arrow_metadata); @@ -1061,6 +1185,49 @@ impl ParquetSink { } } +#[cfg(feature = "parquet_encryption")] +async fn set_writer_encryption_properties( + builder: WriterPropertiesBuilder, + runtime: &Arc, + parquet_opts: &TableParquetOptions, + schema: &Arc, + path: &Path, +) -> Result { + if let Some(file_encryption_properties) = &parquet_opts.crypto.file_encryption { + // Encryption properties have been specified directly + return Ok(builder + .with_file_encryption_properties(file_encryption_properties.clone().into())); + } else if let Some(encryption_factory_id) = &parquet_opts.crypto.factory_id.as_ref() { + // Encryption properties will be generated by an encryption factory + let encryption_factory = + runtime.parquet_encryption_factory(encryption_factory_id)?; + let file_encryption_properties = encryption_factory + .get_file_encryption_properties( + &parquet_opts.crypto.factory_options, + schema, + path, + ) + .await?; + if let Some(file_encryption_properties) = file_encryption_properties { + return Ok( + builder.with_file_encryption_properties(file_encryption_properties) + ); + } + } + Ok(builder) +} + +#[cfg(not(feature = "parquet_encryption"))] +async fn set_writer_encryption_properties( + builder: WriterPropertiesBuilder, + _runtime: &Arc, + _parquet_opts: &TableParquetOptions, + _schema: &Arc, + _path: &Path, +) -> Result { + Ok(builder) +} + #[async_trait] impl FileSink for ParquetSink { fn config(&self) -> &FileSinkConfig { @@ -1075,14 +1242,12 @@ impl FileSink for ParquetSink { object_store: Arc, ) -> Result { let parquet_opts = &self.parquet_options; - let allow_single_file_parallelism = - parquet_opts.global.allow_single_file_parallelism; let mut file_write_tasks: JoinSet< std::result::Result<(Path, FileMetaData), DataFusionError>, > = JoinSet::new(); - let parquet_props = self.create_writer_props()?; + let runtime = context.runtime_env(); let parallel_options = ParallelParquetWriterOptions { max_parallel_row_groups: parquet_opts .global @@ -1093,17 +1258,18 @@ impl FileSink for ParquetSink { }; while let Some((path, mut rx)) = file_stream_rx.recv().await { - if !allow_single_file_parallelism { + let parquet_props = self.create_writer_props(&runtime, &path).await?; + if !parquet_opts.global.allow_single_file_parallelism { let mut writer = self .create_async_arrow_writer( &path, Arc::clone(&object_store), + context, parquet_props.clone(), ) .await?; - let mut reservation = - MemoryConsumer::new(format!("ParquetSink[{}]", path)) - .register(context.memory_pool()); + let mut reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) + .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { writer.write(&batch).await?; @@ -1112,20 +1278,28 @@ impl FileSink for ParquetSink { let file_metadata = writer .close() .await - .map_err(DataFusionError::ParquetError)?; + .map_err(|e| DataFusionError::ParquetError(Box::new(e)))?; Ok((path, file_metadata)) }); } else { - let writer = create_writer( + let writer = ObjectWriterBuilder::new( // Parquet files as a whole are never compressed, since they // manage compressed blocks themselves. FileCompressionType::UNCOMPRESSED, &path, Arc::clone(&object_store), ) - .await?; + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; let schema = get_writer_schema(&self.config); let props = parquet_props.clone(); + let skip_arrow_metadata = self.parquet_options.global.skip_arrow_metadata; let parallel_options_clone = parallel_options.clone(); let pool = Arc::clone(context.memory_pool()); file_write_tasks.spawn(async move { @@ -1134,6 +1308,7 @@ impl FileSink for ParquetSink { rx, schema, &props, + skip_arrow_metadata, parallel_options_clone, pool, ) @@ -1168,7 +1343,7 @@ impl FileSink for ParquetSink { demux_task .join_unwind() .await - .map_err(DataFusionError::ExecutionJoin)??; + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; Ok(row_count as u64) } @@ -1214,13 +1389,10 @@ type ColSender = Sender; /// Returns join handles for each columns serialization task along with a send channel /// to send arrow arrays to each serialization task. fn spawn_column_parallel_row_group_writer( - schema: Arc, - parquet_props: Arc, + col_writers: Vec, max_buffer_size: usize, pool: &Arc, ) -> Result<(Vec, Vec)> { - let schema_desc = ArrowSchemaConverter::new().convert(&schema)?; - let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; let num_columns = col_writers.len(); let mut col_writer_tasks = Vec::with_capacity(num_columns); @@ -1296,7 +1468,7 @@ fn spawn_rg_join_and_finalize_task( let (writer, _col_reservation) = task .join_unwind() .await - .map_err(DataFusionError::ExecutionJoin)??; + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; let encoded_size = writer.get_estimated_total_bytes(); rg_reservation.grow(encoded_size); finalized_rg.push(writer.close()?); @@ -1315,6 +1487,7 @@ fn spawn_rg_join_and_finalize_task( /// across both columns and row_groups, with a theoretical max number of parallel tasks /// given by n_columns * num_row_groups. fn spawn_parquet_parallel_serialization_task( + row_group_writer_factory: ArrowRowGroupWriterFactory, mut data: Receiver, serialize_tx: Sender>, schema: Arc, @@ -1325,13 +1498,11 @@ fn spawn_parquet_parallel_serialization_task( SpawnedTask::spawn(async move { let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; let max_row_group_rows = writer_props.max_row_group_size(); + let mut row_group_index = 0; + let col_writers = + row_group_writer_factory.create_column_writers(row_group_index)?; let (mut column_writer_handles, mut col_array_channels) = - spawn_column_parallel_row_group_writer( - Arc::clone(&schema), - Arc::clone(&writer_props), - max_buffer_rb, - &pool, - )?; + spawn_column_parallel_row_group_writer(col_writers, max_buffer_rb, &pool)?; let mut current_rg_rows = 0; while let Some(mut rb) = data.recv().await { @@ -1377,10 +1548,12 @@ fn spawn_parquet_parallel_serialization_task( current_rg_rows = 0; rb = rb.slice(rows_left, rb.num_rows() - rows_left); + row_group_index += 1; + let col_writers = row_group_writer_factory + .create_column_writers(row_group_index)?; (column_writer_handles, col_array_channels) = spawn_column_parallel_row_group_writer( - Arc::clone(&schema), - Arc::clone(&writer_props), + col_writers, max_buffer_rb, &pool, )?; @@ -1411,29 +1584,21 @@ fn spawn_parquet_parallel_serialization_task( /// Consume RowGroups serialized by other parallel tasks and concatenate them in /// to the final parquet file, while flushing finalized bytes to an [ObjectStore] async fn concatenate_parallel_row_groups( + mut parquet_writer: SerializedFileWriter, + merged_buff: SharedBuffer, mut serialize_rx: Receiver>, - schema: Arc, - writer_props: Arc, mut object_store_writer: Box, pool: Arc, ) -> Result { - let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); - let mut file_reservation = MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); - let schema_desc = ArrowSchemaConverter::new().convert(schema.as_ref())?; - let mut parquet_writer = SerializedFileWriter::new( - merged_buff.clone(), - schema_desc.root_schema_ptr(), - writer_props, - )?; - while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; - let mut rg_out = parquet_writer.next_row_group()?; let (serialized_columns, mut rg_reservation, _cnt) = - result.map_err(DataFusionError::ExecutionJoin)??; + result.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + + let mut rg_out = parquet_writer.next_row_group()?; for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; rg_reservation.free(); @@ -1471,6 +1636,7 @@ async fn output_single_parquet_file_parallelized( data: Receiver, output_schema: Arc, parquet_props: &WriterProperties, + skip_arrow_metadata: bool, parallel_options: ParallelParquetWriterOptions, pool: Arc, ) -> Result { @@ -1480,7 +1646,19 @@ async fn output_single_parquet_file_parallelized( mpsc::channel::>(max_rowgroups); let arc_props = Arc::new(parquet_props.clone()); + let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let options = ArrowWriterOptions::new() + .with_properties(parquet_props.clone()) + .with_skip_arrow_metadata(skip_arrow_metadata); + let writer = ArrowWriter::try_new_with_options( + merged_buff.clone(), + Arc::clone(&output_schema), + options, + )?; + let (writer, row_group_writer_factory) = writer.into_serialized_writer()?; + let launch_serialization_task = spawn_parquet_parallel_serialization_task( + row_group_writer_factory, data, serialize_tx, Arc::clone(&output_schema), @@ -1489,9 +1667,9 @@ async fn output_single_parquet_file_parallelized( Arc::clone(&pool), ); let file_metadata = concatenate_parallel_row_groups( + writer, + merged_buff, serialize_rx, - Arc::clone(&output_schema), - Arc::clone(&arc_props), object_store_writer, pool, ) @@ -1500,38 +1678,224 @@ async fn output_single_parquet_file_parallelized( launch_serialization_task .join_unwind() .await - .map_err(DataFusionError::ExecutionJoin)??; + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; Ok(file_metadata) } -/// Min/max aggregation can take Dictionary encode input but always produces unpacked -/// (aka non Dictionary) output. We need to adjust the output data type to reflect this. -/// The reason min/max aggregate produces unpacked output because there is only one -/// min/max value per group; there is no needs to keep them Dictionary encode -fn min_max_aggregate_data_type(input_type: &DataType) -> &DataType { - if let DataType::Dictionary(_, value_type) = input_type { - value_type.as_ref() - } else { - input_type +#[cfg(test)] +mod tests { + use parquet::arrow::parquet_to_arrow_schema; + use std::sync::Arc; + + use super::*; + + use arrow::datatypes::DataType; + use parquet::schema::parser::parse_message_type; + + #[test] + fn coerce_int96_to_resolution_with_mixed_timestamps() { + // Unclear if Spark (or other writer) could generate a file with mixed timestamps like this, + // but we want to test the scenario just in case since it's at least a valid schema as far + // as the Parquet spec is concerned. + let spark_schema = " + message spark_schema { + optional int96 c0; + optional int64 c1 (TIMESTAMP(NANOS,true)); + optional int64 c2 (TIMESTAMP(NANOS,false)); + optional int64 c3 (TIMESTAMP(MILLIS,true)); + optional int64 c4 (TIMESTAMP(MILLIS,false)); + optional int64 c5 (TIMESTAMP(MICROS,true)); + optional int64 c6 (TIMESTAMP(MICROS,false)); + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = + coerce_int96_to_resolution(&descr, &arrow_schema, &TimeUnit::Microsecond) + .unwrap(); + + // Only the first field (c0) should be converted to a microsecond timestamp because it's the + // only timestamp that originated from an INT96. + let expected_schema = Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + true, + ), + Field::new("c2", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + Field::new( + "c3", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + Field::new("c4", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new( + "c5", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + true, + ), + Field::new("c6", DataType::Timestamp(TimeUnit::Microsecond, None), true), + ]); + + assert_eq!(result, expected_schema); } -} -fn create_max_min_accs( - schema: &Schema, -) -> (Vec>, Vec>) { - let max_values: Vec> = schema - .fields() - .iter() - .map(|field| { - MaxAccumulator::try_new(min_max_aggregate_data_type(field.data_type())).ok() - }) - .collect(); - let min_values: Vec> = schema - .fields() - .iter() - .map(|field| { - MinAccumulator::try_new(min_max_aggregate_data_type(field.data_type())).ok() - }) - .collect(); - (max_values, min_values) + #[test] + fn coerce_int96_to_resolution_with_nested_types() { + // This schema is derived from Comet's CometFuzzTestSuite ParquetGenerator only using int96 + // primitive types with generateStruct, generateArray, and generateMap set to true, with one + // additional field added to c4's struct to make sure all fields in a struct get modified. + // https://github.com/apache/datafusion-comet/blob/main/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala + let spark_schema = " + message spark_schema { + optional int96 c0; + optional group c1 { + optional int96 c0; + } + optional group c2 { + optional group c0 (LIST) { + repeated group list { + optional int96 element; + } + } + } + optional group c3 (LIST) { + repeated group list { + optional int96 element; + } + } + optional group c4 (LIST) { + repeated group list { + optional group element { + optional int96 c0; + optional int96 c1; + } + } + } + optional group c5 (MAP) { + repeated group key_value { + required int96 key; + optional int96 value; + } + } + optional group c6 (LIST) { + repeated group list { + optional group element (MAP) { + repeated group key_value { + required int96 key; + optional int96 value; + } + } + } + } + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = + coerce_int96_to_resolution(&descr, &arrow_schema, &TimeUnit::Microsecond) + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new_struct( + "c1", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + Field::new_struct( + "c2", + vec![Field::new_list( + "c0", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + )], + true, + ), + Field::new_list( + "c3", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + ), + Field::new_list( + "c4", + Field::new_struct( + "element", + vec![ + Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ], + true, + ), + true, + ), + Field::new_map( + "c5", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + Field::new_list( + "c6", + Field::new_map( + "element", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + true, + ), + ]); + + assert_eq!(result, expected_schema); + } } diff --git a/datafusion/datasource-parquet/src/metadata.rs b/datafusion/datasource-parquet/src/metadata.rs new file mode 100644 index 0000000000000..4de68793ce02a --- /dev/null +++ b/datafusion/datasource-parquet/src/metadata.rs @@ -0,0 +1,568 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DFParquetMetadata`] for fetching Parquet file metadata, statistics +//! and schema information. + +use crate::{ + apply_file_schema_type_coercions, coerce_int96_to_resolution, ObjectStoreFetch, +}; +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::compute::and; +use arrow::compute::kernels::cmp::eq; +use arrow::compute::sum; +use arrow::datatypes::{DataType, Schema, SchemaRef, TimeUnit}; +use datafusion_common::encryption::FileDecryptionProperties; +use datafusion_common::stats::Precision; +use datafusion_common::{ + ColumnStatistics, DataFusionError, Result, ScalarValue, Statistics, +}; +use datafusion_execution::cache::cache_manager::{FileMetadata, FileMetadataCache}; +use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; +use datafusion_physical_plan::Accumulator; +use log::debug; +use object_store::path::Path; +use object_store::{ObjectMeta, ObjectStore}; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; +use parquet::arrow::parquet_to_arrow_schema; +use parquet::file::metadata::{ + PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData, +}; +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +/// Handles fetching Parquet file schema, metadata and statistics +/// from object store. +/// +/// This component is exposed for low level integrations through +/// [`ParquetFileReaderFactory`]. +/// +/// [`ParquetFileReaderFactory`]: crate::ParquetFileReaderFactory +#[derive(Debug)] +pub struct DFParquetMetadata<'a> { + store: &'a dyn ObjectStore, + object_meta: &'a ObjectMeta, + metadata_size_hint: Option, + decryption_properties: Option<&'a FileDecryptionProperties>, + file_metadata_cache: Option>, + /// timeunit to coerce INT96 timestamps to + pub coerce_int96: Option, +} + +impl<'a> DFParquetMetadata<'a> { + pub fn new(store: &'a dyn ObjectStore, object_meta: &'a ObjectMeta) -> Self { + Self { + store, + object_meta, + metadata_size_hint: None, + decryption_properties: None, + file_metadata_cache: None, + coerce_int96: None, + } + } + + /// set metadata size hint + pub fn with_metadata_size_hint(mut self, metadata_size_hint: Option) -> Self { + self.metadata_size_hint = metadata_size_hint; + self + } + + /// set decryption properties + pub fn with_decryption_properties( + mut self, + decryption_properties: Option<&'a FileDecryptionProperties>, + ) -> Self { + self.decryption_properties = decryption_properties; + self + } + + /// set file metadata cache + pub fn with_file_metadata_cache( + mut self, + file_metadata_cache: Option>, + ) -> Self { + self.file_metadata_cache = file_metadata_cache; + self + } + + /// Set timeunit to coerce INT96 timestamps to + pub fn with_coerce_int96(mut self, time_unit: Option) -> Self { + self.coerce_int96 = time_unit; + self + } + + /// Fetch parquet metadata from the remote object store + pub async fn fetch_metadata(&self) -> Result> { + let Self { + store, + object_meta, + metadata_size_hint, + decryption_properties, + file_metadata_cache, + coerce_int96: _, + } = self; + + let fetch = ObjectStoreFetch::new(*store, object_meta); + + // implementation to fetch parquet metadata + let cache_metadata = + !cfg!(feature = "parquet_encryption") || decryption_properties.is_none(); + + if cache_metadata { + if let Some(parquet_metadata) = file_metadata_cache + .as_ref() + .and_then(|file_metadata_cache| file_metadata_cache.get(object_meta)) + .and_then(|file_metadata| { + file_metadata + .as_any() + .downcast_ref::() + .map(|cached_parquet_metadata| { + Arc::clone(cached_parquet_metadata.parquet_metadata()) + }) + }) + { + return Ok(parquet_metadata); + } + } + + let mut reader = + ParquetMetaDataReader::new().with_prefetch_hint(*metadata_size_hint); + + #[cfg(feature = "parquet_encryption")] + if let Some(decryption_properties) = decryption_properties { + reader = reader.with_decryption_properties(Some(decryption_properties)); + } + + if cache_metadata && file_metadata_cache.is_some() { + // Need to retrieve the entire metadata for the caching to be effective. + reader = reader.with_page_index_policy(PageIndexPolicy::Optional); + } + + let metadata = Arc::new( + reader + .load_and_finish(fetch, object_meta.size) + .await + .map_err(DataFusionError::from)?, + ); + + if cache_metadata { + if let Some(file_metadata_cache) = file_metadata_cache { + file_metadata_cache.put( + object_meta, + Arc::new(CachedParquetMetaData::new(Arc::clone(&metadata))), + ); + } + } + + Ok(metadata) + } + + /// Read and parse the schema of the Parquet file + pub async fn fetch_schema(&self) -> Result { + let metadata = self.fetch_metadata().await?; + + let file_metadata = metadata.file_metadata(); + let schema = parquet_to_arrow_schema( + file_metadata.schema_descr(), + file_metadata.key_value_metadata(), + )?; + let schema = self + .coerce_int96 + .as_ref() + .and_then(|time_unit| { + coerce_int96_to_resolution( + file_metadata.schema_descr(), + &schema, + time_unit, + ) + }) + .unwrap_or(schema); + Ok(schema) + } + + /// Return (path, schema) tuple by fetching the schema from Parquet file + pub(crate) async fn fetch_schema_with_location(&self) -> Result<(Path, Schema)> { + let loc_path = self.object_meta.location.clone(); + let schema = self.fetch_schema().await?; + Ok((loc_path, schema)) + } + + /// Fetch the metadata from the Parquet file via [`Self::fetch_metadata`] and convert + /// the statistics in the metadata using [`Self::statistics_from_parquet_metadata`] + pub async fn fetch_statistics(&self, table_schema: &SchemaRef) -> Result { + let metadata = self.fetch_metadata().await?; + Self::statistics_from_parquet_metadata(&metadata, table_schema) + } + + /// Convert statistics in [`ParquetMetaData`] into [`Statistics`] using [`StatisticsConverter`] + /// + /// The statistics are calculated for each column in the table schema + /// using the row group statistics in the parquet metadata. + /// + /// # Key behaviors: + /// + /// 1. Extracts row counts and byte sizes from all row groups + /// 2. Applies schema type coercions to align file schema with table schema + /// 3. Collects and aggregates statistics across row groups when available + /// + /// # When there are no statistics: + /// + /// If the Parquet file doesn't contain any statistics (has_statistics is false), the function returns a Statistics object with: + /// - Exact row count + /// - Exact byte size + /// - All column statistics marked as unknown via Statistics::unknown_column(&table_schema) + /// # When only some columns have statistics: + /// + /// For columns with statistics: + /// - Min/max values are properly extracted and represented as Precision::Exact + /// - Null counts are calculated by summing across row groups + /// + /// For columns without statistics, + /// - For min/max, there are two situations: + /// 1. The column isn't in arrow schema, then min/max values are set to Precision::Absent + /// 2. The column is in arrow schema, but not in parquet schema due to schema revolution, min/max values are set to Precision::Exact(null) + /// - Null counts are set to Precision::Exact(num_rows) (conservatively assuming all values could be null) + pub fn statistics_from_parquet_metadata( + metadata: &ParquetMetaData, + table_schema: &SchemaRef, + ) -> Result { + let row_groups_metadata = metadata.row_groups(); + + let mut statistics = Statistics::new_unknown(table_schema); + let mut has_statistics = false; + let mut num_rows = 0_usize; + let mut total_byte_size = 0_usize; + for row_group_meta in row_groups_metadata { + num_rows += row_group_meta.num_rows() as usize; + total_byte_size += row_group_meta.total_byte_size() as usize; + + if !has_statistics { + has_statistics = row_group_meta + .columns() + .iter() + .any(|column| column.statistics().is_some()); + } + } + statistics.num_rows = Precision::Exact(num_rows); + statistics.total_byte_size = Precision::Exact(total_byte_size); + + let file_metadata = metadata.file_metadata(); + let mut file_schema = parquet_to_arrow_schema( + file_metadata.schema_descr(), + file_metadata.key_value_metadata(), + )?; + + if let Some(merged) = apply_file_schema_type_coercions(table_schema, &file_schema) + { + file_schema = merged; + } + + statistics.column_statistics = if has_statistics { + let (mut max_accs, mut min_accs) = create_max_min_accs(table_schema); + let mut null_counts_array = + vec![Precision::Exact(0); table_schema.fields().len()]; + let mut is_max_value_exact = vec![Some(true); table_schema.fields().len()]; + let mut is_min_value_exact = vec![Some(true); table_schema.fields().len()]; + table_schema + .fields() + .iter() + .enumerate() + .for_each(|(idx, field)| { + match StatisticsConverter::try_new( + field.name(), + &file_schema, + file_metadata.schema_descr(), + ) { + Ok(stats_converter) => { + let mut accumulators = StatisticsAccumulators { + min_accs: &mut min_accs, + max_accs: &mut max_accs, + null_counts_array: &mut null_counts_array, + is_min_value_exact: &mut is_min_value_exact, + is_max_value_exact: &mut is_max_value_exact, + }; + summarize_min_max_null_counts( + &mut accumulators, + idx, + num_rows, + &stats_converter, + row_groups_metadata, + ) + .ok(); + } + Err(e) => { + debug!("Failed to create statistics converter: {e}"); + null_counts_array[idx] = Precision::Exact(num_rows); + } + } + }); + + get_col_stats( + table_schema, + null_counts_array, + &mut max_accs, + &mut min_accs, + &mut is_max_value_exact, + &mut is_min_value_exact, + ) + } else { + Statistics::unknown_column(table_schema) + }; + + Ok(statistics) + } +} + +/// Min/max aggregation can take Dictionary encode input but always produces unpacked +/// (aka non Dictionary) output. We need to adjust the output data type to reflect this. +/// The reason min/max aggregate produces unpacked output because there is only one +/// min/max value per group; there is no needs to keep them Dictionary encoded +fn min_max_aggregate_data_type(input_type: &DataType) -> &DataType { + if let DataType::Dictionary(_, value_type) = input_type { + value_type.as_ref() + } else { + input_type + } +} + +fn create_max_min_accs( + schema: &Schema, +) -> (Vec>, Vec>) { + let max_values: Vec> = schema + .fields() + .iter() + .map(|field| { + MaxAccumulator::try_new(min_max_aggregate_data_type(field.data_type())).ok() + }) + .collect(); + let min_values: Vec> = schema + .fields() + .iter() + .map(|field| { + MinAccumulator::try_new(min_max_aggregate_data_type(field.data_type())).ok() + }) + .collect(); + (max_values, min_values) +} + +fn get_col_stats( + schema: &Schema, + null_counts: Vec>, + max_values: &mut [Option], + min_values: &mut [Option], + is_max_value_exact: &mut [Option], + is_min_value_exact: &mut [Option], +) -> Vec { + (0..schema.fields().len()) + .map(|i| { + let max_value = match ( + max_values.get_mut(i).unwrap(), + is_max_value_exact.get(i).unwrap(), + ) { + (Some(max_value), Some(true)) => { + max_value.evaluate().ok().map(Precision::Exact) + } + (Some(max_value), Some(false)) | (Some(max_value), None) => { + max_value.evaluate().ok().map(Precision::Inexact) + } + (None, _) => None, + }; + let min_value = match ( + min_values.get_mut(i).unwrap(), + is_min_value_exact.get(i).unwrap(), + ) { + (Some(min_value), Some(true)) => { + min_value.evaluate().ok().map(Precision::Exact) + } + (Some(min_value), Some(false)) | (Some(min_value), None) => { + min_value.evaluate().ok().map(Precision::Inexact) + } + (None, _) => None, + }; + ColumnStatistics { + null_count: null_counts[i], + max_value: max_value.unwrap_or(Precision::Absent), + min_value: min_value.unwrap_or(Precision::Absent), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + } + }) + .collect() +} + +/// Holds the accumulator state for collecting statistics from row groups +struct StatisticsAccumulators<'a> { + min_accs: &'a mut [Option], + max_accs: &'a mut [Option], + null_counts_array: &'a mut [Precision], + is_min_value_exact: &'a mut [Option], + is_max_value_exact: &'a mut [Option], +} + +fn summarize_min_max_null_counts( + accumulators: &mut StatisticsAccumulators, + arrow_schema_index: usize, + num_rows: usize, + stats_converter: &StatisticsConverter, + row_groups_metadata: &[RowGroupMetaData], +) -> Result<()> { + let max_values = stats_converter.row_group_maxes(row_groups_metadata)?; + let min_values = stats_converter.row_group_mins(row_groups_metadata)?; + let null_counts = stats_converter.row_group_null_counts(row_groups_metadata)?; + let is_max_value_exact_stat = + stats_converter.row_group_is_max_value_exact(row_groups_metadata)?; + let is_min_value_exact_stat = + stats_converter.row_group_is_min_value_exact(row_groups_metadata)?; + + if let Some(max_acc) = &mut accumulators.max_accs[arrow_schema_index] { + max_acc.update_batch(&[Arc::clone(&max_values)])?; + let mut cur_max_acc = max_acc.clone(); + accumulators.is_max_value_exact[arrow_schema_index] = has_any_exact_match( + cur_max_acc.evaluate()?, + max_values, + is_max_value_exact_stat, + ); + } + + if let Some(min_acc) = &mut accumulators.min_accs[arrow_schema_index] { + min_acc.update_batch(&[Arc::clone(&min_values)])?; + let mut cur_min_acc = min_acc.clone(); + accumulators.is_min_value_exact[arrow_schema_index] = has_any_exact_match( + cur_min_acc.evaluate()?, + min_values, + is_min_value_exact_stat, + ); + } + + accumulators.null_counts_array[arrow_schema_index] = + Precision::Exact(match sum(&null_counts) { + Some(null_count) => null_count as usize, + None => num_rows, + }); + + Ok(()) +} + +/// Checks if any occurrence of `value` in `array` corresponds to a `true` +/// entry in the `exactness` array. +/// +/// This is used to determine if a calculated statistic (e.g., min or max) +/// is exact, by checking if at least one of its source values was exact. +/// +/// # Example +/// - `value`: `0` +/// - `array`: `[0, 1, 0, 3, 0, 5]` +/// - `exactness`: `[true, false, false, false, false, false]` +/// +/// The value `0` appears at indices `[0, 2, 4]`. The corresponding exactness +/// values are `[true, false, false]`. Since at least one is `true`, the +/// function returns `Some(true)`. +fn has_any_exact_match( + value: ScalarValue, + array: ArrayRef, + exactness: BooleanArray, +) -> Option { + let scalar_array = value.to_scalar().ok()?; + let eq_mask = eq(&scalar_array, &array).ok()?; + let combined_mask = and(&eq_mask, &exactness).ok()?; + Some(combined_mask.true_count() > 0) +} + +/// Wrapper to implement [`FileMetadata`] for [`ParquetMetaData`]. +pub struct CachedParquetMetaData(Arc); + +impl CachedParquetMetaData { + pub fn new(metadata: Arc) -> Self { + Self(metadata) + } + + pub fn parquet_metadata(&self) -> &Arc { + &self.0 + } +} + +impl FileMetadata for CachedParquetMetaData { + fn as_any(&self) -> &dyn Any { + self + } + + fn memory_size(&self) -> usize { + self.0.memory_size() + } + + fn extra_info(&self) -> HashMap { + let page_index = + self.0.column_index().is_some() && self.0.offset_index().is_some(); + HashMap::from([("page_index".to_owned(), page_index.to_string())]) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, BooleanArray, Int32Array}; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + #[test] + fn test_has_any_exact_match() { + // Case 1: Mixed exact and inexact matches + { + let computed_min = ScalarValue::Int32(Some(0)); + let row_group_mins = + Arc::new(Int32Array::from(vec![0, 1, 0, 3, 0, 5])) as ArrayRef; + let exactness = + BooleanArray::from(vec![true, false, false, false, false, false]); + + let result = has_any_exact_match(computed_min, row_group_mins, exactness); + assert_eq!(result, Some(true)); + } + // Case 2: All inexact matches + { + let computed_min = ScalarValue::Int32(Some(0)); + let row_group_mins = + Arc::new(Int32Array::from(vec![0, 1, 0, 3, 0, 5])) as ArrayRef; + let exactness = + BooleanArray::from(vec![false, false, false, false, false, false]); + + let result = has_any_exact_match(computed_min, row_group_mins, exactness); + assert_eq!(result, Some(false)); + } + // Case 3: All exact matches + { + let computed_max = ScalarValue::Int32(Some(5)); + let row_group_maxes = + Arc::new(Int32Array::from(vec![1, 5, 3, 5, 2, 5])) as ArrayRef; + let exactness = + BooleanArray::from(vec![false, true, true, true, false, true]); + + let result = has_any_exact_match(computed_max, row_group_maxes, exactness); + assert_eq!(result, Some(true)); + } + // Case 4: All maxes are null values + { + let computed_max = ScalarValue::Int32(None); + let row_group_maxes = + Arc::new(Int32Array::from(vec![None, None, None, None])) as ArrayRef; + let exactness = BooleanArray::from(vec![None, Some(true), None, Some(false)]); + + let result = has_any_exact_match(computed_max, row_group_maxes, exactness); + assert_eq!(result, Some(false)); + } + } +} diff --git a/datafusion/datasource-parquet/src/metrics.rs b/datafusion/datasource-parquet/src/metrics.rs index 3213d0201295a..d75a979d4cad0 100644 --- a/datafusion/datasource-parquet/src/metrics.rs +++ b/datafusion/datasource-parquet/src/metrics.rs @@ -27,6 +27,21 @@ use datafusion_physical_plan::metrics::{ /// [`ParquetFileReaderFactory`]: super::ParquetFileReaderFactory #[derive(Debug, Clone)] pub struct ParquetFileMetrics { + /// Number of file **ranges** pruned by partition or file level statistics. + /// Pruning of files often happens at planning time but may happen at execution time + /// if dynamic filters (e.g. from a join) result in additional pruning. + /// + /// This does **not** necessarily equal the number of files pruned: + /// files may be scanned in sub-ranges to increase parallelism, + /// in which case this will represent the number of sub-ranges pruned, not the number of files. + /// The number of files pruned will always be less than or equal to this number. + /// + /// A single file may have some ranges that are not pruned and some that are pruned. + /// For example, with a query like `ORDER BY col LIMIT 10`, the TopK dynamic filter + /// pushdown optimization may fill up the TopK heap when reading the first part of a file, + /// then skip the second part if file statistics indicate it cannot contain rows + /// that would be in the TopK. + pub files_ranges_pruned_statistics: Count, /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, /// Number of row groups whose bloom filters were checked and matched (not pruned) @@ -57,6 +72,13 @@ pub struct ParquetFileMetrics { pub page_index_eval_time: Time, /// Total time spent reading and parsing metadata from the footer pub metadata_load_time: Time, + /// Predicate Cache: number of records read directly from the inner reader. + /// This is the number of rows decoded while evaluating predicates + pub predicate_cache_inner_records: Count, + /// Predicate Cache: number of records read from the cache. This is the + /// number of rows that were stored in the cache after evaluating predicates + /// reused for the output. + pub predicate_cache_records: Count, } impl ParquetFileMetrics { @@ -122,7 +144,19 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .subset_time("metadata_load_time", partition); + let files_ranges_pruned_statistics = MetricBuilder::new(metrics) + .counter("files_ranges_pruned_statistics", partition); + + let predicate_cache_inner_records = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("predicate_cache_inner_records", partition); + + let predicate_cache_records = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("predicate_cache_records", partition); + Self { + files_ranges_pruned_statistics, predicate_evaluation_errors, row_groups_matched_bloom_filter, row_groups_pruned_bloom_filter, @@ -138,6 +172,8 @@ impl ParquetFileMetrics { bloom_filter_eval_time, page_index_eval_time, metadata_load_time, + predicate_cache_inner_records, + predicate_cache_records, } } } diff --git a/datafusion/datasource-parquet/src/mod.rs b/datafusion/datasource-parquet/src/mod.rs index 516b13792189b..2f64f34bc09b4 100644 --- a/datafusion/datasource-parquet/src/mod.rs +++ b/datafusion/datasource-parquet/src/mod.rs @@ -19,10 +19,9 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -//! [`ParquetExec`] FileSource for reading Parquet files - pub mod access_plan; pub mod file_format; +pub mod metadata; mod metrics; mod opener; mod page_filter; @@ -32,520 +31,12 @@ mod row_group_filter; pub mod source; mod writer; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; - pub use access_plan::{ParquetAccessPlan, RowGroupAccess}; -use arrow::datatypes::SchemaRef; -use datafusion_common::config::{ConfigOptions, TableParquetOptions}; -use datafusion_common::Result; -use datafusion_common::{Constraints, Statistics}; -use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::schema_adapter::SchemaAdapterFactory; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{ - EquivalenceProperties, LexOrdering, Partitioning, PhysicalExpr, -}; -use datafusion_physical_optimizer::pruning::PruningPredicate; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; pub use file_format::*; pub use metrics::ParquetFileMetrics; pub use page_filter::PagePruningAccessPlanFilter; -pub use reader::{DefaultParquetFileReaderFactory, ParquetFileReaderFactory}; +pub use reader::*; // Expose so downstream crates can use it pub use row_filter::build_row_filter; pub use row_filter::can_expr_be_pushed_down_with_schemas; pub use row_group_filter::RowGroupAccessPlanFilter; -use source::ParquetSource; pub use writer::plan_to_parquet; - -use datafusion_datasource::file_groups::FileGroup; -use log::debug; - -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -/// Deprecated Execution plan replaced with DataSourceExec -pub struct ParquetExec { - inner: DataSourceExec, - base_config: FileScanConfig, - table_parquet_options: TableParquetOptions, - /// Optional predicate for row filtering during parquet scan - predicate: Option>, - /// Optional predicate for pruning row groups (derived from `predicate`) - pruning_predicate: Option>, - /// Optional user defined parquet file reader factory - parquet_file_reader_factory: Option>, - /// Optional user defined schema adapter - schema_adapter_factory: Option>, -} - -#[allow(unused, deprecated)] -impl From for ParquetExecBuilder { - fn from(exec: ParquetExec) -> Self { - exec.into_builder() - } -} - -/// [`ParquetExecBuilder`], deprecated builder for [`ParquetExec`]. -/// -/// ParquetExec is replaced with `DataSourceExec` and it includes `ParquetSource` -/// -/// See example on [`ParquetSource`]. -#[deprecated( - since = "46.0.0", - note = "use DataSourceExec with ParquetSource instead" -)] -#[allow(unused, deprecated)] -pub struct ParquetExecBuilder { - file_scan_config: FileScanConfig, - predicate: Option>, - metadata_size_hint: Option, - table_parquet_options: TableParquetOptions, - parquet_file_reader_factory: Option>, - schema_adapter_factory: Option>, -} - -#[allow(unused, deprecated)] -impl ParquetExecBuilder { - /// Create a new builder to read the provided file scan configuration - pub fn new(file_scan_config: FileScanConfig) -> Self { - Self::new_with_options(file_scan_config, TableParquetOptions::default()) - } - - /// Create a new builder to read the data specified in the file scan - /// configuration with the provided `TableParquetOptions`. - pub fn new_with_options( - file_scan_config: FileScanConfig, - table_parquet_options: TableParquetOptions, - ) -> Self { - Self { - file_scan_config, - predicate: None, - metadata_size_hint: None, - table_parquet_options, - parquet_file_reader_factory: None, - schema_adapter_factory: None, - } - } - - /// Update the list of files groups to read - pub fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.file_scan_config.file_groups = file_groups; - self - } - - /// Set the filter predicate when reading. - /// - /// See the "Predicate Pushdown" section of the [`ParquetExec`] documentation - /// for more details. - pub fn with_predicate(mut self, predicate: Arc) -> Self { - self.predicate = Some(predicate); - self - } - - /// Set the metadata size hint - /// - /// This value determines how many bytes at the end of the file the default - /// [`ParquetFileReaderFactory`] will request in the initial IO. If this is - /// too small, the ParquetExec will need to make additional IO requests to - /// read the footer. - pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { - self.metadata_size_hint = Some(metadata_size_hint); - self - } - - /// Set the options for controlling how the ParquetExec reads parquet files. - /// - /// See also [`Self::new_with_options`] - pub fn with_table_parquet_options( - mut self, - table_parquet_options: TableParquetOptions, - ) -> Self { - self.table_parquet_options = table_parquet_options; - self - } - - /// Set optional user defined parquet file reader factory. - /// - /// You can use [`ParquetFileReaderFactory`] to more precisely control how - /// data is read from parquet files (e.g. skip re-reading metadata, coalesce - /// I/O operations, etc). - /// - /// The default reader factory reads directly from an [`ObjectStore`] - /// instance using individual I/O operations for the footer and each page. - /// - /// If a custom `ParquetFileReaderFactory` is provided, then data access - /// operations will be routed to this factory instead of [`ObjectStore`]. - /// - /// [`ObjectStore`]: object_store::ObjectStore - pub fn with_parquet_file_reader_factory( - mut self, - parquet_file_reader_factory: Arc, - ) -> Self { - self.parquet_file_reader_factory = Some(parquet_file_reader_factory); - self - } - - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - - /// Convenience: build an `Arc`d `ParquetExec` from this builder - pub fn build_arc(self) -> Arc { - Arc::new(self.build()) - } - - /// Build a [`ParquetExec`] - #[must_use] - pub fn build(self) -> ParquetExec { - let Self { - file_scan_config, - predicate, - metadata_size_hint, - table_parquet_options, - parquet_file_reader_factory, - schema_adapter_factory, - } = self; - let mut parquet = ParquetSource::new(table_parquet_options); - if let Some(predicate) = predicate.clone() { - parquet = parquet - .with_predicate(Arc::clone(&file_scan_config.file_schema), predicate); - } - if let Some(metadata_size_hint) = metadata_size_hint { - parquet = parquet.with_metadata_size_hint(metadata_size_hint) - } - if let Some(parquet_reader_factory) = parquet_file_reader_factory { - parquet = parquet.with_parquet_file_reader_factory(parquet_reader_factory) - } - if let Some(schema_factory) = schema_adapter_factory { - parquet = parquet.with_schema_adapter_factory(schema_factory); - } - - let base_config = file_scan_config.with_source(Arc::new(parquet.clone())); - debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", - base_config.file_groups, base_config.projection, predicate, base_config.limit); - - ParquetExec { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - predicate, - pruning_predicate: parquet.pruning_predicate, - schema_adapter_factory: parquet.schema_adapter_factory, - parquet_file_reader_factory: parquet.parquet_file_reader_factory, - table_parquet_options: parquet.table_parquet_options, - } - } -} - -#[allow(unused, deprecated)] -impl ParquetExec { - /// Create a new Parquet reader execution plan provided file list and schema. - pub fn new( - base_config: FileScanConfig, - predicate: Option>, - metadata_size_hint: Option, - table_parquet_options: TableParquetOptions, - ) -> Self { - let mut builder = - ParquetExecBuilder::new_with_options(base_config, table_parquet_options); - if let Some(predicate) = predicate { - builder = builder.with_predicate(predicate); - } - if let Some(metadata_size_hint) = metadata_size_hint { - builder = builder.with_metadata_size_hint(metadata_size_hint); - } - builder.build() - } - /// Return a [`ParquetExecBuilder`]. - /// - /// See example on [`ParquetExec`] and [`ParquetExecBuilder`] for specifying - /// parquet table options. - pub fn builder(file_scan_config: FileScanConfig) -> ParquetExecBuilder { - ParquetExecBuilder::new(file_scan_config) - } - - /// Convert this `ParquetExec` into a builder for modification - pub fn into_builder(self) -> ParquetExecBuilder { - // list out fields so it is clear what is being dropped - // (note the fields which are dropped are re-created as part of calling - // `build` on the builder) - let file_scan_config = self.file_scan_config(); - let parquet = self.parquet_source(); - - ParquetExecBuilder { - file_scan_config, - predicate: parquet.predicate, - metadata_size_hint: parquet.metadata_size_hint, - table_parquet_options: parquet.table_parquet_options, - parquet_file_reader_factory: parquet.parquet_file_reader_factory, - schema_adapter_factory: parquet.schema_adapter_factory, - } - } - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn parquet_source(&self) -> ParquetSource { - self.file_scan_config() - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - /// [`FileScanConfig`] that controls this scan (such as which files to read) - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - /// Options passed to the parquet reader for this scan - pub fn table_parquet_options(&self) -> &TableParquetOptions { - &self.table_parquet_options - } - /// Optional predicate. - pub fn predicate(&self) -> Option<&Arc> { - self.predicate.as_ref() - } - /// Optional reference to this parquet scan's pruning predicate - pub fn pruning_predicate(&self) -> Option<&Arc> { - self.pruning_predicate.as_ref() - } - /// return the optional file reader factory - pub fn parquet_file_reader_factory( - &self, - ) -> Option<&Arc> { - self.parquet_file_reader_factory.as_ref() - } - /// Optional user defined parquet file reader factory. - pub fn with_parquet_file_reader_factory( - mut self, - parquet_file_reader_factory: Arc, - ) -> Self { - let mut parquet = self.parquet_source(); - parquet.parquet_file_reader_factory = - Some(Arc::clone(&parquet_file_reader_factory)); - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.parquet_file_reader_factory = Some(parquet_file_reader_factory); - self - } - /// return the optional schema adapter factory - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - let mut parquet = self.parquet_source(); - parquet.schema_adapter_factory = Some(Arc::clone(&schema_adapter_factory)); - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - /// If true, the predicate will be used during the parquet scan. - /// Defaults to false - /// - /// [`Expr`]: datafusion_expr::Expr - pub fn with_pushdown_filters(mut self, pushdown_filters: bool) -> Self { - let mut parquet = self.parquet_source(); - parquet.table_parquet_options.global.pushdown_filters = pushdown_filters; - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.table_parquet_options.global.pushdown_filters = pushdown_filters; - self - } - - /// Return the value described in [`Self::with_pushdown_filters`] - fn pushdown_filters(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .pushdown_filters - } - /// If true, the `RowFilter` made by `pushdown_filters` may try to - /// minimize the cost of filter evaluation by reordering the - /// predicate [`Expr`]s. If false, the predicates are applied in - /// the same order as specified in the query. Defaults to false. - /// - /// [`Expr`]: datafusion_expr::Expr - pub fn with_reorder_filters(mut self, reorder_filters: bool) -> Self { - let mut parquet = self.parquet_source(); - parquet.table_parquet_options.global.reorder_filters = reorder_filters; - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.table_parquet_options.global.reorder_filters = reorder_filters; - self - } - /// Return the value described in [`Self::with_reorder_filters`] - fn reorder_filters(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .reorder_filters - } - /// If enabled, the reader will read the page index - /// This is used to optimize filter pushdown - /// via `RowSelector` and `RowFilter` by - /// eliminating unnecessary IO and decoding - fn bloom_filter_on_read(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .bloom_filter_on_read - } - /// Return the value described in [`ParquetSource::with_enable_page_index`] - fn enable_page_index(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .enable_page_index - } - - fn output_partitioning_helper(file_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_config: &FileScanConfig, - ) -> PlanProperties { - PlanProperties::new( - EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints), - Self::output_partitioning_helper(file_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - /// Updates the file groups to read and recalculates the output partitioning - /// - /// Note this function does not update statistics or other properties - /// that depend on the file groups. - fn with_file_groups_and_update_partitioning( - mut self, - file_groups: Vec, - ) -> Self { - let mut config = self.file_scan_config(); - config.file_groups = file_groups; - self.inner = self.inner.with_data_source(Arc::new(config)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for ParquetExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for ParquetExec { - fn name(&self) -> &'static str { - "ParquetExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition_index: usize, - ctx: Arc, - ) -> Result { - self.inner.execute(partition_index, ctx) - } - fn metrics(&self) -> Option { - self.inner.metrics() - } - fn statistics(&self) -> Result { - self.inner.statistics() - } - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - -fn should_enable_page_index( - enable_page_index: bool, - page_pruning_predicate: &Option>, -) -> bool { - enable_page_index - && page_pruning_predicate.is_some() - && page_pruning_predicate - .as_ref() - .map(|p| p.filter_number() > 0) - .unwrap_or(false) -} diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 732fef47d5a75..167fc3c5147e9 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -17,30 +17,44 @@ //! [`ParquetOpener`] for opening Parquet files -use std::sync::Arc; - use crate::page_filter::PagePruningAccessPlanFilter; use crate::row_group_filter::RowGroupAccessPlanFilter; use crate::{ - apply_file_schema_type_coercions, row_filter, should_enable_page_index, + apply_file_schema_type_coercions, coerce_int96_to_resolution, row_filter, ParquetAccessPlan, ParquetFileMetrics, ParquetFileReaderFactory, }; -use datafusion_datasource::file_meta::FileMeta; +use arrow::array::RecordBatch; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; use datafusion_datasource::schema_adapter::SchemaAdapterFactory; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; -use arrow::datatypes::SchemaRef; -use arrow::error::ArrowError; -use datafusion_common::{exec_err, Result}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::pruning::PruningPredicate; -use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use arrow::datatypes::{FieldRef, SchemaRef, TimeUnit}; +use datafusion_common::encryption::FileDecryptionProperties; -use futures::{StreamExt, TryStreamExt}; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use datafusion_physical_expr_common::physical_expr::{ + is_dynamic_physical_expr, PhysicalExpr, +}; +use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion_pruning::{build_pruning_predicate, FilePruner, PruningPredicate}; + +#[cfg(feature = "parquet_encryption")] +use datafusion_common::config::EncryptionFactoryOptions; +#[cfg(feature = "parquet_encryption")] +use datafusion_execution::parquet_encryption::EncryptionFactory; +use futures::{ready, Stream, StreamExt, TryStreamExt}; +use itertools::Itertools; use log::debug; +use parquet::arrow::arrow_reader::metrics::ArrowReaderMetrics; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use parquet::file::metadata::{PageIndexPolicy, ParquetMetaDataReader}; /// Implements [`FileOpener`] for a parquet file pub(super) struct ParquetOpener { @@ -54,12 +68,11 @@ pub(super) struct ParquetOpener { pub limit: Option, /// Optional predicate to apply during the scan pub predicate: Option>, - /// Optional pruning predicate applied to row group statistics - pub pruning_predicate: Option>, - /// Optional pruning predicate applied to data page statistics - pub page_pruning_predicate: Option>, - /// Schema of the output table - pub table_schema: SchemaRef, + /// Schema of the output table without partition columns. + /// This is the schema we coerce the physical file schema into. + pub logical_file_schema: SchemaRef, + /// Partition columns + pub partition_fields: Vec, /// Optional hint for how large the initial request to read parquet metadata /// should be pub metadata_size_hint: Option, @@ -80,22 +93,41 @@ pub(super) struct ParquetOpener { pub enable_bloom_filter: bool, /// Schema adapter factory pub schema_adapter_factory: Arc, + /// Should row group pruning be applied + pub enable_row_group_stats_pruning: bool, + /// Coerce INT96 timestamps to specific TimeUnit + pub coerce_int96: Option, + /// Optional parquet FileDecryptionProperties + #[cfg(feature = "parquet_encryption")] + pub file_decryption_properties: Option>, + /// Rewrite expressions in the context of the file schema + pub(crate) expr_adapter_factory: Option>, + /// Optional factory to create file decryption properties dynamically + #[cfg(feature = "parquet_encryption")] + pub encryption_factory: + Option<(Arc, EncryptionFactoryOptions)>, + /// Maximum size of the predicate cache, in bytes. If none, uses + /// the arrow-rs default. + pub max_predicate_cache_size: Option, } impl FileOpener for ParquetOpener { - fn open(&self, file_meta: FileMeta) -> Result { - let file_range = file_meta.range.clone(); - let extensions = file_meta.extensions.clone(); - let file_name = file_meta.location().to_string(); + fn open(&self, partitioned_file: PartitionedFile) -> Result { + let file_range = partitioned_file.range.clone(); + let extensions = partitioned_file.extensions.clone(); + let file_location = partitioned_file.object_meta.location.clone(); + let file_name = file_location.to_string(); let file_metrics = ParquetFileMetrics::new(self.partition_index, &file_name, &self.metrics); - let metadata_size_hint = file_meta.metadata_size_hint.or(self.metadata_size_hint); + let metadata_size_hint = partitioned_file + .metadata_size_hint + .or(self.metadata_size_hint); - let mut reader: Box = + let mut async_file_reader: Box = self.parquet_file_reader_factory.create_reader( self.partition_index, - file_meta, + partitioned_file.clone(), metadata_size_hint, &self.metrics, )?; @@ -103,53 +135,184 @@ impl FileOpener for ParquetOpener { let batch_size = self.batch_size; let projected_schema = - SchemaRef::from(self.table_schema.project(&self.projection)?); + SchemaRef::from(self.logical_file_schema.project(&self.projection)?); let schema_adapter_factory = Arc::clone(&self.schema_adapter_factory); let schema_adapter = self .schema_adapter_factory - .create(projected_schema, Arc::clone(&self.table_schema)); - let predicate = self.predicate.clone(); - let pruning_predicate = self.pruning_predicate.clone(); - let page_pruning_predicate = self.page_pruning_predicate.clone(); - let table_schema = Arc::clone(&self.table_schema); + .create(projected_schema, Arc::clone(&self.logical_file_schema)); + let mut predicate = self.predicate.clone(); + let logical_file_schema = Arc::clone(&self.logical_file_schema); + let partition_fields = self.partition_fields.clone(); let reorder_predicates = self.reorder_filters; let pushdown_filters = self.pushdown_filters; - let enable_page_index = should_enable_page_index( - self.enable_page_index, - &self.page_pruning_predicate, - ); + let coerce_int96 = self.coerce_int96; let enable_bloom_filter = self.enable_bloom_filter; + let enable_row_group_stats_pruning = self.enable_row_group_stats_pruning; let limit = self.limit; + let predicate_creation_errors = MetricBuilder::new(&self.metrics) + .global_counter("num_predicate_creation_errors"); + + let expr_adapter_factory = self.expr_adapter_factory.clone(); + let mut predicate_file_schema = Arc::clone(&self.logical_file_schema); + + let enable_page_index = self.enable_page_index; + #[cfg(feature = "parquet_encryption")] + let encryption_context = self.get_encryption_context(); + let max_predicate_cache_size = self.max_predicate_cache_size; + Ok(Box::pin(async move { - let options = ArrowReaderOptions::new().with_page_index(enable_page_index); + #[cfg(feature = "parquet_encryption")] + let file_decryption_properties = encryption_context + .get_file_decryption_properties(&file_location) + .await?; + + // Prune this file using the file level statistics and partition values. + // Since dynamic filters may have been updated since planning it is possible that we are able + // to prune files now that we couldn't prune at planning time. + // It is assumed that there is no point in doing pruning here if the predicate is not dynamic, + // as it would have been done at planning time. + // We'll also check this after every record batch we read, + // and if at some point we are able to prove we can prune the file using just the file level statistics + // we can end the stream early. + let mut file_pruner = predicate + .as_ref() + .map(|p| { + Ok::<_, DataFusionError>( + (is_dynamic_physical_expr(p) | partitioned_file.has_statistics()) + .then_some(FilePruner::new( + Arc::clone(p), + &logical_file_schema, + partition_fields.clone(), + partitioned_file.clone(), + predicate_creation_errors.clone(), + )?), + ) + }) + .transpose()? + .flatten(); + + if let Some(file_pruner) = &mut file_pruner { + if file_pruner.should_prune()? { + // Return an empty stream immediately to skip the work of setting up the actual stream + file_metrics.files_ranges_pruned_statistics.add(1); + return Ok(futures::stream::empty().boxed()); + } + } + // Don't load the page index yet. Since it is not stored inline in + // the footer, loading the page index if it is not needed will do + // unnecessary I/O. We decide later if it is needed to evaluate the + // pruning predicates. Thus default to not requesting if from the + // underlying reader. + let mut options = ArrowReaderOptions::new().with_page_index(false); + #[cfg(feature = "parquet_encryption")] + if let Some(fd_val) = file_decryption_properties { + options = options.with_file_decryption_properties((*fd_val).clone()); + } let mut metadata_timer = file_metrics.metadata_load_time.timer(); - let metadata = - ArrowReaderMetadata::load_async(&mut reader, options.clone()).await?; - let mut schema = Arc::clone(metadata.schema()); - // read with view types - if let Some(merged) = apply_file_schema_type_coercions(&table_schema, &schema) - { - schema = Arc::new(merged); + // Begin by loading the metadata from the underlying reader (note + // the returned metadata may actually include page indexes as some + // readers may return page indexes even when not requested -- for + // example when they are cached) + let mut reader_metadata = + ArrowReaderMetadata::load_async(&mut async_file_reader, options.clone()) + .await?; + + // Note about schemas: we are actually dealing with **3 different schemas** here: + // - The table schema as defined by the TableProvider. + // This is what the user sees, what they get when they `SELECT * FROM table`, etc. + // - The logical file schema: this is the table schema minus any hive partition columns and projections. + // This is what the physicalfile schema is coerced to. + // - The physical file schema: this is the schema as defined by the parquet file. This is what the parquet file actually contains. + let mut physical_file_schema = Arc::clone(reader_metadata.schema()); + + // The schema loaded from the file may not be the same as the + // desired schema (for example if we want to instruct the parquet + // reader to read strings using Utf8View instead). Update if necessary + if let Some(merged) = apply_file_schema_type_coercions( + &logical_file_schema, + &physical_file_schema, + ) { + physical_file_schema = Arc::new(merged); + options = options.with_schema(Arc::clone(&physical_file_schema)); + reader_metadata = ArrowReaderMetadata::try_new( + Arc::clone(reader_metadata.metadata()), + options.clone(), + )?; } - let options = ArrowReaderOptions::new() - .with_page_index(enable_page_index) - .with_schema(Arc::clone(&schema)); - let metadata = - ArrowReaderMetadata::try_new(Arc::clone(metadata.metadata()), options)?; + if let Some(ref coerce) = coerce_int96 { + if let Some(merged) = coerce_int96_to_resolution( + reader_metadata.parquet_schema(), + &physical_file_schema, + coerce, + ) { + physical_file_schema = Arc::new(merged); + options = options.with_schema(Arc::clone(&physical_file_schema)); + reader_metadata = ArrowReaderMetadata::try_new( + Arc::clone(reader_metadata.metadata()), + options.clone(), + )?; + } + } - metadata_timer.stop(); + // Adapt the predicate to the physical file schema. + // This evaluates missing columns and inserts any necessary casts. + if let Some(expr_adapter_factory) = expr_adapter_factory { + predicate = predicate + .map(|p| { + let partition_values = partition_fields + .iter() + .cloned() + .zip(partitioned_file.partition_values) + .collect_vec(); + let expr = expr_adapter_factory + .create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), + ) + .with_partition_values(partition_values) + .rewrite(p)?; + // After rewriting to the file schema, further simplifications may be possible. + // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` + // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). + PhysicalExprSimplifier::new(&physical_file_schema).simplify(expr) + }) + .transpose()?; + predicate_file_schema = Arc::clone(&physical_file_schema); + } - let mut builder = - ParquetRecordBatchStreamBuilder::new_with_metadata(reader, metadata); + // Build predicates for this specific file + let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( + predicate.as_ref(), + &predicate_file_schema, + &predicate_creation_errors, + ); - let file_schema = Arc::clone(builder.schema()); + // The page index is not stored inline in the parquet footer so the + // code above may not have read the page index structures yet. If we + // need them for reading and they aren't yet loaded, we need to load them now. + if should_enable_page_index(enable_page_index, &page_pruning_predicate) { + reader_metadata = load_page_index( + reader_metadata, + &mut async_file_reader, + // Since we're manually loading the page index the option here should not matter but we pass it in for consistency + options.with_page_index(true), + ) + .await?; + } + + metadata_timer.stop(); + + let mut builder = ParquetRecordBatchStreamBuilder::new_with_metadata( + async_file_reader, + reader_metadata, + ); let (schema_mapping, adapted_projections) = - schema_adapter.map_schema(&file_schema)?; + schema_adapter.map_schema(&physical_file_schema)?; let mask = ProjectionMask::roots( builder.parquet_schema(), @@ -160,8 +323,8 @@ impl FileOpener for ParquetOpener { if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { let row_filter = row_filter::build_row_filter( &predicate, - &file_schema, - &table_schema, + &physical_file_schema, + &predicate_file_schema, builder.metadata(), reorder_predicates, &file_metrics, @@ -175,8 +338,7 @@ impl FileOpener for ParquetOpener { Ok(None) => {} Err(e) => { debug!( - "Ignoring error building row filter for '{:?}': {}", - predicate, e + "Ignoring error building row filter for '{predicate:?}': {e}" ); } }; @@ -197,18 +359,20 @@ impl FileOpener for ParquetOpener { } // If there is a predicate that can be evaluated against the metadata if let Some(predicate) = predicate.as_ref() { - row_groups.prune_by_statistics( - &file_schema, - builder.parquet_schema(), - rg_metadata, - predicate, - &file_metrics, - ); + if enable_row_group_stats_pruning { + row_groups.prune_by_statistics( + &physical_file_schema, + builder.parquet_schema(), + rg_metadata, + predicate, + &file_metrics, + ); + } if enable_bloom_filter && !row_groups.is_empty() { row_groups .prune_by_bloom_filters( - &file_schema, + &physical_file_schema, &mut builder, predicate, &file_metrics, @@ -226,7 +390,7 @@ impl FileOpener for ParquetOpener { if let Some(p) = page_pruning_predicate { access_plan = p.prune_plan_with_page_index( access_plan, - &file_schema, + &physical_file_schema, builder.parquet_schema(), file_metadata.as_ref(), &file_metrics, @@ -245,24 +409,213 @@ impl FileOpener for ParquetOpener { builder = builder.with_limit(limit) } + if let Some(max_predicate_cache_size) = max_predicate_cache_size { + builder = builder.with_max_predicate_cache_size(max_predicate_cache_size); + } + + // metrics from the arrow reader itself + let arrow_reader_metrics = ArrowReaderMetrics::enabled(); + let stream = builder .with_projection(mask) .with_batch_size(batch_size) .with_row_groups(row_group_indexes) + .with_metrics(arrow_reader_metrics.clone()) .build()?; - let adapted = stream - .map_err(|e| ArrowError::ExternalError(Box::new(e))) - .map(move |maybe_batch| { - maybe_batch - .and_then(|b| schema_mapping.map_batch(b).map_err(Into::into)) - }); + let files_ranges_pruned_statistics = + file_metrics.files_ranges_pruned_statistics.clone(); + let predicate_cache_inner_records = + file_metrics.predicate_cache_inner_records.clone(); + let predicate_cache_records = file_metrics.predicate_cache_records.clone(); - Ok(adapted.boxed()) + let stream = stream.map_err(DataFusionError::from).map(move |b| { + b.and_then(|b| { + copy_arrow_reader_metrics( + &arrow_reader_metrics, + &predicate_cache_inner_records, + &predicate_cache_records, + ); + schema_mapping.map_batch(b) + }) + }); + + if let Some(file_pruner) = file_pruner { + Ok(EarlyStoppingStream::new( + stream, + file_pruner, + files_ranges_pruned_statistics, + ) + .boxed()) + } else { + Ok(stream.boxed()) + } })) } } +/// Copies metrics from ArrowReaderMetrics (the metrics collected by the +/// arrow-rs parquet reader) to the parquet file metrics for DataFusion +fn copy_arrow_reader_metrics( + arrow_reader_metrics: &ArrowReaderMetrics, + predicate_cache_inner_records: &Count, + predicate_cache_records: &Count, +) { + if let Some(v) = arrow_reader_metrics.records_read_from_inner() { + predicate_cache_inner_records.add(v); + } + + if let Some(v) = arrow_reader_metrics.records_read_from_cache() { + predicate_cache_records.add(v); + } +} + +/// Wraps an inner RecordBatchStream and a [`FilePruner`] +/// +/// This can terminate the scan early when some dynamic filters is updated after +/// the scan starts, so we discover after the scan starts that the file can be +/// pruned (can't have matching rows). +struct EarlyStoppingStream { + /// Has the stream finished processing? All subsequent polls will return + /// None + done: bool, + file_pruner: FilePruner, + files_ranges_pruned_statistics: Count, + /// The inner stream + inner: S, +} + +impl EarlyStoppingStream { + pub fn new( + stream: S, + file_pruner: FilePruner, + files_ranges_pruned_statistics: Count, + ) -> Self { + Self { + done: false, + inner: stream, + file_pruner, + files_ranges_pruned_statistics, + } + } +} +impl EarlyStoppingStream +where + S: Stream> + Unpin, +{ + fn check_prune(&mut self, input: Result) -> Result> { + let batch = input?; + + // Since dynamic filters may have been updated, see if we can stop + // reading this stream entirely. + if self.file_pruner.should_prune()? { + self.files_ranges_pruned_statistics.add(1); + self.done = true; + Ok(None) + } else { + // Return the adapted batch + Ok(Some(batch)) + } + } +} + +impl Stream for EarlyStoppingStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if self.done { + return Poll::Ready(None); + } + match ready!(self.inner.poll_next_unpin(cx)) { + None => { + // input done + self.done = true; + Poll::Ready(None) + } + Some(input_batch) => { + let output = self.check_prune(input_batch); + Poll::Ready(output.transpose()) + } + } + } +} + +#[derive(Default)] +#[cfg_attr(not(feature = "parquet_encryption"), allow(dead_code))] +struct EncryptionContext { + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: Option>, + #[cfg(feature = "parquet_encryption")] + encryption_factory: Option<(Arc, EncryptionFactoryOptions)>, +} + +#[cfg(feature = "parquet_encryption")] +impl EncryptionContext { + fn new( + file_decryption_properties: Option>, + encryption_factory: Option<( + Arc, + EncryptionFactoryOptions, + )>, + ) -> Self { + Self { + file_decryption_properties, + encryption_factory, + } + } + + async fn get_file_decryption_properties( + &self, + file_location: &object_store::path::Path, + ) -> Result>> { + match &self.file_decryption_properties { + Some(file_decryption_properties) => { + Ok(Some(Arc::clone(file_decryption_properties))) + } + None => match &self.encryption_factory { + Some((encryption_factory, encryption_config)) => Ok(encryption_factory + .get_file_decryption_properties(encryption_config, file_location) + .await? + .map(Arc::new)), + None => Ok(None), + }, + } + } +} + +#[cfg(not(feature = "parquet_encryption"))] +#[allow(dead_code)] +impl EncryptionContext { + async fn get_file_decryption_properties( + &self, + _file_location: &object_store::path::Path, + ) -> Result>> { + Ok(None) + } +} + +impl ParquetOpener { + #[cfg(feature = "parquet_encryption")] + fn get_encryption_context(&self) -> EncryptionContext { + EncryptionContext::new( + self.file_decryption_properties.clone(), + self.encryption_factory.clone(), + ) + } + + #[cfg(not(feature = "parquet_encryption"))] + #[allow(dead_code)] + fn get_encryption_context(&self) -> EncryptionContext { + EncryptionContext::default() + } +} + /// Return the initial [`ParquetAccessPlan`] /// /// If the user has supplied one as an extension, use that @@ -295,3 +648,743 @@ fn create_initial_plan( // default to scanning all row groups Ok(ParquetAccessPlan::new_all(row_group_count)) } + +/// Build a page pruning predicate from an optional predicate expression. +/// If the predicate is None or the predicate cannot be converted to a page pruning +/// predicate, return None. +pub(crate) fn build_page_pruning_predicate( + predicate: &Arc, + file_schema: &SchemaRef, +) -> Arc { + Arc::new(PagePruningAccessPlanFilter::new( + predicate, + Arc::clone(file_schema), + )) +} + +pub(crate) fn build_pruning_predicates( + predicate: Option<&Arc>, + file_schema: &SchemaRef, + predicate_creation_errors: &Count, +) -> ( + Option>, + Option>, +) { + let Some(predicate) = predicate.as_ref() else { + return (None, None); + }; + let pruning_predicate = build_pruning_predicate( + Arc::clone(predicate), + file_schema, + predicate_creation_errors, + ); + let page_pruning_predicate = build_page_pruning_predicate(predicate, file_schema); + (pruning_predicate, Some(page_pruning_predicate)) +} + +/// Returns a `ArrowReaderMetadata` with the page index loaded, loading +/// it from the underlying `AsyncFileReader` if necessary. +async fn load_page_index( + reader_metadata: ArrowReaderMetadata, + input: &mut T, + options: ArrowReaderOptions, +) -> Result { + let parquet_metadata = reader_metadata.metadata(); + let missing_column_index = parquet_metadata.column_index().is_none(); + let missing_offset_index = parquet_metadata.offset_index().is_none(); + // You may ask yourself: why are we even checking if the page index is already loaded here? + // Didn't we explicitly *not* load it above? + // Well it's possible that a custom implementation of `AsyncFileReader` gives you + // the page index even if you didn't ask for it (e.g. because it's cached) + // so it's important to check that here to avoid extra work. + if missing_column_index || missing_offset_index { + let m = Arc::try_unwrap(Arc::clone(parquet_metadata)) + .unwrap_or_else(|e| e.as_ref().clone()); + let mut reader = ParquetMetaDataReader::new_with_metadata(m) + .with_page_index_policy(PageIndexPolicy::Optional); + reader.load_page_index(input).await?; + let new_parquet_metadata = reader.finish()?; + let new_arrow_reader = + ArrowReaderMetadata::try_new(Arc::new(new_parquet_metadata), options)?; + Ok(new_arrow_reader) + } else { + // No need to load the page index again, just return the existing metadata + Ok(reader_metadata) + } +} + +fn should_enable_page_index( + enable_page_index: bool, + page_pruning_predicate: &Option>, +) -> bool { + enable_page_index + && page_pruning_predicate.is_some() + && page_pruning_predicate + .as_ref() + .map(|p| p.filter_number() > 0) + .unwrap_or(false) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::{ + compute::cast, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + use bytes::{BufMut, BytesMut}; + use datafusion_common::{ + assert_batches_eq, record_batch, stats::Precision, ColumnStatistics, + DataFusionError, ScalarValue, Statistics, + }; + use datafusion_datasource::{ + file_stream::FileOpener, + schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, + SchemaMapper, + }, + PartitionedFile, + }; + use datafusion_expr::{col, lit}; + use datafusion_physical_expr::{ + expressions::DynamicFilterPhysicalExpr, planner::logical2physical, PhysicalExpr, + }; + use datafusion_physical_expr_adapter::DefaultPhysicalExprAdapterFactory; + use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; + use futures::{Stream, StreamExt}; + use object_store::{memory::InMemory, path::Path, ObjectStore}; + use parquet::arrow::ArrowWriter; + + use crate::{opener::ParquetOpener, DefaultParquetFileReaderFactory}; + + async fn count_batches_and_rows( + mut stream: std::pin::Pin< + Box< + dyn Stream> + + Send, + >, + >, + ) -> (usize, usize) { + let mut num_batches = 0; + let mut num_rows = 0; + while let Some(Ok(batch)) = stream.next().await { + num_rows += batch.num_rows(); + num_batches += 1; + } + (num_batches, num_rows) + } + + async fn collect_batches( + mut stream: std::pin::Pin< + Box< + dyn Stream> + + Send, + >, + >, + ) -> Vec { + let mut batches = vec![]; + while let Some(Ok(batch)) = stream.next().await { + batches.push(batch); + } + batches + } + + async fn write_parquet( + store: Arc, + filename: &str, + batch: arrow::record_batch::RecordBatch, + ) -> usize { + let mut out = BytesMut::new().writer(); + { + let mut writer = + ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + let data_len = data.len(); + store.put(&Path::from(filename), data.into()).await.unwrap(); + data_len + } + + fn make_dynamic_expr(expr: Arc) -> Arc { + Arc::new(DynamicFilterPhysicalExpr::new( + expr.children().into_iter().map(Arc::clone).collect(), + expr, + )) + } + + #[tokio::test] + async fn test_prune_on_statistics() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!( + ("a", Int32, vec![Some(1), Some(2), Some(2)]), + ("b", Float32, vec![Some(1.0), Some(2.0), None]) + ) + .unwrap(); + + let data_size = + write_parquet(Arc::clone(&store), "test.parquet", batch.clone()).await; + + let schema = batch.schema(); + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ) + .with_statistics(Arc::new( + Statistics::new_unknown(&schema) + .add_column_statistics(ColumnStatistics::new_unknown()) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Float32(Some(1.0)))) + .with_max_value(Precision::Exact(ScalarValue::Float32(Some(2.0)))) + .with_null_count(Precision::Exact(1)), + ), + )); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0, 1]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![], + pushdown_filters: false, // note that this is false! + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: true, + coerce_int96: None, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: None, + expr_adapter_factory: Some(Arc::new(DefaultPhysicalExprAdapterFactory)), + #[cfg(feature = "parquet_encryption")] + encryption_factory: None, + max_predicate_cache_size: None, + } + }; + + // A filter on "a" should not exclude any rows even if it matches the data + let expr = col("a").eq(lit(1)); + let predicate = logical2physical(&expr, &schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // A filter on `b = 5.0` should exclude all rows + let expr = col("b").eq(lit(ScalarValue::Float32(Some(5.0)))); + let predicate = logical2physical(&expr, &schema); + let opener = make_opener(predicate); + let stream = opener.open(file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_statistics_with_dynamic_expression() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: false, // note that this is false! + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: true, + coerce_int96: None, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: None, + expr_adapter_factory: Some(Arc::new(DefaultPhysicalExprAdapterFactory)), + #[cfg(feature = "parquet_encryption")] + encryption_factory: None, + max_predicate_cache_size: None, + } + }; + + // Filter should match the partition value + let expr = col("part").eq(lit(1)); + // Mark the expression as dynamic even if it's not to force partition pruning to happen + // Otherwise we assume it already happened at the planning stage and won't re-do the work here + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should not match the partition value + let expr = col("part").eq(lit(2)); + // Mark the expression as dynamic even if it's not to force partition pruning to happen + // Otherwise we assume it already happened at the planning stage and won't re-do the work here + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = opener.open(file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_values_and_file_statistics() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!( + ("a", Int32, vec![Some(1), Some(2), Some(3)]), + ("b", Float64, vec![Some(1.0), Some(2.0), None]) + ) + .unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + file.statistics = Some(Arc::new( + Statistics::new_unknown(&file_schema) + .add_column_statistics(ColumnStatistics::new_unknown()) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Float64(Some(1.0)))) + .with_max_value(Precision::Exact(ScalarValue::Float64(Some(2.0)))) + .with_null_count(Precision::Exact(1)), + ), + )); + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float32, true), + ])); + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: false, // note that this is false! + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: true, + coerce_int96: None, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: None, + expr_adapter_factory: Some(Arc::new(DefaultPhysicalExprAdapterFactory)), + #[cfg(feature = "parquet_encryption")] + encryption_factory: None, + max_predicate_cache_size: None, + } + }; + + // Filter should match the partition value and file statistics + let expr = col("part").eq(lit(1)).and(col("b").eq(lit(1.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Should prune based on partition value but not file statistics + let expr = col("part").eq(lit(2)).and(col("b").eq(lit(1.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // Should prune based on file statistics but not partition value + let expr = col("part").eq(lit(1)).and(col("b").eq(lit(7.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // Should prune based on both partition value and file statistics + let expr = col("part").eq(lit(2)).and(col("b").eq(lit(7.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_value_and_data_value() { + let store = Arc::new(InMemory::new()) as Arc; + + // Note: number 3 is missing! + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(4)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: true, // note that this is true! + reorder_filters: true, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: false, // note that this is false! + coerce_int96: None, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: None, + expr_adapter_factory: Some(Arc::new(DefaultPhysicalExprAdapterFactory)), + #[cfg(feature = "parquet_encryption")] + encryption_factory: None, + max_predicate_cache_size: None, + } + }; + + // Filter should match the partition value and data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should match the partition value but not the data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should not match the partition value but match the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 1); + + // Filter should not match the partition value or the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + /// Test that if the filter is not a dynamic filter and we have no stats we don't do extra pruning work at the file level. + #[tokio::test] + async fn test_opener_pruning_skipped_on_static_filters() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: false, // note that this is false! + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: true, + coerce_int96: None, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: None, + expr_adapter_factory: Some(Arc::new(DefaultPhysicalExprAdapterFactory)), + #[cfg(feature = "parquet_encryption")] + encryption_factory: None, + max_predicate_cache_size: None, + } + }; + + // Filter should NOT match the stats but the file is never attempted to be pruned because the filters are not dynamic + let expr = col("part").eq(lit(2)); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // If we make the filter dynamic, it should prune + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + fn get_value(metrics: &MetricsSet, metric_name: &str) -> usize { + match metrics.sum_by_name(metric_name) { + Some(v) => v.as_usize(), + _ => { + panic!( + "Expected metric not found. Looking for '{metric_name}' in\n\n{metrics:#?}" + ); + } + } + } + + #[tokio::test] + async fn test_custom_schema_adapter_no_rewriter() { + // Make a hardcoded schema adapter that adds a new column "b" with default value 0.0 + // and converts the first column "a" from Int32 to UInt64. + #[derive(Debug, Clone)] + struct CustomSchemaMapper; + + impl SchemaMapper for CustomSchemaMapper { + fn map_batch( + &self, + batch: arrow::array::RecordBatch, + ) -> datafusion_common::Result { + let a_column = cast(batch.column(0), &DataType::UInt64)?; + // Add in a new column "b" with default value 0.0 + let b_column = + arrow::array::Float64Array::from(vec![Some(0.0); batch.num_rows()]); + let columns = vec![a_column, Arc::new(b_column)]; + let new_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt64, false), + Field::new("b", DataType::Float64, false), + ])); + Ok(arrow::record_batch::RecordBatch::try_new( + new_schema, columns, + )?) + } + + fn map_column_statistics( + &self, + file_col_statistics: &[ColumnStatistics], + ) -> datafusion_common::Result> { + Ok(vec![ + file_col_statistics[0].clone(), + ColumnStatistics::new_unknown(), + ]) + } + } + + #[derive(Debug, Clone)] + struct CustomSchemaAdapter; + + impl SchemaAdapter for CustomSchemaAdapter { + fn map_schema( + &self, + _file_schema: &Schema, + ) -> datafusion_common::Result<(Arc, Vec)> + { + let mapper = Arc::new(CustomSchemaMapper); + let projection = vec![0]; // We only need to read the first column "a" from the file + Ok((mapper, projection)) + } + + fn map_column_index( + &self, + index: usize, + file_schema: &Schema, + ) -> Option { + if index < file_schema.fields().len() { + Some(index) + } else { + None // The new column "b" is not in the original schema + } + } + } + + #[derive(Debug, Clone)] + struct CustomSchemaAdapterFactory; + + impl SchemaAdapterFactory for CustomSchemaAdapterFactory { + fn create( + &self, + _projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(CustomSchemaAdapter) + } + } + + // Test that if no expression rewriter is provided we use a schemaadapter to adapt the data to the expression + let store = Arc::new(InMemory::new()) as Arc; + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + // Write out the batch to a Parquet file + let data_size = + write_parquet(Arc::clone(&store), "test.parquet", batch.clone()).await; + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt64, false), + Field::new("b", DataType::Float64, false), + ])); + + let make_opener = |predicate| ParquetOpener { + partition_index: 0, + projection: Arc::new([0, 1]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: Arc::clone(&table_schema), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new(DefaultParquetFileReaderFactory::new( + Arc::clone(&store), + )), + partition_fields: vec![], + pushdown_filters: true, + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(CustomSchemaAdapterFactory), + enable_row_group_stats_pruning: false, + coerce_int96: None, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: None, + expr_adapter_factory: None, + #[cfg(feature = "parquet_encryption")] + encryption_factory: None, + max_predicate_cache_size: None, + }; + + let predicate = logical2physical(&col("a").eq(lit(1u64)), &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(file.clone()).unwrap().await.unwrap(); + let batches = collect_batches(stream).await; + + #[rustfmt::skip] + let expected = [ + "+---+-----+", + "| a | b |", + "+---+-----+", + "| 1 | 0.0 |", + "+---+-----+", + ]; + assert_batches_eq!(expected, &batches); + let metrics = opener.metrics.clone_inner(); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 0); + assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 2); + } +} diff --git a/datafusion/datasource-parquet/src/page_filter.rs b/datafusion/datasource-parquet/src/page_filter.rs index ef832d808647c..5f3e05747d404 100644 --- a/datafusion/datasource-parquet/src/page_filter.rs +++ b/datafusion/datasource-parquet/src/page_filter.rs @@ -28,9 +28,10 @@ use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, }; +use datafusion_common::pruning::PruningStatistics; use datafusion_common::ScalarValue; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; -use datafusion_physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion_pruning::PruningPredicate; use log::{debug, trace}; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; @@ -249,9 +250,9 @@ impl PagePruningAccessPlanFilter { } if let Some(overall_selection) = overall_selection { - if overall_selection.selects_any() { - let rows_skipped = rows_skipped(&overall_selection); - let rows_selected = rows_selected(&overall_selection); + let rows_selected = overall_selection.row_count(); + if rows_selected > 0 { + let rows_skipped = overall_selection.skipped_row_count(); trace!("Overall selection from predicate skipped {rows_skipped}, selected {rows_selected}: {overall_selection:?}"); total_skip += rows_skipped; total_select += rows_selected; @@ -280,22 +281,6 @@ impl PagePruningAccessPlanFilter { } } -/// returns the number of rows skipped in the selection -/// TODO should this be upstreamed to RowSelection? -fn rows_skipped(selection: &RowSelection) -> usize { - selection - .iter() - .fold(0, |acc, x| if x.skip { acc + x.row_count } else { acc }) -} - -/// returns the number of rows not skipped in the selection -/// TODO should this be upstreamed to RowSelection? -fn rows_selected(selection: &RowSelection) -> usize { - selection - .iter() - .fold(0, |acc, x| if x.skip { acc } else { acc + x.row_count }) -} - fn update_selection( current_selection: Option, row_selection: RowSelection, @@ -349,7 +334,7 @@ fn prune_pages_in_one_row_group( assert_eq!(page_row_counts.len(), values.len()); let mut sum_row = *page_row_counts.first().unwrap(); let mut selected = *values.first().unwrap(); - trace!("Pruned to {:?} using {:?}", values, pruning_stats); + trace!("Pruned to {values:?} using {pruning_stats:?}"); for (i, &f) in values.iter().enumerate().skip(1) { if f == selected { sum_row += *page_row_counts.get(i).unwrap(); diff --git a/datafusion/datasource-parquet/src/reader.rs b/datafusion/datasource-parquet/src/reader.rs index 5924a5b5038fc..687a7f15fccc8 100644 --- a/datafusion/datasource-parquet/src/reader.rs +++ b/datafusion/datasource-parquet/src/reader.rs @@ -18,19 +18,25 @@ //! [`ParquetFileReaderFactory`] and [`DefaultParquetFileReaderFactory`] for //! low level control of parquet file readers +use crate::metadata::DFParquetMetadata; +use crate::ParquetFileMetrics; use bytes::Bytes; -use datafusion_datasource::file_meta::FileMeta; +use datafusion_datasource::PartitionedFile; +use datafusion_execution::cache::cache_manager::FileMetadata; +use datafusion_execution::cache::cache_manager::FileMetadataCache; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use futures::future::BoxFuture; +use futures::FutureExt; use object_store::ObjectStore; +use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; use parquet::file::metadata::ParquetMetaData; +use std::any::Any; +use std::collections::HashMap; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -use crate::ParquetFileMetrics; - /// Interface for reading parquet files. /// /// The combined implementations of [`ParquetFileReaderFactory`] and @@ -50,13 +56,13 @@ pub trait ParquetFileReaderFactory: Debug + Send + Sync + 'static { /// /// # Arguments /// * partition_index - Index of the partition (for reporting metrics) - /// * file_meta - The file to be read + /// * file - The file to be read /// * metadata_size_hint - If specified, the first IO reads this many bytes from the footer /// * metrics - Execution metrics fn create_reader( &self, partition_index: usize, - file_meta: FileMeta, + partitioned_file: PartitionedFile, metadata_size_hint: Option, metrics: &ExecutionPlanMetricsSet, ) -> datafusion_common::Result>; @@ -88,7 +94,7 @@ impl DefaultParquetFileReaderFactory { /// This implementation does not coalesce I/O operations or cache bytes. Such /// optimizations can be done either at the object store level or by providing a /// custom implementation of [`ParquetFileReaderFactory`]. -pub(crate) struct ParquetFileReader { +pub struct ParquetFileReader { pub file_metrics: ParquetFileMetrics, pub inner: ParquetObjectReader, } @@ -96,28 +102,30 @@ pub(crate) struct ParquetFileReader { impl AsyncFileReader for ParquetFileReader { fn get_bytes( &mut self, - range: Range, + range: Range, ) -> BoxFuture<'_, parquet::errors::Result> { - self.file_metrics.bytes_scanned.add(range.end - range.start); + let bytes_scanned = range.end - range.start; + self.file_metrics.bytes_scanned.add(bytes_scanned as usize); self.inner.get_bytes(range) } fn get_byte_ranges( &mut self, - ranges: Vec>, + ranges: Vec>, ) -> BoxFuture<'_, parquet::errors::Result>> where Self: Send, { - let total = ranges.iter().map(|r| r.end - r.start).sum(); - self.file_metrics.bytes_scanned.add(total); + let total: u64 = ranges.iter().map(|r| r.end - r.start).sum(); + self.file_metrics.bytes_scanned.add(total as usize); self.inner.get_byte_ranges(ranges) } - fn get_metadata( - &mut self, - ) -> BoxFuture<'_, parquet::errors::Result>> { - self.inner.get_metadata() + fn get_metadata<'a>( + &'a mut self, + options: Option<&'a ArrowReaderOptions>, + ) -> BoxFuture<'a, parquet::errors::Result>> { + self.inner.get_metadata(options) } } @@ -125,17 +133,21 @@ impl ParquetFileReaderFactory for DefaultParquetFileReaderFactory { fn create_reader( &self, partition_index: usize, - file_meta: FileMeta, + partitioned_file: PartitionedFile, metadata_size_hint: Option, metrics: &ExecutionPlanMetricsSet, ) -> datafusion_common::Result> { let file_metrics = ParquetFileMetrics::new( partition_index, - file_meta.location().as_ref(), + partitioned_file.object_meta.location.as_ref(), metrics, ); let store = Arc::clone(&self.store); - let mut inner = ParquetObjectReader::new(store, file_meta.object_meta); + let mut inner = ParquetObjectReader::new( + store, + partitioned_file.object_meta.location.clone(), + ) + .with_file_size(partitioned_file.object_meta.size); if let Some(hint) = metadata_size_hint { inner = inner.with_footer_size_hint(hint) @@ -147,3 +159,157 @@ impl ParquetFileReaderFactory for DefaultParquetFileReaderFactory { })) } } + +/// Implementation of [`ParquetFileReaderFactory`] supporting the caching of footer and page +/// metadata. Reads and updates the [`FileMetadataCache`] with the [`ParquetMetaData`] data. +/// This reader always loads the entire metadata (including page index, unless the file is +/// encrypted), even if not required by the current query, to ensure it is always available for +/// those that need it. +#[derive(Debug)] +pub struct CachedParquetFileReaderFactory { + store: Arc, + metadata_cache: Arc, +} + +impl CachedParquetFileReaderFactory { + pub fn new( + store: Arc, + metadata_cache: Arc, + ) -> Self { + Self { + store, + metadata_cache, + } + } +} + +impl ParquetFileReaderFactory for CachedParquetFileReaderFactory { + fn create_reader( + &self, + partition_index: usize, + partitioned_file: PartitionedFile, + metadata_size_hint: Option, + metrics: &ExecutionPlanMetricsSet, + ) -> datafusion_common::Result> { + let file_metrics = ParquetFileMetrics::new( + partition_index, + partitioned_file.object_meta.location.as_ref(), + metrics, + ); + let store = Arc::clone(&self.store); + + let mut inner = ParquetObjectReader::new( + store, + partitioned_file.object_meta.location.clone(), + ) + .with_file_size(partitioned_file.object_meta.size); + + if let Some(hint) = metadata_size_hint { + inner = inner.with_footer_size_hint(hint) + }; + + Ok(Box::new(CachedParquetFileReader { + store: Arc::clone(&self.store), + inner, + file_metrics, + partitioned_file, + metadata_cache: Arc::clone(&self.metadata_cache), + metadata_size_hint, + })) + } +} + +/// Implements [`AsyncFileReader`] for a Parquet file in object storage. Reads the file metadata +/// from the [`FileMetadataCache`], if available, otherwise reads it directly from the file and then +/// updates the cache. +pub struct CachedParquetFileReader { + pub file_metrics: ParquetFileMetrics, + store: Arc, + pub inner: ParquetObjectReader, + partitioned_file: PartitionedFile, + metadata_cache: Arc, + metadata_size_hint: Option, +} + +impl AsyncFileReader for CachedParquetFileReader { + fn get_bytes( + &mut self, + range: Range, + ) -> BoxFuture<'_, parquet::errors::Result> { + let bytes_scanned = range.end - range.start; + self.file_metrics.bytes_scanned.add(bytes_scanned as usize); + self.inner.get_bytes(range) + } + + fn get_byte_ranges( + &mut self, + ranges: Vec>, + ) -> BoxFuture<'_, parquet::errors::Result>> + where + Self: Send, + { + let total: u64 = ranges.iter().map(|r| r.end - r.start).sum(); + self.file_metrics.bytes_scanned.add(total as usize); + self.inner.get_byte_ranges(ranges) + } + + fn get_metadata<'a>( + &'a mut self, + #[allow(unused_variables)] options: Option<&'a ArrowReaderOptions>, + ) -> BoxFuture<'a, parquet::errors::Result>> { + let object_meta = self.partitioned_file.object_meta.clone(); + let metadata_cache = Arc::clone(&self.metadata_cache); + + async move { + #[cfg(feature = "parquet_encryption")] + let file_decryption_properties = + options.and_then(|o| o.file_decryption_properties()); + + #[cfg(not(feature = "parquet_encryption"))] + let file_decryption_properties = None; + + DFParquetMetadata::new(&self.store, &object_meta) + .with_decryption_properties(file_decryption_properties) + .with_file_metadata_cache(Some(Arc::clone(&metadata_cache))) + .with_metadata_size_hint(self.metadata_size_hint) + .fetch_metadata() + .await + .map_err(|e| { + parquet::errors::ParquetError::General(format!( + "Failed to fetch metadata for file {}: {e}", + object_meta.location, + )) + }) + } + .boxed() + } +} + +/// Wrapper to implement [`FileMetadata`] for [`ParquetMetaData`]. +pub struct CachedParquetMetaData(Arc); + +impl CachedParquetMetaData { + pub fn new(metadata: Arc) -> Self { + Self(metadata) + } + + pub fn parquet_metadata(&self) -> &Arc { + &self.0 + } +} + +impl FileMetadata for CachedParquetMetaData { + fn as_any(&self) -> &dyn Any { + self + } + + fn memory_size(&self) -> usize { + self.0.memory_size() + } + + fn extra_info(&self) -> HashMap { + let page_index = + self.0.column_index().is_some() && self.0.offset_index().is_some(); + HashMap::from([("page_index".to_owned(), page_index.to_string())]) + } +} diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index da6bf114d71dd..660b32f486120 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -76,7 +76,7 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor} use datafusion_common::Result; use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::reassign_predicate_columns; +use datafusion_physical_expr::utils::reassign_expr_columns; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use datafusion_physical_plan::metrics; @@ -119,9 +119,8 @@ impl DatafusionArrowPredicate { rows_matched: metrics::Count, time: metrics::Time, ) -> Result { - let projected_schema = Arc::clone(&candidate.filter_schema); let physical_expr = - reassign_predicate_columns(candidate.expr, &projected_schema, true)?; + reassign_expr_columns(candidate.expr, &candidate.filter_schema)?; Ok(Self { physical_expr, @@ -184,7 +183,7 @@ pub(crate) struct FilterCandidate { /// Can this filter use an index (e.g. a page index) to prune rows? can_use_index: bool, /// The projection to read from the file schema to get the columns - /// required to pass thorugh a `SchemaMapper` to the table schema + /// required to pass through a `SchemaMapper` to the table schema /// upon which we then evaluate the filter expression. projection: Vec, /// A `SchemaMapper` used to map batches read from the file schema to @@ -299,6 +298,7 @@ struct PushdownChecker<'schema> { non_primitive_columns: bool, /// Does the expression reference any columns that are in the table /// schema but not in the file schema? + /// This includes partition columns and projected columns. projected_columns: bool, // Indices into the table schema of the columns required to evaluate the expression required_columns: BTreeSet, @@ -366,44 +366,19 @@ fn pushdown_columns( .then_some(checker.required_columns.into_iter().collect())) } -/// creates a PushdownChecker for a single use to check a given column with the given schemes. Used -/// to check preemptively if a column name would prevent pushdowning. -/// effectively does the inverse of [`pushdown_columns`] does, but with a single given column -/// (instead of traversing the entire tree to determine this) -fn would_column_prevent_pushdown(column_name: &str, table_schema: &Schema) -> bool { - let mut checker = PushdownChecker::new(table_schema); - - // the return of this is only used for [`PushdownChecker::f_down()`], so we can safely ignore - // it here. I'm just verifying we know the return type of this so nobody accidentally changes - // the return type of this fn and it gets implicitly ignored here. - let _: Option = checker.check_single_column(column_name); - - // and then return a value based on the state of the checker - checker.prevents_pushdown() -} - /// Recurses through expr as a tree, finds all `column`s, and checks if any of them would prevent /// this expression from being predicate pushed down. If any of them would, this returns false. /// Otherwise, true. +/// Note that the schema passed in here is *not* the physical file schema (as it is not available at that point in time); +/// it is the schema of the table that this expression is being evaluated against minus any projected columns and partition columns. pub fn can_expr_be_pushed_down_with_schemas( - expr: &datafusion_expr::Expr, - _file_schema: &Schema, - table_schema: &Schema, + expr: &Arc, + file_schema: &Schema, ) -> bool { - let mut can_be_pushed = true; - expr.apply(|expr| match expr { - datafusion_expr::Expr::Column(column) => { - can_be_pushed &= !would_column_prevent_pushdown(column.name(), table_schema); - Ok(if can_be_pushed { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Stop - }) - } - _ => Ok(TreeNodeRecursion::Continue), - }) - .unwrap(); // we never return an Err, so we can safely unwrap this - can_be_pushed + match pushdown_columns(expr, file_schema) { + Ok(Some(_)) => true, + Ok(None) | Err(_) => false, + } } /// Calculate the total compressed size of all `Column`'s required for @@ -449,8 +424,8 @@ fn columns_sorted(_columns: &[usize], _metadata: &ParquetMetaData) -> Result, - file_schema: &SchemaRef, - table_schema: &SchemaRef, + physical_file_schema: &SchemaRef, + predicate_file_schema: &SchemaRef, metadata: &ParquetMetaData, reorder_predicates: bool, file_metrics: &ParquetFileMetrics, @@ -470,8 +445,8 @@ pub fn build_row_filter( .map(|expr| { FilterCandidateBuilder::new( Arc::clone(expr), - Arc::clone(file_schema), - Arc::clone(table_schema), + Arc::clone(physical_file_schema), + Arc::clone(predicate_file_schema), Arc::clone(schema_adapter_factory), ) .build(metadata) @@ -516,7 +491,7 @@ mod test { use super::*; use datafusion_common::ScalarValue; - use arrow::datatypes::{Field, Fields, TimeUnit::Nanosecond}; + use arrow::datatypes::{Field, TimeUnit::Nanosecond}; use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion_expr::{col, Expr}; use datafusion_physical_expr::planner::logical2physical; @@ -581,6 +556,7 @@ mod test { // Test all should fail let expr = col("timestamp_col").lt(Expr::Literal( ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), + None, )); let expr = logical2physical(&expr, &table_schema); let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); @@ -621,6 +597,7 @@ mod test { // Test all should pass let expr = col("timestamp_col").gt(Expr::Literal( ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), + None, )); let expr = logical2physical(&expr, &table_schema); let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); @@ -649,73 +626,45 @@ mod test { #[test] fn nested_data_structures_prevent_pushdown() { - let table_schema = get_basic_table_schema(); - - let file_schema = Schema::new(vec![Field::new( - "list_col", - DataType::Struct(Fields::empty()), - true, - )]); + let table_schema = Arc::new(get_lists_table_schema()); - let expr = col("list_col").is_not_null(); + let expr = col("utf8_list").is_not_null(); + let expr = logical2physical(&expr, &table_schema); + check_expression_can_evaluate_against_schema(&expr, &table_schema); - assert!(!can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn projected_columns_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = - Schema::new(vec![Field::new("existing_col", DataType::Int64, true)]); + let expr = + Arc::new(Column::new("nonexistent_column", 0)) as Arc; - let expr = col("nonexistent_column").is_null(); - - assert!(!can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn basic_expr_doesnt_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = - Schema::new(vec![Field::new("string_col", DataType::Utf8, true)]); - let expr = col("string_col").is_null(); + let expr = logical2physical(&expr, &table_schema); - assert!(can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn complex_expr_doesnt_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = Schema::new(vec![ - Field::new("string_col", DataType::Utf8, true), - Field::new("bigint_col", DataType::Int64, true), - ]); - let expr = col("string_col") .is_not_null() - .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5))))); + .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5)), None))); + let expr = logical2physical(&expr, &table_schema); - assert!(can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } fn get_basic_table_schema() -> Schema { @@ -730,4 +679,27 @@ mod test { parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) .expect("parsing schema") } + + fn get_lists_table_schema() -> Schema { + let testdata = datafusion_common::test_util::parquet_test_data(); + let file = std::fs::File::open(format!("{testdata}/list_columns.parquet")) + .expect("opening file"); + + let reader = SerializedFileReader::new(file).expect("creating reader"); + + let metadata = reader.metadata(); + + parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) + .expect("parsing schema") + } + + /// Sanity check that the given expression could be evaluated against the given schema without any errors. + /// This will fail if the expression references columns that are not in the schema or if the types of the columns are incompatible, etc. + fn check_expression_can_evaluate_against_schema( + expr: &Arc, + table_schema: &Arc, + ) -> bool { + let batch = RecordBatch::new_empty(Arc::clone(table_schema)); + expr.evaluate(&batch).is_ok() + } } diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 9d5f9fa16b6eb..51d50d780f103 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -21,9 +21,10 @@ use std::sync::Arc; use super::{ParquetAccessPlan, ParquetFileMetrics}; use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::Schema; +use datafusion_common::pruning::PruningStatistics; use datafusion_common::{Column, Result, ScalarValue}; use datafusion_datasource::FileRange; -use datafusion_physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion_pruning::PruningPredicate; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::parquet_column; use parquet::basic::Type; @@ -1241,12 +1242,16 @@ mod tests { .run( lit("1").eq(lit("1")).and( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( - "Hello_Not_Exists", - ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("Hello_Not_Exists2")), - )))), + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello_Not_Exists"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from( + "Hello_Not_Exists2", + ))), + None, + ))), ), ) .await @@ -1265,7 +1270,7 @@ mod tests { let expr = col(r#""String""#).in_list( (1..25) - .map(|i| lit(format!("Hello_Not_Exists{}", i))) + .map(|i| lit(format!("Hello_Not_Exists{i}"))) .collect::>(), false, ); @@ -1326,15 +1331,18 @@ mod tests { // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` .run( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( - "Hello", - ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("the quick")), - )))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("are you")), - )))), + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("the quick"))), + None, + ))) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("are you"))), + None, + ))), ) .await } @@ -1513,7 +1521,7 @@ mod tests { let object_meta = ObjectMeta { location: object_store::path::Path::parse(file_name).expect("creating path"), last_modified: chrono::DateTime::from(std::time::SystemTime::now()), - size: data.len(), + size: data.len() as u64, e_tag: None, version: None, }; @@ -1526,8 +1534,11 @@ mod tests { let metrics = ExecutionPlanMetricsSet::new(); let file_metrics = ParquetFileMetrics::new(0, object_meta.location.as_ref(), &metrics); + let inner = ParquetObjectReader::new(Arc::new(in_memory), object_meta.location) + .with_file_size(object_meta.size); + let reader = ParquetFileReader { - inner: ParquetObjectReader::new(Arc::new(in_memory), object_meta), + inner, file_metrics: file_metrics.clone(), }; let mut builder = ParquetRecordBatchStreamBuilder::new(reader).await.unwrap(); diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 66d4d313d5a61..dd10363079f91 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -17,31 +17,46 @@ //! ParquetSource implementation for reading parquet files use std::any::Any; +use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; +use crate::opener::build_pruning_predicates; use crate::opener::ParquetOpener; -use crate::page_filter::PagePruningAccessPlanFilter; +use crate::row_filter::can_expr_be_pushed_down_with_schemas; use crate::DefaultParquetFileReaderFactory; use crate::ParquetFileReaderFactory; +use datafusion_common::config::ConfigOptions; +#[cfg(feature = "parquet_encryption")] +use datafusion_common::config::EncryptionFactoryOptions; +use datafusion_datasource::as_file_source; use datafusion_datasource::file_stream::FileOpener; use datafusion_datasource::schema_adapter::{ DefaultSchemaAdapterFactory, SchemaAdapterFactory, }; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{SchemaRef, TimeUnit}; use datafusion_common::config::TableParquetOptions; -use datafusion_common::Statistics; +use datafusion_common::{DataFusionError, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; +use datafusion_physical_expr::conjunction; +use datafusion_physical_expr_adapter::DefaultPhysicalExprAdapterFactory; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::pruning::PruningPredicate; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion_physical_plan::filter_pushdown::PushedDown; +use datafusion_physical_plan::filter_pushdown::{ + FilterPushdownPropagation, PushedDownPredicate, +}; +use datafusion_physical_plan::metrics::Count; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::DisplayFormatType; +#[cfg(feature = "parquet_encryption")] +use datafusion_common::encryption::map_config_decryption_to_decryption; +#[cfg(feature = "parquet_encryption")] +use datafusion_execution::parquet_encryption::EncryptionFactory; use itertools::Itertools; -use log::debug; use object_store::ObjectStore; /// Execution plan for reading one or more Parquet files. @@ -90,7 +105,7 @@ use object_store::ObjectStore; /// # let predicate = lit(true); /// let source = Arc::new( /// ParquetSource::default() -/// .with_predicate(Arc::clone(&file_schema), predicate) +/// .with_predicate(predicate) /// ); /// // Create a DataSourceExec for reading `file1.parquet` with a file size of 100MB /// let config = FileScanConfigBuilder::new(object_store_url, file_schema, source) @@ -159,7 +174,7 @@ use object_store::ObjectStore; /// ```no_run /// # use std::sync::Arc; /// # use arrow::datatypes::Schema; -/// # use datafusion_datasource::file_scan_config::FileScanConfig; +/// # use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; /// # use datafusion_datasource::PartitionedFile; /// # use datafusion_datasource::source::DataSourceExec; /// @@ -173,9 +188,9 @@ use object_store::ObjectStore; /// .iter() /// .map(|file_group| { /// // create a new exec by copying the existing exec's source config -/// let new_config = base_config -/// .clone() -/// .with_file_groups(vec![file_group.clone()]); +/// let new_config = FileScanConfigBuilder::from(base_config.clone()) +/// .with_file_groups(vec![file_group.clone()]) +/// .build(); /// /// (DataSourceExec::from_data_source(new_config)) /// }) @@ -257,12 +272,12 @@ pub struct ParquetSource { pub(crate) table_parquet_options: TableParquetOptions, /// Optional metrics pub(crate) metrics: ExecutionPlanMetricsSet, + /// The schema of the file. + /// In particular, this is the schema of the table without partition columns, + /// *not* the physical schema of the file. + pub(crate) file_schema: Option, /// Optional predicate for row filtering during parquet scan pub(crate) predicate: Option>, - /// Optional predicate for pruning row groups (derived from `predicate`) - pub(crate) pruning_predicate: Option>, - /// Optional predicate for pruning pages (derived from `predicate`) - pub(crate) page_pruning_predicate: Option>, /// Optional user defined parquet file reader factory pub(crate) parquet_file_reader_factory: Option>, /// Optional user defined schema adapter @@ -272,6 +287,8 @@ pub struct ParquetSource { /// Optional hint for the size of the parquet metadata pub(crate) metadata_size_hint: Option, pub(crate) projected_statistics: Option, + #[cfg(feature = "parquet_encryption")] + pub(crate) encryption_factory: Option>, } impl ParquetSource { @@ -296,68 +313,34 @@ impl ParquetSource { self } - fn with_metrics(mut self, metrics: ExecutionPlanMetricsSet) -> Self { - self.metrics = metrics; - self - } - - /// Set predicate information, also sets pruning_predicate and page_pruning_predicate attributes - pub fn with_predicate( - &self, - file_schema: Arc, - predicate: Arc, - ) -> Self { + /// Set predicate information + pub fn with_predicate(&self, predicate: Arc) -> Self { let mut conf = self.clone(); - - let metrics = ExecutionPlanMetricsSet::new(); - let predicate_creation_errors = - MetricBuilder::new(&metrics).global_counter("num_predicate_creation_errors"); - - conf = conf.with_metrics(metrics); conf.predicate = Some(Arc::clone(&predicate)); - - match PruningPredicate::try_new(Arc::clone(&predicate), Arc::clone(&file_schema)) - { - Ok(pruning_predicate) => { - if !pruning_predicate.always_true() { - conf.pruning_predicate = Some(Arc::new(pruning_predicate)); - } - } - Err(e) => { - debug!("Could not create pruning predicate for: {e}"); - predicate_creation_errors.add(1); - } - }; - - let page_pruning_predicate = Arc::new(PagePruningAccessPlanFilter::new( - &predicate, - Arc::clone(&file_schema), - )); - conf.page_pruning_predicate = Some(page_pruning_predicate); - conf } + /// Set the encryption factory to use to generate file decryption properties + #[cfg(feature = "parquet_encryption")] + pub fn with_encryption_factory( + mut self, + encryption_factory: Arc, + ) -> Self { + self.encryption_factory = Some(encryption_factory); + self + } + /// Options passed to the parquet reader for this scan pub fn table_parquet_options(&self) -> &TableParquetOptions { &self.table_parquet_options } /// Optional predicate. + #[deprecated(since = "50.2.0", note = "use `filter` instead")] pub fn predicate(&self) -> Option<&Arc> { self.predicate.as_ref() } - /// Optional reference to this parquet scan's pruning predicate - pub fn pruning_predicate(&self) -> Option<&Arc> { - self.pruning_predicate.as_ref() - } - - /// Optional reference to this parquet scan's page pruning predicate - pub fn page_pruning_predicate(&self) -> Option<&Arc> { - self.page_pruning_predicate.as_ref() - } - /// return the optional file reader factory pub fn parquet_file_reader_factory( &self, @@ -375,29 +358,8 @@ impl ParquetSource { self } - /// return the optional schema adapter factory - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - /// If true, the predicate will be used during the parquet scan. - /// Defaults to false - /// - /// [`Expr`]: datafusion_expr::Expr + /// Defaults to false. pub fn with_pushdown_filters(mut self, pushdown_filters: bool) -> Self { self.table_parquet_options.global.pushdown_filters = pushdown_filters; self @@ -458,6 +420,72 @@ impl ParquetSource { fn bloom_filter_on_read(&self) -> bool { self.table_parquet_options.global.bloom_filter_on_read } + + /// Return the maximum predicate cache size, in bytes, used when + /// `pushdown_filters` + pub fn max_predicate_cache_size(&self) -> Option { + self.table_parquet_options.global.max_predicate_cache_size + } + + /// Applies schema adapter factory from the FileScanConfig if present. + /// + /// # Arguments + /// * `conf` - FileScanConfig that may contain a schema adapter factory + /// # Returns + /// The converted FileSource with schema adapter factory applied if provided + pub fn apply_schema_adapter( + self, + conf: &FileScanConfig, + ) -> datafusion_common::Result> { + let file_source: Arc = self.into(); + + // If the FileScanConfig.file_source() has a schema adapter factory, apply it + if let Some(factory) = conf.file_source().schema_adapter_factory() { + file_source.with_schema_adapter_factory( + Arc::::clone(&factory), + ) + } else { + Ok(file_source) + } + } + + #[cfg(feature = "parquet_encryption")] + fn get_encryption_factory_with_config( + &self, + ) -> Option<(Arc, EncryptionFactoryOptions)> { + match &self.encryption_factory { + None => None, + Some(factory) => Some(( + Arc::clone(factory), + self.table_parquet_options.crypto.factory_options.clone(), + )), + } + } +} + +/// Parses datafusion.common.config.ParquetOptions.coerce_int96 String to a arrow_schema.datatype.TimeUnit +pub(crate) fn parse_coerce_int96_string( + str_setting: &str, +) -> datafusion_common::Result { + let str_setting_lower: &str = &str_setting.to_lowercase(); + + match str_setting_lower { + "ns" => Ok(TimeUnit::Nanosecond), + "us" => Ok(TimeUnit::Microsecond), + "ms" => Ok(TimeUnit::Millisecond), + "s" => Ok(TimeUnit::Second), + _ => Err(DataFusionError::Configuration(format!( + "Unknown or unsupported parquet coerce_int96: \ + {str_setting}. Valid values are: ns, us, ms, and s." + ))), + } +} + +/// Allows easy conversion from ParquetSource to Arc<dyn FileSource> +impl From for Arc { + fn from(source: ParquetSource) -> Self { + as_file_source(source) + } } impl FileSource for ParquetSource { @@ -470,16 +498,66 @@ impl FileSource for ParquetSource { let projection = base_config .file_column_projection_indices() .unwrap_or_else(|| (0..base_config.file_schema.fields().len()).collect()); - let schema_adapter_factory = self - .schema_adapter_factory - .clone() - .unwrap_or_else(|| Arc::new(DefaultSchemaAdapterFactory)); + + let (expr_adapter_factory, schema_adapter_factory) = match ( + base_config.expr_adapter_factory.as_ref(), + self.schema_adapter_factory.as_ref(), + ) { + (Some(expr_adapter_factory), Some(schema_adapter_factory)) => { + // Use both the schema adapter factory and the expr adapter factory. + // This results in the the SchemaAdapter being used for projections (e.g. a column was selected that is a UInt32 in the file and a UInt64 in the table schema) + // but the PhysicalExprAdapterFactory being used for predicate pushdown and stats pruning. + ( + Some(Arc::clone(expr_adapter_factory)), + Arc::clone(schema_adapter_factory), + ) + } + (Some(expr_adapter_factory), None) => { + // If no custom schema adapter factory is provided but an expr adapter factory is provided use the expr adapter factory alongside the default schema adapter factory. + // This means that the PhysicalExprAdapterFactory will be used for predicate pushdown and stats pruning, while the default schema adapter factory will be used for projections. + ( + Some(Arc::clone(expr_adapter_factory)), + Arc::new(DefaultSchemaAdapterFactory) as _, + ) + } + (None, Some(schema_adapter_factory)) => { + // If a custom schema adapter factory is provided but no expr adapter factory is provided use the custom SchemaAdapter for both projections and predicate pushdown. + // This maximizes compatibility with existing code that uses the SchemaAdapter API and did not explicitly opt into the PhysicalExprAdapterFactory API. + (None, Arc::clone(schema_adapter_factory) as _) + } + (None, None) => { + // If no custom schema adapter factory or expr adapter factory is provided, use the default schema adapter factory and the default physical expr adapter factory. + // This means that the default SchemaAdapter will be used for projections (e.g. a column was selected that is a UInt32 in the file and a UInt64 in the table schema) + // and the default PhysicalExprAdapterFactory will be used for predicate pushdown and stats pruning. + // This is the default behavior with not customization and means that most users of DataFusion will be cut over to the new PhysicalExprAdapterFactory API. + ( + Some(Arc::new(DefaultPhysicalExprAdapterFactory) as _), + Arc::new(DefaultSchemaAdapterFactory) as _, + ) + } + }; let parquet_file_reader_factory = self.parquet_file_reader_factory.clone().unwrap_or_else(|| { Arc::new(DefaultParquetFileReaderFactory::new(object_store)) as _ }); + #[cfg(feature = "parquet_encryption")] + let file_decryption_properties = self + .table_parquet_options() + .crypto + .file_decryption + .as_ref() + .map(map_config_decryption_to_decryption) + .map(Arc::new); + + let coerce_int96 = self + .table_parquet_options + .global + .coerce_int96 + .as_ref() + .map(|time_unit| parse_coerce_int96_string(time_unit.as_str()).unwrap()); + Arc::new(ParquetOpener { partition_index: partition, projection: Arc::from(projection), @@ -488,9 +566,8 @@ impl FileSource for ParquetSource { .expect("Batch size must set before creating ParquetOpener"), limit: base_config.limit, predicate: self.predicate.clone(), - pruning_predicate: self.pruning_predicate.clone(), - page_pruning_predicate: self.page_pruning_predicate.clone(), - table_schema: Arc::clone(&base_config.file_schema), + logical_file_schema: Arc::clone(&base_config.file_schema), + partition_fields: base_config.table_partition_cols.clone(), metadata_size_hint: self.metadata_size_hint, metrics: self.metrics().clone(), parquet_file_reader_factory, @@ -498,7 +575,15 @@ impl FileSource for ParquetSource { reorder_filters: self.reorder_filters(), enable_page_index: self.enable_page_index(), enable_bloom_filter: self.bloom_filter_on_read(), + enable_row_group_stats_pruning: self.table_parquet_options.global.pruning, schema_adapter_factory, + coerce_int96, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties, + expr_adapter_factory, + #[cfg(feature = "parquet_encryption")] + encryption_factory: self.get_encryption_factory_with_config(), + max_predicate_cache_size: self.max_predicate_cache_size(), }) } @@ -506,14 +591,21 @@ impl FileSource for ParquetSource { self } + fn filter(&self) -> Option> { + self.predicate.clone() + } + fn with_batch_size(&self, batch_size: usize) -> Arc { let mut conf = self.clone(); conf.batch_size = Some(batch_size); Arc::new(conf) } - fn with_schema(&self, _schema: SchemaRef) -> Arc { - Arc::new(Self { ..self.clone() }) + fn with_schema(&self, schema: SchemaRef) -> Arc { + Arc::new(Self { + file_schema: Some(schema), + ..self.clone() + }) } fn with_statistics(&self, statistics: Statistics) -> Arc { @@ -537,11 +629,10 @@ impl FileSource for ParquetSource { .expect("projected_statistics must be set"); // When filters are pushed down, we have no way of knowing the exact statistics. // Note that pruning predicate is also a kind of filter pushdown. - // (bloom filters use `pruning_predicate` too) - if self.pruning_predicate().is_some() - || self.page_pruning_predicate().is_some() - || (self.predicate().is_some() && self.pushdown_filters()) - { + // (bloom filters use `pruning_predicate` too). + // Because filter pushdown may happen dynamically as long as there is a predicate + // if we have *any* predicate applied, we can't guarantee the statistics are exact. + if self.filter().is_some() { Ok(statistics.to_inexact()) } else { Ok(statistics) @@ -556,34 +647,153 @@ impl FileSource for ParquetSource { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let predicate_string = self - .predicate() + .filter() .map(|p| format!(", predicate={p}")) .unwrap_or_default(); - let pruning_predicate_string = self - .pruning_predicate() - .map(|pre| { - let mut guarantees = pre + + write!(f, "{predicate_string}")?; + + // Try to build a the pruning predicates. + // These are only generated here because it's useful to have *some* + // idea of what pushdown is happening when viewing plans. + // However it is important to note that these predicates are *not* + // necessarily the predicates that are actually evaluated: + // the actual predicates are built in reference to the physical schema of + // each file, which we do not have at this point and hence cannot use. + // Instead we use the logical schema of the file (the table schema without partition columns). + if let (Some(file_schema), Some(predicate)) = + (&self.file_schema, &self.predicate) + { + let predicate_creation_errors = Count::new(); + if let (Some(pruning_predicate), _) = build_pruning_predicates( + Some(predicate), + file_schema, + &predicate_creation_errors, + ) { + let mut guarantees = pruning_predicate .literal_guarantees() .iter() - .map(|item| format!("{}", item)) + .map(|item| format!("{item}")) .collect_vec(); guarantees.sort(); - format!( + write!( + f, ", pruning_predicate={}, required_guarantees=[{}]", - pre.predicate_expr(), + pruning_predicate.predicate_expr(), guarantees.join(", ") - ) - }) - .unwrap_or_default(); - - write!(f, "{}{}", predicate_string, pruning_predicate_string) + )?; + } + }; + Ok(()) } DisplayFormatType::TreeRender => { - if let Some(predicate) = self.predicate() { + if let Some(predicate) = self.filter() { writeln!(f, "predicate={}", fmt_sql(predicate.as_ref()))?; } Ok(()) } } } + + fn try_pushdown_filters( + &self, + filters: Vec>, + config: &ConfigOptions, + ) -> datafusion_common::Result>> { + let Some(file_schema) = self.file_schema.clone() else { + return Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::No; filters.len()], + )); + }; + // Determine if based on configs we should push filters down. + // If either the table / scan itself or the config has pushdown enabled, + // we will push down the filters. + // If both are disabled, we will not push down the filters. + // By default they are both disabled. + // Regardless of pushdown, we will update the predicate to include the filters + // because even if scan pushdown is disabled we can still use the filters for stats pruning. + let config_pushdown_enabled = config.execution.parquet.pushdown_filters; + let table_pushdown_enabled = self.pushdown_filters(); + let pushdown_filters = table_pushdown_enabled || config_pushdown_enabled; + + let mut source = self.clone(); + let filters: Vec = filters + .into_iter() + .map(|filter| { + if can_expr_be_pushed_down_with_schemas(&filter, &file_schema) { + PushedDownPredicate::supported(filter) + } else { + PushedDownPredicate::unsupported(filter) + } + }) + .collect(); + if filters + .iter() + .all(|f| matches!(f.discriminant, PushedDown::No)) + { + // No filters can be pushed down, so we can just return the remaining filters + // and avoid replacing the source in the physical plan. + return Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::No; filters.len()], + )); + } + let allowed_filters = filters + .iter() + .filter_map(|f| match f.discriminant { + PushedDown::Yes => Some(Arc::clone(&f.predicate)), + PushedDown::No => None, + }) + .collect_vec(); + let predicate = match source.predicate { + Some(predicate) => { + conjunction(std::iter::once(predicate).chain(allowed_filters)) + } + None => conjunction(allowed_filters), + }; + source.predicate = Some(predicate); + source = source.with_pushdown_filters(pushdown_filters); + let source = Arc::new(source); + // If pushdown_filters is false we tell our parents that they still have to handle the filters, + // even if we updated the predicate to include the filters (they will only be used for stats pruning). + if !pushdown_filters { + return Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::No; filters.len()], + ) + .with_updated_node(source)); + } + Ok(FilterPushdownPropagation::with_parent_pushdown_result( + filters.iter().map(|f| f.discriminant).collect(), + ) + .with_updated_node(source)) + } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> datafusion_common::Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_physical_expr::expressions::lit; + + #[test] + #[allow(deprecated)] + fn test_parquet_source_predicate_same_as_filter() { + let predicate = lit(true); + + let parquet_source = ParquetSource::default().with_predicate(predicate); + // same value. but filter() call Arc::clone internally + assert_eq!(parquet_source.predicate(), parquet_source.filter().as_ref()); + } } diff --git a/datafusion/datasource-parquet/src/writer.rs b/datafusion/datasource-parquet/src/writer.rs index 64eb37c81f5df..d37b6e26a7536 100644 --- a/datafusion/datasource-parquet/src/writer.rs +++ b/datafusion/datasource-parquet/src/writer.rs @@ -46,7 +46,15 @@ pub async fn plan_to_parquet( let propclone = writer_properties.clone(); let storeref = Arc::clone(&store); - let buf_writer = BufWriter::new(storeref, file.clone()); + let buf_writer = BufWriter::with_capacity( + storeref, + file.clone(), + task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + ); let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut writer = diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml index 2132272b5768d..afd0256ba9720 100644 --- a/datafusion/datasource/Cargo.toml +++ b/datafusion/datasource/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-datasource" description = "datafusion-datasource" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -31,7 +31,6 @@ version.workspace = true all-features = true [features] -parquet = ["dep:parquet", "tempfile"] compression = ["async-compression", "xz2", "bzip2", "flate2", "zstd", "tokio-util"] default = ["compression"] @@ -46,32 +45,33 @@ async-compression = { version = "0.4.19", features = [ ], optional = true } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.5.2", optional = true } +bzip2 = { version = "0.6.0", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } -flate2 = { version = "1.0.24", optional = true } +flate2 = { version = "1.1.4", optional = true } futures = { workspace = true } glob = "0.3.0" itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } -parquet = { workspace = true, optional = true } rand = { workspace = true } tempfile = { workspace = true, optional = true } tokio = { workspace = true } -tokio-util = { version = "0.7.14", features = ["io"], optional = true } +tokio-util = { version = "0.7.16", features = ["io"], optional = true } url = { workspace = true } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] +criterion = { workspace = true } tempfile = { workspace = true } [lints] @@ -80,3 +80,7 @@ workspace = true [lib] name = "datafusion_datasource" path = "src/mod.rs" + +[[bench]] +name = "split_groups_by_statistics" +harness = false diff --git a/datafusion/datasource/README.md b/datafusion/datasource/README.md index 750ee9375154f..cf0bb7547c078 100644 --- a/datafusion/datasource/README.md +++ b/datafusion/datasource/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion datasource +# Apache DataFusion DataSource -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that defines common DataSource related components like FileScanConfig, FileCompression etc. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource/benches/split_groups_by_statistics.rs b/datafusion/datasource/benches/split_groups_by_statistics.rs new file mode 100644 index 0000000000000..d51fdfc0a6e90 --- /dev/null +++ b/datafusion/datasource/benches/split_groups_by_statistics.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::time::Duration; + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_datasource::file_scan_config::FileScanConfig; +use datafusion_datasource::{generate_test_files, verify_sort_integrity}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; + +pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { + let file_schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Float64, + false, + )])); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("value", 0))); + let sort_ordering = LexOrdering::from([sort_expr]); + + // Small, medium, large number of files + let file_counts = [10, 100, 1000]; + let overlap_factors = [0.0, 0.2, 0.5, 0.8]; // No, low, medium, high overlap + + let target_partitions: [usize; 4] = [4, 8, 16, 32]; + + let mut group = c.benchmark_group("split_groups"); + group.measurement_time(Duration::from_secs(10)); + + for &num_files in &file_counts { + for &overlap in &overlap_factors { + let file_groups = generate_test_files(num_files, overlap); + // Benchmark original algorithm + group.bench_with_input( + BenchmarkId::new( + "original", + format!("files={num_files},overlap={overlap:.1}"), + ), + &( + file_groups.clone(), + file_schema.clone(), + sort_ordering.clone(), + ), + |b, (fg, schema, order)| { + let mut result = Vec::new(); + b.iter(|| { + result = + FileScanConfig::split_groups_by_statistics(schema, fg, order) + .unwrap(); + }); + assert!(verify_sort_integrity(&result)); + }, + ); + + // Benchmark new algorithm with different target partitions + for &tp in &target_partitions { + group.bench_with_input( + BenchmarkId::new( + format!("v2_partitions={tp}"), + format!("files={num_files},overlap={overlap:.1}"), + ), + &( + file_groups.clone(), + file_schema.clone(), + sort_ordering.clone(), + tp, + ), + |b, (fg, schema, order, target)| { + let mut result = Vec::new(); + b.iter(|| { + result = FileScanConfig::split_groups_by_statistics_with_target_partitions( + schema, fg, order, *target, + ) + .unwrap(); + }); + assert!(verify_sort_integrity(&result)); + }, + ); + } + } + } + + group.finish(); +} + +criterion_group!(benches, compare_split_groups_by_statistics_algorithms); +criterion_main!(benches); diff --git a/datafusion/datasource/src/file.rs b/datafusion/datasource/src/file.rs index 0066f39801a1b..7a2cf403fd8d6 100644 --- a/datafusion/datasource/src/file.rs +++ b/datafusion/datasource/src/file.rs @@ -25,17 +25,32 @@ use std::sync::Arc; use crate::file_groups::FileGroupPartitioner; use crate::file_scan_config::FileScanConfig; use crate::file_stream::FileOpener; +use crate::schema_adapter::SchemaAdapterFactory; use arrow::datatypes::SchemaRef; -use datafusion_common::Statistics; -use datafusion_physical_expr::LexOrdering; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{not_impl_err, Result, Statistics}; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; +use datafusion_physical_plan::filter_pushdown::{FilterPushdownPropagation, PushedDown}; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::DisplayFormatType; use object_store::ObjectStore; -/// Common file format behaviors needs to implement. +/// Helper function to convert any type implementing FileSource to Arc<dyn FileSource> +pub fn as_file_source(source: T) -> Arc { + Arc::new(source) +} + +/// file format specific behaviors for elements in [`DataSource`] +/// +/// See more details on specific implementations: +/// * [`ArrowSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ArrowSource.html) +/// * [`AvroSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.AvroSource.html) +/// * [`CsvSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.CsvSource.html) +/// * [`JsonSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.JsonSource.html) +/// * [`ParquetSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ParquetSource.html) /// -/// See implementation examples such as `ParquetSource`, `CsvSource` +/// [`DataSource`]: crate::source::DataSource pub trait FileSource: Send + Sync { /// Creates a `dyn FileOpener` based on given parameters fn create_file_opener( @@ -54,10 +69,14 @@ pub trait FileSource: Send + Sync { fn with_projection(&self, config: &FileScanConfig) -> Arc; /// Initialize new instance with projected statistics fn with_statistics(&self, statistics: Statistics) -> Arc; + /// Returns the filter expression that will be applied during the file scan. + fn filter(&self) -> Option> { + None + } /// Return execution plan metrics fn metrics(&self) -> &ExecutionPlanMetricsSet; /// Return projected statistics - fn statistics(&self) -> datafusion_common::Result; + fn statistics(&self) -> Result; /// String representation of file source such as "csv", "json", "parquet" fn file_type(&self) -> &str; /// Format FileType specific information @@ -65,17 +84,19 @@ pub trait FileSource: Send + Sync { Ok(()) } - /// If supported by the [`FileSource`], redistribute files across partitions according to their size. - /// Allows custom file formats to implement their own repartitioning logic. + /// If supported by the [`FileSource`], redistribute files across partitions + /// according to their size. Allows custom file formats to implement their + /// own repartitioning logic. /// - /// Provides a default repartitioning behavior, see comments on [`FileGroupPartitioner`] for more detail. + /// The default implementation uses [`FileGroupPartitioner`]. See that + /// struct for more details. fn repartitioned( &self, target_partitions: usize, repartition_file_min_size: usize, output_ordering: Option, config: &FileScanConfig, - ) -> datafusion_common::Result> { + ) -> Result> { if config.file_compression_type.is_compressed() || config.new_lines_in_values { return Ok(None); } @@ -93,4 +114,44 @@ pub trait FileSource: Send + Sync { } Ok(None) } + + /// Try to push down filters into this FileSource. + /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. + /// + /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result + fn try_pushdown_filters( + &self, + filters: Vec>, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::No; filters.len()], + )) + } + + /// Set optional schema adapter factory. + /// + /// [`SchemaAdapterFactory`] allows user to specify how fields from the + /// file get mapped to that of the table schema. If you implement this + /// method, you should also implement [`schema_adapter_factory`]. + /// + /// The default implementation returns a not implemented error. + /// + /// [`schema_adapter_factory`]: Self::schema_adapter_factory + fn with_schema_adapter_factory( + &self, + _factory: Arc, + ) -> Result> { + not_impl_err!( + "FileSource {} does not support schema adapter factory", + self.file_type() + ) + } + + /// Returns the current schema adapter factory if set + /// + /// Default implementation returns `None`. + fn schema_adapter_factory(&self) -> Option> { + None + } } diff --git a/datafusion/datasource/src/file_format.rs b/datafusion/datasource/src/file_format.rs index 0e0b7b12e16a0..23f68636c156e 100644 --- a/datafusion/datasource/src/file_format.rs +++ b/datafusion/datasource/src/file_format.rs @@ -28,11 +28,10 @@ use crate::file_compression_type::FileCompressionType; use crate::file_scan_config::FileScanConfig; use crate::file_sink_config::FileSinkConfig; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{internal_err, not_impl_err, GetExt, Result, Statistics}; -use datafusion_expr::Expr; -use datafusion_physical_expr::{LexRequirement, PhysicalExpr}; +use datafusion_physical_expr::LexRequirement; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; @@ -49,7 +48,7 @@ pub const DEFAULT_SCHEMA_INFER_MAX_RECORD: usize = 1000; /// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html #[async_trait] pub trait FileFormat: Send + Sync + fmt::Debug { - /// Returns the table provider as [`Any`](std::any::Any) so that it can be + /// Returns the table provider as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -62,6 +61,9 @@ pub trait FileFormat: Send + Sync + fmt::Debug { _file_compression_type: &FileCompressionType, ) -> Result; + /// Returns whether this instance uses compression if applicable + fn compression_type(&self) -> Option; + /// Infer the common schema of the provided objects. The objects will usually /// be analysed up to a given number of records or files (as specified in the /// format config) then give the estimated common schema. This might fail if @@ -94,7 +96,6 @@ pub trait FileFormat: Send + Sync + fmt::Debug { &self, state: &dyn Session, conf: FileScanConfig, - filters: Option<&Arc>, ) -> Result>; /// Take a list of files and the configuration to convert it to the @@ -109,37 +110,10 @@ pub trait FileFormat: Send + Sync + fmt::Debug { not_impl_err!("Writer not implemented for this format") } - /// Check if the specified file format has support for pushing down the provided filters within - /// the given schemas. Added initially to support the Parquet file format's ability to do this. - fn supports_filters_pushdown( - &self, - _file_schema: &Schema, - _table_schema: &Schema, - _filters: &[&Expr], - ) -> Result { - Ok(FilePushdownSupport::NoSupport) - } - /// Return the related FileSource such as `CsvSource`, `JsonSource`, etc. fn file_source(&self) -> Arc; } -/// An enum to distinguish between different states when determining if certain filters can be -/// pushed down to file scanning -#[derive(Debug, PartialEq)] -pub enum FilePushdownSupport { - /// The file format/system being asked does not support any sort of pushdown. This should be - /// used even if the file format theoretically supports some sort of pushdown, but it's not - /// enabled or implemented yet. - NoSupport, - /// The file format/system being asked *does* support pushdown, but it can't make it work for - /// the provided filter/expression - NotSupportedForFilter, - /// The file format/system being asked *does* support pushdown and *can* make it work for the - /// provided filter/expression - Supported, -} - /// Factory for creating [`FileFormat`] instances based on session and command level options /// /// Users can provide their own `FileFormatFactory` to support arbitrary file formats diff --git a/datafusion/datasource/src/file_groups.rs b/datafusion/datasource/src/file_groups.rs index 75c4160f145ed..998d09285cf1d 100644 --- a/datafusion/datasource/src/file_groups.rs +++ b/datafusion/datasource/src/file_groups.rs @@ -20,11 +20,12 @@ use crate::{FileRange, PartitionedFile}; use datafusion_common::Statistics; use itertools::Itertools; -use std::cmp::min; +use std::cmp::{min, Ordering}; use std::collections::BinaryHeap; use std::iter::repeat_with; use std::mem; -use std::ops::{Index, IndexMut}; +use std::ops::{Deref, DerefMut, Index, IndexMut}; +use std::sync::Arc; /// Repartition input files into `target_partitions` partitions, if total file size exceed /// `repartition_file_min_size` @@ -223,10 +224,11 @@ impl FileGroupPartitioner { return None; } - let target_partition_size = (total_size as usize).div_ceil(target_partitions); + let target_partition_size = + (total_size as u64).div_ceil(target_partitions as u64); let current_partition_index: usize = 0; - let current_partition_size: usize = 0; + let current_partition_size: u64 = 0; // Partition byte range evenly for all `PartitionedFile`s let repartitioned_files = flattened_files @@ -302,6 +304,7 @@ impl FileGroupPartitioner { None } }) + .map(CompareByRangeSize) .collect(); // No files can be redistributed @@ -335,7 +338,7 @@ impl FileGroupPartitioner { source_index, file_size, new_groups, - } = to_repartition; + } = to_repartition.into_inner(); assert_eq!(file_groups[source_index].len(), 1); let original_file = file_groups[source_index].pop().unwrap(); @@ -368,7 +371,7 @@ pub struct FileGroup { /// The files in this group files: Vec, /// Optional statistics for the data across all files in the group - statistics: Option, + statistics: Option>, } impl FileGroup { @@ -386,7 +389,7 @@ impl FileGroup { } /// Set the statistics for this group - pub fn with_statistics(mut self, statistics: Statistics) -> Self { + pub fn with_statistics(mut self, statistics: Arc) -> Self { self.statistics = Some(statistics); self } @@ -414,13 +417,23 @@ impl FileGroup { } /// Adds a file to the group - pub fn push(&mut self, file: PartitionedFile) { - self.files.push(file); + pub fn push(&mut self, partitioned_file: PartitionedFile) { + self.files.push(partitioned_file); } - /// Get the statistics for this group - pub fn statistics(&self) -> Option<&Statistics> { - self.statistics.as_ref() + /// Get the specific file statistics for the given index + /// If the index is None, return the `FileGroup` statistics + pub fn file_statistics(&self, index: Option) -> Option<&Statistics> { + if let Some(index) = index { + self.files.get(index).and_then(|f| f.statistics.as_deref()) + } else { + self.statistics.as_deref() + } + } + + /// Get the mutable reference to the statistics for this group + pub fn statistics_mut(&mut self) -> Option<&mut Statistics> { + self.statistics.as_mut().map(Arc::make_mut) } /// Partition the list of files into `n` groups @@ -491,33 +504,55 @@ impl Default for FileGroup { } /// Tracks how a individual file will be repartitioned -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone)] struct ToRepartition { /// the index from which the original file will be taken source_index: usize, /// the size of the original file - file_size: usize, + file_size: u64, /// indexes of which group(s) will this be distributed to (including `source_index`) new_groups: Vec, } impl ToRepartition { - // how big will each file range be when this file is read in its new groups? - fn range_size(&self) -> usize { - self.file_size / self.new_groups.len() + /// How big will each file range be when this file is read in its new groups? + fn range_size(&self) -> u64 { + self.file_size / (self.new_groups.len() as u64) } } -impl PartialOrd for ToRepartition { - fn partial_cmp(&self, other: &Self) -> Option { +struct CompareByRangeSize(ToRepartition); +impl CompareByRangeSize { + fn into_inner(self) -> ToRepartition { + self.0 + } +} +impl Ord for CompareByRangeSize { + fn cmp(&self, other: &Self) -> Ordering { + self.0.range_size().cmp(&other.0.range_size()) + } +} +impl PartialOrd for CompareByRangeSize { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } - -/// Order based on individual range -impl Ord for ToRepartition { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.range_size().cmp(&other.range_size()) +impl PartialEq for CompareByRangeSize { + fn eq(&self, other: &Self) -> bool { + // PartialEq must be consistent with PartialOrd + self.cmp(other) == Ordering::Equal + } +} +impl Eq for CompareByRangeSize {} +impl Deref for CompareByRangeSize { + type Target = ToRepartition; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for CompareByRangeSize { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } @@ -951,8 +986,8 @@ mod test { (Some(_), None) => panic!("Expected Some, got None"), (None, Some(_)) => panic!("Expected None, got Some"), (Some(expected), Some(actual)) => { - let expected_string = format!("{:#?}", expected); - let actual_string = format!("{:#?}", actual); + let expected_string = format!("{expected:#?}"); + let actual_string = format!("{actual:#?}"); assert_eq!(expected_string, actual_string); } } diff --git a/datafusion/datasource/src/file_meta.rs b/datafusion/datasource/src/file_meta.rs deleted file mode 100644 index 098a15eeb38a2..0000000000000 --- a/datafusion/datasource/src/file_meta.rs +++ /dev/null @@ -1,52 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use object_store::{path::Path, ObjectMeta}; - -use crate::FileRange; - -/// A single file or part of a file that should be read, along with its schema, statistics -pub struct FileMeta { - /// Path for the file (e.g. URL, filesystem path, etc) - pub object_meta: ObjectMeta, - /// An optional file range for a more fine-grained parallel execution - pub range: Option, - /// An optional field for user defined per object metadata - pub extensions: Option>, - /// Size hint for the metadata of this file - pub metadata_size_hint: Option, -} - -impl FileMeta { - /// The full path to the object - pub fn location(&self) -> &Path { - &self.object_meta.location - } -} - -impl From for FileMeta { - fn from(object_meta: ObjectMeta) -> Self { - Self { - object_meta, - range: None, - extensions: None, - metadata_size_hint: None, - } - } -} diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index 729283289cafc..e67e1f8273723 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -18,11 +18,15 @@ //! [`FileScanConfig`] to configure scanning of possibly partitioned //! file sources. -use std::{ - any::Any, borrow::Cow, collections::HashMap, fmt::Debug, fmt::Formatter, - fmt::Result as FmtResult, marker::PhantomData, sync::Arc, +use crate::file_groups::FileGroup; +#[allow(unused_imports)] +use crate::schema_adapter::SchemaAdapterFactory; +use crate::{ + display::FileGroupsDisplay, file::FileSource, + file_compression_type::FileCompressionType, file_stream::FileStream, + source::DataSource, statistics::MinMaxStatistics, PartitionedFile, }; - +use arrow::datatypes::FieldRef; use arrow::{ array::{ ArrayData, ArrayRef, BufferBuilder, DictionaryArray, RecordBatch, @@ -31,38 +35,43 @@ use arrow::{ buffer::Buffer, datatypes::{ArrowNativeType, DataType, Field, Schema, SchemaRef, UInt16Type}, }; -use datafusion_common::{exec_err, ColumnStatistics, Constraints, Result, Statistics}; -use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_datafusion_err, ColumnStatistics, + Constraints, Result, ScalarValue, Statistics, +}; use datafusion_execution::{ object_store::ObjectStoreUrl, SendableRecordBatchStream, TaskContext, }; -use datafusion_physical_expr::{ - expressions::Column, EquivalenceProperties, LexOrdering, Partitioning, - PhysicalSortExpr, -}; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::BinaryExpr; +use datafusion_physical_expr::{expressions::Column, utils::reassign_expr_columns}; +use datafusion_physical_expr::{split_conjunction, EquivalenceProperties, Partitioning}; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::projection::ProjectionExpr; use datafusion_physical_plan::{ display::{display_orderings, ProjectSchemaDisplay}, + filter_pushdown::FilterPushdownPropagation, metrics::ExecutionPlanMetricsSet, - projection::{all_alias_free_columns, new_projections_for_columns, ProjectionExec}, - DisplayAs, DisplayFormatType, ExecutionPlan, + projection::{all_alias_free_columns, new_projections_for_columns}, + DisplayAs, DisplayFormatType, }; -use log::{debug, warn}; - -use crate::file_groups::FileGroup; -use crate::{ - display::FileGroupsDisplay, - file::FileSource, - file_compression_type::FileCompressionType, - file_stream::FileStream, - source::{DataSource, DataSourceExec}, - statistics::MinMaxStatistics, - PartitionedFile, +use std::{ + any::Any, borrow::Cow, collections::HashMap, fmt::Debug, fmt::Formatter, + fmt::Result as FmtResult, marker::PhantomData, sync::Arc, }; +use datafusion_physical_expr::equivalence::project_orderings; +use datafusion_physical_plan::coop::cooperative; +use datafusion_physical_plan::execution_plan::SchedulingType; +use log::{debug, warn}; + /// The base configurations for a [`DataSourceExec`], the a physical plan for /// any given file format. /// -/// Use [`Self::build`] to create a [`DataSourceExec`] from a ``FileScanConfig`. +/// Use [`DataSourceExec::from_data_source`] to create a [`DataSourceExec`] from a ``FileScanConfig`. /// /// # Example /// ``` @@ -71,6 +80,7 @@ use crate::{ /// # use arrow::datatypes::{Field, Fields, DataType, Schema, SchemaRef}; /// # use object_store::ObjectStore; /// # use datafusion_common::Statistics; +/// # use datafusion_common::Result; /// # use datafusion_datasource::file::FileSource; /// # use datafusion_datasource::file_groups::FileGroup; /// # use datafusion_datasource::PartitionedFile; @@ -80,6 +90,7 @@ use crate::{ /// # use datafusion_execution::object_store::ObjectStoreUrl; /// # use datafusion_physical_plan::ExecutionPlan; /// # use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +/// # use datafusion_datasource::schema_adapter::SchemaAdapterFactory; /// # let file_schema = Arc::new(Schema::new(vec![ /// # Field::new("c1", DataType::Int32, false), /// # Field::new("c2", DataType::Int32, false), @@ -87,22 +98,26 @@ use crate::{ /// # Field::new("c4", DataType::Int32, false), /// # ])); /// # // Note: crate mock ParquetSource, as ParquetSource is not in the datasource crate +/// #[derive(Clone)] /// # struct ParquetSource { -/// # projected_statistics: Option +/// # projected_statistics: Option, +/// # schema_adapter_factory: Option> /// # }; /// # impl FileSource for ParquetSource { /// # fn create_file_opener(&self, _: Arc, _: &FileScanConfig, _: usize) -> Arc { unimplemented!() } /// # fn as_any(&self) -> &dyn Any { self } /// # fn with_batch_size(&self, _: usize) -> Arc { unimplemented!() } -/// # fn with_schema(&self, _: SchemaRef) -> Arc { unimplemented!() } +/// # fn with_schema(&self, _: SchemaRef) -> Arc { Arc::new(self.clone()) as Arc } /// # fn with_projection(&self, _: &FileScanConfig) -> Arc { unimplemented!() } -/// # fn with_statistics(&self, statistics: Statistics) -> Arc { Arc::new(Self {projected_statistics: Some(statistics)} ) } +/// # fn with_statistics(&self, statistics: Statistics) -> Arc { Arc::new(Self {projected_statistics: Some(statistics), schema_adapter_factory: self.schema_adapter_factory.clone()} ) } /// # fn metrics(&self) -> &ExecutionPlanMetricsSet { unimplemented!() } -/// # fn statistics(&self) -> datafusion_common::Result { Ok(self.projected_statistics.clone().expect("projected_statistics should be set")) } +/// # fn statistics(&self) -> Result { Ok(self.projected_statistics.clone().expect("projected_statistics should be set")) } /// # fn file_type(&self) -> &str { "parquet" } +/// # fn with_schema_adapter_factory(&self, factory: Arc) -> Result> { Ok(Arc::new(Self {projected_statistics: self.projected_statistics.clone(), schema_adapter_factory: Some(factory)} )) } +/// # fn schema_adapter_factory(&self) -> Option> { self.schema_adapter_factory.clone() } /// # } /// # impl ParquetSource { -/// # fn new() -> Self { Self {projected_statistics: None} } +/// # fn new() -> Self { Self {projected_statistics: None, schema_adapter_factory: None} } /// # } /// // create FileScan config for reading parquet files from file:// /// let object_store_url = ObjectStoreUrl::local_filesystem(); @@ -121,6 +136,9 @@ use crate::{ /// // create an execution plan from the config /// let plan: Arc = DataSourceExec::from_data_source(config); /// ``` +/// +/// [`DataSourceExec`]: crate::source::DataSourceExec +/// [`DataSourceExec::from_data_source`]: crate::source::DataSourceExec::from_data_source #[derive(Clone)] pub struct FileScanConfig { /// Object store URL, used to get an [`ObjectStore`] instance from @@ -138,6 +156,11 @@ pub struct FileScanConfig { /// Schema before `projection` is applied. It contains the all columns that may /// appear in the files. It does not include table partition columns /// that may be added. + /// Note that this is **not** the schema of the physical files. + /// This is the schema that the physical file schema will be + /// mapped onto, and the schema that the [`DataSourceExec`] will return. + /// + /// [`DataSourceExec`]: crate::source::DataSourceExec pub file_schema: SchemaRef, /// List of files to be processed, grouped into partitions /// @@ -151,9 +174,6 @@ pub struct FileScanConfig { pub file_groups: Vec, /// Table constraints pub constraints: Constraints, - /// Estimated overall statistics of the files, taking `filters` into account. - /// Defaults to [`Statistics::new_unknown`]. - pub statistics: Statistics, /// Columns on which to project the data. Indexes that are higher than the /// number of columns of `file_schema` refer to `table_partition_cols`. pub projection: Option>, @@ -161,7 +181,7 @@ pub struct FileScanConfig { /// all records after filtering are returned. pub limit: Option, /// The partitioning columns - pub table_partition_cols: Vec, + pub table_partition_cols: Vec, /// All equivalent lexicographical orderings that describe the schema. pub output_ordering: Vec, /// File compression type @@ -173,6 +193,9 @@ pub struct FileScanConfig { /// Batch size while creating new batches /// Defaults to [`datafusion_common::config::ExecutionOptions`] batch_size. pub batch_size: Option, + /// Expression adapter used to adapt filters and projections that are pushed down into the scan + /// from the logical schema to the physical schema of the file. + pub expr_adapter_factory: Option>, } /// A builder for [`FileScanConfig`]'s. @@ -227,12 +250,23 @@ pub struct FileScanConfig { #[derive(Clone)] pub struct FileScanConfigBuilder { object_store_url: ObjectStoreUrl, + /// Table schema before any projections or partition columns are applied. + /// + /// This schema is used to read the files, but is **not** necessarily the + /// schema of the physical files. Rather this is the schema that the + /// physical file schema will be mapped onto, and the schema that the + /// [`DataSourceExec`] will return. + /// + /// This is usually the same as the table schema as specified by the `TableProvider` minus any partition columns. + /// + /// This probably would be better named `table_schema` + /// + /// [`DataSourceExec`]: crate::source::DataSourceExec file_schema: SchemaRef, file_source: Arc, - limit: Option, projection: Option>, - table_partition_cols: Vec, + table_partition_cols: Vec, constraints: Option, file_groups: Vec, statistics: Option, @@ -240,6 +274,7 @@ pub struct FileScanConfigBuilder { file_compression_type: Option, new_lines_in_values: Option, batch_size: Option, + expr_adapter_factory: Option>, } impl FileScanConfigBuilder { @@ -268,6 +303,7 @@ impl FileScanConfigBuilder { table_partition_cols: vec![], constraints: None, batch_size: None, + expr_adapter_factory: None, } } @@ -296,7 +332,10 @@ impl FileScanConfigBuilder { /// Set the partitioning columns pub fn with_table_partition_cols(mut self, table_partition_cols: Vec) -> Self { - self.table_partition_cols = table_partition_cols; + self.table_partition_cols = table_partition_cols + .into_iter() + .map(|f| Arc::new(f) as FieldRef) + .collect(); self } @@ -338,8 +377,8 @@ impl FileScanConfigBuilder { /// Add a file as a single group /// /// See [`Self::with_file_groups`] for more information. - pub fn with_file(self, file: PartitionedFile) -> Self { - self.with_file_group(FileGroup::new(vec![file])) + pub fn with_file(self, partitioned_file: PartitionedFile) -> Self { + self.with_file_group(FileGroup::new(vec![partitioned_file])) } /// Set the output ordering of the files @@ -373,6 +412,20 @@ impl FileScanConfigBuilder { self } + /// Register an expression adapter used to adapt filters and projections that are pushed down into the scan + /// from the logical schema to the physical schema of the file. + /// This can include things like: + /// - Column ordering changes + /// - Handling of missing columns + /// - Rewriting expression to use pre-computed values or file format specific optimizations + pub fn with_expr_adapter( + mut self, + expr_adapter: Option>, + ) -> Self { + self.expr_adapter_factory = expr_adapter; + self + } + /// Build the final [`FileScanConfig`] with all the configured settings. /// /// This method takes ownership of the builder and returns the constructed `FileScanConfig`. @@ -392,13 +445,16 @@ impl FileScanConfigBuilder { file_compression_type, new_lines_in_values, batch_size, + expr_adapter_factory: expr_adapter, } = self; let constraints = constraints.unwrap_or_default(); let statistics = statistics.unwrap_or_else(|| Statistics::new_unknown(&file_schema)); - let file_source = file_source.with_statistics(statistics.clone()); + let file_source = file_source + .with_statistics(statistics.clone()) + .with_schema(Arc::clone(&file_schema)); let file_compression_type = file_compression_type.unwrap_or(FileCompressionType::UNCOMPRESSED); let new_lines_in_values = new_lines_in_values.unwrap_or(false); @@ -412,11 +468,11 @@ impl FileScanConfigBuilder { table_partition_cols, constraints, file_groups, - statistics, output_ordering, file_compression_type, new_lines_in_values, batch_size, + expr_adapter_factory: expr_adapter, } } } @@ -426,9 +482,9 @@ impl From for FileScanConfigBuilder { Self { object_store_url: config.object_store_url, file_schema: config.file_schema, - file_source: config.file_source, + file_source: Arc::::clone(&config.file_source), file_groups: config.file_groups, - statistics: Some(config.statistics), + statistics: config.file_source.statistics().ok(), output_ordering: config.output_ordering, file_compression_type: Some(config.file_compression_type), new_lines_in_values: Some(config.new_lines_in_values), @@ -437,6 +493,7 @@ impl From for FileScanConfigBuilder { table_partition_cols: config.table_partition_cols, constraints: Some(config.constraints), batch_size: config.batch_size, + expr_adapter_factory: config.expr_adapter_factory, } } } @@ -455,13 +512,12 @@ impl DataSource for FileScanConfig { let source = self .file_source .with_batch_size(batch_size) - .with_schema(Arc::clone(&self.file_schema)) .with_projection(self); let opener = source.create_file_opener(object_store, self, partition); let stream = FileStream::new(self, partition, opener, source.metrics())?; - Ok(Box::pin(stream)) + Ok(Box::pin(cooperative(stream))) } fn as_any(&self) -> &dyn Any { @@ -471,7 +527,8 @@ impl DataSource for FileScanConfig { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let (schema, _, _, orderings) = self.project(); + let schema = self.projected_schema(); + let orderings = get_projected_output_ordering(self, &schema); write!(f, "file_groups=")?; FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; @@ -525,8 +582,26 @@ impl DataSource for FileScanConfig { fn eq_properties(&self) -> EquivalenceProperties { let (schema, constraints, _, orderings) = self.project(); - EquivalenceProperties::new_with_orderings(schema, orderings.as_slice()) - .with_constraints(constraints) + let mut eq_properties = + EquivalenceProperties::new_with_orderings(Arc::clone(&schema), orderings) + .with_constraints(constraints); + if let Some(filter) = self.file_source.filter() { + // We need to remap column indexes to match the projected schema since that's what the equivalence properties deal with. + // Note that this will *ignore* any non-projected columns: these don't factor into ordering / equivalence. + match Self::add_filter_equivalence_info(filter, &mut eq_properties, &schema) { + Ok(()) => {} + Err(e) => { + warn!("Failed to add filter equivalence info: {e}"); + #[cfg(debug_assertions)] + panic!("Failed to add filter equivalence info: {e}"); + } + } + } + eq_properties + } + + fn scheduling_type(&self) -> SchedulingType { + SchedulingType::Cooperative } fn statistics(&self) -> Result { @@ -550,20 +625,22 @@ impl DataSource for FileScanConfig { fn try_swapping_with_projection( &self, - projection: &ProjectionExec, - ) -> Result>> { + projection: &[ProjectionExpr], + ) -> Result>> { // This process can be moved into CsvExec, but it would be an overlap of their responsibility. // Must be all column references, with no table partition columns (which can not be projected) - let partitioned_columns_in_proj = projection.expr().iter().any(|(expr, _)| { - expr.as_any() + let partitioned_columns_in_proj = projection.iter().any(|proj_expr| { + proj_expr + .expr + .as_any() .downcast_ref::() .map(|expr| expr.index() >= self.file_schema.fields().len()) .unwrap_or(false) }); // If there is any non-column or alias-carrier expression, Projection should not be removed. - let no_aliases = all_alias_free_columns(projection.expr()); + let no_aliases = all_alias_free_columns(projection); Ok((no_aliases && !partitioned_columns_in_proj).then(|| { let file_scan = self.clone(); @@ -573,9 +650,10 @@ impl DataSource for FileScanConfig { &file_scan .projection .clone() - .unwrap_or((0..self.file_schema.fields().len()).collect()), + .unwrap_or_else(|| (0..self.file_schema.fields().len()).collect()), ); - DataSourceExec::from_data_source( + + Arc::new( FileScanConfigBuilder::from(file_scan) // Assign projected statistics to source .with_projection(Some(new_projections)) @@ -584,66 +662,35 @@ impl DataSource for FileScanConfig { ) as _ })) } -} -impl FileScanConfig { - /// Create a new [`FileScanConfig`] with default settings for scanning files. - /// - /// See example on [`FileScanConfig`] - /// - /// No file groups are added by default. See [`Self::with_file`], [`Self::with_file_group`] and - /// [`Self::with_file_groups`]. - /// - /// # Parameters: - /// * `object_store_url`: See [`Self::object_store_url`] - /// * `file_schema`: See [`Self::file_schema`] - #[allow(deprecated)] // `new` will be removed same time as `with_source` - pub fn new( - object_store_url: ObjectStoreUrl, - file_schema: SchemaRef, - file_source: Arc, - ) -> Self { - let statistics = Statistics::new_unknown(&file_schema); - let file_source = file_source.with_statistics(statistics.clone()); - Self { - object_store_url, - file_schema, - file_groups: vec![], - constraints: Constraints::empty(), - statistics, - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - file_compression_type: FileCompressionType::UNCOMPRESSED, - new_lines_in_values: false, - file_source: Arc::clone(&file_source), - batch_size: None, + fn try_pushdown_filters( + &self, + filters: Vec>, + config: &ConfigOptions, + ) -> Result>> { + let result = self.file_source.try_pushdown_filters(filters, config)?; + match result.updated_node { + Some(new_file_source) => { + let file_scan_config = FileScanConfigBuilder::from(self.clone()) + .with_source(new_file_source) + .build(); + Ok(FilterPushdownPropagation { + filters: result.filters, + updated_node: Some(Arc::new(file_scan_config) as _), + }) + } + None => { + // If the file source does not support filter pushdown, return the original config + Ok(FilterPushdownPropagation { + filters: result.filters, + updated_node: None, + }) + } } } +} - /// Set the file source - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_source(mut self, file_source: Arc) -> Self { - self.file_source = file_source.with_statistics(self.statistics.clone()); - self - } - - /// Set the table constraints of the files - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.constraints = constraints; - self - } - - /// Set the statistics of the files - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_statistics(mut self, statistics: Statistics) -> Self { - self.statistics = statistics.clone(); - self.file_source = self.file_source.with_statistics(statistics); - self - } - +impl FileScanConfig { fn projection_indices(&self) -> Vec { match &self.projection { Some(proj) => proj.clone(), @@ -653,11 +700,8 @@ impl FileScanConfig { } } - fn projected_stats(&self) -> Statistics { - let statistics = self - .file_source - .statistics() - .unwrap_or(self.statistics.clone()); + pub fn projected_stats(&self) -> Statistics { + let statistics = self.file_source.statistics().unwrap(); let table_cols_stats = self .projection_indices() @@ -680,7 +724,7 @@ impl FileScanConfig { } } - fn projected_schema(&self) -> Arc { + pub fn projected_schema(&self) -> Arc { let table_fields: Vec<_> = self .projection_indices() .into_iter() @@ -689,7 +733,9 @@ impl FileScanConfig { self.file_schema.field(idx).clone() } else { let partition_idx = idx - self.file_schema.fields().len(); - self.table_partition_cols[partition_idx].clone() + Arc::unwrap_or_clone(Arc::clone( + &self.table_partition_cols[partition_idx], + )) } }) .collect(); @@ -700,91 +746,35 @@ impl FileScanConfig { )) } - fn projected_constraints(&self) -> Constraints { - let indexes = self.projection_indices(); - - self.constraints - .project(&indexes) - .unwrap_or_else(Constraints::empty) - } - - /// Set the projection of the files - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_projection(mut self, projection: Option>) -> Self { - self.projection = projection; - self - } - - /// Set the limit of the files - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_limit(mut self, limit: Option) -> Self { - self.limit = limit; - self - } - - /// Add a file as a single group - /// - /// See [Self::file_groups] for more information. - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - #[allow(deprecated)] - pub fn with_file(self, file: PartitionedFile) -> Self { - self.with_file_group(FileGroup::new(vec![file])) - } - - /// Add the file groups - /// - /// See [Self::file_groups] for more information. - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_file_groups(mut self, mut file_groups: Vec) -> Self { - self.file_groups.append(&mut file_groups); - self - } - - /// Add a new file group - /// - /// See [Self::file_groups] for more information - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_file_group(mut self, file_group: FileGroup) -> Self { - self.file_groups.push(file_group); - self - } - - /// Set the partitioning columns of the files - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_table_partition_cols(mut self, table_partition_cols: Vec) -> Self { - self.table_partition_cols = table_partition_cols; - self - } - - /// Set the output ordering of the files - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_output_ordering(mut self, output_ordering: Vec) -> Self { - self.output_ordering = output_ordering; - self - } + fn add_filter_equivalence_info( + filter: Arc, + eq_properties: &mut EquivalenceProperties, + schema: &Schema, + ) -> Result<()> { + // Gather valid equality pairs from the filter expression + let equal_pairs = split_conjunction(&filter).into_iter().filter_map(|expr| { + // Ignore any binary expressions that reference non-existent columns in the current schema + // (e.g. due to unnecessary projections being removed) + reassign_expr_columns(Arc::clone(expr), schema) + .ok() + .and_then(|expr| match expr.as_any().downcast_ref::() { + Some(expr) if expr.op() == &Operator::Eq => { + Some((Arc::clone(expr.left()), Arc::clone(expr.right()))) + } + _ => None, + }) + }); - /// Set the file compression type - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_file_compression_type( - mut self, - file_compression_type: FileCompressionType, - ) -> Self { - self.file_compression_type = file_compression_type; - self - } + for (lhs, rhs) in equal_pairs { + eq_properties.add_equal_conditions(lhs, rhs)? + } - /// Set the new_lines_in_values property - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_newlines_in_values(mut self, new_lines_in_values: bool) -> Self { - self.new_lines_in_values = new_lines_in_values; - self + Ok(()) } - /// Set the batch_size property - #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] - pub fn with_batch_size(mut self, batch_size: Option) -> Self { - self.batch_size = batch_size; - self + pub fn projected_constraints(&self) -> Constraints { + let indexes = self.projection_indices(); + self.constraints.project(&indexes).unwrap_or_default() } /// Specifies whether newlines in (quoted) values are supported. @@ -804,7 +794,7 @@ impl FileScanConfig { return ( Arc::clone(&self.file_schema), self.constraints.clone(), - self.statistics.clone(), + self.file_source.statistics().unwrap().clone(), self.output_ordering.clone(), ); } @@ -858,6 +848,96 @@ impl FileScanConfig { }) } + /// Splits file groups into new groups based on statistics to enable efficient parallel processing. + /// + /// The method distributes files across a target number of partitions while ensuring + /// files within each partition maintain sort order based on their min/max statistics. + /// + /// The algorithm works by: + /// 1. Takes files sorted by minimum values + /// 2. For each file: + /// - Finds eligible groups (empty or where file's min > group's last max) + /// - Selects the smallest eligible group + /// - Creates a new group if needed + /// + /// # Parameters + /// * `table_schema`: Schema containing information about the columns + /// * `file_groups`: The original file groups to split + /// * `sort_order`: The lexicographical ordering to maintain within each group + /// * `target_partitions`: The desired number of output partitions + /// + /// # Returns + /// A new set of file groups, where files within each group are non-overlapping with respect to + /// their min/max statistics and maintain the specified sort order. + pub fn split_groups_by_statistics_with_target_partitions( + table_schema: &SchemaRef, + file_groups: &[FileGroup], + sort_order: &LexOrdering, + target_partitions: usize, + ) -> Result> { + if target_partitions == 0 { + return Err(internal_datafusion_err!( + "target_partitions must be greater than 0" + )); + } + + let flattened_files = file_groups + .iter() + .flat_map(FileGroup::iter) + .collect::>(); + + if flattened_files.is_empty() { + return Ok(vec![]); + } + + let statistics = MinMaxStatistics::new_from_files( + sort_order, + table_schema, + None, + flattened_files.iter().copied(), + )?; + + let indices_sorted_by_min = statistics.min_values_sorted(); + + // Initialize with target_partitions empty groups + let mut file_groups_indices: Vec> = vec![vec![]; target_partitions]; + + for (idx, min) in indices_sorted_by_min { + if let Some((_, group)) = file_groups_indices + .iter_mut() + .enumerate() + .filter(|(_, group)| { + group.is_empty() + || min + > statistics + .max(*group.last().expect("groups should not be empty")) + }) + .min_by_key(|(_, group)| group.len()) + { + group.push(idx); + } else { + // Create a new group if no existing group fits + file_groups_indices.push(vec![idx]); + } + } + + // Remove any empty groups + file_groups_indices.retain(|group| !group.is_empty()); + + // Assemble indices back into groups of PartitionedFiles + Ok(file_groups_indices + .into_iter() + .map(|file_group_indices| { + FileGroup::new( + file_group_indices + .into_iter() + .map(|idx| flattened_files[idx].clone()) + .collect(), + ) + }) + .collect()) + } + /// Attempts to do a bin-packing on files into file groups, such that any two files /// in a file group are ordered and non-overlapping with respect to their statistics. /// It will produce the smallest number of file groups possible. @@ -926,12 +1006,6 @@ impl FileScanConfig { .collect()) } - /// Returns a new [`DataSourceExec`] to scan the files specified by this config - #[deprecated(since = "47.0.0", note = "use DataSourceExec::new instead")] - pub fn build(self) -> Arc { - DataSourceExec::from_data_source(self) - } - /// Write the data_type based on file_source fn fmt_file_source(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { write!(f, ", file_type={}", self.file_source.file_type())?; @@ -949,7 +1023,11 @@ impl Debug for FileScanConfig { write!(f, "FileScanConfig {{")?; write!(f, "object_store_url={:?}, ", self.object_store_url)?; - write!(f, "statistics={:?}, ", self.statistics)?; + write!( + f, + "statistics={:?}, ", + self.file_source.statistics().unwrap() + )?; DisplayAs::fmt_as(self, DisplayFormatType::Verbose, f)?; write!(f, "}}") @@ -958,7 +1036,8 @@ impl Debug for FileScanConfig { impl DisplayAs for FileScanConfig { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { - let (schema, _, _, orderings) = self.project(); + let schema = self.projected_schema(); + let orderings = get_projected_output_ordering(self, &schema); write!(f, "file_groups=")?; FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; @@ -1044,12 +1123,9 @@ impl PartitionColumnProjector { let mut cols = file_batch.columns().to_vec(); for &(pidx, sidx) in &self.projected_partition_indexes { - let p_value = - partition_values - .get(pidx) - .ok_or(DataFusionError::Execution( - "Invalid partitioning found on disk".to_string(), - ))?; + let p_value = partition_values.get(pidx).ok_or_else(|| { + exec_datafusion_err!("Invalid partitioning found on disk") + })?; let mut partition_value = Cow::Borrowed(p_value); @@ -1292,32 +1368,11 @@ fn get_projected_output_ordering( base_config: &FileScanConfig, projected_schema: &SchemaRef, ) -> Vec { - let mut all_orderings = vec![]; - for output_ordering in &base_config.output_ordering { - let mut new_ordering = LexOrdering::default(); - for PhysicalSortExpr { expr, options } in output_ordering.iter() { - if let Some(col) = expr.as_any().downcast_ref::() { - let name = col.name(); - if let Some((idx, _)) = projected_schema.column_with_name(name) { - // Compute the new sort expression (with correct index) after projection: - new_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(name, idx)), - options: *options, - }); - continue; - } - } - // Cannot find expression in the projected_schema, stop iterating - // since rest of the orderings are violated - break; - } - - // do not push empty entries - // otherwise we may have `Some(vec![])` at the output ordering. - if new_ordering.is_empty() { - continue; - } + let projected_orderings = + project_orderings(&base_config.output_ordering, projected_schema); + let mut all_orderings = vec![]; + for new_ordering in projected_orderings { // Check if any file groups are not sorted if base_config.file_groups.iter().any(|group| { if group.len() <= 1 { @@ -1377,38 +1432,20 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { #[cfg(test)] mod tests { - use crate::{test_util::MockSource, tests::aggr_test_schema}; - use super::*; - use arrow::{ - array::{Int32Array, RecordBatch}, - compute::SortOptions, + use crate::test_util::col; + use crate::{ + generate_test_files, test_util::MockSource, tests::aggr_test_schema, + verify_sort_integrity, }; + use arrow::array::{Int32Array, RecordBatch}; use datafusion_common::stats::Precision; - use datafusion_common::{assert_batches_eq, DFSchema}; - use datafusion_expr::{execution_props::ExecutionProps, SortExpr}; - use datafusion_physical_expr::create_physical_expr; - use std::collections::HashMap; - - fn create_physical_sort_expr( - e: &SortExpr, - input_dfschema: &DFSchema, - execution_props: &ExecutionProps, - ) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = e; - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } + use datafusion_common::{assert_batches_eq, internal_err}; + use datafusion_expr::{Operator, SortExpr}; + use datafusion_physical_expr::create_physical_sort_expr; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// Returns the column names on the schema pub fn columns(schema: &Schema) -> Vec { @@ -1468,7 +1505,7 @@ mod tests { ); // verify the proj_schema includes the last column and exactly the same the field it is defined - let (proj_schema, _, _, _) = conf.project(); + let proj_schema = conf.projected_schema(); assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); assert_eq!( *proj_schema.field(file_schema.fields().len()), @@ -1574,7 +1611,7 @@ mod tests { assert_eq!(source_statistics, statistics); assert_eq!(source_statistics.column_statistics.len(), 3); - let (proj_schema, ..) = conf.project(); + let proj_schema = conf.projected_schema(); // created a projector for that projected schema let mut proj = PartitionColumnProjector::new( proj_schema, @@ -1764,13 +1801,28 @@ mod tests { struct File { name: &'static str, date: &'static str, - statistics: Vec>, + statistics: Vec, Option)>>, } impl File { fn new( name: &'static str, date: &'static str, statistics: Vec>, + ) -> Self { + Self::new_nullable( + name, + date, + statistics + .into_iter() + .map(|opt| opt.map(|(min, max)| (Some(min), Some(max)))) + .collect(), + ) + } + + fn new_nullable( + name: &'static str, + date: &'static str, + statistics: Vec, Option)>>, ) -> Self { Self { name, @@ -1837,21 +1889,35 @@ mod tests { sort: vec![col("value").sort(false, true)], expected_result: Ok(vec![vec!["1", "0"], vec!["2"]]), }, - // reject nullable sort columns TestCase { - name: "no nullable sort columns", + name: "nullable sort columns, nulls last", file_schema: Schema::new(vec![Field::new( "value".to_string(), DataType::Float64, - true, // should fail because nullable + true, )]), files: vec![ - File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), - File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), - File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), + File::new_nullable("0", "2023-01-01", vec![Some((Some(0.00), Some(0.49)))]), + File::new_nullable("1", "2023-01-01", vec![Some((Some(0.50), None))]), + File::new_nullable("2", "2023-01-02", vec![Some((Some(0.00), None))]), ], sort: vec![col("value").sort(true, false)], - expected_result: Err("construct min/max statistics for split_groups_by_statistics\ncaused by\nbuild min rows\ncaused by\ncreate sorting columns\ncaused by\nError during planning: cannot sort by nullable column") + expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]) + }, + TestCase { + name: "nullable sort columns, nulls first", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + true, + )]), + files: vec![ + File::new_nullable("0", "2023-01-01", vec![Some((None, Some(0.49)))]), + File::new_nullable("1", "2023-01-01", vec![Some((Some(0.50), Some(1.00)))]), + File::new_nullable("2", "2023-01-02", vec![Some((None, Some(1.00)))]), + ], + sort: vec![col("value").sort(true, true)], + expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]) }, TestCase { name: "all three non-overlapping", @@ -1925,25 +1991,27 @@ mod tests { )))) .collect::>(), )); - let sort_order = LexOrdering::from( + let Some(sort_order) = LexOrdering::new( case.sort .into_iter() .map(|expr| { create_physical_sort_expr( &expr, - &DFSchema::try_from(table_schema.as_ref().clone())?, + &DFSchema::try_from(Arc::clone(&table_schema))?, &ExecutionProps::default(), ) }) .collect::>>()?, - ); + ) else { + return internal_err!("This test should always use an ordering"); + }; let partitioned_files = FileGroup::new( case.files.into_iter().map(From::from).collect::>(), ); let result = FileScanConfig::split_groups_by_statistics( &table_schema, - &[partitioned_files.clone()], + std::slice::from_ref(&partitioned_files), &sort_order, ); let results_by_name = result @@ -2009,12 +2077,12 @@ mod tests { .map(|stats| { stats .map(|(min, max)| ColumnStatistics { - min_value: Precision::Exact(ScalarValue::from( - min, - )), - max_value: Precision::Exact(ScalarValue::from( - max, - )), + min_value: Precision::Exact( + ScalarValue::Float64(min), + ), + max_value: Precision::Exact( + ScalarValue::Float64(max), + ), ..Default::default() }) .unwrap_or_default() @@ -2099,13 +2167,15 @@ mod tests { wrap_partition_type_in_dict(DataType::Utf8), false, )]) - .with_constraints(Constraints::empty()) .with_statistics(Statistics::new_unknown(&file_schema)) .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( "test.parquet".to_string(), 1024, )])]) - .with_output_ordering(vec![LexOrdering::default()]) + .with_output_ordering(vec![[PhysicalSortExpr::new_default(Arc::new( + Column::new("date", 0), + ))] + .into()]) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) .with_newlines_in_values(true) .build(); @@ -2131,6 +2201,54 @@ mod tests { assert_eq!(config.output_ordering.len(), 1); } + #[test] + fn equivalence_properties_after_schema_change() { + let file_schema = aggr_test_schema(); + let object_store_url = ObjectStoreUrl::parse("test:///").unwrap(); + // Create a file source with a filter + let file_source: Arc = + Arc::new(MockSource::default().with_filter(Arc::new(BinaryExpr::new( + col("c2", &file_schema).unwrap(), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )))); + + let config = FileScanConfigBuilder::new( + object_store_url.clone(), + Arc::clone(&file_schema), + Arc::clone(&file_source), + ) + .with_projection(Some(vec![0, 1, 2])) + .build(); + + // Simulate projection being updated. Since the filter has already been pushed down, + // the new projection won't include the filtered column. + let data_source = config + .try_swapping_with_projection(&[ProjectionExpr::new( + col("c3", &file_schema).unwrap(), + "c3".to_string(), + )]) + .unwrap() + .unwrap(); + + // Gather the equivalence properties from the new data source. There should + // be no equivalence class for column c2 since it was removed by the projection. + let eq_properties = data_source.eq_properties(); + let eq_group = eq_properties.eq_group(); + + for class in eq_group.iter() { + for expr in class.iter() { + if let Some(col) = expr.as_any().downcast_ref::() { + assert_ne!( + col.name(), + "c2", + "c2 should not be present in any equivalence class" + ); + } + } + } + } + #[test] fn test_file_scan_config_builder_defaults() { let file_schema = aggr_test_schema(); @@ -2161,13 +2279,24 @@ mod tests { assert!(config.constraints.is_empty()); // Verify statistics are set to unknown - assert_eq!(config.statistics.num_rows, Precision::Absent); - assert_eq!(config.statistics.total_byte_size, Precision::Absent); assert_eq!( - config.statistics.column_statistics.len(), + config.file_source.statistics().unwrap().num_rows, + Precision::Absent + ); + assert_eq!( + config.file_source.statistics().unwrap().total_byte_size, + Precision::Absent + ); + assert_eq!( + config + .file_source + .statistics() + .unwrap() + .column_statistics + .len(), file_schema.fields().len() ); - for stat in config.statistics.column_statistics { + for stat in config.file_source.statistics().unwrap().column_statistics { assert_eq!(stat.distinct_count, Precision::Absent); assert_eq!(stat.min_value, Precision::Absent); assert_eq!(stat.max_value, Precision::Absent); @@ -2208,6 +2337,7 @@ mod tests { let new_config = new_builder.build(); // Verify properties match + let partition_cols = partition_cols.into_iter().map(Arc::new).collect::>(); assert_eq!(new_config.object_store_url, object_store_url); assert_eq!(new_config.file_schema, schema); assert_eq!(new_config.projection, Some(vec![0, 2])); @@ -2222,4 +2352,158 @@ mod tests { assert_eq!(new_config.constraints, Constraints::default()); assert!(new_config.new_lines_in_values); } + + #[test] + fn test_split_groups_by_statistics_with_target_partitions() -> Result<()> { + use datafusion_common::DFSchema; + use datafusion_expr::{col, execution_props::ExecutionProps}; + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Float64, + false, + )])); + + // Setup sort expression + let exec_props = ExecutionProps::new(); + let df_schema = DFSchema::try_from_qualified_schema("test", schema.as_ref())?; + let sort_expr = [col("value").sort(true, false)]; + let sort_ordering = sort_expr + .map(|expr| { + create_physical_sort_expr(&expr, &df_schema, &exec_props).unwrap() + }) + .into(); + + // Test case parameters + struct TestCase { + name: String, + file_count: usize, + overlap_factor: f64, + target_partitions: usize, + expected_partition_count: usize, + } + + let test_cases = vec![ + // Basic cases + TestCase { + name: "no_overlap_10_files_4_partitions".to_string(), + file_count: 10, + overlap_factor: 0.0, + target_partitions: 4, + expected_partition_count: 4, + }, + TestCase { + name: "medium_overlap_20_files_5_partitions".to_string(), + file_count: 20, + overlap_factor: 0.5, + target_partitions: 5, + expected_partition_count: 5, + }, + TestCase { + name: "high_overlap_30_files_3_partitions".to_string(), + file_count: 30, + overlap_factor: 0.8, + target_partitions: 3, + expected_partition_count: 7, + }, + // Edge cases + TestCase { + name: "fewer_files_than_partitions".to_string(), + file_count: 3, + overlap_factor: 0.0, + target_partitions: 10, + expected_partition_count: 3, // Should only create as many partitions as files + }, + TestCase { + name: "single_file".to_string(), + file_count: 1, + overlap_factor: 0.0, + target_partitions: 5, + expected_partition_count: 1, // Should create only one partition + }, + TestCase { + name: "empty_files".to_string(), + file_count: 0, + overlap_factor: 0.0, + target_partitions: 3, + expected_partition_count: 0, // Empty result for empty input + }, + ]; + + for case in test_cases { + println!("Running test case: {}", case.name); + + // Generate files using bench utility function + let file_groups = generate_test_files(case.file_count, case.overlap_factor); + + // Call the function under test + let result = + FileScanConfig::split_groups_by_statistics_with_target_partitions( + &schema, + &file_groups, + &sort_ordering, + case.target_partitions, + )?; + + // Verify results + println!( + "Created {} partitions (target was {})", + result.len(), + case.target_partitions + ); + + // Check partition count + assert_eq!( + result.len(), + case.expected_partition_count, + "Case '{}': Unexpected partition count", + case.name + ); + + // Verify sort integrity + assert!( + verify_sort_integrity(&result), + "Case '{}': Files within partitions are not properly ordered", + case.name + ); + + // Distribution check for partitions + if case.file_count > 1 && case.expected_partition_count > 1 { + let group_sizes: Vec = result.iter().map(FileGroup::len).collect(); + let max_size = *group_sizes.iter().max().unwrap(); + let min_size = *group_sizes.iter().min().unwrap(); + + // Check partition balancing - difference shouldn't be extreme + let avg_files_per_partition = + case.file_count as f64 / case.expected_partition_count as f64; + assert!( + (max_size as f64) < 2.0 * avg_files_per_partition, + "Case '{}': Unbalanced distribution. Max partition size {} exceeds twice the average {}", + case.name, + max_size, + avg_files_per_partition + ); + + println!("Distribution - min files: {min_size}, max files: {max_size}"); + } + } + + // Test error case: zero target partitions + let empty_groups: Vec = vec![]; + let err = FileScanConfig::split_groups_by_statistics_with_target_partitions( + &schema, + &empty_groups, + &sort_ordering, + 0, + ) + .unwrap_err(); + + assert!( + err.to_string() + .contains("target_partitions must be greater than 0"), + "Expected error for zero target partitions" + ); + + Ok(()) + } } diff --git a/datafusion/datasource/src/file_stream.rs b/datafusion/datasource/src/file_stream.rs index 1caefc3277aca..e0b6c25a19162 100644 --- a/datafusion/datasource/src/file_stream.rs +++ b/datafusion/datasource/src/file_stream.rs @@ -27,7 +27,6 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::file_meta::FileMeta; use crate::file_scan_config::{FileScanConfig, PartitionColumnProjector}; use crate::PartitionedFile; use arrow::datatypes::SchemaRef; @@ -37,7 +36,6 @@ use datafusion_physical_plan::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time, }; -use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use datafusion_common::instant::Instant; use datafusion_common::ScalarValue; @@ -78,7 +76,7 @@ impl FileStream { file_opener: Arc, metrics: &ExecutionPlanMetricsSet, ) -> Result { - let (projected_schema, ..) = config.project(); + let projected_schema = config.projected_schema(); let pc_projector = PartitionColumnProjector::new( Arc::clone(&projected_schema), &config @@ -119,17 +117,11 @@ impl FileStream { fn start_next_file(&mut self) -> Option)>> { let part_file = self.file_iter.pop_front()?; - let file_meta = FileMeta { - object_meta: part_file.object_meta, - range: part_file.range, - extensions: part_file.extensions, - metadata_size_hint: part_file.metadata_size_hint, - }; - + let partition_values = part_file.partition_values.clone(); Some( self.file_opener - .open(file_meta) - .map(|future| (future, part_file.partition_values)), + .open(part_file) + .map(|future| (future, partition_values)), ) } @@ -224,7 +216,6 @@ impl FileStream { let result = self .pc_projector .project(batch, partition_values) - .map_err(|e| ArrowError::ExternalError(e.into())) .map(|batch| match &mut self.remain { Some(remain) => { if *remain > batch.num_rows() { @@ -246,7 +237,7 @@ impl FileStream { self.state = FileStreamState::Error } self.file_stream_metrics.time_scanning_total.start(); - return Poll::Ready(Some(result.map_err(Into::into))); + return Poll::Ready(Some(result)); } Some(Err(err)) => { self.file_stream_metrics.file_scan_errors.add(1); @@ -280,7 +271,7 @@ impl FileStream { }, OnError::Fail => { self.state = FileStreamState::Error; - return Poll::Ready(Some(Err(err.into()))); + return Poll::Ready(Some(Err(err))); } } } @@ -344,7 +335,7 @@ impl RecordBatchStream for FileStream { /// A fallible future that resolves to a stream of [`RecordBatch`] pub type FileOpenFuture = - BoxFuture<'static, Result>>>; + BoxFuture<'static, Result>>>; /// Describes the behavior of the `FileStream` if file opening or scanning fails pub enum OnError { @@ -367,7 +358,7 @@ impl Default for OnError { pub trait FileOpener: Unpin + Send + Sync { /// Asynchronously open the specified file and return a stream /// of [`RecordBatch`] - fn open(&self, file_meta: FileMeta) -> Result; + fn open(&self, partitioned_file: PartitionedFile) -> Result; } /// Represents the state of the next `FileOpenFuture`. Since we need to poll @@ -375,7 +366,7 @@ pub trait FileOpener: Unpin + Send + Sync { /// is ready pub enum NextOpen { Pending(FileOpenFuture), - Ready(Result>>), + Ready(Result>>), } pub enum FileStreamState { @@ -395,7 +386,7 @@ pub enum FileStreamState { /// Partitioning column values for the current batch_iter partition_values: Vec, /// The reader instance - reader: BoxStream<'static, Result>, + reader: BoxStream<'static, Result>, /// A [`FileOpenFuture`] for the next file to be processed, /// and its corresponding partition column values, if any. /// This allows the next file to be opened in parallel while the @@ -435,7 +426,7 @@ impl StartableTime { /// (not cpu time) so they include time spent waiting on I/O as well /// as other operators. /// -/// [`FileStream`]: +/// [`FileStream`]: pub struct FileStreamMetrics { /// Wall clock time elapsed for file opening. /// @@ -446,13 +437,13 @@ pub struct FileStreamMetrics { /// will open the next file in the background while scanning the /// current file. This metric will only capture time spent opening /// while not also scanning. - /// [`FileStream`]: + /// [`FileStream`]: pub time_opening: StartableTime, /// Wall clock time elapsed for file scanning + first record batch of decompression + decoding /// /// Time between when the [`FileStream`] requests data from the /// stream and when the first [`RecordBatch`] is produced. - /// [`FileStream`]: + /// [`FileStream`]: pub time_scanning_until_data: StartableTime, /// Total elapsed wall clock time for scanning + record batch decompression / decoding /// @@ -525,7 +516,6 @@ mod tests { use crate::file_scan_config::FileScanConfigBuilder; use crate::tests::make_partition; use crate::PartitionedFile; - use arrow::error::ArrowError; use datafusion_common::error::Result; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; @@ -533,13 +523,12 @@ mod tests { use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use crate::file_meta::FileMeta; use crate::file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; use crate::test_util::MockSource; use arrow::array::RecordBatch; use arrow::datatypes::Schema; - use datafusion_common::{assert_batches_eq, internal_err}; + use datafusion_common::{assert_batches_eq, exec_err, internal_err}; /// Test `FileOpener` which will simulate errors during file opening or scanning #[derive(Default)] @@ -555,15 +544,13 @@ mod tests { } impl FileOpener for TestOpener { - fn open(&self, _file_meta: FileMeta) -> Result { + fn open(&self, _partitioned_file: PartitionedFile) -> Result { let idx = self.current_idx.fetch_add(1, Ordering::SeqCst); if self.error_opening_idx.contains(&idx) { Ok(futures::future::ready(internal_err!("error opening")).boxed()) } else if self.error_scanning_idx.contains(&idx) { - let error = futures::future::ready(Err(ArrowError::IpcError( - "error scanning".to_owned(), - ))); + let error = futures::future::ready(exec_err!("error scanning")); let stream = futures::stream::once(error).boxed(); Ok(futures::future::ready(Ok(stream)).boxed()) } else { diff --git a/datafusion/datasource/src/memory.rs b/datafusion/datasource/src/memory.rs index 6d0e16ef4b916..eb55aa9b0b0d2 100644 --- a/datafusion/datasource/src/memory.rs +++ b/datafusion/datasource/src/memory.rs @@ -15,345 +15,46 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading in-memory batches of data - use std::any::Any; +use std::cmp::Ordering; +use std::collections::BinaryHeap; use std::fmt; use std::fmt::Debug; +use std::ops::Deref; use std::sync::Arc; use crate::sink::DataSink; use crate::source::{DataSource, DataSourceExec}; -use async_trait::async_trait; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::memory::MemoryStream; -use datafusion_physical_plan::projection::{ - all_alias_free_columns, new_projections_for_columns, ProjectionExec, -}; -use datafusion_physical_plan::{ - common, ColumnarValue, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PhysicalExpr, PlanProperties, SendableRecordBatchStream, Statistics, -}; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::{ - internal_err, plan_err, project_schema, Constraints, Result, ScalarValue, -}; +use datafusion_common::{internal_err, plan_err, project_schema, Result, ScalarValue}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::equivalence::project_orderings; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; +use datafusion_physical_plan::memory::MemoryStream; +use datafusion_physical_plan::projection::{ + all_alias_free_columns, new_projections_for_columns, ProjectionExpr, +}; +use datafusion_physical_plan::{ + common, ColumnarValue, DisplayAs, DisplayFormatType, Partitioning, PhysicalExpr, + SendableRecordBatchStream, Statistics, +}; + +use async_trait::async_trait; +use datafusion_physical_plan::coop::cooperative; +use datafusion_physical_plan::execution_plan::SchedulingType; use futures::StreamExt; +use itertools::Itertools; use tokio::sync::RwLock; -/// Execution plan for reading in-memory batches of data -#[derive(Clone)] -#[deprecated( - since = "46.0.0", - note = "use MemorySourceConfig and DataSourceExec instead" -)] -pub struct MemoryExec { - inner: DataSourceExec, - /// The partitions to query - partitions: Vec>, - /// Optional projection - projection: Option>, - // Sort information: one or more equivalent orderings - sort_information: Vec, - /// if partition sizes should be displayed - show_sizes: bool, -} - -#[allow(unused, deprecated)] -impl Debug for MemoryExec { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(DisplayFormatType::Default, f) - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for MemoryExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for MemoryExec { - fn name(&self) -> &'static str { - "MemoryExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // This is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - // MemoryExec has no children - if children.is_empty() { - Ok(self) - } else { - internal_err!("Children cannot be replaced in {self:?}") - } - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - /// We recompute the statistics dynamically from the arrow metadata as it is pretty cheap to do so - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn try_swapping_with_projection( - &self, - projection: &ProjectionExec, - ) -> Result>> { - self.inner.try_swapping_with_projection(projection) - } -} - -#[allow(unused, deprecated)] -impl MemoryExec { - /// Create a new execution plan for reading in-memory record batches - /// The provided `schema` should not have the projection applied. - pub fn try_new( - partitions: &[Vec], - schema: SchemaRef, - projection: Option>, - ) -> Result { - let source = MemorySourceConfig::try_new(partitions, schema, projection.clone())?; - let data_source = DataSourceExec::new(Arc::new(source)); - Ok(Self { - inner: data_source, - partitions: partitions.to_vec(), - projection, - sort_information: vec![], - show_sizes: true, - }) - } - - /// Create a new execution plan from a list of constant values (`ValuesExec`) - pub fn try_new_as_values( - schema: SchemaRef, - data: Vec>>, - ) -> Result { - if data.is_empty() { - return plan_err!("Values list cannot be empty"); - } - - let n_row = data.len(); - let n_col = schema.fields().len(); - - // We have this single row batch as a placeholder to satisfy evaluation argument - // and generate a single output row - let placeholder_schema = Arc::new(Schema::empty()); - let placeholder_batch = RecordBatch::try_new_with_options( - Arc::clone(&placeholder_schema), - vec![], - &RecordBatchOptions::new().with_row_count(Some(1)), - )?; - - // Evaluate each column - let arrays = (0..n_col) - .map(|j| { - (0..n_row) - .map(|i| { - let expr = &data[i][j]; - let result = expr.evaluate(&placeholder_batch)?; - - match result { - ColumnarValue::Scalar(scalar) => Ok(scalar), - ColumnarValue::Array(array) if array.len() == 1 => { - ScalarValue::try_from_array(&array, 0) - } - ColumnarValue::Array(_) => { - plan_err!("Cannot have array values in a values list") - } - } - }) - .collect::>>() - .and_then(ScalarValue::iter_to_array) - }) - .collect::>>()?; - - let batch = RecordBatch::try_new_with_options( - Arc::clone(&schema), - arrays, - &RecordBatchOptions::new().with_row_count(Some(n_row)), - )?; - - let partitions = vec![batch]; - Self::try_new_from_batches(Arc::clone(&schema), partitions) - } - - /// Create a new plan using the provided schema and batches. - /// - /// Errors if any of the batches don't match the provided schema, or if no - /// batches are provided. - pub fn try_new_from_batches( - schema: SchemaRef, - batches: Vec, - ) -> Result { - if batches.is_empty() { - return plan_err!("Values list cannot be empty"); - } - - for batch in &batches { - let batch_schema = batch.schema(); - if batch_schema != schema { - return plan_err!( - "Batch has invalid schema. Expected: {}, got: {}", - schema, - batch_schema - ); - } - } - - let partitions = vec![batches]; - let source = MemorySourceConfig { - partitions: partitions.clone(), - schema: Arc::clone(&schema), - projected_schema: Arc::clone(&schema), - projection: None, - sort_information: vec![], - show_sizes: true, - fetch: None, - }; - let data_source = DataSourceExec::new(Arc::new(source)); - Ok(Self { - inner: data_source, - partitions, - projection: None, - sort_information: vec![], - show_sizes: true, - }) - } - - fn memory_source_config(&self) -> MemorySourceConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.inner = self.inner.with_constraints(constraints); - self - } - - /// Set `show_sizes` to determine whether to display partition sizes - pub fn with_show_sizes(mut self, show_sizes: bool) -> Self { - let mut memory_source = self.memory_source_config(); - memory_source.show_sizes = show_sizes; - self.show_sizes = show_sizes; - self.inner = DataSourceExec::new(Arc::new(memory_source)); - self - } - - /// Ref to constraints - pub fn constraints(&self) -> &Constraints { - self.properties().equivalence_properties().constraints() - } - - /// Ref to partitions - pub fn partitions(&self) -> &[Vec] { - &self.partitions - } - - /// Ref to projection - pub fn projection(&self) -> &Option> { - &self.projection - } - - /// Show sizes - pub fn show_sizes(&self) -> bool { - self.show_sizes - } - - /// Ref to sort information - pub fn sort_information(&self) -> &[LexOrdering] { - &self.sort_information - } - - /// A memory table can be ordered by multiple expressions simultaneously. - /// [`EquivalenceProperties`] keeps track of expressions that describe the - /// global ordering of the schema. These columns are not necessarily same; e.g. - /// ```text - /// ┌-------┐ - /// | a | b | - /// |---|---| - /// | 1 | 9 | - /// | 2 | 8 | - /// | 3 | 7 | - /// | 5 | 5 | - /// └---┴---┘ - /// ``` - /// where both `a ASC` and `b DESC` can describe the table ordering. With - /// [`EquivalenceProperties`], we can keep track of these equivalences - /// and treat `a ASC` and `b DESC` as the same ordering requirement. - /// - /// Note that if there is an internal projection, that projection will be - /// also applied to the given `sort_information`. - pub fn try_with_sort_information( - mut self, - sort_information: Vec, - ) -> Result { - self.sort_information = sort_information.clone(); - let mut memory_source = self.memory_source_config(); - memory_source = memory_source.try_with_sort_information(sort_information)?; - self.inner = DataSourceExec::new(Arc::new(memory_source)); - Ok(self) - } - - /// Arc clone of ref to original schema - pub fn original_schema(&self) -> SchemaRef { - Arc::clone(&self.inner.schema()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - partitions: &[Vec], - ) -> PlanProperties { - PlanProperties::new( - EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints), - Partitioning::UnknownPartitioning(partitions.len()), - EmissionType::Incremental, - Boundedness::Bounded, - ) - } -} - /// Data source configuration for reading in-memory batches of data #[derive(Clone, Debug)] pub struct MemorySourceConfig { - /// The partitions to query + /// The partitions to query. + /// + /// Each partition is a `Vec`. partitions: Vec>, /// Schema representing the data before projection schema: SchemaRef, @@ -376,14 +77,14 @@ impl DataSource for MemorySourceConfig { partition: usize, _context: Arc, ) -> Result { - Ok(Box::pin( + Ok(Box::pin(cooperative( MemoryStream::try_new( self.partitions[partition].clone(), Arc::clone(&self.projected_schema), self.projection.clone(), )? .with_fetch(self.fetch), - )) + ))) } fn as_any(&self) -> &dyn Any { @@ -399,9 +100,7 @@ impl DataSource for MemorySourceConfig { let output_ordering = self .sort_information .first() - .map(|output_ordering| { - format!(", output_ordering={}", output_ordering) - }) + .map(|output_ordering| format!(", output_ordering={output_ordering}")) .unwrap_or_default(); let eq_properties = self.eq_properties(); @@ -409,12 +108,12 @@ impl DataSource for MemorySourceConfig { let constraints = if constraints.is_empty() { String::new() } else { - format!(", {}", constraints) + format!(", {constraints}") }; let limit = self .fetch - .map_or(String::new(), |limit| format!(", fetch={}", limit)); + .map_or(String::new(), |limit| format!(", fetch={limit}")); if self.show_sizes { write!( f, @@ -445,6 +144,39 @@ impl DataSource for MemorySourceConfig { } } + /// If possible, redistribute batches across partitions according to their size. + /// + /// Returns `Ok(None)` if unable to repartition. Preserve output ordering if exists. + /// Refer to [`DataSource::repartitioned`] for further details. + fn repartitioned( + &self, + target_partitions: usize, + _repartition_file_min_size: usize, + output_ordering: Option, + ) -> Result>> { + if self.partitions.is_empty() || self.partitions.len() >= target_partitions + // if have no partitions, or already have more partitions than desired, do not repartition + { + return Ok(None); + } + + let maybe_repartitioned = if let Some(output_ordering) = output_ordering { + self.repartition_preserving_order(target_partitions, output_ordering)? + } else { + self.repartition_evenly_by_size(target_partitions)? + }; + + if let Some(repartitioned) = maybe_repartitioned { + Ok(Some(Arc::new(Self::try_new( + &repartitioned, + self.original_schema(), + self.projection.clone(), + )?))) + } else { + Ok(None) + } + } + fn output_partitioning(&self) -> Partitioning { Partitioning::UnknownPartitioning(self.partitions.len()) } @@ -452,10 +184,14 @@ impl DataSource for MemorySourceConfig { fn eq_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new_with_orderings( Arc::clone(&self.projected_schema), - self.sort_information.as_slice(), + self.sort_information.clone(), ) } + fn scheduling_type(&self) -> SchedulingType { + SchedulingType::Cooperative + } + fn statistics(&self) -> Result { Ok(common::compute_record_batch_statistics( &self.partitions, @@ -475,11 +211,11 @@ impl DataSource for MemorySourceConfig { fn try_swapping_with_projection( &self, - projection: &ProjectionExec, - ) -> Result>> { + projection: &[ProjectionExpr], + ) -> Result>> { // If there is any non-column or alias-carrier expression, Projection should not be removed. // This process can be moved into MemoryExec, but it would be an overlap of their responsibility. - all_alias_free_columns(projection.expr()) + all_alias_free_columns(projection) .then(|| { let all_projections = (0..self.schema.fields().len()).collect(); let new_projections = new_projections_for_columns( @@ -487,12 +223,12 @@ impl DataSource for MemorySourceConfig { self.projection().as_ref().unwrap_or(&all_projections), ); - MemorySourceConfig::try_new_exec( + MemorySourceConfig::try_new( self.partitions(), self.original_schema(), Some(new_projections), ) - .map(|e| e as _) + .map(|s| Arc::new(s) as Arc) }) .transpose() } @@ -694,25 +430,9 @@ impl MemorySourceConfig { } // If there is a projection on the source, we also need to project orderings - if let Some(projection) = &self.projection { - let base_eqp = EquivalenceProperties::new_with_orderings( - self.original_schema(), - &sort_information, - ); - let proj_exprs = projection - .iter() - .map(|idx| { - let base_schema = self.original_schema(); - let name = base_schema.field(*idx).name(); - (Arc::new(Column::new(name, *idx)) as _, name.to_string()) - }) - .collect::>(); - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; - sort_information = base_eqp - .project(&projection_mapping, Arc::clone(&self.projected_schema)) - .into_oeq_class() - .into_inner(); + if self.projection.is_some() { + sort_information = + project_orderings(&sort_information, &self.projected_schema); } self.sort_information = sort_information; @@ -723,6 +443,239 @@ impl MemorySourceConfig { pub fn original_schema(&self) -> SchemaRef { Arc::clone(&self.schema) } + + /// Repartition while preserving order. + /// + /// Returns `Ok(None)` if cannot fulfill the requested repartitioning, such + /// as having too few batches to fulfill the `target_partitions` or if unable + /// to preserve output ordering. + fn repartition_preserving_order( + &self, + target_partitions: usize, + output_ordering: LexOrdering, + ) -> Result>>> { + if !self.eq_properties().ordering_satisfy(output_ordering)? { + Ok(None) + } else { + let total_num_batches = + self.partitions.iter().map(|b| b.len()).sum::(); + if total_num_batches < target_partitions { + // no way to create the desired repartitioning + return Ok(None); + } + + let cnt_to_repartition = target_partitions - self.partitions.len(); + + // Label the current partitions and their order. + // Such that when we later split up the partitions into smaller sizes, we are maintaining the order. + let to_repartition = self + .partitions + .iter() + .enumerate() + .map(|(idx, batches)| RePartition { + idx: idx + (cnt_to_repartition * idx), // make space in ordering for split partitions + row_count: batches.iter().map(|batch| batch.num_rows()).sum(), + batches: batches.clone(), + }) + .collect_vec(); + + // Put all of the partitions into a heap ordered by `RePartition::partial_cmp`, which sizes + // by count of rows. + let mut max_heap = BinaryHeap::with_capacity(target_partitions); + for rep in to_repartition { + max_heap.push(CompareByRowCount(rep)); + } + + // Split the largest partitions into smaller partitions. Maintaining the output + // order of the partitions & newly created partitions. + let mut cannot_split_further = Vec::with_capacity(target_partitions); + for _ in 0..cnt_to_repartition { + // triggers loop for the cnt_to_repartition. So if need another 4 partitions, it attempts to split 4 times. + loop { + // Take the largest item off the heap, and attempt to split. + let Some(to_split) = max_heap.pop() else { + // Nothing left to attempt repartition. Break inner loop. + break; + }; + + // Split the partition. The new partitions will be ordered with idx and idx+1. + let mut new_partitions = to_split.into_inner().split(); + if new_partitions.len() > 1 { + for new_partition in new_partitions { + max_heap.push(CompareByRowCount(new_partition)); + } + // Successful repartition. Break inner loop, and return to outer `cnt_to_repartition` loop. + break; + } else { + cannot_split_further.push(new_partitions.remove(0)); + } + } + } + let mut partitions = max_heap + .drain() + .map(CompareByRowCount::into_inner) + .collect_vec(); + partitions.extend(cannot_split_further); + + // Finally, sort all partitions by the output ordering. + // This was the original ordering of the batches within the partition. We are maintaining this ordering. + partitions.sort_by_key(|p| p.idx); + let partitions = partitions.into_iter().map(|rep| rep.batches).collect_vec(); + + Ok(Some(partitions)) + } + } + + /// Repartition into evenly sized chunks (as much as possible without batch splitting), + /// disregarding any ordering. + /// + /// Current implementation uses a first-fit-decreasing bin packing, modified to enable + /// us to still return the desired count of `target_partitions`. + /// + /// Returns `Ok(None)` if cannot fulfill the requested repartitioning, such + /// as having too few batches to fulfill the `target_partitions`. + fn repartition_evenly_by_size( + &self, + target_partitions: usize, + ) -> Result>>> { + // determine if we have enough total batches to fulfill request + let mut flatten_batches = + self.partitions.clone().into_iter().flatten().collect_vec(); + if flatten_batches.len() < target_partitions { + return Ok(None); + } + + // Take all flattened batches (all in 1 partititon/vec) and divide evenly into the desired number of `target_partitions`. + let total_num_rows = flatten_batches.iter().map(|b| b.num_rows()).sum::(); + // sort by size, so we pack multiple smaller batches into the same partition + flatten_batches.sort_by_key(|b| std::cmp::Reverse(b.num_rows())); + + // Divide. + let mut partitions = + vec![Vec::with_capacity(flatten_batches.len()); target_partitions]; + let mut target_partition_size = total_num_rows.div_ceil(target_partitions); + let mut total_rows_seen = 0; + let mut curr_bin_row_count = 0; + let mut idx = 0; + for batch in flatten_batches { + let row_cnt = batch.num_rows(); + idx = std::cmp::min(idx, target_partitions - 1); + + partitions[idx].push(batch); + curr_bin_row_count += row_cnt; + total_rows_seen += row_cnt; + + if curr_bin_row_count >= target_partition_size { + idx += 1; + curr_bin_row_count = 0; + + // update target_partition_size, to handle very lopsided batch distributions + // while still returning the count of `target_partitions` + if total_rows_seen < total_num_rows { + target_partition_size = (total_num_rows - total_rows_seen) + .div_ceil(target_partitions - idx); + } + } + } + + Ok(Some(partitions)) + } +} + +/// For use in repartitioning, track the total size and original partition index. +/// +/// Do not implement clone, in order to avoid unnecessary copying during repartitioning. +struct RePartition { + /// Original output ordering for the partition. + idx: usize, + /// Total size of the partition, for use in heap ordering + /// (a.k.a. splitting up the largest partitions). + row_count: usize, + /// A partition containing record batches. + batches: Vec, +} + +impl RePartition { + /// Split [`RePartition`] into 2 pieces, consuming self. + /// + /// Returns only 1 partition if cannot be split further. + fn split(self) -> Vec { + if self.batches.len() == 1 { + return vec![self]; + } + + let new_0 = RePartition { + idx: self.idx, // output ordering + row_count: 0, + batches: vec![], + }; + let new_1 = RePartition { + idx: self.idx + 1, // output ordering +1 + row_count: 0, + batches: vec![], + }; + let split_pt = self.row_count / 2; + + let [new_0, new_1] = self.batches.into_iter().fold( + [new_0, new_1], + |[mut new0, mut new1], batch| { + if new0.row_count < split_pt { + new0.add_batch(batch); + } else { + new1.add_batch(batch); + } + [new0, new1] + }, + ); + vec![new_0, new_1] + } + + fn add_batch(&mut self, batch: RecordBatch) { + self.row_count += batch.num_rows(); + self.batches.push(batch); + } +} + +impl fmt::Display for RePartition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}rows-in-{}batches@{}", + self.row_count, + self.batches.len(), + self.idx + ) + } +} + +struct CompareByRowCount(RePartition); +impl CompareByRowCount { + fn into_inner(self) -> RePartition { + self.0 + } +} +impl Ord for CompareByRowCount { + fn cmp(&self, other: &Self) -> Ordering { + self.0.row_count.cmp(&other.0.row_count) + } +} +impl PartialOrd for CompareByRowCount { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl PartialEq for CompareByRowCount { + fn eq(&self, other: &Self) -> bool { + // PartialEq must be consistent with PartialOrd + self.cmp(other) == Ordering::Equal + } +} +impl Eq for CompareByRowCount {} +impl Deref for CompareByRowCount { + type Target = RePartition; + fn deref(&self) -> &Self::Target { + &self.0 + } } /// Type alias for partition data @@ -816,22 +769,22 @@ mod memory_source_tests { use crate::memory::MemorySourceConfig; use crate::source::DataSourceExec; - use datafusion_physical_plan::ExecutionPlan; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::PhysicalSortExpr; - use datafusion_physical_expr_common::sort_expr::LexOrdering; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_plan::ExecutionPlan; #[test] - fn test_memory_order_eq() -> datafusion_common::Result<()> { + fn test_memory_order_eq() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, false), Field::new("b", DataType::Int64, false), Field::new("c", DataType::Int64, false), ])); - let sort1 = LexOrdering::new(vec![ + let sort1: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), @@ -840,13 +793,14 @@ mod memory_source_tests { expr: col("b", &schema)?, options: SortOptions::default(), }, - ]); - let sort2 = LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(); + let sort2: LexOrdering = [PhysicalSortExpr { expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let mut expected_output_order = LexOrdering::default(); - expected_output_order.extend(sort1.clone()); + }] + .into(); + let mut expected_output_order = sort1.clone(); expected_output_order.extend(sort2.clone()); let sort_information = vec![sort1.clone(), sort2.clone()]; @@ -868,15 +822,18 @@ mod memory_source_tests { #[cfg(test)] mod tests { - use crate::tests::{aggr_test_schema, make_partition}; - use super::*; + use crate::test_util::col; + use crate::tests::{aggr_test_schema, make_partition}; - use datafusion_physical_plan::expressions::lit; - + use arrow::array::{ArrayRef, Int32Array, Int64Array, StringArray}; use arrow::datatypes::{DataType, Field}; use datafusion_common::assert_batches_eq; use datafusion_common::stats::{ColumnStatistics, Precision}; + use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::expressions::lit; + + use datafusion_physical_plan::ExecutionPlan; use futures::StreamExt; #[tokio::test] @@ -976,7 +933,7 @@ mod tests { )?; assert_eq!( - values.statistics()?, + values.partition_statistics(None)?, Statistics { num_rows: Precision::Exact(rows), total_byte_size: Precision::Exact(8), // not important @@ -992,4 +949,458 @@ mod tests { Ok(()) } + + fn batch(row_size: usize) -> RecordBatch { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("foo"); row_size])); + let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![1; row_size])); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + } + + fn schema() -> SchemaRef { + batch(1).schema() + } + + fn memorysrcconfig_no_partitions( + sort_information: Vec, + ) -> Result { + let partitions = vec![]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_1_partition_1_batch( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_3_partitions_1_batch_each( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100)], vec![batch(100)], vec![batch(100)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_3_partitions_with_2_batches_each( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![batch(100), batch(100)], + vec![batch(100), batch(100)], + vec![batch(100), batch(100)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Batches of different sizes, with batches ordered by size (100_000, 10_000, 100, 1) + /// in the Memtable partition (a.k.a. vector of batches). + fn memorysrcconfig_1_partition_with_different_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100_000), batch(10_000), batch(100), batch(1)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Same as [`memorysrcconfig_1_partition_with_different_sized_batches`], + /// but the batches are ordered differently (not by size) + /// in the Memtable partition (a.k.a. vector of batches). + fn memorysrcconfig_1_partition_with_ordering_not_matching_size( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100_000), batch(1), batch(100), batch(10_000)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_2_partition_with_different_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![batch(100_000), batch(10_000), batch(1_000)], + vec![batch(2_000), batch(20)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_2_partition_with_extreme_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![ + batch(100_000), + batch(1), + batch(1), + batch(1), + batch(1), + batch(0), + ], + vec![batch(1), batch(1), batch(1), batch(1), batch(0), batch(100)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Assert that we get the expected count of partitions after repartitioning. + /// + /// If None, then we expected the [`DataSource::repartitioned`] to return None. + fn assert_partitioning( + partitioned_datasrc: Option>, + partition_cnt: Option, + ) { + let should_exist = if let Some(partition_cnt) = partition_cnt { + format!("new datasource should exist and have {partition_cnt:?} partitions") + } else { + "new datasource should not exist".into() + }; + + let actual = partitioned_datasrc + .map(|datasrc| datasrc.output_partitioning().partition_count()); + assert_eq!( + actual, + partition_cnt, + "partitioned datasrc does not match expected, we expected {should_exist}, instead found {actual:?}" + ); + } + + fn run_all_test_scenarios( + output_ordering: Option, + sort_information_on_config: Vec, + ) -> Result<()> { + let not_used = usize::MAX; + + // src has no partitions + let mem_src_config = + memorysrcconfig_no_partitions(sort_information_on_config.clone())?; + let partitioned_datasrc = + mem_src_config.repartitioned(1, not_used, output_ordering.clone())?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions == target partitions (=1) + let target_partitions = 1; + let mem_src_config = + memorysrcconfig_1_partition_1_batch(sort_information_on_config.clone())?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions == target partitions (=3) + let target_partitions = 3; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions > target partitions, but we don't merge them + let target_partitions = 2; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions < target partitions, but not enough batches (per partition) to split into more partitions + let target_partitions = 4; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions < target partitions, and can split to sufficient amount + // has 6 batches across 3 partitions. Will need to split 2 of it's partitions. + let target_partitions = 5; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, Some(5)); + + // src has partitions < target partitions, and can split to sufficient amount + // has 6 batches across 3 partitions. Will need to split all of it's partitions. + let target_partitions = 6; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, Some(6)); + + // src has partitions < target partitions, but not enough total batches to fulfill the split (desired target_partitions) + let target_partitions = 3 * 2 + 1; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has 1 partition with many batches of lopsided sizes + // make sure it handles the split properly + let target_partitions = 2; + let mem_src_config = memorysrcconfig_1_partition_with_different_sized_batches( + sort_information_on_config, + )?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + not_used, + output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(2)); + // Starting = batch(100_000), batch(10_000), batch(100), batch(1). + // It should have split as p1=batch(100_000), p2=[batch(10_000), batch(100), batch(1)] + let partitioned_datasrc = partitioned_datasrc.unwrap(); + let Some(mem_src_config) = partitioned_datasrc + .as_any() + .downcast_ref::() + else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config.partitions.clone(); + assert_eq!(repartitioned_raw_batches.len(), 2); + let [ref p1, ref p2] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(10_000), batch(100), batch(1)] + assert_eq!(p2.len(), 3); + assert_eq!(p2[0].num_rows(), 10_000); + assert_eq!(p2[1].num_rows(), 100); + assert_eq!(p2[2].num_rows(), 1); + + Ok(()) + } + + #[test] + fn test_repartition_no_sort_information_no_output_ordering() -> Result<()> { + let no_sort = vec![]; + let no_output_ordering = None; + + // Test: common set of functionality + run_all_test_scenarios(no_output_ordering.clone(), no_sort.clone())?; + + // Test: how no-sort-order divides differently. + // * does not preserve separate partitions (with own internal ordering) on even split, + // * nor does it preserve ordering (re-orders batch(2_000) vs batch(1_000)). + let target_partitions = 3; + let mem_src_config = + memorysrcconfig_2_partition_with_different_sized_batches(no_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + no_output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(3)); + // Starting = batch(100_000), batch(10_000), batch(1_000), batch(2_000), batch(20) + // It should have split as p1=batch(100_000), p2=batch(10_000), p3=rest(mixed across original partitions) + let repartitioned_raw_batches = mem_src_config + .repartition_evenly_by_size(target_partitions)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 3); + let [ref p1, ref p2, ref p3] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=batch(10_000) + assert_eq!(p2.len(), 1); + assert_eq!(p2[0].num_rows(), 10_000); + // p3= batch(2_000), batch(1_000), batch(20) + assert_eq!(p3.len(), 3); + assert_eq!(p3[0].num_rows(), 2_000); + assert_eq!(p3[1].num_rows(), 1_000); + assert_eq!(p3[2].num_rows(), 20); + + Ok(()) + } + + #[test] + fn test_repartition_no_sort_information_no_output_ordering_lopsized_batches( + ) -> Result<()> { + let no_sort = vec![]; + let no_output_ordering = None; + + // Test: case has two input partitions: + // b(100_000), b(1), b(1), b(1), b(1), b(0) + // b(1), b(1), b(1), b(1), b(0), b(100) + // + // We want an output with target_partitions=5, which means the ideal division is: + // b(100_000) + // b(100) + // b(1), b(1), b(1) + // b(1), b(1), b(1) + // b(1), b(1), b(0) + let target_partitions = 5; + let mem_src_config = + memorysrcconfig_2_partition_with_extreme_sized_batches(no_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + no_output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(5)); + // Starting partition 1 = batch(100_000), batch(1), batch(1), batch(1), batch(1), batch(0) + // Starting partition 1 = batch(1), batch(1), batch(1), batch(1), batch(0), batch(100) + // It should have split as p1=batch(100_000), p2=batch(100), p3=[batch(1),batch(1)], p4=[batch(1),batch(1)], p5=[batch(1),batch(1),batch(0),batch(0)] + let repartitioned_raw_batches = mem_src_config + .repartition_evenly_by_size(target_partitions)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 5); + let [ref p1, ref p2, ref p3, ref p4, ref p5] = repartitioned_raw_batches[..] + else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=batch(100) + assert_eq!(p2.len(), 1); + assert_eq!(p2[0].num_rows(), 100); + // p3=[batch(1),batch(1),batch(1)] + assert_eq!(p3.len(), 3); + assert_eq!(p3[0].num_rows(), 1); + assert_eq!(p3[1].num_rows(), 1); + assert_eq!(p3[2].num_rows(), 1); + // p4=[batch(1),batch(1),batch(1)] + assert_eq!(p4.len(), 3); + assert_eq!(p4[0].num_rows(), 1); + assert_eq!(p4[1].num_rows(), 1); + assert_eq!(p4[2].num_rows(), 1); + // p5=[batch(1),batch(1),batch(0),batch(0)] + assert_eq!(p5.len(), 4); + assert_eq!(p5[0].num_rows(), 1); + assert_eq!(p5[1].num_rows(), 1); + assert_eq!(p5[2].num_rows(), 0); + assert_eq!(p5[3].num_rows(), 0); + + Ok(()) + } + + #[test] + fn test_repartition_with_sort_information() -> Result<()> { + let schema = schema(); + let sort_key: LexOrdering = + [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); + let has_sort = vec![sort_key.clone()]; + let output_ordering = Some(sort_key); + + // Test: common set of functionality + run_all_test_scenarios(output_ordering.clone(), has_sort.clone())?; + + // Test: DOES preserve separate partitions (with own internal ordering) + let target_partitions = 3; + let mem_src_config = + memorysrcconfig_2_partition_with_different_sized_batches(has_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(3)); + // Starting = batch(100_000), batch(10_000), batch(1_000), batch(2_000), batch(20) + // It should have split as p1=batch(100_000), p2=[batch(10_000),batch(1_000)], p3= + let Some(output_ord) = output_ordering else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config + .repartition_preserving_order(target_partitions, output_ord)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 3); + let [ref p1, ref p2, ref p3] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(10_000),batch(1_000)] + assert_eq!(p2.len(), 2); + assert_eq!(p2[0].num_rows(), 10_000); + assert_eq!(p2[1].num_rows(), 1_000); + // p3=batch(2_000), batch(20) + assert_eq!(p3.len(), 2); + assert_eq!(p3[0].num_rows(), 2_000); + assert_eq!(p3[1].num_rows(), 20); + + Ok(()) + } + + #[test] + fn test_repartition_with_batch_ordering_not_matching_sizing() -> Result<()> { + let schema = schema(); + let sort_key: LexOrdering = + [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); + let has_sort = vec![sort_key.clone()]; + let output_ordering = Some(sort_key); + + // src has 1 partition with many batches of lopsided sizes + // note that the input vector of batches are not ordered by decreasing size + let target_partitions = 2; + let mem_src_config = + memorysrcconfig_1_partition_with_ordering_not_matching_size(has_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(2)); + // Starting = batch(100_000), batch(1), batch(100), batch(10_000). + // It should have split as p1=batch(100_000), p2=[batch(1), batch(100), batch(10_000)] + let partitioned_datasrc = partitioned_datasrc.unwrap(); + let Some(mem_src_config) = partitioned_datasrc + .as_any() + .downcast_ref::() + else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config.partitions.clone(); + assert_eq!(repartitioned_raw_batches.len(), 2); + let [ref p1, ref p2] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(1), batch(100), batch(10_000)] -- **this is preserving the partition order** + assert_eq!(p2.len(), 3); + assert_eq!(p2[0].num_rows(), 1); + assert_eq!(p2[1].num_rows(), 100); + assert_eq!(p2[2].num_rows(), 10_000); + + Ok(()) + } } diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index e4461c0b90a44..1f47c0983ea10 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] @@ -33,7 +33,6 @@ pub mod file; pub mod file_compression_type; pub mod file_format; pub mod file_groups; -pub mod file_meta; pub mod file_scan_config; pub mod file_sink_config; pub mod file_stream; @@ -44,25 +43,28 @@ pub mod source; mod statistics; #[cfg(test)] -mod test_util; +pub mod test_util; pub mod url; pub mod write; +pub use self::file::as_file_source; +pub use self::url::ListingTableUrl; +use crate::file_groups::FileGroup; use chrono::TimeZone; -use datafusion_common::Result; +use datafusion_common::stats::Precision; +use datafusion_common::{exec_datafusion_err, ColumnStatistics, Result}; use datafusion_common::{ScalarValue, Statistics}; -use file_meta::FileMeta; use futures::{Stream, StreamExt}; use object_store::{path::Path, ObjectMeta}; use object_store::{GetOptions, GetRange, ObjectStore}; +// Remove when add_row_stats is remove +#[allow(deprecated)] +pub use statistics::add_row_stats; +pub use statistics::compute_all_files_statistics; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; -pub use self::url::ListingTableUrl; -pub use statistics::add_row_stats; -pub use statistics::compute_all_files_statistics; - /// Stream of files get listed from object store pub type PartitionedFileStream = Pin> + Send + Sync + 'static>>; @@ -98,9 +100,9 @@ pub struct PartitionedFile { /// You may use [`wrap_partition_value_in_dict`] to wrap them if you have used [`wrap_partition_type_in_dict`] to wrap the column type. /// /// - /// [`wrap_partition_type_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L55 - /// [`wrap_partition_value_in_dict`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/physical_plan/file_scan_config.rs#L62 - /// [`table_partition_cols`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/file_format/options.rs#L190 + /// [`wrap_partition_type_in_dict`]: crate::file_scan_config::wrap_partition_type_in_dict + /// [`wrap_partition_value_in_dict`]: crate::file_scan_config::wrap_partition_value_in_dict + /// [`table_partition_cols`]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/datasource/file_format/options.rs#L87 pub partition_values: Vec, /// An optional file range for a more fine-grained parallel execution pub range: Option, @@ -122,7 +124,7 @@ impl PartitionedFile { object_meta: ObjectMeta { location: Path::from(path.into()), last_modified: chrono::Utc.timestamp_nanos(0), - size: size as usize, + size, e_tag: None, version: None, }, @@ -140,7 +142,7 @@ impl PartitionedFile { object_meta: ObjectMeta { location: Path::from(path), last_modified: chrono::Utc.timestamp_nanos(0), - size: size as usize, + size, e_tag: None, version: None, }, @@ -194,6 +196,23 @@ impl PartitionedFile { self.statistics = Some(statistics); self } + + /// Check if this file has any statistics. + /// This returns `true` if the file has any Exact or Inexact statistics + /// and `false` if all statistics are `Precision::Absent`. + pub fn has_statistics(&self) -> bool { + if let Some(stats) = &self.statistics { + stats.column_statistics.iter().any(|col_stats| { + col_stats.null_count != Precision::Absent + || col_stats.max_value != Precision::Absent + || col_stats.min_value != Precision::Absent + || col_stats.sum_value != Precision::Absent + || col_stats.distinct_count != Precision::Absent + }) + } else { + false + } + } } impl From for PartitionedFile { @@ -223,33 +242,38 @@ impl From for PartitionedFile { /// Indicates that the range calculation determined no further action is /// necessary, possibly because the calculated range is empty or invalid. pub enum RangeCalculation { - Range(Option>), + Range(Option>), TerminateEarly, } /// Calculates an appropriate byte range for reading from an object based on the /// provided metadata. /// -/// This asynchronous function examines the `FileMeta` of an object in an object store +/// This asynchronous function examines the [`PartitionedFile`] of an object in an object store /// and determines the range of bytes to be read. The range calculation may adjust /// the start and end points to align with meaningful data boundaries (like newlines). /// -/// Returns a `Result` wrapping a `RangeCalculation`, which is either a calculated byte range or an indication to terminate early. +/// Returns a `Result` wrapping a [`RangeCalculation`], which is either a calculated byte range or an indication to terminate early. /// /// Returns an `Error` if any part of the range calculation fails, such as issues in reading from the object store or invalid range boundaries. pub async fn calculate_range( - file_meta: &FileMeta, + file: &PartitionedFile, store: &Arc, terminator: Option, ) -> Result { - let location = file_meta.location(); - let file_size = file_meta.object_meta.size; + let location = &file.object_meta.location; + let file_size = file.object_meta.size; let newline = terminator.unwrap_or(b'\n'); - match file_meta.range { + match file.range { None => Ok(RangeCalculation::Range(None)), Some(FileRange { start, end }) => { - let (start, end) = (start as usize, end as usize); + let start: u64 = start.try_into().map_err(|_| { + exec_datafusion_err!("Expect start range to fit in u64, got {start}") + })?; + let end: u64 = end.try_into().map_err(|_| { + exec_datafusion_err!("Expect end range to fit in u64, got {end}") + })?; let start_delta = if start != 0 { find_first_newline(store, location, start - 1, file_size, newline).await? @@ -288,10 +312,10 @@ pub async fn calculate_range( async fn find_first_newline( object_store: &Arc, location: &Path, - start: usize, - end: usize, + start: u64, + end: u64, newline: u8, -) -> Result { +) -> Result { let options = GetOptions { range: Some(GetRange::Bounded(start..end)), ..Default::default() @@ -304,15 +328,125 @@ async fn find_first_newline( while let Some(chunk) = result_stream.next().await.transpose()? { if let Some(position) = chunk.iter().position(|&byte| byte == newline) { + let position = position as u64; return Ok(index + position); } - index += chunk.len(); + index += chunk.len() as u64; } Ok(index) } +/// Generates test files with min-max statistics in different overlap patterns. +/// +/// Used by tests and benchmarks. +/// +/// # Overlap Factors +/// +/// The `overlap_factor` parameter controls how much the value ranges in generated test files overlap: +/// - `0.0`: No overlap between files (completely disjoint ranges) +/// - `0.2`: Low overlap (20% of the range size overlaps with adjacent files) +/// - `0.5`: Medium overlap (50% of ranges overlap) +/// - `0.8`: High overlap (80% of ranges overlap between files) +/// +/// # Examples +/// +/// With 5 files and different overlap factors showing `[min, max]` ranges: +/// +/// overlap_factor = 0.0 (no overlap): +/// +/// File 0: [0, 20] +/// File 1: [20, 40] +/// File 2: [40, 60] +/// File 3: [60, 80] +/// File 4: [80, 100] +/// +/// overlap_factor = 0.5 (50% overlap): +/// +/// File 0: [0, 40] +/// File 1: [20, 60] +/// File 2: [40, 80] +/// File 3: [60, 100] +/// File 4: [80, 120] +/// +/// overlap_factor = 0.8 (80% overlap): +/// +/// File 0: [0, 100] +/// File 1: [20, 120] +/// File 2: [40, 140] +/// File 3: [60, 160] +/// File 4: [80, 180] +pub fn generate_test_files(num_files: usize, overlap_factor: f64) -> Vec { + let mut files = Vec::with_capacity(num_files); + if num_files == 0 { + return vec![]; + } + let range_size = if overlap_factor == 0.0 { + 100 / num_files as i64 + } else { + (100.0 / (overlap_factor * num_files as f64)).max(1.0) as i64 + }; + + for i in 0..num_files { + let base = (i as f64 * range_size as f64 * (1.0 - overlap_factor)) as i64; + let min = base as f64; + let max = (base + range_size) as f64; + + let file = PartitionedFile { + object_meta: ObjectMeta { + location: Path::from(format!("file_{i}.parquet")), + last_modified: chrono::Utc::now(), + size: 1000, + e_tag: None, + version: None, + }, + partition_values: vec![], + range: None, + statistics: Some(Arc::new(Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(1000), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Float64(Some(max))), + min_value: Precision::Exact(ScalarValue::Float64(Some(min))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }], + })), + extensions: None, + metadata_size_hint: None, + }; + files.push(file); + } + + vec![FileGroup::new(files)] +} + +// Helper function to verify that files within each group maintain sort order +/// Used by tests and benchmarks +pub fn verify_sort_integrity(file_groups: &[FileGroup]) -> bool { + for group in file_groups { + let files = group.iter().collect::>(); + for i in 1..files.len() { + let prev_file = files[i - 1]; + let curr_file = files[i]; + + // Check if the min value of current file is greater than max value of previous file + if let (Some(prev_stats), Some(curr_stats)) = + (&prev_file.statistics, &curr_file.statistics) + { + let prev_max = &prev_stats.column_statistics[0].max_value; + let curr_min = &curr_stats.column_statistics[0].min_value; + if curr_min.get_value().unwrap() <= prev_max.get_value().unwrap() { + return false; + } + } + } + } + true +} + #[cfg(test)] mod tests { use super::ListingTableUrl; diff --git a/datafusion/datasource/src/schema_adapter.rs b/datafusion/datasource/src/schema_adapter.rs index 4164cda8cba11..4c7b37113d58d 100644 --- a/datafusion/datasource/src/schema_adapter.rs +++ b/datafusion/datasource/src/schema_adapter.rs @@ -20,13 +20,26 @@ //! Adapter provides a method of translating the RecordBatches that come out of the //! physical format into how they should be used by DataFusion. For instance, a schema //! can be stored external to a parquet file that maps parquet logical types to arrow types. - -use arrow::array::{new_null_array, RecordBatch, RecordBatchOptions}; -use arrow::compute::{can_cast_types, cast}; -use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::plan_err; -use std::fmt::Debug; -use std::sync::Arc; +use arrow::{ + array::{new_null_array, ArrayRef, RecordBatch, RecordBatchOptions}, + compute::can_cast_types, + datatypes::{DataType, Field, Schema, SchemaRef}, +}; +use datafusion_common::{ + format::DEFAULT_CAST_OPTIONS, + nested_struct::{cast_column, validate_struct_compatibility}, + plan_err, ColumnStatistics, +}; +use std::{fmt::Debug, sync::Arc}; +/// Function used by [`SchemaMapping`] to adapt a column from the file schema to +/// the table schema. +pub type CastColumnFn = dyn Fn( + &ArrayRef, + &Field, + &arrow::compute::CastOptions, + ) -> datafusion_common::Result + + Send + + Sync; /// Factory for creating [`SchemaAdapter`] /// @@ -42,7 +55,7 @@ pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { /// Arguments: /// /// * `projected_table_schema`: The schema for the table, projected to - /// include only the fields being output (projected) by the this mapping. + /// include only the fields being output (projected) by the this mapping. /// /// * `table_schema`: The entire table schema for the table fn create( @@ -50,6 +63,17 @@ pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { projected_table_schema: SchemaRef, table_schema: SchemaRef, ) -> Box; + + /// Create a [`SchemaAdapter`] using only the projected table schema. + /// + /// This is a convenience method for cases where the table schema and the + /// projected table schema are the same. + fn create_with_projected_schema( + &self, + projected_table_schema: SchemaRef, + ) -> Box { + self.create(Arc::clone(&projected_table_schema), projected_table_schema) + } } /// Creates [`SchemaMapper`]s to map file-level [`RecordBatch`]es to a table @@ -96,6 +120,12 @@ pub trait SchemaAdapter: Send + Sync { pub trait SchemaMapper: Debug + Send + Sync { /// Adapts a `RecordBatch` to match the `table_schema` fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result; + + /// Adapts file-level column `Statistics` to match the `table_schema` + fn map_column_statistics( + &self, + file_col_statistics: &[ColumnStatistics], + ) -> datafusion_common::Result>; } /// Default [`SchemaAdapterFactory`] for mapping schemas. @@ -219,6 +249,34 @@ pub(crate) struct DefaultSchemaAdapter { projected_table_schema: SchemaRef, } +/// Checks if a file field can be cast to a table field +/// +/// Returns Ok(true) if casting is possible, or an error explaining why casting is not possible +pub(crate) fn can_cast_field( + file_field: &Field, + table_field: &Field, +) -> datafusion_common::Result { + match (file_field.data_type(), table_field.data_type()) { + (DataType::Struct(source_fields), DataType::Struct(target_fields)) => { + // validate_struct_compatibility returns Result<()>; on success we can cast structs + validate_struct_compatibility(source_fields, target_fields)?; + Ok(true) + } + _ => { + if can_cast_types(file_field.data_type(), table_field.data_type()) { + Ok(true) + } else { + plan_err!( + "Cannot cast file schema field {} of type {} to table schema field of type {}", + file_field.name(), + file_field.data_type(), + table_field.data_type() + ) + } + } + } +} + impl SchemaAdapter for DefaultSchemaAdapter { /// Map a column index in the table schema to a column index in a particular /// file schema @@ -242,40 +300,60 @@ impl SchemaAdapter for DefaultSchemaAdapter { &self, file_schema: &Schema, ) -> datafusion_common::Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - let mut field_mappings = vec![None; self.projected_table_schema.fields().len()]; - - for (file_idx, file_field) in file_schema.fields.iter().enumerate() { - if let Some((table_idx, table_field)) = - self.projected_table_schema.fields().find(file_field.name()) - { - match can_cast_types(file_field.data_type(), table_field.data_type()) { - true => { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } - false => { - return plan_err!( - "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", - file_field.name(), - file_field.data_type(), - table_field.data_type() - ) - } - } - } - } + let (field_mappings, projection) = create_field_mapping( + file_schema, + &self.projected_table_schema, + can_cast_field, + )?; Ok(( - Arc::new(SchemaMapping { - projected_table_schema: Arc::clone(&self.projected_table_schema), + Arc::new(SchemaMapping::new( + Arc::clone(&self.projected_table_schema), field_mappings, - }), + Arc::new( + |array: &ArrayRef, + field: &Field, + opts: &arrow::compute::CastOptions| { + cast_column(array, field, opts) + }, + ), + )), projection, )) } } +/// Helper function that creates field mappings between file schema and table schema +/// +/// Maps columns from the file schema to their corresponding positions in the table schema, +/// applying type compatibility checking via the provided predicate function. +/// +/// Returns field mappings (for column reordering) and a projection (for field selection). +pub(crate) fn create_field_mapping( + file_schema: &Schema, + projected_table_schema: &SchemaRef, + can_map_field: F, +) -> datafusion_common::Result<(Vec>, Vec)> +where + F: Fn(&Field, &Field) -> datafusion_common::Result, +{ + let mut projection = Vec::with_capacity(file_schema.fields().len()); + let mut field_mappings = vec![None; projected_table_schema.fields().len()]; + + for (file_idx, file_field) in file_schema.fields.iter().enumerate() { + if let Some((table_idx, table_field)) = + projected_table_schema.fields().find(file_field.name()) + { + if can_map_field(file_field, table_field)? { + field_mappings[table_idx] = Some(projection.len()); + projection.push(file_idx); + } + } + } + + Ok((field_mappings, projection)) +} + /// The SchemaMapping struct holds a mapping from the file schema to the table /// schema and any necessary type conversions. /// @@ -285,7 +363,6 @@ impl SchemaAdapter for DefaultSchemaAdapter { /// `projected_table_schema` as it can only operate on the projected fields. /// /// [`map_batch`]: Self::map_batch -#[derive(Debug)] pub struct SchemaMapping { /// The schema of the table. This is the expected schema after conversion /// and it should match the schema of the query result. @@ -296,6 +373,36 @@ pub struct SchemaMapping { /// They are Options instead of just plain `usize`s because the table could /// have fields that don't exist in the file. field_mappings: Vec>, + /// Function used to adapt a column from the file schema to the table schema + /// when it exists in both schemas + cast_column: Arc, +} + +impl Debug for SchemaMapping { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SchemaMapping") + .field("projected_table_schema", &self.projected_table_schema) + .field("field_mappings", &self.field_mappings) + .field("cast_column", &"") + .finish() + } +} + +impl SchemaMapping { + /// Creates a new SchemaMapping instance + /// + /// Initializes the field mappings needed to transform file data to the projected table schema + pub fn new( + projected_table_schema: SchemaRef, + field_mappings: Vec>, + cast_column: Arc, + ) -> Self { + Self { + projected_table_schema, + field_mappings, + cast_column, + } + } } impl SchemaMapper for SchemaMapping { @@ -303,8 +410,7 @@ impl SchemaMapper for SchemaMapping { /// conversions. /// The produced RecordBatch has a schema that contains only the projected columns. fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result { - let batch_rows = batch.num_rows(); - let batch_cols = batch.columns().to_vec(); + let (_old_schema, batch_cols, batch_rows) = batch.into_parts(); let cols = self .projected_table_schema @@ -320,18 +426,621 @@ impl SchemaMapper for SchemaMapping { // If this field only exists in the table, and not in the file, then we know // that it's null, so just return that. || Ok(new_null_array(field.data_type(), batch_rows)), - // However, if it does exist in both, then try to cast it to the correct output - // type - |batch_idx| cast(&batch_cols[batch_idx], field.data_type()), + // However, if it does exist in both, use the cast_column function + // to perform any necessary conversions + |batch_idx| { + (self.cast_column)( + &batch_cols[batch_idx], + field, + &DEFAULT_CAST_OPTIONS, + ) + }, ) }) .collect::, _>>()?; // Necessary to handle empty batches - let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + let options = RecordBatchOptions::new().with_row_count(Some(batch_rows)); let schema = Arc::clone(&self.projected_table_schema); let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } + + /// Adapts file-level column `Statistics` to match the `table_schema` + fn map_column_statistics( + &self, + file_col_statistics: &[ColumnStatistics], + ) -> datafusion_common::Result> { + let mut table_col_statistics = vec![]; + + // Map the statistics for each field in the file schema to the corresponding field in the + // table schema, if a field is not present in the file schema, we need to fill it with `ColumnStatistics::new_unknown` + for (_, file_col_idx) in self + .projected_table_schema + .fields() + .iter() + .zip(&self.field_mappings) + { + if let Some(file_col_idx) = file_col_idx { + table_col_statistics.push( + file_col_statistics + .get(*file_col_idx) + .cloned() + .unwrap_or_default(), + ); + } else { + table_col_statistics.push(ColumnStatistics::new_unknown()); + } + } + + Ok(table_col_statistics) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::{Array, ArrayRef, StringBuilder, StructArray, TimestampMillisecondArray}, + compute::cast, + datatypes::{DataType, Field, TimeUnit}, + record_batch::RecordBatch, + }; + use datafusion_common::{stats::Precision, Result, ScalarValue, Statistics}; + + #[test] + fn test_schema_mapping_map_statistics_basic() { + // Create table schema (a, b, c) + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), + ])); + + // Create file schema (b, a) - different order, missing c + let file_schema = Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, true), + ]); + + // Create SchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // Get mapper and projection + let (mapper, projection) = adapter.map_schema(&file_schema).unwrap(); + + // Should project columns 0,1 from file + assert_eq!(projection, vec![0, 1]); + + // Create file statistics + let mut file_stats = Statistics::default(); + + // Statistics for column b (index 0 in file) + let b_stats = ColumnStatistics { + null_count: Precision::Exact(5), + ..Default::default() + }; + + // Statistics for column a (index 1 in file) + let a_stats = ColumnStatistics { + null_count: Precision::Exact(10), + ..Default::default() + }; + + file_stats.column_statistics = vec![b_stats, a_stats]; + + // Map statistics + let table_col_stats = mapper + .map_column_statistics(&file_stats.column_statistics) + .unwrap(); + + // Verify stats + assert_eq!(table_col_stats.len(), 3); + assert_eq!(table_col_stats[0].null_count, Precision::Exact(10)); // a from file idx 1 + assert_eq!(table_col_stats[1].null_count, Precision::Exact(5)); // b from file idx 0 + assert_eq!(table_col_stats[2].null_count, Precision::Absent); // c (unknown) + } + + #[test] + fn test_schema_mapping_map_statistics_empty() { + // Create schemas + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + let file_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(&file_schema).unwrap(); + + // Empty file statistics + let file_stats = Statistics::default(); + let table_col_stats = mapper + .map_column_statistics(&file_stats.column_statistics) + .unwrap(); + + // All stats should be unknown + assert_eq!(table_col_stats.len(), 2); + assert_eq!(table_col_stats[0], ColumnStatistics::new_unknown(),); + assert_eq!(table_col_stats[1], ColumnStatistics::new_unknown(),); + } + + #[test] + fn test_can_cast_field() { + // Same type should work + let from_field = Field::new("col", DataType::Int32, true); + let to_field = Field::new("col", DataType::Int32, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Casting Int32 to Float64 is allowed + let from_field = Field::new("col", DataType::Int32, true); + let to_field = Field::new("col", DataType::Float64, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Casting Float64 to Utf8 should work (converts to string) + let from_field = Field::new("col", DataType::Float64, true); + let to_field = Field::new("col", DataType::Utf8, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Binary to Utf8 is not supported - this is an example of a cast that should fail + // Note: We use Binary instead of Utf8->Int32 because Arrow actually supports that cast + let from_field = Field::new("col", DataType::Binary, true); + let to_field = Field::new("col", DataType::Decimal128(10, 2), true); + let result = can_cast_field(&from_field, &to_field); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast file schema field col")); + } + + #[test] + fn test_create_field_mapping() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), + ])); + + // Define file schema: different order, missing column c, and b has different type + let file_schema = Schema::new(vec![ + Field::new("b", DataType::Float64, true), // Different type but castable to Utf8 + Field::new("a", DataType::Int32, true), // Same type + Field::new("d", DataType::Boolean, true), // Not in table schema + ]); + + // Custom can_map_field function that allows all mappings for testing + let allow_all = |_: &Field, _: &Field| Ok(true); + + // Test field mapping + let (field_mappings, projection) = + create_field_mapping(&file_schema, &table_schema, allow_all).unwrap(); + + // Expected: + // - field_mappings[0] (a) maps to projection[1] + // - field_mappings[1] (b) maps to projection[0] + // - field_mappings[2] (c) is None (not in file) + assert_eq!(field_mappings, vec![Some(1), Some(0), None]); + assert_eq!(projection, vec![0, 1]); // Projecting file columns b, a + + // Test with a failing mapper + let fails_all = |_: &Field, _: &Field| Ok(false); + let (field_mappings, projection) = + create_field_mapping(&file_schema, &table_schema, fails_all).unwrap(); + + // Should have no mappings or projections if all cast checks fail + assert_eq!(field_mappings, vec![None, None, None]); + assert_eq!(projection, Vec::::new()); + + // Test with error-producing mapper + let error_mapper = |_: &Field, _: &Field| plan_err!("Test error"); + let result = create_field_mapping(&file_schema, &table_schema, error_mapper); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Test error")); + } + + #[test] + fn test_schema_mapping_new() { + // Define the projected table schema + let projected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + + // Define field mappings from table to file + let field_mappings = vec![Some(1), Some(0)]; + + // Create SchemaMapping manually + let mapping = SchemaMapping::new( + Arc::clone(&projected_schema), + field_mappings.clone(), + Arc::new( + |array: &ArrayRef, field: &Field, opts: &arrow::compute::CastOptions| { + cast_column(array, field, opts) + }, + ), + ); + + // Check that fields were set correctly + assert_eq!(*mapping.projected_table_schema, *projected_schema); + assert_eq!(mapping.field_mappings, field_mappings); + + // Test with a batch to ensure it works properly + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("b_file", DataType::Utf8, true), + Field::new("a_file", DataType::Int32, true), + ])), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["hello", "world"])), + Arc::new(arrow::array::Int32Array::from(vec![1, 2])), + ], + ) + .unwrap(); + + // Test that map_batch works with our manually created mapping + let mapped_batch = mapping.map_batch(batch).unwrap(); + + // Verify the mapped batch has the correct schema and data + assert_eq!(*mapped_batch.schema(), *projected_schema); + assert_eq!(mapped_batch.num_columns(), 2); + assert_eq!(mapped_batch.column(0).len(), 2); // a column + assert_eq!(mapped_batch.column(1).len(), 2); // b column + } + + #[test] + fn test_map_schema_error_path() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Decimal128(10, 2), true), // Use Decimal which has stricter cast rules + ])); + + // Define file schema with incompatible type for column c + let file_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true), // Different but castable + Field::new("c", DataType::Binary, true), // Not castable to Decimal128 + ]); + + // Create DefaultSchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // map_schema should error due to incompatible types + let result = adapter.map_schema(&file_schema); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast file schema field c")); + } + + #[test] + fn test_map_schema_happy_path() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Decimal128(10, 2), true), + ])); + + // Create DefaultSchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // Define compatible file schema (missing column c) + let compatible_file_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), // Can be cast to Int32 + Field::new("b", DataType::Float64, true), // Can be cast to Utf8 + ]); + + // Test successful schema mapping + let (mapper, projection) = adapter.map_schema(&compatible_file_schema).unwrap(); + + // Verify field_mappings and projection created correctly + assert_eq!(projection, vec![0, 1]); // Projecting a and b + + // Verify the SchemaMapping works with actual data + let file_batch = RecordBatch::try_new( + Arc::new(compatible_file_schema.clone()), + vec![ + Arc::new(arrow::array::Int64Array::from(vec![100, 200])), + Arc::new(arrow::array::Float64Array::from(vec![1.5, 2.5])), + ], + ) + .unwrap(); + + let mapped_batch = mapper.map_batch(file_batch).unwrap(); + + // Verify correct schema mapping + assert_eq!(*mapped_batch.schema(), *table_schema); + assert_eq!(mapped_batch.num_columns(), 3); // a, b, c + + // Column c should be null since it wasn't in the file schema + let c_array = mapped_batch.column(2); + assert_eq!(c_array.len(), 2); + assert_eq!(c_array.null_count(), 2); + } + + #[test] + fn test_adapt_struct_with_added_nested_fields() -> Result<()> { + let (file_schema, table_schema) = create_test_schemas_with_nested_fields(); + let batch = create_test_batch_with_struct_data(&file_schema)?; + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(file_schema.as_ref())?; + let mapped_batch = mapper.map_batch(batch)?; + + verify_adapted_batch_with_nested_fields(&mapped_batch, &table_schema)?; + Ok(()) + } + + #[test] + fn test_map_column_statistics_struct() -> Result<()> { + let (file_schema, table_schema) = create_test_schemas_with_nested_fields(); + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(file_schema.as_ref())?; + + let file_stats = vec![ + create_test_column_statistics( + 0, + 100, + Some(ScalarValue::Int32(Some(1))), + Some(ScalarValue::Int32(Some(100))), + Some(ScalarValue::Int32(Some(5100))), + ), + create_test_column_statistics(10, 50, None, None, None), + ]; + + let table_stats = mapper.map_column_statistics(&file_stats)?; + assert_eq!(table_stats.len(), 1); + verify_column_statistics( + &table_stats[0], + Some(0), + Some(100), + Some(ScalarValue::Int32(Some(1))), + Some(ScalarValue::Int32(Some(100))), + Some(ScalarValue::Int32(Some(5100))), + ); + let missing_stats = mapper.map_column_statistics(&[])?; + assert_eq!(missing_stats.len(), 1); + assert_eq!(missing_stats[0], ColumnStatistics::new_unknown()); + Ok(()) + } + + fn create_test_schemas_with_nested_fields() -> (SchemaRef, SchemaRef) { + let file_schema = Arc::new(Schema::new(vec![Field::new( + "info", + DataType::Struct( + vec![ + Field::new("location", DataType::Utf8, true), + Field::new( + "timestamp_utc", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + ] + .into(), + ), + true, + )])); + + let table_schema = Arc::new(Schema::new(vec![Field::new( + "info", + DataType::Struct( + vec![ + Field::new("location", DataType::Utf8, true), + Field::new( + "timestamp_utc", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + Field::new( + "reason", + DataType::Struct( + vec![ + Field::new("_level", DataType::Float64, true), + Field::new( + "details", + DataType::Struct( + vec![ + Field::new("rurl", DataType::Utf8, true), + Field::new("s", DataType::Float64, true), + Field::new("t", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ] + .into(), + ), + true, + ), + ] + .into(), + ), + true, + )])); + + (file_schema, table_schema) + } + + fn create_test_batch_with_struct_data( + file_schema: &SchemaRef, + ) -> Result { + let mut location_builder = StringBuilder::new(); + location_builder.append_value("San Francisco"); + location_builder.append_value("New York"); + + let timestamp_array = TimestampMillisecondArray::from(vec![ + Some(1640995200000), + Some(1641081600000), + ]); + + let timestamp_type = + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())); + let timestamp_array = cast(×tamp_array, ×tamp_type)?; + + let info_struct = StructArray::from(vec![ + ( + Arc::new(Field::new("location", DataType::Utf8, true)), + Arc::new(location_builder.finish()) as ArrayRef, + ), + ( + Arc::new(Field::new("timestamp_utc", timestamp_type, true)), + timestamp_array, + ), + ]); + + Ok(RecordBatch::try_new( + Arc::clone(file_schema), + vec![Arc::new(info_struct)], + )?) + } + + fn verify_adapted_batch_with_nested_fields( + mapped_batch: &RecordBatch, + table_schema: &SchemaRef, + ) -> Result<()> { + assert_eq!(mapped_batch.schema(), *table_schema); + assert_eq!(mapped_batch.num_rows(), 2); + + let info_col = mapped_batch.column(0); + let info_array = info_col + .as_any() + .downcast_ref::() + .expect("Expected info column to be a StructArray"); + + verify_preserved_fields(info_array)?; + verify_reason_field_structure(info_array)?; + Ok(()) + } + + fn verify_preserved_fields(info_array: &StructArray) -> Result<()> { + let location_col = info_array + .column_by_name("location") + .expect("Expected location field in struct"); + let location_array = location_col + .as_any() + .downcast_ref::() + .expect("Expected location to be a StringArray"); + assert_eq!(location_array.value(0), "San Francisco"); + assert_eq!(location_array.value(1), "New York"); + + let timestamp_col = info_array + .column_by_name("timestamp_utc") + .expect("Expected timestamp_utc field in struct"); + let timestamp_array = timestamp_col + .as_any() + .downcast_ref::() + .expect("Expected timestamp_utc to be a TimestampMillisecondArray"); + assert_eq!(timestamp_array.value(0), 1640995200000); + assert_eq!(timestamp_array.value(1), 1641081600000); + Ok(()) + } + + fn verify_reason_field_structure(info_array: &StructArray) -> Result<()> { + let reason_col = info_array + .column_by_name("reason") + .expect("Expected reason field in struct"); + let reason_array = reason_col + .as_any() + .downcast_ref::() + .expect("Expected reason to be a StructArray"); + assert_eq!(reason_array.fields().len(), 2); + assert!(reason_array.column_by_name("_level").is_some()); + assert!(reason_array.column_by_name("details").is_some()); + + let details_col = reason_array + .column_by_name("details") + .expect("Expected details field in reason struct"); + let details_array = details_col + .as_any() + .downcast_ref::() + .expect("Expected details to be a StructArray"); + assert_eq!(details_array.fields().len(), 3); + assert!(details_array.column_by_name("rurl").is_some()); + assert!(details_array.column_by_name("s").is_some()); + assert!(details_array.column_by_name("t").is_some()); + for i in 0..2 { + assert!(reason_array.is_null(i), "reason field should be null"); + } + Ok(()) + } + + fn verify_column_statistics( + stats: &ColumnStatistics, + expected_null_count: Option, + expected_distinct_count: Option, + expected_min: Option, + expected_max: Option, + expected_sum: Option, + ) { + if let Some(count) = expected_null_count { + assert_eq!( + stats.null_count, + Precision::Exact(count), + "Null count should match expected value" + ); + } + if let Some(count) = expected_distinct_count { + assert_eq!( + stats.distinct_count, + Precision::Exact(count), + "Distinct count should match expected value" + ); + } + if let Some(min) = expected_min { + assert_eq!( + stats.min_value, + Precision::Exact(min), + "Min value should match expected value" + ); + } + if let Some(max) = expected_max { + assert_eq!( + stats.max_value, + Precision::Exact(max), + "Max value should match expected value" + ); + } + if let Some(sum) = expected_sum { + assert_eq!( + stats.sum_value, + Precision::Exact(sum), + "Sum value should match expected value" + ); + } + } + + fn create_test_column_statistics( + null_count: usize, + distinct_count: usize, + min_value: Option, + max_value: Option, + sum_value: Option, + ) -> ColumnStatistics { + ColumnStatistics { + null_count: Precision::Exact(null_count), + distinct_count: Precision::Exact(distinct_count), + min_value: min_value.map_or_else(|| Precision::Absent, Precision::Exact), + max_value: max_value.map_or_else(|| Precision::Absent, Precision::Exact), + sum_value: sum_value.map_or_else(|| Precision::Absent, Precision::Exact), + } + } } diff --git a/datafusion/datasource/src/sink.rs b/datafusion/datasource/src/sink.rs index 0552370d8ed0c..a8adb46b96ffa 100644 --- a/datafusion/datasource/src/sink.rs +++ b/datafusion/datasource/src/sink.rs @@ -22,22 +22,21 @@ use std::fmt; use std::fmt::Debug; use std::sync::Arc; -use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_plan::{ - execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, -}; - use arrow::array::{ArrayRef, RecordBatch, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{Distribution, EquivalenceProperties}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::{LexRequirement, OrderingRequirements}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, + ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, +}; use async_trait::async_trait; +use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType}; use futures::StreamExt; /// `DataSink` implements writing streams of [`RecordBatch`]es to @@ -47,7 +46,7 @@ use futures::StreamExt; /// output. #[async_trait] pub trait DataSink: DisplayAs + Debug + Send + Sync { - /// Returns the data sink as [`Any`](std::any::Any) so that it can be + /// Returns the data sink as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -101,6 +100,11 @@ impl Debug for DataSinkExec { impl DataSinkExec { /// Create a plan to write to `sink` + /// Note: DataSinkExec requires its input to have a single partition. + /// If the input has multiple partitions, the physical optimizer will + /// automatically insert a Merge-related operator to merge them. + /// If you construct PhysicalPlan without going through the physical optimizer, + /// you must ensure that the input has a single partition. pub fn new( input: Arc, sink: Arc, @@ -143,6 +147,8 @@ impl DataSinkExec { input.pipeline_behavior(), input.boundedness(), ) + .with_scheduling_type(SchedulingType::Cooperative) + .with_evaluation_type(EvaluationType::Eager) } } @@ -184,10 +190,10 @@ impl ExecutionPlan for DataSinkExec { vec![Distribution::SinglePartition; self.children().len()] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { // The required input ordering is set externally (e.g. by a `ListingTable`). - // Otherwise, there is no specific requirement (i.e. `sort_expr` is `None`). - vec![self.sort_order.as_ref().cloned()] + // Otherwise, there is no specific requirement (i.e. `sort_order` is `None`). + vec![self.sort_order.as_ref().cloned().map(Into::into)] } fn maintains_input_order(&self) -> Vec { diff --git a/datafusion/datasource/src/source.rs b/datafusion/datasource/src/source.rs index 6c9122ce1ac10..20d9a1d6e53f0 100644 --- a/datafusion/datasource/src/source.rs +++ b/datafusion/datasource/src/source.rs @@ -22,53 +22,136 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::execution_plan::{ + Boundedness, EmissionType, SchedulingType, +}; +use datafusion_physical_plan::metrics::SplitMetrics; use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; +use datafusion_physical_plan::stream::BatchSplitStream; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; +use itertools::Itertools; use crate::file_scan_config::FileScanConfig; use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::{Constraints, Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::filter_pushdown::{ + ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, PushedDown, +}; -/// Common behaviors in Data Sources for both from Files and Memory. +/// A source of data, typically a list of files or memory +/// +/// This trait provides common behaviors for abstract sources of data. It has +/// two common implementations: +/// +/// 1. [`FileScanConfig`]: lists of files +/// 2. [`MemorySourceConfig`]: in memory list of `RecordBatch` +/// +/// File format specific behaviors are defined by [`FileSource`] /// /// # See Also -/// * [`DataSourceExec`] for physical plan implementation -/// * [`FileSource`] for file format implementations (Parquet, Json, etc) +/// * [`FileSource`] for file format specific implementations (Parquet, Json, etc) +/// * [`DataSourceExec`]: The [`ExecutionPlan`] that reads from a `DataSource` /// /// # Notes +/// /// Requires `Debug` to assist debugging /// +/// [`FileScanConfig`]: https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.FileScanConfig.html +/// [`MemorySourceConfig`]: https://docs.rs/datafusion/latest/datafusion/datasource/memory/struct.MemorySourceConfig.html /// [`FileSource`]: crate::file::FileSource +/// [`FileFormat``]: https://docs.rs/datafusion/latest/datafusion/datasource/file_format/index.html +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html +/// +/// The following diagram shows how DataSource, FileSource, and DataSourceExec are related +/// ```text +/// ┌─────────────────────┐ -----► execute path +/// │ │ ┄┄┄┄┄► init path +/// │ DataSourceExec │ +/// │ │ +/// └───────▲─────────────┘ +/// ┊ │ +/// ┊ │ +/// ┌──────────▼──────────┐ ┌──────────-──────────┐ +/// │ │ | | +/// │ DataSource(trait) │ | TableProvider(trait)| +/// │ │ | | +/// └───────▲─────────────┘ └─────────────────────┘ +/// ┊ │ ┊ +/// ┌───────────────┿──┴────────────────┐ ┊ +/// | ┌┄┄┄┄┄┄┄┄┄┄┄┘ | ┊ +/// | ┊ | ┊ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ ┊ +/// │ │ │ │ ┌──────────▼──────────┐ +/// │ FileScanConfig │ │ MemorySourceConfig │ | | +/// │ │ │ │ | FileFormat(trait) | +/// └──────────────▲──────┘ └─────────────────────┘ | | +/// │ ┊ └─────────────────────┘ +/// │ ┊ ┊ +/// │ ┊ ┊ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ │ │ ArrowSource │ +/// │ FileSource(trait) ◄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄│ ... │ +/// │ │ │ ParquetSource │ +/// └─────────────────────┘ └─────────────────────┘ +/// │ +/// │ +/// │ +/// │ +/// ┌──────────▼──────────┐ +/// │ ArrowSource │ +/// │ ... │ +/// │ ParquetSource │ +/// └─────────────────────┘ +/// | +/// FileOpener (called by FileStream) +/// │ +/// ┌──────────▼──────────┐ +/// │ │ +/// │ RecordBatch │ +/// │ │ +/// └─────────────────────┘ +/// ``` pub trait DataSource: Send + Sync + Debug { fn open( &self, partition: usize, context: Arc, - ) -> datafusion_common::Result; + ) -> Result; fn as_any(&self) -> &dyn Any; /// Format this source for display in explain plans fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result; - /// Return a copy of this DataSource with a new partitioning scheme + /// Return a copy of this DataSource with a new partitioning scheme. + /// + /// Returns `Ok(None)` (the default) if the partitioning cannot be changed. + /// Refer to [`ExecutionPlan::repartitioned`] for details on when None should be returned. + /// + /// Repartitioning should not change the output ordering, if this ordering exists. + /// Refer to [`MemorySourceConfig::repartition_preserving_order`](crate::memory::MemorySourceConfig) + /// and the FileSource's + /// [`FileGroupPartitioner::repartition_file_groups`](crate::file_groups::FileGroupPartitioner::repartition_file_groups) + /// for examples. fn repartitioned( &self, _target_partitions: usize, _repartition_file_min_size: usize, _output_ordering: Option, - ) -> datafusion_common::Result>> { + ) -> Result>> { Ok(None) } fn output_partitioning(&self) -> Partitioning; fn eq_properties(&self) -> EquivalenceProperties; - fn statistics(&self) -> datafusion_common::Result; + fn scheduling_type(&self) -> SchedulingType { + SchedulingType::NonCooperative + } + fn statistics(&self) -> Result; /// Return a copy of this DataSource with a new fetch limit fn with_fetch(&self, _limit: Option) -> Option>; fn fetch(&self) -> Option; @@ -77,18 +160,33 @@ pub trait DataSource: Send + Sync + Debug { } fn try_swapping_with_projection( &self, - _projection: &ProjectionExec, - ) -> datafusion_common::Result>>; + _projection: &[ProjectionExpr], + ) -> Result>>; + /// Try to push down filters into this DataSource. + /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. + /// + /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result + fn try_pushdown_filters( + &self, + filters: Vec>, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::with_parent_pushdown_result( + vec![PushedDown::No; filters.len()], + )) + } } -/// [`ExecutionPlan`] handles different file formats like JSON, CSV, AVRO, ARROW, PARQUET +/// [`ExecutionPlan`] that reads one or more files +/// +/// `DataSourceExec` implements common functionality such as applying +/// projections, and caching plan properties. /// -/// `DataSourceExec` implements common functionality such as applying projections, -/// and caching plan properties. +/// The [`DataSource`] describes where to find the data for this data source +/// (for example in files or what in memory partitions). /// -/// The [`DataSource`] trait describes where to find the data for this data -/// source (for example what files or what in memory partitions). Format -/// specifics are implemented with the [`FileSource`] trait. +/// For file based [`DataSource`]s, format specific behavior is implemented in +/// the [`FileSource`] trait. /// /// [`FileSource`]: crate::file::FileSource #[derive(Clone, Debug)] @@ -131,15 +229,19 @@ impl ExecutionPlan for DataSourceExec { fn with_new_children( self: Arc, _: Vec>, - ) -> datafusion_common::Result> { + ) -> Result> { Ok(self) } + /// Implementation of [`ExecutionPlan::repartitioned`] which relies upon the inner [`DataSource::repartitioned`]. + /// + /// If the data source does not support changing its partitioning, returns `Ok(None)` (the default). Refer + /// to [`ExecutionPlan::repartitioned`] for more details. fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, - ) -> datafusion_common::Result>> { + ) -> Result>> { let data_source = self.data_source.repartitioned( target_partitions, config.optimizer.repartition_file_min_size, @@ -163,16 +265,41 @@ impl ExecutionPlan for DataSourceExec { &self, partition: usize, context: Arc, - ) -> datafusion_common::Result { - self.data_source.open(partition, context) + ) -> Result { + let stream = self.data_source.open(partition, Arc::clone(&context))?; + let batch_size = context.session_config().batch_size(); + log::debug!( + "Batch splitting enabled for partition {partition}: batch_size={batch_size}" + ); + let metrics = self.data_source.metrics(); + let split_metrics = SplitMetrics::new(&metrics, partition); + Ok(Box::pin(BatchSplitStream::new( + stream, + batch_size, + split_metrics, + ))) } fn metrics(&self) -> Option { Some(self.data_source.metrics().clone_inner()) } - fn statistics(&self) -> datafusion_common::Result { - self.data_source.statistics() + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + let mut statistics = Statistics::new_unknown(&self.schema()); + if let Some(file_config) = + self.data_source.as_any().downcast_ref::() + { + if let Some(file_group) = file_config.file_groups.get(partition) { + if let Some(stat) = file_group.file_statistics(None) { + statistics = stat.clone(); + } + } + } + Ok(statistics) + } else { + Ok(self.data_source.statistics()?) + } } fn with_fetch(&self, limit: Option) -> Option> { @@ -189,8 +316,51 @@ impl ExecutionPlan for DataSourceExec { fn try_swapping_with_projection( &self, projection: &ProjectionExec, - ) -> datafusion_common::Result>> { - self.data_source.try_swapping_with_projection(projection) + ) -> Result>> { + match self + .data_source + .try_swapping_with_projection(projection.expr())? + { + Some(new_data_source) => { + Ok(Some(Arc::new(DataSourceExec::new(new_data_source)))) + } + None => Ok(None), + } + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + config: &ConfigOptions, + ) -> Result>> { + // Push any remaining filters into our data source + let parent_filters = child_pushdown_result + .parent_filters + .into_iter() + .map(|f| f.filter) + .collect_vec(); + let res = self + .data_source + .try_pushdown_filters(parent_filters.clone(), config)?; + match res.updated_node { + Some(data_source) => { + let mut new_node = self.clone(); + new_node.data_source = data_source; + // Re-compute properties since we have new filters which will impact equivalence info + new_node.cache = + Self::compute_properties(Arc::clone(&new_node.data_source)); + + Ok(FilterPushdownPropagation { + filters: res.filters, + updated_node: Some(Arc::new(new_node)), + }) + } + None => Ok(FilterPushdownPropagation { + filters: res.filters, + updated_node: None, + }), + } } } @@ -199,6 +369,7 @@ impl DataSourceExec { Arc::new(Self::new(Arc::new(data_source))) } + // Default constructor for `DataSourceExec`, setting the `cooperative` flag to `true`. pub fn new(data_source: Arc) -> Self { let cache = Self::compute_properties(Arc::clone(&data_source)); Self { data_source, cache } @@ -234,6 +405,7 @@ impl DataSourceExec { EmissionType::Incremental, Boundedness::Bounded, ) + .with_scheduling_type(data_source.scheduling_type()) } /// Downcast the `DataSourceExec`'s `data_source` to a specific file source @@ -254,3 +426,13 @@ impl DataSourceExec { }) } } + +/// Create a new `DataSourceExec` from a `DataSource` +impl From for DataSourceExec +where + S: DataSource + 'static, +{ + fn from(source: S) -> Self { + Self::new(Arc::new(source)) + } +} diff --git a/datafusion/datasource/src/statistics.rs b/datafusion/datasource/src/statistics.rs index 7e875513f03fc..0dd9bdb87c40a 100644 --- a/datafusion/datasource/src/statistics.rs +++ b/datafusion/datasource/src/statistics.rs @@ -20,26 +20,25 @@ //! Currently, this module houses code to sort file groups if they are non-overlapping with //! respect to the required sort order. See [`MinMaxStatistics`] -use futures::{Stream, StreamExt}; -use std::mem; use std::sync::Arc; use crate::file_groups::FileGroup; use crate::PartitionedFile; use arrow::array::RecordBatch; +use arrow::compute::SortColumn; use arrow::datatypes::SchemaRef; -use arrow::{ - compute::SortColumn, - row::{Row, Rows}, -}; +use arrow::row::{Row, Rows}; use datafusion_common::stats::Precision; -use datafusion_common::ScalarValue; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::{ + plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::{ColumnStatistics, Statistics}; +use futures::{Stream, StreamExt}; + /// A normalized representation of file min/max statistics that allows for efficient sorting & comparison. /// The min/max values are ordered by [`Self::sort_order`]. /// Furthermore, any columns that are reversed in the sort order have their min/max values swapped. @@ -58,12 +57,12 @@ impl MinMaxStatistics { /// Min value at index #[allow(unused)] - pub fn min(&self, idx: usize) -> Row { + pub fn min(&'_ self, idx: usize) -> Row<'_> { self.min_by_sort_order.row(idx) } /// Max value at index - pub fn max(&self, idx: usize) -> Row { + pub fn max(&'_ self, idx: usize) -> Row<'_> { self.max_by_sort_order.row(idx) } @@ -73,9 +72,7 @@ impl MinMaxStatistics { projection: Option<&[usize]>, // Indices of projection in full table schema (None = all columns) files: impl IntoIterator, ) -> Result { - use datafusion_common::ScalarValue; - - let statistics_and_partition_values = files + let Some(statistics_and_partition_values) = files .into_iter() .map(|file| { file.statistics @@ -83,9 +80,9 @@ impl MinMaxStatistics { .zip(Some(file.partition_values.as_slice())) }) .collect::>>() - .ok_or_else(|| { - DataFusionError::Plan("Parquet file missing statistics".to_string()) - })?; + else { + return plan_err!("Parquet file missing statistics"); + }; // Helper function to get min/max statistics for a given column of projected_schema let get_min_max = |i: usize| -> Result<(Vec, Vec)> { @@ -98,9 +95,7 @@ impl MinMaxStatistics { .get_value() .cloned() .zip(s.column_statistics[i].max_value.get_value().cloned()) - .ok_or_else(|| { - DataFusionError::Plan("statistics not found".to_string()) - }) + .ok_or_else(|| plan_datafusion_err!("statistics not found")) } else { let partition_value = &pv[i - s.column_statistics.len()]; Ok((partition_value.clone(), partition_value.clone())) @@ -111,27 +106,28 @@ impl MinMaxStatistics { .unzip()) }; - let sort_columns = sort_columns_from_physical_sort_exprs(projected_sort_order) - .ok_or(DataFusionError::Plan( - "sort expression must be on column".to_string(), - ))?; + let Some(sort_columns) = + sort_columns_from_physical_sort_exprs(projected_sort_order) + else { + return plan_err!("sort expression must be on column"); + }; // Project the schema & sort order down to just the relevant columns let min_max_schema = Arc::new( projected_schema .project(&(sort_columns.iter().map(|c| c.index()).collect::>()))?, ); - let min_max_sort_order = LexOrdering::from( - sort_columns - .iter() - .zip(projected_sort_order.iter()) - .enumerate() - .map(|(i, (col, sort))| PhysicalSortExpr { - expr: Arc::new(Column::new(col.name(), i)), - options: sort.options, - }) - .collect::>(), - ); + + let min_max_sort_order = projected_sort_order + .iter() + .zip(sort_columns.iter()) + .enumerate() + .map(|(idx, (sort_expr, col))| { + let expr = Arc::new(Column::new(col.name(), idx)); + PhysicalSortExpr::new(expr, sort_expr.options) + }); + // Safe to `unwrap` as we know that sort columns are non-empty: + let min_max_sort_order = LexOrdering::new(min_max_sort_order).unwrap(); let (min_values, max_values): (Vec<_>, Vec<_>) = sort_columns .iter() @@ -139,7 +135,9 @@ impl MinMaxStatistics { // Reverse the projection to get the index of the column in the full statistics // The file statistics contains _every_ column , but the sort column's index() // refers to the index in projected_schema - let i = projection.map(|p| p[c.index()]).unwrap_or(c.index()); + let i = projection + .map(|p| p[c.index()]) + .unwrap_or_else(|| c.index()); let (min, max) = get_min_max(i).map_err(|e| { e.context(format!("get min/max for column: '{}'", c.name())) @@ -159,12 +157,18 @@ impl MinMaxStatistics { &min_max_schema, RecordBatch::try_new(Arc::clone(&min_max_schema), min_values).map_err( |e| { - DataFusionError::ArrowError(e, Some("\ncreate min batch".to_string())) + DataFusionError::ArrowError( + Box::new(e), + Some("\ncreate min batch".to_string()), + ) }, )?, RecordBatch::try_new(Arc::clone(&min_max_schema), max_values).map_err( |e| { - DataFusionError::ArrowError(e, Some("\ncreate max batch".to_string())) + DataFusionError::ArrowError( + Box::new(e), + Some("\ncreate max batch".to_string()), + ) }, )?, ) @@ -189,25 +193,23 @@ impl MinMaxStatistics { .map_err(|e| e.context("create sort fields"))?; let converter = RowConverter::new(sort_fields)?; - let sort_columns = sort_columns_from_physical_sort_exprs(sort_order).ok_or( - DataFusionError::Plan("sort expression must be on column".to_string()), - )?; + let Some(sort_columns) = sort_columns_from_physical_sort_exprs(sort_order) else { + return plan_err!("sort expression must be on column"); + }; // swap min/max if they're reversed in the ordering let (new_min_cols, new_max_cols): (Vec<_>, Vec<_>) = sort_order .iter() .zip(sort_columns.iter().copied()) .map(|(sort_expr, column)| { - if sort_expr.options.descending { - max_values - .column_by_name(column.name()) - .zip(min_values.column_by_name(column.name())) + let maxes = max_values.column_by_name(column.name()); + let mins = min_values.column_by_name(column.name()); + let opt_value = if sort_expr.options.descending { + maxes.zip(mins) } else { - min_values - .column_by_name(column.name()) - .zip(max_values.column_by_name(column.name())) - } - .ok_or_else(|| { + mins.zip(maxes) + }; + opt_value.ok_or_else(|| { plan_datafusion_err!( "missing column in MinMaxStatistics::new: '{}'", column.name() @@ -228,14 +230,7 @@ impl MinMaxStatistics { .zip(sort_columns.iter().copied()) .map(|(sort_expr, column)| { let schema = values.schema(); - let idx = schema.index_of(column.name())?; - let field = schema.field(idx); - - // check that sort columns are non-nullable - if field.is_nullable() { - return plan_err!("cannot sort by nullable column"); - } Ok(SortColumn { values: Arc::clone(values.column(idx)), @@ -252,7 +247,10 @@ impl MinMaxStatistics { .collect::>(), ) .map_err(|e| { - DataFusionError::ArrowError(e, Some("convert columns".to_string())) + DataFusionError::ArrowError( + Box::new(e), + Some("convert columns".to_string()), + ) }) }); @@ -285,7 +283,7 @@ fn sort_columns_from_physical_sort_exprs( sort_order .iter() .map(|expr| expr.expr.as_any().downcast_ref::()) - .collect::>>() + .collect() } /// Get all files as well as the file level summary statistics (no statistic for partition columns). @@ -357,10 +355,9 @@ pub async fn get_statistics_with_limit( // counts across all the files in question. If any file does not // provide any information or provides an inexact value, we demote // the statistic precision to inexact. - num_rows = add_row_stats(file_stats.num_rows, num_rows); + num_rows = num_rows.add(&file_stats.num_rows); - total_byte_size = - add_row_stats(file_stats.total_byte_size, total_byte_size); + total_byte_size = total_byte_size.add(&file_stats.total_byte_size); for (file_col_stats, col_stats) in file_stats .column_statistics @@ -375,10 +372,10 @@ pub async fn get_statistics_with_limit( distinct_count: _, } = file_col_stats; - col_stats.null_count = add_row_stats(*file_nc, col_stats.null_count); - set_max_if_greater(file_max, &mut col_stats.max_value); - set_min_if_lesser(file_min, &mut col_stats.min_value); - col_stats.sum_value = file_sum.add(&col_stats.sum_value); + col_stats.null_count = col_stats.null_count.add(file_nc); + col_stats.max_value = col_stats.max_value.max(file_max); + col_stats.min_value = col_stats.min_value.min(file_min); + col_stats.sum_value = col_stats.sum_value.add(file_sum); } // If the number of rows exceeds the limit, we can stop processing @@ -409,62 +406,6 @@ pub async fn get_statistics_with_limit( Ok((result_files, statistics)) } -/// Generic function to compute statistics across multiple items that have statistics -fn compute_summary_statistics( - items: I, - file_schema: &SchemaRef, - stats_extractor: impl Fn(&T) -> Option<&Statistics>, -) -> Statistics -where - I: IntoIterator, -{ - let size = file_schema.fields().len(); - let mut col_stats_set = vec![ColumnStatistics::default(); size]; - let mut num_rows = Precision::::Absent; - let mut total_byte_size = Precision::::Absent; - - for (idx, item) in items.into_iter().enumerate() { - if let Some(item_stats) = stats_extractor(&item) { - if idx == 0 { - // First item, set values directly - num_rows = item_stats.num_rows; - total_byte_size = item_stats.total_byte_size; - for (index, column_stats) in - item_stats.column_statistics.iter().enumerate() - { - col_stats_set[index].null_count = column_stats.null_count; - col_stats_set[index].max_value = column_stats.max_value.clone(); - col_stats_set[index].min_value = column_stats.min_value.clone(); - col_stats_set[index].sum_value = column_stats.sum_value.clone(); - } - continue; - } - - // Accumulate statistics for subsequent items - num_rows = add_row_stats(item_stats.num_rows, num_rows); - total_byte_size = add_row_stats(item_stats.total_byte_size, total_byte_size); - - for (item_col_stats, col_stats) in item_stats - .column_statistics - .iter() - .zip(col_stats_set.iter_mut()) - { - col_stats.null_count = - add_row_stats(item_col_stats.null_count, col_stats.null_count); - set_max_if_greater(&item_col_stats.max_value, &mut col_stats.max_value); - set_min_if_lesser(&item_col_stats.min_value, &mut col_stats.min_value); - col_stats.sum_value = item_col_stats.sum_value.add(&col_stats.sum_value); - } - } - } - - Statistics { - num_rows, - total_byte_size, - column_statistics: col_stats_set, - } -} - /// Computes the summary statistics for a group of files(`FileGroup` level's statistics). /// /// This function combines statistics from all files in the file group to create @@ -489,12 +430,13 @@ pub fn compute_file_group_statistics( return Ok(file_group); } - let statistics = - compute_summary_statistics(file_group.iter(), &file_schema, |file| { - file.statistics.as_ref().map(|stats| stats.as_ref()) - }); + let file_group_stats = file_group.iter().filter_map(|file| { + let stats = file.statistics.as_ref()?; + Some(stats.as_ref()) + }); + let statistics = Statistics::try_merge_iter(file_group_stats, &file_schema)?; - Ok(file_group.with_statistics(statistics)) + Ok(file_group.with_statistics(Arc::new(statistics))) } /// Computes statistics for all files across multiple file groups. @@ -506,7 +448,7 @@ pub fn compute_file_group_statistics( /// /// # Parameters /// * `file_groups` - Vector of file groups to process -/// * `file_schema` - Schema of the files +/// * `table_schema` - Schema of the table /// * `collect_stats` - Whether to collect statistics /// * `inexact_stats` - Whether to mark the resulting statistics as inexact /// @@ -516,26 +458,28 @@ pub fn compute_file_group_statistics( /// * The summary statistics across all file groups, aka all files summary statistics pub fn compute_all_files_statistics( file_groups: Vec, - file_schema: SchemaRef, + table_schema: SchemaRef, collect_stats: bool, inexact_stats: bool, ) -> Result<(Vec, Statistics)> { - let mut file_groups_with_stats = Vec::with_capacity(file_groups.len()); - - // First compute statistics for each file group - for file_group in file_groups { - file_groups_with_stats.push(compute_file_group_statistics( - file_group, - Arc::clone(&file_schema), - collect_stats, - )?); - } + let file_groups_with_stats = file_groups + .into_iter() + .map(|file_group| { + compute_file_group_statistics( + file_group, + Arc::clone(&table_schema), + collect_stats, + ) + }) + .collect::>>()?; // Then summary statistics across all file groups + let file_groups_statistics = file_groups_with_stats + .iter() + .filter_map(|file_group| file_group.file_statistics(None)); + let mut statistics = - compute_summary_statistics(&file_groups_with_stats, &file_schema, |file_group| { - file_group.statistics() - }); + Statistics::try_merge_iter(file_groups_statistics, &table_schema)?; if inexact_stats { statistics = statistics.to_inexact() @@ -544,255 +488,10 @@ pub fn compute_all_files_statistics( Ok((file_groups_with_stats, statistics)) } +#[deprecated(since = "47.0.0", note = "Use Statistics::add")] pub fn add_row_stats( file_num_rows: Precision, num_rows: Precision, ) -> Precision { - match (file_num_rows, &num_rows) { - (Precision::Absent, _) => num_rows.to_inexact(), - (lhs, Precision::Absent) => lhs.to_inexact(), - (lhs, rhs) => lhs.add(rhs), - } -} - -/// If the given value is numerically greater than the original maximum value, -/// return the new maximum value with appropriate exactness information. -fn set_max_if_greater( - max_nominee: &Precision, - max_value: &mut Precision, -) { - match (&max_value, max_nominee) { - (Precision::Exact(val1), Precision::Exact(val2)) if val1 < val2 => { - *max_value = max_nominee.clone(); - } - (Precision::Exact(val1), Precision::Inexact(val2)) - | (Precision::Inexact(val1), Precision::Inexact(val2)) - | (Precision::Inexact(val1), Precision::Exact(val2)) - if val1 < val2 => - { - *max_value = max_nominee.clone().to_inexact(); - } - (Precision::Exact(_), Precision::Absent) => { - let exact_max = mem::take(max_value); - *max_value = exact_max.to_inexact(); - } - (Precision::Absent, Precision::Exact(_)) => { - *max_value = max_nominee.clone().to_inexact(); - } - (Precision::Absent, Precision::Inexact(_)) => { - *max_value = max_nominee.clone(); - } - _ => {} - } -} - -/// If the given value is numerically lesser than the original minimum value, -/// return the new minimum value with appropriate exactness information. -fn set_min_if_lesser( - min_nominee: &Precision, - min_value: &mut Precision, -) { - match (&min_value, min_nominee) { - (Precision::Exact(val1), Precision::Exact(val2)) if val1 > val2 => { - *min_value = min_nominee.clone(); - } - (Precision::Exact(val1), Precision::Inexact(val2)) - | (Precision::Inexact(val1), Precision::Inexact(val2)) - | (Precision::Inexact(val1), Precision::Exact(val2)) - if val1 > val2 => - { - *min_value = min_nominee.clone().to_inexact(); - } - (Precision::Exact(_), Precision::Absent) => { - let exact_min = mem::take(min_value); - *min_value = exact_min.to_inexact(); - } - (Precision::Absent, Precision::Exact(_)) => { - *min_value = min_nominee.clone().to_inexact(); - } - (Precision::Absent, Precision::Inexact(_)) => { - *min_value = min_nominee.clone(); - } - _ => {} - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; - use std::sync::Arc; - - #[test] - fn test_compute_summary_statistics_basic() { - // Create a schema with two columns - let schema = Arc::new(Schema::new(vec![ - Field::new("col1", DataType::Int32, false), - Field::new("col2", DataType::Int32, false), - ])); - - // Create items with statistics - let stats1 = Statistics { - num_rows: Precision::Exact(10), - total_byte_size: Precision::Exact(100), - column_statistics: vec![ - ColumnStatistics { - null_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Int32(Some(100))), - min_value: Precision::Exact(ScalarValue::Int32(Some(1))), - sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), - distinct_count: Precision::Absent, - }, - ColumnStatistics { - null_count: Precision::Exact(2), - max_value: Precision::Exact(ScalarValue::Int32(Some(200))), - min_value: Precision::Exact(ScalarValue::Int32(Some(10))), - sum_value: Precision::Exact(ScalarValue::Int32(Some(1000))), - distinct_count: Precision::Absent, - }, - ], - }; - - let stats2 = Statistics { - num_rows: Precision::Exact(15), - total_byte_size: Precision::Exact(150), - column_statistics: vec![ - ColumnStatistics { - null_count: Precision::Exact(2), - max_value: Precision::Exact(ScalarValue::Int32(Some(120))), - min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), - sum_value: Precision::Exact(ScalarValue::Int32(Some(600))), - distinct_count: Precision::Absent, - }, - ColumnStatistics { - null_count: Precision::Exact(3), - max_value: Precision::Exact(ScalarValue::Int32(Some(180))), - min_value: Precision::Exact(ScalarValue::Int32(Some(5))), - sum_value: Precision::Exact(ScalarValue::Int32(Some(1200))), - distinct_count: Precision::Absent, - }, - ], - }; - - let items = vec![Arc::new(stats1), Arc::new(stats2)]; - - // Call compute_summary_statistics - let summary_stats = - compute_summary_statistics(items, &schema, |item| Some(item.as_ref())); - - // Verify the results - assert_eq!(summary_stats.num_rows, Precision::Exact(25)); // 10 + 15 - assert_eq!(summary_stats.total_byte_size, Precision::Exact(250)); // 100 + 150 - - // Verify column statistics - let col1_stats = &summary_stats.column_statistics[0]; - assert_eq!(col1_stats.null_count, Precision::Exact(3)); // 1 + 2 - assert_eq!( - col1_stats.max_value, - Precision::Exact(ScalarValue::Int32(Some(120))) - ); - assert_eq!( - col1_stats.min_value, - Precision::Exact(ScalarValue::Int32(Some(-10))) - ); - assert_eq!( - col1_stats.sum_value, - Precision::Exact(ScalarValue::Int32(Some(1100))) - ); // 500 + 600 - - let col2_stats = &summary_stats.column_statistics[1]; - assert_eq!(col2_stats.null_count, Precision::Exact(5)); // 2 + 3 - assert_eq!( - col2_stats.max_value, - Precision::Exact(ScalarValue::Int32(Some(200))) - ); - assert_eq!( - col2_stats.min_value, - Precision::Exact(ScalarValue::Int32(Some(5))) - ); - assert_eq!( - col2_stats.sum_value, - Precision::Exact(ScalarValue::Int32(Some(2200))) - ); // 1000 + 1200 - } - - #[test] - fn test_compute_summary_statistics_mixed_precision() { - // Create a schema with one column - let schema = Arc::new(Schema::new(vec![Field::new( - "col1", - DataType::Int32, - false, - )])); - - // Create items with different precision levels - let stats1 = Statistics { - num_rows: Precision::Exact(10), - total_byte_size: Precision::Inexact(100), - column_statistics: vec![ColumnStatistics { - null_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::Int32(Some(100))), - min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), - sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), - distinct_count: Precision::Absent, - }], - }; - - let stats2 = Statistics { - num_rows: Precision::Inexact(15), - total_byte_size: Precision::Exact(150), - column_statistics: vec![ColumnStatistics { - null_count: Precision::Inexact(2), - max_value: Precision::Inexact(ScalarValue::Int32(Some(120))), - min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - }], - }; - - let items = vec![Arc::new(stats1), Arc::new(stats2)]; - - let summary_stats = - compute_summary_statistics(items, &schema, |item| Some(item.as_ref())); - - assert_eq!(summary_stats.num_rows, Precision::Inexact(25)); - assert_eq!(summary_stats.total_byte_size, Precision::Inexact(250)); - - let col_stats = &summary_stats.column_statistics[0]; - assert_eq!(col_stats.null_count, Precision::Inexact(3)); - assert_eq!( - col_stats.max_value, - Precision::Inexact(ScalarValue::Int32(Some(120))) - ); - assert_eq!( - col_stats.min_value, - Precision::Inexact(ScalarValue::Int32(Some(-10))) - ); - assert!(matches!(col_stats.sum_value, Precision::Absent)); - } - - #[test] - fn test_compute_summary_statistics_empty() { - let schema = Arc::new(Schema::new(vec![Field::new( - "col1", - DataType::Int32, - false, - )])); - - // Empty collection - let items: Vec> = vec![]; - - let summary_stats = - compute_summary_statistics(items, &schema, |item| Some(item.as_ref())); - - // Verify default values for empty collection - assert_eq!(summary_stats.num_rows, Precision::Absent); - assert_eq!(summary_stats.total_byte_size, Precision::Absent); - assert_eq!(summary_stats.column_statistics.len(), 1); - assert_eq!( - summary_stats.column_statistics[0].null_count, - Precision::Absent - ); - } + file_num_rows.add(&num_rows) } diff --git a/datafusion/datasource/src/test_util.rs b/datafusion/datasource/src/test_util.rs index 9a9b98d5041b0..f0aff1fa62b70 100644 --- a/datafusion/datasource/src/test_util.rs +++ b/datafusion/datasource/src/test_util.rs @@ -17,12 +17,14 @@ use crate::{ file::FileSource, file_scan_config::FileScanConfig, file_stream::FileOpener, + schema_adapter::SchemaAdapterFactory, }; use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{Result, Statistics}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use object_store::ObjectStore; @@ -31,6 +33,15 @@ use object_store::ObjectStore; pub(crate) struct MockSource { metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, + filter: Option>, +} + +impl MockSource { + pub fn with_filter(mut self, filter: Arc) -> Self { + self.filter = Some(filter); + self + } } impl FileSource for MockSource { @@ -47,6 +58,10 @@ impl FileSource for MockSource { self } + fn filter(&self) -> Option> { + self.filter.clone() + } + fn with_batch_size(&self, _batch_size: usize) -> Arc { Arc::new(Self { ..self.clone() }) } @@ -80,4 +95,23 @@ impl FileSource for MockSource { fn file_type(&self) -> &str { "mock" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +/// Create a column expression +pub(crate) fn col(name: &str, schema: &Schema) -> Result> { + Ok(Arc::new(Column::new_with_schema(name, schema)?)) } diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 2dbcfa2ef1fae..c87b307c5fb80 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -209,10 +209,10 @@ impl ListingTableUrl { /// assert_eq!(url.file_extension(), None); /// ``` pub fn file_extension(&self) -> Option<&str> { - if let Some(segments) = self.url.path_segments() { - if let Some(last_segment) = segments.last() { + if let Some(mut segments) = self.url.path_segments() { + if let Some(last_segment) = segments.next_back() { if last_segment.contains(".") && !last_segment.ends_with(".") { - return last_segment.split('.').last(); + return last_segment.split('.').next_back(); } } } @@ -242,25 +242,20 @@ impl ListingTableUrl { ) -> Result>> { let exec_options = &ctx.config_options().execution; let ignore_subdirectory = exec_options.listing_table_ignore_subdirectory; - // If the prefix is a file, use a head request, otherwise list - let list = match self.is_collection() { - true => match ctx.runtime_env().cache_manager.get_list_files_cache() { - None => store.list(Some(&self.prefix)), - Some(cache) => { - if let Some(res) = cache.get(&self.prefix) { - debug!("Hit list all files cache"); - futures::stream::iter(res.as_ref().clone().into_iter().map(Ok)) - .boxed() - } else { - let list_res = store.list(Some(&self.prefix)); - let vec = list_res.try_collect::>().await?; - cache.put(&self.prefix, Arc::new(vec.clone())); - futures::stream::iter(vec.into_iter().map(Ok)).boxed() - } - } - }, - false => futures::stream::once(store.head(&self.prefix)).boxed(), + + let list: BoxStream<'a, Result> = if self.is_collection() { + list_with_cache(ctx, store, &self.prefix).await? + } else { + match store.head(&self.prefix).await { + Ok(meta) => futures::stream::once(async { Ok(meta) }) + .map_err(|e| DataFusionError::ObjectStore(Box::new(e))) + .boxed(), + // If the head command fails, it is likely that object doesn't exist. + // Retry as though it were a prefix (aka a collection) + Err(_) => list_with_cache(ctx, store, &self.prefix).await?, + } }; + Ok(list .try_filter(move |meta| { let path = &meta.location; @@ -268,7 +263,6 @@ impl ListingTableUrl { let glob_match = self.contains(path, ignore_subdirectory); futures::future::ready(extension_match && glob_match) }) - .map_err(DataFusionError::ObjectStore) .boxed()) } @@ -282,6 +276,55 @@ impl ListingTableUrl { let url = &self.url[url::Position::BeforeScheme..url::Position::BeforePath]; ObjectStoreUrl::parse(url).unwrap() } + + /// Returns true if the [`ListingTableUrl`] points to the folder + pub fn is_folder(&self) -> bool { + self.url.scheme() == "file" && self.is_collection() + } + + /// Return the `url` for [`ListingTableUrl`] + pub fn get_url(&self) -> &Url { + &self.url + } + + /// Return the `glob` for [`ListingTableUrl`] + pub fn get_glob(&self) -> &Option { + &self.glob + } + + /// Returns a copy of current [`ListingTableUrl`] with a specified `glob` + pub fn with_glob(self, glob: &str) -> Result { + let glob = + Pattern::new(glob).map_err(|e| DataFusionError::External(Box::new(e)))?; + Self::try_new(self.url, Some(glob)) + } +} + +async fn list_with_cache<'b>( + ctx: &'b dyn Session, + store: &'b dyn ObjectStore, + prefix: &'b Path, +) -> Result>> { + match ctx.runtime_env().cache_manager.get_list_files_cache() { + None => Ok(store + .list(Some(prefix)) + .map(|res| res.map_err(|e| DataFusionError::ObjectStore(Box::new(e)))) + .boxed()), + Some(cache) => { + let vec = if let Some(res) = cache.get(prefix) { + debug!("Hit list all files cache"); + res.as_ref().clone() + } else { + let vec = store + .list(Some(prefix)) + .try_collect::>() + .await?; + cache.put(prefix, Arc::new(vec.clone())); + vec + }; + Ok(futures::stream::iter(vec.into_iter().map(Ok)).boxed()) + } + } } /// Creates a file URL from a potentially relative filesystem path @@ -362,6 +405,18 @@ fn split_glob_expression(path: &str) -> Option<(&str, &str)> { #[cfg(test)] mod tests { use super::*; + use datafusion_common::config::TableOptions; + use datafusion_common::DFSchema; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::runtime_env::RuntimeEnv; + use datafusion_execution::TaskContext; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_plan::ExecutionPlan; + use object_store::PutPayload; + use std::any::Any; + use std::collections::HashMap; use tempfile::tempdir; #[test] @@ -575,4 +630,165 @@ mod tests { "file path ends with .ext - extension is ext", ); } + + #[tokio::test] + async fn test_list_files() { + let store = object_store::memory::InMemory::new(); + // Create some files: + create_file(&store, "a.parquet").await; + create_file(&store, "/t/b.parquet").await; + create_file(&store, "/t/c.csv").await; + create_file(&store, "/t/d.csv").await; + + assert_eq!( + list_all_files("/", &store, "parquet").await, + vec!["a.parquet"], + ); + + // test with and without trailing slash + assert_eq!( + list_all_files("/t/", &store, "parquet").await, + vec!["t/b.parquet"], + ); + assert_eq!( + list_all_files("/t", &store, "parquet").await, + vec!["t/b.parquet"], + ); + + // test with and without trailing slash + assert_eq!( + list_all_files("/t", &store, "csv").await, + vec!["t/c.csv", "t/d.csv"], + ); + assert_eq!( + list_all_files("/t/", &store, "csv").await, + vec!["t/c.csv", "t/d.csv"], + ); + + // Test a non existing prefix + assert_eq!( + list_all_files("/NonExisting", &store, "csv").await, + vec![] as Vec + ); + assert_eq!( + list_all_files("/NonExisting/", &store, "csv").await, + vec![] as Vec + ); + } + + /// Creates a file with "hello world" content at the specified path + async fn create_file(object_store: &dyn ObjectStore, path: &str) { + object_store + .put(&Path::from(path), PutPayload::from_static(b"hello world")) + .await + .expect("failed to create test file"); + } + + /// Runs "list_all_files" and returns their paths + /// + /// Panic's on error + async fn list_all_files( + url: &str, + store: &dyn ObjectStore, + file_extension: &str, + ) -> Vec { + try_list_all_files(url, store, file_extension) + .await + .unwrap() + } + + /// Runs "list_all_files" and returns their paths + async fn try_list_all_files( + url: &str, + store: &dyn ObjectStore, + file_extension: &str, + ) -> Result> { + let session = MockSession::new(); + let url = ListingTableUrl::parse(url)?; + let files = url + .list_all_files(&session, store, file_extension) + .await? + .try_collect::>() + .await? + .into_iter() + .map(|meta| meta.location.as_ref().to_string()) + .collect(); + Ok(files) + } + + struct MockSession { + config: SessionConfig, + runtime_env: Arc, + } + + impl MockSession { + fn new() -> Self { + Self { + config: SessionConfig::new(), + runtime_env: Arc::new(RuntimeEnv::default()), + } + } + } + + #[async_trait::async_trait] + impl Session for MockSession { + fn session_id(&self) -> &str { + unimplemented!() + } + + fn config(&self) -> &SessionConfig { + &self.config + } + + async fn create_physical_plan( + &self, + _logical_plan: &LogicalPlan, + ) -> Result> { + unimplemented!() + } + + fn create_physical_expr( + &self, + _expr: Expr, + _df_schema: &DFSchema, + ) -> Result> { + unimplemented!() + } + + fn scalar_functions(&self) -> &HashMap> { + unimplemented!() + } + + fn aggregate_functions(&self) -> &HashMap> { + unimplemented!() + } + + fn window_functions(&self) -> &HashMap> { + unimplemented!() + } + + fn runtime_env(&self) -> &Arc { + &self.runtime_env + } + + fn execution_props(&self) -> &ExecutionProps { + unimplemented!() + } + + fn as_any(&self) -> &dyn Any { + unimplemented!() + } + + fn table_options(&self) -> &TableOptions { + unimplemented!() + } + + fn table_options_mut(&mut self) -> &mut TableOptions { + unimplemented!() + } + + fn task_ctx(&self) -> Arc { + unimplemented!() + } + } } diff --git a/datafusion/datasource/src/write/demux.rs b/datafusion/datasource/src/write/demux.rs index fc2e5daf92b66..e80099823054d 100644 --- a/datafusion/datasource/src/write/demux.rs +++ b/datafusion/datasource/src/write/demux.rs @@ -28,8 +28,8 @@ use datafusion_common::error::Result; use datafusion_physical_plan::SendableRecordBatchStream; use arrow::array::{ - builder::UInt64Builder, cast::AsArray, downcast_dictionary_array, RecordBatch, - StringArray, StructArray, + builder::UInt64Builder, cast::AsArray, downcast_dictionary_array, ArrayAccessor, + RecordBatch, StringArray, StructArray, }; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::{ @@ -38,14 +38,14 @@ use datafusion_common::cast::{ as_int8_array, as_string_array, as_string_view_array, as_uint16_array, as_uint32_array, as_uint64_array, as_uint8_array, }; -use datafusion_common::{exec_datafusion_err, not_impl_err, DataFusionError}; +use datafusion_common::{exec_datafusion_err, internal_datafusion_err, not_impl_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; use chrono::NaiveDate; use futures::StreamExt; use object_store::path::Path; -use rand::distributions::DistString; +use rand::distr::SampleString; use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; type RecordBatchReceiver = Receiver; @@ -151,8 +151,7 @@ async fn row_count_demuxer( let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; let minimum_parallel_files = exec_options.minimum_parallel_output_files; let mut part_idx = 0; - let write_id = - rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16); let mut open_file_streams = Vec::with_capacity(minimum_parallel_files); @@ -204,9 +203,7 @@ async fn row_count_demuxer( .send(rb) .await .map_err(|_| { - DataFusionError::Execution( - "Error sending RecordBatch to file stream!".into(), - ) + exec_datafusion_err!("Error sending RecordBatch to file stream!") })?; next_send_steam = (next_send_steam + 1) % minimum_parallel_files; @@ -225,7 +222,7 @@ fn generate_file_path( if !single_file_output { base_output_path .prefix() - .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + .child(format!("{write_id}_{part_idx}.{file_extension}")) } else { base_output_path.prefix().to_owned() } @@ -249,9 +246,8 @@ fn create_new_file_stream( single_file_output, ); let (tx_file, rx_file) = mpsc::channel(max_buffered_batches / 2); - tx.send((file_path, rx_file)).map_err(|_| { - DataFusionError::Execution("Error sending RecordBatch to file stream!".into()) - })?; + tx.send((file_path, rx_file)) + .map_err(|_| exec_datafusion_err!("Error sending RecordBatch to file stream!"))?; Ok(tx_file) } @@ -267,8 +263,7 @@ async fn hive_style_partitions_demuxer( file_extension: String, keep_partition_by_columns: bool, ) -> Result<()> { - let write_id = - rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16); let exec_options = &context.session_config().options().execution; let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; @@ -309,17 +304,13 @@ async fn hive_style_partitions_demuxer( ); tx.send((file_path, part_rx)).map_err(|_| { - DataFusionError::Execution( - "Error sending new file stream!".into(), - ) + exec_datafusion_err!("Error sending new file stream!") })?; value_map.insert(part_key.clone(), part_tx); - value_map - .get_mut(&part_key) - .ok_or(DataFusionError::Internal( - "Key must exist since it was just inserted!".into(), - ))? + value_map.get_mut(&part_key).ok_or_else(|| { + exec_datafusion_err!("Key must exist since it was just inserted!") + })? } }; @@ -331,7 +322,7 @@ async fn hive_style_partitions_demuxer( // Finally send the partial batch partitioned by distinct value! part_tx.send(final_batch_to_send).await.map_err(|_| { - DataFusionError::Internal("Unexpected error sending parted batch!".into()) + internal_datafusion_err!("Unexpected error sending parted batch!") })?; } } @@ -482,10 +473,8 @@ fn compute_partition_keys_by_row<'a>( .ok_or(exec_datafusion_err!("it is not yet supported to write to hive partitions with datatype {}", dtype))?; - for val in array.values() { - partition_values.push( - Cow::from(val.ok_or(exec_datafusion_err!("Cannot partition by null value for column {}", col))?), - ); + for i in 0..rb.num_rows() { + partition_values.push(Cow::from(array.value(i))); } }, _ => unreachable!(), @@ -515,7 +504,7 @@ fn compute_take_arrays( for vals in all_partition_values.iter() { part_key.push(vals[i].clone().into()); } - let builder = take_map.entry(part_key).or_insert(UInt64Builder::new()); + let builder = take_map.entry(part_key).or_insert_with(UInt64Builder::new); builder.append_value(i as u64); } take_map @@ -558,5 +547,5 @@ fn compute_hive_style_file_path( file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j])); } - file_path.child(format!("{}.{}", write_id, file_extension)) + file_path.child(format!("{write_id}.{file_extension}")) } diff --git a/datafusion/datasource/src/write/mod.rs b/datafusion/datasource/src/write/mod.rs index f581126095a7e..3694568682a5d 100644 --- a/datafusion/datasource/src/write/mod.rs +++ b/datafusion/datasource/src/write/mod.rs @@ -77,15 +77,18 @@ pub trait BatchSerializer: Sync + Send { /// Returns an [`AsyncWrite`] which writes to the given object store location /// with the specified compression. +/// +/// The writer will have a default buffer size as chosen by [`BufWriter::new`]. +/// /// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. /// Users can configure automatic cleanup with their cloud provider. +#[deprecated(since = "48.0.0", note = "Use ObjectWriterBuilder::new(...) instead")] pub async fn create_writer( file_compression_type: FileCompressionType, location: &Path, object_store: Arc, ) -> Result> { - let buf_writer = BufWriter::new(object_store, location.clone()); - file_compression_type.convert_async_writer(buf_writer) + ObjectWriterBuilder::new(file_compression_type, location, object_store).build() } /// Converts table schema to writer schema, which may differ in the case @@ -109,3 +112,108 @@ pub fn get_writer_schema(config: &FileSinkConfig) -> Arc { Arc::clone(config.output_schema()) } } + +/// A builder for an [`AsyncWrite`] that writes to an object store location. +/// +/// This can be used to specify file compression on the writer. The writer +/// will have a default buffer size unless altered. The specific default size +/// is chosen by [`BufWriter::new`]. +/// +/// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. +/// Users can configure automatic cleanup with their cloud provider. +#[derive(Debug)] +pub struct ObjectWriterBuilder { + /// Compression type for object writer. + file_compression_type: FileCompressionType, + /// Output path + location: Path, + /// The related store that handles the given path + object_store: Arc, + /// The size of the buffer for the object writer. + buffer_size: Option, +} + +impl ObjectWriterBuilder { + /// Create a new [`ObjectWriterBuilder`] for the specified path and compression type. + pub fn new( + file_compression_type: FileCompressionType, + location: &Path, + object_store: Arc, + ) -> Self { + Self { + file_compression_type, + location: location.clone(), + object_store, + buffer_size: None, + } + } + + /// Set buffer size in bytes for object writer. + /// + /// # Example + /// ``` + /// # use datafusion_datasource::file_compression_type::FileCompressionType; + /// # use datafusion_datasource::write::ObjectWriterBuilder; + /// # use object_store::memory::InMemory; + /// # use object_store::path::Path; + /// # use std::sync::Arc; + /// # let compression_type = FileCompressionType::UNCOMPRESSED; + /// # let location = Path::from("/foo/bar"); + /// # let object_store = Arc::new(InMemory::new()); + /// let mut builder = ObjectWriterBuilder::new(compression_type, &location, object_store); + /// builder.set_buffer_size(Some(20 * 1024 * 1024)); //20 MiB + /// assert_eq!(builder.get_buffer_size(), Some(20 * 1024 * 1024), "Internal error: Builder buffer size doesn't match"); + /// ``` + pub fn set_buffer_size(&mut self, buffer_size: Option) { + self.buffer_size = buffer_size; + } + + /// Set buffer size in bytes for object writer, returning the builder. + /// + /// # Example + /// ``` + /// # use datafusion_datasource::file_compression_type::FileCompressionType; + /// # use datafusion_datasource::write::ObjectWriterBuilder; + /// # use object_store::memory::InMemory; + /// # use object_store::path::Path; + /// # use std::sync::Arc; + /// # let compression_type = FileCompressionType::UNCOMPRESSED; + /// # let location = Path::from("/foo/bar"); + /// # let object_store = Arc::new(InMemory::new()); + /// let builder = ObjectWriterBuilder::new(compression_type, &location, object_store) + /// .with_buffer_size(Some(20 * 1024 * 1024)); //20 MiB + /// assert_eq!(builder.get_buffer_size(), Some(20 * 1024 * 1024), "Internal error: Builder buffer size doesn't match"); + /// ``` + pub fn with_buffer_size(mut self, buffer_size: Option) -> Self { + self.buffer_size = buffer_size; + self + } + + /// Currently specified buffer size in bytes. + pub fn get_buffer_size(&self) -> Option { + self.buffer_size + } + + /// Return a writer object that writes to the object store location. + /// + /// If a buffer size has not been set, the default buffer buffer size will + /// be used. + /// + /// # Errors + /// If there is an error applying the compression type. + pub fn build(self) -> Result> { + let Self { + file_compression_type, + location, + object_store, + buffer_size, + } = self; + + let buf_writer = match buffer_size { + Some(size) => BufWriter::with_capacity(object_store, location, size), + None => BufWriter::new(object_store, location), + }; + + file_compression_type.convert_async_writer(buf_writer) + } +} diff --git a/datafusion/datasource/src/write/orchestration.rs b/datafusion/datasource/src/write/orchestration.rs index 0ac1d26c6cc19..ab836b7b7f388 100644 --- a/datafusion/datasource/src/write/orchestration.rs +++ b/datafusion/datasource/src/write/orchestration.rs @@ -22,12 +22,14 @@ use std::sync::Arc; use super::demux::DemuxedStreamReceiver; -use super::{create_writer, BatchSerializer}; +use super::{BatchSerializer, ObjectWriterBuilder}; use crate::file_compression_type::FileCompressionType; use datafusion_common::error::Result; use arrow::array::RecordBatch; -use datafusion_common::{internal_datafusion_err, internal_err, DataFusionError}; +use datafusion_common::{ + exec_datafusion_err, internal_datafusion_err, internal_err, DataFusionError, +}; use datafusion_common_runtime::{JoinSet, SpawnedTask}; use datafusion_execution::TaskContext; @@ -117,9 +119,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( Err(e) => { return SerializedRecordBatchResult::failure( None, - DataFusionError::Execution(format!( - "Error writing to object store: {e}" - )), + exec_datafusion_err!("Error writing to object store: {e}"), ) } }; @@ -133,9 +133,9 @@ pub(crate) async fn serialize_rb_stream_to_object_store( // Handle task panic or cancellation return SerializedRecordBatchResult::failure( Some(writer), - DataFusionError::Execution(format!( + exec_datafusion_err!( "Serialization task panicked or was cancelled: {e}" - )), + ), ); } } @@ -257,7 +257,15 @@ pub async fn spawn_writer_tasks_and_join( }); while let Some((location, rb_stream)) = file_stream_rx.recv().await { let writer = - create_writer(compression, &location, Arc::clone(&object_store)).await?; + ObjectWriterBuilder::new(compression, &location, Arc::clone(&object_store)) + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; if tx_file_bundle .send((rb_stream, Arc::clone(&serializer), writer)) @@ -277,8 +285,8 @@ pub async fn spawn_writer_tasks_and_join( write_coordinator_task.join_unwind(), demux_task.join_unwind() ); - r1.map_err(DataFusionError::ExecutionJoin)??; - r2.map_err(DataFusionError::ExecutionJoin)??; + r1.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + r2.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; // Return total row count: rx_row_cnt.await.map_err(|_| { diff --git a/datafusion/doc/Cargo.toml b/datafusion/doc/Cargo.toml index fa316348a6daa..b8324565a0c67 100644 --- a/datafusion/doc/Cargo.toml +++ b/datafusion/doc/Cargo.toml @@ -19,6 +19,7 @@ name = "datafusion-doc" description = "Documentation module for DataFusion query engine" keywords = ["datafusion", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } homepage = { workspace = true } diff --git a/datafusion/doc/README.md b/datafusion/doc/README.md new file mode 100644 index 0000000000000..f137a273e31ab --- /dev/null +++ b/datafusion/doc/README.md @@ -0,0 +1,33 @@ + + +# Apache DataFusion Documentation + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate is a submodule of DataFusion that provides structures and macros +for documenting user defined functions. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs index 68ed1e2352ca4..977130ffc0d6a 100644 --- a/datafusion/doc/src/lib.rs +++ b/datafusion/doc/src/lib.rs @@ -19,7 +19,15 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +mod udaf; +mod udf; +mod udwf; + +pub use udaf::aggregate_doc_sections; +pub use udf::scalar_doc_sections; +pub use udwf::window_doc_sections; #[allow(rustdoc::broken_intra_doc_links)] /// Documentation for use by [`ScalarUDFImpl`](ScalarUDFImpl), @@ -39,7 +47,7 @@ /// thus all text should be in English. /// /// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Documentation { /// The section in the documentation where the UDF will be documented pub doc_section: DocSection, @@ -93,7 +101,7 @@ impl Documentation { self.doc_section.label, self.doc_section .description - .map(|s| format!(", description = \"{}\"", s)) + .map(|s| format!(", description = \"{s}\"")) .unwrap_or_default(), ) .as_ref(), @@ -110,7 +118,7 @@ impl Documentation { &self .sql_example .clone() - .map(|s| format!("\n sql_example = r#\"{}\"#,", s)) + .map(|s| format!("\n sql_example = r#\"{s}\"#,")) .unwrap_or_default(), ); @@ -120,7 +128,7 @@ impl Documentation { args.iter().for_each(|(name, value)| { if value.contains(st_arg_token) { if name.starts_with("The ") { - result.push_str(format!("\n standard_argument(\n name = \"{}\"),", name).as_ref()); + result.push_str(format!("\n standard_argument(\n name = \"{name}\"),").as_ref()); } else { result.push_str(format!("\n standard_argument(\n name = \"{}\",\n prefix = \"{}\"\n ),", name, value.replace(st_arg_token, "")).as_ref()); } @@ -132,7 +140,7 @@ impl Documentation { if let Some(args) = self.arguments.clone() { args.iter().for_each(|(name, value)| { if !value.contains(st_arg_token) { - result.push_str(format!("\n argument(\n name = \"{}\",\n description = \"{}\"\n ),", name, value).as_ref()); + result.push_str(format!("\n argument(\n name = \"{name}\",\n description = \"{value}\"\n ),").as_ref()); } }); } @@ -140,7 +148,7 @@ impl Documentation { if let Some(alt_syntax) = self.alternative_syntax.clone() { alt_syntax.iter().for_each(|syntax| { result.push_str( - format!("\n alternative_syntax = \"{}\",", syntax).as_ref(), + format!("\n alternative_syntax = \"{syntax}\",").as_ref(), ); }); } @@ -148,8 +156,7 @@ impl Documentation { // Related UDFs if let Some(related_udf) = self.related_udfs.clone() { related_udf.iter().for_each(|udf| { - result - .push_str(format!("\n related_udf(name = \"{}\"),", udf).as_ref()); + result.push_str(format!("\n related_udf(name = \"{udf}\"),").as_ref()); }); } @@ -159,7 +166,7 @@ impl Documentation { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DocSection { /// True to include this doc section in the public /// documentation, false otherwise @@ -213,15 +220,6 @@ pub struct DocumentationBuilder { } impl DocumentationBuilder { - #[allow(clippy::new_without_default)] - #[deprecated( - since = "44.0.0", - note = "please use `DocumentationBuilder::new_with_details` instead" - )] - pub fn new() -> Self { - Self::new_with_details(DocSection::default(), "", "") - } - /// Creates a new [`DocumentationBuilder`] with all required fields pub fn new_with_details( doc_section: DocSection, diff --git a/datafusion/doc/src/udaf.rs b/datafusion/doc/src/udaf.rs new file mode 100644 index 0000000000000..c3a0b4adbcb1e --- /dev/null +++ b/datafusion/doc/src/udaf.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Aggregate UDF doc sections for use in public documentation +pub mod aggregate_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_GENERAL, + DOC_SECTION_STATISTICAL, + DOC_SECTION_APPROXIMATE, + ] + } + + pub const DOC_SECTION_GENERAL: DocSection = DocSection { + include: true, + label: "General Functions", + description: None, + }; + + pub const DOC_SECTION_STATISTICAL: DocSection = DocSection { + include: true, + label: "Statistical Functions", + description: None, + }; + + pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection { + include: true, + label: "Approximate Functions", + description: None, + }; +} diff --git a/datafusion/doc/src/udf.rs b/datafusion/doc/src/udf.rs new file mode 100644 index 0000000000000..3d18c9ac2714e --- /dev/null +++ b/datafusion/doc/src/udf.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Scalar UDF doc sections for use in public documentation +pub mod scalar_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_MATH, + DOC_SECTION_CONDITIONAL, + DOC_SECTION_STRING, + DOC_SECTION_BINARY_STRING, + DOC_SECTION_REGEX, + DOC_SECTION_DATETIME, + DOC_SECTION_ARRAY, + DOC_SECTION_STRUCT, + DOC_SECTION_MAP, + DOC_SECTION_HASHING, + DOC_SECTION_UNION, + DOC_SECTION_OTHER, + ] + } + + pub const fn doc_sections_const() -> &'static [DocSection] { + &[ + DOC_SECTION_MATH, + DOC_SECTION_CONDITIONAL, + DOC_SECTION_STRING, + DOC_SECTION_BINARY_STRING, + DOC_SECTION_REGEX, + DOC_SECTION_DATETIME, + DOC_SECTION_ARRAY, + DOC_SECTION_STRUCT, + DOC_SECTION_MAP, + DOC_SECTION_HASHING, + DOC_SECTION_UNION, + DOC_SECTION_OTHER, + ] + } + + pub const DOC_SECTION_MATH: DocSection = DocSection { + include: true, + label: "Math Functions", + description: None, + }; + + pub const DOC_SECTION_CONDITIONAL: DocSection = DocSection { + include: true, + label: "Conditional Functions", + description: None, + }; + + pub const DOC_SECTION_STRING: DocSection = DocSection { + include: true, + label: "String Functions", + description: None, + }; + + pub const DOC_SECTION_BINARY_STRING: DocSection = DocSection { + include: true, + label: "Binary String Functions", + description: None, + }; + + pub const DOC_SECTION_REGEX: DocSection = DocSection { + include: true, + label: "Regular Expression Functions", + description: Some( + r#"Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions) +regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) +(minus support for several features including look-around and backreferences). +The following regular expression functions are supported:"#, + ), + }; + + pub const DOC_SECTION_DATETIME: DocSection = DocSection { + include: true, + label: "Time and Date Functions", + description: None, + }; + + pub const DOC_SECTION_ARRAY: DocSection = DocSection { + include: true, + label: "Array Functions", + description: None, + }; + + pub const DOC_SECTION_STRUCT: DocSection = DocSection { + include: true, + label: "Struct Functions", + description: None, + }; + + pub const DOC_SECTION_MAP: DocSection = DocSection { + include: true, + label: "Map Functions", + description: None, + }; + + pub const DOC_SECTION_HASHING: DocSection = DocSection { + include: true, + label: "Hashing Functions", + description: None, + }; + + pub const DOC_SECTION_OTHER: DocSection = DocSection { + include: true, + label: "Other Functions", + description: None, + }; + + pub const DOC_SECTION_UNION: DocSection = DocSection { + include: true, + label: "Union Functions", + description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"), + }; +} diff --git a/datafusion/doc/src/udwf.rs b/datafusion/doc/src/udwf.rs new file mode 100644 index 0000000000000..0257ce5ba66b5 --- /dev/null +++ b/datafusion/doc/src/udwf.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Window UDF doc sections for use in public documentation +pub mod window_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_AGGREGATE, + DOC_SECTION_RANKING, + DOC_SECTION_ANALYTICAL, + ] + } + + pub const DOC_SECTION_AGGREGATE: DocSection = DocSection { + include: true, + label: "Aggregate Functions", + description: Some("All aggregate functions can be used as window functions."), + }; + + pub const DOC_SECTION_RANKING: DocSection = DocSection { + include: true, + label: "Ranking Functions", + description: None, + }; + + pub const DOC_SECTION_ANALYTICAL: DocSection = DocSection { + include: true, + label: "Analytical Functions", + description: None, + }; +} diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index 8f642f3384d2e..67a37a86c7066 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -37,18 +37,29 @@ workspace = true [lib] name = "datafusion_execution" +[features] +default = ["sql"] + +parquet_encryption = [ + "parquet/encryption", +] +sql = [] + [dependencies] arrow = { workspace = true } +async-trait = { workspace = true } dashmap = { workspace = true } -datafusion-common = { workspace = true, default-features = true } -datafusion-expr = { workspace = true } +datafusion-common = { workspace = true, default-features = false } +datafusion-expr = { workspace = true, default-features = false } futures = { workspace = true } log = { workspace = true } -object_store = { workspace = true } +object_store = { workspace = true, features = ["fs"] } parking_lot = { workspace = true } +parquet = { workspace = true, optional = true } rand = { workspace = true } tempfile = { workspace = true } url = { workspace = true } [dev-dependencies] chrono = { workspace = true } +insta = { workspace = true } diff --git a/datafusion/execution/README.md b/datafusion/execution/README.md index 8a03255ee4ad3..5b1528b0daab9 100644 --- a/datafusion/execution/README.md +++ b/datafusion/execution/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Execution +# Apache DataFusion Execution -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides execution runtime such as the memory pools and disk manager. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/execution/src/cache/cache_manager.rs b/datafusion/execution/src/cache/cache_manager.rs index c2403e34c6657..3e0f4065d13f5 100644 --- a/datafusion/execution/src/cache/cache_manager.rs +++ b/datafusion/execution/src/cache/cache_manager.rs @@ -15,23 +15,94 @@ // specific language governing permissions and limitations // under the License. +use crate::cache::cache_unit::DefaultFilesMetadataCache; use crate::cache::CacheAccessor; use datafusion_common::{Result, Statistics}; use object_store::path::Path; use object_store::ObjectMeta; +use std::any::Any; +use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -/// The cache of listing files statistics. -/// if set [`CacheManagerConfig::with_files_statistics_cache`] -/// Will avoid infer same file statistics repeatedly during the session lifetime, -/// this cache will store in [`crate::runtime_env::RuntimeEnv`]. +/// A cache for [`Statistics`]. +/// +/// If enabled via [`CacheManagerConfig::with_files_statistics_cache`] this +/// cache avoids inferring the same file statistics repeatedly during the +/// session lifetime. +/// +/// See [`crate::runtime_env::RuntimeEnv`] for more details pub type FileStatisticsCache = Arc, Extra = ObjectMeta>>; +/// Cache for storing the [`ObjectMeta`]s that result from listing a path +/// +/// Listing a path means doing an object store "list" operation or `ls` +/// command on the local filesystem. This operation can be expensive, +/// especially when done over remote object stores. +/// +/// See [`crate::runtime_env::RuntimeEnv`] for more details pub type ListFilesCache = Arc>, Extra = ObjectMeta>>; +/// Generic file-embedded metadata used with [`FileMetadataCache`]. +/// +/// For example, Parquet footers and page metadata can be represented +/// using this trait. +/// +/// See [`crate::runtime_env::RuntimeEnv`] for more details +pub trait FileMetadata: Any + Send + Sync { + /// Returns the file metadata as [`Any`] so that it can be downcast to a specific + /// implementation. + fn as_any(&self) -> &dyn Any; + + /// Returns the size of the metadata in bytes. + fn memory_size(&self) -> usize; + + /// Returns extra information about this entry (used by [`FileMetadataCache::list_entries`]). + fn extra_info(&self) -> HashMap; +} + +/// Cache for file-embedded metadata. +/// +/// This cache stores per-file metadata in the form of [`FileMetadata`], +/// +/// For example, the built in [`ListingTable`] uses this cache to avoid parsing +/// Parquet footers multiple times for the same file. +/// +/// DataFusion provides a default implementation, [`DefaultFilesMetadataCache`], +/// and users can also provide their own implementations to implement custom +/// caching strategies. +/// +/// See [`crate::runtime_env::RuntimeEnv`] for more details. +/// +/// [`ListingTable`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html +pub trait FileMetadataCache: + CacheAccessor, Extra = ObjectMeta> +{ + /// Returns the cache's memory limit in bytes. + fn cache_limit(&self) -> usize; + + /// Updates the cache with a new memory limit in bytes. + fn update_cache_limit(&self, limit: usize); + + /// Retrieves the information about the entries currently cached. + fn list_entries(&self) -> HashMap; +} + +#[derive(Debug, Clone, PartialEq, Eq)] +/// Represents information about a cached metadata entry. +/// This is used to expose the metadata cache contents to outside modules. +pub struct FileMetadataCacheEntry { + pub object_meta: ObjectMeta, + /// Size of the cached metadata, in bytes. + pub size_bytes: usize, + /// Number of times this entry was retrieved. + pub hits: usize, + /// Additional object-specific information. + pub extra: HashMap, +} + impl Debug for dyn CacheAccessor, Extra = ObjectMeta> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "Cache name: {} with length: {}", self.name(), self.len()) @@ -44,22 +115,49 @@ impl Debug for dyn CacheAccessor>, Extra = ObjectMeta> } } -#[derive(Default, Debug)] +impl Debug for dyn FileMetadataCache { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Cache name: {} with length: {}", self.name(), self.len()) + } +} + +/// Manages various caches used in DataFusion. +/// +/// Following DataFusion design principles, DataFusion provides default cache +/// implementations, while also allowing users to provide their own custom cache +/// implementations by implementing the relevant traits. +/// +/// See [`CacheManagerConfig`] for configuration options. +#[derive(Debug)] pub struct CacheManager { file_statistic_cache: Option, list_files_cache: Option, + file_metadata_cache: Arc, } impl CacheManager { pub fn try_new(config: &CacheManagerConfig) -> Result> { - let mut manager = CacheManager::default(); - if let Some(cc) = &config.table_files_statistics_cache { - manager.file_statistic_cache = Some(Arc::clone(cc)) - } - if let Some(lc) = &config.list_files_cache { - manager.list_files_cache = Some(Arc::clone(lc)) - } - Ok(Arc::new(manager)) + let file_statistic_cache = + config.table_files_statistics_cache.as_ref().map(Arc::clone); + + let list_files_cache = config.list_files_cache.as_ref().map(Arc::clone); + + let file_metadata_cache = config + .file_metadata_cache + .as_ref() + .map(Arc::clone) + .unwrap_or_else(|| { + Arc::new(DefaultFilesMetadataCache::new(config.metadata_cache_limit)) + }); + + // the cache memory limit might have changed, ensure the limit is updated + file_metadata_cache.update_cache_limit(config.metadata_cache_limit); + + Ok(Arc::new(CacheManager { + file_statistic_cache, + list_files_cache, + file_metadata_cache, + })) } /// Get the cache of listing files statistics. @@ -67,13 +165,25 @@ impl CacheManager { self.file_statistic_cache.clone() } - /// Get the cache of objectMeta under same path. + /// Get the cache for storing the result of listing [`ObjectMeta`]s under the same path. pub fn get_list_files_cache(&self) -> Option { self.list_files_cache.clone() } + + /// Get the file embedded metadata cache. + pub fn get_file_metadata_cache(&self) -> Arc { + Arc::clone(&self.file_metadata_cache) + } + + /// Get the limit of the file embedded metadata cache. + pub fn get_metadata_cache_limit(&self) -> usize { + self.file_metadata_cache.cache_limit() + } } -#[derive(Clone, Default)] +const DEFAULT_METADATA_CACHE_LIMIT: usize = 50 * 1024 * 1024; // 50M + +#[derive(Clone)] pub struct CacheManagerConfig { /// Enable cache of files statistics when listing files. /// Avoid get same file statistics repeatedly in same datafusion session. @@ -86,9 +196,29 @@ pub struct CacheManagerConfig { /// location. /// Default is disable. pub list_files_cache: Option, + /// Cache of file-embedded metadata, used to avoid reading it multiple times when processing a + /// data file (e.g., Parquet footer and page metadata). + /// If not provided, the [`CacheManager`] will create a [`DefaultFilesMetadataCache`]. + pub file_metadata_cache: Option>, + /// Limit of the file-embedded metadata cache, in bytes. + pub metadata_cache_limit: usize, +} + +impl Default for CacheManagerConfig { + fn default() -> Self { + Self { + table_files_statistics_cache: Default::default(), + list_files_cache: Default::default(), + file_metadata_cache: Default::default(), + metadata_cache_limit: DEFAULT_METADATA_CACHE_LIMIT, + } + } } impl CacheManagerConfig { + /// Set the cache for files statistics. + /// + /// Default is `None` (disabled). pub fn with_files_statistics_cache( mut self, cache: Option, @@ -97,8 +227,28 @@ impl CacheManagerConfig { self } + /// Set the cache for listing files. + /// + /// Default is `None` (disabled). pub fn with_list_files_cache(mut self, cache: Option) -> Self { self.list_files_cache = cache; self } + + /// Sets the cache for file-embedded metadata. + /// + /// Default is a [`DefaultFilesMetadataCache`]. + pub fn with_file_metadata_cache( + mut self, + cache: Option>, + ) -> Self { + self.file_metadata_cache = cache; + self + } + + /// Sets the limit of the file-embedded metadata cache, in bytes. + pub fn with_metadata_cache_limit(mut self, limit: usize) -> Self { + self.metadata_cache_limit = limit; + self + } } diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index a9291659a3efa..d27c266b768ad 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -15,8 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use crate::cache::cache_manager::{ + FileMetadata, FileMetadataCache, FileMetadataCacheEntry, +}; +use crate::cache::lru_queue::LruQueue; use crate::cache::CacheAccessor; use datafusion_common::Statistics; @@ -25,8 +30,13 @@ use dashmap::DashMap; use object_store::path::Path; use object_store::ObjectMeta; -/// Collected statistics for files +/// Default implementation of [`FileStatisticsCache`] +/// +/// Stores collected statistics for files +/// /// Cache is invalided when file size or last modification has changed +/// +/// [`FileStatisticsCache`]: crate::cache::cache_manager::FileStatisticsCache #[derive(Default)] pub struct DefaultFileStatisticsCache { statistics: DashMap)>, @@ -97,8 +107,13 @@ impl CacheAccessor> for DefaultFileStatisticsCache { } } +/// Default implementation of [`ListFilesCache`] +/// /// Collected files metadata for listing files. -/// Cache will not invalided until user call remove or clear. +/// +/// Cache is not invalided until user calls [`Self::remove`] or [`Self::clear`]. +/// +/// [`ListFilesCache`]: crate::cache::cache_manager::ListFilesCache #[derive(Default)] pub struct DefaultListFilesCache { statistics: DashMap>>, @@ -157,9 +172,269 @@ impl CacheAccessor>> for DefaultListFilesCache { } } +/// Handles the inner state of the [`DefaultFilesMetadataCache`] struct. +struct DefaultFilesMetadataCacheState { + lru_queue: LruQueue)>, + memory_limit: usize, + memory_used: usize, + cache_hits: HashMap, +} + +impl DefaultFilesMetadataCacheState { + fn new(memory_limit: usize) -> Self { + Self { + lru_queue: LruQueue::new(), + memory_limit, + memory_used: 0, + cache_hits: HashMap::new(), + } + } + + /// Returns the respective entry from the cache, if it exists and the `size` and `last_modified` + /// properties from [`ObjectMeta`] match. + /// If the entry exists, it becomes the most recently used. + fn get(&mut self, k: &ObjectMeta) -> Option> { + self.lru_queue + .get(&k.location) + .map(|(object_meta, metadata)| { + if object_meta.size != k.size + || object_meta.last_modified != k.last_modified + { + None + } else { + *self.cache_hits.entry(k.location.clone()).or_insert(0) += 1; + Some(Arc::clone(metadata)) + } + }) + .unwrap_or(None) + } + + /// Checks if the metadata is currently cached (entry exists and the `size` and `last_modified` + /// properties of [`ObjectMeta`] match). + /// The LRU queue is not updated. + fn contains_key(&self, k: &ObjectMeta) -> bool { + self.lru_queue + .peek(&k.location) + .map(|(object_meta, _)| { + object_meta.size == k.size && object_meta.last_modified == k.last_modified + }) + .unwrap_or(false) + } + + /// Adds a new key-value pair to cache, meaning LRU entries might be evicted if required. + /// If the key is already in the cache, the previous metadata is returned. + /// If the size of the metadata is greater than the `memory_limit`, the value is not inserted. + fn put( + &mut self, + key: ObjectMeta, + value: Arc, + ) -> Option> { + let value_size = value.memory_size(); + + // no point in trying to add this value to the cache if it cannot fit entirely + if value_size > self.memory_limit { + return None; + } + + self.cache_hits.insert(key.location.clone(), 0); + // if the key is already in the cache, the old value is removed + let old_value = self.lru_queue.put(key.location.clone(), (key, value)); + self.memory_used += value_size; + if let Some((_, ref old_metadata)) = old_value { + self.memory_used -= old_metadata.memory_size(); + } + + self.evict_entries(); + + old_value.map(|v| v.1) + } + + /// Evicts entries from the LRU cache until `memory_used` is lower than `memory_limit`. + fn evict_entries(&mut self) { + while self.memory_used > self.memory_limit { + if let Some(removed) = self.lru_queue.pop() { + let metadata: Arc = removed.1 .1; + self.memory_used -= metadata.memory_size(); + } else { + // cache is empty while memory_used > memory_limit, cannot happen + debug_assert!( + false, + "cache is empty while memory_used > memory_limit, cannot happen" + ); + return; + } + } + } + + /// Removes an entry from the cache and returns it, if it exists. + fn remove(&mut self, k: &ObjectMeta) -> Option> { + if let Some((_, old_metadata)) = self.lru_queue.remove(&k.location) { + self.memory_used -= old_metadata.memory_size(); + self.cache_hits.remove(&k.location); + Some(old_metadata) + } else { + None + } + } + + /// Returns the number of entries currently cached. + fn len(&self) -> usize { + self.lru_queue.len() + } + + /// Removes all entries from the cache. + fn clear(&mut self) { + self.lru_queue.clear(); + self.memory_used = 0; + self.cache_hits.clear(); + } +} + +/// Default implementation of [`FileMetadataCache`] +/// +/// Collected file embedded metadata cache. +/// +/// The metadata for each file is invalidated when the file size or last +/// modification time have been changed. +/// +/// # Internal details +/// +/// The `memory_limit` controls the maximum size of the cache, which uses a +/// Least Recently Used eviction algorithm. When adding a new entry, if the total +/// size of the cached entries exceeds `memory_limit`, the least recently used entries +/// are evicted until the total size is lower than `memory_limit`. +/// +/// # `Extra` Handling +/// +/// Users should use the [`Self::get`] and [`Self::put`] methods. The +/// [`Self::get_with_extra`] and [`Self::put_with_extra`] methods simply call +/// `get` and `put`, respectively. +pub struct DefaultFilesMetadataCache { + // the state is wrapped in a Mutex to ensure the operations are atomic + state: Mutex, +} + +impl DefaultFilesMetadataCache { + /// Create a new instance of [`DefaultFilesMetadataCache`]. + /// + /// # Arguments + /// `memory_limit`: the maximum size of the cache, in bytes + // + pub fn new(memory_limit: usize) -> Self { + Self { + state: Mutex::new(DefaultFilesMetadataCacheState::new(memory_limit)), + } + } + + /// Returns the size of the cached memory, in bytes. + pub fn memory_used(&self) -> usize { + let state = self.state.lock().unwrap(); + state.memory_used + } +} + +impl FileMetadataCache for DefaultFilesMetadataCache { + fn cache_limit(&self) -> usize { + let state = self.state.lock().unwrap(); + state.memory_limit + } + + fn update_cache_limit(&self, limit: usize) { + let mut state = self.state.lock().unwrap(); + state.memory_limit = limit; + state.evict_entries(); + } + + fn list_entries(&self) -> HashMap { + let state = self.state.lock().unwrap(); + let mut entries = HashMap::::new(); + + for (path, (object_meta, metadata)) in state.lru_queue.list_entries() { + entries.insert( + path.clone(), + FileMetadataCacheEntry { + object_meta: object_meta.clone(), + size_bytes: metadata.memory_size(), + hits: *state.cache_hits.get(path).expect("entry must exist"), + extra: metadata.extra_info(), + }, + ); + } + + entries + } +} + +impl CacheAccessor> for DefaultFilesMetadataCache { + type Extra = ObjectMeta; + + fn get(&self, k: &ObjectMeta) -> Option> { + let mut state = self.state.lock().unwrap(); + state.get(k) + } + + fn get_with_extra( + &self, + k: &ObjectMeta, + _e: &Self::Extra, + ) -> Option> { + self.get(k) + } + + fn put( + &self, + key: &ObjectMeta, + value: Arc, + ) -> Option> { + let mut state = self.state.lock().unwrap(); + state.put(key.clone(), value) + } + + fn put_with_extra( + &self, + key: &ObjectMeta, + value: Arc, + _e: &Self::Extra, + ) -> Option> { + self.put(key, value) + } + + fn remove(&mut self, k: &ObjectMeta) -> Option> { + let mut state = self.state.lock().unwrap(); + state.remove(k) + } + + fn contains_key(&self, k: &ObjectMeta) -> bool { + let state = self.state.lock().unwrap(); + state.contains_key(k) + } + + fn len(&self) -> usize { + let state = self.state.lock().unwrap(); + state.len() + } + + fn clear(&self) { + let mut state = self.state.lock().unwrap(); + state.clear(); + } + + fn name(&self) -> String { + "DefaultFilesMetadataCache".to_string() + } +} + #[cfg(test)] mod tests { - use crate::cache::cache_unit::{DefaultFileStatisticsCache, DefaultListFilesCache}; + use std::collections::HashMap; + use std::sync::Arc; + + use crate::cache::cache_manager::{ + FileMetadata, FileMetadataCache, FileMetadataCacheEntry, + }; + use crate::cache::cache_unit::{ + DefaultFileStatisticsCache, DefaultFilesMetadataCache, DefaultListFilesCache, + }; use crate::cache::CacheAccessor; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use chrono::DateTime; @@ -232,4 +507,444 @@ mod tests { meta.clone() ); } + + pub struct TestFileMetadata { + metadata: String, + } + + impl FileMetadata for TestFileMetadata { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn memory_size(&self) -> usize { + self.metadata.len() + } + + fn extra_info(&self) -> HashMap { + HashMap::from([("extra_info".to_owned(), "abc".to_owned())]) + } + } + + #[test] + fn test_default_file_metadata_cache() { + let object_meta = ObjectMeta { + location: Path::from("test"), + last_modified: DateTime::parse_from_rfc3339("2025-07-29T12:12:12+00:00") + .unwrap() + .into(), + size: 1024, + e_tag: None, + version: None, + }; + + let metadata: Arc = Arc::new(TestFileMetadata { + metadata: "retrieved_metadata".to_owned(), + }); + + let mut cache = DefaultFilesMetadataCache::new(1024 * 1024); + assert!(cache.get(&object_meta).is_none()); + + // put + cache.put(&object_meta, Arc::clone(&metadata)); + + // get and contains of a valid entry + assert!(cache.contains_key(&object_meta)); + let value = cache.get(&object_meta); + assert!(value.is_some()); + let test_file_metadata = Arc::downcast::(value.unwrap()); + assert!(test_file_metadata.is_ok()); + assert_eq!(test_file_metadata.unwrap().metadata, "retrieved_metadata"); + + // file size changed + let mut object_meta2 = object_meta.clone(); + object_meta2.size = 2048; + assert!(cache.get(&object_meta2).is_none()); + assert!(!cache.contains_key(&object_meta2)); + + // file last_modified changed + let mut object_meta2 = object_meta.clone(); + object_meta2.last_modified = + DateTime::parse_from_rfc3339("2025-07-29T13:13:13+00:00") + .unwrap() + .into(); + assert!(cache.get(&object_meta2).is_none()); + assert!(!cache.contains_key(&object_meta2)); + + // different file + let mut object_meta2 = object_meta.clone(); + object_meta2.location = Path::from("test2"); + assert!(cache.get(&object_meta2).is_none()); + assert!(!cache.contains_key(&object_meta2)); + + // remove + cache.remove(&object_meta); + assert!(cache.get(&object_meta).is_none()); + assert!(!cache.contains_key(&object_meta)); + + // len and clear + cache.put(&object_meta, Arc::clone(&metadata)); + cache.put(&object_meta2, metadata); + assert_eq!(cache.len(), 2); + cache.clear(); + assert_eq!(cache.len(), 0); + } + + fn generate_test_metadata_with_size( + path: &str, + size: usize, + ) -> (ObjectMeta, Arc) { + let object_meta = ObjectMeta { + location: Path::from(path), + last_modified: chrono::Utc::now(), + size: size as u64, + e_tag: None, + version: None, + }; + let metadata: Arc = Arc::new(TestFileMetadata { + metadata: "a".repeat(size), + }); + + (object_meta, metadata) + } + + #[test] + fn test_default_file_metadata_cache_with_limit() { + let mut cache = DefaultFilesMetadataCache::new(1000); + let (object_meta1, metadata1) = generate_test_metadata_with_size("1", 100); + let (object_meta2, metadata2) = generate_test_metadata_with_size("2", 500); + let (object_meta3, metadata3) = generate_test_metadata_with_size("3", 300); + + cache.put(&object_meta1, metadata1); + cache.put(&object_meta2, metadata2); + cache.put(&object_meta3, metadata3); + + // all entries will fit + assert_eq!(cache.len(), 3); + assert_eq!(cache.memory_used(), 900); + assert!(cache.contains_key(&object_meta1)); + assert!(cache.contains_key(&object_meta2)); + assert!(cache.contains_key(&object_meta3)); + + // add a new entry which will remove the least recently used ("1") + let (object_meta4, metadata4) = generate_test_metadata_with_size("4", 200); + cache.put(&object_meta4, metadata4); + assert_eq!(cache.len(), 3); + assert_eq!(cache.memory_used(), 1000); + assert!(!cache.contains_key(&object_meta1)); + assert!(cache.contains_key(&object_meta4)); + + // get entry "2", which will move it to the top of the queue, and add a new one which will + // remove the new least recently used ("3") + cache.get(&object_meta2); + let (object_meta5, metadata5) = generate_test_metadata_with_size("5", 100); + cache.put(&object_meta5, metadata5); + assert_eq!(cache.len(), 3); + assert_eq!(cache.memory_used(), 800); + assert!(!cache.contains_key(&object_meta3)); + assert!(cache.contains_key(&object_meta5)); + + // new entry which will not be able to fit in the 1000 bytes allocated + let (object_meta6, metadata6) = generate_test_metadata_with_size("6", 1200); + cache.put(&object_meta6, metadata6); + assert_eq!(cache.len(), 3); + assert_eq!(cache.memory_used(), 800); + assert!(!cache.contains_key(&object_meta6)); + + // new entry which is able to fit without removing any entry + let (object_meta7, metadata7) = generate_test_metadata_with_size("7", 200); + cache.put(&object_meta7, metadata7); + assert_eq!(cache.len(), 4); + assert_eq!(cache.memory_used(), 1000); + assert!(cache.contains_key(&object_meta7)); + + // new entry which will remove all other entries + let (object_meta8, metadata8) = generate_test_metadata_with_size("8", 999); + cache.put(&object_meta8, metadata8); + assert_eq!(cache.len(), 1); + assert_eq!(cache.memory_used(), 999); + assert!(cache.contains_key(&object_meta8)); + + // when updating an entry, the previous ones are not unnecessarily removed + let (object_meta9, metadata9) = generate_test_metadata_with_size("9", 300); + let (object_meta10, metadata10) = generate_test_metadata_with_size("10", 200); + let (object_meta11_v1, metadata11_v1) = + generate_test_metadata_with_size("11", 400); + cache.put(&object_meta9, metadata9); + cache.put(&object_meta10, metadata10); + cache.put(&object_meta11_v1, metadata11_v1); + assert_eq!(cache.memory_used(), 900); + assert_eq!(cache.len(), 3); + let (object_meta11_v2, metadata11_v2) = + generate_test_metadata_with_size("11", 500); + cache.put(&object_meta11_v2, metadata11_v2); + assert_eq!(cache.memory_used(), 1000); + assert_eq!(cache.len(), 3); + assert!(cache.contains_key(&object_meta9)); + assert!(cache.contains_key(&object_meta10)); + assert!(cache.contains_key(&object_meta11_v2)); + assert!(!cache.contains_key(&object_meta11_v1)); + + // when updating an entry that now exceeds the limit, the LRU ("9") needs to be removed + let (object_meta11_v3, metadata11_v3) = + generate_test_metadata_with_size("11", 501); + cache.put(&object_meta11_v3, metadata11_v3); + assert_eq!(cache.memory_used(), 701); + assert_eq!(cache.len(), 2); + assert!(cache.contains_key(&object_meta10)); + assert!(cache.contains_key(&object_meta11_v3)); + assert!(!cache.contains_key(&object_meta11_v2)); + + // manually removing an entry that is not the LRU + cache.remove(&object_meta11_v3); + assert_eq!(cache.len(), 1); + assert_eq!(cache.memory_used(), 200); + assert!(cache.contains_key(&object_meta10)); + assert!(!cache.contains_key(&object_meta11_v3)); + + // clear + cache.clear(); + assert_eq!(cache.len(), 0); + assert_eq!(cache.memory_used(), 0); + + // resizing the cache should clear the extra entries + let (object_meta12, metadata12) = generate_test_metadata_with_size("12", 300); + let (object_meta13, metadata13) = generate_test_metadata_with_size("13", 200); + let (object_meta14, metadata14) = generate_test_metadata_with_size("14", 500); + cache.put(&object_meta12, metadata12); + cache.put(&object_meta13, metadata13); + cache.put(&object_meta14, metadata14); + assert_eq!(cache.len(), 3); + assert_eq!(cache.memory_used(), 1000); + cache.update_cache_limit(600); + assert_eq!(cache.len(), 1); + assert_eq!(cache.memory_used(), 500); + assert!(!cache.contains_key(&object_meta12)); + assert!(!cache.contains_key(&object_meta13)); + assert!(cache.contains_key(&object_meta14)); + } + + #[test] + fn test_default_file_metadata_cache_entries_info() { + let mut cache = DefaultFilesMetadataCache::new(1000); + let (object_meta1, metadata1) = generate_test_metadata_with_size("1", 100); + let (object_meta2, metadata2) = generate_test_metadata_with_size("2", 200); + let (object_meta3, metadata3) = generate_test_metadata_with_size("3", 300); + + // initial entries, all will have hits = 0 + cache.put(&object_meta1, metadata1); + cache.put(&object_meta2, metadata2); + cache.put(&object_meta3, metadata3); + assert_eq!( + cache.list_entries(), + HashMap::from([ + ( + Path::from("1"), + FileMetadataCacheEntry { + object_meta: object_meta1.clone(), + size_bytes: 100, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("2"), + FileMetadataCacheEntry { + object_meta: object_meta2.clone(), + size_bytes: 200, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("3"), + FileMetadataCacheEntry { + object_meta: object_meta3.clone(), + size_bytes: 300, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ) + ]) + ); + + // new hit on "1" + cache.get(&object_meta1); + assert_eq!( + cache.list_entries(), + HashMap::from([ + ( + Path::from("1"), + FileMetadataCacheEntry { + object_meta: object_meta1.clone(), + size_bytes: 100, + hits: 1, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("2"), + FileMetadataCacheEntry { + object_meta: object_meta2.clone(), + size_bytes: 200, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("3"), + FileMetadataCacheEntry { + object_meta: object_meta3.clone(), + size_bytes: 300, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ) + ]) + ); + + // new entry, will evict "2" + let (object_meta4, metadata4) = generate_test_metadata_with_size("4", 600); + cache.put(&object_meta4, metadata4); + assert_eq!( + cache.list_entries(), + HashMap::from([ + ( + Path::from("1"), + FileMetadataCacheEntry { + object_meta: object_meta1.clone(), + size_bytes: 100, + hits: 1, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("3"), + FileMetadataCacheEntry { + object_meta: object_meta3.clone(), + size_bytes: 300, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("4"), + FileMetadataCacheEntry { + object_meta: object_meta4.clone(), + size_bytes: 600, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ) + ]) + ); + + // replace entry "1" + let (object_meta1_new, metadata1_new) = generate_test_metadata_with_size("1", 50); + cache.put(&object_meta1_new, metadata1_new); + assert_eq!( + cache.list_entries(), + HashMap::from([ + ( + Path::from("1"), + FileMetadataCacheEntry { + object_meta: object_meta1_new.clone(), + size_bytes: 50, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("3"), + FileMetadataCacheEntry { + object_meta: object_meta3.clone(), + size_bytes: 300, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("4"), + FileMetadataCacheEntry { + object_meta: object_meta4.clone(), + size_bytes: 600, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ) + ]) + ); + + // remove entry "4" + cache.remove(&object_meta4); + assert_eq!( + cache.list_entries(), + HashMap::from([ + ( + Path::from("1"), + FileMetadataCacheEntry { + object_meta: object_meta1_new.clone(), + size_bytes: 50, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ), + ( + Path::from("3"), + FileMetadataCacheEntry { + object_meta: object_meta3.clone(), + size_bytes: 300, + hits: 0, + extra: HashMap::from([( + "extra_info".to_owned(), + "abc".to_owned() + )]), + } + ) + ]) + ); + + // clear + cache.clear(); + assert_eq!(cache.list_entries(), HashMap::from([])); + } } diff --git a/datafusion/execution/src/cache/lru_queue.rs b/datafusion/execution/src/cache/lru_queue.rs new file mode 100644 index 0000000000000..fb3d158ced425 --- /dev/null +++ b/datafusion/execution/src/cache/lru_queue.rs @@ -0,0 +1,542 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + collections::HashMap, + hash::Hash, + sync::{Arc, Weak}, +}; + +use parking_lot::Mutex; + +#[derive(Default)] +/// Provides a Least Recently Used queue with unbounded capacity. +/// +/// # Examples +/// +/// ``` +/// use datafusion_execution::cache::lru_queue::LruQueue; +/// +/// let mut lru_queue: LruQueue = LruQueue::new(); +/// lru_queue.put(1, 10); +/// lru_queue.put(2, 20); +/// lru_queue.put(3, 30); +/// assert_eq!(lru_queue.get(&2), Some(&20)); +/// assert_eq!(lru_queue.pop(), Some((1, 10))); +/// assert_eq!(lru_queue.pop(), Some((3, 30))); +/// assert_eq!(lru_queue.pop(), Some((2, 20))); +/// assert_eq!(lru_queue.pop(), None); +/// ``` +pub struct LruQueue { + data: LruData, + queue: LruList, +} + +/// Maps the key to the [`LruNode`] in queue and the value. +type LruData = HashMap>>, V)>; + +#[derive(Default)] +/// Doubly-linked list that maintains the LRU order +struct LruList { + head: Link, + tail: Link, +} + +/// Doubly-linked list node. +struct LruNode { + key: K, + prev: Link, + next: Link, +} + +/// Weak pointer to a [`LruNode`], used to connect nodes in the doubly-linked list. +/// The strong reference is guaranteed to be stored in the `data` map of the [`LruQueue`]. +type Link = Option>>>; + +impl LruQueue { + pub fn new() -> Self { + Self { + data: HashMap::new(), + queue: LruList { + head: None, + tail: None, + }, + } + } + + /// Returns a reference to value mapped by `key`, if it exists. + /// If the entry exists, it becomes the most recently used. + pub fn get(&mut self, key: &K) -> Option<&V> { + if let Some(value) = self.remove(key) { + self.put(key.clone(), value); + } + self.data.get(key).map(|(_, value)| value) + } + + /// Returns a reference to value mapped by `key`, if it exists. + /// Does not affect the queue order. + pub fn peek(&self, key: &K) -> Option<&V> { + self.data.get(key).map(|(_, value)| value) + } + + /// Checks whether there is an entry with key `key` in the queue. + /// Does not affect the queue order. + pub fn contains_key(&self, key: &K) -> bool { + self.data.contains_key(key) + } + + /// Inserts an entry in the queue, becoming the most recently used. + /// If the entry already exists, returns the previous value. + pub fn put(&mut self, key: K, value: V) -> Option { + let old_value = self.remove(&key); + + let node = Arc::new(Mutex::new(LruNode { + key: key.clone(), + prev: None, + next: None, + })); + + match self.queue.head { + // queue is not empty + Some(ref old_head) => { + old_head + .upgrade() + .expect("value has been unexpectedly dropped") + .lock() + .prev = Some(Arc::downgrade(&node)); + node.lock().next = Some(Weak::clone(old_head)); + self.queue.head = Some(Arc::downgrade(&node)); + } + // queue is empty + _ => { + self.queue.head = Some(Arc::downgrade(&node)); + self.queue.tail = Some(Arc::downgrade(&node)); + } + } + + self.data.insert(key, (node, value)); + + old_value + } + + /// Removes and returns the least recently used value. + /// Returns `None` if the queue is empty. + pub fn pop(&mut self) -> Option<(K, V)> { + let key_to_remove = self.queue.tail.as_ref().map(|n| { + n.upgrade() + .expect("value has been unexpectedly dropped") + .lock() + .key + .clone() + }); + if let Some(k) = key_to_remove { + let value = self.remove(&k).unwrap(); // confirmed above that the entry exists + Some((k, value)) + } else { + None + } + } + + /// Removes a specific entry from the queue, if it exists. + pub fn remove(&mut self, key: &K) -> Option { + if let Some((old_node, old_value)) = self.data.remove(key) { + let LruNode { key: _, prev, next } = &*old_node.lock(); + match (prev, next) { + // single node in the queue + (None, None) => { + self.queue.head = None; + self.queue.tail = None; + } + // removed the head node + (None, Some(n)) => { + let n_strong = + n.upgrade().expect("value has been unexpectedly dropped"); + n_strong.lock().prev = None; + self.queue.head = Some(Weak::clone(n)); + } + // removed the tail node + (Some(p), None) => { + let p_strong = + p.upgrade().expect("value has been unexpectedly dropped"); + p_strong.lock().next = None; + self.queue.tail = Some(Weak::clone(p)); + } + // removed a middle node + (Some(p), Some(n)) => { + let n_strong = + n.upgrade().expect("value has been unexpectedly dropped"); + let p_strong = + p.upgrade().expect("value has been unexpectedly dropped"); + n_strong.lock().prev = Some(Weak::clone(p)); + p_strong.lock().next = Some(Weak::clone(n)); + } + }; + Some(old_value) + } else { + None + } + } + + /// Returns the number of entries in the queue. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Checks whether the queue has no items. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Removes all entries from the queue. + pub fn clear(&mut self) { + self.queue.head = None; + self.queue.tail = None; + self.data.clear(); + } + + /// Returns a reference to the entries currently in the queue. + pub fn list_entries(&self) -> HashMap<&K, &V> { + self.data.iter().map(|(k, (_, v))| (k, v)).collect() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use rand::seq::IndexedRandom; + + use crate::cache::lru_queue::LruQueue; + + #[test] + fn test_get() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // value does not exist + assert_eq!(lru_queue.get(&1), None); + + // value exists + lru_queue.put(1, 10); + assert_eq!(lru_queue.get(&1), Some(&10)); + assert_eq!(lru_queue.get(&1), Some(&10)); + + // value is removed + lru_queue.remove(&1); + assert_eq!(lru_queue.get(&1), None); + } + + #[test] + fn test_peek() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // value does not exist + assert_eq!(lru_queue.peek(&1), None); + + // value exists + lru_queue.put(1, 10); + assert_eq!(lru_queue.peek(&1), Some(&10)); + assert_eq!(lru_queue.peek(&1), Some(&10)); + + // value is removed + lru_queue.remove(&1); + assert_eq!(lru_queue.peek(&1), None); + } + + #[test] + fn test_put() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // no previous value + assert_eq!(lru_queue.put(1, 10), None); + + // update, the previous value is returned + assert_eq!(lru_queue.put(1, 11), Some(10)); + assert_eq!(lru_queue.put(1, 12), Some(11)); + assert_eq!(lru_queue.put(1, 13), Some(12)); + } + + #[test] + fn test_remove() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // value does not exist + assert_eq!(lru_queue.remove(&1), None); + + // value exists and is returned + lru_queue.put(1, 10); + assert_eq!(lru_queue.remove(&1), Some(10)); + + // value does not exist + assert_eq!(lru_queue.remove(&1), None); + } + + #[test] + fn test_contains_key() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // value does not exist + assert!(!lru_queue.contains_key(&1)); + + // value exists + lru_queue.put(1, 10); + assert!(lru_queue.contains_key(&1)); + + // value is removed + lru_queue.remove(&1); + assert!(!lru_queue.contains_key(&1)); + } + + #[test] + fn test_len() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // empty + assert_eq!(lru_queue.len(), 0); + + // puts + lru_queue.put(1, 10); + assert_eq!(lru_queue.len(), 1); + lru_queue.put(2, 20); + assert_eq!(lru_queue.len(), 2); + lru_queue.put(3, 30); + assert_eq!(lru_queue.len(), 3); + lru_queue.put(1, 11); + lru_queue.put(3, 31); + assert_eq!(lru_queue.len(), 3); + + // removes + lru_queue.remove(&1); + assert_eq!(lru_queue.len(), 2); + lru_queue.remove(&1); + assert_eq!(lru_queue.len(), 2); + lru_queue.remove(&4); + assert_eq!(lru_queue.len(), 2); + lru_queue.remove(&3); + assert_eq!(lru_queue.len(), 1); + lru_queue.remove(&2); + assert_eq!(lru_queue.len(), 0); + lru_queue.remove(&2); + assert_eq!(lru_queue.len(), 0); + + // clear + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + assert_eq!(lru_queue.len(), 3); + lru_queue.clear(); + assert_eq!(lru_queue.len(), 0); + } + + #[test] + fn test_is_empty() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // empty + assert!(lru_queue.is_empty()); + + // puts + lru_queue.put(1, 10); + assert!(!lru_queue.is_empty()); + lru_queue.put(2, 20); + assert!(!lru_queue.is_empty()); + + // removes + lru_queue.remove(&1); + assert!(!lru_queue.is_empty()); + lru_queue.remove(&1); + assert!(!lru_queue.is_empty()); + lru_queue.remove(&2); + assert!(lru_queue.is_empty()); + + // clear + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + assert!(!lru_queue.is_empty()); + lru_queue.clear(); + assert!(lru_queue.is_empty()); + } + + #[test] + fn test_clear() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // empty + lru_queue.clear(); + + // filled + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + assert_eq!(lru_queue.get(&1), Some(&10)); + assert_eq!(lru_queue.get(&2), Some(&20)); + assert_eq!(lru_queue.get(&3), Some(&30)); + lru_queue.clear(); + assert_eq!(lru_queue.get(&1), None); + assert_eq!(lru_queue.get(&2), None); + assert_eq!(lru_queue.get(&3), None); + assert_eq!(lru_queue.len(), 0); + } + + #[test] + fn test_pop() { + let mut lru_queue: LruQueue = LruQueue::new(); + + // empty queue + assert_eq!(lru_queue.pop(), None); + + // simplest case + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + assert_eq!(lru_queue.pop(), Some((1, 10))); + assert_eq!(lru_queue.pop(), Some((2, 20))); + assert_eq!(lru_queue.pop(), Some((3, 30))); + assert_eq!(lru_queue.pop(), None); + + // 'get' changes the order + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.get(&2); + assert_eq!(lru_queue.pop(), Some((1, 10))); + assert_eq!(lru_queue.pop(), Some((3, 30))); + assert_eq!(lru_queue.pop(), Some((2, 20))); + assert_eq!(lru_queue.pop(), None); + + // multiple 'gets' + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.get(&2); + lru_queue.get(&3); + lru_queue.get(&1); + assert_eq!(lru_queue.pop(), Some((2, 20))); + assert_eq!(lru_queue.pop(), Some((3, 30))); + assert_eq!(lru_queue.pop(), Some((1, 10))); + assert_eq!(lru_queue.pop(), None); + + // 'peak' does not change the order + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.peek(&2); + assert_eq!(lru_queue.pop(), Some((1, 10))); + assert_eq!(lru_queue.pop(), Some((2, 20))); + assert_eq!(lru_queue.pop(), Some((3, 30))); + assert_eq!(lru_queue.pop(), None); + + // 'contains' does not change the order + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.contains_key(&2); + assert_eq!(lru_queue.pop(), Some((1, 10))); + assert_eq!(lru_queue.pop(), Some((2, 20))); + assert_eq!(lru_queue.pop(), Some((3, 30))); + assert_eq!(lru_queue.pop(), None); + + // 'put' on the same key promotes it + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.put(2, 21); + assert_eq!(lru_queue.pop(), Some((1, 10))); + assert_eq!(lru_queue.pop(), Some((3, 30))); + assert_eq!(lru_queue.pop(), Some((2, 21))); + assert_eq!(lru_queue.pop(), None); + + // multiple 'puts' + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.put(2, 21); + lru_queue.put(3, 31); + lru_queue.put(1, 11); + assert_eq!(lru_queue.pop(), Some((2, 21))); + assert_eq!(lru_queue.pop(), Some((3, 31))); + assert_eq!(lru_queue.pop(), Some((1, 11))); + assert_eq!(lru_queue.pop(), None); + + // 'remove' an element in the middle of the queue + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.remove(&2); + assert_eq!(lru_queue.pop(), Some((1, 10))); + assert_eq!(lru_queue.pop(), Some((3, 30))); + assert_eq!(lru_queue.pop(), None); + + // 'remove' the LRU + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.remove(&1); + assert_eq!(lru_queue.pop(), Some((2, 20))); + assert_eq!(lru_queue.pop(), Some((3, 30))); + assert_eq!(lru_queue.pop(), None); + + // 'remove' the MRU + lru_queue.put(1, 10); + lru_queue.put(2, 20); + lru_queue.put(3, 30); + lru_queue.remove(&3); + assert_eq!(lru_queue.pop(), Some((1, 10))); + assert_eq!(lru_queue.pop(), Some((2, 20))); + assert_eq!(lru_queue.pop(), None); + } + + #[test] + /// Fuzzy test using an hashmap as the base to check the methods. + fn test_fuzzy() { + let mut lru_queue: LruQueue = LruQueue::new(); + let mut map: HashMap = HashMap::new(); + let max_keys = 1_000; + let methods = ["get", "put", "remove", "pop", "contains", "len"]; + let mut rng = rand::rng(); + + for i in 0..1_000_000 { + match *methods.choose(&mut rng).unwrap() { + "get" => { + assert_eq!(lru_queue.get(&(i % max_keys)), map.get(&(i % max_keys))) + } + "put" => assert_eq!( + lru_queue.put(i % max_keys, i), + map.insert(i % max_keys, i) + ), + "remove" => assert_eq!( + lru_queue.remove(&(i % max_keys)), + map.remove(&(i % max_keys)) + ), + "pop" => { + let removed = lru_queue.pop(); + if let Some((k, v)) = removed { + assert_eq!(Some(v), map.remove(&k)) + } + } + "contains" => { + assert_eq!( + lru_queue.contains_key(&(i % max_keys)), + map.contains_key(&(i % max_keys)) + ) + } + "len" => assert_eq!(lru_queue.len(), map.len()), + _ => unreachable!(), + } + } + } +} diff --git a/datafusion/execution/src/cache/mod.rs b/datafusion/execution/src/cache/mod.rs index 4271bebd0b326..b1857c94facdf 100644 --- a/datafusion/execution/src/cache/mod.rs +++ b/datafusion/execution/src/cache/mod.rs @@ -17,6 +17,7 @@ pub mod cache_manager; pub mod cache_unit; +pub mod lru_queue; /// The cache accessor, users usually working on this interface while manipulating caches. /// This interface does not get `mut` references and thus has to handle its own diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 53646dc5b468e..491b1aca69ea1 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -23,7 +23,7 @@ use std::{ }; use datafusion_common::{ - config::{ConfigExtension, ConfigOptions}, + config::{ConfigExtension, ConfigOptions, SpillCompression}, Result, ScalarValue, }; @@ -91,8 +91,11 @@ use datafusion_common::{ /// [`SessionContext::new_with_config`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.new_with_config #[derive(Clone, Debug)] pub struct SessionConfig { - /// Configuration options - options: ConfigOptions, + /// Configuration options for the current session. + /// + /// A new copy is created on write, if there are other outstanding + /// references to the same options. + options: Arc, /// Opaque extensions. extensions: AnyMap, } @@ -100,7 +103,7 @@ pub struct SessionConfig { impl Default for SessionConfig { fn default() -> Self { Self { - options: ConfigOptions::new(), + options: Arc::new(ConfigOptions::new()), // Assume no extensions by default. extensions: HashMap::with_capacity_and_hasher( 0, @@ -117,6 +120,9 @@ impl SessionConfig { } /// Create an execution config with config options read from the environment + /// + /// See [`ConfigOptions::from_env`] for details on how environment variables + /// are mapped to config options. pub fn from_env() -> Result { Ok(ConfigOptions::from_env()?.into()) } @@ -136,7 +142,7 @@ impl SessionConfig { /// let config = SessionConfig::new(); /// assert!(config.options().execution.batch_size > 0); /// ``` - pub fn options(&self) -> &ConfigOptions { + pub fn options(&self) -> &Arc { &self.options } @@ -152,7 +158,7 @@ impl SessionConfig { /// assert_eq!(config.options().execution.batch_size, 1024); /// ``` pub fn options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options + Arc::make_mut(&mut self.options) } /// Set a configuration option @@ -177,7 +183,7 @@ impl SessionConfig { /// Set a generic `str` configuration option pub fn set_str(mut self, key: &str, value: &str) -> Self { - self.options.set(key, value).unwrap(); + self.options_mut().set(key, value).unwrap(); self } @@ -185,7 +191,7 @@ impl SessionConfig { pub fn with_batch_size(mut self, n: usize) -> Self { // batch size must be greater than zero assert!(n > 0); - self.options.execution.batch_size = n; + self.options_mut().execution.batch_size = n; self } @@ -193,9 +199,11 @@ impl SessionConfig { /// /// [`target_partitions`]: datafusion_common::config::ExecutionOptions::target_partitions pub fn with_target_partitions(mut self, n: usize) -> Self { - // partition count must be greater than zero - assert!(n > 0); - self.options.execution.target_partitions = n; + self.options_mut().execution.target_partitions = if n == 0 { + datafusion_common::config::ExecutionOptions::default().target_partitions + } else { + n + }; self } @@ -256,68 +264,75 @@ impl SessionConfig { self.options.execution.collect_statistics } + /// Compression codec for spill file + pub fn spill_compression(&self) -> SpillCompression { + self.options.execution.spill_compression + } + /// Selects a name for the default catalog and schema pub fn with_default_catalog_and_schema( mut self, catalog: impl Into, schema: impl Into, ) -> Self { - self.options.catalog.default_catalog = catalog.into(); - self.options.catalog.default_schema = schema.into(); + self.options_mut().catalog.default_catalog = catalog.into(); + self.options_mut().catalog.default_schema = schema.into(); self } /// Controls whether the default catalog and schema will be automatically created pub fn with_create_default_catalog_and_schema(mut self, create: bool) -> Self { - self.options.catalog.create_default_catalog_and_schema = create; + self.options_mut().catalog.create_default_catalog_and_schema = create; self } /// Enables or disables the inclusion of `information_schema` virtual tables pub fn with_information_schema(mut self, enabled: bool) -> Self { - self.options.catalog.information_schema = enabled; + self.options_mut().catalog.information_schema = enabled; self } /// Enables or disables the use of repartitioning for joins to improve parallelism pub fn with_repartition_joins(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_joins = enabled; + self.options_mut().optimizer.repartition_joins = enabled; self } /// Enables or disables the use of repartitioning for aggregations to improve parallelism pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_aggregations = enabled; + self.options_mut().optimizer.repartition_aggregations = enabled; self } /// Sets minimum file range size for repartitioning scans pub fn with_repartition_file_min_size(mut self, size: usize) -> Self { - self.options.optimizer.repartition_file_min_size = size; + self.options_mut().optimizer.repartition_file_min_size = size; self } /// Enables or disables the allowing unordered symmetric hash join pub fn with_allow_symmetric_joins_without_pruning(mut self, enabled: bool) -> Self { - self.options.optimizer.allow_symmetric_joins_without_pruning = enabled; + self.options_mut() + .optimizer + .allow_symmetric_joins_without_pruning = enabled; self } /// Enables or disables the use of repartitioning for file scans pub fn with_repartition_file_scans(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_file_scans = enabled; + self.options_mut().optimizer.repartition_file_scans = enabled; self } /// Enables or disables the use of repartitioning for window functions to improve parallelism pub fn with_repartition_windows(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_windows = enabled; + self.options_mut().optimizer.repartition_windows = enabled; self } /// Enables or disables the use of per-partition sorting to improve parallelism pub fn with_repartition_sorts(mut self, enabled: bool) -> Self { - self.options.optimizer.repartition_sorts = enabled; + self.options_mut().optimizer.repartition_sorts = enabled; self } @@ -326,7 +341,7 @@ impl SessionConfig { /// /// [prefer_existing_sort]: datafusion_common::config::OptimizerOptions::prefer_existing_sort pub fn with_prefer_existing_sort(mut self, enabled: bool) -> Self { - self.options.optimizer.prefer_existing_sort = enabled; + self.options_mut().optimizer.prefer_existing_sort = enabled; self } @@ -334,13 +349,13 @@ impl SessionConfig { /// /// [prefer_existing_union]: datafusion_common::config::OptimizerOptions::prefer_existing_union pub fn with_prefer_existing_union(mut self, enabled: bool) -> Self { - self.options.optimizer.prefer_existing_union = enabled; + self.options_mut().optimizer.prefer_existing_union = enabled; self } /// Enables or disables the use of pruning predicate for parquet readers to skip row groups pub fn with_parquet_pruning(mut self, enabled: bool) -> Self { - self.options.execution.parquet.pruning = enabled; + self.options_mut().execution.parquet.pruning = enabled; self } @@ -356,7 +371,7 @@ impl SessionConfig { /// Enables or disables the use of bloom filter for parquet readers to skip row groups pub fn with_parquet_bloom_filter_pruning(mut self, enabled: bool) -> Self { - self.options.execution.parquet.bloom_filter_on_read = enabled; + self.options_mut().execution.parquet.bloom_filter_on_read = enabled; self } @@ -367,13 +382,13 @@ impl SessionConfig { /// Enables or disables the use of page index for parquet readers to skip parquet data pages pub fn with_parquet_page_index_pruning(mut self, enabled: bool) -> Self { - self.options.execution.parquet.enable_page_index = enabled; + self.options_mut().execution.parquet.enable_page_index = enabled; self } /// Enables or disables the collection of statistics after listing files pub fn with_collect_statistics(mut self, enabled: bool) -> Self { - self.options.execution.collect_statistics = enabled; + self.options_mut().execution.collect_statistics = enabled; self } @@ -384,7 +399,7 @@ impl SessionConfig { /// Enables or disables the coalescence of small batches into larger batches pub fn with_coalesce_batches(mut self, enabled: bool) -> Self { - self.options.execution.coalesce_batches = enabled; + self.options_mut().execution.coalesce_batches = enabled; self } @@ -396,7 +411,7 @@ impl SessionConfig { /// Enables or disables the round robin repartition for increasing parallelism pub fn with_round_robin_repartition(mut self, enabled: bool) -> Self { - self.options.optimizer.enable_round_robin_repartition = enabled; + self.options_mut().optimizer.enable_round_robin_repartition = enabled; self } @@ -414,11 +429,19 @@ impl SessionConfig { mut self, sort_spill_reservation_bytes: usize, ) -> Self { - self.options.execution.sort_spill_reservation_bytes = + self.options_mut().execution.sort_spill_reservation_bytes = sort_spill_reservation_bytes; self } + /// Set the compression codec [`spill_compression`] used when spilling data to disk. + /// + /// [`spill_compression`]: datafusion_common::config::ExecutionOptions::spill_compression + pub fn with_spill_compression(mut self, spill_compression: SpillCompression) -> Self { + self.options_mut().execution.spill_compression = spill_compression; + self + } + /// Set the size of [`sort_in_place_threshold_bytes`] to control /// how sort does things. /// @@ -427,7 +450,7 @@ impl SessionConfig { mut self, sort_in_place_threshold_bytes: usize, ) -> Self { - self.options.execution.sort_in_place_threshold_bytes = + self.options_mut().execution.sort_in_place_threshold_bytes = sort_in_place_threshold_bytes; self } @@ -437,7 +460,8 @@ impl SessionConfig { mut self, enforce_batch_size_in_joins: bool, ) -> Self { - self.options.execution.enforce_batch_size_in_joins = enforce_batch_size_in_joins; + self.options_mut().execution.enforce_batch_size_in_joins = + enforce_batch_size_in_joins; self } @@ -575,6 +599,7 @@ impl SessionConfig { impl From for SessionConfig { fn from(options: ConfigOptions) -> Self { + let options = Arc::new(options); Self { options, ..Default::default() diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index caa62eefe14c7..82f2d75ac1b57 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -17,15 +17,110 @@ //! [`DiskManager`]: Manages files generated during query execution -use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; +use datafusion_common::{ + config_err, resources_datafusion_err, resources_err, DataFusionError, Result, +}; use log::debug; use parking_lot::Mutex; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use tempfile::{Builder, NamedTempFile, TempDir}; +use crate::memory_pool::human_readable_size; + +const DEFAULT_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; // 100GB + +/// Builder pattern for the [DiskManager] structure +#[derive(Clone, Debug)] +pub struct DiskManagerBuilder { + /// The storage mode of the disk manager + mode: DiskManagerMode, + /// The maximum amount of data (in bytes) stored inside the temporary directories. + /// Default to 100GB + max_temp_directory_size: u64, +} + +impl Default for DiskManagerBuilder { + fn default() -> Self { + Self { + mode: DiskManagerMode::OsTmpDirectory, + max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, + } + } +} + +impl DiskManagerBuilder { + pub fn set_mode(&mut self, mode: DiskManagerMode) { + self.mode = mode; + } + + pub fn with_mode(mut self, mode: DiskManagerMode) -> Self { + self.set_mode(mode); + self + } + + pub fn set_max_temp_directory_size(&mut self, value: u64) { + self.max_temp_directory_size = value; + } + + pub fn with_max_temp_directory_size(mut self, value: u64) -> Self { + self.set_max_temp_directory_size(value); + self + } + + /// Create a DiskManager given the builder + pub fn build(self) -> Result { + match self.mode { + DiskManagerMode::OsTmpDirectory => Ok(DiskManager { + local_dirs: Mutex::new(Some(vec![])), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }), + DiskManagerMode::Directories(conf_dirs) => { + let local_dirs = create_local_dirs(conf_dirs)?; + debug!( + "Created local dirs {local_dirs:?} as DataFusion working directory" + ); + Ok(DiskManager { + local_dirs: Mutex::new(Some(local_dirs)), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }) + } + DiskManagerMode::Disabled => Ok(DiskManager { + local_dirs: Mutex::new(None), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }), + } + } +} + +#[derive(Clone, Debug)] +pub enum DiskManagerMode { + /// Create a new [DiskManager] that creates temporary files within + /// a temporary directory chosen by the OS + OsTmpDirectory, + + /// Create a new [DiskManager] that creates temporary files within + /// the specified directories. One of the directories will be chosen + /// at random for each temporary file created. + Directories(Vec), + + /// Disable disk manager, attempts to create temporary files will error + Disabled, +} + +impl Default for DiskManagerMode { + fn default() -> Self { + Self::OsTmpDirectory + } +} + /// Configuration for temporary disk access +#[deprecated(since = "48.0.0", note = "Use DiskManagerBuilder instead")] #[derive(Debug, Clone)] pub enum DiskManagerConfig { /// Use the provided [DiskManager] instance @@ -43,12 +138,14 @@ pub enum DiskManagerConfig { Disabled, } +#[allow(deprecated)] impl Default for DiskManagerConfig { fn default() -> Self { Self::NewOs } } +#[allow(deprecated)] impl DiskManagerConfig { /// Create temporary files in a temporary directory chosen by the OS pub fn new() -> Self { @@ -75,32 +172,90 @@ pub struct DiskManager { /// If `Some(vec![])` a new OS specified temporary directory will be created /// If `None` an error will be returned (configured not to spill) local_dirs: Mutex>>>, + /// The maximum amount of data (in bytes) stored inside the temporary directories. + /// Default to 100GB + max_temp_directory_size: u64, + /// Used disk space in the temporary directories. Now only spilled data for + /// external executors are counted. + used_disk_space: Arc, } impl DiskManager { + /// Creates a builder for [DiskManager] + pub fn builder() -> DiskManagerBuilder { + DiskManagerBuilder::default() + } + /// Create a DiskManager given the configuration + #[allow(deprecated)] + #[deprecated(since = "48.0.0", note = "Use DiskManager::builder() instead")] pub fn try_new(config: DiskManagerConfig) -> Result> { match config { DiskManagerConfig::Existing(manager) => Ok(manager), DiskManagerConfig::NewOs => Ok(Arc::new(Self { local_dirs: Mutex::new(Some(vec![])), + max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, + used_disk_space: Arc::new(AtomicU64::new(0)), })), DiskManagerConfig::NewSpecified(conf_dirs) => { let local_dirs = create_local_dirs(conf_dirs)?; debug!( - "Created local dirs {:?} as DataFusion working directory", - local_dirs + "Created local dirs {local_dirs:?} as DataFusion working directory" ); Ok(Arc::new(Self { local_dirs: Mutex::new(Some(local_dirs)), + max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, + used_disk_space: Arc::new(AtomicU64::new(0)), })) } DiskManagerConfig::Disabled => Ok(Arc::new(Self { local_dirs: Mutex::new(None), + max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, + used_disk_space: Arc::new(AtomicU64::new(0)), })), } } + pub fn set_max_temp_directory_size( + &mut self, + max_temp_directory_size: u64, + ) -> Result<()> { + // If the disk manager is disabled and `max_temp_directory_size` is not 0, + // this operation is not meaningful, fail early. + if self.local_dirs.lock().is_none() && max_temp_directory_size != 0 { + return config_err!( + "Cannot set max temp directory size for a disk manager that spilling is disabled" + ); + } + + self.max_temp_directory_size = max_temp_directory_size; + Ok(()) + } + + pub fn set_arc_max_temp_directory_size( + this: &mut Arc, + max_temp_directory_size: u64, + ) -> Result<()> { + if let Some(inner) = Arc::get_mut(this) { + inner.set_max_temp_directory_size(max_temp_directory_size)?; + Ok(()) + } else { + config_err!("DiskManager should be a single instance") + } + } + + pub fn with_max_temp_directory_size( + mut self, + max_temp_directory_size: u64, + ) -> Result { + self.set_max_temp_directory_size(max_temp_directory_size)?; + Ok(self) + } + + pub fn used_disk_space(&self) -> u64 { + self.used_disk_space.load(Ordering::Relaxed) + } + /// Return true if this disk manager supports creating temporary /// files. If this returns false, any call to `create_tmp_file` /// will error. @@ -113,7 +268,7 @@ impl DiskManager { /// If the file can not be created for some reason, returns an /// error message referencing the request description pub fn create_tmp_file( - &self, + self: &Arc, request_description: &str, ) -> Result { let mut guard = self.local_dirs.lock(); @@ -136,24 +291,37 @@ impl DiskManager { local_dirs.push(Arc::new(tempdir)); } - let dir_index = thread_rng().gen_range(0..local_dirs.len()); + let dir_index = rng().random_range(0..local_dirs.len()); Ok(RefCountedTempFile { _parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Builder::new() .tempfile_in(local_dirs[dir_index].as_ref()) .map_err(DataFusionError::IoError)?, + current_file_disk_usage: 0, + disk_manager: Arc::clone(self), }) } } /// A wrapper around a [`NamedTempFile`] that also contains -/// a reference to its parent temporary directory +/// a reference to its parent temporary directory. +/// +/// # Note +/// After any modification to the underlying file (e.g., writing data to it), the caller +/// must invoke [`Self::update_disk_usage`] to update the global disk usage counter. +/// This ensures the disk manager can properly enforce usage limits configured by +/// [`DiskManager::with_max_temp_directory_size`]. #[derive(Debug)] pub struct RefCountedTempFile { /// The reference to the directory in which temporary files are created to ensure /// it is not cleaned up prior to the NamedTempFile _parent_temp_dir: Arc, tempfile: NamedTempFile, + /// Tracks the current disk usage of this temporary file. See + /// [`Self::update_disk_usage`] for more details. + current_file_disk_usage: u64, + /// The disk manager that created and manages this temporary file + disk_manager: Arc, } impl RefCountedTempFile { @@ -164,6 +332,54 @@ impl RefCountedTempFile { pub fn inner(&self) -> &NamedTempFile { &self.tempfile } + + /// Updates the global disk usage counter after modifications to the underlying file. + /// + /// # Errors + /// - Returns an error if the global disk usage exceeds the configured limit. + pub fn update_disk_usage(&mut self) -> Result<()> { + // Get new file size from OS + let metadata = self.tempfile.as_file().metadata()?; + let new_disk_usage = metadata.len(); + + // Update the global disk usage by: + // 1. Subtracting the old file size from the global counter + self.disk_manager + .used_disk_space + .fetch_sub(self.current_file_disk_usage, Ordering::Relaxed); + // 2. Adding the new file size to the global counter + self.disk_manager + .used_disk_space + .fetch_add(new_disk_usage, Ordering::Relaxed); + + // 3. Check if the updated global disk usage exceeds the configured limit + let global_disk_usage = self.disk_manager.used_disk_space.load(Ordering::Relaxed); + if global_disk_usage > self.disk_manager.max_temp_directory_size { + return resources_err!( + "The used disk space during the spilling process has exceeded the allowable limit of {}. Try increasing the `max_temp_directory_size` in the disk manager configuration.", + human_readable_size(self.disk_manager.max_temp_directory_size as usize) + ); + } + + // 4. Update the local file size tracking + self.current_file_disk_usage = new_disk_usage; + + Ok(()) + } + + pub fn current_disk_usage(&self) -> u64 { + self.current_file_disk_usage + } +} + +/// When the temporary file is dropped, subtract its disk usage from the disk manager's total +impl Drop for RefCountedTempFile { + fn drop(&mut self) { + // Subtract the current file's disk usage from the global counter + self.disk_manager + .used_disk_space + .fetch_sub(self.current_file_disk_usage, Ordering::Relaxed); + } } /// Setup local dirs by creating one new dir in each of the given dirs @@ -190,8 +406,7 @@ mod tests { #[test] fn lazy_temp_dir_creation() -> Result<()> { // A default configuration should not create temp files until requested - let config = DiskManagerConfig::new(); - let dm = DiskManager::try_new(config)?; + let dm = Arc::new(DiskManagerBuilder::default().build()?); assert_eq!(0, local_dir_snapshot(&dm).len()); @@ -223,11 +438,14 @@ mod tests { let local_dir2 = TempDir::new()?; let local_dir3 = TempDir::new()?; let local_dirs = vec![local_dir1.path(), local_dir2.path(), local_dir3.path()]; - let config = DiskManagerConfig::new_specified( - local_dirs.iter().map(|p| p.into()).collect(), + let dm = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories( + local_dirs.iter().map(|p| p.into()).collect(), + )) + .build()?, ); - let dm = DiskManager::try_new(config)?; assert!(dm.tmp_files_enabled()); let actual = dm.create_tmp_file("Testing")?; @@ -239,8 +457,12 @@ mod tests { #[test] fn test_disabled_disk_manager() { - let config = DiskManagerConfig::Disabled; - let manager = DiskManager::try_new(config).unwrap(); + let manager = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Disabled) + .build() + .unwrap(), + ); assert!(!manager.tmp_files_enabled()); assert_eq!( manager.create_tmp_file("Testing").unwrap_err().strip_backtrace(), @@ -251,11 +473,9 @@ mod tests { #[test] fn test_disk_manager_create_spill_folder() { let dir = TempDir::new().unwrap(); - let config = DiskManagerConfig::new_specified(vec![dir.path().to_owned()]); - - DiskManager::try_new(config) - .unwrap() - .create_tmp_file("Testing") + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories(vec![dir.path().to_path_buf()])) + .build() .unwrap(); } @@ -278,8 +498,7 @@ mod tests { #[test] fn test_temp_file_still_alive_after_disk_manager_dropped() -> Result<()> { // Test for the case using OS arranged temporary directory - let config = DiskManagerConfig::new(); - let dm = DiskManager::try_new(config)?; + let dm = Arc::new(DiskManagerBuilder::default().build()?); let temp_file = dm.create_tmp_file("Testing")?; let temp_file_path = temp_file.path().to_owned(); assert!(temp_file_path.exists()); @@ -295,10 +514,13 @@ mod tests { let local_dir2 = TempDir::new()?; let local_dir3 = TempDir::new()?; let local_dirs = [local_dir1.path(), local_dir2.path(), local_dir3.path()]; - let config = DiskManagerConfig::new_specified( - local_dirs.iter().map(|p| p.into()).collect(), + let dm = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories( + local_dirs.iter().map(|p| p.into()).collect(), + )) + .build()?, ); - let dm = DiskManager::try_new(config)?; let temp_file = dm.create_tmp_file("Testing")?; let temp_file_path = temp_file.path().to_owned(); assert!(temp_file_path.exists()); diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index 6a0a4b6322ee8..55243e301e0e9 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -31,6 +31,8 @@ pub mod config; pub mod disk_manager; pub mod memory_pool; pub mod object_store; +#[cfg(feature = "parquet_encryption")] +pub mod parquet_encryption; pub mod runtime_env; mod stream; mod task; diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 71d40aeab53c7..e620b23267962 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -19,7 +19,8 @@ //! help with allocation accounting. use datafusion_common::{internal_err, Result}; -use std::{cmp::Ordering, sync::Arc}; +use std::hash::{Hash, Hasher}; +use std::{cmp::Ordering, sync::atomic, sync::Arc}; mod pool; pub mod proxy { @@ -56,8 +57,8 @@ pub use pool::*; /// `GroupByHashExec`. It does NOT track and limit memory used internally by /// other operators such as `DataSourceExec` or the `RecordBatch`es that flow /// between operators. Furthermore, operators should not reserve memory for the -/// batches they produce. Instead, if a parent operator needs to hold batches -/// from its children in memory for an extended period, it is the parent +/// batches they produce. Instead, if a consumer operator needs to hold batches +/// from its producers in memory for an extended period, it is the consumer /// operator's responsibility to reserve the necessary memory for those batches. /// /// In order to avoid allocating memory until the OS or the container system @@ -97,6 +98,67 @@ pub use pool::*; /// operator will spill the intermediate buffers to disk, and release memory /// from the memory pool, and continue to retry memory reservation. /// +/// # Related Structs +/// +/// To better understand memory management in DataFusion, here are the key structs +/// and their relationships: +/// +/// - [`MemoryConsumer`]: A named allocation traced by a particular operator. If an +/// execution is parallelized, and there are multiple partitions of the same +/// operator, each partition will have a separate `MemoryConsumer`. +/// - `SharedRegistration`: A registration of a `MemoryConsumer` with a `MemoryPool`. +/// `SharedRegistration` and `MemoryPool` have a many-to-one relationship. `MemoryPool` +/// implementation can decide how to allocate memory based on the registered consumers. +/// (e.g. `FairSpillPool` will try to share available memory evenly among all registered +/// consumers) +/// - [`MemoryReservation`]: Each `MemoryConsumer`/operator can have multiple +/// `MemoryReservation`s for different internal data structures. The relationship +/// between `MemoryConsumer` and `MemoryReservation` is one-to-many. This design +/// enables cleaner operator implementations: +/// - Different `MemoryReservation`s can be used for different purposes +/// - `MemoryReservation` follows RAII principles - to release a reservation, +/// simply drop the `MemoryReservation` object. When all `MemoryReservation`s +/// for a `SharedRegistration` are dropped, the `SharedRegistration` is dropped +/// when its reference count reaches zero, automatically unregistering the +/// `MemoryConsumer` from the `MemoryPool`. +/// +/// ## Relationship Diagram +/// +/// ```text +/// ┌──────────────────┐ ┌──────────────────┐ +/// │MemoryReservation │ │MemoryReservation │ +/// └───┬──────────────┘ └──────────────────┘ ...... +/// │belongs to │ +/// │ ┌───────────────────────┘ │ │ +/// │ │ │ │ +/// ▼ ▼ ▼ ▼ +/// ┌────────────────────────┐ ┌────────────────────────┐ +/// │ SharedRegistration │ │ SharedRegistration │ +/// │ ┌────────────────┐ │ │ ┌────────────────┐ │ +/// │ │ │ │ │ │ │ │ +/// │ │ MemoryConsumer │ │ │ │ MemoryConsumer │ │ +/// │ │ │ │ │ │ │ │ +/// │ └────────────────┘ │ │ └────────────────┘ │ +/// └────────────┬───────────┘ └────────────┬───────────┘ +/// │ │ +/// │ register│into +/// │ │ +/// └─────────────┐ ┌──────────────┘ +/// │ │ +/// ▼ ▼ +/// ╔═══════════════════════════════════════════════════╗ +/// ║ ║ +/// ║ MemoryPool ║ +/// ║ ║ +/// ╚═══════════════════════════════════════════════════╝ +/// ``` +/// +/// For example, there are two parallel partitions of an operator X: each partition +/// corresponds to a `MemoryConsumer` in the above diagram. Inside each partition of +/// operator X, there are typically several `MemoryReservation`s - one for each +/// internal data structure that needs memory tracking (e.g., 1 reservation for the hash +/// table, and 1 reservation for buffered input, etc.). +/// /// # Implementing `MemoryPool` /// /// You can implement a custom allocation policy by implementing the @@ -140,30 +202,101 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// Return the total amount of memory reserved fn reserved(&self) -> usize; + + /// Return the memory limit of the pool + /// + /// The default implementation of `MemoryPool::memory_limit` + /// will return `MemoryLimit::Unknown`. + /// If you are using your custom memory pool, but have the requirement to + /// know the memory usage limit of the pool, please implement this method + /// to return it(`Memory::Finite(limit)`). + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Unknown + } +} + +/// Memory limit of `MemoryPool` +pub enum MemoryLimit { + Infinite, + /// Bounded memory limit in bytes. + Finite(usize), + Unknown, } /// A memory consumer is a named allocation traced by a particular /// [`MemoryReservation`] in a [`MemoryPool`]. All allocations are registered to /// a particular `MemoryConsumer`; /// +/// Each `MemoryConsumer` is identifiable by a process-unique id, and is therefor not cloneable, +/// If you want a clone of a `MemoryConsumer`, you should look into [`MemoryConsumer::clone_with_new_id`], +/// but note that this `MemoryConsumer` may be treated as a separate entity based on the used pool, +/// and is only guaranteed to share the name and inner properties. +/// /// For help with allocation accounting, see the [`proxy`] module. /// /// [proxy]: datafusion_common::utils::proxy -#[derive(Debug, PartialEq, Eq, Hash, Clone)] +#[derive(Debug)] pub struct MemoryConsumer { name: String, can_spill: bool, + id: usize, +} + +impl PartialEq for MemoryConsumer { + fn eq(&self, other: &Self) -> bool { + let is_same_id = self.id == other.id; + + #[cfg(debug_assertions)] + if is_same_id { + assert_eq!(self.name, other.name); + assert_eq!(self.can_spill, other.can_spill); + } + + is_same_id + } +} + +impl Eq for MemoryConsumer {} + +impl Hash for MemoryConsumer { + fn hash(&self, state: &mut H) { + self.id.hash(state); + self.name.hash(state); + self.can_spill.hash(state); + } } impl MemoryConsumer { + fn new_unique_id() -> usize { + static ID: atomic::AtomicUsize = atomic::AtomicUsize::new(0); + ID.fetch_add(1, atomic::Ordering::Relaxed) + } + /// Create a new empty [`MemoryConsumer`] that can be grown using [`MemoryReservation`] pub fn new(name: impl Into) -> Self { Self { name: name.into(), can_spill: false, + id: Self::new_unique_id(), } } + /// Returns a clone of this [`MemoryConsumer`] with a new unique id, + /// which can be registered with a [`MemoryPool`], + /// This new consumer is separate from the original. + pub fn clone_with_new_id(&self) -> Self { + Self { + name: self.name.clone(), + can_spill: self.can_spill, + id: Self::new_unique_id(), + } + } + + /// Return the unique id of this [`MemoryConsumer`] + pub fn id(&self) -> usize { + self.id + } + /// Set whether this allocation can be spilled to disk pub fn with_can_spill(self, can_spill: bool) -> Self { Self { can_spill, ..self } @@ -349,7 +482,7 @@ pub mod units { pub const KB: u64 = 1 << 10; } -/// Present size in human readable form +/// Present size in human-readable form pub fn human_readable_size(size: usize) -> String { use units::*; @@ -374,6 +507,15 @@ pub fn human_readable_size(size: usize) -> String { mod tests { use super::*; + #[test] + fn test_id_uniqueness() { + let mut ids = std::collections::HashSet::new(); + for _ in 0..100 { + let consumer = MemoryConsumer::new("test"); + assert!(ids.insert(consumer.id())); // Ensures unique insertion + } + } + #[test] fn test_memory_pool_underflow() { let pool = Arc::new(GreedyMemoryPool::new(50)) as _; diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index 261332180e571..306df3defdbb3 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +use crate::memory_pool::{ + human_readable_size, MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, +}; use datafusion_common::HashMap; use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; use log::debug; use parking_lot::Mutex; use std::{ num::NonZeroUsize, - sync::atomic::{AtomicU64, AtomicUsize, Ordering}, + sync::atomic::{AtomicUsize, Ordering}, }; /// A [`MemoryPool`] that enforces no limit @@ -48,6 +50,10 @@ impl MemoryPool for UnboundedMemoryPool { fn reserved(&self) -> usize { self.used.load(Ordering::Relaxed) } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Infinite + } } /// A [`MemoryPool`] that implements a greedy first-come first-serve limit. @@ -100,6 +106,10 @@ impl MemoryPool for GreedyMemoryPool { fn reserved(&self) -> usize { self.used.load(Ordering::Relaxed) } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Finite(self.pool_size) + } } /// A [`MemoryPool`] that prevents spillable reservations from using more than @@ -233,6 +243,10 @@ impl MemoryPool for FairSpillPool { let state = self.state.lock(); state.spillable + state.unspillable } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Finite(self.pool_size) + } } /// Constructs a resources error based upon the individual [`MemoryReservation`]. @@ -246,7 +260,41 @@ fn insufficient_capacity_err( additional: usize, available: usize, ) -> DataFusionError { - resources_datafusion_err!("Failed to allocate additional {} bytes for {} with {} bytes already allocated for this reservation - {} bytes remain available for the total pool", additional, reservation.registration.consumer.name, reservation.size, available) + resources_datafusion_err!("Failed to allocate additional {} for {} with {} already allocated for this reservation - {} remain available for the total pool", + human_readable_size(additional), reservation.registration.consumer.name, human_readable_size(reservation.size), human_readable_size(available)) +} + +#[derive(Debug)] +struct TrackedConsumer { + name: String, + can_spill: bool, + reserved: AtomicUsize, + peak: AtomicUsize, +} + +impl TrackedConsumer { + /// Shorthand to return the currently reserved value + fn reserved(&self) -> usize { + self.reserved.load(Ordering::Relaxed) + } + + /// Return the peak value + fn peak(&self) -> usize { + self.peak.load(Ordering::Relaxed) + } + + /// Grows the tracked consumer's reserved size, + /// should be called after the pool has successfully performed the grow(). + fn grow(&self, additional: usize) { + self.reserved.fetch_add(additional, Ordering::Relaxed); + self.peak.fetch_max(self.reserved(), Ordering::Relaxed); + } + + /// Reduce the tracked consumer's reserved size, + /// should be called after the pool has successfully performed the shrink(). + fn shrink(&self, shrink: usize) { + self.reserved.fetch_sub(shrink, Ordering::Relaxed); + } } /// A [`MemoryPool`] that tracks the consumers that have @@ -254,19 +302,68 @@ fn insufficient_capacity_err( /// /// By tracking memory reservations more carefully this pool /// can provide better error messages on the largest memory users +/// when memory allocation fails. /// /// Tracking is per hashed [`MemoryConsumer`], not per [`MemoryReservation`]. /// The same consumer can have multiple reservations. +/// +/// # Automatic Usage via [`RuntimeEnvBuilder`] +/// +/// The easiest way to use `TrackConsumersPool` is via +/// [`RuntimeEnvBuilder::with_memory_limit()`]. +/// +/// [`RuntimeEnvBuilder`]: crate::runtime_env::RuntimeEnvBuilder +/// [`RuntimeEnvBuilder::with_memory_limit()`]: crate::runtime_env::RuntimeEnvBuilder::with_memory_limit +/// +/// # Usage Examples +/// +/// For more examples of using `TrackConsumersPool`, see the [memory_pool_tracking.rs] example +/// +/// [memory_pool_tracking.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/memory_pool_tracking.rs +/// [memory_pool_execution_plan.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/memory_pool_execution_plan.rs #[derive(Debug)] pub struct TrackConsumersPool { + /// The wrapped memory pool that actually handles reservation logic inner: I, + /// The amount of consumers to report(ordered top to bottom by reservation size) top: NonZeroUsize, - tracked_consumers: Mutex>, + /// Maps consumer_id --> TrackedConsumer + tracked_consumers: Mutex>, } impl TrackConsumersPool { /// Creates a new [`TrackConsumersPool`]. /// + /// # Arguments + /// * `inner` - The underlying memory pool that handles actual memory allocation + /// * `top` - The number of top memory consumers to include in error messages + /// + /// # Note + /// In most cases, you should use [`RuntimeEnvBuilder::with_memory_limit()`](crate::runtime_env::RuntimeEnvBuilder::with_memory_limit) + /// instead of creating this pool manually, as it automatically sets up tracking with + /// sensible defaults (top 5 consumers). + /// + /// # Example + /// + /// ```rust + /// use std::num::NonZeroUsize; + /// use datafusion_execution::memory_pool::{TrackConsumersPool, GreedyMemoryPool, FairSpillPool}; + /// + /// // Create with a greedy pool backend, reporting top 3 consumers in error messages + /// let tracked_greedy = TrackConsumersPool::new( + /// GreedyMemoryPool::new(1024 * 1024), // 1MB limit + /// NonZeroUsize::new(3).unwrap(), + /// ); + /// + /// // Create with a fair spill pool backend, reporting top 5 consumers in error messages + /// let tracked_fair = TrackConsumersPool::new( + /// FairSpillPool::new(2 * 1024 * 1024), // 2MB limit + /// NonZeroUsize::new(5).unwrap(), + /// ); + /// ``` + /// + /// # Impact on Error Messages + /// /// The `top` determines how many Top K [`MemoryConsumer`]s to include /// in the reported [`DataFusionError::ResourcesExhausted`]. pub fn new(inner: I, top: NonZeroUsize) -> Self { @@ -277,27 +374,21 @@ impl TrackConsumersPool { } } - /// Determine if there are multiple [`MemoryConsumer`]s registered - /// which have the same name. - /// - /// This is very tied to the implementation of the memory consumer. - fn has_multiple_consumers(&self, name: &String) -> bool { - let consumer = MemoryConsumer::new(name); - let consumer_with_spill = consumer.clone().with_can_spill(true); - let guard = self.tracked_consumers.lock(); - guard.contains_key(&consumer) && guard.contains_key(&consumer_with_spill) - } - - /// The top consumers in a report string. + /// Returns a formatted string with the top memory consumers. pub fn report_top(&self, top: usize) -> String { let mut consumers = self .tracked_consumers .lock() .iter() - .map(|(consumer, reserved)| { + .map(|(consumer_id, tracked_consumer)| { ( - (consumer.name().to_owned(), consumer.can_spill()), - reserved.load(Ordering::Acquire), + ( + *consumer_id, + tracked_consumer.name.to_owned(), + tracked_consumer.can_spill, + tracked_consumer.peak(), + ), + tracked_consumer.reserved(), ) }) .collect::>(); @@ -305,15 +396,16 @@ impl TrackConsumersPool { consumers[0..std::cmp::min(top, consumers.len())] .iter() - .map(|((name, can_spill), size)| { - if self.has_multiple_consumers(name) { - format!("{name}(can_spill={}) consumed {:?} bytes", can_spill, size) - } else { - format!("{name} consumed {:?} bytes", size) - } + .map(|((id, name, can_spill, peak), size)| { + format!( + " {name}#{id}(can spill: {can_spill}) consumed {}, peak {}", + human_readable_size(*size), + human_readable_size(*peak), + ) }) .collect::>() - .join(", ") + .join(",\n") + + "." } } @@ -322,29 +414,34 @@ impl MemoryPool for TrackConsumersPool { self.inner.register(consumer); let mut guard = self.tracked_consumers.lock(); - if let Some(already_reserved) = guard.insert(consumer.clone(), Default::default()) - { - guard.entry_ref(consumer).and_modify(|bytes| { - bytes.fetch_add( - already_reserved.load(Ordering::Acquire), - Ordering::AcqRel, - ); - }); - } + let existing = guard.insert( + consumer.id(), + TrackedConsumer { + name: consumer.name().to_string(), + can_spill: consumer.can_spill(), + reserved: Default::default(), + peak: Default::default(), + }, + ); + + debug_assert!( + existing.is_none(), + "Registered was called twice on the same consumer" + ); } fn unregister(&self, consumer: &MemoryConsumer) { self.inner.unregister(consumer); - self.tracked_consumers.lock().remove(consumer); + self.tracked_consumers.lock().remove(&consumer.id()); } fn grow(&self, reservation: &MemoryReservation, additional: usize) { self.inner.grow(reservation, additional); self.tracked_consumers .lock() - .entry_ref(reservation.consumer()) - .and_modify(|bytes| { - bytes.fetch_add(additional as u64, Ordering::AcqRel); + .entry(reservation.consumer().id()) + .and_modify(|tracked_consumer| { + tracked_consumer.grow(additional); }); } @@ -352,9 +449,9 @@ impl MemoryPool for TrackConsumersPool { self.inner.shrink(reservation, shrink); self.tracked_consumers .lock() - .entry_ref(reservation.consumer()) - .and_modify(|bytes| { - bytes.fetch_sub(shrink as u64, Ordering::AcqRel); + .entry(reservation.consumer().id()) + .and_modify(|tracked_consumer| { + tracked_consumer.shrink(shrink); }); } @@ -366,6 +463,7 @@ impl MemoryPool for TrackConsumersPool { // wrap OOM message in top consumers DataFusionError::ResourcesExhausted( provide_top_memory_consumers_to_error_msg( + &reservation.consumer().name, e, self.report_top(self.top.into()), ), @@ -376,9 +474,9 @@ impl MemoryPool for TrackConsumersPool { self.tracked_consumers .lock() - .entry_ref(reservation.consumer()) - .and_modify(|bytes| { - bytes.fetch_add(additional as u64, Ordering::AcqRel); + .entry(reservation.consumer().id()) + .and_modify(|tracked_consumer| { + tracked_consumer.grow(additional); }); Ok(()) } @@ -386,20 +484,35 @@ impl MemoryPool for TrackConsumersPool { fn reserved(&self) -> usize { self.inner.reserved() } + + fn memory_limit(&self) -> MemoryLimit { + self.inner.memory_limit() + } } fn provide_top_memory_consumers_to_error_msg( + consumer_name: &str, error_msg: String, top_consumers: String, ) -> String { - format!("Additional allocation failed with top memory consumers (across reservations) as: {}. Error: {}", top_consumers, error_msg) + format!("Additional allocation failed for {consumer_name} with top memory consumers (across reservations) as:\n{top_consumers}\nError: {error_msg}") } #[cfg(test)] mod tests { use super::*; + use insta::{allow_duplicates, assert_snapshot, Settings}; use std::sync::Arc; + fn make_settings() -> Settings { + let mut settings = Settings::clone_current(); + settings.add_filter( + r"([^\s]+)\#\d+\(can spill: (true|false)\)", + "$1#[ID](can spill: $2)", + ); + settings + } + #[test] fn test_fair() { let pool = Arc::new(FairSpillPool::new(100)) as _; @@ -418,10 +531,10 @@ mod tests { assert_eq!(pool.reserved(), 4000); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 1.0 B for r2 with 2000.0 B already allocated for this reservation - 0.0 B remain available for the total pool"); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 1.0 B for r2 with 2000.0 B already allocated for this reservation - 0.0 B remain available for the total pool"); r1.shrink(1990); r2.shrink(2000); @@ -446,12 +559,12 @@ mod tests { .register(&pool); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 70.0 B for r3 with 0.0 B already allocated for this reservation - 40.0 B remain available for the total pool"); //Shrinking r2 to zero doesn't allow a3 to allocate more than 45 r2.free(); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 70.0 B for r3 with 0.0 B already allocated for this reservation - 40.0 B remain available for the total pool"); // But dropping r2 does drop(r2); @@ -464,11 +577,13 @@ mod tests { let mut r4 = MemoryConsumer::new("s4").register(&pool); let err = r4.try_grow(30).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 30 bytes for s4 with 0 bytes already allocated for this reservation - 20 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 30.0 B for s4 with 0.0 B already allocated for this reservation - 20.0 B remain available for the total pool"); } #[test] fn test_tracked_consumers_pool() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let pool: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), NonZeroUsize::new(3).unwrap(), @@ -478,7 +593,8 @@ mod tests { // set r1=50, using grow and shrink let mut r1 = MemoryConsumer::new("r1").register(&pool); - r1.grow(70); + r1.grow(50); + r1.grow(20); r1.shrink(20); // set r2=15 using try_grow @@ -501,20 +617,22 @@ mod tests { // Test: reports if new reservation causes error // using the previously set sizes for other consumers let mut r5 = MemoryConsumer::new("r5").register(&pool); - let expected = "Additional allocation failed with top memory consumers (across reservations) as: r1 consumed 50 bytes, r3 consumed 20 bytes, r2 consumed 15 bytes. Error: Failed to allocate additional 150 bytes for r5 with 0 bytes already allocated for this reservation - 5 bytes remain available for the total pool"; let res = r5.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) - ), - "should provide list of top memory consumers, instead found {:?}", - res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for r5 with top memory consumers (across reservations) as: + r1#[ID](can spill: false) consumed 50.0 B, peak 70.0 B, + r3#[ID](can spill: false) consumed 20.0 B, peak 25.0 B, + r2#[ID](can spill: false) consumed 15.0 B, peak 15.0 B. + Error: Failed to allocate additional 150.0 B for r5 with 0.0 B already allocated for this reservation - 5.0 B remain available for the total pool + "); } #[test] fn test_tracked_consumers_pool_register() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let pool: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), NonZeroUsize::new(3).unwrap(), @@ -524,120 +642,118 @@ mod tests { // Test: see error message when no consumers recorded yet let mut r0 = MemoryConsumer::new(same_name).register(&pool); - let expected = "Additional allocation failed with top memory consumers (across reservations) as: foo consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 100 bytes remain available for the total pool"; let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) - ), - "should provide proper error when no reservations have been made yet, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for foo with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 0.0 B, peak 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 100.0 B remain available for the total pool + "); // API: multiple registrations using the same hashed consumer, - // will be recognized as the same in the TrackConsumersPool. + // will be recognized *differently* in the TrackConsumersPool. - // Test: will be the same per Top Consumers reported. r0.grow(10); // make r0=10, pool available=90 let new_consumer_same_name = MemoryConsumer::new(same_name); let mut r1 = new_consumer_same_name.register(&pool); // TODO: the insufficient_capacity_err() message is per reservation, not per consumer. // a followup PR will clarify this message "0 bytes already allocated for this reservation" - let expected = "Additional allocation failed with top memory consumers (across reservations) as: foo consumed 10 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; let res = r1.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) - ), - "should provide proper error with same hashed consumer (a single foo=10 bytes, available=90), instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for foo with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 10.0 B, peak 10.0 B, + foo#[ID](can spill: false) consumed 0.0 B, peak 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 90.0 B remain available for the total pool + "); // Test: will accumulate size changes per consumer, not per reservation r1.grow(20); - let expected = "Additional allocation failed with top memory consumers (across reservations) as: foo consumed 30 bytes. Error: Failed to allocate additional 150 bytes for foo with 20 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r1.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) - ), - "should provide proper error with same hashed consumer (a single foo=30 bytes, available=70), instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for foo with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 20.0 B, peak 20.0 B, + foo#[ID](can spill: false) consumed 10.0 B, peak 10.0 B. + Error: Failed to allocate additional 150.0 B for foo with 20.0 B already allocated for this reservation - 70.0 B remain available for the total pool + "); // Test: different hashed consumer, (even with the same name), // will be recognized as different in the TrackConsumersPool let consumer_with_same_name_but_different_hash = MemoryConsumer::new(same_name).with_can_spill(true); let mut r2 = consumer_with_same_name_but_different_hash.register(&pool); - let expected = "Additional allocation failed with top memory consumers (across reservations) as: foo(can_spill=false) consumed 30 bytes, foo(can_spill=true) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; let res = r2.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) - ), - "should provide proper error with different hashed consumer (foo(can_spill=false)=30 bytes and foo(can_spill=true)=0 bytes, available=70), instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for foo with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 20.0 B, peak 20.0 B, + foo#[ID](can spill: false) consumed 10.0 B, peak 10.0 B, + foo#[ID](can spill: true) consumed 0.0 B, peak 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 70.0 B remain available for the total pool + "); } #[test] fn test_tracked_consumers_pool_deregister() { fn test_per_pool_type(pool: Arc) { // Baseline: see the 2 memory consumers + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let mut r0 = MemoryConsumer::new("r0").register(&pool); r0.grow(10); let r1_consumer = MemoryConsumer::new("r1"); - let mut r1 = r1_consumer.clone().register(&pool); + let mut r1 = r1_consumer.register(&pool); r1.grow(20); - let expected = "Additional allocation failed with top memory consumers (across reservations) as: r1 consumed 20 bytes, r0 consumed 10 bytes. Error: Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) - ), - "should provide proper error with both consumers, instead found {:?}", - res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for r0 with top memory consumers (across reservations) as: + r1#[ID](can spill: false) consumed 20.0 B, peak 20.0 B, + r0#[ID](can spill: false) consumed 10.0 B, peak 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 70.0 B remain available for the total pool + ")); // Test: unregister one // only the remaining one should be listed - pool.unregister(&r1_consumer); - let expected_consumers = "Additional allocation failed with top memory consumers (across reservations) as: r0 consumed 10 bytes"; + drop(r1); let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_consumers) - ), - "should provide proper error with only 1 consumer left registered, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for r0 with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B, peak 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); // Test: actual message we see is the `available is 70`. When it should be `available is 90`. // This is because the pool.shrink() does not automatically occur within the inner_pool.deregister(). - let expected_70_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_70_available) - ), - "should find that the inner pool will still count all bytes for the deregistered consumer until the reservation is dropped, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for r0 with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B, peak 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); // Test: the registration needs to free itself (or be dropped), // for the proper error message - r1.free(); - let expected_90_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_90_available) - ), - "should correctly account the total bytes after reservation is free, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed for r0 with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B, peak 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); } let tracked_spill_pool: Arc = Arc::new(TrackConsumersPool::new( @@ -655,6 +771,8 @@ mod tests { #[test] fn test_tracked_consumers_pool_use_beyond_errors() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let upcasted: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), @@ -678,12 +796,10 @@ mod tests { .unwrap(); // Test: can get runtime metrics, even without an error thrown - let expected = "r3 consumed 45 bytes, r1 consumed 20 bytes"; let res = downcasted.report_top(2); - assert_eq!( - res, expected, - "should provide list of top memory consumers, instead found {:?}", - res - ); + assert_snapshot!(res, @r" + r3#[ID](can spill: false) consumed 45.0 B, peak 45.0 B, + r1#[ID](can spill: false) consumed 20.0 B, peak 20.0 B. + "); } } diff --git a/datafusion/execution/src/object_store.rs b/datafusion/execution/src/object_store.rs index cd75c9f3c49ee..ef83128ac6818 100644 --- a/datafusion/execution/src/object_store.rs +++ b/datafusion/execution/src/object_store.rs @@ -20,7 +20,7 @@ //! and query data inside these systems. use dashmap::DashMap; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, internal_datafusion_err, DataFusionError, Result}; #[cfg(not(target_arch = "wasm32"))] use object_store::local::LocalFileSystem; use object_store::ObjectStore; @@ -236,9 +236,7 @@ impl ObjectStoreRegistry for DefaultObjectStoreRegistry { .get(&s) .map(|o| Arc::clone(o.value())) .ok_or_else(|| { - DataFusionError::Internal(format!( - "No suitable object store found for {url}. See `RuntimeEnv::register_object_store`" - )) + internal_datafusion_err!("No suitable object store found for {url}. See `RuntimeEnv::register_object_store`") }) } } diff --git a/datafusion/execution/src/parquet_encryption.rs b/datafusion/execution/src/parquet_encryption.rs new file mode 100644 index 0000000000000..73881e11ca72f --- /dev/null +++ b/datafusion/execution/src/parquet_encryption.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use dashmap::DashMap; +use datafusion_common::config::EncryptionFactoryOptions; +use datafusion_common::error::Result; +use datafusion_common::internal_datafusion_err; +use object_store::path::Path; +use parquet::encryption::decrypt::FileDecryptionProperties; +use parquet::encryption::encrypt::FileEncryptionProperties; +use std::sync::Arc; + +/// Trait for types that generate file encryption and decryption properties to +/// write and read encrypted Parquet files. +/// This allows flexibility in how encryption keys are managed, for example, to +/// integrate with a user's key management service (KMS). +/// For example usage, see the [`parquet_encrypted_with_kms` example]. +/// +/// [`parquet_encrypted_with_kms` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/parquet_encrypted_with_kms.rs +#[async_trait] +pub trait EncryptionFactory: Send + Sync + std::fmt::Debug + 'static { + /// Generate file encryption properties to use when writing a Parquet file. + async fn get_file_encryption_properties( + &self, + config: &EncryptionFactoryOptions, + schema: &SchemaRef, + file_path: &Path, + ) -> Result>; + + /// Generate file decryption properties to use when reading a Parquet file. + async fn get_file_decryption_properties( + &self, + config: &EncryptionFactoryOptions, + file_path: &Path, + ) -> Result>; +} + +/// Stores [`EncryptionFactory`] implementations that can be retrieved by a unique string identifier +#[derive(Clone, Debug, Default)] +pub struct EncryptionFactoryRegistry { + factories: DashMap>, +} + +impl EncryptionFactoryRegistry { + /// Register an [`EncryptionFactory`] with an associated identifier that can be later + /// used to configure encryption when reading or writing Parquet. + /// If an encryption factory with the same identifier was already registered, it is replaced and returned. + pub fn register_factory( + &self, + id: &str, + factory: Arc, + ) -> Option> { + self.factories.insert(id.to_owned(), factory) + } + + /// Retrieve an [`EncryptionFactory`] by its identifier + pub fn get_factory(&self, id: &str) -> Result> { + self.factories + .get(id) + .map(|f| Arc::clone(f.value())) + .ok_or_else(|| { + internal_datafusion_err!( + "No Parquet encryption factory found for id '{id}'" + ) + }) + } +} diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index 95f14f485792a..db045a8b7e8a7 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -18,8 +18,10 @@ //! Execution [`RuntimeEnv`] environment that manages access to object //! store, memory manager, disk manager. +#[allow(deprecated)] +use crate::disk_manager::DiskManagerConfig; use crate::{ - disk_manager::{DiskManager, DiskManagerConfig}, + disk_manager::{DiskManager, DiskManagerBuilder, DiskManagerMode}, memory_pool::{ GreedyMemoryPool, MemoryPool, TrackConsumersPool, UnboundedMemoryPool, }, @@ -27,7 +29,9 @@ use crate::{ }; use crate::cache::cache_manager::{CacheManager, CacheManagerConfig}; -use datafusion_common::Result; +#[cfg(feature = "parquet_encryption")] +use crate::parquet_encryption::{EncryptionFactory, EncryptionFactoryRegistry}; +use datafusion_common::{config::ConfigEntry, Result}; use object_store::ObjectStore; use std::path::PathBuf; use std::sync::Arc; @@ -76,6 +80,9 @@ pub struct RuntimeEnv { pub cache_manager: Arc, /// Object Store Registry pub object_store_registry: Arc, + /// Parquet encryption factory registry + #[cfg(feature = "parquet_encryption")] + pub parquet_encryption_factory_registry: Arc, } impl Debug for RuntimeEnv { @@ -85,18 +92,6 @@ impl Debug for RuntimeEnv { } impl RuntimeEnv { - #[deprecated(since = "43.0.0", note = "please use `RuntimeEnvBuilder` instead")] - #[allow(deprecated)] - pub fn new(config: RuntimeConfig) -> Result { - Self::try_new(config) - } - /// Create env based on configuration - #[deprecated(since = "44.0.0", note = "please use `RuntimeEnvBuilder` instead")] - #[allow(deprecated)] - pub fn try_new(config: RuntimeConfig) -> Result { - config.build() - } - /// Registers a custom `ObjectStore` to be used with a specific url. /// This allows DataFusion to create external tables from urls that do not have /// built in support such as `hdfs://namenode:port/...`. @@ -152,6 +147,28 @@ impl RuntimeEnv { pub fn object_store(&self, url: impl AsRef) -> Result> { self.object_store_registry.get_store(url.as_ref()) } + + /// Register an [`EncryptionFactory`] with an associated identifier that can be later + /// used to configure encryption when reading or writing Parquet. + /// If an encryption factory with the same identifier was already registered, it is replaced and returned. + #[cfg(feature = "parquet_encryption")] + pub fn register_parquet_encryption_factory( + &self, + id: &str, + encryption_factory: Arc, + ) -> Option> { + self.parquet_encryption_factory_registry + .register_factory(id, encryption_factory) + } + + /// Retrieve an [`EncryptionFactory`] by its identifier + #[cfg(feature = "parquet_encryption")] + pub fn parquet_encryption_factory( + &self, + id: &str, + ) -> Result> { + self.parquet_encryption_factory_registry.get_factory(id) + } } impl Default for RuntimeEnv { @@ -160,18 +177,16 @@ impl Default for RuntimeEnv { } } -/// Please see: -/// This a type alias for backwards compatibility. -#[deprecated(since = "43.0.0", note = "please use `RuntimeEnvBuilder` instead")] -pub type RuntimeConfig = RuntimeEnvBuilder; - -#[derive(Clone)] /// Execution runtime configuration builder. /// /// See example on [`RuntimeEnv`] +#[derive(Clone)] pub struct RuntimeEnvBuilder { + #[allow(deprecated)] /// DiskManager to manage temporary disk file usage pub disk_manager: DiskManagerConfig, + /// DiskManager builder to manager temporary disk file usage + pub disk_manager_builder: Option, /// [`MemoryPool`] from which to allocate memory /// /// Defaults to using an [`UnboundedMemoryPool`] if `None` @@ -180,6 +195,9 @@ pub struct RuntimeEnvBuilder { pub cache_manager: CacheManagerConfig, /// ObjectStoreRegistry to get object store based on url pub object_store_registry: Arc, + /// Parquet encryption factory registry + #[cfg(feature = "parquet_encryption")] + pub parquet_encryption_factory_registry: Arc, } impl Default for RuntimeEnvBuilder { @@ -193,18 +211,29 @@ impl RuntimeEnvBuilder { pub fn new() -> Self { Self { disk_manager: Default::default(), + disk_manager_builder: Default::default(), memory_pool: Default::default(), cache_manager: Default::default(), object_store_registry: Arc::new(DefaultObjectStoreRegistry::default()), + #[cfg(feature = "parquet_encryption")] + parquet_encryption_factory_registry: Default::default(), } } + #[allow(deprecated)] + #[deprecated(since = "48.0.0", note = "Use with_disk_manager_builder instead")] /// Customize disk manager pub fn with_disk_manager(mut self, disk_manager: DiskManagerConfig) -> Self { self.disk_manager = disk_manager; self } + /// Customize the disk manager builder + pub fn with_disk_manager_builder(mut self, disk_manager: DiskManagerBuilder) -> Self { + self.disk_manager_builder = Some(disk_manager); + self + } + /// Customize memory policy pub fn with_memory_pool(mut self, memory_pool: Arc) -> Self { self.memory_pool = Some(memory_pool); @@ -229,7 +258,8 @@ impl RuntimeEnvBuilder { /// Specify the total memory to use while running the DataFusion /// plan to `max_memory * memory_fraction` in bytes. /// - /// This defaults to using [`GreedyMemoryPool`] + /// This defaults to using [`GreedyMemoryPool`] wrapped in the + /// [`TrackConsumersPool`] with a maximum of 5 consumers. /// /// Note DataFusion does not yet respect this limit in all cases. pub fn with_memory_limit(self, max_memory: usize, memory_fraction: f64) -> Self { @@ -241,26 +271,51 @@ impl RuntimeEnvBuilder { } /// Use the specified path to create any needed temporary files - pub fn with_temp_file_path(self, path: impl Into) -> Self { - self.with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()])) + pub fn with_temp_file_path(mut self, path: impl Into) -> Self { + let builder = self.disk_manager_builder.take().unwrap_or_default(); + self.with_disk_manager_builder( + builder.with_mode(DiskManagerMode::Directories(vec![path.into()])), + ) + } + + /// Specify a limit on the size of the temporary file directory in bytes + pub fn with_max_temp_directory_size(mut self, size: u64) -> Self { + let builder = self.disk_manager_builder.take().unwrap_or_default(); + self.with_disk_manager_builder(builder.with_max_temp_directory_size(size)) + } + + /// Specify the limit of the file-embedded metadata cache, in bytes. + pub fn with_metadata_cache_limit(mut self, limit: usize) -> Self { + self.cache_manager = self.cache_manager.with_metadata_cache_limit(limit); + self } /// Build a RuntimeEnv pub fn build(self) -> Result { let Self { disk_manager, + disk_manager_builder, memory_pool, cache_manager, object_store_registry, + #[cfg(feature = "parquet_encryption")] + parquet_encryption_factory_registry, } = self; let memory_pool = memory_pool.unwrap_or_else(|| Arc::new(UnboundedMemoryPool::default())); Ok(RuntimeEnv { memory_pool, - disk_manager: DiskManager::try_new(disk_manager)?, + disk_manager: if let Some(builder) = disk_manager_builder { + Arc::new(builder.build()?) + } else { + #[allow(deprecated)] + DiskManager::try_new(disk_manager)? + }, cache_manager: CacheManager::try_new(&cache_manager)?, object_store_registry, + #[cfg(feature = "parquet_encryption")] + parquet_encryption_factory_registry, }) } @@ -268,4 +323,82 @@ impl RuntimeEnvBuilder { pub fn build_arc(self) -> Result> { self.build().map(Arc::new) } + + /// Create a new RuntimeEnvBuilder from an existing RuntimeEnv + pub fn from_runtime_env(runtime_env: &RuntimeEnv) -> Self { + let cache_config = CacheManagerConfig { + table_files_statistics_cache: runtime_env + .cache_manager + .get_file_statistic_cache(), + list_files_cache: runtime_env.cache_manager.get_list_files_cache(), + file_metadata_cache: Some( + runtime_env.cache_manager.get_file_metadata_cache(), + ), + metadata_cache_limit: runtime_env.cache_manager.get_metadata_cache_limit(), + }; + + Self { + #[allow(deprecated)] + disk_manager: DiskManagerConfig::Existing(Arc::clone( + &runtime_env.disk_manager, + )), + disk_manager_builder: None, + memory_pool: Some(Arc::clone(&runtime_env.memory_pool)), + cache_manager: cache_config, + object_store_registry: Arc::clone(&runtime_env.object_store_registry), + #[cfg(feature = "parquet_encryption")] + parquet_encryption_factory_registry: Arc::clone( + &runtime_env.parquet_encryption_factory_registry, + ), + } + } + + /// Returns a list of all available runtime configurations with their current values and descriptions + pub fn entries(&self) -> Vec { + vec![ + ConfigEntry { + key: "datafusion.runtime.memory_limit".to_string(), + value: None, // Default is system-dependent + description: "Maximum memory limit for query execution. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes.", + }, + ConfigEntry { + key: "datafusion.runtime.max_temp_directory_size".to_string(), + value: Some("100G".to_string()), + description: "Maximum temporary file directory size. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes.", + }, + ConfigEntry { + key: "datafusion.runtime.temp_directory".to_string(), + value: None, // Default is system-dependent + description: "The path to the temporary file directory.", + }, + ConfigEntry { + key: "datafusion.runtime.metadata_cache_limit".to_string(), + value: Some("50M".to_owned()), + description: "Maximum memory to use for file metadata cache such as Parquet metadata. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes.", + } + ] + } + + /// Generate documentation that can be included in the user guide + pub fn generate_config_markdown() -> String { + use std::fmt::Write as _; + + let s = Self::default(); + + let mut docs = "| key | default | description |\n".to_string(); + docs += "|-----|---------|-------------|\n"; + let mut entries = s.entries(); + entries.sort_unstable_by(|a, b| a.key.cmp(&b.key)); + + for entry in &entries { + let _ = writeln!( + &mut docs, + "| {} | {} | {} |", + entry.key, + entry.value.as_deref().unwrap_or("NULL"), + entry.description + ); + } + docs + } } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index b11596c4a30f4..c2a6cfe2c833f 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -19,7 +19,7 @@ use crate::{ config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, runtime_env::RuntimeEnv, }; -use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{internal_datafusion_err, plan_datafusion_err, Result}; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use std::collections::HashSet; @@ -168,9 +168,9 @@ impl FunctionRegistry for TaskContext { let result = self.window_functions.get(name); result.cloned().ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "There is no UDWF named \"{name}\" in the TaskContext" - )) + ) }) } fn register_udaf( @@ -201,6 +201,14 @@ impl FunctionRegistry for TaskContext { fn expr_planners(&self) -> Vec> { vec![] } + + fn udafs(&self) -> HashSet { + self.aggregate_functions.keys().cloned().collect() + } + + fn udwfs(&self) -> HashSet { + self.window_functions.keys().cloned().collect() + } } #[cfg(test)] diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 14717dd78135d..db85f32079214 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -19,6 +19,7 @@ name = "datafusion-expr-common" description = "Logical plan and expression representation for DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } homepage = { workspace = true } diff --git a/datafusion/expr-common/README.md b/datafusion/expr-common/README.md new file mode 100644 index 0000000000000..97006702542a0 --- /dev/null +++ b/datafusion/expr-common/README.md @@ -0,0 +1,32 @@ + + +# Apache DataFusion Common Logical Plan and Expressions + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate is a submodule of DataFusion that provides common logical expressions + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index 3a63c32894810..2829a9416f033 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -42,7 +42,6 @@ use std::fmt::Debug; /// [`state`] and combine the state from multiple accumulators /// via [`merge_batch`], as part of efficient multi-phase grouping. /// -/// [`GroupsAccumulator`]: crate::GroupsAccumulator /// [`update_batch`]: Self::update_batch /// [`retract_batch`]: Self::retract_batch /// [`state`]: Self::state diff --git a/datafusion/expr-common/src/casts.rs b/datafusion/expr-common/src/casts.rs new file mode 100644 index 0000000000000..8939ff1371bb9 --- /dev/null +++ b/datafusion/expr-common/src/casts.rs @@ -0,0 +1,1293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for casting scalar literals to different data types +//! +//! This module contains functions for casting ScalarValue literals +//! to different data types, originally extracted from the optimizer's +//! unwrap_cast module to be shared between logical and physical layers. + +use std::cmp::Ordering; + +use arrow::datatypes::{ + DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, + MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION, + MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION, + MIN_DECIMAL64_FOR_EACH_PRECISION, +}; +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; +use datafusion_common::ScalarValue; + +/// Convert a literal value from one data type to another +pub fn try_cast_literal_to_type( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { + return None; + } + if lit_value.is_null() { + // null value can be cast to any type of null value + return ScalarValue::try_from(target_type).ok(); + } + try_cast_numeric_literal(lit_value, target_type) + .or_else(|| try_cast_string_literal(lit_value, target_type)) + .or_else(|| try_cast_dictionary(lit_value, target_type)) + .or_else(|| try_cast_binary(lit_value, target_type)) +} + +/// Returns true if unwrap_cast_in_comparison supports this data type +pub fn is_supported_type(data_type: &DataType) -> bool { + is_supported_numeric_type(data_type) + || is_supported_string_type(data_type) + || is_supported_dictionary_type(data_type) + || is_supported_binary_type(data_type) +} + +/// Returns true if unwrap_cast_in_comparison support this numeric type +fn is_supported_numeric_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Timestamp(_, _) + ) +} + +/// Returns true if unwrap_cast_in_comparison supports casting this value as a string +fn is_supported_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + +/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary +fn is_supported_dictionary_type(data_type: &DataType) -> bool { + matches!(data_type, + DataType::Dictionary(_, inner) if is_supported_type(inner)) +} + +fn is_supported_binary_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) +} + +/// Convert a numeric value from one numeric data type to another +fn try_cast_numeric_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_numeric_type(&lit_data_type) + || !is_supported_numeric_type(target_type) + { + return None; + } + + let mul = match target_type { + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 => 1_i128, + DataType::Timestamp(_, _) => 1_i128, + DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32), + DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32), + DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + _ => return None, + }; + let (target_min, target_max) = match target_type { + DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), + DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), + DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), + DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), + DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), + DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), + DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), + DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), + DataType::Decimal32(precision, _) => ( + // Different precision for decimal32 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128, + MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128, + ), + DataType::Decimal64(precision, _) => ( + // Different precision for decimal64 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128, + MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128, + ), + DataType::Decimal128(precision, _) => ( + // Different precision for decimal128 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], + MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], + ), + _ => return None, + }; + let lit_value_target_type = match lit_value { + ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::Decimal32(Some(v), _, scale) => { + let v = *v as i128; + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + v.checked_mul(mul / lit_scale_mul) + } else if v % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } + ScalarValue::Decimal64(Some(v), _, scale) => { + let v = *v as i128; + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + v.checked_mul(mul / lit_scale_mul) + } else if v % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } + ScalarValue::Decimal128(Some(v), _, scale) => { + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + (*v).checked_mul(mul / lit_scale_mul) + } else if (*v) % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(*v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } + _ => None, + }; + + match lit_value_target_type { + None => None, + Some(value) => { + if value >= target_min && value <= target_max { + // the value casted from lit to the target type is in the range of target type. + // return the target type of scalar value + let result_scalar = match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), + DataType::Timestamp(TimeUnit::Second, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Second, tz.clone()), + value, + ); + ScalarValue::TimestampSecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + value, + ); + ScalarValue::TimestampMillisecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + value, + ); + ScalarValue::TimestampMicrosecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + value, + ); + ScalarValue::TimestampNanosecond(value, tz.clone()) + } + DataType::Decimal32(p, s) => { + ScalarValue::Decimal32(Some(value as i32), *p, *s) + } + DataType::Decimal64(p, s) => { + ScalarValue::Decimal64(Some(value as i64), *p, *s) + } + DataType::Decimal128(p, s) => { + ScalarValue::Decimal128(Some(value), *p, *s) + } + _ => { + return None; + } + }; + Some(result_scalar) + } else { + None + } + } + } +} + +fn try_cast_string_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); + let scalar_value = match target_type { + DataType::Utf8 => ScalarValue::Utf8(string_value), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), + DataType::Utf8View => ScalarValue::Utf8View(string_value), + _ => return None, + }; + Some(scalar_value) +} + +/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary +fn try_cast_dictionary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_value_type = lit_value.data_type(); + let result_scalar = match (lit_value, target_type) { + // Unwrap dictionary when inner type matches target type + (ScalarValue::Dictionary(_, inner_value), _) + if inner_value.data_type() == *target_type => + { + (**inner_value).clone() + } + // Wrap type when target type is dictionary + (_, DataType::Dictionary(index_type, inner_type)) + if **inner_type == lit_value_type => + { + ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) + } + _ => { + return None; + } + }; + Some(result_scalar) +} + +/// Cast a timestamp value from one unit to another +fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { + let value = value as i64; + let from_scale = match from { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + let to_scale = match to { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + match from_scale.cmp(&to_scale) { + Ordering::Less => value.checked_mul(to_scale / from_scale), + Ordering::Greater => Some(value / (from_scale / to_scale)), + Ordering::Equal => Some(value), + } +} + +fn try_cast_binary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + match (lit_value, target_type) { + (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) + if v.len() == *n as usize => + { + Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::compute::{cast_with_options, CastOptions}; + use arrow::datatypes::{Field, Fields, TimeUnit}; + use std::sync::Arc; + + #[derive(Debug, Clone)] + enum ExpectedCast { + /// test successfully cast value and it is as specified + Value(ScalarValue), + /// test returned OK, but could not cast the value + NoValue, + } + + /// Runs try_cast_literal_to_type with the specified inputs and + /// ensure it computes the expected output, and ensures the + /// casting is consistent with the Arrow kernels + fn expect_cast( + literal: ScalarValue, + target_type: DataType, + expected_result: ExpectedCast, + ) { + let actual_value = try_cast_literal_to_type(&literal, &target_type); + + println!("expect_cast: "); + println!(" {literal:?} --> {target_type}"); + println!(" expected_result: {expected_result:?}"); + println!(" actual_result: {actual_value:?}"); + + match expected_result { + ExpectedCast::Value(expected_value) => { + let actual_value = + actual_value.expect("Expected cast value but got None"); + + assert_eq!(actual_value, expected_value); + + // Verify that calling the arrow + // cast kernel yields the same results + // input array + let literal_array = literal + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let expected_array = expected_value + .to_array_of_size(1) + .expect("Failed to convert to array of size"); + let cast_array = cast_with_options( + &literal_array, + &target_type, + &CastOptions::default(), + ) + .expect("Expected to be cast array with arrow cast kernel"); + + assert_eq!( + &expected_array, &cast_array, + "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" + ); + + // Verify that for timestamp types the timezones are the same + // (ScalarValue::cmp doesn't account for timezones); + if let ( + DataType::Timestamp(left_unit, left_tz), + DataType::Timestamp(right_unit, right_tz), + ) = (actual_value.data_type(), expected_value.data_type()) + { + assert_eq!(left_unit, right_unit); + assert_eq!(left_tz, right_tz); + } + } + ExpectedCast::NoValue => { + assert!( + actual_value.is_none(), + "Expected no cast value, but got {actual_value:?}" + ); + } + } + } + + #[test] + fn test_try_cast_to_type_nulls() { + // test that nulls can be cast to/from all integer types + let scalars = vec![ + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), + ScalarValue::Decimal128(None, 3, 0), + ScalarValue::Decimal128(None, 8, 2), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + } + + #[test] + fn test_try_cast_to_type_int_in_range() { + // test values that can be cast to/from all integer types + let scalars = vec![ + ScalarValue::Int8(Some(123)), + ScalarValue::Int16(Some(123)), + ScalarValue::Int32(Some(123)), + ScalarValue::Int64(Some(123)), + ScalarValue::UInt8(Some(123)), + ScalarValue::UInt16(Some(123)), + ScalarValue::UInt32(Some(123)), + ScalarValue::UInt64(Some(123)), + ScalarValue::Decimal128(Some(123), 3, 0), + ScalarValue::Decimal128(Some(12300), 8, 2), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + + let max_i32 = ScalarValue::Int32(Some(i32::MAX)); + expect_cast( + max_i32, + DataType::UInt64, + ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))), + ); + + let min_i32 = ScalarValue::Int32(Some(i32::MIN)); + expect_cast( + min_i32, + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))), + ); + + let max_i64 = ScalarValue::Int64(Some(i64::MAX)); + expect_cast( + max_i64, + DataType::UInt64, + ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))), + ); + } + + #[test] + fn test_try_cast_to_type_int_out_of_range() { + let min_i32 = ScalarValue::Int32(Some(i32::MIN)); + let min_i64 = ScalarValue::Int64(Some(i64::MIN)); + let max_i64 = ScalarValue::Int64(Some(i64::MAX)); + let max_u64 = ScalarValue::UInt64(Some(u64::MAX)); + + expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue); + + expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue); + + expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue); + + expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue); + + expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue); + + expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue); + + // decimal out of range + expect_cast( + ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0), + DataType::Int64, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1), + DataType::Int64, + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_decimal_cast_in_range() { + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(3, 0), + ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)), + ); + + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(8, 0), + ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)), + ); + + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(8, 5), + ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)), + ); + } + + #[test] + fn test_try_decimal_cast_out_of_range() { + // decimal would lose precision + expect_cast( + ScalarValue::Decimal128(Some(12345), 5, 2), + DataType::Decimal128(3, 0), + ExpectedCast::NoValue, + ); + + // decimal would lose precision + expect_cast( + ScalarValue::Decimal128(Some(12300), 5, 2), + DataType::Decimal128(2, 0), + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_cast_to_type_timestamps() { + for time_unit in [ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + let utc = Some("+00:00".into()); + // No timezone, utc timezone + let (lit_tz_none, lit_tz_utc) = match time_unit { + TimeUnit::Second => ( + ScalarValue::TimestampSecond(Some(12345), None), + ScalarValue::TimestampSecond(Some(12345), utc), + ), + + TimeUnit::Millisecond => ( + ScalarValue::TimestampMillisecond(Some(12345), None), + ScalarValue::TimestampMillisecond(Some(12345), utc), + ), + + TimeUnit::Microsecond => ( + ScalarValue::TimestampMicrosecond(Some(12345), None), + ScalarValue::TimestampMicrosecond(Some(12345), utc), + ), + + TimeUnit::Nanosecond => ( + ScalarValue::TimestampNanosecond(Some(12345), None), + ScalarValue::TimestampNanosecond(Some(12345), utc), + ), + }; + + // DataFusion ignores timezones for comparisons of ScalarValue + // so double check it here + assert_eq!(lit_tz_none, lit_tz_utc); + + // e.g. DataType::Timestamp(_, None) + let dt_tz_none = lit_tz_none.data_type(); + + // e.g. DataType::Timestamp(_, Some(utc)) + let dt_tz_utc = lit_tz_utc.data_type(); + + // None <--> None + expect_cast( + lit_tz_none.clone(), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // None <--> Utc + expect_cast( + lit_tz_none.clone(), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // Utc <--> None + expect_cast( + lit_tz_utc.clone(), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // Utc <--> Utc + expect_cast( + lit_tz_utc.clone(), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // timestamp to int64 + expect_cast( + lit_tz_utc.clone(), + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(12345))), + ); + + // int64 to timestamp + expect_cast( + ScalarValue::Int64(Some(12345)), + dt_tz_none.clone(), + ExpectedCast::Value(lit_tz_none.clone()), + ); + + // int64 to timestamp + expect_cast( + ScalarValue::Int64(Some(12345)), + dt_tz_utc.clone(), + ExpectedCast::Value(lit_tz_utc.clone()), + ); + + // timestamp to string (not supported yet) + expect_cast( + lit_tz_utc.clone(), + DataType::LargeUtf8, + ExpectedCast::NoValue, + ); + } + } + + #[test] + fn test_try_cast_to_type_unsupported() { + // int64 to list + expect_cast( + ScalarValue::Int64(Some(12345)), + DataType::List(Arc::new(Field::new("f", DataType::Int32, true))), + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_try_cast_literal_to_timestamp() { + // same timestamp + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123456), None) + ); + + // TimestampNanosecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123), None) + ); + + // TimestampNanosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampNanosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampNanosecond(Some(123456), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); + + // TimestampMicrosecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000), None) + ); + + // TimestampMicrosecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); + + // TimestampMicrosecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMicrosecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); + + // TimestampMillisecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000), None) + ); + + // TimestampMillisecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000), None) + ); + // TimestampMillisecond to TimestampSecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampMillisecond(Some(123456789), None), + &DataType::Timestamp(TimeUnit::Second, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); + + // TimestampSecond to TimestampNanosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Nanosecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampNanosecond(Some(123000000000), None) + ); + + // TimestampSecond to TimestampMicrosecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Microsecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMicrosecond(Some(123000000), None) + ); + + // TimestampSecond to TimestampMillisecond + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(123), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + assert_eq!( + new_scalar, + ScalarValue::TimestampMillisecond(Some(123000), None) + ); + + // overflow + let new_scalar = try_cast_literal_to_type( + &ScalarValue::TimestampSecond(Some(i64::MAX), None), + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .unwrap(); + assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); + } + + #[test] + fn test_try_cast_to_string_type() { + let scalars = vec![ + ScalarValue::from("string"), + ScalarValue::LargeUtf8(Some("string".to_owned())), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + } + + #[test] + fn test_try_cast_to_dictionary_type() { + fn dictionary_type(t: DataType) -> DataType { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) + } + fn dictionary_value(value: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) + } + let scalars = vec![ + ScalarValue::from("string"), + ScalarValue::LargeUtf8(Some("string".to_owned())), + ]; + for s in &scalars { + expect_cast( + s.clone(), + dictionary_type(s.data_type()), + ExpectedCast::Value(dictionary_value(s.clone())), + ); + expect_cast( + dictionary_value(s.clone()), + s.data_type(), + ExpectedCast::Value(s.clone()), + ) + } + } + + #[test] + fn test_try_cast_to_fixed_size_binary() { + expect_cast( + ScalarValue::Binary(Some(vec![1, 2, 3])), + DataType::FixedSizeBinary(3), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))), + ) + } + + #[test] + fn test_numeric_boundary_values() { + // Test exact boundary values for signed integers + expect_cast( + ScalarValue::Int8(Some(i8::MAX)), + DataType::UInt8, + ExpectedCast::Value(ScalarValue::UInt8(Some(i8::MAX as u8))), + ); + + expect_cast( + ScalarValue::Int8(Some(i8::MIN)), + DataType::UInt8, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::UInt8(Some(u8::MAX)), + DataType::Int8, + ExpectedCast::NoValue, + ); + + // Test cross-type boundary scenarios + expect_cast( + ScalarValue::Int32(Some(i32::MAX)), + DataType::Int64, + ExpectedCast::Value(ScalarValue::Int64(Some(i32::MAX as i64))), + ); + + expect_cast( + ScalarValue::Int64(Some(i64::MIN)), + DataType::UInt64, + ExpectedCast::NoValue, + ); + + // Test unsigned to signed edge cases + expect_cast( + ScalarValue::UInt32(Some(u32::MAX)), + DataType::Int32, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::UInt64(Some(u64::MAX)), + DataType::Int64, + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_decimal_precision_limits() { + use arrow::datatypes::{ + MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION, + }; + + // Test maximum precision values + expect_cast( + ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0), + DataType::Decimal128(5, 0), + ExpectedCast::Value(ScalarValue::Decimal128( + Some(MAX_DECIMAL128_FOR_EACH_PRECISION[3]), + 5, + 0, + )), + ); + + // Test minimum precision values + expect_cast( + ScalarValue::Decimal128(Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), 3, 0), + DataType::Decimal128(5, 0), + ExpectedCast::Value(ScalarValue::Decimal128( + Some(MIN_DECIMAL128_FOR_EACH_PRECISION[3]), + 5, + 0, + )), + ); + + // Test scale increase + expect_cast( + ScalarValue::Decimal128(Some(123), 3, 0), + DataType::Decimal128(5, 2), + ExpectedCast::Value(ScalarValue::Decimal128(Some(12300), 5, 2)), + ); + + // Test precision overflow (value too large for target precision) + expect_cast( + ScalarValue::Decimal128(Some(MAX_DECIMAL128_FOR_EACH_PRECISION[10]), 10, 0), + DataType::Decimal128(3, 0), + ExpectedCast::NoValue, + ); + + // Test non-divisible decimal conversion (should fail) + expect_cast( + ScalarValue::Decimal128(Some(12345), 5, 3), // 12.345 + DataType::Int32, + ExpectedCast::NoValue, // Can't convert 12.345 to integer without loss + ); + + // Test edge case: scale reduction with precision loss + expect_cast( + ScalarValue::Decimal128(Some(12345), 5, 2), // 123.45 + DataType::Decimal128(3, 0), // Can only hold up to 999 + ExpectedCast::NoValue, + ); + } + + #[test] + fn test_timestamp_overflow_scenarios() { + // Test overflow in timestamp conversions + let max_seconds = i64::MAX / 1_000_000_000; // Avoid overflow when converting to nanos + + // This should work - within safe range + expect_cast( + ScalarValue::TimestampSecond(Some(max_seconds), None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ExpectedCast::Value(ScalarValue::TimestampNanosecond( + Some(max_seconds * 1_000_000_000), + None, + )), + ); + + // Test very large nanosecond value conversion to smaller units + expect_cast( + ScalarValue::TimestampNanosecond(Some(i64::MAX), None), + DataType::Timestamp(TimeUnit::Second, None), + ExpectedCast::Value(ScalarValue::TimestampSecond( + Some(i64::MAX / 1_000_000_000), + None, + )), + ); + + // Test precision loss in downscaling + expect_cast( + ScalarValue::TimestampNanosecond(Some(1), None), + DataType::Timestamp(TimeUnit::Second, None), + ExpectedCast::Value(ScalarValue::TimestampSecond(Some(0), None)), + ); + + expect_cast( + ScalarValue::TimestampMicrosecond(Some(999), None), + DataType::Timestamp(TimeUnit::Millisecond, None), + ExpectedCast::Value(ScalarValue::TimestampMillisecond(Some(0), None)), + ); + } + + #[test] + fn test_string_view() { + // Test Utf8View to other string types + expect_cast( + ScalarValue::Utf8View(Some("test".to_string())), + DataType::Utf8, + ExpectedCast::Value(ScalarValue::Utf8(Some("test".to_string()))), + ); + + expect_cast( + ScalarValue::Utf8View(Some("test".to_string())), + DataType::LargeUtf8, + ExpectedCast::Value(ScalarValue::LargeUtf8(Some("test".to_string()))), + ); + + // Test other string types to Utf8View + expect_cast( + ScalarValue::Utf8(Some("hello".to_string())), + DataType::Utf8View, + ExpectedCast::Value(ScalarValue::Utf8View(Some("hello".to_string()))), + ); + + expect_cast( + ScalarValue::LargeUtf8(Some("world".to_string())), + DataType::Utf8View, + ExpectedCast::Value(ScalarValue::Utf8View(Some("world".to_string()))), + ); + + // Test empty string + expect_cast( + ScalarValue::Utf8(Some("".to_string())), + DataType::Utf8View, + ExpectedCast::Value(ScalarValue::Utf8View(Some("".to_string()))), + ); + + // Test large string + let large_string = "x".repeat(1000); + expect_cast( + ScalarValue::LargeUtf8(Some(large_string.clone())), + DataType::Utf8View, + ExpectedCast::Value(ScalarValue::Utf8View(Some(large_string))), + ); + } + + #[test] + fn test_binary_size_edge_cases() { + // Test size mismatch - too small + expect_cast( + ScalarValue::Binary(Some(vec![1, 2])), + DataType::FixedSizeBinary(3), + ExpectedCast::NoValue, + ); + + // Test size mismatch - too large + expect_cast( + ScalarValue::Binary(Some(vec![1, 2, 3, 4])), + DataType::FixedSizeBinary(3), + ExpectedCast::NoValue, + ); + + // Test empty binary + expect_cast( + ScalarValue::Binary(Some(vec![])), + DataType::FixedSizeBinary(0), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(0, Some(vec![]))), + ); + + // Test exact size match + expect_cast( + ScalarValue::Binary(Some(vec![1, 2, 3])), + DataType::FixedSizeBinary(3), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))), + ); + + // Test single byte + expect_cast( + ScalarValue::Binary(Some(vec![42])), + DataType::FixedSizeBinary(1), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(1, Some(vec![42]))), + ); + } + + #[test] + fn test_dictionary_index_types() { + // Test different dictionary index types + let string_value = ScalarValue::Utf8(Some("test".to_string())); + + // Int8 index dictionary + let dict_int8 = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + expect_cast( + string_value.clone(), + dict_int8, + ExpectedCast::Value(ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(string_value.clone()), + )), + ); + + // Int16 index dictionary + let dict_int16 = + DataType::Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)); + expect_cast( + string_value.clone(), + dict_int16, + ExpectedCast::Value(ScalarValue::Dictionary( + Box::new(DataType::Int16), + Box::new(string_value.clone()), + )), + ); + + // Int64 index dictionary + let dict_int64 = + DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8)); + expect_cast( + string_value.clone(), + dict_int64, + ExpectedCast::Value(ScalarValue::Dictionary( + Box::new(DataType::Int64), + Box::new(string_value.clone()), + )), + ); + + // Test dictionary unwrapping + let dict_value = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))), + ); + expect_cast( + dict_value, + DataType::LargeUtf8, + ExpectedCast::Value(ScalarValue::LargeUtf8(Some("unwrap_test".to_string()))), + ); + } + + #[test] + fn test_type_support_functions() { + // Test numeric type support + assert!(is_supported_numeric_type(&DataType::Int8)); + assert!(is_supported_numeric_type(&DataType::UInt64)); + assert!(is_supported_numeric_type(&DataType::Decimal128(10, 2))); + assert!(is_supported_numeric_type(&DataType::Timestamp( + TimeUnit::Nanosecond, + None + ))); + assert!(!is_supported_numeric_type(&DataType::Float32)); + assert!(!is_supported_numeric_type(&DataType::Float64)); + + // Test string type support + assert!(is_supported_string_type(&DataType::Utf8)); + assert!(is_supported_string_type(&DataType::LargeUtf8)); + assert!(is_supported_string_type(&DataType::Utf8View)); + assert!(!is_supported_string_type(&DataType::Binary)); + + // Test binary type support + assert!(is_supported_binary_type(&DataType::Binary)); + assert!(is_supported_binary_type(&DataType::FixedSizeBinary(10))); + assert!(!is_supported_binary_type(&DataType::Utf8)); + + // Test dictionary type support with nested types + assert!(is_supported_dictionary_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8) + ))); + assert!(is_supported_dictionary_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Int64) + ))); + assert!(!is_supported_dictionary_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true + )))) + ))); + + // Test overall type support + assert!(is_supported_type(&DataType::Int32)); + assert!(is_supported_type(&DataType::Utf8)); + assert!(is_supported_type(&DataType::Binary)); + assert!(is_supported_type(&DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8) + ))); + assert!(!is_supported_type(&DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true + ))))); + assert!(!is_supported_type(&DataType::Struct(Fields::empty()))); + } + + #[test] + fn test_error_conditions() { + // Test unsupported source type + expect_cast( + ScalarValue::Float32(Some(1.5)), + DataType::Int32, + ExpectedCast::NoValue, + ); + + // Test unsupported target type + expect_cast( + ScalarValue::Int32(Some(123)), + DataType::Float64, + ExpectedCast::NoValue, + ); + + // Test both types unsupported + expect_cast( + ScalarValue::Float64(Some(1.5)), + DataType::Float32, + ExpectedCast::NoValue, + ); + + // Test complex unsupported types + let list_type = + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + expect_cast( + ScalarValue::Int32(Some(123)), + list_type, + ExpectedCast::NoValue, + ); + + // Test dictionary with unsupported inner type + let bad_dict = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + expect_cast( + ScalarValue::Int32(Some(123)), + bad_dict, + ExpectedCast::NoValue, + ); + } +} diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index cb7cbdbac291d..a21ad5bbbcc30 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -237,7 +237,7 @@ impl fmt::Display for ColumnarValue { }; if let Ok(formatted) = formatted { - write!(f, "{}", formatted) + write!(f, "{formatted}") } else { write!(f, "Error formatting columnar value") } diff --git a/datafusion/expr-common/src/dyn_eq.rs b/datafusion/expr-common/src/dyn_eq.rs new file mode 100644 index 0000000000000..e0ebcae4879d6 --- /dev/null +++ b/datafusion/expr-common/src/dyn_eq.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::hash::{Hash, Hasher}; + +/// A dyn-compatible version of [`Eq`] trait. +/// The implementation constraints for this trait are the same as for [`Eq`]: +/// the implementation must be reflexive, symmetric, and transitive. +/// Additionally, if two values can be compared with [`DynEq`] and [`PartialEq`] then +/// they must be [`DynEq`]-equal if and only if they are [`PartialEq`]-equal. +/// It is therefore strongly discouraged to implement this trait for types +/// that implement `PartialEq` or `Eq` for any type `Other` other than `Self`. +/// +/// Note: This trait should not be implemented directly. Implement `Eq` and `Any` and use +/// the blanket implementation. +#[allow(private_bounds)] +pub trait DynEq: private::EqSealed { + fn dyn_eq(&self, other: &dyn Any) -> bool; +} + +impl private::EqSealed for T {} +impl DynEq for T { + fn dyn_eq(&self, other: &dyn Any) -> bool { + other.downcast_ref::() == Some(self) + } +} + +/// A dyn-compatible version of [`Hash`] trait. +/// If two values are equal according to [`DynEq`], they must produce the same hash value. +/// +/// Note: This trait should not be implemented directly. Implement `Hash` and `Any` and use +/// the blanket implementation. +#[allow(private_bounds)] +pub trait DynHash: private::HashSealed { + fn dyn_hash(&self, _state: &mut dyn Hasher); +} + +impl private::HashSealed for T {} +impl DynHash for T { + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.type_id().hash(&mut state); + self.hash(&mut state) + } +} + +mod private { + pub(super) trait EqSealed {} + pub(super) trait HashSealed {} +} diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 5ff1c1d072164..9bcc1edff8824 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray}; use datafusion_common::{not_impl_err, Result}; /// Describes how many rows should be emitted during grouping. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EmitTo { /// Emit all groups All, diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 9d00b45962bc2..b5b632076b006 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -174,7 +174,7 @@ macro_rules! value_transition { /// - `INF` values are converted to `NULL`s while constructing an interval to /// ensure consistency, with other data types. /// - `NaN` (Not a Number) results are conservatively result in unbounded -/// endpoints. +/// endpoints. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Interval { lower: ScalarValue, @@ -606,7 +606,7 @@ impl Interval { upper: ScalarValue::Boolean(Some(upper)), }) } - _ => internal_err!("Incompatible data types for logical conjunction"), + _ => internal_err!("Incompatible data types for logical disjunction"), } } @@ -754,6 +754,17 @@ impl Interval { } } + /// Decide if this interval is a superset of `other`. If argument `strict` + /// is `true`, only returns `true` if this interval is a strict superset. + /// + /// NOTE: This function only works with intervals of the same data type. + /// Attempting to compare intervals of different data types will lead + /// to an error. + pub fn is_superset(&self, other: &Interval, strict: bool) -> Result { + Ok(!(strict && self.eq(other)) + && (self.contains(other)? == Interval::CERTAINLY_TRUE)) + } + /// Add the given interval (`other`) to this interval. Say we have intervals /// `[a1, b1]` and `[a2, b2]`, then their sum is `[a1 + a2, b1 + b2]`. Note /// that this represents all possible values the sum can take if one can @@ -949,6 +960,18 @@ impl Display for Interval { } } +impl From for Interval { + fn from(value: ScalarValue) -> Self { + Self::new(value.clone(), value) + } +} + +impl From<&ScalarValue> for Interval { + fn from(value: &ScalarValue) -> Self { + Self::new(value.to_owned(), value.to_owned()) + } +} + /// Applies the given binary operator the `lhs` and `rhs` arguments. pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { @@ -959,6 +982,7 @@ pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result lhs.lt(rhs), Operator::LtEq => lhs.lt_eq(rhs), Operator::And => lhs.and(rhs), + Operator::Or => lhs.or(rhs), Operator::Plus => lhs.add(rhs), Operator::Minus => lhs.sub(rhs), Operator::Multiply => lhs.mul(rhs), @@ -1683,9 +1707,9 @@ impl Display for NullableInterval { match self { Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), Self::MaybeNull { values } => { - write!(f, "NullableInterval: {} U {{NULL}}", values) + write!(f, "NullableInterval: {values} U {{NULL}}") } - Self::NotNull { values } => write!(f, "NullableInterval: {}", values), + Self::NotNull { values } => write!(f, "NullableInterval: {values}"), } } } @@ -2706,8 +2730,8 @@ mod tests { ), ]; for (first, second, expected) in possible_cases { - println!("{}", first); - println!("{}", second); + println!("{first}"); + println!("{second}"); assert_eq!(first.union(second)?, expected) } @@ -3674,6 +3698,76 @@ mod tests { Interval::make(Some(-500.0_f64), Some(1000.0_f64))?, Interval::make(Some(-500.0_f64), Some(500.0_f64))?, ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(-0_i64), Some(0_i64))?, + true, + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(-0_i64), Some(0_i64))?, + ), + ( + Interval::make(Some(-0_i64), Some(0_i64))?, + Interval::make(Some(-0_i64), Some(-0_i64))?, + true, + Interval::make(Some(-0_i64), Some(0_i64))?, + Interval::make(Some(-0_i64), Some(-0_i64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + true, + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + ), + ( + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + false, + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + true, + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + ), + ( + Interval::make(Some(-0.0_f64), Some(0.0_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + false, + Interval::make(Some(0.0_f64), Some(0.0_f64))?, + Interval::make(Some(-0.0_f64), Some(-0.0_f64))?, + ), + ( + Interval::make(Some(0_i64), None)?, + Interval::make(Some(-0_i64), None)?, + true, + Interval::make(Some(0_i64), None)?, + Interval::make(Some(-0_i64), None)?, + ), + ( + Interval::make(Some(0_i64), None)?, + Interval::make(Some(-0_i64), None)?, + false, + Interval::make(Some(1_i64), None)?, + Interval::make(Some(-0_i64), None)?, + ), + ( + Interval::make(Some(0.0_f64), None)?, + Interval::make(Some(-0.0_f64), None)?, + true, + Interval::make(Some(0.0_f64), None)?, + Interval::make(Some(-0.0_f64), None)?, + ), + ( + Interval::make(Some(0.0_f64), None)?, + Interval::make(Some(-0.0_f64), None)?, + false, + Interval::make(Some(0.0_f64), None)?, + Interval::make(Some(-0.0_f64), None)?, + ), ]; for (first, second, includes_endpoints, left_modified, right_modified) in cases { assert_eq!( @@ -3693,6 +3787,16 @@ mod tests { Interval::make(Some(1500.0_f32), Some(2000.0_f32))?, false, ), + ( + Interval::make(Some(0_i64), Some(0_i64))?, + Interval::make(Some(-0_i64), Some(0_i64))?, + false, + ), + ( + Interval::make(Some(-0_i64), Some(0_i64))?, + Interval::make(Some(-0_i64), Some(-0_i64))?, + false, + ), ]; for (first, second, includes_endpoints) in infeasible_cases { assert_eq!(satisfy_greater(&first, &second, !includes_endpoints)?, None); @@ -3704,14 +3808,14 @@ mod tests { #[test] fn test_interval_display() { let interval = Interval::make(Some(0.25_f32), Some(0.50_f32)).unwrap(); - assert_eq!(format!("{}", interval), "[0.25, 0.5]"); + assert_eq!(format!("{interval}"), "[0.25, 0.5]"); let interval = Interval::try_new( ScalarValue::Float32(Some(f32::NEG_INFINITY)), ScalarValue::Float32(Some(f32::INFINITY)), ) .unwrap(); - assert_eq!(format!("{}", interval), "[NULL, NULL]"); + assert_eq!(format!("{interval}"), "[NULL, NULL]"); } macro_rules! capture_mode_change { @@ -3792,4 +3896,138 @@ mod tests { let upper = 1.5; capture_mode_change_f32((lower, upper), true, true); } + + #[test] + fn test_is_superset() -> Result<()> { + // Test cases: (interval1, interval2, strict, expected) + let test_cases = vec![ + // Equal intervals - non-strict should be true, strict should be false + ( + Interval::make(Some(10_i32), Some(50_i32))?, + Interval::make(Some(10_i32), Some(50_i32))?, + false, + true, + ), + ( + Interval::make(Some(10_i32), Some(50_i32))?, + Interval::make(Some(10_i32), Some(50_i32))?, + true, + false, + ), + // Unbounded intervals + ( + Interval::make::(None, None)?, + Interval::make(Some(10_i32), Some(50_i32))?, + false, + true, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + false, + true, + ), + ( + Interval::make::(None, None)?, + Interval::make::(None, None)?, + true, + false, + ), + // Half-bounded intervals + ( + Interval::make(Some(0_i32), None)?, + Interval::make(Some(10_i32), Some(50_i32))?, + false, + true, + ), + ( + Interval::make(None, Some(100_i32))?, + Interval::make(Some(10_i32), Some(50_i32))?, + false, + true, + ), + // Non-superset cases - partial overlap + ( + Interval::make(Some(0_i32), Some(50_i32))?, + Interval::make(Some(25_i32), Some(75_i32))?, + false, + false, + ), + ( + Interval::make(Some(0_i32), Some(50_i32))?, + Interval::make(Some(25_i32), Some(75_i32))?, + true, + false, + ), + // Non-superset cases - disjoint intervals + ( + Interval::make(Some(0_i32), Some(50_i32))?, + Interval::make(Some(60_i32), Some(100_i32))?, + false, + false, + ), + // Subset relationship (reversed) + ( + Interval::make(Some(20_i32), Some(80_i32))?, + Interval::make(Some(0_i32), Some(100_i32))?, + false, + false, + ), + // Float cases + ( + Interval::make(Some(0.0_f32), Some(100.0_f32))?, + Interval::make(Some(25.5_f32), Some(75.5_f32))?, + false, + true, + ), + ( + Interval::make(Some(0.0_f64), Some(100.0_f64))?, + Interval::make(Some(0.0_f64), Some(100.0_f64))?, + true, + false, + ), + // Edge cases with single point intervals + ( + Interval::make(Some(0_i32), Some(100_i32))?, + Interval::make(Some(50_i32), Some(50_i32))?, + false, + true, + ), + ( + Interval::make(Some(50_i32), Some(50_i32))?, + Interval::make(Some(50_i32), Some(50_i32))?, + false, + true, + ), + ( + Interval::make(Some(50_i32), Some(50_i32))?, + Interval::make(Some(50_i32), Some(50_i32))?, + true, + false, + ), + // Boundary touch cases + ( + Interval::make(Some(0_i32), Some(50_i32))?, + Interval::make(Some(0_i32), Some(25_i32))?, + false, + true, + ), + ( + Interval::make(Some(0_i32), Some(50_i32))?, + Interval::make(Some(25_i32), Some(50_i32))?, + false, + true, + ), + ]; + + for (interval1, interval2, strict, expected) in test_cases { + let result = interval1.is_superset(&interval2, strict)?; + assert_eq!( + result, expected, + "Failed for interval1: {interval1}, interval2: {interval2}, strict: {strict}", + ); + } + + Ok(()) + } } diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 961670a3b7f45..a4f6414a8c51d 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -27,13 +27,15 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] pub mod accumulator; +pub mod casts; pub mod columnar_value; +pub mod dyn_eq; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod operator; diff --git a/datafusion/expr-common/src/operator.rs b/datafusion/expr-common/src/operator.rs index 19fc6b80745e2..33512b0c354d6 100644 --- a/datafusion/expr-common/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -229,15 +229,6 @@ impl Operator { ) } - /// Return true if the comparison operator can be used in interval arithmetic and constraint - /// propagation - /// - /// For example, 'Binary(a, >, b)' expression supports propagation. - #[deprecated(since = "43.0.0", note = "please use `supports_propagation` instead")] - pub fn is_comparison_operator(&self) -> bool { - self.supports_propagation() - } - /// Return true if the operator is a logic operator. /// /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would @@ -337,6 +328,60 @@ impl Operator { Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } } + + /// Returns true if the `Expr::BinaryOperator` with this operator + /// is guaranteed to return null if either side is null. + pub fn returns_null_on_null(&self) -> bool { + match self { + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + | Operator::LikeMatch + | Operator::ILikeMatch + | Operator::NotLikeMatch + | Operator::NotILikeMatch + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft + | Operator::AtArrow + | Operator::ArrowAt + | Operator::Arrow + | Operator::LongArrow + | Operator::HashArrow + | Operator::HashLongArrow + | Operator::AtAt + | Operator::IntegerDivide + | Operator::HashMinus + | Operator::AtQuestion + | Operator::Question + | Operator::QuestionAnd + | Operator::QuestionPipe => true, + + // E.g. `TRUE OR NULL` is `TRUE` + Operator::Or + // E.g. `FALSE AND NULL` is `FALSE` + | Operator::And + // IS DISTINCT FROM and IS NOT DISTINCT FROM always return a TRUE/FALSE value, never NULL + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + // DataFusion string concatenation operator treats NULL as an empty string + | Operator::StringConcat => false, + } + } } impl fmt::Display for Operator { diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 063417a254be3..5fd4518e2e57f 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Signature module contains foundational types that are used to represent signatures, types, -//! and return types of functions in DataFusion. +//! Function signatures: [`Volatility`], [`Signature`] and [`TypeSignature`] use std::fmt::Display; use std::hash::Hash; @@ -44,42 +43,89 @@ pub const TIMEZONE_WILDCARD: &str = "+TZ"; /// valid length. It exists to avoid the need to enumerate all possible fixed size list lengths. pub const FIXED_SIZE_LIST_WILDCARD: i32 = i32::MIN; -/// A function's volatility, which defines the functions eligibility for certain optimizations +/// How a function's output changes with respect to a fixed input +/// +/// The volatility of a function determines eligibility for certain +/// optimizations. You should always define your function to have the strictest +/// possible volatility to maximize performance and avoid unexpected +/// results. #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub enum Volatility { - /// An immutable function will always return the same output when given the same - /// input. DataFusion will attempt to inline immutable functions during planning. + /// Always returns the same output when given the same input. + /// + /// DataFusion will inline immutable functions during planning. + /// + /// For example, the `abs` function is immutable, so `abs(-1)` will be + /// evaluated and replaced with `1` during planning rather than invoking + /// the function at runtime. Immutable, - /// A stable function may return different values given the same input across different - /// queries but must return the same value for a given input within a query. An example of - /// this is the `Now` function. DataFusion will attempt to inline `Stable` functions - /// during planning, when possible. - /// For query `select col1, now() from t1`, it might take a while to execute but - /// `now()` column will be the same for each output row, which is evaluated - /// during planning. + /// May return different values given the same input across different + /// queries but must return the same value for a given input within a query. + /// + /// For example, the `now()` function is stable, because the query `select + /// col1, now() from t1`, will return different results each time it is run, + /// but within the same query, the output of the `now()` function has the + /// same value for each output row. + /// + /// DataFusion will inline `Stable` functions when possible. For example, + /// `Stable` functions are inlined when planning a query for execution, but + /// not in View definitions or prepared statements. Stable, - /// A volatile function may change the return value from evaluation to evaluation. - /// Multiple invocations of a volatile function may return different results when used in the - /// same query. An example of this is the random() function. DataFusion - /// can not evaluate such functions during planning. - /// In the query `select col1, random() from t1`, `random()` function will be evaluated - /// for each output row, resulting in a unique random value for each row. + /// May change the return value from evaluation to evaluation. + /// + /// Multiple invocations of a volatile function may return different results + /// when used in the same query on different rows. An example of this is the + /// `random()` function. + /// + /// DataFusion can not evaluate such functions during planning or push these + /// predicates into scans. In the query `select col1, random() from t1`, + /// `random()` function will be evaluated for each output row, resulting in + /// a unique random value for each row. Volatile, } -/// A function's type signature defines the types of arguments the function supports. +/// The types of arguments for which a function has implementations. +/// +/// [`TypeSignature`] **DOES NOT** define the types that a user query could call the +/// function with. DataFusion will automatically coerce (cast) argument types to +/// one of the supported function signatures, if possible. /// -/// Functions typically support only a few different types of arguments compared to the -/// different datatypes in Arrow. To make functions easy to use, when possible DataFusion -/// automatically coerces (add casts to) function arguments so they match the type signature. +/// # Overview +/// Functions typically provide implementations for a small number of different +/// argument [`DataType`]s, rather than all possible combinations. If a user +/// calls a function with arguments that do not match any of the declared types, +/// DataFusion will attempt to automatically coerce (add casts to) function +/// arguments so they match the [`TypeSignature`]. See the [`type_coercion`] module +/// for more details /// -/// For example, a function like `cos` may only be implemented for `Float64` arguments. To support a query -/// that calls `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically -/// adds a cast such as `cos(CAST int_column AS DOUBLE)` during planning. +/// # Example: Numeric Functions +/// For example, a function like `cos` may only provide an implementation for +/// [`DataType::Float64`]. When users call `cos` with a different argument type, +/// such as `cos(int_column)`, and type coercion automatically adds a cast such +/// as `cos(CAST int_column AS DOUBLE)` during planning. /// -/// # Data Types +/// [`type_coercion`]: crate::type_coercion /// -/// ## Timestamps +/// ## Example: Strings +/// +/// There are several different string types in Arrow, such as +/// [`DataType::Utf8`], [`DataType::LargeUtf8`], and [`DataType::Utf8View`]. +/// +/// Some functions may have specialized implementations for these types, while others +/// may be able to handle only one of them. For example, a function that +/// only works with [`DataType::Utf8View`] would have the following signature: +/// +/// ``` +/// # use arrow::datatypes::DataType; +/// # use datafusion_expr_common::signature::{TypeSignature}; +/// // Declares the function must be invoked with a single argument of type `Utf8View`. +/// // if a user calls the function with `Utf8` or `LargeUtf8`, DataFusion will +/// // automatically add a cast to `Utf8View` during planning. +/// let type_signature = TypeSignature::Exact(vec![DataType::Utf8View]); +/// +/// ``` +/// +/// # Example: Timestamps /// /// Types to match are represented using Arrow's [`DataType`]. [`DataType::Timestamp`] has an optional variable /// timezone specification. To specify a function can handle a timestamp with *ANY* timezone, use @@ -130,8 +176,9 @@ pub enum TypeSignature { Exact(Vec), /// One or more arguments belonging to the [`TypeSignatureClass`], in order. /// - /// [`Coercion`] contains not only the desired type but also the allowed casts. - /// For example, if you expect a function has string type, but you also allow it to be casted from binary type. + /// [`Coercion`] contains not only the desired type but also the allowed + /// casts. For example, if you expect a function has string type, but you + /// also allow it to be casted from binary type. /// /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. Coercible(Vec), @@ -170,7 +217,7 @@ pub enum TypeSignature { OneOf(Vec), /// A function that has an [`ArrayFunctionSignature`] ArraySignature(ArrayFunctionSignature), - /// One or more arguments of numeric types. + /// One or more arguments of numeric types, coerced to a common numeric type. /// /// See [`NativeType::is_numeric`] to know which type is considered numeric /// @@ -206,7 +253,7 @@ impl TypeSignature { /// just listing specific DataTypes. For example, TypeSignatureClass::Timestamp matches any timestamp /// type regardless of timezone or precision. /// -/// Used primarily with TypeSignature::Coercible to define function signatures that can accept +/// Used primarily with [`TypeSignature::Coercible`] to define function signatures that can accept /// arguments that can be coerced to a particular class of types. #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)] pub enum TypeSignatureClass { @@ -218,6 +265,8 @@ pub enum TypeSignatureClass { // TODO: // Numeric Integer, + /// Encompasses both the native Binary as well as arbitrarily sized FixedSizeBinary types + Binary, } impl Display for TypeSignatureClass { @@ -255,6 +304,9 @@ impl TypeSignatureClass { TypeSignatureClass::Integer => { vec![DataType::Int64] } + TypeSignatureClass::Binary => { + vec![DataType::Binary] + } } } @@ -274,6 +326,7 @@ impl TypeSignatureClass { TypeSignatureClass::Interval if logical_type.is_interval() => true, TypeSignatureClass::Duration if logical_type.is_duration() => true, TypeSignatureClass::Integer if logical_type.is_integer() => true, + TypeSignatureClass::Binary if logical_type.is_binary() => true, _ => false, } } @@ -304,6 +357,9 @@ impl TypeSignatureClass { TypeSignatureClass::Integer if native_type.is_integer() => { Ok(origin_type.to_owned()) } + TypeSignatureClass::Binary if native_type.is_binary() => { + Ok(origin_type.to_owned()) + } _ => internal_err!("May miss the matching logic in `matches_native_type`"), } } @@ -391,10 +447,11 @@ impl TypeSignature { vec![format!("{}, ..", Self::join_types(types, "/"))] } TypeSignature::Uniform(arg_count, valid_types) => { - vec![std::iter::repeat(Self::join_types(valid_types, "/")) - .take(*arg_count) - .collect::>() - .join(", ")] + vec![ + std::iter::repeat_n(Self::join_types(valid_types, "/"), *arg_count) + .collect::>() + .join(", "), + ] } TypeSignature::String(num) => { vec![format!("String({num})")] @@ -412,8 +469,7 @@ impl TypeSignature { vec![Self::join_types(types, ", ")] } TypeSignature::Any(arg_count) => { - vec![std::iter::repeat("Any") - .take(*arg_count) + vec![std::iter::repeat_n("Any", *arg_count) .collect::>() .join(", ")] } @@ -736,10 +792,12 @@ impl Hash for ImplicitCoercion { } } -/// Defines the supported argument types ([`TypeSignature`]) and [`Volatility`] for a function. +/// Provides information necessary for calling a function. +/// +/// - [`TypeSignature`] defines the argument types that a function has implementations +/// for. /// -/// DataFusion will automatically coerce (cast) argument types to one of the supported -/// function signatures, if possible. +/// - [`Volatility`] defines how the output of the function changes with the input. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Signature { /// The data types that the function accepts. See [TypeSignature] for more information. @@ -779,7 +837,7 @@ impl Signature { } } - /// A specified number of numeric arguments + /// A specified number of string arguments pub fn string(arg_count: usize, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::String(arg_count), @@ -843,6 +901,7 @@ impl Signature { volatility, } } + /// Any one of a list of [TypeSignature]s. pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { Signature { @@ -850,7 +909,8 @@ impl Signature { volatility, } } - /// Specialized Signature for ArrayAppend and similar functions + + /// Specialized [Signature] for ArrayAppend and similar functions. pub fn array_and_element(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( @@ -865,7 +925,41 @@ impl Signature { volatility, } } - /// Specialized Signature for Array functions with an optional index + + /// Specialized [Signature] for ArrayPrepend and similar functions. + pub fn element_and_array(volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Array, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility, + } + } + + /// Specialized [Signature] for functions that take a fixed number of arrays. + pub fn arrays( + n: usize, + coercion: Option, + volatility: Volatility, + ) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array; n], + array_coercion: coercion, + }, + ), + volatility, + } + } + + /// Specialized [Signature] for Array functions with an optional index. pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::OneOf(vec![ @@ -874,7 +968,7 @@ impl Signature { ArrayFunctionArgument::Array, ArrayFunctionArgument::Element, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ @@ -882,14 +976,14 @@ impl Signature { ArrayFunctionArgument::Element, ArrayFunctionArgument::Index, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), ]), volatility, } } - /// Specialized Signature for ArrayElement and similar functions + /// Specialized [Signature] for ArrayElement and similar functions. pub fn array_and_index(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( @@ -898,23 +992,16 @@ impl Signature { ArrayFunctionArgument::Array, ArrayFunctionArgument::Index, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }, ), volatility, } } - /// Specialized Signature for ArrayEmpty and similar functions + + /// Specialized [Signature] for ArrayEmpty and similar functions. pub fn array(volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, - }, - ), - volatility, - } + Signature::arrays(1, Some(ListCoercion::FixedSizedListToList), volatility) } } @@ -940,8 +1027,7 @@ mod tests { for case in positive_cases { assert!( case.supports_zero_argument(), - "Expected {:?} to support zero arguments", - case + "Expected {case:?} to support zero arguments" ); } @@ -960,8 +1046,7 @@ mod tests { for case in negative_cases { assert!( !case.supports_zero_argument(), - "Expected {:?} not to support zero arguments", - case + "Expected {case:?} not to support zero arguments" ); } } diff --git a/datafusion/expr-common/src/statistics.rs b/datafusion/expr-common/src/statistics.rs index 7e0bc88087efb..5c5e397e74e76 100644 --- a/datafusion/expr-common/src/statistics.rs +++ b/datafusion/expr-common/src/statistics.rs @@ -189,7 +189,7 @@ impl Distribution { pub fn target_type(args: &[&ScalarValue]) -> Result { let mut arg_types = args .iter() - .filter(|&&arg| (arg != &ScalarValue::Null)) + .filter(|&&arg| arg != &ScalarValue::Null) .map(|&arg| arg.data_type()); let Some(dt) = arg_types.next().map_or_else( @@ -1559,18 +1559,14 @@ mod tests { assert_eq!( new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?, apply_operator(&op, a, b)?, - "Failed for {:?} {op} {:?}", - dist_a, - dist_b + "Failed for {dist_a:?} {op} {dist_b:?}" ); } for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] { assert_eq!( create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?, apply_operator(&op, a, b)?, - "Failed for {:?} {op} {:?}", - dist_a, - dist_b + "Failed for {dist_a:?} {op} {dist_b:?}" ); } } diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 13d52959aba65..e77a072a84f38 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -17,8 +17,9 @@ use crate::signature::TypeSignature; use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, + DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, }; use datafusion_common::{internal_err, plan_err, Result}; @@ -82,48 +83,48 @@ pub static TIMES: &[DataType] = &[ DataType::Time64(TimeUnit::Nanosecond), ]; -/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// Validate the length of `input_fields` matches the `signature` for `agg_fun`. /// -/// This method DOES NOT validate the argument types - only that (at least one, +/// This method DOES NOT validate the argument fields - only that (at least one, /// in the case of [`TypeSignature::OneOf`]) signature matches the desired /// number of input types. pub fn check_arg_count( func_name: &str, - input_types: &[DataType], + input_fields: &[FieldRef], signature: &TypeSignature, ) -> Result<()> { match signature { TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != *agg_count { + if input_fields.len() != *agg_count { return plan_err!( "The function {func_name} expects {:?} arguments, but {:?} were provided", agg_count, - input_types.len() + input_fields.len() ); } } TypeSignature::Exact(types) => { - if types.len() != input_types.len() { + if types.len() != input_fields.len() { return plan_err!( "The function {func_name} expects {:?} arguments, but {:?} were provided", types.len(), - input_types.len() + input_fields.len() ); } } TypeSignature::OneOf(variants) => { let ok = variants .iter() - .any(|v| check_arg_count(func_name, input_types, v).is_ok()); + .any(|v| check_arg_count(func_name, input_fields, v).is_ok()); if !ok { return plan_err!( "The function {func_name} does not accept {:?} function arguments.", - input_types.len() + input_fields.len() ); } } TypeSignature::VariadicAny => { - if input_types.is_empty() { + if input_fields.is_empty() { return plan_err!( "The function {func_name} expects at least one argument" ); @@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal32(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal32(new_precision, *scale)) + } + DataType::Decimal64(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal64(new_precision, *scale)) + } DataType::Decimal128(precision, scale) => { // In the spark, the result type is DECIMAL(min(38,precision+10), s) // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 @@ -171,7 +184,7 @@ pub fn variance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - plan_err!("VAR does not support {arg_type:?}") + plan_err!("VAR does not support {arg_type}") } } @@ -180,7 +193,7 @@ pub fn covariance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - plan_err!("COVAR does not support {arg_type:?}") + plan_err!("COVAR does not support {arg_type}") } } @@ -189,13 +202,27 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - plan_err!("CORR does not support {arg_type:?}") + plan_err!("CORR does not support {arg_type}") } } /// Function return type of an average pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result { match arg_type { + DataType::Decimal32(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal32(new_precision, new_scale)) + } + DataType::Decimal64(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal64(new_precision, new_scale)) + } DataType::Decimal128(precision, scale) => { // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 @@ -210,6 +237,7 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal256(new_precision, new_scale)) } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_return_type(func_name, dict_value_type.as_ref()) @@ -221,6 +249,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result /// Internal sum type of an average pub fn avg_sum_type(arg_type: &DataType) -> Result { match arg_type { + DataType::Decimal32(precision, scale) => { + // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal32(new_precision, *scale)) + } + DataType::Decimal64(precision, scale) => { + // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal64(new_precision, *scale)) + } DataType::Decimal128(precision, scale) => { // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); @@ -231,6 +269,7 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_sum_type(dict_value_type.as_ref()) @@ -247,7 +286,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { _ => matches!( arg_type, arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) + || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) ), } } @@ -260,7 +299,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { _ => matches!( arg_type, arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _)) + || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) ), } } @@ -295,13 +334,16 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result Result { match &data_type { + DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), + DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), d if d.is_numeric() => Ok(DataType::Float64), + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()), _ => { plan_err!( - "The function {:?} does not support inputs of type {:?}.", + "The function {:?} does not support inputs of type {}.", func_name, data_type ) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index c49de3984097f..52bb211d9b99b 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -27,6 +27,8 @@ use arrow::compute::can_cast_types; use arrow::datatypes::{ DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, + DECIMAL64_MAX_SCALE, }; use datafusion_common::types::NativeType; use datafusion_common::{ @@ -124,6 +126,64 @@ impl<'a> BinaryTypeCoercer<'a> { /// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs` fn signature(&'a self) -> Result { + // Special handling for arithmetic operations with both `lhs` and `rhs` NULL: + // When both operands are NULL, we are providing a concrete numeric type (Int64) + // to allow the arithmetic operation to proceed. This ensures NULL `op` NULL returns NULL + // instead of failing during planning. + if matches!((self.lhs, self.rhs), (DataType::Null, DataType::Null)) + && self.op.is_numerical_operators() + { + return Ok(Signature::uniform(DataType::Int64)); + } + + if let Some(coerced) = null_coercion(self.lhs, self.rhs) { + // Special handling for arithmetic + null coercion: + // For arithmetic operators on non-temporal types, we must handle the result type here using Arrow's numeric kernel. + // This is because Arrow expects concrete numeric types, and this ensures the correct result type (e.g., for NULL + Int32, result is Int32). + // For all other cases (including temporal arithmetic and non-arithmetic operators), + // we can delegate to signature_inner(&coerced, &coerced), which handles the necessary logic for those operators. + // In those cases, signature_inner is designed to work with the coerced type, even if it originated from a NULL. + if self.op.is_numerical_operators() && !coerced.is_temporal() { + let ret = self.get_result(&coerced, &coerced).map_err(|e| { + plan_datafusion_err!( + "Cannot get result type for arithmetic operation {coerced} {} {coerced}: {e}", + self.op + ) + })?; + + return Ok(Signature { + lhs: coerced.clone(), + rhs: coerced, + ret, + }); + } + return self.signature_inner(&coerced, &coerced); + } + self.signature_inner(self.lhs, self.rhs) + } + + /// Returns the result type for arithmetic operations + fn get_result( + &self, + lhs: &DataType, + rhs: &DataType, + ) -> arrow::error::Result { + use arrow::compute::kernels::numeric::*; + let l = new_empty_array(lhs); + let r = new_empty_array(rhs); + + let result = match self.op { + Operator::Plus => add_wrapping(&l, &r), + Operator::Minus => sub_wrapping(&l, &r), + Operator::Multiply => mul_wrapping(&l, &r), + Operator::Divide => div(&l, &r), + Operator::Modulo => rem(&l, &r), + _ => unreachable!(), + }; + result.map(|x| x.data_type().clone()) + } + + fn signature_inner(&'a self, lhs: &DataType, rhs: &DataType) -> Result { use arrow::datatypes::DataType::*; use Operator::*; let result = match self.op { @@ -135,7 +195,7 @@ impl<'a> BinaryTypeCoercer<'a> { GtEq | IsDistinctFrom | IsNotDistinctFrom => { - comparison_coercion(self.lhs, self.rhs).map(Signature::comparison).ok_or_else(|| { + comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common argument type for comparison operation {} {} {}", self.lhs, @@ -144,9 +204,9 @@ impl<'a> BinaryTypeCoercer<'a> { ) }) } - And | Or => if matches!((self.lhs, self.rhs), (Boolean | Null, Boolean | Null)) { + And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { // Logical binary boolean operators can only be evaluated for - // boolean or null arguments. + // boolean or null arguments. Ok(Signature::uniform(Boolean)) } else { plan_err!( @@ -154,28 +214,28 @@ impl<'a> BinaryTypeCoercer<'a> { ) } RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => { - regex_coercion(self.lhs, self.rhs).map(Signature::comparison).ok_or_else(|| { + regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs ) }) } LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => { - regex_coercion(self.lhs, self.rhs).map(Signature::comparison).ok_or_else(|| { + regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs ) }) } BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => { - bitwise_coercion(self.lhs, self.rhs).map(Signature::uniform).ok_or_else(|| { + bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common type for bitwise operation {} {} {}", self.lhs, self.op, self.rhs ) }) } StringConcat => { - string_concat_coercion(self.lhs, self.rhs).map(Signature::uniform).ok_or_else(|| { + string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common string type for string concat operation {} {} {}", self.lhs, self.op, self.rhs ) @@ -183,8 +243,8 @@ impl<'a> BinaryTypeCoercer<'a> { } AtArrow | ArrowAt => { // Array contains or search (similar to LIKE) operation - array_coercion(self.lhs, self.rhs) - .or_else(|| like_coercion(self.lhs, self.rhs)).map(Signature::comparison).ok_or_else(|| { + array_coercion(lhs, rhs) + .or_else(|| like_coercion(lhs, rhs)).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common argument type for operation {} {} {}", self.lhs, self.op, self.rhs ) @@ -192,40 +252,24 @@ impl<'a> BinaryTypeCoercer<'a> { } AtAt => { // text search has similar signature to LIKE - like_coercion(self.lhs, self.rhs).map(Signature::comparison).ok_or_else(|| { + like_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| { plan_datafusion_err!( "Cannot infer common argument type for AtAt operation {} {} {}", self.lhs, self.op, self.rhs ) }) } Plus | Minus | Multiply | Divide | Modulo => { - let get_result = |lhs, rhs| { - use arrow::compute::kernels::numeric::*; - let l = new_empty_array(lhs); - let r = new_empty_array(rhs); - - let result = match self.op { - Plus => add_wrapping(&l, &r), - Minus => sub_wrapping(&l, &r), - Multiply => mul_wrapping(&l, &r), - Divide => div(&l, &r), - Modulo => rem(&l, &r), - _ => unreachable!(), - }; - result.map(|x| x.data_type().clone()) - }; - - if let Ok(ret) = get_result(self.lhs, self.rhs) { + if let Ok(ret) = self.get_result(lhs, rhs) { // Temporal arithmetic, e.g. Date32 + Interval Ok(Signature{ - lhs: self.lhs.clone(), - rhs: self.rhs.clone(), + lhs: lhs.clone(), + rhs: rhs.clone(), ret, }) - } else if let Some(coerced) = temporal_coercion_strict_timezone(self.lhs, self.rhs) { + } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) { // Temporal arithmetic by first coercing to a common time representation // e.g. Date32 - Timestamp - let ret = get_result(&coerced, &coerced).map_err(|e| { + let ret = self.get_result(&coerced, &coerced).map_err(|e| { plan_datafusion_err!( "Cannot get result type for temporal operation {coerced} {} {coerced}: {e}", self.op ) @@ -235,9 +279,9 @@ impl<'a> BinaryTypeCoercer<'a> { rhs: coerced, ret, }) - } else if let Some((lhs, rhs)) = math_decimal_coercion(self.lhs, self.rhs) { + } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) { // Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0) - let ret = get_result(&lhs, &rhs).map_err(|e| { + let ret = self.get_result(&lhs, &rhs).map_err(|e| { plan_datafusion_err!( "Cannot get result type for decimal operation {} {} {}: {e}", self.lhs, self.op, self.rhs ) @@ -247,7 +291,7 @@ impl<'a> BinaryTypeCoercer<'a> { rhs, ret, }) - } else if let Some(numeric) = mathematics_numerical_coercion(self.lhs, self.rhs) { + } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) { // Numeric arithmetic, e.g. Int32 + Int32 Ok(Signature::uniform(numeric)) } else { @@ -283,6 +327,16 @@ impl<'a> BinaryTypeCoercer<'a> { // TODO Move the rest inside of BinaryTypeCoercer +fn is_decimal(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Decimal32(..) + | DataType::Decimal64(..) + | DataType::Decimal128(..) + | DataType::Decimal256(..) + ) +} + /// Coercion rules for mathematics operators between decimal and non-decimal types. fn math_decimal_coercion( lhs_type: &DataType, @@ -299,25 +353,84 @@ fn math_decimal_coercion( let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?; Some((lhs_type, value_type)) } - (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), Null) => { - Some((dec_type.clone(), dec_type.clone())) - } - (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => { + ( + Null, + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => Some((rhs_type.clone(), rhs_type.clone())), + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + Null, + ) => Some((lhs_type.clone(), lhs_type.clone())), + (Decimal32(_, _), Decimal32(_, _)) + | (Decimal64(_, _), Decimal64(_, _)) + | (Decimal128(_, _), Decimal128(_, _)) + | (Decimal256(_, _), Decimal256(_, _)) => { Some((lhs_type.clone(), rhs_type.clone())) } + // Cross-variant decimal coercion - choose larger variant with appropriate precision/scale + (lhs, rhs) + if is_decimal(lhs) + && is_decimal(rhs) + && std::mem::discriminant(lhs) != std::mem::discriminant(rhs) => + { + let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?; + Some((coerced_type.clone(), coerced_type)) + } // Unlike with comparison we don't coerce to a decimal in the case of floating point // numbers, instead falling back to floating point arithmetic instead - (Decimal128(_, _), Int8 | Int16 | Int32 | Int64) => { - Some((lhs_type.clone(), coerce_numeric_type_to_decimal(rhs_type)?)) - } - (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => { - Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone())) - } - (Decimal256(_, _), Int8 | Int16 | Int32 | Int64) => Some(( + ( + Decimal32(_, _), + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + ) => Some(( + lhs_type.clone(), + coerce_numeric_type_to_decimal32(rhs_type)?, + )), + ( + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + Decimal32(_, _), + ) => Some(( + coerce_numeric_type_to_decimal32(lhs_type)?, + rhs_type.clone(), + )), + ( + Decimal64(_, _), + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + ) => Some(( + lhs_type.clone(), + coerce_numeric_type_to_decimal64(rhs_type)?, + )), + ( + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + Decimal64(_, _), + ) => Some(( + coerce_numeric_type_to_decimal64(lhs_type)?, + rhs_type.clone(), + )), + ( + Decimal128(_, _), + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + ) => Some(( + lhs_type.clone(), + coerce_numeric_type_to_decimal128(rhs_type)?, + )), + ( + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + Decimal128(_, _), + ) => Some(( + coerce_numeric_type_to_decimal128(lhs_type)?, + rhs_type.clone(), + )), + ( + Decimal256(_, _), + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + ) => Some(( lhs_type.clone(), coerce_numeric_type_to_decimal256(rhs_type)?, )), - (Int8 | Int16 | Int32 | Int64, Decimal256(_, _)) => Some(( + ( + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + Decimal256(_, _), + ) => Some(( coerce_numeric_type_to_decimal256(lhs_type)?, rhs_type.clone(), )), @@ -462,7 +575,7 @@ pub fn type_union_resolution(data_types: &[DataType]) -> Option { // If all the data_types are null, return string if data_types.iter().all(|t| t == &DataType::Null) { - return Some(DataType::Utf8); + return Some(DataType::Utf8View); } // Ignore Nulls, if any data_type category is not the same, return None @@ -644,7 +757,7 @@ pub fn try_type_union_resolution_with_struct( keys_string = Some(keys); } } else { - return exec_err!("Expect to get struct but got {}", data_type); + return exec_err!("Expect to get struct but got {data_type}"); } } @@ -676,7 +789,7 @@ pub fn try_type_union_resolution_with_struct( } } } else { - return exec_err!("Expect to get struct but got {}", data_type); + return exec_err!("Expect to get struct but got {data_type}"); } } @@ -733,6 +846,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; + // Prefer decimal data type over floating point for comparison operation match (lhs_type, rhs_type) { - // Prefer decimal data type over floating point for comparison operation - (Decimal128(_, _), Decimal128(_, _)) => { + // Same decimal types + (lhs_type, rhs_type) + if is_decimal(lhs_type) + && is_decimal(rhs_type) + && std::mem::discriminant(lhs_type) + == std::mem::discriminant(rhs_type) => + { get_wider_decimal_type(lhs_type, rhs_type) } - (Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), - (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), - (Decimal256(_, _), Decimal256(_, _)) => { - get_wider_decimal_type(lhs_type, rhs_type) + // Mismatched decimal types + (lhs_type, rhs_type) + if is_decimal(lhs_type) + && is_decimal(rhs_type) + && std::mem::discriminant(lhs_type) + != std::mem::discriminant(rhs_type) => + { + get_wider_decimal_type_cross_variant(lhs_type, rhs_type) + } + // Decimal + non-decimal types + (Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _) => { + get_common_decimal_type(lhs_type, rhs_type) + } + (_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => { + get_common_decimal_type(rhs_type, lhs_type) } - (Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), - (_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type), (_, _) => None, } } +/// Handle cross-variant decimal widening by choosing the larger variant +fn get_wider_decimal_type_cross_variant( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + let (p1, s1) = match lhs_type { + Decimal32(p, s) => (*p, *s), + Decimal64(p, s) => (*p, *s), + Decimal128(p, s) => (*p, *s), + Decimal256(p, s) => (*p, *s), + _ => return None, + }; + + let (p2, s2) = match rhs_type { + Decimal32(p, s) => (*p, *s), + Decimal64(p, s) => (*p, *s), + Decimal128(p, s) => (*p, *s), + Decimal256(p, s) => (*p, *s), + _ => return None, + }; + + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + let s = s1.max(s2); + let range = (p1 as i8 - s1).max(p2 as i8 - s2); + let required_precision = (range + s) as u8; + + // Choose the larger variant between the two input types, while making sure we don't overflow the precision. + match (lhs_type, rhs_type) { + (Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _)) + if required_precision <= DECIMAL64_MAX_PRECISION => + { + Some(Decimal64(required_precision, s)) + } + (Decimal32(_, _), Decimal128(_, _)) + | (Decimal128(_, _), Decimal32(_, _)) + | (Decimal64(_, _), Decimal128(_, _)) + | (Decimal128(_, _), Decimal64(_, _)) + if required_precision <= DECIMAL128_MAX_PRECISION => + { + Some(Decimal128(required_precision, s)) + } + (Decimal32(_, _), Decimal256(_, _)) + | (Decimal256(_, _), Decimal32(_, _)) + | (Decimal64(_, _), Decimal256(_, _)) + | (Decimal256(_, _), Decimal64(_, _)) + | (Decimal128(_, _), Decimal256(_, _)) + | (Decimal256(_, _), Decimal128(_, _)) + if required_precision <= DECIMAL256_MAX_PRECISION => + { + Some(Decimal256(required_precision, s)) + } + _ => None, + } +} /// Coerce `lhs_type` and `rhs_type` to a common type. fn get_common_decimal_type( @@ -881,8 +1066,16 @@ fn get_common_decimal_type( ) -> Option { use arrow::datatypes::DataType::*; match decimal_type { + Decimal32(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?; + get_wider_decimal_type(decimal_type, &other_decimal_type) + } + Decimal64(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?; + get_wider_decimal_type(decimal_type, &other_decimal_type) + } Decimal128(_, _) => { - let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?; + let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?; get_wider_decimal_type(decimal_type, &other_decimal_type) } Decimal256(_, _) => { @@ -893,7 +1086,7 @@ fn get_common_decimal_type( } } -/// Returns a `DataType::Decimal128` that can store any value from either +/// Returns a decimal [`DataType`] variant that can store any value from either /// `lhs_decimal_type` and `rhs_decimal_type` /// /// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`. @@ -902,11 +1095,23 @@ fn get_wider_decimal_type( rhs_type: &DataType, ) -> Option { match (lhs_decimal_type, rhs_type) { + (DataType::Decimal32(p1, s1), DataType::Decimal32(p2, s2)) => { + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + let s = *s1.max(s2); + let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); + Some(create_decimal32_type((range + s) as u8, s)) + } + (DataType::Decimal64(p1, s1), DataType::Decimal64(p2, s2)) => { + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + let s = *s1.max(s2); + let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); + Some(create_decimal64_type((range + s) as u8, s)) + } (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => { // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) let s = *s1.max(s2); let range = (*p1 as i8 - s1).max(*p2 as i8 - s2); - Some(create_decimal_type((range + s) as u8, s)) + Some(create_decimal128_type((range + s) as u8, s)) } (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => { // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) @@ -920,7 +1125,39 @@ fn get_wider_decimal_type( /// Convert the numeric data type to the decimal data type. /// We support signed and unsigned integer types and floating-point type. -fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { +fn coerce_numeric_type_to_decimal32(numeric_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + // This conversion rule is from spark + // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 + match numeric_type { + Int8 | UInt8 => Some(Decimal32(3, 0)), + Int16 | UInt16 => Some(Decimal32(5, 0)), + // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal32(6, 3)), + _ => None, + } +} + +/// Convert the numeric data type to the decimal data type. +/// We support signed and unsigned integer types and floating-point type. +fn coerce_numeric_type_to_decimal64(numeric_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + // This conversion rule is from spark + // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 + match numeric_type { + Int8 | UInt8 => Some(Decimal64(3, 0)), + Int16 | UInt16 => Some(Decimal64(5, 0)), + Int32 | UInt32 => Some(Decimal64(10, 0)), + // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal64(6, 3)), + Float32 => Some(Decimal64(14, 7)), + _ => None, + } +} + +/// Convert the numeric data type to the decimal data type. +/// We support signed and unsigned integer types and floating-point type. +fn coerce_numeric_type_to_decimal128(numeric_type: &DataType) -> Option { use arrow::datatypes::DataType::*; // This conversion rule is from spark // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127 @@ -930,6 +1167,7 @@ fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { Int32 | UInt32 => Some(Decimal128(10, 0)), Int64 | UInt64 => Some(Decimal128(20, 0)), // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal128(6, 3)), Float32 => Some(Decimal128(14, 7)), Float64 => Some(Decimal128(30, 15)), _ => None, @@ -948,6 +1186,7 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Some(Decimal256(10, 0)), Int64 | UInt64 => Some(Decimal256(20, 0)), // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal256(6, 3)), Float32 => Some(Decimal256(14, 7)), Float64 => Some(Decimal256(30, 15)), _ => None, @@ -987,6 +1226,25 @@ fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> Field Arc::new(Field::new(name, common_type, is_nullable)) } +/// coerce two types if they are Maps by coercing their inner 'entries' fields' types +/// using struct coercion +fn map_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Map(lhs_field, lhs_ordered), Map(rhs_field, rhs_ordered)) => { + struct_coercion(lhs_field.data_type(), rhs_field.data_type()).map( + |key_value_type| { + Map( + Arc::new((**lhs_field).clone().with_data_type(key_value_type)), + *lhs_ordered && *rhs_ordered, + ) + }, + ) + } + _ => None, + } +} + /// Returns the output type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_numerical_coercion( @@ -1024,6 +1282,7 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), + (_, Float16) | (Float16, _) => Some(Float16), // The following match arms encode the following logic: Given the two // integral types, we choose the narrowest possible integral type that // accommodates all values of both types. Note that to avoid information @@ -1047,7 +1306,21 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option DataType { +fn create_decimal32_type(precision: u8, scale: i8) -> DataType { + DataType::Decimal32( + DECIMAL32_MAX_PRECISION.min(precision), + DECIMAL32_MAX_SCALE.min(scale), + ) +} + +fn create_decimal64_type(precision: u8, scale: i8) -> DataType { + DataType::Decimal64( + DECIMAL64_MAX_PRECISION.min(precision), + DECIMAL64_MAX_SCALE.min(scale), + ) +} + +fn create_decimal128_type(precision: u8, scale: i8) -> DataType { DataType::Decimal128( DECIMAL128_MAX_PRECISION.min(precision), DECIMAL128_MAX_SCALE.min(scale), @@ -1118,7 +1391,7 @@ fn dictionary_comparison_coercion( /// 2. Data type of the other side should be able to cast to string type fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; - string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { + string_coercion(lhs_type, rhs_type).or_else(|| match (lhs_type, rhs_type) { (Utf8View, from_type) | (from_type, Utf8View) => { string_concat_internal_coercion(from_type, &Utf8View) } @@ -1179,7 +1452,8 @@ pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Utf8 | LargeUtf8, other_type) | (other_type, Utf8 | LargeUtf8) + (Utf8 | LargeUtf8 | Utf8View, other_type) + | (other_type, Utf8 | LargeUtf8 | Utf8View) if other_type.is_numeric() => { Some(other_type.clone()) @@ -1277,6 +1551,13 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(LargeBinary) } (Binary, Utf8) | (Utf8, Binary) => Some(Binary), + + // Cast FixedSizeBinary to Binary + (FixedSizeBinary(_), Binary) | (Binary, FixedSizeBinary(_)) => Some(Binary), + (FixedSizeBinary(_), BinaryView) | (BinaryView, FixedSizeBinary(_)) => { + Some(BinaryView) + } + _ => None, } } @@ -1454,8 +1735,8 @@ fn timeunit_coercion(lhs_unit: &TimeUnit, rhs_unit: &TimeUnit) -> TimeUnit { } } -/// Coercion rules from NULL type. Since NULL can be casted to any other type in arrow, -/// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coercion is valid. +/// Coercion rules from NULL type. Since NULL can be cast to any other type in arrow, +/// either lhs or rhs is NULL, if NULL can be cast to type of the other side, the coercion is valid. fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { match (lhs_type, rhs_type) { (DataType::Null, other_type) | (other_type, DataType::Null) => { @@ -1470,1017 +1751,4 @@ fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { } #[cfg(test)] -mod tests { - use super::*; - - use datafusion_common::assert_contains; - - #[test] - fn test_coercion_error() -> Result<()> { - let coercer = - BinaryTypeCoercer::new(&DataType::Float32, &Operator::Plus, &DataType::Utf8); - let result_type = coercer.get_input_types(); - - let e = result_type.unwrap_err(); - assert_eq!(e.strip_backtrace(), "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types"); - Ok(()) - } - - #[test] - fn test_decimal_binary_comparison_coercion() -> Result<()> { - let input_decimal = DataType::Decimal128(20, 3); - let input_types = [ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(38, 10), - DataType::Decimal128(20, 8), - DataType::Null, - ]; - let result_types = [ - DataType::Decimal128(20, 3), - DataType::Decimal128(20, 3), - DataType::Decimal128(20, 3), - DataType::Decimal128(23, 3), - DataType::Decimal128(24, 7), - DataType::Decimal128(32, 15), - DataType::Decimal128(38, 10), - DataType::Decimal128(25, 8), - DataType::Decimal128(20, 3), - ]; - let comparison_op_types = [ - Operator::NotEq, - Operator::Eq, - Operator::Gt, - Operator::GtEq, - Operator::Lt, - Operator::LtEq, - ]; - for (i, input_type) in input_types.iter().enumerate() { - let expect_type = &result_types[i]; - for op in comparison_op_types { - let (lhs, rhs) = BinaryTypeCoercer::new(&input_decimal, &op, input_type) - .get_input_types()?; - assert_eq!(expect_type, &lhs); - assert_eq!(expect_type, &rhs); - } - } - // negative test - let result_type = - BinaryTypeCoercer::new(&input_decimal, &Operator::Eq, &DataType::Boolean) - .get_input_types(); - assert!(result_type.is_err()); - Ok(()) - } - - #[test] - fn test_decimal_mathematics_op_type() { - assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(), - DataType::Decimal128(3, 0) - ); - assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(), - DataType::Decimal128(5, 0) - ); - assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(), - DataType::Decimal128(10, 0) - ); - assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(), - DataType::Decimal128(20, 0) - ); - assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(), - DataType::Decimal128(14, 7) - ); - assert_eq!( - coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(), - DataType::Decimal128(30, 15) - ); - } - - #[test] - fn test_dictionary_type_coercion() { - use DataType::*; - - let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); - let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); - assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), - Some(Int32) - ); - assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), - Some(Int32) - ); - - // Since we can coerce values of Int16 to Utf8 can support this - let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); - assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), - Some(Utf8) - ); - - // Since we can coerce values of Utf8 to Binary can support this - let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); - assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), - Some(Binary) - ); - - let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - let rhs_type = Utf8; - assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), - Some(Utf8) - ); - assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), - Some(lhs_type.clone()) - ); - - let lhs_type = Utf8; - let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), - Some(Utf8) - ); - assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), - Some(rhs_type.clone()) - ); - } - - /// Test coercion rules for binary operators - /// - /// Applies coercion rules for `$LHS_TYPE $OP $RHS_TYPE` and asserts that - /// the result type is `$RESULT_TYPE` - macro_rules! test_coercion_binary_rule { - ($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE:expr) => {{ - let (lhs, rhs) = - BinaryTypeCoercer::new(&$LHS_TYPE, &$OP, &$RHS_TYPE).get_input_types()?; - assert_eq!(lhs, $RESULT_TYPE); - assert_eq!(rhs, $RESULT_TYPE); - }}; - } - - /// Test coercion rules for binary operators - /// - /// Applies coercion rules for each RHS_TYPE in $RHS_TYPES such that - /// `$LHS_TYPE $OP RHS_TYPE` and asserts that the result type is `$RESULT_TYPE`. - /// Also tests that the inverse `RHS_TYPE $OP $LHS_TYPE` is true - macro_rules! test_coercion_binary_rule_multiple { - ($LHS_TYPE:expr, $RHS_TYPES:expr, $OP:expr, $RESULT_TYPE:expr) => {{ - for rh_type in $RHS_TYPES { - let (lhs, rhs) = BinaryTypeCoercer::new(&$LHS_TYPE, &$OP, &rh_type) - .get_input_types()?; - assert_eq!(lhs, $RESULT_TYPE); - assert_eq!(rhs, $RESULT_TYPE); - - BinaryTypeCoercer::new(&rh_type, &$OP, &$LHS_TYPE).get_input_types()?; - assert_eq!(lhs, $RESULT_TYPE); - assert_eq!(rhs, $RESULT_TYPE); - } - }}; - } - - /// Test coercion rules for like - /// - /// Applies coercion rules for both - /// * `$LHS_TYPE LIKE $RHS_TYPE` - /// * `$RHS_TYPE LIKE $LHS_TYPE` - /// - /// And asserts the result type is `$RESULT_TYPE` - macro_rules! test_like_rule { - ($LHS_TYPE:expr, $RHS_TYPE:expr, $RESULT_TYPE:expr) => {{ - println!("Coercing {} LIKE {}", $LHS_TYPE, $RHS_TYPE); - let result = like_coercion(&$LHS_TYPE, &$RHS_TYPE); - assert_eq!(result, $RESULT_TYPE); - // reverse the order - let result = like_coercion(&$RHS_TYPE, &$LHS_TYPE); - assert_eq!(result, $RESULT_TYPE); - }}; - } - - #[test] - fn test_date_timestamp_arithmetic_error() -> Result<()> { - let (lhs, rhs) = BinaryTypeCoercer::new( - &DataType::Timestamp(TimeUnit::Nanosecond, None), - &Operator::Minus, - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .get_input_types()?; - assert_eq!(lhs.to_string(), "Timestamp(Millisecond, None)"); - assert_eq!(rhs.to_string(), "Timestamp(Millisecond, None)"); - - let err = - BinaryTypeCoercer::new(&DataType::Date32, &Operator::Plus, &DataType::Date64) - .get_input_types() - .unwrap_err() - .to_string(); - - assert_contains!( - &err, - "Cannot get result type for temporal operation Date64 + Date64" - ); - - Ok(()) - } - - #[test] - fn test_like_coercion() { - // string coerce to strings - test_like_rule!(DataType::Utf8, DataType::Utf8, Some(DataType::Utf8)); - test_like_rule!( - DataType::LargeUtf8, - DataType::Utf8, - Some(DataType::LargeUtf8) - ); - test_like_rule!( - DataType::Utf8, - DataType::LargeUtf8, - Some(DataType::LargeUtf8) - ); - test_like_rule!( - DataType::LargeUtf8, - DataType::LargeUtf8, - Some(DataType::LargeUtf8) - ); - - // Also coerce binary to strings - test_like_rule!(DataType::Binary, DataType::Utf8, Some(DataType::Utf8)); - test_like_rule!( - DataType::LargeBinary, - DataType::Utf8, - Some(DataType::LargeUtf8) - ); - test_like_rule!( - DataType::Binary, - DataType::LargeUtf8, - Some(DataType::LargeUtf8) - ); - test_like_rule!( - DataType::LargeBinary, - DataType::LargeUtf8, - Some(DataType::LargeUtf8) - ); - } - - #[test] - fn test_type_coercion() -> Result<()> { - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Date32, - Operator::Eq, - DataType::Date32 - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Date64, - Operator::Lt, - DataType::Date64 - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Time32(TimeUnit::Second), - Operator::Eq, - DataType::Time32(TimeUnit::Second) - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Time32(TimeUnit::Millisecond), - Operator::Eq, - DataType::Time32(TimeUnit::Millisecond) - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Time64(TimeUnit::Microsecond), - Operator::Eq, - DataType::Time64(TimeUnit::Microsecond) - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Time64(TimeUnit::Nanosecond), - Operator::Eq, - DataType::Time64(TimeUnit::Nanosecond) - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Timestamp(TimeUnit::Second, None), - Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Timestamp(TimeUnit::Millisecond, None), - Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Timestamp(TimeUnit::Microsecond, None), - Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), - Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Utf8, - Operator::RegexMatch, - DataType::Utf8 - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Utf8View, - Operator::RegexMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Utf8View, - DataType::Utf8, - Operator::RegexMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Utf8View, - DataType::Utf8View, - Operator::RegexMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Utf8, - Operator::RegexNotMatch, - DataType::Utf8 - ); - test_coercion_binary_rule!( - DataType::Utf8View, - DataType::Utf8, - Operator::RegexNotMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Utf8View, - Operator::RegexNotMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Utf8View, - DataType::Utf8View, - Operator::RegexNotMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Utf8, - Operator::RegexNotIMatch, - DataType::Utf8 - ); - test_coercion_binary_rule!( - DataType::Utf8View, - DataType::Utf8, - Operator::RegexNotIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Utf8View, - Operator::RegexNotIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Utf8View, - DataType::Utf8View, - Operator::RegexNotIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8, - Operator::RegexMatch, - DataType::Utf8 - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8View, - Operator::RegexMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), - DataType::Utf8, - Operator::RegexMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), - DataType::Utf8View, - Operator::RegexMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8, - Operator::RegexIMatch, - DataType::Utf8 - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), - DataType::Utf8, - Operator::RegexIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8View, - Operator::RegexIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), - DataType::Utf8View, - Operator::RegexIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8, - Operator::RegexNotMatch, - DataType::Utf8 - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8View, - Operator::RegexNotMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), - DataType::Utf8, - Operator::RegexNotMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8View, - Operator::RegexNotMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8, - Operator::RegexNotIMatch, - DataType::Utf8 - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), - DataType::Utf8, - Operator::RegexNotIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), - DataType::Utf8View, - Operator::RegexNotIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), - DataType::Utf8View, - Operator::RegexNotIMatch, - DataType::Utf8View - ); - test_coercion_binary_rule!( - DataType::Int16, - DataType::Int64, - Operator::BitwiseAnd, - DataType::Int64 - ); - test_coercion_binary_rule!( - DataType::UInt64, - DataType::UInt64, - Operator::BitwiseAnd, - DataType::UInt64 - ); - test_coercion_binary_rule!( - DataType::Int8, - DataType::UInt32, - Operator::BitwiseAnd, - DataType::Int64 - ); - test_coercion_binary_rule!( - DataType::UInt32, - DataType::Int32, - Operator::BitwiseAnd, - DataType::Int64 - ); - test_coercion_binary_rule!( - DataType::UInt16, - DataType::Int16, - Operator::BitwiseAnd, - DataType::Int32 - ); - test_coercion_binary_rule!( - DataType::UInt32, - DataType::UInt32, - Operator::BitwiseAnd, - DataType::UInt32 - ); - test_coercion_binary_rule!( - DataType::UInt16, - DataType::UInt32, - Operator::BitwiseAnd, - DataType::UInt32 - ); - Ok(()) - } - - #[test] - fn test_type_coercion_arithmetic() -> Result<()> { - use DataType::*; - - // (Float64, _) | (_, Float64) => Some(Float64), - test_coercion_binary_rule_multiple!( - Float64, - [ - Float64, Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, - Int8, UInt8 - ], - Operator::Plus, - Float64 - ); - // (_, Float32) | (Float32, _) => Some(Float32), - test_coercion_binary_rule_multiple!( - Float32, - [ - Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, - UInt8 - ], - Operator::Plus, - Float32 - ); - // (UInt64, Int64 | Int32 | Int16 | Int8) | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)), - test_coercion_binary_rule_multiple!( - UInt64, - [Int64, Int32, Int16, Int8], - Operator::Divide, - Decimal128(20, 0) - ); - // (UInt64, _) | (_, UInt64) => Some(UInt64), - test_coercion_binary_rule_multiple!( - UInt64, - [UInt64, UInt32, UInt16, UInt8], - Operator::Modulo, - UInt64 - ); - // (Int64, _) | (_, Int64) => Some(Int64), - test_coercion_binary_rule_multiple!( - Int64, - [Int64, Int32, UInt32, Int16, UInt16, Int8, UInt8], - Operator::Modulo, - Int64 - ); - // (UInt32, Int32 | Int16 | Int8) | (Int32 | Int16 | Int8, UInt32) => Some(Int64) - test_coercion_binary_rule_multiple!( - UInt32, - [Int32, Int16, Int8], - Operator::Modulo, - Int64 - ); - // (UInt32, _) | (_, UInt32) => Some(UInt32), - test_coercion_binary_rule_multiple!( - UInt32, - [UInt32, UInt16, UInt8], - Operator::Modulo, - UInt32 - ); - // (Int32, _) | (_, Int32) => Some(Int32), - test_coercion_binary_rule_multiple!( - Int32, - [Int32, Int16, Int8], - Operator::Modulo, - Int32 - ); - // (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => Some(Int32) - test_coercion_binary_rule_multiple!( - UInt16, - [Int16, Int8], - Operator::Minus, - Int32 - ); - // (UInt16, _) | (_, UInt16) => Some(UInt16), - test_coercion_binary_rule_multiple!( - UInt16, - [UInt16, UInt8, UInt8], - Operator::Plus, - UInt16 - ); - // (Int16, _) | (_, Int16) => Some(Int16), - test_coercion_binary_rule_multiple!(Int16, [Int16, Int8], Operator::Plus, Int16); - // (UInt8, Int8) | (Int8, UInt8) => Some(Int16) - test_coercion_binary_rule!(Int8, UInt8, Operator::Minus, Int16); - test_coercion_binary_rule!(UInt8, Int8, Operator::Multiply, Int16); - // (UInt8, _) | (_, UInt8) => Some(UInt8), - test_coercion_binary_rule!(UInt8, UInt8, Operator::Minus, UInt8); - // (Int8, _) | (_, Int8) => Some(Int8), - test_coercion_binary_rule!(Int8, Int8, Operator::Plus, Int8); - - Ok(()) - } - - fn test_math_decimal_coercion_rule( - lhs_type: DataType, - rhs_type: DataType, - expected_lhs_type: DataType, - expected_rhs_type: DataType, - ) { - // The coerced types for lhs and rhs, if any of them is not decimal - let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap(); - assert_eq!(lhs_type, expected_lhs_type); - assert_eq!(rhs_type, expected_rhs_type); - } - - #[test] - fn test_coercion_arithmetic_decimal() -> Result<()> { - test_math_decimal_coercion_rule( - DataType::Decimal128(10, 2), - DataType::Decimal128(10, 2), - DataType::Decimal128(10, 2), - DataType::Decimal128(10, 2), - ); - - test_math_decimal_coercion_rule( - DataType::Int32, - DataType::Decimal128(10, 2), - DataType::Decimal128(10, 0), - DataType::Decimal128(10, 2), - ); - - test_math_decimal_coercion_rule( - DataType::Int32, - DataType::Decimal128(10, 2), - DataType::Decimal128(10, 0), - DataType::Decimal128(10, 2), - ); - - test_math_decimal_coercion_rule( - DataType::Int32, - DataType::Decimal128(10, 2), - DataType::Decimal128(10, 0), - DataType::Decimal128(10, 2), - ); - - test_math_decimal_coercion_rule( - DataType::Int32, - DataType::Decimal128(10, 2), - DataType::Decimal128(10, 0), - DataType::Decimal128(10, 2), - ); - - test_math_decimal_coercion_rule( - DataType::Int32, - DataType::Decimal128(10, 2), - DataType::Decimal128(10, 0), - DataType::Decimal128(10, 2), - ); - - Ok(()) - } - - #[test] - fn test_type_coercion_compare() -> Result<()> { - // boolean - test_coercion_binary_rule!( - DataType::Boolean, - DataType::Boolean, - Operator::Eq, - DataType::Boolean - ); - // float - test_coercion_binary_rule!( - DataType::Float32, - DataType::Int64, - Operator::Eq, - DataType::Float32 - ); - test_coercion_binary_rule!( - DataType::Float32, - DataType::Float64, - Operator::GtEq, - DataType::Float64 - ); - // signed integer - test_coercion_binary_rule!( - DataType::Int8, - DataType::Int32, - Operator::LtEq, - DataType::Int32 - ); - test_coercion_binary_rule!( - DataType::Int64, - DataType::Int32, - Operator::LtEq, - DataType::Int64 - ); - // unsigned integer - test_coercion_binary_rule!( - DataType::UInt32, - DataType::UInt8, - Operator::Gt, - DataType::UInt32 - ); - test_coercion_binary_rule!( - DataType::UInt64, - DataType::UInt8, - Operator::Eq, - DataType::UInt64 - ); - test_coercion_binary_rule!( - DataType::UInt64, - DataType::Int64, - Operator::Eq, - DataType::Decimal128(20, 0) - ); - // numeric/decimal - test_coercion_binary_rule!( - DataType::Int64, - DataType::Decimal128(10, 0), - Operator::Eq, - DataType::Decimal128(20, 0) - ); - test_coercion_binary_rule!( - DataType::Int64, - DataType::Decimal128(10, 2), - Operator::Lt, - DataType::Decimal128(22, 2) - ); - test_coercion_binary_rule!( - DataType::Float64, - DataType::Decimal128(10, 3), - Operator::Gt, - DataType::Decimal128(30, 15) - ); - test_coercion_binary_rule!( - DataType::Int64, - DataType::Decimal128(10, 0), - Operator::Eq, - DataType::Decimal128(20, 0) - ); - test_coercion_binary_rule!( - DataType::Decimal128(14, 2), - DataType::Decimal128(10, 3), - Operator::GtEq, - DataType::Decimal128(15, 3) - ); - test_coercion_binary_rule!( - DataType::UInt64, - DataType::Decimal128(20, 0), - Operator::Eq, - DataType::Decimal128(20, 0) - ); - - // Binary - test_coercion_binary_rule!( - DataType::Binary, - DataType::Binary, - Operator::Eq, - DataType::Binary - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::Binary, - Operator::Eq, - DataType::Binary - ); - test_coercion_binary_rule!( - DataType::Binary, - DataType::Utf8, - Operator::Eq, - DataType::Binary - ); - - // LargeBinary - test_coercion_binary_rule!( - DataType::LargeBinary, - DataType::LargeBinary, - Operator::Eq, - DataType::LargeBinary - ); - test_coercion_binary_rule!( - DataType::Binary, - DataType::LargeBinary, - Operator::Eq, - DataType::LargeBinary - ); - test_coercion_binary_rule!( - DataType::LargeBinary, - DataType::Binary, - Operator::Eq, - DataType::LargeBinary - ); - test_coercion_binary_rule!( - DataType::Utf8, - DataType::LargeBinary, - Operator::Eq, - DataType::LargeBinary - ); - test_coercion_binary_rule!( - DataType::LargeBinary, - DataType::Utf8, - Operator::Eq, - DataType::LargeBinary - ); - test_coercion_binary_rule!( - DataType::LargeUtf8, - DataType::LargeBinary, - Operator::Eq, - DataType::LargeBinary - ); - test_coercion_binary_rule!( - DataType::LargeBinary, - DataType::LargeUtf8, - Operator::Eq, - DataType::LargeBinary - ); - - // Timestamps - let utc: Option> = Some("UTC".into()); - test_coercion_binary_rule!( - DataType::Timestamp(TimeUnit::Second, utc.clone()), - DataType::Timestamp(TimeUnit::Second, utc.clone()), - Operator::Eq, - DataType::Timestamp(TimeUnit::Second, utc.clone()) - ); - test_coercion_binary_rule!( - DataType::Timestamp(TimeUnit::Second, utc.clone()), - DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), - Operator::Eq, - DataType::Timestamp(TimeUnit::Second, utc.clone()) - ); - test_coercion_binary_rule!( - DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())), - DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), - Operator::Eq, - DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())) - ); - test_coercion_binary_rule!( - DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), - DataType::Timestamp(TimeUnit::Second, utc), - Operator::Eq, - DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())) - ); - - // list - let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true)); - test_coercion_binary_rule!( - DataType::List(Arc::clone(&inner_field)), - DataType::List(Arc::clone(&inner_field)), - Operator::Eq, - DataType::List(Arc::clone(&inner_field)) - ); - test_coercion_binary_rule!( - DataType::List(Arc::clone(&inner_field)), - DataType::LargeList(Arc::clone(&inner_field)), - Operator::Eq, - DataType::LargeList(Arc::clone(&inner_field)) - ); - test_coercion_binary_rule!( - DataType::LargeList(Arc::clone(&inner_field)), - DataType::List(Arc::clone(&inner_field)), - Operator::Eq, - DataType::LargeList(Arc::clone(&inner_field)) - ); - test_coercion_binary_rule!( - DataType::LargeList(Arc::clone(&inner_field)), - DataType::LargeList(Arc::clone(&inner_field)), - Operator::Eq, - DataType::LargeList(Arc::clone(&inner_field)) - ); - test_coercion_binary_rule!( - DataType::FixedSizeList(Arc::clone(&inner_field), 10), - DataType::FixedSizeList(Arc::clone(&inner_field), 10), - Operator::Eq, - DataType::FixedSizeList(Arc::clone(&inner_field), 10) - ); - test_coercion_binary_rule!( - DataType::FixedSizeList(Arc::clone(&inner_field), 10), - DataType::LargeList(Arc::clone(&inner_field)), - Operator::Eq, - DataType::LargeList(Arc::clone(&inner_field)) - ); - test_coercion_binary_rule!( - DataType::LargeList(Arc::clone(&inner_field)), - DataType::FixedSizeList(Arc::clone(&inner_field), 10), - Operator::Eq, - DataType::LargeList(Arc::clone(&inner_field)) - ); - test_coercion_binary_rule!( - DataType::List(Arc::clone(&inner_field)), - DataType::FixedSizeList(Arc::clone(&inner_field), 10), - Operator::Eq, - DataType::List(Arc::clone(&inner_field)) - ); - test_coercion_binary_rule!( - DataType::FixedSizeList(Arc::clone(&inner_field), 10), - DataType::List(Arc::clone(&inner_field)), - Operator::Eq, - DataType::List(Arc::clone(&inner_field)) - ); - - // Negative test: inner_timestamp_field and inner_field are not compatible because their inner types are not compatible - let inner_timestamp_field = Arc::new(Field::new_list_field( - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - )); - let result_type = BinaryTypeCoercer::new( - &DataType::List(Arc::clone(&inner_field)), - &Operator::Eq, - &DataType::List(Arc::clone(&inner_timestamp_field)), - ) - .get_input_types(); - assert!(result_type.is_err()); - - // TODO add other data type - Ok(()) - } - - #[test] - fn test_list_coercion() { - let lhs_type = DataType::List(Arc::new(Field::new("lhs", DataType::Int8, false))); - - let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true))); - - let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap(); - assert_eq!( - coerced_type, - DataType::List(Arc::new(Field::new("lhs", DataType::Int64, true))) - ); // nullable because the RHS is nullable - } - - #[test] - fn test_type_coercion_logical_op() -> Result<()> { - test_coercion_binary_rule!( - DataType::Boolean, - DataType::Boolean, - Operator::And, - DataType::Boolean - ); - - test_coercion_binary_rule!( - DataType::Boolean, - DataType::Boolean, - Operator::Or, - DataType::Boolean - ); - test_coercion_binary_rule!( - DataType::Boolean, - DataType::Null, - Operator::And, - DataType::Boolean - ); - test_coercion_binary_rule!( - DataType::Boolean, - DataType::Null, - Operator::Or, - DataType::Boolean - ); - test_coercion_binary_rule!( - DataType::Null, - DataType::Null, - Operator::Or, - DataType::Boolean - ); - test_coercion_binary_rule!( - DataType::Null, - DataType::Null, - Operator::And, - DataType::Boolean - ); - test_coercion_binary_rule!( - DataType::Null, - DataType::Boolean, - Operator::And, - DataType::Boolean - ); - test_coercion_binary_rule!( - DataType::Null, - DataType::Boolean, - Operator::Or, - DataType::Boolean - ); - Ok(()) - } -} +mod tests; diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs new file mode 100644 index 0000000000000..63945a4dabd0c --- /dev/null +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -0,0 +1,423 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; +use datafusion_common::assert_contains; + +#[test] +fn test_coercion_error() -> Result<()> { + let coercer = + BinaryTypeCoercer::new(&DataType::Float32, &Operator::Plus, &DataType::Utf8); + let result_type = coercer.get_input_types(); + + let e = result_type.unwrap_err(); + assert_eq!(e.strip_backtrace(), "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types"); + Ok(()) +} + +#[test] +fn test_date_timestamp_arithmetic_error() -> Result<()> { + let (lhs, rhs) = BinaryTypeCoercer::new( + &DataType::Timestamp(TimeUnit::Nanosecond, None), + &Operator::Minus, + &DataType::Timestamp(TimeUnit::Millisecond, None), + ) + .get_input_types()?; + assert_eq!(lhs, DataType::Timestamp(TimeUnit::Millisecond, None)); + assert_eq!(rhs, DataType::Timestamp(TimeUnit::Millisecond, None)); + + let err = + BinaryTypeCoercer::new(&DataType::Date32, &Operator::Plus, &DataType::Date64) + .get_input_types() + .unwrap_err() + .to_string(); + + assert_contains!( + &err, + "Cannot get result type for temporal operation Date64 + Date64" + ); + + Ok(()) +} + +#[test] +fn test_decimal_mathematics_op_type() { + // Decimal32 + assert_eq!( + coerce_numeric_type_to_decimal32(&DataType::Int8).unwrap(), + DataType::Decimal32(3, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal32(&DataType::Int16).unwrap(), + DataType::Decimal32(5, 0) + ); + assert!(coerce_numeric_type_to_decimal32(&DataType::Int32).is_none()); + assert!(coerce_numeric_type_to_decimal32(&DataType::Int64).is_none(),); + assert_eq!( + coerce_numeric_type_to_decimal32(&DataType::Float16).unwrap(), + DataType::Decimal32(6, 3) + ); + assert!(coerce_numeric_type_to_decimal32(&DataType::Float32).is_none(),); + assert!(coerce_numeric_type_to_decimal32(&DataType::Float64).is_none()); + + // Decimal64 + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Int8).unwrap(), + DataType::Decimal64(3, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Int16).unwrap(), + DataType::Decimal64(5, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Int32).unwrap(), + DataType::Decimal64(10, 0) + ); + assert!(coerce_numeric_type_to_decimal64(&DataType::Int64).is_none(),); + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Float16).unwrap(), + DataType::Decimal64(6, 3) + ); + assert_eq!( + coerce_numeric_type_to_decimal64(&DataType::Float32).unwrap(), + DataType::Decimal64(14, 7) + ); + assert!(coerce_numeric_type_to_decimal64(&DataType::Float64).is_none()); + + // Decimal128 + assert_eq!( + coerce_numeric_type_to_decimal128(&DataType::Int8).unwrap(), + DataType::Decimal128(3, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal128(&DataType::Int16).unwrap(), + DataType::Decimal128(5, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal128(&DataType::Int32).unwrap(), + DataType::Decimal128(10, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal128(&DataType::Int64).unwrap(), + DataType::Decimal128(20, 0) + ); + assert_eq!( + coerce_numeric_type_to_decimal128(&DataType::Float16).unwrap(), + DataType::Decimal128(6, 3) + ); + assert_eq!( + coerce_numeric_type_to_decimal128(&DataType::Float32).unwrap(), + DataType::Decimal128(14, 7) + ); + assert_eq!( + coerce_numeric_type_to_decimal128(&DataType::Float64).unwrap(), + DataType::Decimal128(30, 15) + ); +} + +#[test] +fn test_type_coercion_arithmetic() -> Result<()> { + use DataType::*; + + // (Float64, _) | (_, Float64) => Some(Float64) + test_coercion_binary_rule_multiple!( + Float64, + [ + Float64, Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, + UInt8 + ], + Operator::Plus, + Float64 + ); + // (_, Float32) | (Float32, _) => Some(Float32) + test_coercion_binary_rule_multiple!( + Float32, + [Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, UInt8], + Operator::Plus, + Float32 + ); + // (_, Float16) | (Float16, _) => Some(Float16) + test_coercion_binary_rule_multiple!( + Float16, + [Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, UInt8], + Operator::Plus, + Float16 + ); + // (UInt64, Int64 | Int32 | Int16 | Int8) | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)) + test_coercion_binary_rule_multiple!( + UInt64, + [Int64, Int32, Int16, Int8], + Operator::Divide, + Decimal128(20, 0) + ); + // (UInt64, _) | (_, UInt64) => Some(UInt64) + test_coercion_binary_rule_multiple!( + UInt64, + [UInt64, UInt32, UInt16, UInt8], + Operator::Modulo, + UInt64 + ); + // (Int64, _) | (_, Int64) => Some(Int64) + test_coercion_binary_rule_multiple!( + Int64, + [Int64, Int32, UInt32, Int16, UInt16, Int8, UInt8], + Operator::Modulo, + Int64 + ); + // (UInt32, Int32 | Int16 | Int8) | (Int32 | Int16 | Int8, UInt32) => Some(Int64) + test_coercion_binary_rule_multiple!( + UInt32, + [Int32, Int16, Int8], + Operator::Modulo, + Int64 + ); + // (UInt32, _) | (_, UInt32) => Some(UInt32) + test_coercion_binary_rule_multiple!( + UInt32, + [UInt32, UInt16, UInt8], + Operator::Modulo, + UInt32 + ); + // (Int32, _) | (_, Int32) => Some(Int32) + test_coercion_binary_rule_multiple!( + Int32, + [Int32, Int16, Int8], + Operator::Modulo, + Int32 + ); + // (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => Some(Int32) + test_coercion_binary_rule_multiple!(UInt16, [Int16, Int8], Operator::Minus, Int32); + // (UInt16, _) | (_, UInt16) => Some(UInt16) + test_coercion_binary_rule_multiple!( + UInt16, + [UInt16, UInt8, UInt8], + Operator::Plus, + UInt16 + ); + // (Int16, _) | (_, Int16) => Some(Int16) + test_coercion_binary_rule_multiple!(Int16, [Int16, Int8], Operator::Plus, Int16); + // (UInt8, Int8) | (Int8, UInt8) => Some(Int16) + test_coercion_binary_rule!(Int8, UInt8, Operator::Minus, Int16); + test_coercion_binary_rule!(UInt8, Int8, Operator::Multiply, Int16); + // (UInt8, _) | (_, UInt8) => Some(UInt8) + test_coercion_binary_rule!(UInt8, UInt8, Operator::Minus, UInt8); + // (Int8, _) | (_, Int8) => Some(Int8) + test_coercion_binary_rule!(Int8, Int8, Operator::Plus, Int8); + + Ok(()) +} + +fn test_math_decimal_coercion_rule( + lhs_type: DataType, + rhs_type: DataType, + expected_lhs_type: DataType, + expected_rhs_type: DataType, +) { + let (lhs_type, rhs_type) = math_decimal_coercion(&lhs_type, &rhs_type).unwrap(); + assert_eq!(lhs_type, expected_lhs_type); + assert_eq!(rhs_type, expected_rhs_type); +} + +#[test] +fn test_coercion_arithmetic_decimal() -> Result<()> { + test_math_decimal_coercion_rule( + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 2), + ); + + test_math_decimal_coercion_rule( + DataType::Int32, + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), + ); + + test_math_decimal_coercion_rule( + DataType::Int32, + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), + ); + + test_math_decimal_coercion_rule( + DataType::Int32, + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), + ); + + test_math_decimal_coercion_rule( + DataType::Int32, + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), + ); + + test_math_decimal_coercion_rule( + DataType::Int32, + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), + ); + + test_math_decimal_coercion_rule( + DataType::UInt32, + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 0), + DataType::Decimal128(10, 2), + ); + test_math_decimal_coercion_rule( + DataType::Decimal128(10, 2), + DataType::UInt32, + DataType::Decimal128(10, 2), + DataType::Decimal128(10, 0), + ); + + Ok(()) +} + +#[test] +fn test_coercion_arithmetic_decimal_cross_variant() -> Result<()> { + let test_cases = [ + ( + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal64(12, 3), + DataType::Decimal128(18, 2), + DataType::Decimal128(19, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + // Reverse order cases + ( + DataType::Decimal64(10, 3), + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal128(15, 4), + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal256(20, 5), + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal128(18, 2), + DataType::Decimal64(12, 3), + DataType::Decimal128(19, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal256(25, 6), + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal256(30, 8), + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + ]; + + for (lhs_type, rhs_type, expected_lhs_type, expected_rhs_type) in test_cases { + test_math_decimal_coercion_rule( + lhs_type, + rhs_type, + expected_lhs_type, + expected_rhs_type, + ); + } + + Ok(()) +} + +#[test] +fn test_decimal_precision_overflow_cross_variant() -> Result<()> { + // s = max(0, 1) = 1, range = max(76-0, 38-1) = 76, required_precision = 76 + 1 = 77 (overflow) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal256(76, 0), + &DataType::Decimal128(38, 1), + ); + assert!(result.is_none()); + + // s = max(0, 10) = 10, range = max(9-0, 18-10) = 9, required_precision = 9 + 10 = 19 (overflow > 18) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal32(9, 0), + &DataType::Decimal64(18, 10), + ); + assert!(result.is_none()); + + // s = max(5, 26) = 26, range = max(18-5, 38-26) = 13, required_precision = 13 + 26 = 39 (overflow > 38) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal64(18, 5), + &DataType::Decimal128(38, 26), + ); + assert!(result.is_none()); + + // s = max(10, 49) = 49, range = max(38-10, 76-49) = 28, required_precision = 28 + 49 = 77 (overflow > 76) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal128(38, 10), + &DataType::Decimal256(76, 49), + ); + assert!(result.is_none()); + + // s = max(2, 3) = 3, range = max(5-2, 10-3) = 7, required_precision = 7 + 3 = 10 (valid <= 18) + let result = get_wider_decimal_type_cross_variant( + &DataType::Decimal32(5, 2), + &DataType::Decimal64(10, 3), + ); + assert!(result.is_some()); + + Ok(()) +} diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs new file mode 100644 index 0000000000000..5401264e43e39 --- /dev/null +++ b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs @@ -0,0 +1,787 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[test] +fn test_decimal_binary_comparison_coercion() -> Result<()> { + let input_decimal = DataType::Decimal128(20, 3); + let input_types = [ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Decimal128(38, 10), + DataType::Decimal128(20, 8), + DataType::Null, + ]; + let result_types = [ + DataType::Decimal128(20, 3), + DataType::Decimal128(20, 3), + DataType::Decimal128(20, 3), + DataType::Decimal128(23, 3), + DataType::Decimal128(24, 7), + DataType::Decimal128(32, 15), + DataType::Decimal128(38, 10), + DataType::Decimal128(25, 8), + DataType::Decimal128(20, 3), + ]; + let comparison_op_types = [ + Operator::NotEq, + Operator::Eq, + Operator::Gt, + Operator::GtEq, + Operator::Lt, + Operator::LtEq, + ]; + for (i, input_type) in input_types.iter().enumerate() { + let expect_type = &result_types[i]; + for op in comparison_op_types { + let (lhs, rhs) = BinaryTypeCoercer::new(&input_decimal, &op, input_type) + .get_input_types()?; + assert_eq!(expect_type, &lhs); + assert_eq!(expect_type, &rhs); + } + } + // negative test + let result_type = + BinaryTypeCoercer::new(&input_decimal, &Operator::Eq, &DataType::Boolean) + .get_input_types(); + assert!(result_type.is_err()); + Ok(()) +} + +#[test] +fn test_like_coercion() { + // string coerce to strings + test_like_rule!(DataType::Utf8, DataType::Utf8, Some(DataType::Utf8)); + test_like_rule!( + DataType::LargeUtf8, + DataType::Utf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::Utf8, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::LargeUtf8, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + + // Also coerce binary to strings + test_like_rule!(DataType::Binary, DataType::Utf8, Some(DataType::Utf8)); + test_like_rule!( + DataType::LargeBinary, + DataType::Utf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::Binary, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); + test_like_rule!( + DataType::LargeBinary, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) + ); +} + +#[test] +fn test_type_coercion() -> Result<()> { + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Date32, + Operator::Eq, + DataType::Date32 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Date64, + Operator::Lt, + DataType::Date64 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Time32(TimeUnit::Second), + Operator::Eq, + DataType::Time32(TimeUnit::Second) + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Time32(TimeUnit::Millisecond), + Operator::Eq, + DataType::Time32(TimeUnit::Millisecond) + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Time64(TimeUnit::Microsecond), + Operator::Eq, + DataType::Time64(TimeUnit::Microsecond) + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Time64(TimeUnit::Nanosecond), + Operator::Eq, + DataType::Time64(TimeUnit::Nanosecond) + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Timestamp(TimeUnit::Second, None), + Operator::Lt, + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Timestamp(TimeUnit::Millisecond, None), + Operator::Lt, + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Timestamp(TimeUnit::Microsecond, None), + Operator::Lt, + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Timestamp(TimeUnit::Nanosecond, None), + Operator::Lt, + DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8View, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8View, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8View, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8View, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8View, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8View, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8View, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8, + Operator::RegexIMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8, + Operator::RegexIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8View, + Operator::RegexIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8View, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Int16, + DataType::Int64, + Operator::BitwiseAnd, + DataType::Int64 + ); + test_coercion_binary_rule!( + DataType::UInt64, + DataType::UInt64, + Operator::BitwiseAnd, + DataType::UInt64 + ); + test_coercion_binary_rule!( + DataType::Int8, + DataType::UInt32, + Operator::BitwiseAnd, + DataType::Int64 + ); + test_coercion_binary_rule!( + DataType::UInt32, + DataType::Int32, + Operator::BitwiseAnd, + DataType::Int64 + ); + test_coercion_binary_rule!( + DataType::UInt16, + DataType::Int16, + Operator::BitwiseAnd, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::UInt32, + DataType::UInt32, + Operator::BitwiseAnd, + DataType::UInt32 + ); + test_coercion_binary_rule!( + DataType::UInt16, + DataType::UInt32, + Operator::BitwiseAnd, + DataType::UInt32 + ); + Ok(()) +} + +#[test] +fn test_type_coercion_compare() -> Result<()> { + // boolean + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Boolean, + Operator::Eq, + DataType::Boolean + ); + // float + test_coercion_binary_rule!( + DataType::Float16, + DataType::Int64, + Operator::Eq, + DataType::Float16 + ); + test_coercion_binary_rule!( + DataType::Float16, + DataType::Float64, + Operator::Eq, + DataType::Float64 + ); + test_coercion_binary_rule!( + DataType::Float32, + DataType::Int64, + Operator::Eq, + DataType::Float32 + ); + test_coercion_binary_rule!( + DataType::Float32, + DataType::Float64, + Operator::GtEq, + DataType::Float64 + ); + // signed integer + test_coercion_binary_rule!( + DataType::Int8, + DataType::Int32, + Operator::LtEq, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Int32, + Operator::LtEq, + DataType::Int64 + ); + // unsigned integer + test_coercion_binary_rule!( + DataType::UInt32, + DataType::UInt8, + Operator::Gt, + DataType::UInt32 + ); + test_coercion_binary_rule!( + DataType::UInt64, + DataType::UInt8, + Operator::Eq, + DataType::UInt64 + ); + test_coercion_binary_rule!( + DataType::UInt64, + DataType::Int64, + Operator::Eq, + DataType::Decimal128(20, 0) + ); + // numeric/decimal + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 0), + Operator::Eq, + DataType::Decimal128(20, 0) + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 2), + Operator::Lt, + DataType::Decimal128(22, 2) + ); + test_coercion_binary_rule!( + DataType::Float64, + DataType::Decimal128(10, 3), + Operator::Gt, + DataType::Decimal128(30, 15) + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 0), + Operator::Eq, + DataType::Decimal128(20, 0) + ); + test_coercion_binary_rule!( + DataType::Decimal128(14, 2), + DataType::Decimal128(10, 3), + Operator::GtEq, + DataType::Decimal128(15, 3) + ); + test_coercion_binary_rule!( + DataType::UInt64, + DataType::Decimal128(20, 0), + Operator::Eq, + DataType::Decimal128(20, 0) + ); + + // Binary + test_coercion_binary_rule!( + DataType::Binary, + DataType::Binary, + Operator::Eq, + DataType::Binary + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Binary, + Operator::Eq, + DataType::Binary + ); + test_coercion_binary_rule!( + DataType::Binary, + DataType::Utf8, + Operator::Eq, + DataType::Binary + ); + + // LargeBinary + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::Binary, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::Binary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::Utf8, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeUtf8, + DataType::LargeBinary, + Operator::Eq, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::LargeUtf8, + Operator::Eq, + DataType::LargeBinary + ); + + // Timestamps + let utc: Option> = Some("UTC".into()); + test_coercion_binary_rule!( + DataType::Timestamp(TimeUnit::Second, utc.clone()), + DataType::Timestamp(TimeUnit::Second, utc.clone()), + Operator::Eq, + DataType::Timestamp(TimeUnit::Second, utc.clone()) + ); + test_coercion_binary_rule!( + DataType::Timestamp(TimeUnit::Second, utc.clone()), + DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), + Operator::Eq, + DataType::Timestamp(TimeUnit::Second, utc.clone()) + ); + test_coercion_binary_rule!( + DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())), + DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), + Operator::Eq, + DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())) + ); + test_coercion_binary_rule!( + DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), + DataType::Timestamp(TimeUnit::Second, utc), + Operator::Eq, + DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())) + ); + + // list + let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true)); + test_coercion_binary_rule!( + DataType::List(Arc::clone(&inner_field)), + DataType::List(Arc::clone(&inner_field)), + Operator::Eq, + DataType::List(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::List(Arc::clone(&inner_field)), + DataType::LargeList(Arc::clone(&inner_field)), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::LargeList(Arc::clone(&inner_field)), + DataType::List(Arc::clone(&inner_field)), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::LargeList(Arc::clone(&inner_field)), + DataType::LargeList(Arc::clone(&inner_field)), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + Operator::Eq, + DataType::FixedSizeList(Arc::clone(&inner_field), 10) + ); + test_coercion_binary_rule!( + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + DataType::LargeList(Arc::clone(&inner_field)), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::LargeList(Arc::clone(&inner_field)), + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::List(Arc::clone(&inner_field)), + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + Operator::Eq, + DataType::List(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + DataType::List(Arc::clone(&inner_field)), + Operator::Eq, + DataType::List(Arc::clone(&inner_field)) + ); + + let inner_timestamp_field = Arc::new(Field::new_list_field( + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )); + let result_type = BinaryTypeCoercer::new( + &DataType::List(Arc::clone(&inner_field)), + &Operator::Eq, + &DataType::List(Arc::clone(&inner_timestamp_field)), + ) + .get_input_types(); + assert!(result_type.is_err()); + + Ok(()) +} + +#[test] +fn test_list_coercion() { + let lhs_type = DataType::List(Arc::new(Field::new("lhs", DataType::Int8, false))); + + let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true))); + + let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap(); + assert_eq!( + coerced_type, + DataType::List(Arc::new(Field::new("lhs", DataType::Int64, true))) + ); +} + +#[test] +fn test_map_coercion() -> Result<()> { + let lhs = Field::new_map( + "lhs", + "entries", + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::LargeUtf8, false)), + true, + false, + ); + let rhs = Field::new_map( + "rhs", + "kvp", + Arc::new(Field::new("k", DataType::Utf8, false)), + Arc::new(Field::new("v", DataType::Utf8, true)), + false, + true, + ); + + let expected = Field::new_map( + "expected", + "entries", + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::LargeUtf8, true)), + false, + true, + ); + + test_coercion_binary_rule!( + lhs.data_type(), + rhs.data_type(), + Operator::Eq, + expected.data_type().clone() + ); + Ok(()) +} + +#[test] +fn test_decimal_cross_variant_comparison_coercion() -> Result<()> { + let test_cases = [ + // (lhs, rhs, expected_result) + ( + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal64(12, 3), + DataType::Decimal128(18, 2), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + // Reverse order cases + ( + DataType::Decimal64(10, 3), + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal128(15, 4), + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal256(20, 5), + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal128(18, 2), + DataType::Decimal64(12, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal256(25, 6), + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal256(30, 8), + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + ), + ]; + + let comparison_op_types = [ + Operator::NotEq, + Operator::Eq, + Operator::Gt, + Operator::GtEq, + Operator::Lt, + Operator::LtEq, + ]; + + for (lhs_type, rhs_type, expected_type) in test_cases { + for op in comparison_op_types { + let (lhs, rhs) = + BinaryTypeCoercer::new(&lhs_type, &op, &rhs_type).get_input_types()?; + assert_eq!(expected_type, lhs, "Coercion of type {lhs_type:?} with {rhs_type:?} resulted in unexpected type: {lhs:?}"); + assert_eq!(expected_type, rhs, "Coercion of type {rhs_type:?} with {lhs_type:?} resulted in unexpected type: {rhs:?}"); + } + } + + Ok(()) +} diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs b/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs new file mode 100644 index 0000000000000..0fb56a4a2c536 --- /dev/null +++ b/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[test] +fn test_dictionary_type_coercion() { + use DataType::*; + + let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); + let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(Int32) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + Some(Int32) + ); + + // Since we can coerce values of Int16 to Utf8 can support this + let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(Utf8) + ); + + // Since we can coerce values of Utf8 to Binary can support this + let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(Binary) + ); + + let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let rhs_type = Utf8; + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + Some(Utf8) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(lhs_type.clone()) + ); + + let lhs_type = Utf8; + let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + Some(Utf8) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(rhs_type.clone()) + ); +} diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/mod.rs b/datafusion/expr-common/src/type_coercion/binary/tests/mod.rs new file mode 100644 index 0000000000000..6d21d795e4b72 --- /dev/null +++ b/datafusion/expr-common/src/type_coercion/binary/tests/mod.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +// Common test macros + +/// Tests that coercion for a binary operator between two types yields the expected result type for both sides. +/// +/// Usage: test_coercion_binary_rule!(lhs_type, rhs_type, op, expected_type) +/// - lhs_type: The left-hand side data type +/// - rhs_type: The right-hand side data type +/// - op: The binary operator (e.g., "+", "-", etc.) +/// - expected_type: The type both sides should be coerced to +macro_rules! test_coercion_binary_rule { + ($LHS_TYPE:expr, $RHS_TYPE:expr, $OP:expr, $RESULT_TYPE:expr) => {{ + let (lhs, rhs) = + BinaryTypeCoercer::new(&$LHS_TYPE, &$OP, &$RHS_TYPE).get_input_types()?; + assert_eq!(lhs, $RESULT_TYPE); + assert_eq!(rhs, $RESULT_TYPE); + }}; +} + +/// Tests that coercion for a binary operator between one type and multiple right-hand side types +/// yields the expected result type for both sides, in both lhs/rhs and rhs/lhs order. +/// +/// Usage: test_coercion_binary_rule_multiple!(lhs_type, rhs_types, op, expected_type) +/// - lhs_type: The left-hand side data type +/// - rhs_types: An iterable of right-hand side data types +/// - op: The binary operator +/// - expected_type: The type both sides should be coerced to +macro_rules! test_coercion_binary_rule_multiple { + ($LHS_TYPE:expr, $RHS_TYPES:expr, $OP:expr, $RESULT_TYPE:expr) => {{ + for rh_type in $RHS_TYPES { + let (lhs, rhs) = + BinaryTypeCoercer::new(&$LHS_TYPE, &$OP, &rh_type).get_input_types()?; + assert_eq!(lhs, $RESULT_TYPE); + assert_eq!(rhs, $RESULT_TYPE); + + BinaryTypeCoercer::new(&rh_type, &$OP, &$LHS_TYPE).get_input_types()?; + assert_eq!(lhs, $RESULT_TYPE); + assert_eq!(rhs, $RESULT_TYPE); + } + }}; +} + +/// Tests that the like_coercion function returns the expected result type for both lhs/rhs and rhs/lhs order. +/// +/// Usage: test_like_rule!(lhs_type, rhs_type, expected_type) +/// - lhs_type: The left-hand side data type +/// - rhs_type: The right-hand side data type +/// - expected_type: The expected result type from like_coercion +macro_rules! test_like_rule { + ($LHS_TYPE:expr, $RHS_TYPE:expr, $RESULT_TYPE:expr) => {{ + let result = like_coercion(&$LHS_TYPE, &$RHS_TYPE); + assert_eq!(result, $RESULT_TYPE); + let result = like_coercion(&$RHS_TYPE, &$LHS_TYPE); + assert_eq!(result, $RESULT_TYPE); + }}; +} + +mod arithmetic; +mod comparison; +mod dictionary; +mod null_coercion; diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/null_coercion.rs b/datafusion/expr-common/src/type_coercion/binary/tests/null_coercion.rs new file mode 100644 index 0000000000000..91c826b563c7c --- /dev/null +++ b/datafusion/expr-common/src/type_coercion/binary/tests/null_coercion.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::*; + +#[test] +fn test_type_coercion_logical_op() -> Result<()> { + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Boolean, + Operator::And, + DataType::Boolean + ); + + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Boolean, + Operator::Or, + DataType::Boolean + ); + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Null, + Operator::And, + DataType::Boolean + ); + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Null, + Operator::Or, + DataType::Boolean + ); + test_coercion_binary_rule!( + DataType::Null, + DataType::Null, + Operator::Or, + DataType::Boolean + ); + test_coercion_binary_rule!( + DataType::Null, + DataType::Null, + Operator::And, + DataType::Boolean + ); + test_coercion_binary_rule!( + DataType::Null, + DataType::Boolean, + Operator::And, + DataType::Boolean + ); + test_coercion_binary_rule!( + DataType::Null, + DataType::Boolean, + Operator::Or, + DataType::Boolean + ); + Ok(()) +} diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 37e1ed1936fb4..e6b2734cfff34 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,23 +38,28 @@ workspace = true name = "datafusion_expr" [features] +default = ["sql"] recursive_protection = ["dep:recursive"] +sql = ["sqlparser"] [dependencies] arrow = { workspace = true } +async-trait = { workspace = true } chrono = { workspace = true } -datafusion-common = { workspace = true } +datafusion-common = { workspace = true, default-features = false } datafusion-doc = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } indexmap = { workspace = true } +itertools = { workspace = true } paste = "^1.0" recursive = { workspace = true, optional = true } serde_json = { workspace = true } -sqlparser = { workspace = true } +sqlparser = { workspace = true, optional = true } [dev-dependencies] ctor = { workspace = true } env_logger = { workspace = true } +insta = { workspace = true } diff --git a/datafusion/expr/README.md b/datafusion/expr/README.md index b086f930e871b..b3ab9a383dbbd 100644 --- a/datafusion/expr/README.md +++ b/datafusion/expr/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Logical Plan and Expressions +# Apache DataFusion Logical Plan and Expressions -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides data types and utilities for logical plans and expressions. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs new file mode 100644 index 0000000000000..561ef1dc15e7d --- /dev/null +++ b/datafusion/expr/src/async_udf.rs @@ -0,0 +1,260 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +use arrow::datatypes::{DataType, FieldRef}; +use async_trait::async_trait; +use datafusion_common::error::Result; +use datafusion_common::internal_err; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::signature::Signature; +use std::any::Any; +use std::fmt::{Debug, Display}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// A scalar UDF that can invoke using async methods +/// +/// Note this is less efficient than the ScalarUDFImpl, but it can be used +/// to register remote functions in the context. +/// +/// The name is chosen to mirror ScalarUDFImpl +#[async_trait] +pub trait AsyncScalarUDFImpl: ScalarUDFImpl { + /// The ideal batch size for this function. + /// + /// This is used to determine what size of data to be evaluated at once. + /// If None, the whole batch will be evaluated at once. + fn ideal_batch_size(&self) -> Option { + None + } + + /// Invoke the function asynchronously with the async arguments + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result; +} + +/// A scalar UDF that must be invoked using async methods +/// +/// Note this is not meant to be used directly, but is meant to be an implementation detail +/// for AsyncUDFImpl. +#[derive(Debug)] +pub struct AsyncScalarUDF { + inner: Arc, +} + +impl PartialEq for AsyncScalarUDF { + fn eq(&self, other: &Self) -> bool { + // Deconstruct to catch any new fields added in future + let Self { inner } = self; + inner.dyn_eq(other.inner.as_any()) + } +} +impl Eq for AsyncScalarUDF {} + +impl Hash for AsyncScalarUDF { + fn hash(&self, state: &mut H) { + // Deconstruct to catch any new fields added in future + let Self { inner } = self; + inner.dyn_hash(state); + } +} + +impl AsyncScalarUDF { + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + /// The ideal batch size for this function + pub fn ideal_batch_size(&self) -> Option { + self.inner.ideal_batch_size() + } + + /// Turn this AsyncUDF into a ScalarUDF, suitable for + /// registering in the context + pub fn into_scalar_udf(self) -> ScalarUDF { + ScalarUDF::new_from_impl(self) + } + + /// Invoke the function asynchronously with the async arguments + pub async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + self.inner.invoke_async_with_args(args).await + } +} + +impl ScalarUDFImpl for AsyncScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("async functions should not be called directly") + } +} + +impl Display for AsyncScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "AsyncScalarUDF: {}", self.inner.name()) + } +} + +#[cfg(test)] +mod tests { + use std::{ + hash::{DefaultHasher, Hash, Hasher}, + sync::Arc, + }; + + use arrow::datatypes::DataType; + use async_trait::async_trait; + use datafusion_common::error::Result; + use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature}; + + use crate::{ + async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}, + ScalarFunctionArgs, ScalarUDFImpl, + }; + + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl1 { + a: i32, + } + + impl ScalarUDFImpl for TestAsyncUDFImpl1 { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + todo!() + } + + fn signature(&self) -> &Signature { + todo!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + todo!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + todo!() + } + } + + #[async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDFImpl1 { + async fn invoke_async_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + todo!() + } + } + + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl2 { + a: i32, + } + + impl ScalarUDFImpl for TestAsyncUDFImpl2 { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + todo!() + } + + fn signature(&self) -> &Signature { + todo!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + todo!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + todo!() + } + } + + #[async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDFImpl2 { + async fn invoke_async_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + todo!() + } + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } + + #[test] + fn test_async_udf_partial_eq_and_hash() { + // Inner is same cloned arc -> equal + let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 }); + let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc); + let b = AsyncScalarUDF::new(inner); + assert_eq!(a, b); + assert_eq!(hash(&a), hash(&b)); + + // Inner is distinct arc -> still equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + assert_eq!(a, b); + assert_eq!(hash(&a), hash(&b)); + + // Negative case: inner is different value -> not equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 })); + assert_ne!(a, b); + assert_ne!(hash(&a), hash(&b)); + + // Negative case: different functions -> not equal + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 })); + assert_ne!(a, b); + assert_ne!(hash(&a), hash(&b)); + } +} diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 9cb51612d0cab..d02f522910c19 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -20,8 +20,10 @@ use crate::expr::Case; use crate::{expr_schema::ExprSchemable, Expr}; use arrow::datatypes::DataType; use datafusion_common::{plan_err, DFSchema, HashSet, Result}; +use itertools::Itertools as _; /// Helper struct for building [Expr::Case] +#[derive(Debug, Clone)] pub struct CaseBuilder { expr: Option>, when_expr: Vec, @@ -72,7 +74,7 @@ impl CaseBuilder { let then_types: Vec = then_expr .iter() .map(|e| match e { - Expr::Literal(_) => e.get_type(&DFSchema::empty()), + Expr::Literal(_, _) => e.get_type(&DFSchema::empty()), _ => Ok(DataType::Null), }) .collect::>>()?; @@ -81,9 +83,12 @@ impl CaseBuilder { // Cannot verify types until execution type } else { let unique_types: HashSet<&DataType> = then_types.iter().collect(); - if unique_types.len() != 1 { + if unique_types.is_empty() { + return plan_err!("CASE expression 'then' values had no data types"); + } else if unique_types.len() != 1 { return plan_err!( - "CASE expression 'then' values had multiple data types: {unique_types:?}" + "CASE expression 'then' values had multiple data types: {}", + unique_types.iter().join(", ") ); } } diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index d672bd1acc460..d8a8c6bb49e19 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -18,6 +18,7 @@ use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, TimeZone, Utc}; use datafusion_common::alias::AliasGenerator; +use datafusion_common::config::ConfigOptions; use datafusion_common::HashMap; use std::sync::Arc; @@ -35,6 +36,8 @@ pub struct ExecutionProps { pub query_execution_start_time: DateTime, /// Alias generator used by subquery optimizer rules pub alias_generator: Arc, + /// Snapshot of config options when the query started + pub config_options: Option>, /// Providers for scalar variables pub var_providers: Option>>, } @@ -53,6 +56,7 @@ impl ExecutionProps { // not being updated / propagated correctly query_execution_start_time: Utc.timestamp_nanos(0), alias_generator: Arc::new(AliasGenerator::new()), + config_options: None, var_providers: None, } } @@ -66,11 +70,18 @@ impl ExecutionProps { self } + #[deprecated(since = "50.0.0", note = "Use mark_start_execution instead")] + pub fn start_execution(&mut self) -> &Self { + let default_config = Arc::new(ConfigOptions::default()); + self.mark_start_execution(default_config) + } + /// Marks the execution of query started timestamp. /// This also instantiates a new alias generator. - pub fn start_execution(&mut self) -> &Self { + pub fn mark_start_execution(&mut self, config_options: Arc) -> &Self { self.query_execution_start_time = Utc::now(); self.alias_generator = Arc::new(AliasGenerator::new()); + self.config_options = Some(config_options); &*self } @@ -99,6 +110,12 @@ impl ExecutionProps { .as_ref() .and_then(|var_providers| var_providers.get(&var_type).cloned()) } + + /// Returns the configuration properties for this execution + /// if the execution has started + pub fn config_options(&self) -> Option<&Arc> { + self.config_options.as_ref() + } } #[cfg(test)] @@ -107,6 +124,6 @@ mod test { #[test] fn debug() { let props = ExecutionProps::new(); - assert_eq!("ExecutionProps { query_execution_start_time: 1970-01-01T00:00:00Z, alias_generator: AliasGenerator { next_id: 1 }, var_providers: None }", format!("{props:?}")); + assert_eq!("ExecutionProps { query_execution_start_time: 1970-01-01T00:00:00Z, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None }", format!("{props:?}")); } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9f6855b698243..282b3f6a0f55c 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,18 +17,20 @@ //! Logical Expressions: [`Expr`] -use std::collections::HashSet; +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashSet}; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; use std::sync::Arc; use crate::expr_fn::binary_expr; +use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; -use crate::Volatility; -use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; +use crate::{AggregateUDF, Volatility}; +use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; -use arrow::datatypes::{DataType, FieldRef}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, @@ -37,11 +39,39 @@ use datafusion_common::{ Column, DFSchema, HashMap, Result, ScalarValue, Spans, TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +#[cfg(feature = "sql")] use sqlparser::ast::{ display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, - NullTreatment, RenameSelectItem, ReplaceSelectElement, + RenameSelectItem, ReplaceSelectElement, }; +// This mirrors sqlparser::ast::NullTreatment but we need our own variant +// for when the sql feature is disabled. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd)] +pub enum NullTreatment { + IgnoreNulls, + RespectNulls, +} + +impl Display for NullTreatment { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(match self { + NullTreatment::IgnoreNulls => "IGNORE NULLS", + NullTreatment::RespectNulls => "RESPECT NULLS", + }) + } +} + +#[cfg(feature = "sql")] +impl From for NullTreatment { + fn from(value: sqlparser::ast::NullTreatment) -> Self { + match value { + sqlparser::ast::NullTreatment::IgnoreNulls => Self::IgnoreNulls, + sqlparser::ast::NullTreatment::RespectNulls => Self::RespectNulls, + } + } +} + /// Represents logical expressions such as `A + 1`, or `CAST(c1 AS int)`. /// /// For example the expression `A + 1` will be represented as @@ -50,7 +80,7 @@ use sqlparser::ast::{ /// BinaryExpr { /// left: Expr::Column("A"), /// op: Operator::Plus, -/// right: Expr::Literal(ScalarValue::Int32(Some(1))) +/// right: Expr::Literal(ScalarValue::Int32(Some(1)), None) /// } /// ``` /// @@ -112,10 +142,10 @@ use sqlparser::ast::{ /// # use datafusion_expr::{lit, col, Expr}; /// // All literals are strongly typed in DataFusion. To make an `i64` 42: /// let expr = lit(42i64); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)), None)); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)), None)); /// // To make a (typed) NULL: -/// let expr = Expr::Literal(ScalarValue::Int64(None)); +/// let expr = Expr::Literal(ScalarValue::Int64(None), None); /// // to make an (untyped) NULL (the optimizer will coerce this to the correct type): /// let expr = lit(ScalarValue::Null); /// ``` @@ -149,7 +179,7 @@ use sqlparser::ast::{ /// if let Expr::BinaryExpr(binary_expr) = expr { /// assert_eq!(*binary_expr.left, col("c1")); /// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*binary_expr.right, Expr::Literal(scalar)); +/// assert_eq!(*binary_expr.right, Expr::Literal(scalar, None)); /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` @@ -193,7 +223,7 @@ use sqlparser::ast::{ /// ``` /// # use datafusion_expr::{lit, col}; /// let expr = col("c1") + lit(42); -/// assert_eq!(format!("{expr:?}"), "BinaryExpr(BinaryExpr { left: Column(Column { relation: None, name: \"c1\" }), op: Plus, right: Literal(Int32(42)) })"); +/// assert_eq!(format!("{expr:?}"), "BinaryExpr(BinaryExpr { left: Column(Column { relation: None, name: \"c1\" }), op: Plus, right: Literal(Int32(42), None) })"); /// ``` /// /// ## Use the `Display` trait (detailed expression) @@ -239,7 +269,7 @@ use sqlparser::ast::{ /// let mut scalars = HashSet::new(); /// // apply recursively visits all nodes in the expression tree /// expr.apply(|e| { -/// if let Expr::Literal(scalar) = e { +/// if let Expr::Literal(scalar, _) = e { /// scalars.insert(scalar); /// } /// // The return value controls whether to continue visiting the tree @@ -274,7 +304,7 @@ use sqlparser::ast::{ /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub enum Expr { /// An expression with a specific name. Alias(Alias), @@ -282,8 +312,8 @@ pub enum Expr { Column(Column), /// A named reference to a variable in a registry. ScalarVariable(DataType, Vec), - /// A constant value. - Literal(ScalarValue), + /// A constant value along with associated [`FieldMetadata`]. + Literal(ScalarValue, Option), /// A binary expression such as "age > 21" BinaryExpr(BinaryExpr), /// LIKE expression @@ -312,27 +342,7 @@ pub enum Expr { Negative(Box), /// Whether an expression is between a given range. Between(Between), - /// The CASE expression is similar to a series of nested if/else and there are two forms that - /// can be used. The first form consists of a series of boolean "when" expressions with - /// corresponding "then" expressions, and an optional "else" expression. - /// - /// ```text - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// ``` - /// - /// The second form uses a base expression and then a series of "when" clauses that match on a - /// literal value. - /// - /// ```text - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// ``` + /// A CASE expression (see docs on [`Case`]) Case(Case), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. @@ -340,7 +350,7 @@ pub enum Expr { /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast(TryCast), - /// Represents the call of a scalar function with a set of arguments. + /// Call a scalar function with a set of arguments. ScalarFunction(ScalarFunction), /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. @@ -349,8 +359,8 @@ pub enum Expr { /// /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), - /// Represents the call of a window function with arguments. - WindowFunction(WindowFunction), + /// Call a window function with a set of arguments. + WindowFunction(Box), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -378,16 +388,22 @@ pub enum Expr { /// A place holder for parameters in a prepared statement /// (e.g. `$foo` or `$1`) Placeholder(Placeholder), - /// A place holder which hold a reference to a qualified field + /// A placeholder which holds a reference to a qualified field /// in the outer query, used for correlated sub queries. - OuterReferenceColumn(DataType, Column), + OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), } impl Default for Expr { fn default() -> Self { - Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null, None) + } +} + +impl AsRef for Expr { + fn as_ref(&self) -> &Expr { + self } } @@ -398,6 +414,13 @@ impl From for Expr { } } +/// Create an [`Expr`] from a [`WindowFunction`] +impl From for Expr { + fn from(value: WindowFunction) -> Self { + Expr::WindowFunction(Box::new(value)) + } +} + /// Create an [`Expr`] from an optional qualifier and a [`FieldRef`]. This is /// useful for creating [`Expr`] from a [`DFSchema`]. /// @@ -424,6 +447,294 @@ impl<'a> TreeNodeContainer<'a, Self> for Expr { } } +/// Literal metadata +/// +/// Stores metadata associated with a literal expressions +/// and is designed to be fast to `clone`. +/// +/// This structure is used to store metadata associated with a literal expression, and it +/// corresponds to the `metadata` field on [`Field`]. +/// +/// # Example: Create [`FieldMetadata`] from a [`Field`] +/// ``` +/// # use std::collections::HashMap; +/// # use datafusion_expr::expr::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true) +/// # .with_metadata(HashMap::from([("foo".to_string(), "bar".to_string())])); +/// // Create a new `FieldMetadata` instance from a `Field` +/// let metadata = FieldMetadata::new_from_field(&field); +/// // There is also a `From` impl: +/// let metadata = FieldMetadata::from(&field); +/// ``` +/// +/// # Example: Update a [`Field`] with [`FieldMetadata`] +/// ``` +/// # use datafusion_expr::expr::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true); +/// # let metadata = FieldMetadata::new_from_field(&field); +/// // Add any metadata from `FieldMetadata` to `Field` +/// let updated_field = metadata.add_to_field(field); +/// ``` +/// +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct FieldMetadata { + /// The inner metadata of a literal expression, which is a map of string + /// keys to string values. + /// + /// Note this is not a `HashMap` because `HashMap` does not provide + /// implementations for traits like `Debug` and `Hash`. + inner: Arc>, +} + +impl Default for FieldMetadata { + fn default() -> Self { + Self::new_empty() + } +} + +impl FieldMetadata { + /// Create a new empty metadata instance. + pub fn new_empty() -> Self { + Self { + inner: Arc::new(BTreeMap::new()), + } + } + + /// Merges two optional `FieldMetadata` instances, overwriting any existing + /// keys in `m` with keys from `n` if present. + /// + /// This function is commonly used in alias operations, particularly for literals + /// with metadata. When creating an alias expression, the metadata from the original + /// expression (such as a literal) is combined with any metadata specified on the alias. + /// + /// # Arguments + /// + /// * `m` - The first metadata (typically from the original expression like a literal) + /// * `n` - The second metadata (typically from the alias definition) + /// + /// # Merge Strategy + /// + /// - If both metadata instances exist, they are merged with `n` taking precedence + /// - Keys from `n` will overwrite keys from `m` if they have the same name + /// - If only one metadata instance exists, it is returned unchanged + /// - If neither exists, `None` is returned + /// + /// # Example usage + /// ```rust + /// use datafusion_expr::expr::FieldMetadata; + /// use std::collections::BTreeMap; + /// + /// // Create metadata for a literal expression + /// let literal_metadata = Some(FieldMetadata::from(BTreeMap::from([ + /// ("source".to_string(), "constant".to_string()), + /// ("type".to_string(), "int".to_string()), + /// ]))); + /// + /// // Create metadata for an alias + /// let alias_metadata = Some(FieldMetadata::from(BTreeMap::from([ + /// ("description".to_string(), "answer".to_string()), + /// ("source".to_string(), "user".to_string()), // This will override literal's "source" + /// ]))); + /// + /// // Merge the metadata + /// let merged = FieldMetadata::merge_options( + /// literal_metadata.as_ref(), + /// alias_metadata.as_ref(), + /// ); + /// + /// // Result contains: {"source": "user", "type": "int", "description": "answer"} + /// assert!(merged.is_some()); + /// ``` + pub fn merge_options( + m: Option<&FieldMetadata>, + n: Option<&FieldMetadata>, + ) -> Option { + match (m, n) { + (Some(m), Some(n)) => { + let mut merged = m.clone(); + merged.extend(n.clone()); + Some(merged) + } + (Some(m), None) => Some(m.clone()), + (None, Some(n)) => Some(n.clone()), + (None, None) => None, + } + } + + /// Create a new metadata instance from a `Field`'s metadata. + pub fn new_from_field(field: &Field) -> Self { + let inner = field + .metadata() + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self { + inner: Arc::new(inner), + } + } + + /// Create a new metadata instance from a map of string keys to string values. + pub fn new(inner: BTreeMap) -> Self { + Self { + inner: Arc::new(inner), + } + } + + /// Get the inner metadata as a reference to a `BTreeMap`. + pub fn inner(&self) -> &BTreeMap { + &self.inner + } + + /// Return the inner metadata + pub fn into_inner(self) -> Arc> { + self.inner + } + + /// Adds metadata from `other` into `self`, overwriting any existing keys. + pub fn extend(&mut self, other: Self) { + if other.is_empty() { + return; + } + let other = Arc::unwrap_or_clone(other.into_inner()); + Arc::make_mut(&mut self.inner).extend(other); + } + + /// Returns true if the metadata is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns the number of key-value pairs in the metadata. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Convert this `FieldMetadata` into a `HashMap` + pub fn to_hashmap(&self) -> std::collections::HashMap { + self.inner + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } + + /// Updates the metadata on the Field with this metadata, if it is not empty. + pub fn add_to_field(&self, field: Field) -> Field { + if self.inner.is_empty() { + return field; + } + + field.with_metadata(self.to_hashmap()) + } +} + +impl From<&Field> for FieldMetadata { + fn from(field: &Field) -> Self { + Self::new_from_field(field) + } +} + +impl From> for FieldMetadata { + fn from(inner: BTreeMap) -> Self { + Self::new(inner) + } +} + +impl From> for FieldMetadata { + fn from(map: std::collections::HashMap) -> Self { + Self::new(map.into_iter().collect()) + } +} + +/// From reference +impl From<&std::collections::HashMap> for FieldMetadata { + fn from(map: &std::collections::HashMap) -> Self { + let inner = map + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} + +/// From hashbrown map +impl From> for FieldMetadata { + fn from(map: HashMap) -> Self { + let inner = map.into_iter().collect(); + Self::new(inner) + } +} + +impl From<&HashMap> for FieldMetadata { + fn from(map: &HashMap) -> Self { + let inner = map + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} + +/// The metadata used in [`Field::metadata`]. +/// +/// This represents the metadata associated with an Arrow [`Field`]. The metadata consists of key-value pairs. +/// +/// # Common Use Cases +/// +/// Field metadata is commonly used to store: +/// - Default values for columns when data is missing +/// - Column descriptions or documentation +/// - Data lineage information +/// - Custom application-specific annotations +/// - Encoding hints or display formatting preferences +/// +/// # Example: Storing Default Values +/// +/// A practical example of using field metadata is storing default values for columns +/// that may be missing in the physical data but present in the logical schema. +/// See the [default_column_values.rs] example implementation. +/// +/// [default_column_values.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/default_column_values.rs +pub type SchemaFieldMetadata = std::collections::HashMap; + +/// Intersects multiple metadata instances for UNION operations. +/// +/// This function implements the intersection strategy used by UNION operations, +/// where only metadata keys that exist in ALL inputs with identical values +/// are preserved in the result. +/// +/// # Union Metadata Behavior +/// +/// Union operations require consistent metadata across all branches: +/// - Only metadata keys present in ALL union branches are kept +/// - For each kept key, the value must be identical across all branches +/// - If a key has different values across branches, it is excluded from the result +/// - If any input has no metadata, the result will be empty +/// +/// # Arguments +/// +/// * `metadatas` - An iterator of `SchemaFieldMetadata` instances to intersect +/// +/// # Returns +/// +/// A new `SchemaFieldMetadata` containing only the intersected metadata +pub fn intersect_metadata_for_union<'a>( + metadatas: impl IntoIterator, +) -> SchemaFieldMetadata { + let mut metadatas = metadatas.into_iter(); + let Some(mut intersected) = metadatas.next().cloned() else { + return Default::default(); + }; + + for metadata in metadatas { + // Only keep keys that exist in both with the same value + intersected.retain(|k, v| metadata.get(k) == Some(v)); + } + + intersected +} + /// UNNEST expression. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { @@ -450,7 +761,7 @@ pub struct Alias { pub expr: Box, pub relation: Option, pub name: String, - pub metadata: Option>, + pub metadata: Option, } impl Hash for Alias { @@ -462,16 +773,19 @@ impl Hash for Alias { } impl PartialOrd for Alias { - fn partial_cmp(&self, other: &Self) -> Option { + fn partial_cmp(&self, other: &Self) -> Option { let cmp = self.expr.partial_cmp(&other.expr); - let Some(std::cmp::Ordering::Equal) = cmp else { + let Some(Ordering::Equal) = cmp else { return cmp; }; let cmp = self.relation.partial_cmp(&other.relation); - let Some(std::cmp::Ordering::Equal) = cmp else { + let Some(Ordering::Equal) = cmp else { return cmp; }; - self.name.partial_cmp(&other.name) + self.name + .partial_cmp(&other.name) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -490,10 +804,7 @@ impl Alias { } } - pub fn with_metadata( - mut self, - metadata: Option>, - ) -> Self { + pub fn with_metadata(mut self, metadata: Option) -> Self { self.metadata = metadata; self } @@ -551,6 +862,28 @@ impl Display for BinaryExpr { } /// CASE expression +/// +/// The CASE expression is similar to a series of nested if/else and there are two forms that +/// can be used. The first form consists of a series of boolean "when" expressions with +/// corresponding "then" expressions, and an optional "else" expression. +/// +/// ```text +/// CASE WHEN condition THEN result +/// [WHEN ...] +/// [ELSE result] +/// END +/// ``` +/// +/// The second form uses a base expression and then a series of "when" clauses that match on a +/// literal value. +/// +/// ```text +/// CASE expression +/// WHEN value THEN result +/// [WHEN ...] +/// [ELSE result] +/// END +/// ``` #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)] pub struct Case { /// Optional base expression that can be compared to literal values in the "when" expressions @@ -631,7 +964,9 @@ impl Between { } } -/// ScalarFunction expression invokes a built-in scalar function +/// Invoke a [`ScalarUDF`] with a set of arguments +/// +/// [`ScalarUDF`]: crate::ScalarUDF #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct ScalarFunction { /// The function @@ -648,7 +983,9 @@ impl ScalarFunction { } impl ScalarFunction { - /// Create a new ScalarFunction expression with a user-defined function (UDF) + /// Create a new `ScalarFunction` from a [`ScalarUDF`] + /// + /// [`ScalarUDF`]: crate::ScalarUDF pub fn new_udf(udf: Arc, args: Vec) -> Self { Self { func: udf, args } } @@ -784,7 +1121,7 @@ impl<'a> TreeNodeContainer<'a, Expr> for Sort { #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub func: Arc, + pub func: Arc, pub params: AggregateFunctionParams, } @@ -796,18 +1133,18 @@ pub struct AggregateFunctionParams { /// Optional filter pub filter: Option>, /// Optional ordering - pub order_by: Option>, + pub order_by: Vec, pub null_treatment: Option, } impl AggregateFunction { /// Create a new AggregateFunction expression with a user-defined function (UDF) pub fn new_udf( - func: Arc, + func: Arc, args: Vec, distinct: bool, filter: Option>, - order_by: Option>, + order_by: Vec, null_treatment: Option, ) -> Self { Self { @@ -831,26 +1168,25 @@ impl AggregateFunction { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WindowFunctionDefinition { /// A user defined aggregate function - AggregateUDF(Arc), + AggregateUDF(Arc), /// A user defined aggregate function WindowUDF(Arc), } impl WindowFunctionDefinition { /// Returns the datatype of the window function - pub fn return_type( + pub fn return_field( &self, - input_expr_types: &[DataType], - _input_expr_nullable: &[bool], + input_expr_fields: &[FieldRef], display_name: &str, - ) -> Result { + ) -> Result { match self { WindowFunctionDefinition::AggregateUDF(fun) => { - fun.return_type(input_expr_types) + fun.return_field(input_expr_fields) + } + WindowFunctionDefinition::WindowUDF(fun) => { + fun.field(WindowUDFFieldArgs::new(input_expr_fields, display_name)) } - WindowFunctionDefinition::WindowUDF(fun) => fun - .field(WindowUDFFieldArgs::new(input_expr_types, display_name)) - .map(|field| field.data_type().clone()), } } @@ -869,6 +1205,16 @@ impl WindowFunctionDefinition { WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + match self { + WindowFunctionDefinition::AggregateUDF(_) => None, + WindowFunctionDefinition::WindowUDF(udwf) => udwf.simplify(), + } + } } impl Display for WindowFunctionDefinition { @@ -880,8 +1226,8 @@ impl Display for WindowFunctionDefinition { } } -impl From> for WindowFunctionDefinition { - fn from(value: Arc) -> Self { +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { Self::AggregateUDF(value) } } @@ -921,8 +1267,12 @@ pub struct WindowFunctionParams { pub order_by: Vec, /// Window frame pub window_frame: WindowFrame, + /// Optional filter expression (FILTER (WHERE ...)) + pub filter: Option>, /// Specifies how NULL value is treated: ignore or respect pub null_treatment: Option, + /// Distinct flag + pub distinct: bool, } impl WindowFunction { @@ -936,10 +1286,19 @@ impl WindowFunction { partition_by: Vec::default(), order_by: Vec::default(), window_frame: WindowFrame::new(None), + filter: None, null_treatment: None, + distinct: false, }, } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + self.fun.simplify() + } } /// EXISTS expression @@ -958,38 +1317,6 @@ impl Exists { } } -/// User Defined Aggregate Function -/// -/// See [`udaf::AggregateUDF`] for more information. -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct AggregateUDF { - /// The function - pub fun: Arc, - /// List of expressions to feed to the functions as arguments - pub args: Vec, - /// Optional filter - pub filter: Option>, - /// Optional ORDER BY applied prior to aggregating - pub order_by: Option>, -} - -impl AggregateUDF { - /// Create a new AggregateUDF expression - pub fn new( - fun: Arc, - args: Vec, - filter: Option>, - order_by: Option>, - ) -> Self { - Self { - fun, - args, - filter, - order_by, - } - } -} - /// InList expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct InList { @@ -1091,6 +1418,130 @@ impl GroupingSet { } } +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[cfg(not(feature = "sql"))] +pub struct IlikeSelectItem { + pub pattern: String, +} +#[cfg(not(feature = "sql"))] +impl Display for IlikeSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "ILIKE '{}'", &self.pattern)?; + Ok(()) + } +} +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[cfg(not(feature = "sql"))] +pub enum ExcludeSelectItem { + Single(Ident), + Multiple(Vec), +} +#[cfg(not(feature = "sql"))] +impl Display for ExcludeSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "EXCLUDE")?; + match self { + Self::Single(column) => { + write!(f, " {column}")?; + } + Self::Multiple(columns) => { + write!(f, " ({})", display_comma_separated(columns))?; + } + } + Ok(()) + } +} +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[cfg(not(feature = "sql"))] +pub struct ExceptSelectItem { + pub first_element: Ident, + pub additional_elements: Vec, +} +#[cfg(not(feature = "sql"))] +impl Display for ExceptSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "EXCEPT ")?; + if self.additional_elements.is_empty() { + write!(f, "({})", self.first_element)?; + } else { + write!( + f, + "({}, {})", + self.first_element, + display_comma_separated(&self.additional_elements) + )?; + } + Ok(()) + } +} + +#[cfg(not(feature = "sql"))] +pub fn display_comma_separated(slice: &[T]) -> String +where + T: Display, +{ + use itertools::Itertools; + slice.iter().map(|v| format!("{v}")).join(", ") +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[cfg(not(feature = "sql"))] +pub enum RenameSelectItem { + Single(String), + Multiple(Vec), +} +#[cfg(not(feature = "sql"))] +impl Display for RenameSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "RENAME")?; + match self { + Self::Single(column) => { + write!(f, " {column}")?; + } + Self::Multiple(columns) => { + write!(f, " ({})", display_comma_separated(columns))?; + } + } + Ok(()) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[cfg(not(feature = "sql"))] +pub struct Ident { + /// The value of the identifier without quotes. + pub value: String, + /// The starting quote if any. Valid quote characters are the single quote, + /// double quote, backtick, and opening square bracket. + pub quote_style: Option, + /// The span of the identifier in the original SQL string. + pub span: String, +} +#[cfg(not(feature = "sql"))] +impl Display for Ident { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "[{}]", self.value) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[cfg(not(feature = "sql"))] +pub struct ReplaceSelectElement { + pub expr: String, + pub column_name: Ident, + pub as_keyword: bool, +} +#[cfg(not(feature = "sql"))] +impl Display for ReplaceSelectElement { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.as_keyword { + write!(f, "{} AS {}", self.expr, self.column_name) + } else { + write!(f, "{} {}", self.expr, self.column_name) + } + } +} + /// Additional options for wildcards, e.g. Snowflake `EXCLUDE`/`RENAME` and Bigquery `EXCEPT`. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug, Default)] pub struct WildcardOptions { @@ -1243,12 +1694,6 @@ impl Expr { } } - /// Returns a full and complete string representation of this expression. - #[deprecated(since = "42.0.0", note = "use format! instead")] - pub fn canonical_name(&self) -> String { - format!("{self}") - } - /// Return String representation of the variant represented by `self` /// Useful for non-rust based bindings pub fn variant_name(&self) -> &str { @@ -1397,15 +1842,17 @@ impl Expr { /// # Example /// ``` /// # use datafusion_expr::col; - /// use std::collections::HashMap; + /// # use std::collections::HashMap; + /// # use datafusion_expr::expr::FieldMetadata; /// let metadata = HashMap::from([("key".to_string(), "value".to_string())]); + /// let metadata = FieldMetadata::from(metadata); /// let expr = col("foo").alias_with_metadata("bar", Some(metadata)); /// ``` /// pub fn alias_with_metadata( self, name: impl Into, - metadata: Option>, + metadata: Option, ) -> Expr { Expr::Alias(Alias::new(self, None::<&str>, name.into()).with_metadata(metadata)) } @@ -1427,8 +1874,10 @@ impl Expr { /// # Example /// ``` /// # use datafusion_expr::col; - /// use std::collections::HashMap; + /// # use std::collections::HashMap; + /// # use datafusion_expr::expr::FieldMetadata; /// let metadata = HashMap::from([("key".to_string(), "value".to_string())]); + /// let metadata = FieldMetadata::from(metadata); /// let expr = col("foo").alias_qualified_with_metadata(Some("tbl"), "bar", Some(metadata)); /// ``` /// @@ -1436,7 +1885,7 @@ impl Expr { self, relation: Option>, name: impl Into, - metadata: Option>, + metadata: Option, ) -> Expr { Expr::Alias(Alias::new(self, relation, name.into()).with_metadata(metadata)) } @@ -1506,8 +1955,16 @@ impl Expr { |expr| { // f_up: unalias on up so we can remove nested aliases like // `(x as foo) as bar` - if let Expr::Alias(Alias { expr, .. }) = expr { - Ok(Transformed::yes(*expr)) + if let Expr::Alias(alias) = expr { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } } else { Ok(Transformed::no(expr)) } @@ -1747,23 +2204,38 @@ impl Expr { pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; self.transform(|mut expr| { - // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { - rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; - rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; - }; - if let Expr::Between(Between { - expr, - negated: _, - low, - high, - }) = &mut expr - { - rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; - rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; - } - if let Expr::Placeholder(_) = &expr { - has_placeholder = true; + match &mut expr { + // Default to assuming the arguments are the same type + Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { + rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; + } + Expr::Between(Between { + expr, + negated: _, + low, + high, + }) => { + rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; + rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; + } + Expr::InList(InList { + expr, + list, + negated: _, + }) => { + for item in list.iter_mut() { + rewrite_placeholder(item, expr.as_ref(), schema)?; + } + } + Expr::Like(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; + } + Expr::Placeholder(_) => { + has_placeholder = true; + } + _ => {} } Ok(Transformed::yes(expr)) }) @@ -1827,6 +2299,15 @@ impl Expr { _ => None, } } + + /// Check if the Expr is literal and get the literal value if it is. + pub fn as_literal(&self) -> Option<&ScalarValue> { + if let Expr::Literal(lit, _) = self { + Some(lit) + } else { + None + } + } } impl Normalizeable for Expr { @@ -2052,48 +2533,51 @@ impl NormalizeEq for Expr { (None, None) => true, _ => false, } - && match (self_order_by, other_order_by) { - (Some(self_order_by), Some(other_order_by)) => self_order_by - .iter() - .zip(other_order_by.iter()) - .all(|(a, b)| { - a.asc == b.asc - && a.nulls_first == b.nulls_first - && a.expr.normalize_eq(&b.expr) - }), - (None, None) => true, - _ => false, - } + && self_order_by + .iter() + .zip(other_order_by.iter()) + .all(|(a, b)| { + a.asc == b.asc + && a.nulls_first == b.nulls_first + && a.expr.normalize_eq(&b.expr) + }) + && self_order_by.len() == other_order_by.len() } - ( - Expr::WindowFunction(WindowFunction { + (Expr::WindowFunction(left), Expr::WindowFunction(other)) => { + let WindowFunction { fun: self_fun, - params: self_params, - }), - Expr::WindowFunction(WindowFunction { + params: + WindowFunctionParams { + args: self_args, + window_frame: self_window_frame, + partition_by: self_partition_by, + order_by: self_order_by, + filter: self_filter, + null_treatment: self_null_treatment, + distinct: self_distinct, + }, + } = left.as_ref(); + let WindowFunction { fun: other_fun, - params: other_params, - }), - ) => { - let ( - WindowFunctionParams { - args: self_args, - window_frame: self_window_frame, - partition_by: self_partition_by, - order_by: self_order_by, - null_treatment: self_null_treatment, - }, - WindowFunctionParams { - args: other_args, - window_frame: other_window_frame, - partition_by: other_partition_by, - order_by: other_order_by, - null_treatment: other_null_treatment, - }, - ) = (self_params, other_params); + params: + WindowFunctionParams { + args: other_args, + window_frame: other_window_frame, + partition_by: other_partition_by, + order_by: other_order_by, + filter: other_filter, + null_treatment: other_null_treatment, + distinct: other_distinct, + }, + } = other.as_ref(); self_fun.name() == other_fun.name() && self_window_frame == other_window_frame + && match (self_filter, other_filter) { + (Some(a), Some(b)) => a.normalize_eq(b), + (None, None) => true, + _ => false, + } && self_null_treatment == other_null_treatment && self_args.len() == other_args.len() && self_args @@ -2112,6 +2596,7 @@ impl NormalizeEq for Expr { && a.nulls_first == b.nulls_first && a.expr.normalize_eq(&b.expr) }) + && self_distinct == other_distinct } ( Expr::Exists(Exists { @@ -2256,7 +2741,7 @@ impl HashNode for Expr { data_type.hash(state); name.hash(state); } - Expr::Literal(scalar_value) => { + Expr::Literal(scalar_value, _) => { scalar_value.hash(state); } Expr::BinaryExpr(BinaryExpr { @@ -2335,17 +2820,25 @@ impl HashNode for Expr { distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { fun, params }) => { - let WindowFunctionParams { - args: _args, - partition_by: _, - order_by: _, - window_frame, - null_treatment, - } = params; + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args: _args, + partition_by: _, + order_by: _, + window_frame, + filter, + null_treatment, + distinct, + }, + } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); + filter.hash(state); null_treatment.hash(state); + distinct.hash(state); } Expr::InList(InList { expr: _expr, @@ -2384,8 +2877,8 @@ impl HashNode for Expr { Expr::Placeholder(place_holder) => { place_holder.hash(state); } - Expr::OuterReferenceColumn(data_type, column) => { - data_type.hash(state); + Expr::OuterReferenceColumn(field, column) => { + field.hash(state); column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} @@ -2432,7 +2925,7 @@ impl Display for SchemaDisplay<'_> { // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Column(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::ScalarVariable(..) | Expr::OuterReferenceColumn(..) | Expr::Placeholder(_) @@ -2443,7 +2936,7 @@ impl Display for SchemaDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -2594,7 +3087,7 @@ impl Display for SchemaDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -2625,52 +3118,79 @@ impl Display for SchemaDisplay<'_> { Ok(()) } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_schema_name(params) { - Ok(name) => { - write!(f, "{name}") - } - Err(e) => { - write!(f, "got error from window_function_schema_name {}", e) + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_schema_name {e}" + ) + } } } - } - _ => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - write!( - f, - "{}({})", - fun, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; - - if let Some(null_treatment) = null_treatment { - write!(f, " {}", null_treatment)?; - } + _ => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + } = params; + + // Write function name and open parenthesis + write!(f, "{fun}(")?; + + // If DISTINCT, emit the keyword + if *distinct { + write!(f, "DISTINCT ")?; + } - if !partition_by.is_empty() { + // Write the comma‑separated argument list write!( f, - " PARTITION BY [{}]", - schema_name_from_exprs(partition_by)? + "{}", + schema_name_from_exprs_comma_separated_without_space(args)? )?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; + // **Close the argument parenthesis** + write!(f, ")")?; + + if let Some(null_treatment) = null_treatment { + write!(f, " {null_treatment}")?; + } + + if let Some(filter) = filter { + write!(f, " FILTER (WHERE {filter})")?; + } + + if !partition_by.is_empty() { + write!( + f, + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + )?; + } + + if !order_by.is_empty() { + write!( + f, + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + )?; + }; - write!(f, " {window_frame}") + write!(f, " {window_frame}") + } } - }, + } } } } @@ -2681,7 +3201,7 @@ struct SqlDisplay<'a>(&'a Expr); impl Display for SqlDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self.0 { - Expr::Literal(scalar) => scalar.fmt(f), + Expr::Literal(scalar, _) => scalar.fmt(f), Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), Expr::Between(Between { expr, @@ -2847,7 +3367,7 @@ impl Display for SqlDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -2948,7 +3468,12 @@ impl Display for Expr { write!(f, "{OUTER_REFERENCE_COLUMN_PREFIX}({c})") } Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{v:?}"), + Expr::Literal(v, metadata) => { + match metadata.as_ref().map(|m| m.is_empty()).unwrap_or(true) { + false => write!(f, "{v:?} {:?}", metadata.as_ref().unwrap()), + true => write!(f, "{v:?}"), + } + } Expr::Case(case) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { @@ -2963,10 +3488,10 @@ impl Display for Expr { write!(f, "END") } Expr::Cast(Cast { expr, data_type }) => { - write!(f, "CAST({expr} AS {data_type:?})") + write!(f, "CAST({expr} AS {data_type})") } Expr::TryCast(TryCast { expr, data_type }) => { - write!(f, "TRY_CAST({expr} AS {data_type:?})") + write!(f, "TRY_CAST({expr} AS {data_type})") } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), @@ -3001,58 +3526,66 @@ impl Display for Expr { Expr::ScalarFunction(fun) => { fmt_function(f, fun.name(), false, &fun.args, true) } - // TODO: use udf's display_name, need to fix the separator issue, - // Expr::ScalarFunction(ScalarFunction { func, args }) => { - // write!(f, "{}", func.display_name(args).unwrap()) - // } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_display_name(params) { - Ok(name) => { - write!(f, "{}", name) - } - Err(e) => { - write!(f, "got error from window_function_display_name {}", e) + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_display_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_display_name {e}" + ) + } } } - } - WindowFunctionDefinition::WindowUDF(fun) => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - fmt_function(f, &fun.to_string(), false, args, true)?; + WindowFunctionDefinition::WindowUDF(fun) => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + } = params; + + fmt_function(f, &fun.to_string(), *distinct, args, true)?; + + if let Some(nt) = null_treatment { + write!(f, "{nt}")?; + } - if let Some(nt) = null_treatment { - write!(f, "{}", nt)?; - } + if let Some(fe) = filter { + write!(f, " FILTER (WHERE {fe})")?; + } - if !partition_by.is_empty() { - write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + } + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + ) } - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - ) } - }, + } Expr::AggregateFunction(AggregateFunction { func, params }) => { match func.display_name(params) { Ok(name) => { - write!(f, "{}", name) + write!(f, "{name}") } Err(e) => { - write!(f, "got error from display_name {}", e) + write!(f, "got error from display_name {e}") } } } @@ -3185,32 +3718,135 @@ mod test { case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, }; + use arrow::datatypes::{Field, Schema}; use sqlparser::ast; use sqlparser::ast::{Ident, IdentWithAlias}; use std::any::Any; #[test] - #[allow(deprecated)] + fn infer_placeholder_in_clause() { + // SELECT * FROM employees WHERE department_id IN ($1, $2, $3); + let column = col("department_id"); + let param_placeholders = vec![ + Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + }), + Expr::Placeholder(Placeholder { + id: "$2".to_string(), + data_type: None, + }), + Expr::Placeholder(Placeholder { + id: "$3".to_string(), + data_type: None, + }), + ]; + let in_list = Expr::InList(InList { + expr: Box::new(column), + list: param_placeholders, + negated: false, + }); + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("department_id", DataType::Int32, true), + ])); + let df_schema = DFSchema::try_from(schema).unwrap(); + + let (inferred_expr, contains_placeholder) = + in_list.infer_placeholder_types(&df_schema).unwrap(); + + assert!(contains_placeholder); + + match inferred_expr { + Expr::InList(in_list) => { + for expr in in_list.list { + match expr { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Int32), + "Placeholder {} should infer Int32", + placeholder.id + ); + } + _ => panic!("Expected Placeholder expression"), + } + } + } + _ => panic!("Expected InList expression"), + } + } + + #[test] + fn infer_placeholder_like_and_similar_to() { + // name LIKE $1 + let schema = + Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); + let df_schema = DFSchema::try_from(schema).unwrap(); + + let like = Like { + expr: Box::new(col("name")), + pattern: Box::new(Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + })), + negated: false, + case_insensitive: false, + escape_char: None, + }; + + let expr = Expr::Like(like.clone()); + + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); + match inferred_expr { + Expr::Like(like) => match *like.pattern { + Expr::Placeholder(placeholder) => { + assert_eq!(placeholder.data_type, Some(DataType::Utf8)); + } + _ => panic!("Expected Placeholder"), + }, + _ => panic!("Expected Like"), + } + + // name SIMILAR TO $1 + let expr = Expr::SimilarTo(like); + + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); + match inferred_expr { + Expr::SimilarTo(like) => match *like.pattern { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Utf8), + "Placeholder {} should infer Utf8", + placeholder.id + ); + } + _ => panic!("Expected Placeholder expression"), + }, + _ => panic!("Expected SimilarTo expression"), + } + } + + #[test] fn format_case_when() -> Result<()> { let expr = case(col("a")) .when(lit(1), lit(true)) .when(lit(0), lit(false)) .otherwise(lit(ScalarValue::Null))?; let expected = "CASE a WHEN Int32(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END"; - assert_eq!(expected, expr.canonical_name()); assert_eq!(expected, format!("{expr}")); Ok(()) } #[test] - #[allow(deprecated)] fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), + expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), data_type: DataType::Utf8, }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; - assert_eq!(expected_canonical, expr.canonical_name()); assert_eq!(expected_canonical, format!("{expr}")); // Note that CAST intentionally has a name that is different from its `Display` // representation. CAST does not change the name of expressions. @@ -3292,7 +3928,7 @@ mod test { #[test] fn test_is_volatile_scalar_func() { // UDF - #[derive(Debug)] + #[derive(Debug, PartialEq, Eq, Hash)] struct TestScalarUDF { signature: Signature, } @@ -3464,4 +4100,73 @@ mod test { rename: opt_rename, } } + + #[test] + fn test_size_of_expr() { + // because Expr is such a widely used struct in DataFusion + // it is important to keep its size as small as possible + // + // If this test fails when you change `Expr`, please try + // `Box`ing the fields to make `Expr` smaller + // See https://github.com/apache/datafusion/issues/16199 for details + assert_eq!(size_of::(), 112); + assert_eq!(size_of::(), 64); + assert_eq!(size_of::(), 24); // 3 ptrs + assert_eq!(size_of::>(), 24); + assert_eq!(size_of::>(), 8); + } + + #[test] + fn test_accept_exprs() { + fn accept_exprs>(_: &[E]) {} + + let expr = || -> Expr { lit(1) }; + + // Call accept_exprs with owned expressions + let owned_exprs = vec![expr(), expr()]; + accept_exprs(&owned_exprs); + + // Call accept_exprs with expressions from expr tree + let udf = Expr::ScalarFunction(ScalarFunction { + func: Arc::new(ScalarUDF::new_from_impl(TestUDF {})), + args: vec![expr(), expr()], + }); + let Expr::ScalarFunction(scalar) = &udf else { + unreachable!() + }; + accept_exprs(&scalar.args); + + // Call accept_exprs with expressions collected from expr tree, without cloning + let mut collected_refs: Vec<&Expr> = scalar.args.iter().collect(); + collected_refs.extend(&owned_exprs); + accept_exprs(&collected_refs); + + // test helpers + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestUDF {} + impl ScalarUDFImpl for TestUDF { + fn as_any(&self) -> &dyn Any { + unimplemented!() + } + + fn name(&self) -> &str { + unimplemented!() + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + unimplemented!() + } + } + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 966aba7d1195e..4666411dd5408 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,17 +19,18 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, WindowFunctionParams, + NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, StateFieldsArgs, }; +use crate::ptr_eq::PtrEq; use crate::select_expr::SelectExpr; use crate::{ conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs, - ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, + AggregateUDF, Expr, LimitEffect, LogicalPlan, Operator, PartitionEvaluator, + ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -37,13 +38,15 @@ use crate::{ use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; -use sqlparser::ast::NullTreatment; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; +use std::collections::HashMap; use std::fmt::Debug; +use std::hash::Hash; use std::ops::Not; use std::sync::Arc; @@ -68,8 +71,22 @@ pub fn col(ident: impl Into) -> Expr { /// Create an out reference column which hold a reference that has been resolved to a field /// outside of the current plan. +/// The expression created by this function does not preserve the metadata of the outer column. +/// Please use `out_ref_col_with_metadata` if you want to preserve the metadata. pub fn out_ref_col(dt: DataType, ident: impl Into) -> Expr { - Expr::OuterReferenceColumn(dt, ident.into()) + out_ref_col_with_metadata(dt, HashMap::new(), ident) +} + +/// Create an out reference column from an existing field (preserving metadata) +pub fn out_ref_col_with_metadata( + dt: DataType, + metadata: HashMap, + ident: impl Into, +) -> Expr { + let column = ident.into(); + let field: FieldRef = + Arc::new(Field::new(column.name(), dt, true).with_metadata(metadata)); + Expr::OuterReferenceColumn(field, column) } /// Create an unqualified column expression from the provided name, without normalizing @@ -401,11 +418,12 @@ pub fn create_udf( /// Implements [`ScalarUDFImpl`] for functions that have a single signature and /// return type. +#[derive(PartialEq, Eq, Hash)] pub struct SimpleScalarUDF { name: String, signature: Signature, return_type: DataType, - fun: ScalarFunctionImplementation, + fun: PtrEq, } impl Debug for SimpleScalarUDF { @@ -449,7 +467,7 @@ impl SimpleScalarUDF { name: name.into(), signature, return_type, - fun, + fun: fun.into(), } } } @@ -492,6 +510,7 @@ pub fn create_udaf( .into_iter() .enumerate() .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .map(Arc::new) .collect::>(); AggregateUDF::from(SimpleAggregateUDF::new( name, @@ -505,12 +524,13 @@ pub fn create_udaf( /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. +#[derive(PartialEq, Eq, Hash)] pub struct SimpleAggregateUDF { name: String, signature: Signature, return_type: DataType, - accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + accumulator: PtrEq, + state_fields: Vec, } impl Debug for SimpleAggregateUDF { @@ -533,7 +553,7 @@ impl SimpleAggregateUDF { return_type: DataType, volatility: Volatility, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -541,7 +561,7 @@ impl SimpleAggregateUDF { name, signature, return_type, - accumulator, + accumulator: accumulator.into(), state_fields, } } @@ -553,14 +573,14 @@ impl SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); Self { name, signature, return_type, - accumulator, + accumulator: accumulator.into(), state_fields, } } @@ -590,7 +610,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } @@ -619,11 +639,12 @@ pub fn create_udwf( /// Implements [`WindowUDFImpl`] for functions that have a single signature and /// return type. +#[derive(PartialEq, Eq, Hash)] pub struct SimpleWindowUDF { name: String, signature: Signature, return_type: DataType, - partition_evaluator_factory: PartitionEvaluatorFactory, + partition_evaluator_factory: PtrEq, } impl Debug for SimpleWindowUDF { @@ -653,7 +674,7 @@ impl SimpleWindowUDF { name, signature, return_type, - partition_evaluator_factory, + partition_evaluator_factory: partition_evaluator_factory.into(), } } } @@ -678,28 +699,32 @@ impl WindowUDFImpl for SimpleWindowUDF { (self.partition_evaluator_factory)() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new( + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Arc::new(Field::new( field_args.name(), self.return_type.clone(), true, - )) + ))) + } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown } } pub fn interval_year_month_lit(value: &str) -> Expr { let interval = parse_interval_year_month(value).ok(); - Expr::Literal(ScalarValue::IntervalYearMonth(interval)) + Expr::Literal(ScalarValue::IntervalYearMonth(interval), None) } pub fn interval_datetime_lit(value: &str) -> Expr { let interval = parse_interval_day_time(value).ok(); - Expr::Literal(ScalarValue::IntervalDayTime(interval)) + Expr::Literal(ScalarValue::IntervalDayTime(interval), None) } pub fn interval_month_day_nano_lit(value: &str) -> Expr { let interval = parse_interval_month_day_nano(value).ok(); - Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) + Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None) } /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] @@ -710,8 +735,8 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// # Example /// ```no_run /// # use datafusion_common::Result; +/// # use datafusion_expr::expr::NullTreatment; /// # use datafusion_expr::test::function_stub::count; -/// # use sqlparser::ast::NullTreatment; /// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; /// # // first_value is an aggregate function in another crate /// # fn first_value(_arg: Expr) -> Expr { @@ -764,7 +789,7 @@ pub trait ExprFunctionExt { #[derive(Debug, Clone)] pub enum ExprFuncKind { Aggregate(AggregateFunction), - Window(WindowFunction), + Window(Box), } /// Implementation of [`ExprFunctionExt`]. @@ -820,28 +845,22 @@ impl ExprFuncBuilder { let fun_expr = match fun { ExprFuncKind::Aggregate(mut udaf) => { - udaf.params.order_by = order_by; + udaf.params.order_by = order_by.unwrap_or_default(); udaf.params.filter = filter.map(Box::new); udaf.params.distinct = distinct; udaf.params.null_treatment = null_treatment; Expr::AggregateFunction(udaf) } - ExprFuncKind::Window(WindowFunction { - fun, - params: WindowFunctionParams { args, .. }, - }) => { + ExprFuncKind::Window(mut udwf) => { let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); - Expr::WindowFunction(WindowFunction { - fun, - params: WindowFunctionParams { - args, - partition_by: partition_by.unwrap_or_default(), - order_by: order_by.unwrap_or_default(), - window_frame: window_frame - .unwrap_or(WindowFrame::new(has_order_by)), - null_treatment, - }, - }) + udwf.params.partition_by = partition_by.unwrap_or_default(); + udwf.params.order_by = order_by.unwrap_or_default(); + udwf.params.window_frame = + window_frame.unwrap_or_else(|| WindowFrame::new(has_order_by)); + udwf.params.filter = filter.map(Box::new); + udwf.params.null_treatment = null_treatment; + udwf.params.distinct = distinct; + Expr::WindowFunction(udwf) } }; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 90dcbce46b017..d9fb9f7219c69 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -354,6 +354,7 @@ mod test { use std::ops::Add; use super::*; + use crate::literal::lit_with_metadata; use crate::{col, lit, Cast}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeRewriter; @@ -383,13 +384,13 @@ mod test { // rewrites all "foo" string literals to "bar" let transformer = |expr: Expr| -> Result> { match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { + Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => { let utf8_val = if utf8_val == "foo" { "bar".to_string() } else { utf8_val }; - Ok(Transformed::yes(lit(utf8_val))) + Ok(Transformed::yes(lit_with_metadata(utf8_val, metadata))) } // otherwise, return None _ => Ok(Transformed::no(expr)), @@ -476,7 +477,7 @@ mod test { ) -> DFSchema { let fields = fields .iter() - .map(|f| Arc::new(Field::new(f.to_string(), DataType::Int8, false))) + .map(|f| Arc::new(Field::new((*f).to_string(), DataType::Int8, false))) .collect::>(); let schema = Arc::new(Schema::new(fields)); DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap() diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index a349c83a49340..e803e35341305 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,24 +17,23 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, FieldMetadata, + InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use crate::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, }; -use crate::udf::ReturnTypeArgs; +use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, }; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; -use std::collections::HashMap; use std::sync::Arc; /// Trait to allow expr to typable with respect to a schema @@ -46,7 +45,7 @@ pub trait ExprSchemable { fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; /// Given a schema, return the expr's optional metadata - fn metadata(&self, schema: &dyn ExprSchema) -> Result>; + fn metadata(&self, schema: &dyn ExprSchema) -> Result; /// Convert to a field with respect to a schema fn to_field( @@ -113,9 +112,9 @@ impl ExprSchemable for Expr { }, Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), - Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), + Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.data_type()), + Expr::Literal(l, _) => Ok(l.data_type()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { let then_type = then_expr.get_type(schema)?; @@ -158,12 +157,16 @@ impl ExprSchemable for Expr { func, params: AggregateFunctionParams { args, .. }, }) => { - let data_types = args + let fields = args .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_types = data_types_with_aggregate_udf(&data_types, func) + let new_fields = fields_with_aggregate_udf(&fields, func) .map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); plan_datafusion_err!( "{} {}", match err { @@ -176,8 +179,10 @@ impl ExprSchemable for Expr { &data_types ) ) - })?; - Ok(func.return_type(&new_types)?) + })? + .into_iter() + .collect::>(); + Ok(func.return_field(&new_fields)?.data_type().clone()) } Expr::Not(_) | Expr::IsNull(_) @@ -271,8 +276,8 @@ impl ExprSchemable for Expr { || high.nullable(input_schema)?), Expr::Column(c) => input_schema.nullable(c), - Expr::OuterReferenceColumn(_, _) => Ok(true), - Expr::Literal(value) => Ok(value.is_null()), + Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()), + Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { // This expression is nullable if any of the input expressions are nullable let then_nullable = case @@ -340,22 +345,9 @@ impl ExprSchemable for Expr { } } - fn metadata(&self, schema: &dyn ExprSchema) -> Result> { - match self { - Expr::Column(c) => Ok(schema.metadata(c)?.clone()), - Expr::Alias(Alias { expr, metadata, .. }) => { - let mut ret = expr.metadata(schema)?; - if let Some(metadata) = metadata { - if !metadata.is_empty() { - ret.extend(metadata.clone()); - return Ok(ret); - } - } - Ok(ret) - } - Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), - _ => Ok(HashMap::new()), - } + fn metadata(&self, schema: &dyn ExprSchema) -> Result { + self.to_field(schema) + .map(|(_, field)| FieldMetadata::from(field.metadata())) } /// Returns the datatype and nullability of the expression based on [ExprSchema]. @@ -372,23 +364,112 @@ impl ExprSchemable for Expr { &self, schema: &dyn ExprSchema, ) -> Result<(DataType, bool)> { - match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { - None => schema - .data_type_and_nullable(&Column::from_name(name)) - .map(|(d, n)| (d.clone(), n)), - Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)), - }, - _ => expr.data_type_and_nullable(schema), - }, - Expr::Negative(expr) => expr.data_type_and_nullable(schema), - Expr::Column(c) => schema - .data_type_and_nullable(c) - .map(|(d, n)| (d.clone(), n)), - Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), - Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), - Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + let field = self.to_field(schema)?.1; + + Ok((field.data_type().clone(), field.is_nullable())) + } + + /// Returns a [arrow::datatypes::Field] compatible with this expression. + /// + /// This function converts an expression into a field with appropriate metadata + /// and nullability based on the expression type and context. It is the primary + /// mechanism for determining field-level schemas. + /// + /// # Field Property Resolution + /// + /// For each expression, the following properties are determined: + /// + /// ## Data Type Resolution + /// - **Column references**: Data type from input schema field + /// - **Literals**: Data type inferred from literal value + /// - **Aliases**: Data type inherited from the underlying expression (the aliased expression) + /// - **Binary expressions**: Result type from type coercion rules + /// - **Boolean expressions**: Always a boolean type + /// - **Cast expressions**: Target data type from cast operation + /// - **Function calls**: Return type based on function signature and argument types + /// + /// ## Nullability Determination + /// - **Column references**: Inherit nullability from input schema field + /// - **Literals**: Nullable only if literal value is NULL + /// - **Aliases**: Inherit nullability from the underlying expression (the aliased expression) + /// - **Binary expressions**: Nullable if either operand is nullable + /// - **Boolean expressions**: Always non-nullable (IS NULL, EXISTS, etc.) + /// - **Cast expressions**: determined by the input expression's nullability rules + /// - **Function calls**: Based on function nullability rules and input nullability + /// + /// ## Metadata Handling + /// - **Column references**: Preserve original field metadata from input schema + /// - **Literals**: Use explicitly provided metadata, otherwise empty + /// - **Aliases**: Merge underlying expr metadata with alias-specific metadata, preferring the alias metadata + /// - **Binary expressions**: field metadata is empty + /// - **Boolean expressions**: field metadata is empty + /// - **Cast expressions**: determined by the input expression's field metadata handling + /// - **Scalar functions**: Generate metadata via function's [`return_field_from_args`] method, + /// with the default implementation returning empty field metadata + /// - **Aggregate functions**: Generate metadata via function's [`return_field`] method, + /// with the default implementation returning empty field metadata + /// - **Window functions**: field metadata is empty + /// + /// ## Table Reference Scoping + /// - Establishes proper qualified field references when columns belong to specific tables + /// - Maintains table context for accurate field resolution in multi-table scenarios + /// + /// So for example, a projected expression `col(c1) + col(c2)` is + /// placed in an output field **named** col("c1 + c2") + /// + /// [`return_field_from_args`]: crate::ScalarUDF::return_field_from_args + /// [`return_field`]: crate::AggregateUDF::return_field + fn to_field( + &self, + schema: &dyn ExprSchema, + ) -> Result<(Option, Arc)> { + let (relation, schema_name) = self.qualified_name(); + #[allow(deprecated)] + let field = match self { + Expr::Alias(Alias { + expr, + name, + metadata, + .. + }) => { + let field = match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }) => { + match &data_type { + None => schema + .data_type_and_nullable(&Column::from_name(name)) + .map(|(d, n)| Field::new(&schema_name, d.clone(), n)), + Some(dt) => Ok(Field::new( + &schema_name, + dt.clone(), + expr.nullable(schema)?, + )), + } + } + _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()), + }?; + + let mut combined_metadata = expr.metadata(schema)?; + if let Some(metadata) = metadata { + combined_metadata.extend(metadata.clone()); + } + + Ok(Arc::new(combined_metadata.add_to_field(field))) + } + Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f), + Expr::Column(c) => schema.field_from_column(c).map(|f| Arc::new(f.clone())), + Expr::OuterReferenceColumn(field, _) => { + Ok(Arc::new(field.as_ref().clone().with_name(&schema_name))) + } + Expr::ScalarVariable(ty, _) => { + Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) + } + Expr::Literal(l, metadata) => { + let mut field = Field::new(&schema_name, l.data_type(), l.is_null()); + if let Some(metadata) = metadata { + field = metadata.add_to_field(field); + } + Ok(Arc::new(field)) + } Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -397,11 +478,12 @@ impl ExprSchemable for Expr { | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) - | Expr::Exists { .. } => Ok((DataType::Boolean, false)), - Expr::ScalarSubquery(subquery) => Ok(( - subquery.subquery.schema().field(0).data_type().clone(), - subquery.subquery.schema().field(0).is_nullable(), - )), + | Expr::Exists { .. } => { + Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, false))) + } + Expr::ScalarSubquery(subquery) => { + Ok(Arc::clone(&subquery.subquery.schema().fields()[0])) + } Expr::BinaryExpr(BinaryExpr { ref left, ref right, @@ -412,17 +494,63 @@ impl ExprSchemable for Expr { let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type); coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default()); coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default()); - Ok((coercer.get_result_type()?, lhs_nullable || rhs_nullable)) + Ok(Arc::new(Field::new( + &schema_name, + coercer.get_result_type()?, + lhs_nullable || rhs_nullable, + ))) } Expr::WindowFunction(window_function) => { - self.data_type_and_nullable_with_window_function(schema, window_function) + let (dt, nullable) = self.data_type_and_nullable_with_window_function( + schema, + window_function, + )?; + Ok(Arc::new(Field::new(&schema_name, dt, nullable))) + } + Expr::AggregateFunction(aggregate_function) => { + let AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + .. + } = aggregate_function; + + let fields = args + .iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()?; + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + let new_fields = fields_with_aggregate_udf(&fields, func) + .map_err(|err| { + let arg_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + plan_datafusion_err!( + "{} {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_types, + ) + ) + })? + .into_iter() + .collect::>(); + + func.return_field(&new_fields) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, nullables): (Vec, Vec) = args + let (arg_types, fields): (Vec, Vec>) = args .iter() - .map(|e| e.data_type_and_nullable(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()? .into_iter() + .map(|f| (f.data_type().clone(), f)) .unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) @@ -440,42 +568,54 @@ impl ExprSchemable for Expr { ) ) })?; + let new_fields = fields + .into_iter() + .zip(new_data_types) + .map(|(f, d)| f.as_ref().clone().with_data_type(d)) + .map(Arc::new) + .collect::>(); let arguments = args .iter() .map(|e| match e { - Expr::Literal(sv) => Some(sv), + Expr::Literal(sv, _) => Some(sv), _ => None, }) .collect::>(); - let args = ReturnTypeArgs { - arg_types: &new_data_types, + let args = ReturnFieldArgs { + arg_fields: &new_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = - func.return_type_from_args(args)?.into_parts(); - Ok((return_type, nullable)) + func.return_field_from_args(args) } - _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - } - } - - /// Returns a [arrow::datatypes::Field] compatible with this expression. - /// - /// So for example, a projected expression `col(c1) + col(c2)` is - /// placed in an output field **named** col("c1 + c2") - fn to_field( - &self, - input_schema: &dyn ExprSchema, - ) -> Result<(Option, Arc)> { - let (relation, schema_name) = self.qualified_name(); - let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - let field = Field::new(schema_name, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(); - Ok((relation, field)) + // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), + Expr::Cast(Cast { expr, data_type }) => expr + .to_field(schema) + .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) + .map(Arc::new), + Expr::Like(_) + | Expr::SimilarTo(_) + | Expr::Not(_) + | Expr::Between(_) + | Expr::Case(_) + | Expr::TryCast(_) + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::Wildcard { .. } + | Expr::GroupingSet(_) + | Expr::Placeholder(_) + | Expr::Unnest(_) => Ok(Arc::new(Field::new( + &schema_name, + self.get_type(schema)?, + self.nullable(schema)?, + ))), + }?; + + Ok(( + relation, + Arc::new(field.as_ref().clone().with_name(schema_name)), + )) } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -502,7 +642,7 @@ impl ExprSchemable for Expr { _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))), } } else { - plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}") + plan_err!("Cannot automatically convert {this_type} to {cast_to_type}") } } } @@ -528,13 +668,18 @@ impl Expr { .. } = window_function; - let data_types = args + let fields = args .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udaf) => { - let new_types = data_types_with_aggregate_udf(&data_types, udaf) + let data_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let new_fields = fields_with_aggregate_udf(&fields, udaf) .map_err(|err| { plan_datafusion_err!( "{} {}", @@ -548,16 +693,22 @@ impl Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); - let return_type = udaf.return_type(&new_types)?; - let nullable = udaf.is_nullable(); + let return_field = udaf.return_field(&new_fields)?; - Ok((return_type, nullable)) + Ok((return_field.data_type().clone(), return_field.is_nullable())) } WindowFunctionDefinition::WindowUDF(udwf) => { - let new_types = - data_types_with_window_udf(&data_types, udwf).map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let new_fields = fields_with_window_udf(&fields, udwf) + .map_err(|err| { plan_datafusion_err!( "{} {}", match err { @@ -570,9 +721,11 @@ impl Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); let (_, function_name) = self.qualified_name(); - let field_args = WindowUDFFieldArgs::new(&new_types, &function_name); + let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name); udwf.field(field_args) .map(|field| (field.data_type().clone(), field.is_nullable())) @@ -624,9 +777,9 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result {{ @@ -732,6 +885,7 @@ mod tests { fn test_expr_metadata() { let mut meta = HashMap::new(); meta.insert("bar".to_string(), "buzz".to_string()); + let meta = FieldMetadata::from(meta); let expr = col("foo"); let schema = MockExprSchema::new() .with_data_type(DataType::Int32) @@ -750,41 +904,44 @@ mod tests { ); let schema = DFSchema::from_unqualified_fields( - vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())] - .into(), - HashMap::new(), + vec![meta.add_to_field(Field::new("foo", DataType::Int32, true))].into(), + std::collections::HashMap::new(), ) .unwrap(); // verify to_field method populates metadata - assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata()); + assert_eq!(meta, expr.metadata(&schema).unwrap()); + + // outer ref constructed by `out_ref_col_with_metadata` should be metadata-preserving + let outer_ref = out_ref_col_with_metadata( + DataType::Int32, + meta.to_hashmap(), + Column::from_name("foo"), + ); + assert_eq!(meta, outer_ref.metadata(&schema).unwrap()); } #[derive(Debug)] struct MockExprSchema { - nullable: bool, - data_type: DataType, + field: Field, error_on_nullable: bool, - metadata: HashMap, } impl MockExprSchema { fn new() -> Self { Self { - nullable: false, - data_type: DataType::Null, + field: Field::new("mock_field", DataType::Null, false), error_on_nullable: false, - metadata: HashMap::new(), } } fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.field = self.field.with_nullable(nullable); self } fn with_data_type(mut self, data_type: DataType) -> Self { - self.data_type = data_type; + self.field = self.field.with_data_type(data_type); self } @@ -793,8 +950,8 @@ mod tests { self } - fn with_metadata(mut self, metadata: HashMap) -> Self { - self.metadata = metadata; + fn with_metadata(mut self, metadata: FieldMetadata) -> Self { + self.field = metadata.add_to_field(self.field); self } } @@ -804,20 +961,12 @@ mod tests { if self.error_on_nullable { internal_err!("nullable error") } else { - Ok(self.nullable) + Ok(self.field.is_nullable()) } } - fn data_type(&self, _col: &Column) -> Result<&DataType> { - Ok(&self.data_type) - } - - fn metadata(&self, _col: &Column) -> Result<&HashMap> { - Ok(&self.metadata) - } - - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - Ok((self.data_type(col)?, self.nullable(col)?)) + fn field_from_column(&self, _col: &Column) -> Result<&Field> { + Ok(&self.field) } } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d3cc881af3616..346d373ff5b4d 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -34,6 +34,8 @@ //! //! The [expr_fn] module contains functions for creating expressions. +extern crate core; + mod literal; mod operation; mod partition_evaluator; @@ -63,18 +65,24 @@ pub mod simplify; pub mod sort_properties { pub use datafusion_expr_common::sort_properties::*; } +pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } +pub mod ptr_eq; pub mod test; pub mod tree_node; pub mod type_coercion; +pub mod udf_eq; pub mod utils; pub mod var_provider; pub mod window_frame; pub mod window_state; -pub use datafusion_doc::{DocSection, Documentation, DocumentationBuilder}; +pub use datafusion_doc::{ + aggregate_doc_sections, scalar_doc_sections, window_doc_sections, DocSection, + Documentation, DocumentationBuilder, +}; pub use datafusion_expr_common::accumulator::Accumulator; pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; @@ -94,20 +102,22 @@ pub use function::{ AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, ScalarFunctionImplementation, StateTypeFunction, }; -pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; +pub use literal::{ + lit, lit_timestamp_nano, lit_with_metadata, Literal, TimestampLiteral, +}; pub use logical_plan::*; pub use partition_evaluator::PartitionEvaluator; +#[cfg(feature = "sql")] pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{ - aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, - SetMonotonicity, StatisticsArgs, -}; -pub use udf::{ - scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, + udaf_default_display_name, udaf_default_human_display, udaf_default_return_field, + udaf_default_schema_name, udaf_default_window_function_display_name, + udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, + ReversedUDAF, SetMonotonicity, StatisticsArgs, }; -pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; +pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index 90ba5a9a693c7..c4bd43bc0a620 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -17,6 +17,7 @@ //! Literal module contains foundational types that are used to represent literals in DataFusion. +use crate::expr::FieldMetadata; use crate::Expr; use datafusion_common::ScalarValue; @@ -25,6 +26,25 @@ pub fn lit(n: T) -> Expr { n.lit() } +pub fn lit_with_metadata(n: T, metadata: Option) -> Expr { + let Some(metadata) = metadata else { + return n.lit(); + }; + + let Expr::Literal(sv, prior_metadata) = n.lit() else { + unreachable!(); + }; + let new_metadata = match prior_metadata { + Some(mut prior) => { + prior.extend(metadata); + prior + } + None => metadata, + }; + + Expr::Literal(sv, Some(new_metadata)) +} + /// Create a literal timestamp expression pub fn lit_timestamp_nano(n: T) -> Expr { n.lit_timestamp_nano() @@ -43,37 +63,37 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(*self)) + Expr::Literal(ScalarValue::from(*self), None) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::Literal(ScalarValue::from(self.as_ref()), None) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::Literal(ScalarValue::from(self.as_ref()), None) } } impl Literal for Vec { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())), None) } } impl Literal for &[u8] { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())), None) } } impl Literal for ScalarValue { fn lit(&self) -> Expr { - Expr::Literal(self.clone()) + Expr::Literal(self.clone(), None) } } @@ -82,7 +102,7 @@ macro_rules! make_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + Expr::Literal(ScalarValue::$SCALAR(Some(self.clone())), None) } } }; @@ -93,7 +113,7 @@ macro_rules! make_nonzero_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.get()))) + Expr::Literal(ScalarValue::$SCALAR(Some(self.get())), None) } } }; @@ -104,10 +124,10 @@ macro_rules! make_timestamp_literal { #[doc = $DOC] impl TimestampLiteral for $TYPE { fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( - Some((self.clone()).into()), + Expr::Literal( + ScalarValue::TimestampNanosecond(Some((self.clone()).into()), None), None, - )) + ) } } }; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 91a871d52e9ad..7a283b0420d3c 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -18,13 +18,14 @@ //! This module provides a builder for creating LogicalPlans use std::any::Any; +use std::borrow::Cow; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::iter::once; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::{Alias, PlannedReplaceSelectItem, Sort as SortExpr}; +use crate::expr::{Alias, FieldMetadata, PlannedReplaceSelectItem, Sort as SortExpr}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, @@ -43,20 +44,19 @@ use crate::utils::{ group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, - Statement, TableProviderFilterPushDown, TableSource, WriteOp, + and, binary_expr, lit, DmlStatement, ExplainOption, Expr, ExprSchemable, Operator, + RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, WriteOp, }; use super::dml::InsertOp; -use super::plan::{ColumnUnnestList, ExplainFormat}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - exec_err, get_target_functional_dependencies, internal_err, not_impl_err, + exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, + NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -282,15 +282,14 @@ impl LogicalPlanBuilder { let value = &row[j]; let data_type = value.get_type(schema)?; - if !data_type.equals_datatype(field_type) { - if can_cast_types(&data_type, field_type) { - } else { - return exec_err!( - "type mismatch and can't cast to got {} and {}", - data_type, - field_type - ); - } + if !data_type.equals_datatype(field_type) + && !can_cast_types(&data_type, field_type) + { + return exec_err!( + "type mismatch and can't cast to got {} and {}", + data_type, + field_type + ); } } fields.push(field_type.to_owned(), field_nullable); @@ -306,8 +305,17 @@ impl LogicalPlanBuilder { for j in 0..n_cols { let mut common_type: Option = None; + let mut common_metadata: Option = None; for (i, row) in values.iter().enumerate() { let value = &row[j]; + let metadata = value.metadata(&schema)?; + if let Some(ref cm) = common_metadata { + if &metadata != cm { + return plan_err!("Inconsistent metadata across values list at row {i} column {j}. Was {:?} but found {:?}", cm, metadata); + } + } else { + common_metadata = Some(metadata.clone()); + } let data_type = value.get_type(&schema)?; if data_type == DataType::Null { continue; @@ -326,7 +334,11 @@ impl LogicalPlanBuilder { } // assuming common_type was not set, and no error, therefore the type should be NULL // since the code loop skips NULL - fields.push(common_type.unwrap_or(DataType::Null), true); + fields.push_with_metadata( + common_type.unwrap_or(DataType::Null), + true, + common_metadata, + ); } Self::infer_inner(values, fields, &schema) @@ -341,8 +353,11 @@ impl LogicalPlanBuilder { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in fields.iter().map(|f| f.data_type()).enumerate() { - if let Expr::Literal(ScalarValue::Null) = row[j] { - row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); + if let Expr::Literal(ScalarValue::Null, metadata) = &row[j] { + row[j] = Expr::Literal( + ScalarValue::try_from(field_type)?, + metadata.clone(), + ); } else { row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?; } @@ -403,13 +418,13 @@ impl LogicalPlanBuilder { options: HashMap, partition_by: Vec, ) -> Result { - Ok(Self::new(LogicalPlan::Copy(CopyTo { - input: Arc::new(input), + Ok(Self::new(LogicalPlan::Copy(CopyTo::new( + Arc::new(input), output_url, partition_by, file_type, options, - }))) + )))) } /// Create a [`DmlStatement`] for inserting the contents of this builder into the named table. @@ -501,6 +516,21 @@ impl LogicalPlanBuilder { if table_scan.filters.is_empty() { if let Some(p) = table_scan.source.get_logical_plan() { let sub_plan = p.into_owned(); + + if let Some(proj) = table_scan.projection { + let projection_exprs = proj + .into_iter() + .map(|i| { + Expr::Column(Column::from( + sub_plan.schema().qualified_field(i), + )) + }) + .collect::>(); + return Self::new(sub_plan) + .project(projection_exprs)? + .alias(table_scan.table_name); + } + // Ensures that the reference to the inlined table remains the // same, meaning we don't have to change any of the parent nodes // that reference this table. @@ -586,7 +616,7 @@ impl LogicalPlanBuilder { /// Apply a filter which is used for a having clause pub fn having(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Filter::try_new_with_having(expr, self.plan) + Filter::try_new(expr, self.plan) .map(LogicalPlan::Filter) .map(Self::from) } @@ -885,7 +915,13 @@ impl LogicalPlanBuilder { join_keys: (Vec>, Vec>), filter: Option, ) -> Result { - self.join_detailed(right, join_type, join_keys, filter, false) + self.join_detailed( + right, + join_type, + join_keys, + filter, + NullEquality::NullEqualsNothing, + ) } /// Apply a join using the specified expressions. @@ -941,15 +977,11 @@ impl LogicalPlanBuilder { join_type, (Vec::::new(), Vec::::new()), filter, - false, + NullEquality::NullEqualsNothing, ) } - pub(crate) fn normalize( - plan: &LogicalPlan, - column: impl Into, - ) -> Result { - let column = column.into(); + pub(crate) fn normalize(plan: &LogicalPlan, column: Column) -> Result { if column.relation.is_some() { // column is already normalized return Ok(column); @@ -969,16 +1001,14 @@ impl LogicalPlanBuilder { /// The behavior is the same as [`join`](Self::join) except that it allows /// specifying the null equality behavior. /// - /// If `null_equals_null=true`, rows where both join keys are `null` will be - /// emitted. Otherwise rows where either or both join keys are `null` will be - /// omitted. + /// The `null_equality` dictates how `null` values are joined. pub fn join_detailed( self, right: LogicalPlan, join_type: JoinType, join_keys: (Vec>, Vec>), filter: Option, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { if join_keys.0.len() != join_keys.1.len() { return plan_err!("left_keys and right_keys were not the same length"); @@ -1095,7 +1125,7 @@ impl LogicalPlanBuilder { join_type, join_constraint: JoinConstraint::On, schema: DFSchemaRef::new(join_schema), - null_equals_null, + null_equality, }))) } @@ -1104,7 +1134,7 @@ impl LogicalPlanBuilder { self, right: LogicalPlan, join_type: JoinType, - using_keys: Vec + Clone>, + using_keys: Vec, ) -> Result { let left_keys: Vec = using_keys .clone() @@ -1117,19 +1147,29 @@ impl LogicalPlanBuilder { .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys).collect(); - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &join_type)?; let mut join_on: Vec<(Expr, Expr)> = vec![]; let mut filters: Option = None; for (l, r) in &on { if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(l)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + l, + )? + .data_type(), + ) { join_on.push((Expr::Column(l.clone()), Expr::Column(r.clone()))); } else if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(r)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + r, + )? + .data_type(), + ) { join_on.push((Expr::Column(r.clone()), Expr::Column(l.clone()))); } else { @@ -1148,36 +1188,36 @@ impl LogicalPlanBuilder { if join_on.is_empty() { let join = Self::from(self.plan).cross_join(right)?; join.filter(filters.ok_or_else(|| { - DataFusionError::Internal("filters should not be None here".to_string()) + internal_datafusion_err!("filters should not be None here") })?) } else { - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: join_on, - filter: filters, + let join = Join::try_new( + self.plan, + Arc::new(right), + join_on, + filters, join_type, - join_constraint: JoinConstraint::Using, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }))) + JoinConstraint::Using, + NullEquality::NullEqualsNothing, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } } /// Apply a cross join pub fn cross_join(self, right: LogicalPlan) -> Result { - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: vec![], - filter: None, - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - null_equals_null: false, - schema: DFSchemaRef::new(join_schema), - }))) + let join = Join::try_new( + self.plan, + Arc::new(right), + vec![], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } /// Repartition @@ -1230,12 +1270,24 @@ impl LogicalPlanBuilder { /// /// if `verbose` is true, prints out additional details. pub fn explain(self, verbose: bool, analyze: bool) -> Result { + // Keep the format default to Indent + self.explain_option_format( + ExplainOption::default() + .with_verbose(verbose) + .with_analyze(analyze), + ) + } + + /// Create an expression to represent the explanation of the plan + /// The`explain_option` is used to specify the format and verbosity of the explanation. + /// Details see [`ExplainOption`]. + pub fn explain_option_format(self, explain_option: ExplainOption) -> Result { let schema = LogicalPlan::explain_schema(); let schema = schema.to_dfschema_ref()?; - if analyze { + if explain_option.analyze { Ok(Self::new(LogicalPlan::Analyze(Analyze { - verbose, + verbose: explain_option.verbose, input: self.plan, schema, }))) @@ -1244,9 +1296,9 @@ impl LogicalPlanBuilder { vec![self.plan.to_stringified(PlanType::InitialLogicalPlan)]; Ok(Self::new(LogicalPlan::Explain(Explain { - verbose, + verbose: explain_option.verbose, plan: self.plan, - explain_format: ExplainFormat::Indent, + explain_format: explain_option.format, stringified_plans, schema, logical_optimization_succeeded: false, @@ -1312,12 +1364,24 @@ impl LogicalPlanBuilder { .unzip(); if is_all { LogicalPlanBuilder::from(left_plan) - .join_detailed(right_plan, join_type, join_keys, None, true)? + .join_detailed( + right_plan, + join_type, + join_keys, + None, + NullEquality::NullEqualsNull, + )? .build() } else { LogicalPlanBuilder::from(left_plan) .distinct()? - .join_detailed(right_plan, join_type, join_keys, None, true)? + .join_detailed( + right_plan, + join_type, + join_keys, + None, + NullEquality::NullEqualsNull, + )? .build() } } @@ -1338,7 +1402,7 @@ impl LogicalPlanBuilder { /// to columns from the existing input. `r`, the second element of the tuple, /// must only refer to columns from the right input. /// - /// `filter` contains any other other filter expression to apply during the + /// `filter` contains any other filter expression to apply during the /// join. Note that `equi_exprs` predicates are evaluated more efficiently /// than the filter expressions, so they are preferred. pub fn join_with_expr_keys( @@ -1388,19 +1452,17 @@ impl LogicalPlanBuilder { }) .collect::>>()?; - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: join_key_pairs, + let join = Join::try_new( + self.plan, + Arc::new(right), + join_key_pairs, filter, join_type, - join_constraint: JoinConstraint::On, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }))) + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } /// Unnest the given column. @@ -1457,10 +1519,23 @@ impl ValuesFields { } pub fn push(&mut self, data_type: DataType, nullable: bool) { + self.push_with_metadata(data_type, nullable, None); + } + + pub fn push_with_metadata( + &mut self, + data_type: DataType, + nullable: bool, + metadata: Option, + ) { // Naming follows the convention described here: // https://www.postgresql.org/docs/current/queries-values.html let name = format!("column{}", self.inner.len() + 1); - self.inner.push(Field::new(name, data_type, nullable)); + let mut field = Field::new(name, data_type, nullable); + if let Some(metadata) = metadata { + field.set_metadata(metadata.to_hashmap()); + } + self.inner.push(field); } pub fn into_fields(self) -> Fields { @@ -1468,18 +1543,48 @@ impl ValuesFields { } } -pub fn change_redundant_column(fields: &Fields) -> Vec { - let mut name_map = HashMap::new(); +/// Returns aliases to make field names unique. +/// +/// Returns a vector of optional aliases, one per input field. `None` means keep the original name, +/// `Some(alias)` means rename to the alias to ensure uniqueness. +/// +/// Used when creating [`SubqueryAlias`] or similar operations that strip table qualifiers but need +/// to maintain unique column names. +/// +/// # Example +/// Input fields: `[a, a, b, b, a, a:1]` ([`DFSchema`] valid when duplicate fields have different qualifiers) +/// Returns: `[None, Some("a:1"), None, Some("b:1"), Some("a:2"), Some("a:1:1")]` +pub fn unique_field_aliases(fields: &Fields) -> Vec> { + // Some field names might already come to this function with the count (number of times it appeared) + // as a suffix e.g. id:1, so there's still a chance of name collisions, for example, + // if these three fields passed to this function: "col:1", "col" and "col", the function + // would rename them to -> col:1, col, col:1 causing a posterior error when building the DFSchema. + // That's why we need the `seen` set, so the fields are always unique. + + // Tracks a mapping between a field name and the number of appearances of that field. + let mut name_map = HashMap::<&str, usize>::new(); + // Tracks all the fields and aliases that were previously seen. + let mut seen = HashSet::>::new(); + fields - .into_iter() + .iter() .map(|field| { - let counter = name_map.entry(field.name().to_string()).or_insert(0); - *counter += 1; - if *counter > 1 { - let new_name = format!("{}:{}", field.name(), *counter - 1); - Field::new(new_name, field.data_type().clone(), field.is_nullable()) - } else { - field.as_ref().clone() + let original_name = field.name(); + let mut name = Cow::Borrowed(original_name); + + let count = name_map.entry(original_name).or_insert(0); + + // Loop until we find a name that hasn't been used. + while seen.contains(&name) { + *count += 1; + name = Cow::Owned(format!("{original_name}:{count}")); + } + + seen.insert(name.clone()); + + match name { + Cow::Borrowed(_) => None, + Cow::Owned(alias) => Some(alias), } }) .collect() @@ -1579,22 +1684,68 @@ pub fn build_join_schema( .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect() } + JoinType::RightMark => right_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .chain(once(mark_field(left))) + .collect(), }; let func_dependencies = left.functional_dependencies().join( right.functional_dependencies(), join_type, left.fields().len(), ); - let metadata = left + + let (schema1, schema2) = match join_type { + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => (left, right), + _ => (right, left), + }; + + let metadata = schema1 .metadata() .clone() .into_iter() - .chain(right.metadata().clone()) + .chain(schema2.metadata().clone()) .collect(); + let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; dfschema.with_functional_dependencies(func_dependencies) } +/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise +/// conflict with the columns from the other. +/// This is especially useful for queries that come as Substrait, since Substrait doesn't currently allow specifying +/// aliases, neither for columns nor for tables. DataFusion requires columns to be uniquely identifiable, in some +/// places (see e.g. DFSchema::check_names). +/// The function returns: +/// - The requalified or original left logical plan +/// - The requalified or original right logical plan +/// - If a requalification was needed or not +pub fn requalify_sides_if_needed( + left: LogicalPlanBuilder, + right: LogicalPlanBuilder, +) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder, bool)> { + let left_cols = left.schema().columns(); + let right_cols = right.schema().columns(); + if left_cols.iter().any(|l| { + right_cols.iter().any(|r| { + l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) + }) + }) { + // These names have no connection to the original plan, but they'll make the columns + // (mostly) unique. + Ok(( + left.alias(TableReference::bare("left"))?, + right.alias(TableReference::bare("right"))?, + true, + )) + } else { + Ok((left, right, false)) + } +} + /// Add additional "synthetic" group by expressions based on functional /// dependencies. /// @@ -1847,6 +1998,7 @@ pub fn table_scan_with_filter_and_fetch( } pub fn table_source(table_schema: &Schema) -> Arc { + // TODO should we take SchemaRef and avoid cloning? let table_schema = Arc::new(table_schema.clone()); Arc::new(LogicalTableSource { table_schema, @@ -1858,6 +2010,7 @@ pub fn table_source_with_constraints( table_schema: &Schema, constraints: Constraints, ) -> Arc { + // TODO should we take SchemaRef and avoid cloning? let table_schema = Arc::new(table_schema.clone()); Arc::new(LogicalTableSource { table_schema, @@ -1977,27 +2130,6 @@ pub fn unnest(input: LogicalPlan, columns: Vec) -> Result { unnest_with_options(input, columns, UnnestOptions::default()) } -// Get the data type of a multi-dimensional type after unnesting it -// with a given depth -fn get_unnested_list_datatype_recursive( - data_type: &DataType, - depth: usize, -) -> Result { - match data_type { - DataType::List(field) - | DataType::FixedSizeList(field, _) - | DataType::LargeList(field) => { - if depth == 1 { - return Ok(field.data_type().clone()); - } - return get_unnested_list_datatype_recursive(field.data_type(), depth - 1); - } - _ => {} - }; - - internal_err!("trying to unnest on invalid data type {:?}", data_type) -} - pub fn get_struct_unnested_columns( col_name: &String, inner_fields: &Fields, @@ -2008,53 +2140,6 @@ pub fn get_struct_unnested_columns( .collect() } -// Based on data type, either struct or a variant of list -// return a set of columns as the result of unnesting -// the input columns. -// For example, given a column with name "a", -// - List(Element) returns ["a"] with data type Element -// - Struct(field1, field2) returns ["a.field1","a.field2"] -// For list data type, an argument depth is used to specify -// the recursion level -pub fn get_unnested_columns( - col_name: &String, - data_type: &DataType, - depth: usize, -) -> Result)>> { - let mut qualified_columns = Vec::with_capacity(1); - - match data_type { - DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { - let data_type = get_unnested_list_datatype_recursive(data_type, depth)?; - let new_field = Arc::new(Field::new( - col_name, data_type, - // Unnesting may produce NULLs even if the list is not null. - // For example: unnest([1], []) -> 1, null - true, - )); - let column = Column::from_name(col_name); - // let column = Column::from((None, &new_field)); - qualified_columns.push((column, new_field)); - } - DataType::Struct(fields) => { - qualified_columns.extend(fields.iter().map(|f| { - let new_name = format!("{}.{}", col_name, f.name()); - let column = Column::from_name(&new_name); - let new_field = f.as_ref().clone().with_name(new_name); - // let column = Column::from((None, &f)); - (column, Arc::new(new_field)) - })) - } - _ => { - return internal_err!( - "trying to unnest on invalid data type {:?}", - data_type - ); - } - }; - Ok(qualified_columns) -} - /// Create a [`LogicalPlan::Unnest`] plan with options /// This function receive a list of columns to be unnested /// because multiple unnest can be performed on the same column (e.g unnest with different depth) @@ -2089,136 +2174,27 @@ pub fn unnest_with_options( columns_to_unnest: Vec, options: UnnestOptions, ) -> Result { - let mut list_columns: Vec<(usize, ColumnUnnestList)> = vec![]; - let mut struct_columns = vec![]; - let indices_to_unnest = columns_to_unnest - .iter() - .map(|c| Ok((input.schema().index_of_column(c)?, c))) - .collect::>>()?; - - let input_schema = input.schema(); - - let mut dependency_indices = vec![]; - // Transform input schema into new schema - // Given this comprehensive example - // - // input schema: - // 1.col1_unnest_placeholder: list[list[int]], - // 2.col1: list[list[int]] - // 3.col2: list[int] - // with unnest on unnest(col1,depth=2), unnest(col1,depth=1) and unnest(col2,depth=1) - // output schema: - // 1.unnest_col1_depth_2: int - // 2.unnest_col1_depth_1: list[int] - // 3.col1: list[list[int]] - // 4.unnest_col2_depth_1: int - // Meaning the placeholder column will be replaced by its unnested variation(s), note - // the plural. - let fields = input_schema - .iter() - .enumerate() - .map(|(index, (original_qualifier, original_field))| { - match indices_to_unnest.get(&index) { - Some(column_to_unnest) => { - let recursions_on_column = options - .recursions - .iter() - .filter(|p| -> bool { &p.input_column == *column_to_unnest }) - .collect::>(); - let mut transformed_columns = recursions_on_column - .iter() - .map(|r| { - list_columns.push(( - index, - ColumnUnnestList { - output_column: r.output_column.clone(), - depth: r.depth, - }, - )); - Ok(get_unnested_columns( - &r.output_column.name, - original_field.data_type(), - r.depth, - )? - .into_iter() - .next() - .unwrap()) // because unnesting a list column always result into one result - }) - .collect::)>>>()?; - if transformed_columns.is_empty() { - transformed_columns = get_unnested_columns( - &column_to_unnest.name, - original_field.data_type(), - 1, - )?; - match original_field.data_type() { - DataType::Struct(_) => { - struct_columns.push(index); - } - DataType::List(_) - | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) => { - list_columns.push(( - index, - ColumnUnnestList { - output_column: Column::from_name( - &column_to_unnest.name, - ), - depth: 1, - }, - )); - } - _ => {} - }; - } - - // new columns dependent on the same original index - dependency_indices - .extend(std::iter::repeat(index).take(transformed_columns.len())); - Ok(transformed_columns - .iter() - .map(|(col, field)| (col.relation.to_owned(), field.to_owned())) - .collect()) - } - None => { - dependency_indices.push(index); - Ok(vec![( - original_qualifier.cloned(), - Arc::clone(original_field), - )]) - } - } - }) - .collect::>>()? - .into_iter() - .flatten() - .collect::>(); - - let metadata = input_schema.metadata().clone(); - let df_schema = DFSchema::new_with_metadata(fields, metadata)?; - // We can use the existing functional dependencies: - let deps = input_schema.functional_dependencies().clone(); - let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); - - Ok(LogicalPlan::Unnest(Unnest { - input: Arc::new(input), - exec_columns: columns_to_unnest, - list_type_columns: list_columns, - struct_type_columns: struct_columns, - dependency_indices, - schema, + Ok(LogicalPlan::Unnest(Unnest::try_new( + Arc::new(input), + columns_to_unnest, options, - })) + )?)) } #[cfg(test)] mod tests { + use std::vec; + use super::*; + use crate::lit_with_metadata; use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; use crate::test::function_stub::sum; - use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError}; + use datafusion_common::{ + Constraint, DataFusionError, RecursionUnnestOption, SchemaError, + }; + use insta::assert_snapshot; #[test] fn plan_builder_simple() -> Result<()> { @@ -2228,11 +2204,11 @@ mod tests { .project(vec![col("id")])? .build()?; - let expected = "Projection: employee_csv.id\ - \n Filter: employee_csv.state = Utf8(\"CO\")\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r#" + Projection: employee_csv.id + Filter: employee_csv.state = Utf8("CO") + TableScan: employee_csv projection=[id, state] + "#); Ok(()) } @@ -2244,12 +2220,7 @@ mod tests { let plan = LogicalPlanBuilder::scan("employee_csv", table_source(&schema), projection) .unwrap(); - let expected = DFSchema::try_from_qualified_schema( - TableReference::bare("employee_csv"), - &schema, - ) - .unwrap(); - assert_eq!(&expected, plan.schema().as_ref()); + assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}"); // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifier // (and thus normalized to "employee"csv") as well @@ -2257,7 +2228,7 @@ mod tests { let plan = LogicalPlanBuilder::scan("EMPLOYEE_CSV", table_source(&schema), projection) .unwrap(); - assert_eq!(&expected, plan.schema().as_ref()); + assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}"); } #[test] @@ -2266,9 +2237,9 @@ mod tests { let projection = None; let err = LogicalPlanBuilder::scan("", table_source(&schema), projection).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: table_name cannot be empty" + @"Error during planning: table_name cannot be empty" ); } @@ -2282,10 +2253,10 @@ mod tests { ])? .build()?; - let expected = "Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2302,15 +2273,15 @@ mod tests { .union(plan.build()?)? .build()?; - let expected = "Union\ - \n Union\ - \n Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Union + Union + Union + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2327,19 +2298,18 @@ mod tests { .union_distinct(plan.build()?)? .build()?; - let expected = "\ - Distinct:\ - \n Union\ - \n Distinct:\ - \n Union\ - \n Distinct:\ - \n Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Distinct: + Union + Distinct: + Union + Distinct: + Union + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2353,13 +2323,12 @@ mod tests { .distinct()? .build()?; - let expected = "\ - Distinct:\ - \n Projection: employee_csv.id\ - \n Filter: employee_csv.state = Utf8(\"CO\")\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r#" + Distinct: + Projection: employee_csv.id + Filter: employee_csv.state = Utf8("CO") + TableScan: employee_csv projection=[id, state] + "#); Ok(()) } @@ -2379,14 +2348,15 @@ mod tests { .filter(exists(Arc::new(subquery)))? .build()?; - let expected = "Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.a\ - \n TableScan: foo\ - \n Projection: bar.a\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Filter: EXISTS () + Subquery: + Filter: foo.a = bar.a + Projection: foo.a + TableScan: foo + Projection: bar.a + TableScan: bar + "); Ok(()) } @@ -2407,14 +2377,15 @@ mod tests { .filter(in_subquery(col("a"), Arc::new(subquery)))? .build()?; - let expected = "Filter: bar.a IN ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.a\ - \n TableScan: foo\ - \n Projection: bar.a\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Filter: bar.a IN () + Subquery: + Filter: foo.a = bar.a + Projection: foo.a + TableScan: foo + Projection: bar.a + TableScan: bar + "); Ok(()) } @@ -2434,13 +2405,14 @@ mod tests { .project(vec![scalar_subquery(Arc::new(subquery))])? .build()?; - let expected = "Projection: ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.b\ - \n TableScan: foo\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Projection: () + Subquery: + Filter: foo.a = bar.a + Projection: foo.b + TableScan: foo + TableScan: bar + "); Ok(()) } @@ -2457,20 +2429,24 @@ mod tests { .project(vec![col("id"), col("first_name").alias("id")]); match plan { - Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: - Column { - relation: Some(TableReference::Bare { table }), - name, - spans: _, - }, - }, - _, - )) => { - assert_eq!(*"employee_csv", *table); - assert_eq!("id", &name); - Ok(()) + Err(DataFusionError::SchemaError(err, _)) => { + if let SchemaError::AmbiguousReference { field } = *err { + let Column { + relation, + name, + spans: _, + } = *field; + let Some(TableReference::Bare { table }) = relation else { + return plan_err!( + "wrong relation: {relation:?}, expected table name" + ); + }; + assert_eq!(*"employee_csv", *table); + assert_eq!("id", &name); + Ok(()) + } else { + plan_err!("Plan should have returned an DataFusionError::SchemaError") + } } _ => plan_err!("Plan should have returned an DataFusionError::SchemaError"), } @@ -2534,13 +2510,11 @@ mod tests { let plan2 = table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?; - let expected = "Error during planning: INTERSECT/EXCEPT query must have the same number of columns. \ - Left is 1 and right is 2."; let err_msg1 = LogicalPlanBuilder::intersect(plan1.build()?, plan2.build()?, true) .unwrap_err(); - assert_eq!(err_msg1.strip_backtrace(), expected); + assert_snapshot!(err_msg1.strip_backtrace(), @"Error during planning: INTERSECT/EXCEPT query must have the same number of columns. Left is 1 and right is 2."); Ok(()) } @@ -2551,19 +2525,29 @@ mod tests { let err = nested_table_scan("test_table")? .unnest_column("scalar") .unwrap_err(); - assert!(err - .to_string() - .starts_with("Internal error: trying to unnest on invalid data type UInt32")); + + let DataFusionError::Internal(desc) = err else { + return plan_err!("Plan should have returned an DataFusionError::Internal"); + }; + + let desc = (*desc + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .first() + .unwrap_or(&"")) + .to_string(); + + assert_snapshot!(desc, @"trying to unnest on invalid data type UInt32"); // Unnesting the strings list. let plan = nested_table_scan("test_table")? .unnest_column("strings")? .build()?; - let expected = "\ - Unnest: lists[test_table.strings|depth=1] structs[]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.strings|depth=1] structs[] + TableScan: test_table + "); // Check unnested field is a scalar let field = plan.schema().field_with_name(None, "strings").unwrap(); @@ -2574,16 +2558,16 @@ mod tests { .unnest_column("struct_singular")? .build()?; - let expected = "\ - Unnest: lists[] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[] structs[test_table.struct_singular] + TableScan: test_table + "); for field_name in &["a", "b"] { // Check unnested struct field is a scalar let field = plan .schema() - .field_with_name(None, &format!("struct_singular.{}", field_name)) + .field_with_name(None, &format!("struct_singular.{field_name}")) .unwrap(); assert_eq!(&DataType::UInt32, field.data_type()); } @@ -2595,12 +2579,12 @@ mod tests { .unnest_column("struct_singular")? .build()?; - let expected = "\ - Unnest: lists[] structs[test_table.struct_singular]\ - \n Unnest: lists[test_table.structs|depth=1] structs[]\ - \n Unnest: lists[test_table.strings|depth=1] structs[]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[] structs[test_table.struct_singular] + Unnest: lists[test_table.structs|depth=1] structs[] + Unnest: lists[test_table.strings|depth=1] structs[] + TableScan: test_table + "); // Check unnested struct list field should be a struct. let field = plan.schema().field_with_name(None, "structs").unwrap(); @@ -2616,10 +2600,10 @@ mod tests { .unnest_columns_with_options(cols, UnnestOptions::default())? .build()?; - let expected = "\ - Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular] + TableScan: test_table + "); // Unnesting missing column should fail. let plan = nested_table_scan("test_table")?.unnest_column("missing"); @@ -2643,10 +2627,10 @@ mod tests { )? .build()?; - let expected = "\ - Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular] + TableScan: test_table + "); // Check output columns has correct type let field = plan @@ -2666,7 +2650,7 @@ mod tests { for field_name in &["a", "b"] { let field = plan .schema() - .field_with_name(None, &format!("struct_singular.{}", field_name)) + .field_with_name(None, &format!("struct_singular.{field_name}")) .unwrap(); assert_eq!(&DataType::UInt32, field.data_type()); } @@ -2718,34 +2702,24 @@ mod tests { let join = LogicalPlanBuilder::from(left).cross_join(right)?.build()?; - let _ = LogicalPlanBuilder::from(join.clone()) + let plan = LogicalPlanBuilder::from(join.clone()) .union(join)? .build()?; - Ok(()) - } - - #[test] - fn test_change_redundant_column() -> Result<()> { - let t1_field_1 = Field::new("a", DataType::Int32, false); - let t2_field_1 = Field::new("a", DataType::Int32, false); - let t2_field_3 = Field::new("a", DataType::Int32, false); - let t1_field_2 = Field::new("b", DataType::Int32, false); - let t2_field_2 = Field::new("b", DataType::Int32, false); - - let field_vec = vec![t1_field_1, t2_field_1, t1_field_2, t2_field_2, t2_field_3]; - let remove_redundant = change_redundant_column(&Fields::from(field_vec)); + assert_snapshot!(plan, @r" + Union + Cross Join: + SubqueryAlias: left + Values: (Int32(1)) + SubqueryAlias: right + Values: (Int32(1)) + Cross Join: + SubqueryAlias: left + Values: (Int32(1)) + SubqueryAlias: right + Values: (Int32(1)) + "); - assert_eq!( - remove_redundant, - vec![ - Field::new("a", DataType::Int32, false), - Field::new("a:1", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("b:1", DataType::Int32, false), - Field::new("a:2", DataType::Int32, false), - ] - ); Ok(()) } @@ -2777,10 +2751,10 @@ mod tests { .aggregate(vec![col("id")], vec![sum(col("salary"))])? .build()?; - let expected = - "Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\ - \n TableScan: employee_csv projection=[id, state, salary]"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]] + TableScan: employee_csv projection=[id, state, salary] + "); Ok(()) } @@ -2799,11 +2773,102 @@ mod tests { .aggregate(vec![col("id")], vec![sum(col("salary"))])? .build()?; - let expected = - "Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\ - \n TableScan: employee_csv projection=[id, state, salary]"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]] + TableScan: employee_csv projection=[id, state, salary] + "); + + Ok(()) + } + + #[test] + fn test_join_metadata() -> Result<()> { + let left_schema = DFSchema::new_with_metadata( + vec![(None, Arc::new(Field::new("a", DataType::Int32, false)))], + HashMap::from([("key".to_string(), "left".to_string())]), + )?; + let right_schema = DFSchema::new_with_metadata( + vec![(None, Arc::new(Field::new("b", DataType::Int32, false)))], + HashMap::from([("key".to_string(), "right".to_string())]), + )?; + + let join_schema = + build_join_schema(&left_schema, &right_schema, &JoinType::Left)?; + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "left".to_string())]) + ); + let join_schema = + build_join_schema(&left_schema, &right_schema, &JoinType::Right)?; + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "right".to_string())]) + ); Ok(()) } + + #[test] + fn test_values_metadata() -> Result<()> { + let metadata: HashMap = + [("ARROW:extension:metadata".to_string(), "test".to_string())] + .into_iter() + .collect(); + let metadata = FieldMetadata::from(metadata); + let values = LogicalPlanBuilder::values(vec![ + vec![lit_with_metadata(1, Some(metadata.clone()))], + vec![lit_with_metadata(2, Some(metadata.clone()))], + ])? + .build()?; + assert_eq!(*values.schema().field(0).metadata(), metadata.to_hashmap()); + + // Do not allow VALUES with different metadata mixed together + let metadata2: HashMap = + [("ARROW:extension:metadata".to_string(), "test2".to_string())] + .into_iter() + .collect(); + let metadata2 = FieldMetadata::from(metadata2); + assert!(LogicalPlanBuilder::values(vec![ + vec![lit_with_metadata(1, Some(metadata.clone()))], + vec![lit_with_metadata(2, Some(metadata2.clone()))], + ]) + .is_err()); + + Ok(()) + } + + #[test] + fn test_unique_field_aliases() { + let t1_field_1 = Field::new("a", DataType::Int32, false); + let t2_field_1 = Field::new("a", DataType::Int32, false); + let t2_field_3 = Field::new("a", DataType::Int32, false); + let t2_field_4 = Field::new("a:1", DataType::Int32, false); + let t1_field_2 = Field::new("b", DataType::Int32, false); + let t2_field_2 = Field::new("b", DataType::Int32, false); + + let fields = vec![ + t1_field_1, t2_field_1, t1_field_2, t2_field_2, t2_field_3, t2_field_4, + ]; + let fields = Fields::from(fields); + + let remove_redundant = unique_field_aliases(&fields); + + // Input [a, a, b, b, a, a:1] becomes [None, a:1, None, b:1, a:2, a:1:1] + // First occurrence of each field name keeps original name (None), duplicates get + // incremental suffixes (:1, :2, etc.). + // Crucially in this case the 2nd occurrence of `a` gets rewritten to `a:1` which later + // conflicts with the last column which is _actually_ called `a:1` so we need to rename it + // as well to `a:1:1`. + assert_eq!( + remove_redundant, + vec![ + None, + Some("a:1".to_string()), + None, + Some("b:1".to_string()), + Some("a:2".to_string()), + Some("a:1:1".to_string()), + ] + ); + } } diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 827e2812ecae1..74fe7a2d009d0 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -24,12 +24,15 @@ use std::{ hash::{Hash, Hasher}, }; +#[cfg(not(feature = "sql"))] +use crate::expr::Ident; use crate::expr::Sort; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion}; use datafusion_common::{ Constraints, DFSchemaRef, Result, SchemaReference, TableReference, }; +#[cfg(feature = "sql")] use sqlparser::ast::Ident; /// Various types of DDL (CREATE / DROP) catalog manipulation @@ -213,6 +216,8 @@ pub struct CreateExternalTable { pub table_partition_cols: Vec, /// Option to not error if table already exists pub if_not_exists: bool, + /// Option to replace table content if table already exists + pub or_replace: bool, /// Whether the table is a temporary table pub temporary: bool, /// SQL used to create the table, if available @@ -292,7 +297,10 @@ impl PartialOrd for CreateExternalTable { unbounded: &other.unbounded, constraints: &other.constraints, }; - comparable_self.partial_cmp(&comparable_other) + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -348,6 +356,8 @@ impl PartialOrd for CreateCatalog { Some(Ordering::Equal) => self.if_not_exists.partial_cmp(&other.if_not_exists), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -369,6 +379,8 @@ impl PartialOrd for CreateCatalogSchema { Some(Ordering::Equal) => self.if_not_exists.partial_cmp(&other.if_not_exists), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -390,6 +402,8 @@ impl PartialOrd for DropTable { Some(Ordering::Equal) => self.if_exists.partial_cmp(&other.if_exists), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -411,6 +425,8 @@ impl PartialOrd for DropView { Some(Ordering::Equal) => self.if_exists.partial_cmp(&other.if_exists), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -437,17 +453,25 @@ impl PartialOrd for DropCatalogSchema { }, cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } -/// Arguments passed to `CREATE FUNCTION` +/// Arguments passed to the `CREATE FUNCTION` statement +/// +/// These statements are turned into executable functions using [`FunctionFactory`] +/// +/// # Notes /// -/// Note this meant to be the same as from sqlparser's [`sqlparser::ast::Statement::CreateFunction`] +/// This structure purposely mirrors the structure in sqlparser's +/// [`sqlparser::ast::Statement::CreateFunction`], but does not use it directly +/// to avoid a dependency on sqlparser in the core crate. +/// +/// +/// [`FunctionFactory`]: https://docs.rs/datafusion/latest/datafusion/execution/context/trait.FunctionFactory.html #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct CreateFunction { - // TODO: There is open question should we expose sqlparser types or redefine them here? - // At the moment it make more sense to expose sqlparser types and leave - // user to convert them as needed pub or_replace: bool, pub temporary: bool, pub name: String, @@ -486,10 +510,16 @@ impl PartialOrd for CreateFunction { return_type: &other.return_type, params: &other.params, }; - comparable_self.partial_cmp(&comparable_other) + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } +/// Part of the `CREATE FUNCTION` statement +/// +/// See [`CreateFunction`] for details #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct OperateFunctionArg { // TODO: figure out how to support mode @@ -520,6 +550,9 @@ impl<'a> TreeNodeContainer<'a, Expr> for OperateFunctionArg { } } +/// Part of the `CREATE FUNCTION` statement +/// +/// See [`CreateFunction`] for details #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct CreateFunctionBody { /// LANGUAGE lang_name @@ -566,6 +599,8 @@ impl PartialOrd for DropFunction { Some(Ordering::Equal) => self.if_exists.partial_cmp(&other.if_exists), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -608,7 +643,10 @@ impl PartialOrd for CreateIndex { unique: &other.unique, if_not_exists: &other.if_not_exists, }; - comparable_self.partial_cmp(&comparable_other) + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 14758b61e859d..ea08c223e8f4d 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -31,7 +31,7 @@ use crate::dml::CopyTo; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::{Column, DataFusionError}; +use datafusion_common::{internal_datafusion_err, Column, DataFusionError}; use serde_json::json; /// Formats plans with a single line per node. For example: @@ -72,11 +72,7 @@ impl<'n> TreeNodeVisitor<'n> for IndentVisitor<'_, '_> { write!(self.f, "{:indent$}", "", indent = self.indent * 2)?; write!(self.f, "{}", plan.display())?; if self.with_schema { - write!( - self.f, - " {}", - display_schema(&plan.schema().as_ref().to_owned().into()) - )?; + write!(self.f, " {}", display_schema(plan.schema().as_arrow()))?; } self.indent += 1; @@ -196,7 +192,7 @@ impl<'n> TreeNodeVisitor<'n> for GraphvizVisitor<'_, '_> { format!( r"{}\nSchema: {}", plan.display(), - display_schema(&plan.schema().as_ref().to_owned().into()) + display_schema(plan.schema().as_arrow()) ) } else { format!("{}", plan.display()) @@ -204,14 +200,14 @@ impl<'n> TreeNodeVisitor<'n> for GraphvizVisitor<'_, '_> { self.graphviz_builder .add_node(self.f, id, &label, None) - .map_err(|_e| DataFusionError::Internal("Fail to format".to_string()))?; + .map_err(|_e| internal_datafusion_err!("Fail to format"))?; // Create an edge to our parent node, if any // parent_id -> id if let Some(parent_id) = self.parent_ids.last() { self.graphviz_builder .add_edge(self.f, *parent_id, id) - .map_err(|_e| DataFusionError::Internal("Fail to format".to_string()))?; + .map_err(|_e| internal_datafusion_err!("Fail to format"))?; } self.parent_ids.push(id); @@ -225,7 +221,7 @@ impl<'n> TreeNodeVisitor<'n> for GraphvizVisitor<'_, '_> { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); - res.ok_or(DataFusionError::Internal("Fail to format".to_string())) + res.ok_or(internal_datafusion_err!("Fail to format")) .map(|_| TreeNodeRecursion::Continue) } } @@ -341,7 +337,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { let eclipse = if values.len() > 5 { "..." } else { "" }; - let values_str = format!("{}{}", str_values, eclipse); + let values_str = format!("{str_values}{eclipse}"); json!({ "Node Type": "Values", "Values": values_str @@ -426,10 +422,11 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { file_type, partition_by: _, options, + output_schema: _, }) => { let op_str = options .iter() - .map(|(k, v)| format!("{}={}", k, v)) + .map(|(k, v)| format!("{k}={v}")) .collect::>() .join(", "); json!({ @@ -689,9 +686,10 @@ impl<'n> TreeNodeVisitor<'n> for PgJsonVisitor<'_, '_> { ) -> datafusion_common::Result { let id = self.parent_ids.pop().unwrap(); - let current_node = self.objects.remove(&id).ok_or_else(|| { - DataFusionError::Internal("Missing current node!".to_string()) - })?; + let current_node = self + .objects + .remove(&id) + .ok_or_else(|| internal_datafusion_err!("Missing current node!"))?; if let Some(parent_id) = self.parent_ids.last() { let parent_node = self @@ -722,13 +720,14 @@ impl<'n> TreeNodeVisitor<'n> for PgJsonVisitor<'_, '_> { #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; + use insta::assert_snapshot; use super::*; #[test] fn test_display_empty_schema() { let schema = Schema::empty(); - assert_eq!("[]", format!("{}", display_schema(&schema))); + assert_snapshot!(display_schema(&schema), @"[]"); } #[test] @@ -738,9 +737,6 @@ mod tests { Field::new("first_name", DataType::Utf8, true), ]); - assert_eq!( - "[id:Int32, first_name:Utf8;N]", - format!("{}", display_schema(&schema)) - ); + assert_snapshot!(display_schema(&schema), @"[id:Int32, first_name:Utf8;N]"); } } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index d4d50ac4eae4e..b8448a5da6c42 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -40,6 +40,8 @@ pub struct CopyTo { pub file_type: Arc, /// SQL Options that can affect the formats pub options: HashMap, + /// The schema of the output (a single column "count") + pub output_schema: DFSchemaRef, } impl Debug for CopyTo { @@ -50,6 +52,7 @@ impl Debug for CopyTo { .field("partition_by", &self.partition_by) .field("file_type", &"...") .field("options", &self.options) + .field("output_schema", &self.output_schema) .finish_non_exhaustive() } } @@ -78,6 +81,8 @@ impl PartialOrd for CopyTo { }, cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -89,8 +94,48 @@ impl Hash for CopyTo { } } -/// The operator that modifies the content of a database (adapted from -/// substrait WriteRel) +impl CopyTo { + pub fn new( + input: Arc, + output_url: String, + partition_by: Vec, + file_type: Arc, + options: HashMap, + ) -> Self { + Self { + input, + output_url, + partition_by, + file_type, + options, + // The output schema is always a single column "count" with the number of rows copied + output_schema: make_count_schema(), + } + } +} + +/// Modifies the content of a database +/// +/// This operator is used to perform DML operations such as INSERT, DELETE, +/// UPDATE, and CTAS (CREATE TABLE AS SELECT). +/// +/// * `INSERT` - Appends new rows to the existing table. Calls +/// [`TableProvider::insert_into`] +/// +/// * `DELETE` - Removes rows from the table. Currently NOT supported by the +/// [`TableProvider`] trait or builtin sources. +/// +/// * `UPDATE` - Modifies existing rows in the table. Currently NOT supported by +/// the [`TableProvider`] trait or builtin sources. +/// +/// * `CREATE TABLE AS SELECT` - Creates a new table and populates it with data +/// from a query. This is similar to the `INSERT` operation, but it creates a new +/// table instead of modifying an existing one. +/// +/// Note that the structure is adapted from substrait WriteRel) +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +/// [`TableProvider::insert_into`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#method.insert_into #[derive(Clone)] pub struct DmlStatement { /// The table name @@ -174,14 +219,23 @@ impl PartialOrd for DmlStatement { }, cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } +/// The type of DML operation to perform. +/// +/// See [`DmlStatement`] for more details. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WriteOp { + /// `INSERT INTO` operation Insert(InsertOp), + /// `DELETE` operation Delete, + /// `UPDATE` operation Update, + /// `CREATE TABLE AS SELECT` operation Ctas, } diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index 5bf64a36a6540..a8ee7885644a7 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -57,7 +57,7 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { fn schema(&self) -> &DFSchemaRef; /// Perform check of invariants for the extension node. - fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>; + fn check_invariants(&self, check: InvariantLevel) -> Result<()>; /// Returns all expressions in the current logical plan node. This should /// not include expressions of any inputs (aka non-recursively). @@ -150,7 +150,7 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// directly because it must remain object safe. fn dyn_hash(&self, state: &mut dyn Hasher); - /// Compare `other`, respecting requirements from [std::cmp::Eq]. + /// Compare `other`, respecting requirements from [Eq]. /// /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of /// [`UserDefinedLogicalNode`] directly. @@ -188,6 +188,9 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Eq`] /// directly because it must remain object safe. fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool; + + /// Compare `other`, respecting requirements from [PartialOrd]. + /// Must return `Some(Equal)` if and only if `self.dyn_eq(other)`. fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option; /// Returns `true` if a limit can be safely pushed down through this @@ -241,11 +244,7 @@ pub trait UserDefinedLogicalNodeCore: /// Perform check of invariants for the extension node. /// /// This is the default implementation for extension nodes. - fn check_invariants( - &self, - _check: InvariantLevel, - _plan: &LogicalPlan, - ) -> Result<()> { + fn check_invariants(&self, _check: InvariantLevel) -> Result<()> { Ok(()) } @@ -316,7 +315,7 @@ pub trait UserDefinedLogicalNodeCore: } /// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode` -/// to avoid boiler plate for implementing `as_any`, `Hash` and `PartialEq` +/// to avoid boiler plate for implementing `as_any`, `Hash`, `PartialEq` and `PartialOrd`. impl UserDefinedLogicalNode for T { fn as_any(&self) -> &dyn Any { self @@ -334,8 +333,8 @@ impl UserDefinedLogicalNode for T { self.schema() } - fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> { - self.check_invariants(check, plan) + fn check_invariants(&self, check: InvariantLevel) -> Result<()> { + self.check_invariants(check) } fn expressions(&self) -> Vec { diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index d83410bf99c98..ccdf9e444b8fd 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -74,7 +74,7 @@ pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> { fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> { plan.apply_with_subqueries(|plan: &LogicalPlan| { if let LogicalPlan::Extension(Extension { node }) = plan { - node.check_invariants(check, plan)?; + node.check_invariants(check)?; } plan.apply_expressions(|expr| { // recursively look for subqueries @@ -102,7 +102,7 @@ fn assert_unique_field_names(plan: &LogicalPlan) -> Result<()> { plan.schema().check_names() } -/// Returns an error if the plan is not sematically valid. +/// Returns an error if the plan is not semantically valid. fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> { assert_subqueries_are_valid(plan)?; @@ -112,11 +112,11 @@ fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> { /// Returns an error if the plan does not have the expected schema. /// Ignores metadata and nullability. pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> { - let compatible = plan.schema().has_equivalent_names_and_types(schema); + let compatible = plan.schema().logically_equivalent_names_and_types(schema); - if let Err(e) = compatible { + if !compatible { internal_err!( - "Failed due to a difference in schemas: {e}, original schema: {:?}, new schema: {:?}", + "Failed due to a difference in schemas: original schema: {:?}, new schema: {:?}", schema, plan.schema() ) @@ -310,7 +310,10 @@ fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> { check_inner_plan(left)?; check_no_outer_references(right) } - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => { check_no_outer_references(left)?; check_inner_plan(right) } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a55f4d97b2126..7de2fd117487a 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -27,8 +27,9 @@ mod statement; pub mod tree_node; pub use builder::{ - build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, - LogicalPlanBuilder, LogicalPlanBuilderOptions, LogicalTableSource, UNNAMED_TABLE, + build_join_schema, requalify_sides_if_needed, table_scan, union, + wrap_projection_for_join_if_necessary, LogicalPlanBuilder, LogicalPlanBuilderOptions, + LogicalTableSource, UNNAMED_TABLE, }; pub use ddl::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, @@ -38,7 +39,7 @@ pub use ddl::{ pub use dml::{DmlStatement, WriteOp}; pub use plan::{ projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, - DistinctOn, EmptyRelation, Explain, ExplainFormat, Extension, FetchType, Filter, + DistinctOn, EmptyRelation, Explain, ExplainOption, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, @@ -48,6 +49,8 @@ pub use statement::{ TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, }; +pub use datafusion_common::format::ExplainFormat; + pub use display::display_schema; pub use extension::{UserDefinedLogicalNode, UserDefinedLogicalNodeCore}; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 76b45d5d723ae..b8200ab8a48c3 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,7 +21,6 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::str::FromStr; use std::sync::{Arc, LazyLock}; use super::dml::CopyTo; @@ -30,8 +29,11 @@ use super::invariants::{ InvariantLevel, }; use super::DdlStatement; -use crate::builder::{change_redundant_column, unnest_with_options}; -use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction, WindowFunctionParams}; +use crate::builder::{unique_field_aliases, unnest_with_options}; +use crate::expr::{ + intersect_metadata_for_union, Alias, Placeholder, Sort as SortExpr, WindowFunction, + WindowFunctionParams, +}; use crate::expr_rewriter::{ create_col_from_scalar_expr, normalize_cols, normalize_sorts, NamePreserver, }; @@ -43,21 +45,23 @@ use crate::utils::{ grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Execute, - Expr, ExprSchemable, LogicalPlanBuilder, Operator, Prepare, - TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, + build_join_schema, expr_vec_fmt, requalify_sides_if_needed, BinaryExpr, + CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, LogicalPlanBuilder, + Operator, Prepare, TableProviderFilterPushDown, TableSource, + WindowFunctionDefinition, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::cse::{NormalizeEq, Normalizeable}; +use datafusion_common::format::ExplainFormat; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, ScalarValue, Spans, TableReference, - UnnestOptions, + FunctionalDependencies, NullEquality, ParamValues, Result, ScalarValue, Spans, + TableReference, UnnestOptions, }; use indexmap::IndexSet; @@ -344,7 +348,7 @@ impl LogicalPlan { output_schema } LogicalPlan::Dml(DmlStatement { output_schema, .. }) => output_schema, - LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), + LogicalPlan::Copy(CopyTo { output_schema, .. }) => output_schema, LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { @@ -556,7 +560,9 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left.head_output_expr() } - JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.head_output_expr() + } }, LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { static_term.head_output_expr() @@ -630,12 +636,9 @@ impl LogicalPlan { // todo it isn't clear why the schema is not recomputed here Ok(LogicalPlan::Values(Values { schema, values })) } - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => Filter::try_new_internal(predicate, input, having) - .map(LogicalPlan::Filter), + LogicalPlan::Filter(Filter { predicate, input }) => { + Filter::try_new(predicate, input).map(LogicalPlan::Filter) + } LogicalPlan::Repartition(_) => Ok(self), LogicalPlan::Window(Window { input, @@ -658,7 +661,7 @@ impl LogicalPlan { join_constraint, on, schema: _, - null_equals_null, + null_equality, }) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -679,7 +682,7 @@ impl LogicalPlan { on: new_on, filter, schema: DFSchemaRef::new(schema), - null_equals_null, + null_equality, })) } LogicalPlan::Subquery(_) => Ok(self), @@ -810,16 +813,17 @@ impl LogicalPlan { file_type, options, partition_by, + output_schema: _, }) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; - Ok(LogicalPlan::Copy(CopyTo { - input: Arc::new(input), - output_url: output_url.clone(), - file_type: Arc::clone(file_type), - options: options.clone(), - partition_by: partition_by.clone(), - })) + Ok(LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + output_url.clone(), + partition_by.clone(), + Arc::clone(file_type), + options.clone(), + ))) } LogicalPlan::Values(Values { schema, .. }) => { self.assert_no_inputs(inputs)?; @@ -897,7 +901,7 @@ impl LogicalPlan { join_type, join_constraint, on, - null_equals_null, + null_equality, .. }) => { let (left, right) = self.only_two_inputs(inputs)?; @@ -936,7 +940,7 @@ impl LogicalPlan { on: new_on, filter: filter_expr, schema: DFSchemaRef::new(schema), - null_equals_null: *null_equals_null, + null_equality: *null_equality, })) } LogicalPlan::Subquery(Subquery { @@ -991,7 +995,7 @@ impl LogicalPlan { Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { input: Arc::new(input), - constraints: Constraints::empty(), + constraints: Constraints::default(), name: name.clone(), if_not_exists: *if_not_exists, or_replace: *or_replace, @@ -1308,7 +1312,7 @@ impl LogicalPlan { // Empty group_expr will return Some(1) if group_expr .iter() - .all(|expr| matches!(expr, Expr::Literal(_))) + .all(|expr| matches!(expr, Expr::Literal(_, _))) { Some(1) } else { @@ -1343,7 +1347,9 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left.max_rows() } - JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.max_rows() + } }, LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), LogicalPlan::Union(Union { inputs, .. }) => { @@ -1458,7 +1464,7 @@ impl LogicalPlan { let transformed_expr = e.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::Literal(value, None))) } else { Ok(Transformed::no(e)) } @@ -1717,11 +1723,14 @@ impl LogicalPlan { impl Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.0 { - LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"), + LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row, schema: _ }) => { + let rows = if *produce_one_row { 1 } else { 0 }; + write!(f, "EmptyRelation: rows={rows}") + }, LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }) => { - write!(f, "RecursiveQuery: is_distinct={}", is_distinct) + write!(f, "RecursiveQuery: is_distinct={is_distinct}") } LogicalPlan::Values(Values { ref values, .. }) => { let str_values: Vec<_> = values @@ -1818,12 +1827,12 @@ impl LogicalPlan { Ok(()) } LogicalPlan::Projection(Projection { ref expr, .. }) => { - write!(f, "Projection: ")?; + write!(f, "Projection:")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { - write!(f, ", ")?; + write!(f, ",")?; } - write!(f, "{expr_item}")?; + write!(f, " {expr_item}")?; } Ok(()) } @@ -1964,7 +1973,7 @@ impl LogicalPlan { }; write!( f, - "Limit: skip={}, fetch={}", skip_str,fetch_str, + "Limit: skip={skip_str}, fetch={fetch_str}", ) } LogicalPlan::Subquery(Subquery { .. }) => { @@ -2037,7 +2046,9 @@ impl ToStringifiedPlan for LogicalPlan { } } -/// Produces no rows: An empty relation with an empty schema +/// Relationship produces 0 or 1 placeholder rows with specified output schema +/// In most cases the output schema for `EmptyRelation` would be empty, +/// however, it can be non-empty typically for optimizer rules #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct EmptyRelation { /// Whether to produce a placeholder row @@ -2049,7 +2060,10 @@ pub struct EmptyRelation { // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for EmptyRelation { fn partial_cmp(&self, other: &Self) -> Option { - self.produce_one_row.partial_cmp(&other.produce_one_row) + self.produce_one_row + .partial_cmp(&other.produce_one_row) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -2103,7 +2117,10 @@ pub struct Values { // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for Values { fn partial_cmp(&self, other: &Self) -> Option { - self.values.partial_cmp(&other.values) + self.values + .partial_cmp(&other.values) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -2128,6 +2145,8 @@ impl PartialOrd for Projection { Some(Ordering::Equal) => self.input.partial_cmp(&other.input), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -2176,14 +2195,22 @@ impl Projection { /// will be computed. /// * `exprs`: A slice of `Expr` expressions representing the projection operation to apply. /// +/// # Metadata Handling +/// +/// - **Schema-level metadata**: Passed through unchanged from the input schema +/// - **Field-level metadata**: Determined by each expression via [`exprlist_to_fields`], which +/// calls [`Expr::to_field`] to handle expression-specific metadata (literals, aliases, etc.) +/// /// # Returns /// /// A `Result` containing an `Arc` representing the schema of the result /// produced by the projection operation. If the schema computation is successful, /// the `Result` will contain the schema; otherwise, it will contain an error. pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result> { + // Preserve input schema metadata at the schema level let metadata = input.schema().metadata().clone(); + // Convert expressions to fields with Field properties determined by `Expr::to_field` let schema = DFSchema::new_with_metadata(exprlist_to_fields(exprs, input)?, metadata)? .with_functional_dependencies(calc_func_dependencies_for_project( @@ -2212,15 +2239,47 @@ impl SubqueryAlias { alias: impl Into, ) -> Result { let alias = alias.into(); - let fields = change_redundant_column(plan.schema().fields()); - let meta_data = plan.schema().as_ref().metadata().clone(); - let schema: Schema = - DFSchema::from_unqualified_fields(fields.into(), meta_data)?.into(); - // Since schema is the same, other than qualifier, we can use existing - // functional dependencies: + + // Since SubqueryAlias will replace all field qualification for the output schema of `plan`, + // no field must share the same column name as this would lead to ambiguity when referencing + // columns in parent logical nodes. + + // Compute unique aliases, if any, for each column of the input's schema. + let aliases = unique_field_aliases(plan.schema().fields()); + let is_projection_needed = aliases.iter().any(Option::is_some); + + // Insert a projection node, if needed, to make sure aliases are applied. + let plan = if is_projection_needed { + let projection_expressions = aliases + .iter() + .zip(plan.schema().iter()) + .map(|(alias, (qualifier, field))| { + let column = + Expr::Column(Column::new(qualifier.cloned(), field.name())); + match alias { + None => column, + Some(alias) => { + Expr::Alias(Alias::new(column, qualifier.cloned(), alias)) + } + } + }) + .collect(); + let projection = Projection::try_new(projection_expressions, plan)?; + Arc::new(LogicalPlan::Projection(projection)) + } else { + plan + }; + + // Requalify fields with the new `alias`. + let fields = plan.schema().fields().clone(); + let meta_data = plan.schema().metadata().clone(); let func_dependencies = plan.schema().functional_dependencies().clone(); + + let schema = DFSchema::from_unqualified_fields(fields, meta_data)?; + let schema = schema.as_arrow(); + let schema = DFSchemaRef::new( - DFSchema::try_from_qualified_schema(alias.clone(), &schema)? + DFSchema::try_from_qualified_schema(alias.clone(), schema)? .with_functional_dependencies(func_dependencies)?, ); Ok(SubqueryAlias { @@ -2238,6 +2297,8 @@ impl PartialOrd for SubqueryAlias { Some(Ordering::Equal) => self.alias.partial_cmp(&other.alias), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -2259,8 +2320,6 @@ pub struct Filter { pub predicate: Expr, /// The incoming logical plan pub input: Arc, - /// The flag to indicate if the filter is a having clause - pub having: bool, } impl Filter { @@ -2269,13 +2328,14 @@ impl Filter { /// Notes: as Aliases have no effect on the output of a filter operator, /// they are removed from the predicate expression. pub fn try_new(predicate: Expr, input: Arc) -> Result { - Self::try_new_internal(predicate, input, false) + Self::try_new_internal(predicate, input) } /// Create a new filter operator for a having clause. /// This is similar to a filter, but its having flag is set to true. + #[deprecated(since = "48.0.0", note = "Use `try_new` instead")] pub fn try_new_with_having(predicate: Expr, input: Arc) -> Result { - Self::try_new_internal(predicate, input, true) + Self::try_new_internal(predicate, input) } fn is_allowed_filter_type(data_type: &DataType) -> bool { @@ -2289,11 +2349,7 @@ impl Filter { } } - fn try_new_internal( - predicate: Expr, - input: Arc, - having: bool, - ) -> Result { + fn try_new_internal(predicate: Expr, input: Arc) -> Result { // Filter predicates must return a boolean value so we try and validate that here. // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and @@ -2309,7 +2365,6 @@ impl Filter { Ok(Self { predicate: predicate.unalias_nested().data, input, - having, }) } @@ -2431,18 +2486,23 @@ impl Window { .iter() .enumerate() .filter_map(|(idx, expr)| { - if let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = expr else { + return None; + }; + let WindowFunction { fun: WindowFunctionDefinition::WindowUDF(udwf), params: WindowFunctionParams { partition_by, .. }, - }) = expr - { - // When there is no PARTITION BY, row number will be unique - // across the entire table. - if udwf.name() == "row_number" && partition_by.is_empty() { - return Some(idx + input_len); - } + } = window_fun.as_ref() + else { + return None; + }; + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if udwf.name() == "row_number" && partition_by.is_empty() { + Some(idx + input_len) + } else { + None } - None }) .map(|idx| { FunctionalDependence::new(vec![idx], vec![], false) @@ -2459,6 +2519,20 @@ impl Window { window_func_dependencies.extend(new_deps); } + // Validate that FILTER clauses are only used with aggregate window functions + if let Some(e) = window_expr.iter().find(|e| { + matches!( + e, + Expr::WindowFunction(wf) + if !matches!(wf.fun, WindowFunctionDefinition::AggregateUDF(_)) + && wf.params.filter.is_some() + ) + }) { + return plan_err!( + "FILTER clause can only be used with aggregate window functions. Found in '{e}'" + ); + } + Self::try_new_with_schema( window_expr, input, @@ -2469,16 +2543,22 @@ impl Window { ) } + /// Create a new window function using the provided schema to avoid the overhead of + /// building the schema again when the schema is already known. + /// + /// This method should only be called when you are absolutely sure that the schema being + /// provided is correct for the window function. If in doubt, call [try_new](Self::try_new) instead. pub fn try_new_with_schema( window_expr: Vec, input: Arc, schema: DFSchemaRef, ) -> Result { - if window_expr.len() != schema.fields().len() - input.schema().fields().len() { + let input_fields_count = input.schema().fields().len(); + if schema.fields().len() != input_fields_count + window_expr.len() { return plan_err!( - "Window has mismatch between number of expressions ({}) and number of fields in schema ({})", - window_expr.len(), - schema.fields().len() - input.schema().fields().len() + "Window schema has wrong number of fields. Expected {} got {}", + input_fields_count + window_expr.len(), + schema.fields().len() ); } @@ -2493,9 +2573,22 @@ impl Window { // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for Window { fn partial_cmp(&self, other: &Self) -> Option { - match self.input.partial_cmp(&other.input) { - Some(Ordering::Equal) => self.window_expr.partial_cmp(&other.window_expr), - cmp => cmp, + match self.input.partial_cmp(&other.input)? { + Ordering::Equal => {} // continue + not_equal => return Some(not_equal), + } + + match self.window_expr.partial_cmp(&other.window_expr)? { + Ordering::Equal => {} // continue + not_equal => return Some(not_equal), + } + + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if self == other { + Some(Ordering::Equal) + } else { + None } } } @@ -2569,7 +2662,10 @@ impl PartialOrd for TableScan { filters: &other.filters, fetch: &other.fetch, }; - comparable_self.partial_cmp(&comparable_other) + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -2612,7 +2708,7 @@ impl TableScan { let df_schema = DFSchema::new_with_metadata( p.iter() .map(|i| { - (Some(table_name.clone()), Arc::new(schema.field(*i).clone())) + (Some(table_name.clone()), Arc::clone(&schema.fields()[*i])) }) .collect(), schema.metadata.clone(), @@ -2702,7 +2798,9 @@ impl Union { { expr.push(Expr::Column(column)); } else { - expr.push(Expr::Literal(ScalarValue::Null).alias(column.name())); + expr.push( + Expr::Literal(ScalarValue::Null, None).alias(column.name()), + ); } } wrapped_inputs.push(Arc::new(LogicalPlan::Projection( @@ -2790,15 +2888,16 @@ impl Union { let mut field = Field::new(name, data_type.clone(), final_is_nullable); - field.set_metadata(intersect_maps(unmerged_metadata)); + field.set_metadata(intersect_metadata_for_union(unmerged_metadata)); (None, Arc::new(field)) }, ) .collect::, _)>>(); - let union_schema_metadata = - intersect_maps(inputs.iter().map(|input| input.schema().metadata())); + let union_schema_metadata = intersect_metadata_for_union( + inputs.iter().map(|input| input.schema().metadata()), + ); // Functional Dependencies are not preserved after UNION operation let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?; @@ -2860,21 +2959,23 @@ impl Union { // Generate unique field name let name = if let Some(count) = name_counts.get_mut(&base_name) { *count += 1; - format!("{}_{}", base_name, count) + format!("{base_name}_{count}") } else { name_counts.insert(base_name.clone(), 0); base_name }; let mut field = Field::new(&name, data_type.clone(), nullable); - let field_metadata = - intersect_maps(fields.iter().map(|field| field.metadata())); + let field_metadata = intersect_metadata_for_union( + fields.iter().map(|field| field.metadata()), + ); field.set_metadata(field_metadata); Ok((None, Arc::new(field))) }) .collect::>()?; - let union_schema_metadata = - intersect_maps(inputs.iter().map(|input| input.schema().metadata())); + let union_schema_metadata = intersect_metadata_for_union( + inputs.iter().map(|input| input.schema().metadata()), + ); // Functional Dependencies are not preserved after UNION operation let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?; @@ -2884,25 +2985,13 @@ impl Union { } } -fn intersect_maps<'a>( - inputs: impl IntoIterator>, -) -> HashMap { - let mut inputs = inputs.into_iter(); - let mut merged: HashMap = inputs.next().cloned().unwrap_or_default(); - for input in inputs { - // The extra dereference below (`&*v`) is a workaround for https://github.com/rkyv/rkyv/issues/434. - // When this crate is used in a workspace that enables the `rkyv-64` feature in the `chrono` crate, - // this triggers a Rust compilation error: - // error[E0277]: can't compare `Option<&std::string::String>` with `Option<&mut std::string::String>`. - merged.retain(|k, v| input.get(k) == Some(&*v)); - } - merged -} - // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for Union { fn partial_cmp(&self, other: &Self) -> Option { - self.inputs.partial_cmp(&other.inputs) + self.inputs + .partial_cmp(&other.inputs) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -2945,154 +3034,47 @@ impl PartialOrd for DescribeTable { } } -/// Output formats for controlling for Explain plans +/// Options for EXPLAIN #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum ExplainFormat { - /// Indent mode - /// - /// Example: - /// ```text - /// > explain format indent select x from values (1) t(x); - /// +---------------+-----------------------------------------------------+ - /// | plan_type | plan | - /// +---------------+-----------------------------------------------------+ - /// | logical_plan | SubqueryAlias: t | - /// | | Projection: column1 AS x | - /// | | Values: (Int64(1)) | - /// | physical_plan | ProjectionExec: expr=[column1@0 as x] | - /// | | DataSourceExec: partitions=1, partition_sizes=[1] | - /// | | | - /// +---------------+-----------------------------------------------------+ - /// ``` - Indent, - /// Tree mode - /// - /// Example: - /// ```text - /// > explain format tree select x from values (1) t(x); - /// +---------------+-------------------------------+ - /// | plan_type | plan | - /// +---------------+-------------------------------+ - /// | physical_plan | ┌───────────────────────────┐ | - /// | | │ ProjectionExec │ | - /// | | │ -------------------- │ | - /// | | │ x: column1@0 │ | - /// | | └─────────────┬─────────────┘ | - /// | | ┌─────────────┴─────────────┐ | - /// | | │ DataSourceExec │ | - /// | | │ -------------------- │ | - /// | | │ bytes: 128 │ | - /// | | │ format: memory │ | - /// | | │ rows: 1 │ | - /// | | └───────────────────────────┘ | - /// | | | - /// +---------------+-------------------------------+ - /// ``` - Tree, - /// Postgres Json mode - /// - /// A displayable structure that produces plan in postgresql JSON format. - /// - /// Users can use this format to visualize the plan in existing plan - /// visualization tools, for example [dalibo](https://explain.dalibo.com/) - /// - /// Example: - /// ```text - /// > explain format pgjson select x from values (1) t(x); - /// +--------------+--------------------------------------+ - /// | plan_type | plan | - /// +--------------+--------------------------------------+ - /// | logical_plan | [ | - /// | | { | - /// | | "Plan": { | - /// | | "Alias": "t", | - /// | | "Node Type": "Subquery", | - /// | | "Output": [ | - /// | | "x" | - /// | | ], | - /// | | "Plans": [ | - /// | | { | - /// | | "Expressions": [ | - /// | | "column1 AS x" | - /// | | ], | - /// | | "Node Type": "Projection", | - /// | | "Output": [ | - /// | | "x" | - /// | | ], | - /// | | "Plans": [ | - /// | | { | - /// | | "Node Type": "Values", | - /// | | "Output": [ | - /// | | "column1" | - /// | | ], | - /// | | "Plans": [], | - /// | | "Values": "(Int64(1))" | - /// | | } | - /// | | ] | - /// | | } | - /// | | ] | - /// | | } | - /// | | } | - /// | | ] | - /// +--------------+--------------------------------------+ - /// ``` - PostgresJSON, - /// Graphviz mode - /// - /// Example: - /// ```text - /// > explain format graphviz select x from values (1) t(x); - /// +--------------+------------------------------------------------------------------------+ - /// | plan_type | plan | - /// +--------------+------------------------------------------------------------------------+ - /// | logical_plan | | - /// | | // Begin DataFusion GraphViz Plan, | - /// | | // display it online here: https://dreampuf.github.io/GraphvizOnline | - /// | | | - /// | | digraph { | - /// | | subgraph cluster_1 | - /// | | { | - /// | | graph[label="LogicalPlan"] | - /// | | 2[shape=box label="SubqueryAlias: t"] | - /// | | 3[shape=box label="Projection: column1 AS x"] | - /// | | 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] | - /// | | 4[shape=box label="Values: (Int64(1))"] | - /// | | 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] | - /// | | } | - /// | | subgraph cluster_5 | - /// | | { | - /// | | graph[label="Detailed LogicalPlan"] | - /// | | 6[shape=box label="SubqueryAlias: t\nSchema: [x:Int64;N]"] | - /// | | 7[shape=box label="Projection: column1 AS x\nSchema: [x:Int64;N]"] | - /// | | 6 -> 7 [arrowhead=none, arrowtail=normal, dir=back] | - /// | | 8[shape=box label="Values: (Int64(1))\nSchema: [column1:Int64;N]"] | - /// | | 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] | - /// | | } | - /// | | } | - /// | | // End DataFusion GraphViz Plan | - /// | | | - /// +--------------+------------------------------------------------------------------------+ - /// ``` - Graphviz, +pub struct ExplainOption { + /// Include detailed debug info + pub verbose: bool, + /// Actually execute the plan and report metrics + pub analyze: bool, + /// Output syntax/format + pub format: ExplainFormat, } -/// Implement parsing strings to `ExplainFormat` -impl FromStr for ExplainFormat { - type Err = DataFusionError; - - fn from_str(format: &str) -> std::result::Result { - match format.to_lowercase().as_str() { - "indent" => Ok(ExplainFormat::Indent), - "tree" => Ok(ExplainFormat::Tree), - "pgjson" => Ok(ExplainFormat::PostgresJSON), - "graphviz" => Ok(ExplainFormat::Graphviz), - _ => { - plan_err!("Invalid explain format. Expected 'indent', 'tree', 'pgjson' or 'graphviz'. Got '{format}'") - } +impl Default for ExplainOption { + fn default() -> Self { + ExplainOption { + verbose: false, + analyze: false, + format: ExplainFormat::Indent, } } } +impl ExplainOption { + /// Builder‐style setter for `verbose` + pub fn with_verbose(mut self, verbose: bool) -> Self { + self.verbose = verbose; + self + } + + /// Builder‐style setter for `analyze` + pub fn with_analyze(mut self, analyze: bool) -> Self { + self.analyze = analyze; + self + } + + /// Builder‐style setter for `format` + pub fn with_format(mut self, format: ExplainFormat) -> Self { + self.format = format; + self + } +} + /// Produces a relation with string representations of /// various parts of the plan /// @@ -3142,7 +3124,10 @@ impl PartialOrd for Explain { stringified_plans: &other.stringified_plans, logical_optimization_succeeded: &other.logical_optimization_succeeded, }; - comparable_self.partial_cmp(&comparable_other) + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -3165,6 +3150,8 @@ impl PartialOrd for Analyze { Some(Ordering::Equal) => self.input.partial_cmp(&other.input), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -3228,7 +3215,7 @@ impl Limit { pub fn get_skip_type(&self) -> Result { match self.skip.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(s)) => { + Expr::Literal(ScalarValue::Int64(s), _) => { // `skip = NULL` is equivalent to `skip = 0` let s = s.unwrap_or(0); if s >= 0 { @@ -3248,14 +3235,16 @@ impl Limit { pub fn get_fetch_type(&self) -> Result { match self.fetch.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => { + Expr::Literal(ScalarValue::Int64(Some(s)), _) => { if s >= 0 { Ok(FetchType::Literal(Some(s as usize))) } else { plan_err!("LIMIT must be >= 0, '{}' was provided", s) } } - Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + Expr::Literal(ScalarValue::Int64(None), _) => { + Ok(FetchType::Literal(None)) + } _ => Ok(FetchType::UnsupportedExpr), }, None => Ok(FetchType::Literal(None)), @@ -3390,7 +3379,10 @@ impl PartialOrd for DistinctOn { sort_expr: &other.sort_expr, input: &other.input, }; - comparable_self.partial_cmp(&comparable_other) + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -3475,7 +3467,10 @@ impl Aggregate { ) -> Result { if group_expr.is_empty() && aggr_expr.is_empty() { return plan_err!( - "Aggregate requires at least one grouping or aggregate expression" + "Aggregate requires at least one grouping or aggregate expression. \ + Aggregate without grouping expressions nor aggregate expressions is \ + logically equivalent to, but less efficient than, VALUES producing \ + single row. Please use VALUES instead." ); } let group_expr_count = grouping_set_expr_count(&group_expr)?; @@ -3574,6 +3569,8 @@ impl PartialOrd for Aggregate { } cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -3657,7 +3654,7 @@ fn calc_func_dependencies_for_project( .unwrap_or(vec![])) } _ => { - let name = format!("{}", expr); + let name = format!("{expr}"); Ok(input_fields .iter() .position(|item| *item == name) @@ -3704,42 +3701,107 @@ pub struct Join { pub join_constraint: JoinConstraint, /// The output schema, containing fields from the left and right inputs pub schema: DFSchemaRef, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: bool, + /// Defines the null equality for the join. + pub null_equality: NullEquality, } impl Join { - /// Create Join with input which wrapped with projection, this method is used to help create physical join. + /// Creates a new Join operator with automatically computed schema. + /// + /// This constructor computes the schema based on the join type and inputs, + /// removing the need to manually specify the schema or call `recompute_schema`. + /// + /// # Arguments + /// + /// * `left` - Left input plan + /// * `right` - Right input plan + /// * `on` - Join condition as a vector of (left_expr, right_expr) pairs + /// * `filter` - Optional filter expression (for non-equijoin conditions) + /// * `join_type` - Type of join (Inner, Left, Right, etc.) + /// * `join_constraint` - Join constraint (On, Using) + /// * `null_equality` - How to handle nulls in join comparisons + /// + /// # Returns + /// + /// A new Join operator with the computed schema + pub fn try_new( + left: Arc, + right: Arc, + on: Vec<(Expr, Expr)>, + filter: Option, + join_type: JoinType, + join_constraint: JoinConstraint, + null_equality: NullEquality, + ) -> Result { + let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; + + Ok(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema: Arc::new(join_schema), + null_equality, + }) + } + + /// Create Join with input which wrapped with projection, this method is used in physical planning only to help + /// create the physical join. pub fn try_new_with_project_input( original: &LogicalPlan, left: Arc, right: Arc, column_on: (Vec, Vec), - ) -> Result { + ) -> Result<(Self, bool)> { let original_join = match original { LogicalPlan::Join(join) => join, _ => return plan_err!("Could not create join with project input"), }; + let mut left_sch = LogicalPlanBuilder::from(Arc::clone(&left)); + let mut right_sch = LogicalPlanBuilder::from(Arc::clone(&right)); + + let mut requalified = false; + + // By definition, the resulting schema of an inner/left/right & full join will have first the left side fields and then the right, + // potentially having duplicate field names. Note this will only qualify fields if they have not been qualified before. + if original_join.join_type == JoinType::Inner + || original_join.join_type == JoinType::Left + || original_join.join_type == JoinType::Right + || original_join.join_type == JoinType::Full + { + (left_sch, right_sch, requalified) = + requalify_sides_if_needed(left_sch.clone(), right_sch.clone())?; + } + let on: Vec<(Expr, Expr)> = column_on .0 .into_iter() .zip(column_on.1) .map(|(l, r)| (Expr::Column(l), Expr::Column(r))) .collect(); - let join_schema = - build_join_schema(left.schema(), right.schema(), &original_join.join_type)?; - Ok(Join { - left, - right, - on, - filter: original_join.filter.clone(), - join_type: original_join.join_type, - join_constraint: original_join.join_constraint, - schema: Arc::new(join_schema), - null_equals_null: original_join.null_equals_null, - }) + let join_schema = build_join_schema( + left_sch.schema(), + right_sch.schema(), + &original_join.join_type, + )?; + + Ok(( + Join { + left, + right, + on, + filter: original_join.filter.clone(), + join_type: original_join.join_type, + join_constraint: original_join.join_constraint, + schema: Arc::new(join_schema), + null_equality: original_join.null_equality, + }, + requalified, + )) } } @@ -3760,8 +3822,8 @@ impl PartialOrd for Join { pub join_type: &'a JoinType, /// Join constraint pub join_constraint: &'a JoinConstraint, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: &'a bool, + /// The null handling behavior for equalities + pub null_equality: &'a NullEquality, } let comparable_self = ComparableJoin { left: &self.left, @@ -3770,7 +3832,7 @@ impl PartialOrd for Join { filter: &self.filter, join_type: &self.join_type, join_constraint: &self.join_constraint, - null_equals_null: &self.null_equals_null, + null_equality: &self.null_equality, }; let comparable_other = ComparableJoin { left: &other.left, @@ -3779,9 +3841,12 @@ impl PartialOrd for Join { filter: &other.filter, join_type: &other.join_type, join_constraint: &other.join_constraint, - null_equals_null: &other.null_equals_null, + null_equality: &other.null_equality, }; - comparable_self.partial_cmp(&comparable_other) + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -3946,27 +4011,231 @@ impl PartialOrd for Unnest { dependency_indices: &other.dependency_indices, options: &other.options, }; - comparable_self.partial_cmp(&comparable_other) + comparable_self + .partial_cmp(&comparable_other) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } +impl Unnest { + pub fn try_new( + input: Arc, + exec_columns: Vec, + options: UnnestOptions, + ) -> Result { + if exec_columns.is_empty() { + return plan_err!("unnest plan requires at least 1 column to unnest"); + } + + let mut list_columns: Vec<(usize, ColumnUnnestList)> = vec![]; + let mut struct_columns = vec![]; + let indices_to_unnest = exec_columns + .iter() + .map(|c| Ok((input.schema().index_of_column(c)?, c))) + .collect::>>()?; + + let input_schema = input.schema(); + + let mut dependency_indices = vec![]; + // Transform input schema into new schema + // Given this comprehensive example + // + // input schema: + // 1.col1_unnest_placeholder: list[list[int]], + // 2.col1: list[list[int]] + // 3.col2: list[int] + // with unnest on unnest(col1,depth=2), unnest(col1,depth=1) and unnest(col2,depth=1) + // output schema: + // 1.unnest_col1_depth_2: int + // 2.unnest_col1_depth_1: list[int] + // 3.col1: list[list[int]] + // 4.unnest_col2_depth_1: int + // Meaning the placeholder column will be replaced by its unnested variation(s), note + // the plural. + let fields = input_schema + .iter() + .enumerate() + .map(|(index, (original_qualifier, original_field))| { + match indices_to_unnest.get(&index) { + Some(column_to_unnest) => { + let recursions_on_column = options + .recursions + .iter() + .filter(|p| -> bool { &p.input_column == *column_to_unnest }) + .collect::>(); + let mut transformed_columns = recursions_on_column + .iter() + .map(|r| { + list_columns.push(( + index, + ColumnUnnestList { + output_column: r.output_column.clone(), + depth: r.depth, + }, + )); + Ok(get_unnested_columns( + &r.output_column.name, + original_field.data_type(), + r.depth, + )? + .into_iter() + .next() + .unwrap()) // because unnesting a list column always result into one result + }) + .collect::)>>>()?; + if transformed_columns.is_empty() { + transformed_columns = get_unnested_columns( + &column_to_unnest.name, + original_field.data_type(), + 1, + )?; + match original_field.data_type() { + DataType::Struct(_) => { + struct_columns.push(index); + } + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) => { + list_columns.push(( + index, + ColumnUnnestList { + output_column: Column::from_name( + &column_to_unnest.name, + ), + depth: 1, + }, + )); + } + _ => {} + }; + } + + // new columns dependent on the same original index + dependency_indices.extend(std::iter::repeat_n( + index, + transformed_columns.len(), + )); + Ok(transformed_columns + .iter() + .map(|(col, field)| { + (col.relation.to_owned(), field.to_owned()) + }) + .collect()) + } + None => { + dependency_indices.push(index); + Ok(vec![( + original_qualifier.cloned(), + Arc::clone(original_field), + )]) + } + } + }) + .collect::>>()? + .into_iter() + .flatten() + .collect::>(); + + let metadata = input_schema.metadata().clone(); + let df_schema = DFSchema::new_with_metadata(fields, metadata)?; + // We can use the existing functional dependencies: + let deps = input_schema.functional_dependencies().clone(); + let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); + + Ok(Unnest { + input, + exec_columns, + list_type_columns: list_columns, + struct_type_columns: struct_columns, + dependency_indices, + schema, + options, + }) + } +} + +// Based on data type, either struct or a variant of list +// return a set of columns as the result of unnesting +// the input columns. +// For example, given a column with name "a", +// - List(Element) returns ["a"] with data type Element +// - Struct(field1, field2) returns ["a.field1","a.field2"] +// For list data type, an argument depth is used to specify +// the recursion level +fn get_unnested_columns( + col_name: &String, + data_type: &DataType, + depth: usize, +) -> Result)>> { + let mut qualified_columns = Vec::with_capacity(1); + + match data_type { + DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { + let data_type = get_unnested_list_datatype_recursive(data_type, depth)?; + let new_field = Arc::new(Field::new( + col_name, data_type, + // Unnesting may produce NULLs even if the list is not null. + // For example: unnest([1], []) -> 1, null + true, + )); + let column = Column::from_name(col_name); + // let column = Column::from((None, &new_field)); + qualified_columns.push((column, new_field)); + } + DataType::Struct(fields) => { + qualified_columns.extend(fields.iter().map(|f| { + let new_name = format!("{}.{}", col_name, f.name()); + let column = Column::from_name(&new_name); + let new_field = f.as_ref().clone().with_name(new_name); + // let column = Column::from((None, &f)); + (column, Arc::new(new_field)) + })) + } + _ => { + return internal_err!("trying to unnest on invalid data type {data_type}"); + } + }; + Ok(qualified_columns) +} + +// Get the data type of a multi-dimensional type after unnesting it +// with a given depth +fn get_unnested_list_datatype_recursive( + data_type: &DataType, + depth: usize, +) -> Result { + match data_type { + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) => { + if depth == 1 { + return Ok(field.data_type().clone()); + } + return get_unnested_list_datatype_recursive(field.data_type(), depth - 1); + } + _ => {} + }; + + internal_err!("trying to unnest on invalid data type {data_type}") +} + #[cfg(test)] mod tests { - use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; + use crate::test::function_stub::{count, count_udaf}; use crate::{ binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet, }; - use datafusion_common::tree_node::{ TransformedResult, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; - - use crate::test::function_stub::count; + use insta::{assert_debug_snapshot, assert_snapshot}; + use std::hash::DefaultHasher; fn employee_schema() -> Schema { Schema::new(vec![ @@ -3992,13 +4261,13 @@ mod tests { fn test_display_indent() -> Result<()> { let plan = display_plan()?; - let expected = "Projection: employee_csv.id\ - \n Filter: employee_csv.state IN ()\ - \n Subquery:\ - \n TableScan: employee_csv projection=[state]\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{}", plan.display_indent())); + assert_snapshot!(plan.display_indent(), @r" + Projection: employee_csv.id + Filter: employee_csv.state IN () + Subquery: + TableScan: employee_csv projection=[state] + TableScan: employee_csv projection=[id, state] + "); Ok(()) } @@ -4006,13 +4275,13 @@ mod tests { fn test_display_indent_schema() -> Result<()> { let plan = display_plan()?; - let expected = "Projection: employee_csv.id [id:Int32]\ - \n Filter: employee_csv.state IN () [id:Int32, state:Utf8]\ - \n Subquery: [state:Utf8]\ - \n TableScan: employee_csv projection=[state] [state:Utf8]\ - \n TableScan: employee_csv projection=[id, state] [id:Int32, state:Utf8]"; - - assert_eq!(expected, format!("{}", plan.display_indent_schema())); + assert_snapshot!(plan.display_indent_schema(), @r" + Projection: employee_csv.id [id:Int32] + Filter: employee_csv.state IN () [id:Int32, state:Utf8] + Subquery: [state:Utf8] + TableScan: employee_csv projection=[state] [state:Utf8] + TableScan: employee_csv projection=[id, state] [id:Int32, state:Utf8] + "); Ok(()) } @@ -4027,12 +4296,12 @@ mod tests { .project(vec![col("id"), exists(plan1).alias("exists")])? .build(); - let expected = "Projection: employee_csv.id, EXISTS () AS exists\ - \n Subquery:\ - \n TableScan: employee_csv projection=[state]\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{}", plan?.display_indent())); + assert_snapshot!(plan?.display_indent(), @r" + Projection: employee_csv.id, EXISTS () AS exists + Subquery: + TableScan: employee_csv projection=[state] + TableScan: employee_csv projection=[id, state] + "); Ok(()) } @@ -4040,46 +4309,42 @@ mod tests { fn test_display_graphviz() -> Result<()> { let plan = display_plan()?; - let expected_graphviz = r#" -// Begin DataFusion GraphViz Plan, -// display it online here: https://dreampuf.github.io/GraphvizOnline - -digraph { - subgraph cluster_1 - { - graph[label="LogicalPlan"] - 2[shape=box label="Projection: employee_csv.id"] - 3[shape=box label="Filter: employee_csv.state IN ()"] - 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] - 4[shape=box label="Subquery:"] - 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] - 5[shape=box label="TableScan: employee_csv projection=[state]"] - 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] - 6[shape=box label="TableScan: employee_csv projection=[id, state]"] - 3 -> 6 [arrowhead=none, arrowtail=normal, dir=back] - } - subgraph cluster_7 - { - graph[label="Detailed LogicalPlan"] - 8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"] - 9[shape=box label="Filter: employee_csv.state IN ()\nSchema: [id:Int32, state:Utf8]"] - 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] - 10[shape=box label="Subquery:\nSchema: [state:Utf8]"] - 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] - 11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"] - 10 -> 11 [arrowhead=none, arrowtail=normal, dir=back] - 12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"] - 9 -> 12 [arrowhead=none, arrowtail=normal, dir=back] - } -} -// End DataFusion GraphViz Plan -"#; - // just test for a few key lines in the output rather than the // whole thing to make test maintenance easier. - let graphviz = format!("{}", plan.display_graphviz()); - - assert_eq!(expected_graphviz, graphviz); + assert_snapshot!(plan.display_graphviz(), @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Projection: employee_csv.id"] + 3[shape=box label="Filter: employee_csv.state IN ()"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Subquery:"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: employee_csv projection=[state]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + 6[shape=box label="TableScan: employee_csv projection=[id, state]"] + 3 -> 6 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_7 + { + graph[label="Detailed LogicalPlan"] + 8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"] + 9[shape=box label="Filter: employee_csv.state IN ()\nSchema: [id:Int32, state:Utf8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="Subquery:\nSchema: [state:Utf8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + 11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"] + 10 -> 11 [arrowhead=none, arrowtail=normal, dir=back] + 12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"] + 9 -> 12 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "#); Ok(()) } @@ -4087,60 +4352,58 @@ digraph { fn test_display_pg_json() -> Result<()> { let plan = display_plan()?; - let expected_pg_json = r#"[ - { - "Plan": { - "Expressions": [ - "employee_csv.id" - ], - "Node Type": "Projection", - "Output": [ - "id" - ], - "Plans": [ - { - "Condition": "employee_csv.state IN ()", - "Node Type": "Filter", - "Output": [ - "id", - "state" - ], - "Plans": [ - { - "Node Type": "Subquery", + assert_snapshot!(plan.display_pg_json(), @r#" + [ + { + "Plan": { + "Expressions": [ + "employee_csv.id" + ], + "Node Type": "Projection", "Output": [ - "state" + "id" ], "Plans": [ { - "Node Type": "TableScan", + "Condition": "employee_csv.state IN ()", + "Node Type": "Filter", "Output": [ + "id", "state" ], - "Plans": [], - "Relation Name": "employee_csv" + "Plans": [ + { + "Node Type": "Subquery", + "Output": [ + "state" + ], + "Plans": [ + { + "Node Type": "TableScan", + "Output": [ + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] + }, + { + "Node Type": "TableScan", + "Output": [ + "id", + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] } ] - }, - { - "Node Type": "TableScan", - "Output": [ - "id", - "state" - ], - "Plans": [], - "Relation Name": "employee_csv" } - ] - } - ] - } - } -]"#; - - let pg_json = format!("{}", plan.display_pg_json()); - - assert_eq!(expected_pg_json, pg_json); + } + ] + "#); Ok(()) } @@ -4189,17 +4452,16 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( - visitor.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - "post_visit Filter", - "post_visit Projection", - ] - ); + assert_debug_snapshot!(visitor.strings, @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + "post_visit Filter", + "post_visit Projection", + ] + "#); } #[derive(Debug, Default)] @@ -4265,9 +4527,14 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec!["pre_visit Projection", "pre_visit Filter"] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + ] + "# ); } @@ -4281,14 +4548,16 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - ] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + "# ); } @@ -4330,13 +4599,18 @@ digraph { }; let plan = test_plan(); let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); - assert_eq!( - "This feature is not implemented: Error in pre_visit", - res.strip_backtrace() + assert_snapshot!( + res.strip_backtrace(), + @"This feature is not implemented: Error in pre_visit" ); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec!["pre_visit Projection", "pre_visit Filter"] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + ] + "# ); } @@ -4348,21 +4622,80 @@ digraph { }; let plan = test_plan(); let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); - assert_eq!( - "This feature is not implemented: Error in post_visit", - res.strip_backtrace() + assert_snapshot!( + res.strip_backtrace(), + @"This feature is not implemented: Error in post_visit" ); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - ] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + "# ); } + #[test] + fn test_partial_eq_hash_and_partial_ord() { + let empty_values = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + })); + + let count_window_function = |schema| { + Window::try_new_with_schema( + vec![Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + vec![], + )))], + Arc::clone(&empty_values), + Arc::new(schema), + ) + .unwrap() + }; + + let schema_without_metadata = || { + DFSchema::from_unqualified_fields( + vec![Field::new("count", DataType::Int64, false)].into(), + HashMap::new(), + ) + .unwrap() + }; + + let schema_with_metadata = || { + DFSchema::from_unqualified_fields( + vec![Field::new("count", DataType::Int64, false)].into(), + [("key".to_string(), "value".to_string())].into(), + ) + .unwrap() + }; + + // A Window + let f = count_window_function(schema_without_metadata()); + + // Same like `f`, different instance + let f2 = count_window_function(schema_without_metadata()); + assert_eq!(f, f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Same like `f`, except for schema metadata + let o = count_window_function(schema_with_metadata()); + assert_ne!(f, o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), None); + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } + #[test] fn projection_expr_schema_mismatch() -> Result<()> { let empty_schema = Arc::new(DFSchema::empty()); @@ -4374,7 +4707,7 @@ digraph { })), empty_schema, ); - assert_eq!(p.err().unwrap().strip_backtrace(), "Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)"); + assert_snapshot!(p.unwrap_err().strip_backtrace(), @"Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)"); Ok(()) } @@ -4494,7 +4827,7 @@ digraph { let col = schema.field_names()[0].clone(); let filter = Filter::try_new( - Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)), None)), scan, ) .unwrap(); @@ -4561,11 +4894,12 @@ digraph { .data() .unwrap(); - let expected = "Explain\ - \n Filter: foo = Boolean(true)\ - \n TableScan: ?table?"; let actual = format!("{}", plan.display_indent()); - assert_eq!(expected.to_string(), actual) + assert_snapshot!(actual, @r" + Explain + Filter: foo = Boolean(true) + TableScan: ?table? + ") } #[test] @@ -4620,12 +4954,14 @@ digraph { skip: None, fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), input: Arc::clone(&input), }), LogicalPlan::Limit(Limit { skip: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), fetch: None, input: Arc::clone(&input), @@ -4633,9 +4969,11 @@ digraph { LogicalPlan::Limit(Limit { skip: Some(Box::new(Expr::Literal( ScalarValue::new_one(&DataType::UInt32).unwrap(), + None, ))), fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), input, }), @@ -4837,7 +5175,7 @@ digraph { join_type: JoinType::Inner, join_constraint: JoinConstraint::On, schema: Arc::new(left_schema.join(&right_schema)?), - null_equals_null: false, + null_equality: NullEquality::NullEqualsNothing, })) } @@ -4916,4 +5254,374 @@ digraph { Ok(()) } + + #[test] + fn test_join_try_new() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let left_scan = table_scan(Some("t1"), &schema, None)?.build()?; + + let right_scan = table_scan(Some("t2"), &schema, None)?.build()?; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + JoinType::LeftMark, + ]; + + for join_type in join_types { + let join = Join::try_new( + Arc::new(left_scan.clone()), + Arc::new(right_scan.clone()), + vec![(col("t1.a"), col("t2.a"))], + Some(col("t1.b").gt(col("t2.b"))), + join_type, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + match join_type { + JoinType::LeftSemi | JoinType::LeftAnti => { + assert_eq!(join.schema.fields().len(), 2); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + } + JoinType::RightSemi | JoinType::RightAnti => { + assert_eq!(join.schema.fields().len(), 2); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from right table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from right table" + ); + } + JoinType::LeftMark => { + assert_eq!(join.schema.fields().len(), 3); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + assert_eq!( + fields[2].name(), + "mark", + "Third field should be the mark column" + ); + + assert!(!fields[0].is_nullable()); + assert!(!fields[1].is_nullable()); + assert!(!fields[2].is_nullable()); + } + _ => { + assert_eq!(join.schema.fields().len(), 4); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + assert_eq!( + fields[2].name(), + "a", + "Third field should be 'a' from right table" + ); + assert_eq!( + fields[3].name(), + "b", + "Fourth field should be 'b' from right table" + ); + + if join_type == JoinType::Left { + // Left side fields (first two) shouldn't be nullable + assert!(!fields[0].is_nullable()); + assert!(!fields[1].is_nullable()); + // Right side fields (third and fourth) should be nullable + assert!(fields[2].is_nullable()); + assert!(fields[3].is_nullable()); + } else if join_type == JoinType::Right { + // Left side fields (first two) should be nullable + assert!(fields[0].is_nullable()); + assert!(fields[1].is_nullable()); + // Right side fields (third and fourth) shouldn't be nullable + assert!(!fields[2].is_nullable()); + assert!(!fields[3].is_nullable()); + } else if join_type == JoinType::Full { + assert!(fields[0].is_nullable()); + assert!(fields[1].is_nullable()); + assert!(fields[2].is_nullable()); + assert!(fields[3].is_nullable()); + } + } + } + + assert_eq!(join.on, vec![(col("t1.a"), col("t2.a"))]); + assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b")))); + assert_eq!(join.join_type, join_type); + assert_eq!(join.join_constraint, JoinConstraint::On); + assert_eq!(join.null_equality, NullEquality::NullEqualsNothing); + } + + Ok(()) + } + + #[test] + fn test_join_try_new_with_using_constraint_and_overlapping_columns() -> Result<()> { + let left_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), // Common column in both tables + Field::new("name", DataType::Utf8, false), // Unique to left + Field::new("value", DataType::Int32, false), // Common column, different meaning + ]); + + let right_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), // Common column in both tables + Field::new("category", DataType::Utf8, false), // Unique to right + Field::new("value", DataType::Float64, true), // Common column, different meaning + ]); + + let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?; + + let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?; + + // Test 1: USING constraint with a common column + { + // In the logical plan, both copies of the `id` column are preserved + // The USING constraint is handled later during physical execution, where the common column appears once + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::Using, + NullEquality::NullEqualsNothing, + )?; + + let fields = join.schema.fields(); + + assert_eq!(fields.len(), 6); + + assert_eq!( + fields[0].name(), + "id", + "First field should be 'id' from left table" + ); + assert_eq!( + fields[1].name(), + "name", + "Second field should be 'name' from left table" + ); + assert_eq!( + fields[2].name(), + "value", + "Third field should be 'value' from left table" + ); + assert_eq!( + fields[3].name(), + "id", + "Fourth field should be 'id' from right table" + ); + assert_eq!( + fields[4].name(), + "category", + "Fifth field should be 'category' from right table" + ); + assert_eq!( + fields[5].name(), + "value", + "Sixth field should be 'value' from right table" + ); + + assert_eq!(join.join_constraint, JoinConstraint::Using); + } + + // Test 2: Complex join condition with expressions + { + // Complex condition: join on id equality AND where left.value < right.value + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], // Equijoin condition + Some(col("t1.value").lt(col("t2.value"))), // Non-equi filter condition + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + let fields = join.schema.fields(); + assert_eq!(fields.len(), 6); + + assert_eq!( + fields[0].name(), + "id", + "First field should be 'id' from left table" + ); + assert_eq!( + fields[1].name(), + "name", + "Second field should be 'name' from left table" + ); + assert_eq!( + fields[2].name(), + "value", + "Third field should be 'value' from left table" + ); + assert_eq!( + fields[3].name(), + "id", + "Fourth field should be 'id' from right table" + ); + assert_eq!( + fields[4].name(), + "category", + "Fifth field should be 'category' from right table" + ); + assert_eq!( + fields[5].name(), + "value", + "Sixth field should be 'value' from right table" + ); + + assert_eq!(join.filter, Some(col("t1.value").lt(col("t2.value")))); + } + + // Test 3: Join with null equality behavior set to true + { + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNull, + )?; + + assert_eq!(join.null_equality, NullEquality::NullEqualsNull); + } + + Ok(()) + } + + #[test] + fn test_join_try_new_schema_validation() -> Result<()> { + let left_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("value", DataType::Float64, true), + ]); + + let right_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("category", DataType::Utf8, true), + Field::new("code", DataType::Int16, false), + ]); + + let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?; + + let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + ]; + + for join_type in join_types { + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + Some(col("t1.value").gt(lit(5.0))), + join_type, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + let fields = join.schema.fields(); + assert_eq!(fields.len(), 6, "Expected 6 fields for {join_type} join"); + + for (i, field) in fields.iter().enumerate() { + let expected_nullable = match (i, &join_type) { + // Left table fields (indices 0, 1, 2) + (0, JoinType::Right | JoinType::Full) => true, // id becomes nullable in RIGHT/FULL + (1, JoinType::Right | JoinType::Full) => true, // name becomes nullable in RIGHT/FULL + (2, _) => true, // value is already nullable + + // Right table fields (indices 3, 4, 5) + (3, JoinType::Left | JoinType::Full) => true, // id becomes nullable in LEFT/FULL + (4, _) => true, // category is already nullable + (5, JoinType::Left | JoinType::Full) => true, // code becomes nullable in LEFT/FULL + + _ => false, + }; + + assert_eq!( + field.is_nullable(), + expected_nullable, + "Field {} ({}) nullability incorrect for {:?} join", + i, + field.name(), + join_type + ); + } + } + + let using_join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::Using, + NullEquality::NullEqualsNothing, + )?; + + assert_eq!( + using_join.schema.fields().len(), + 6, + "USING join should have all fields" + ); + assert_eq!(using_join.join_constraint, JoinConstraint::Using); + + Ok(()) + } } diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 82acebee3de66..6d3fe9fa75acf 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -17,6 +17,7 @@ use arrow::datatypes::DataType; use datafusion_common::{DFSchema, DFSchemaRef}; +use itertools::Itertools as _; use std::fmt::{self, Display}; use std::sync::{Arc, LazyLock}; @@ -110,7 +111,7 @@ impl Statement { Statement::Prepare(Prepare { name, data_types, .. }) => { - write!(f, "Prepare: {name:?} {data_types:?} ") + write!(f, "Prepare: {name:?} [{}]", data_types.iter().join(", ")) } Statement::Execute(Execute { name, parameters, .. @@ -123,7 +124,7 @@ impl Statement { ) } Statement::Deallocate(Deallocate { name }) => { - write!(f, "Deallocate: {}", name) + write!(f, "Deallocate: {name}") } } } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 7f6e1e025387c..47088370a1d93 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -85,17 +85,9 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) - }), + LogicalPlan::Filter(Filter { predicate, input }) => input + .map_elements(f)? + .update_data(|input| LogicalPlan::Filter(Filter { predicate, input })), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -140,7 +132,7 @@ impl TreeNode for LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, @@ -150,7 +142,7 @@ impl TreeNode for LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) }), LogicalPlan::Limit(Limit { skip, fetch, input }) => input @@ -251,6 +243,7 @@ impl TreeNode for LogicalPlan { partition_by, file_type, options, + output_schema, }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Copy(CopyTo { input, @@ -258,6 +251,7 @@ impl TreeNode for LogicalPlan { partition_by, file_type, options, + output_schema, }) }), LogicalPlan::Ddl(ddl) => { @@ -321,9 +315,9 @@ impl TreeNode for LogicalPlan { LogicalPlan::Unnest(Unnest { input, exec_columns: input_columns, - dependency_indices, list_type_columns, struct_type_columns, + dependency_indices, schema, options, }) @@ -444,11 +438,11 @@ impl LogicalPlan { filters.apply_elements(f) } LogicalPlan::Unnest(unnest) => { - let columns = unnest.exec_columns.clone(); - - let exprs = columns + let exprs = unnest + .exec_columns .iter() - .map(|c| Expr::Column(c.clone())) + .cloned() + .map(Expr::Column) .collect::>(); exprs.apply_elements(f) } @@ -509,17 +503,10 @@ impl LogicalPlan { LogicalPlan::Values(Values { schema, values }) => values .map_elements(f)? .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => f(predicate)?.update_data(|predicate| { - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) - }), + LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? + .update_data(|predicate| { + LogicalPlan::Filter(Filter { predicate, input }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -576,7 +563,7 @@ impl LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, @@ -586,7 +573,7 @@ impl LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) }), LogicalPlan::Sort(Sort { expr, input, fetch }) => expr diff --git a/datafusion/expr/src/operation.rs b/datafusion/expr/src/operation.rs index 6b79a8248b293..3158a19dce449 100644 --- a/datafusion/expr/src/operation.rs +++ b/datafusion/expr/src/operation.rs @@ -17,8 +17,8 @@ //! This module contains implementations of operations (unary, binary etc.) for DataFusion expressions. +use crate::expr::{Exists, Expr, InList, InSubquery, Like}; use crate::expr_fn::binary_expr; -use crate::{Expr, Like}; use datafusion_expr_common::operator::Operator; use std::ops::{self, Not}; @@ -153,6 +153,19 @@ impl Not for Expr { escape_char, case_insensitive, )), + Expr::InList(InList { + expr, + list, + negated, + }) => Expr::InList(InList::new(expr, list, !negated)), + Expr::Exists(Exists { subquery, negated }) => { + Expr::Exists(Exists::new(subquery, !negated)) + } + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => Expr::InSubquery(InSubquery::new(expr, subquery, !negated)), _ => Expr::Not(Box::new(self)), } } diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index a2ed0592efdb4..25a0f83947eee 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -20,17 +20,16 @@ use std::fmt::Debug; use std::sync::Arc; +use crate::expr::NullTreatment; +use crate::{ + AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, + WindowFunctionDefinition, WindowUDF, +}; use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, Result, TableReference, }; -use sqlparser::ast::{self, NullTreatment}; - -use crate::{ - AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, - WindowFunctionDefinition, WindowUDF, -}; /// Provides the `SQL` query planner meta-data about tables and /// functions referenced in SQL statements, without a direct dependency on the @@ -85,6 +84,7 @@ pub trait ContextProvider { } /// Return [`TypePlanner`] extensions for planning data types + #[cfg(feature = "sql")] fn get_type_planner(&self) -> Option> { None } @@ -261,7 +261,10 @@ pub trait ExprPlanner: Debug + Send + Sync { /// custom expressions. #[derive(Debug, Clone)] pub struct RawBinaryExpr { - pub op: ast::BinaryOperator, + #[cfg(not(feature = "sql"))] + pub op: datafusion_expr_common::operator::Operator, + #[cfg(feature = "sql")] + pub op: sqlparser::ast::BinaryOperator, pub left: Expr, pub right: Expr, } @@ -294,7 +297,7 @@ pub struct RawAggregateExpr { pub args: Vec, pub distinct: bool, pub filter: Option>, - pub order_by: Option>, + pub order_by: Vec, pub null_treatment: Option, } @@ -307,7 +310,9 @@ pub struct RawWindowExpr { pub partition_by: Vec, pub order_by: Vec, pub window_frame: WindowFrame, + pub filter: Option>, pub null_treatment: Option, + pub distinct: bool, } /// Result of planning a raw expr with [`ExprPlanner`] @@ -320,11 +325,15 @@ pub enum PlannerResult { } /// Customize planning SQL types to DataFusion (Arrow) types. +#[cfg(feature = "sql")] pub trait TypePlanner: Debug + Send + Sync { - /// Plan SQL [`ast::DataType`] to DataFusion [`DataType`] + /// Plan SQL [`sqlparser::ast::DataType`] to DataFusion [`DataType`] /// /// Returns None if not possible - fn plan_type(&self, _sql_type: &ast::DataType) -> Result> { + fn plan_type( + &self, + _sql_type: &sqlparser::ast::DataType, + ) -> Result> { Ok(None) } } diff --git a/datafusion/expr/src/ptr_eq.rs b/datafusion/expr/src/ptr_eq.rs new file mode 100644 index 0000000000000..0bbfba5e8d063 --- /dev/null +++ b/datafusion/expr/src/ptr_eq.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +/// Compares two `Arc` pointers for equality based on their underlying pointers values. +/// This is not equivalent to [`Arc::ptr_eq`] for fat pointers, see that method +/// for more information. +pub fn arc_ptr_eq(a: &Arc, b: &Arc) -> bool { + std::ptr::eq(Arc::as_ptr(a), Arc::as_ptr(b)) +} + +/// Hashes an `Arc` pointer based on its underlying pointer value. +/// The general contract for this function is that if [`arc_ptr_eq`] returns `true` +/// for two `Arc`s, then this function should return the same hash value for both. +pub fn arc_ptr_hash(a: &Arc, hasher: &mut impl Hasher) { + std::ptr::hash(Arc::as_ptr(a), hasher) +} + +/// A wrapper around a pointer that implements `Eq` and `Hash` comparing +/// the underlying pointer address. +/// +/// If you have pointers to a `dyn UDF impl` consider using [`super::udf_eq::UdfEq`]. +#[derive(Clone)] +#[allow(private_bounds)] // This is so that PtrEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. +pub struct PtrEq(Ptr); + +impl PartialEq for PtrEq> +where + T: ?Sized, +{ + fn eq(&self, other: &Self) -> bool { + arc_ptr_eq(&self.0, &other.0) + } +} +impl Eq for PtrEq> where T: ?Sized {} + +impl Hash for PtrEq> +where + T: ?Sized, +{ + fn hash(&self, state: &mut H) { + arc_ptr_hash(&self.0, state); + } +} + +impl From for PtrEq +where + Ptr: PointerType, +{ + fn from(ptr: Ptr) -> Self { + PtrEq(ptr) + } +} + +impl From>> for Arc +where + T: ?Sized, +{ + fn from(wrapper: PtrEq>) -> Self { + wrapper.0 + } +} + +impl Debug for PtrEq +where + Ptr: PointerType + Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Deref for PtrEq +where + Ptr: PointerType, +{ + type Target = Ptr; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +trait PointerType {} +impl PointerType for Arc where T: ?Sized {} + +#[cfg(test)] +mod tests { + use super::*; + use std::hash::DefaultHasher; + + #[test] + pub fn test_ptr_eq_wrapper() { + let a = Arc::new("Hello".to_string()); + let b = Arc::new(a.deref().clone()); + let c = Arc::new("world".to_string()); + + let wrapper = PtrEq(Arc::clone(&a)); + assert_eq!(wrapper, wrapper); + + // same address (equal) + assert_eq!(PtrEq(Arc::clone(&a)), PtrEq(Arc::clone(&a))); + assert_eq!(hash(PtrEq(Arc::clone(&a))), hash(PtrEq(Arc::clone(&a)))); + + // different address, same content (not equal) + assert_ne!(PtrEq(Arc::clone(&a)), PtrEq(Arc::clone(&b))); + + // different address, different content (not equal) + assert_ne!(PtrEq(Arc::clone(&a)), PtrEq(Arc::clone(&c))); + } + + fn hash(value: T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } +} diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 4eb49710bcf85..9554dd68e1758 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -27,9 +27,15 @@ use std::sync::Arc; /// A registry knows how to build logical expressions out of user-defined function' names pub trait FunctionRegistry { - /// Set of all available udfs. + /// Returns names of all available scalar user defined functions. fn udfs(&self) -> HashSet; + /// Returns names of all available aggregate user defined functions. + fn udafs(&self) -> HashSet; + + /// Returns names of all available window user defined functions. + fn udwfs(&self) -> HashSet; + /// Returns a reference to the user defined scalar function (udf) named /// `name`. fn udf(&self, name: &str) -> Result>; @@ -200,4 +206,12 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn expr_planners(&self) -> Vec> { vec![] } + + fn udafs(&self) -> HashSet { + self.udafs.keys().cloned().collect() + } + + fn udwfs(&self) -> HashSet { + self.udwfs.keys().cloned().collect() + } } diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 467ce8bf53e2d..02794271a9ee1 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -18,7 +18,7 @@ //! Structs and traits to provide the information needed for expression simplification. use arrow::datatypes::DataType; -use datafusion_common::{DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{internal_datafusion_err, DFSchemaRef, Result}; use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; @@ -86,9 +86,7 @@ impl SimplifyInfo for SimplifyContext<'_> { /// Returns true if expr is nullable fn nullable(&self, expr: &Expr) -> Result { let schema = self.schema.as_ref().ok_or_else(|| { - DataFusionError::Internal( - "attempt to get nullability without schema".to_string(), - ) + internal_datafusion_err!("attempt to get nullability without schema") })?; expr.nullable(schema.as_ref()) } @@ -96,9 +94,7 @@ impl SimplifyInfo for SimplifyContext<'_> { /// Returns data type of this expr needed for determining optimized int type of a value fn get_data_type(&self, expr: &Expr) -> Result { let schema = self.schema.as_ref().ok_or_else(|| { - DataFusionError::Internal( - "attempt to get data type without schema".to_string(), - ) + internal_datafusion_err!("attempt to get data type without schema") })?; expr.get_type(schema) } diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index d6155cfb5dc02..d3b253c0e102c 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -32,7 +32,7 @@ use std::{any::Any, borrow::Cow}; /// the filter") are returned. Rows that evaluate to `false` or `NULL` are /// omitted. /// -/// [`TableProvider::scan`]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html#tymethod.scan +/// [`TableProvider::scan`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#tymethod.scan #[derive(Debug, Clone, PartialEq, Eq)] pub enum TableProviderFilterPushDown { /// The filter cannot be used by the provider and will not be pushed down. @@ -89,7 +89,7 @@ impl std::fmt::Display for TableType { /// plan code be dependent on the DataFusion execution engine. Some projects use /// DataFusion's logical plans and have their own execution engine. /// -/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html /// [`DefaultTableSource`]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html pub trait TableSource: Sync + Send { fn as_any(&self) -> &dyn Any; @@ -121,7 +121,7 @@ pub trait TableSource: Sync + Send { /// Get the Logical plan of this table provider, if available. /// /// For example, a view may have a logical plan, but a CSV file does not. - fn get_logical_plan(&self) -> Option> { + fn get_logical_plan(&'_ self) -> Option> { None } diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index a753f4c376c63..41bc645058079 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -22,7 +22,8 @@ use std::any::Any; use arrow::datatypes::{ - DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; @@ -60,7 +61,7 @@ pub fn sum(expr: Expr) -> Expr { vec![expr], false, None, - None, + vec![], None, )) } @@ -73,7 +74,7 @@ pub fn count(expr: Expr) -> Expr { vec![expr], false, None, - None, + vec![], None, )) } @@ -86,13 +87,13 @@ pub fn avg(expr: Expr) -> Expr { vec![expr], false, None, - None, + vec![], None, )) } /// Stub `sum` used for optimizer testing -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Sum { signature: Signature, } @@ -135,13 +136,14 @@ impl AggregateUDFImpl for Sum { DataType::Dictionary(_, v) => coerced_type(v), // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { - Ok(data_type.clone()) - } + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => Ok(data_type.clone()), dt if dt.is_signed_integer() => Ok(DataType::Int64), dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), dt if dt.is_floating() => Ok(DataType::Float64), - _ => exec_err!("Sum not supported for {}", data_type), + _ => exec_err!("Sum not supported for {data_type}"), } } @@ -153,6 +155,18 @@ impl AggregateUDFImpl for Sum { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal32(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal32(new_precision, *scale)) + } + DataType::Decimal64(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal64(new_precision, *scale)) + } DataType::Decimal128(precision, scale) => { // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 @@ -175,14 +189,10 @@ impl AggregateUDFImpl for Sum { unreachable!("stub should not have accumulate()") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unreachable!("stub should not have state_fields()") } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { false } @@ -204,6 +214,7 @@ impl AggregateUDFImpl for Sum { } /// Testing stub implementation of COUNT aggregate +#[derive(PartialEq, Eq, Hash)] pub struct Count { signature: Signature, aliases: Vec, @@ -254,7 +265,7 @@ impl AggregateUDFImpl for Count { false } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -286,12 +297,13 @@ pub fn min(expr: Expr) -> Expr { vec![expr], false, None, - None, + vec![], None, )) } /// Testing stub implementation of Min aggregate +#[derive(PartialEq, Eq, Hash)] pub struct Min { signature: Signature, } @@ -336,7 +348,7 @@ impl AggregateUDFImpl for Min { Ok(DataType::Int64) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -344,10 +356,6 @@ impl AggregateUDFImpl for Min { not_impl_err!("no impl for stub") } - fn aliases(&self) -> &[String] { - &[] - } - fn create_groups_accumulator( &self, _args: AccumulatorArgs, @@ -371,12 +379,13 @@ pub fn max(expr: Expr) -> Expr { vec![expr], false, None, - None, + vec![], None, )) } /// Testing stub implementation of MAX aggregate +#[derive(PartialEq, Eq, Hash)] pub struct Max { signature: Signature, } @@ -421,7 +430,7 @@ impl AggregateUDFImpl for Max { Ok(DataType::Int64) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -429,10 +438,6 @@ impl AggregateUDFImpl for Max { not_impl_err!("no impl for stub") } - fn aliases(&self) -> &[String] { - &[] - } - fn create_groups_accumulator( &self, _args: AccumulatorArgs, @@ -449,7 +454,7 @@ impl AggregateUDFImpl for Max { } /// Testing stub implementation of avg aggregate -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Avg { signature: Signature, aliases: Vec, @@ -491,9 +496,10 @@ impl AggregateUDFImpl for Avg { not_impl_err!("no impl for stub") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } + fn aliases(&self) -> &[String] { &self.aliases } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f20dab7e165fc..81846b4f80608 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -22,7 +22,7 @@ use crate::expr::{ GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; -use crate::{Expr, ExprFunctionExt}; +use crate::Expr; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, @@ -73,7 +73,7 @@ impl TreeNode for Expr { // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } @@ -92,14 +92,17 @@ impl TreeNode for Expr { (expr, when_then_expr, else_expr).apply_ref_elements(f), Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) => (args, filter, order_by).apply_ref_elements(f), - Expr::WindowFunction(WindowFunction { - params : WindowFunctionParams { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { args, partition_by, order_by, - ..}, ..}) => { - (args, partition_by, order_by).apply_ref_elements(f) + filter, + .. + } = &window_fun.as_ref().params; + (args, partition_by, order_by, filter).apply_ref_elements(f) } + Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } @@ -124,7 +127,7 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) => Transformed::no(self), + | Expr::Literal(_, _) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), @@ -230,27 +233,40 @@ impl TreeNode for Expr { ))) })? } - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => (args, partition_by, order_by).map_elements(f)?.update_data( - |(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }, - ), + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + }, + } = *window_fun; + + (args, partition_by, order_by, filter) + .map_elements(f)? + .map_data( + |(new_args, new_partition_by, new_order_by, new_filter)| { + Ok(Expr::from(WindowFunction { + fun, + params: WindowFunctionParams { + args: new_args, + partition_by: new_partition_by, + order_by: new_order_by, + window_frame, + filter: new_filter, + null_treatment, + distinct, + }, + })) + }, + )? + } Expr::AggregateFunction(AggregateFunction { func, params: diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 3b34718062eb4..bcaff11bcdb49 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,24 +15,28 @@ // specific language governing permissions and limitations // under the License. -use super::binary::{binary_numeric_coercion, comparison_coercion}; +use super::binary::binary_numeric_coercion; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use arrow::datatypes::FieldRef; use arrow::{ compute::can_cast_types, - datatypes::{DataType, Field, TimeUnit}, + datatypes::{DataType, TimeUnit}, }; use datafusion_common::types::LogicalType; -use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; +use datafusion_common::utils::{ + base_type, coerced_fixed_size_list_to_list, ListCoercion, +}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, - utils::list_ndims, Result, + exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result, }; use datafusion_expr_common::signature::ArrayFunctionArgument; +use datafusion_expr_common::type_coercion::binary::type_union_resolution; use datafusion_expr_common::{ signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, type_coercion::binary::comparison_coercion_numeric, type_coercion::binary::string_coercion, }; +use itertools::Itertools as _; use std::sync::Arc; /// Performs type coercion for scalar function arguments. @@ -75,19 +79,19 @@ pub fn data_types_with_scalar_udf( /// Performs type coercion for aggregate function arguments. /// -/// Returns the data types to which each argument must be coerced to +/// Returns the fields to which each argument must be coerced to /// match `signature`. /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_aggregate_udf( - current_types: &[DataType], +pub fn fields_with_aggregate_udf( + current_fields: &[FieldRef], func: &AggregateUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { + if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { if type_signature.supports_zero_argument() { return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { @@ -97,17 +101,32 @@ pub fn data_types_with_aggregate_udf( return plan_err!("'{}' does not support zero arguments", func.name()); } } + let current_types = current_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let valid_types = - get_valid_types_with_aggregate_udf(type_signature, current_types, func)?; + get_valid_types_with_aggregate_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_type == ¤t_types) { - return Ok(current_types.to_vec()); + return Ok(current_fields.to_vec()); } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + let updated_types = + try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + + Ok(current_fields + .iter() + .zip(updated_types) + .map(|(current_field, new_type)| { + current_field.as_ref().clone().with_data_type(new_type) + }) + .map(Arc::new) + .collect()) } /// Performs type coercion for window function arguments. @@ -117,14 +136,14 @@ pub fn data_types_with_aggregate_udf( /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_window_udf( - current_types: &[DataType], +pub fn fields_with_window_udf( + current_fields: &[FieldRef], func: &WindowUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { + if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { if type_signature.supports_zero_argument() { return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { @@ -135,16 +154,31 @@ pub fn data_types_with_window_udf( } } + let current_types = current_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let valid_types = - get_valid_types_with_window_udf(type_signature, current_types, func)?; + get_valid_types_with_window_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_type == ¤t_types) { - return Ok(current_types.to_vec()); + return Ok(current_fields.to_vec()); } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + let updated_types = + try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + + Ok(current_fields + .iter() + .zip(updated_types) + .map(|(current_field, new_type)| { + current_field.as_ref().clone().with_data_type(new_type) + }) + .map(Arc::new) + .collect()) } /// Performs type coercion for function arguments. @@ -245,7 +279,8 @@ fn try_coerce_types( // none possible -> Error plan_err!( - "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {current_types:?} to the signature {type_signature:?} failed" + "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature:?} failed", + current_types.iter().join(", ") ) } @@ -364,98 +399,93 @@ fn get_valid_types( return Ok(vec![vec![]]); } - let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| { - if *arg == ArrayFunctionArgument::Array { - Some(idx) - } else { - None - } - }); - let Some(array_idx) = array_idx else { - return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument")); - }; - let Some(array_type) = array(¤t_types[array_idx]) else { - return Ok(vec![vec![]]); - }; - - // We need to find the coerced base type, mainly for cases like: - // `array_append(List(null), i64)` -> `List(i64)` - let mut new_base_type = datafusion_common::utils::base_type(&array_type); - for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { - match argument_type { - ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => { - new_base_type = - coerce_array_types(function_name, current_type, &new_base_type)?; + let mut large_list = false; + let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList); + let mut list_sizes = Vec::with_capacity(arguments.len()); + let mut element_types = Vec::with_capacity(arguments.len()); + let mut nested_item_nullability = Vec::with_capacity(arguments.len()); + for (argument, current_type) in arguments.iter().zip(current_types.iter()) { + match argument { + ArrayFunctionArgument::Index | ArrayFunctionArgument::String => { + nested_item_nullability.push(None); } - ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {} + ArrayFunctionArgument::Element => { + element_types.push(current_type.clone()); + nested_item_nullability.push(None); + } + ArrayFunctionArgument::Array => match current_type { + DataType::Null => { + element_types.push(DataType::Null); + nested_item_nullability.push(None); + } + DataType::List(field) => { + element_types.push(field.data_type().clone()); + nested_item_nullability.push(Some(field.is_nullable())); + fixed_size = false; + } + DataType::LargeList(field) => { + element_types.push(field.data_type().clone()); + nested_item_nullability.push(Some(field.is_nullable())); + large_list = true; + fixed_size = false; + } + DataType::FixedSizeList(field, size) => { + element_types.push(field.data_type().clone()); + nested_item_nullability.push(Some(field.is_nullable())); + list_sizes.push(*size) + } + arg_type => { + plan_err!("{function_name} does not support type {arg_type}")? + } + }, } } - let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( - &array_type, - &new_base_type, - array_coercion, - ); - let new_elem_type = match new_array_type { - DataType::List(ref field) - | DataType::LargeList(ref field) - | DataType::FixedSizeList(ref field, _) => field.data_type(), - _ => return Ok(vec![vec![]]), + debug_assert_eq!(nested_item_nullability.len(), arguments.len()); + + let Some(element_type) = type_union_resolution(&element_types) else { + return Ok(vec![vec![]]); + }; + + if !fixed_size { + list_sizes.clear() }; - let mut valid_types = Vec::with_capacity(arguments.len()); - for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { - let valid_type = match argument_type { - ArrayFunctionArgument::Element => new_elem_type.clone(), - ArrayFunctionArgument::Index => DataType::Int64, - ArrayFunctionArgument::String => DataType::Utf8, - ArrayFunctionArgument::Array => { - let Some(current_type) = array(current_type) else { - return Ok(vec![vec![]]); - }; - let new_type = - datafusion_common::utils::coerced_type_with_base_type_only( - ¤t_type, - &new_base_type, - array_coercion, - ); - // All array arguments must be coercible to the same type - if new_type != new_array_type { - return Ok(vec![vec![]]); + let mut list_sizes = list_sizes.into_iter(); + let valid_types = arguments + .iter() + .zip(current_types.iter()) + .zip(nested_item_nullability) + .map(|((argument_type, current_type), is_nested_item_nullable)| { + match argument_type { + ArrayFunctionArgument::Index => DataType::Int64, + ArrayFunctionArgument::String => DataType::Utf8, + ArrayFunctionArgument::Element => element_type.clone(), + ArrayFunctionArgument::Array => { + if current_type.is_null() { + DataType::Null + } else if large_list { + DataType::new_large_list( + element_type.clone(), + is_nested_item_nullable.unwrap_or(true), + ) + } else if let Some(size) = list_sizes.next() { + DataType::new_fixed_size_list( + element_type.clone(), + size, + is_nested_item_nullable.unwrap_or(true), + ) + } else { + DataType::new_list( + element_type.clone(), + is_nested_item_nullable.unwrap_or(true), + ) + } } - new_type } - }; - valid_types.push(valid_type); - } + }); - Ok(vec![valid_types]) - } - - fn array(array_type: &DataType) -> Option { - match array_type { - DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()), - DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))), - DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field( - DataType::Int64, - true, - )))), - _ => None, - } - } - - fn coerce_array_types( - function_name: &str, - current_type: &DataType, - base_type: &DataType, - ) -> Result { - let current_base_type = datafusion_common::utils::base_type(current_type); - let new_base_type = comparison_coercion(base_type, ¤t_base_type); - new_base_type.ok_or_else(|| { - internal_datafusion_err!( - "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}" - ) - }) + Ok(vec![valid_types.collect()]) } fn recursive_array(array_type: &DataType) -> Option { @@ -501,7 +531,7 @@ fn get_valid_types( new_types.push(DataType::Utf8); } else { return plan_err!( - "Function '{function_name}' expects NativeType::String but received {logical_data_type}" + "Function '{function_name}' expects NativeType::String but NativeType::received NativeType::{logical_data_type}" ); } } @@ -561,7 +591,7 @@ fn get_valid_types( if !logical_data_type.is_numeric() { return plan_err!( - "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}" + "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}" ); } @@ -582,7 +612,7 @@ fn get_valid_types( valid_type = DataType::Float64; } else if !logical_data_type.is_numeric() { return plan_err!( - "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}" + "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}" ); } @@ -629,7 +659,7 @@ fn get_valid_types( new_types.push(casted_type); } else { return internal_err!( - "Expect {} but received {}, DataType: {}", + "Expect {} but received NativeType::{}, DataType: {}", param.desired_type(), current_native_type, current_type @@ -800,7 +830,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { /// /// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32. /// -/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion. +/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion. fn coerced_from<'a>( type_into: &'a DataType, type_from: &'a DataType, @@ -849,7 +879,10 @@ fn coerced_from<'a>( | UInt64 | Float32 | Float64 - | Decimal128(_, _), + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _), ) => Some(type_into.clone()), ( Timestamp(TimeUnit::Nanosecond, None), @@ -867,7 +900,7 @@ fn coerced_from<'a>( // Only accept list and largelist with the same number of dimensions unless the type is Null. // List or LargeList with different dimensions should be handled in TypeSignature or other places before this (List(_) | LargeList(_), _) - if datafusion_common::utils::base_type(type_from).eq(&Null) + if base_type(type_from).is_null() || list_ndims(type_from) == list_ndims(type_into) => { Some(type_into.clone()) @@ -906,7 +939,6 @@ fn coerced_from<'a>( #[cfg(test)] mod tests { - use crate::Volatility; use super::*; @@ -1193,4 +1225,167 @@ mod tests { Some(type_into.clone()) ); } + + #[test] + fn test_get_valid_types_array_and_array() -> Result<()> { + let function = "array_and_array"; + let signature = Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ); + + let data_types = vec![ + DataType::new_list(DataType::Int32, true), + DataType::new_large_list(DataType::Float64, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Float64, true), + DataType::new_large_list(DataType::Float64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int32, 5, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int64, true), + DataType::new_list(DataType::Int64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Null, 3, true), + DataType::new_large_list(DataType::Utf8, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Utf8, true), + DataType::new_large_list(DataType::Utf8, true), + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_array_and_element() -> Result<()> { + let function = "array_and_element"; + let signature = Signature::array_and_element(Volatility::Immutable); + + let data_types = + vec![DataType::new_list(DataType::Int32, true), DataType::Float64]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Float64, true), + DataType::Float64, + ]] + ); + + let data_types = vec![ + DataType::new_large_list(DataType::Int32, true), + DataType::Null, + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Int32, true), + DataType::Int32, + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Null, 3, true), + DataType::Utf8, + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Utf8, true), + DataType::Utf8, + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_element_and_array() -> Result<()> { + let function = "element_and_array"; + let signature = Signature::element_and_array(Volatility::Immutable); + + let data_types = vec![ + DataType::new_large_list(DataType::Null, false), + DataType::new_list(DataType::new_list(DataType::Int64, true), true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Int64, true), + DataType::new_list(DataType::new_large_list(DataType::Int64, true), true), + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_fixed_size_arrays() -> Result<()> { + let function = "fixed_size_arrays"; + let signature = Signature::arrays(2, None, Volatility::Immutable); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int32, 5, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int64, 5, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_list(DataType::Int32, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int64, true), + DataType::new_list(DataType::Int64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Utf8, 3, true), + DataType::new_list(DataType::new_list(DataType::Int32, true), true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, false), + DataType::new_list(DataType::Int32, false), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int64, false), + DataType::new_list(DataType::Int64, false), + ]] + ); + + Ok(()) + } } diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 4fc150ef2996a..bd1acd3f3a2e2 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -51,6 +51,8 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _), ) @@ -89,5 +91,11 @@ pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool { /// Determine whether the given data type `dt` is a `Decimal`. pub fn is_decimal(dt: &DataType) -> bool { - matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) + matches!( + dt, + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ) } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b75e8fd3cd3c4..bfd699d814855 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -20,13 +20,14 @@ use std::any::Any; use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter, Write}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use std::vec; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; +use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expr::{ @@ -38,6 +39,7 @@ use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; use crate::groups_accumulator::GroupsAccumulator; +use crate::udf_eq::UdfEq; use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; use crate::{expr_vec_fmt, Accumulator, Expr}; @@ -70,7 +72,7 @@ use crate::{Documentation, Signature}; /// /// [the examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples#single-process /// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function -/// [`Accumulator`]: crate::Accumulator +/// [`Accumulator`]: Accumulator /// [`create_udaf`]: crate::expr_fn::create_udaf /// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs /// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs @@ -81,7 +83,7 @@ pub struct AggregateUDF { impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.inner.equals(other.inner.as_ref()) + self.inner.dyn_eq(other.inner.as_any()) } } @@ -89,7 +91,7 @@ impl Eq for AggregateUDF {} impl Hash for AggregateUDF { fn hash(&self, state: &mut H) { - self.inner.hash_value().hash(state) + self.inner.dyn_hash(state) } } @@ -158,7 +160,7 @@ impl AggregateUDF { args, false, None, - None, + vec![], None, )) } @@ -170,6 +172,11 @@ impl AggregateUDF { self.inner.name() } + /// Returns the aliases for this function. + pub fn aliases(&self) -> &[String] { + self.inner.aliases() + } + /// See [`AggregateUDFImpl::schema_name`] for more details. pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result { self.inner.schema_name(params) @@ -205,11 +212,6 @@ impl AggregateUDF { self.inner.is_nullable() } - /// Returns the aliases for this function. - pub fn aliases(&self) -> &[String] { - self.inner.aliases() - } - /// Returns this function's signature (what input types are accepted) /// /// See [`AggregateUDFImpl::signature`] for more details. @@ -224,6 +226,13 @@ impl AggregateUDF { self.inner.return_type(args) } + /// Return the field of the function given its input fields + /// + /// See [`AggregateUDFImpl::return_field`] for more details. + pub fn return_field(&self, args: &[FieldRef]) -> Result { + self.inner.return_field(args) + } + /// Return an accumulator the given aggregate, given its return datatype pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { self.inner.accumulator(acc_args) @@ -234,7 +243,7 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.inner.state_fields(args) } @@ -315,6 +324,16 @@ impl AggregateUDF { self.inner.default_value(data_type) } + /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details. + pub fn supports_null_handling_clause(&self) -> bool { + self.inner.supports_null_handling_clause() + } + + /// See [`AggregateUDFImpl::is_ordered_set_aggregate`] for more details. + pub fn is_ordered_set_aggregate(&self) -> bool { + self.inner.is_ordered_set_aggregate() + } + /// Returns the documentation for this Aggregate UDF. /// /// Documentation can be accessed programmatically as well as @@ -346,8 +365,8 @@ where /// # Basic Example /// ``` /// # use std::any::Any; -/// # use std::sync::LazyLock; -/// # use arrow::datatypes::DataType; +/// # use std::sync::{Arc, LazyLock}; +/// # use arrow::datatypes::{DataType, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; @@ -355,7 +374,7 @@ where /// # use arrow::datatypes::Schema; /// # use arrow::datatypes::Field; /// -/// #[derive(Debug, Clone)] +/// #[derive(Debug, Clone, PartialEq, Eq, Hash)] /// struct GeoMeanUdf { /// signature: Signature, /// } @@ -377,7 +396,7 @@ where /// fn get_doc() -> &'static Documentation { /// &DOCUMENTATION /// } -/// +/// /// /// Implement the AggregateUDFImpl trait for GeoMeanUdf /// impl AggregateUDFImpl for GeoMeanUdf { /// fn as_any(&self) -> &dyn Any { self } @@ -391,14 +410,14 @@ where /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { +/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// Field::new("value", args.return_type.clone(), true), -/// Field::new("ordering", DataType::UInt32, true) +/// Arc::new(args.return_field.as_ref().clone().with_name("value")), +/// Arc::new(Field::new("ordering", DataType::UInt32, true)) /// ]) /// } /// fn documentation(&self) -> Option<&Documentation> { -/// Some(get_doc()) +/// Some(get_doc()) /// } /// } /// @@ -408,94 +427,35 @@ where /// // Call the function `geo_mean(col)` /// let expr = geometric_mean.call(vec![col("a")]); /// ``` -pub trait AggregateUDFImpl: Debug + Send + Sync { - // Note: When adding any methods (with default implementations), remember to add them also - // into the AliasedAggregateUDFImpl below! - +pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; /// Returns this function's name fn name(&self) -> &str; + /// Returns any aliases (alternate names) for this function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } + /// Returns the name of the column this expression would create /// /// See [`Expr::schema_name`] for details /// /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..] fn schema_name(&self, params: &AggregateFunctionParams) -> Result { - let AggregateFunctionParams { - args, - distinct, - filter, - order_by, - null_treatment, - } = params; - - let mut schema_name = String::new(); - - schema_name.write_fmt(format_args!( - "{}({}{})", - self.name(), - if *distinct { "DISTINCT " } else { "" }, - schema_name_from_exprs_comma_separated_without_space(args)? - ))?; - - if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; - } - - if let Some(filter) = filter { - schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?; - }; - - if let Some(order_by) = order_by { - schema_name.write_fmt(format_args!( - " ORDER BY [{}]", - schema_name_from_sorts(order_by)? - ))?; - }; - - Ok(schema_name) + udaf_default_schema_name(self, params) } /// Returns a human readable expression. /// /// See [`Expr::human_display`] for details. fn human_display(&self, params: &AggregateFunctionParams) -> Result { - let AggregateFunctionParams { - args, - distinct, - filter, - order_by, - null_treatment, - } = params; - - let mut schema_name = String::new(); - - schema_name.write_fmt(format_args!( - "{}({}{})", - self.name(), - if *distinct { "DISTINCT " } else { "" }, - ExprListDisplay::comma_separated(args.as_slice()) - ))?; - - if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; - } - - if let Some(filter) = filter { - schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?; - }; - - if let Some(order_by) = order_by { - schema_name.write_fmt(format_args!( - " ORDER BY [{}]", - schema_name_from_sorts(order_by)? - ))?; - }; - - Ok(schema_name) + udaf_default_human_display(self, params) } /// Returns the name of the column this expression would create @@ -509,42 +469,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { &self, params: &WindowFunctionParams, ) -> Result { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - let mut schema_name = String::new(); - schema_name.write_fmt(format_args!( - "{}({})", - self.name(), - schema_name_from_exprs(args)? - ))?; - - if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; - } - - if !partition_by.is_empty() { - schema_name.write_fmt(format_args!( - " PARTITION BY [{}]", - schema_name_from_exprs(partition_by)? - ))?; - } - - if !order_by.is_empty() { - schema_name.write_fmt(format_args!( - " ORDER BY [{}]", - schema_name_from_sorts(order_by)? - ))?; - }; - - schema_name.write_fmt(format_args!(" {window_frame}"))?; - - Ok(schema_name) + udaf_default_window_function_schema_name(self, params) } /// Returns the user-defined display name of function, given the arguments @@ -554,40 +479,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]` fn display_name(&self, params: &AggregateFunctionParams) -> Result { - let AggregateFunctionParams { - args, - distinct, - filter, - order_by, - null_treatment, - } = params; - - let mut display_name = String::new(); - - display_name.write_fmt(format_args!( - "{}({}{})", - self.name(), - if *distinct { "DISTINCT " } else { "" }, - expr_vec_fmt!(args) - ))?; - - if let Some(nt) = null_treatment { - display_name.write_fmt(format_args!(" {}", nt))?; - } - if let Some(fe) = filter { - display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?; - } - if let Some(ob) = order_by { - display_name.write_fmt(format_args!( - " ORDER BY [{}]", - ob.iter() - .map(|o| format!("{o}")) - .collect::>() - .join(", ") - ))?; - } - - Ok(display_name) + udaf_default_display_name(self, params) } /// Returns the user-defined display name of function, given the arguments @@ -602,44 +494,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { &self, params: &WindowFunctionParams, ) -> Result { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - let mut display_name = String::new(); - - display_name.write_fmt(format_args!( - "{}({})", - self.name(), - expr_vec_fmt!(args) - ))?; - - if let Some(null_treatment) = null_treatment { - display_name.write_fmt(format_args!(" {}", null_treatment))?; - } - - if !partition_by.is_empty() { - display_name.write_fmt(format_args!( - " PARTITION BY [{}]", - expr_vec_fmt!(partition_by) - ))?; - } - - if !order_by.is_empty() { - display_name - .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?; - }; - - display_name.write_fmt(format_args!( - " {} BETWEEN {} AND {}", - window_frame.units, window_frame.start_bound, window_frame.end_bound - ))?; - - Ok(display_name) + udaf_default_window_function_display_name(self, params) } /// Returns the function's [`Signature`] for information about what input @@ -650,6 +505,27 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// the arguments fn return_type(&self, arg_types: &[DataType]) -> Result; + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// # Notes + /// + /// Most UDFs should implement [`Self::return_type`] and not this + /// function as the output type for most functions only depends on the types + /// of their inputs (e.g. `sum(f64)` is always `f64`). + /// + /// This function can be used for more advanced cases such as: + /// + /// 1. specifying nullability + /// 2. return types based on the **values** of the arguments (rather than + /// their **types**. + /// 3. return types based on metadata within the fields of the inputs + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + udaf_default_return_field(self, arg_fields) + } + /// Whether the aggregate function is nullable. /// /// Nullable means that the function could return `null` for any inputs. @@ -688,15 +564,16 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let fields = vec![Field::new( - format_state_name(args.name, "value"), - args.return_type.clone(), - true, - )]; + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![args + .return_field + .as_ref() + .clone() + .with_name(format_state_name(args.name, "value"))]; Ok(fields .into_iter() + .map(Arc::new) .chain(args.ordering_fields.to_vec()) .collect()) } @@ -727,20 +604,12 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } - /// Returns any aliases (alternate names) for this function. - /// - /// Note: `aliases` should only include names other than [`Self::name`]. - /// Defaults to `[]` (no aliases) - fn aliases(&self) -> &[String] { - &[] - } - /// Sliding accumulator is an alternative accumulator that can be used for /// window functions. It has retract method to revert the previous update. /// /// See [retract_batch] for more details. /// - /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch + /// [retract_batch]: Accumulator::retract_batch fn create_sliding_accumulator( &self, args: AccumulatorArgs, @@ -797,11 +666,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// [None] if simplify is not defined or, /// /// Or, a closure with two arguments: - /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked + /// * 'aggregate_function': [AggregateFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] /// /// closure returns simplified [Expr] or an error. /// + /// # Notes + /// + /// The returned expression must have the same schema as the original + /// expression, including both the data type and nullability. For example, + /// if the original expression is nullable, the returned expression must + /// also be nullable, otherwise it may lead to schema verification errors + /// later in query planning. fn simplify(&self) -> Option { None } @@ -834,33 +710,6 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { not_impl_err!("Function {} does not implement coerce_types", self.name()) } - /// Return true if this aggregate UDF is equal to the other. - /// - /// Allows customizing the equality of aggregate UDFs. - /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: - /// - /// - reflexive: `a.equals(a)`; - /// - symmetric: `a.equals(b)` implies `b.equals(a)`; - /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. - /// - /// By default, compares [`Self::name`] and [`Self::signature`]. - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - self.name() == other.name() && self.signature() == other.signature() - } - - /// Returns a hash value for this aggregate UDF. - /// - /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`], - /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. - /// - /// By default, hashes [`Self::name`] and [`Self::signature`]. - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.name().hash(hasher); - self.signature().hash(hasher); - hasher.finish() - } - /// If this function is max, return true /// If the function is min, return false /// Otherwise return None (the default) @@ -891,6 +740,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ScalarValue::try_from(data_type) } + /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` clause, return true + /// If the function does not, return false + fn supports_null_handling_clause(&self) -> bool { + true + } + + /// If this function is ordered-set aggregate function, return true + /// otherwise, return false + /// + /// Ordered-set aggregate functions require an explicit `ORDER BY` clause + /// because the calculation performed by these functions is dependent on the + /// specific sequence of the input rows, unlike other aggregate functions + /// like `SUM`, `AVG`, or `COUNT`. + /// + /// An example of an ordered-set aggregate function is `percentile_cont` + /// which computes a specific percentile value from a sorted list of values, and + /// is only meaningful when the input data is ordered. + /// + /// In SQL syntax, ordered-set aggregate functions are used with the + /// `WITHIN GROUP (ORDER BY ...)` clause to specify the ordering of the input + /// data. + fn is_ordered_set_aggregate(&self) -> bool { + false + } + /// Returns the documentation for this Aggregate UDF. /// /// Documentation can be accessed programmatically as well as @@ -908,22 +782,290 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { impl PartialEq for dyn AggregateUDFImpl { fn eq(&self, other: &Self) -> bool { - self.equals(other) + self.dyn_eq(other.as_any()) } } -// Manual implementation of `PartialOrd` -// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl -// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5 impl PartialOrd for dyn AggregateUDFImpl { fn partial_cmp(&self, other: &Self) -> Option { match self.name().partial_cmp(other.name()) { Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } +/// Encapsulates default implementation of [`AggregateUDFImpl::schema_name`]. +pub fn udaf_default_schema_name( + func: &F, + params: &AggregateFunctionParams, +) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + // exclude the first function argument(= column) in ordered set aggregate function, + // because it is duplicated with the WITHIN GROUP clause in schema name. + let args = if func.is_ordered_set_aggregate() && !order_by.is_empty() { + &args[1..] + } else { + &args[..] + }; + + let mut schema_name = String::new(); + + schema_name.write_fmt(format_args!( + "{}({}{})", + func.name(), + if *distinct { "DISTINCT " } else { "" }, + schema_name_from_exprs_comma_separated_without_space(args)? + ))?; + + if let Some(null_treatment) = null_treatment { + schema_name.write_fmt(format_args!(" {null_treatment}"))?; + } + + if let Some(filter) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?; + }; + + if !order_by.is_empty() { + let clause = match func.is_ordered_set_aggregate() { + true => "WITHIN GROUP", + false => "ORDER BY", + }; + + schema_name.write_fmt(format_args!( + " {} [{}]", + clause, + schema_name_from_sorts(order_by)? + ))?; + }; + + Ok(schema_name) +} + +/// Encapsulates default implementation of [`AggregateUDFImpl::human_display`]. +pub fn udaf_default_human_display( + func: &F, + params: &AggregateFunctionParams, +) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + let mut schema_name = String::new(); + + schema_name.write_fmt(format_args!( + "{}({}{})", + func.name(), + if *distinct { "DISTINCT " } else { "" }, + ExprListDisplay::comma_separated(args.as_slice()) + ))?; + + if let Some(null_treatment) = null_treatment { + schema_name.write_fmt(format_args!(" {null_treatment}"))?; + } + + if let Some(filter) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?; + }; + + if !order_by.is_empty() { + schema_name.write_fmt(format_args!( + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + ))?; + }; + + Ok(schema_name) +} + +/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_schema_name`]. +pub fn udaf_default_window_function_schema_name( + func: &F, + params: &WindowFunctionParams, +) -> Result { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + } = params; + + let mut schema_name = String::new(); + + // Inject DISTINCT into the schema name when requested + if *distinct { + schema_name.write_fmt(format_args!( + "{}(DISTINCT {})", + func.name(), + schema_name_from_exprs(args)? + ))?; + } else { + schema_name.write_fmt(format_args!( + "{}({})", + func.name(), + schema_name_from_exprs(args)? + ))?; + } + + if let Some(null_treatment) = null_treatment { + schema_name.write_fmt(format_args!(" {null_treatment}"))?; + } + + if let Some(filter) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?; + } + + if !partition_by.is_empty() { + schema_name.write_fmt(format_args!( + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + ))?; + } + + if !order_by.is_empty() { + schema_name.write_fmt(format_args!( + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + ))?; + } + + schema_name.write_fmt(format_args!(" {window_frame}"))?; + + Ok(schema_name) +} + +/// Encapsulates default implementation of [`AggregateUDFImpl::display_name`]. +pub fn udaf_default_display_name( + func: &F, + params: &AggregateFunctionParams, +) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + let mut display_name = String::new(); + + display_name.write_fmt(format_args!( + "{}({}{})", + func.name(), + if *distinct { "DISTINCT " } else { "" }, + expr_vec_fmt!(args) + ))?; + + if let Some(nt) = null_treatment { + display_name.write_fmt(format_args!(" {nt}"))?; + } + if let Some(fe) = filter { + display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?; + } + if !order_by.is_empty() { + display_name.write_fmt(format_args!( + " ORDER BY [{}]", + order_by + .iter() + .map(|o| format!("{o}")) + .collect::>() + .join(", ") + ))?; + } + + Ok(display_name) +} + +/// Encapsulates default implementation of [`AggregateUDFImpl::window_function_display_name`]. +pub fn udaf_default_window_function_display_name( + func: &F, + params: &WindowFunctionParams, +) -> Result { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + } = params; + + let mut display_name = String::new(); + + if *distinct { + display_name.write_fmt(format_args!( + "{}(DISTINCT {})", + func.name(), + expr_vec_fmt!(args) + ))?; + } else { + display_name.write_fmt(format_args!( + "{}({})", + func.name(), + expr_vec_fmt!(args) + ))?; + } + + if let Some(null_treatment) = null_treatment { + display_name.write_fmt(format_args!(" {null_treatment}"))?; + } + + if let Some(fe) = filter { + display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?; + } + + if !partition_by.is_empty() { + display_name.write_fmt(format_args!( + " PARTITION BY [{}]", + expr_vec_fmt!(partition_by) + ))?; + } + + if !order_by.is_empty() { + display_name + .write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?; + }; + + display_name.write_fmt(format_args!( + " {} BETWEEN {} AND {}", + window_frame.units, window_frame.start_bound, window_frame.end_bound + ))?; + + Ok(display_name) +} + +/// Encapsulates default implementation of [`AggregateUDFImpl::return_field`]. +pub fn udaf_default_return_field( + func: &F, + arg_fields: &[FieldRef], +) -> Result { + let arg_types: Vec<_> = arg_fields.iter().map(|f| f.data_type()).cloned().collect(); + let data_type = func.return_type(&arg_types)?; + + Ok(Arc::new(Field::new( + func.name(), + data_type, + func.is_nullable(), + ))) +} + pub enum ReversedUDAF { /// The expression is the same as the original expression, like SUM, COUNT Identical, @@ -935,9 +1077,9 @@ pub enum ReversedUDAF { /// AggregateUDF that adds an alias to the underlying function. It is better to /// implement [`AggregateUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedAggregateUDFImpl { - inner: Arc, + inner: UdfEq>, aliases: Vec, } @@ -949,10 +1091,14 @@ impl AliasedAggregateUDFImpl { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } + Self { + inner: inner.into(), + aliases, + } } } +#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method impl AggregateUDFImpl for AliasedAggregateUDFImpl { fn as_any(&self) -> &dyn Any { self @@ -978,7 +1124,33 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { &self.aliases } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn schema_name(&self, params: &AggregateFunctionParams) -> Result { + self.inner.schema_name(params) + } + + fn human_display(&self, params: &AggregateFunctionParams) -> Result { + self.inner.human_display(params) + } + + fn window_function_schema_name( + &self, + params: &WindowFunctionParams, + ) -> Result { + self.inner.window_function_schema_name(params) + } + + fn display_name(&self, params: &AggregateFunctionParams) -> Result { + self.inner.display_name(params) + } + + fn window_function_display_name( + &self, + params: &WindowFunctionParams, + ) -> Result { + self.inner.window_function_display_name(params) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.inner.state_fields(args) } @@ -1009,7 +1181,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { .map(|udf| { udf.map(|udf| { Arc::new(AliasedAggregateUDFImpl { - inner: udf, + inner: udf.into(), aliases: self.aliases.clone(), }) as Arc }) @@ -1032,59 +1204,41 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.coerce_types(arg_types) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases - } else { - false - } + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + self.inner.return_field(arg_fields) } - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.inner.hash_value().hash(hasher); - self.aliases.hash(hasher); - hasher.finish() + fn is_nullable(&self) -> bool { + self.inner.is_nullable() } fn is_descending(&self) -> Option { self.inner.is_descending() } - fn documentation(&self) -> Option<&Documentation> { - self.inner.documentation() + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.inner.value_from_stats(statistics_args) } -} -// Aggregate UDF doc sections for use in public documentation -pub mod aggregate_doc_sections { - use crate::DocSection; + fn default_value(&self, data_type: &DataType) -> Result { + self.inner.default_value(data_type) + } - pub fn doc_sections() -> Vec { - vec![ - DOC_SECTION_GENERAL, - DOC_SECTION_STATISTICAL, - DOC_SECTION_APPROXIMATE, - ] + fn supports_null_handling_clause(&self) -> bool { + self.inner.supports_null_handling_clause() } - pub const DOC_SECTION_GENERAL: DocSection = DocSection { - include: true, - label: "General Functions", - description: None, - }; + fn is_ordered_set_aggregate(&self) -> bool { + self.inner.is_ordered_set_aggregate() + } - pub const DOC_SECTION_STATISTICAL: DocSection = DocSection { - include: true, - label: "Statistical Functions", - description: None, - }; + fn set_monotonicity(&self, data_type: &DataType) -> SetMonotonicity { + self.inner.set_monotonicity(data_type) + } - pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection { - include: true, - label: "Approximate Functions", - description: None, - }; + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } } /// Indicates whether an aggregation function is monotonic as a set @@ -1111,7 +1265,7 @@ pub enum SetMonotonicity { #[cfg(test)] mod test { use crate::{AggregateUDF, AggregateUDFImpl}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::signature::{Signature, Volatility}; @@ -1120,8 +1274,9 @@ mod test { }; use std::any::Any; use std::cmp::Ordering; + use std::hash::{DefaultHasher, Hash, Hasher}; - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct AMeanUdf { signature: Signature, } @@ -1157,12 +1312,12 @@ mod test { ) -> Result> { unimplemented!() } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!() } } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct BMeanUdf { signature: Signature, } @@ -1197,11 +1352,21 @@ mod test { ) -> Result> { unimplemented!() } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!() } } + #[test] + fn test_partial_eq() { + let a1 = AggregateUDF::from(AMeanUdf::new()); + let a2 = AggregateUDF::from(AMeanUdf::new()); + let eq = a1 == a2; + assert!(eq); + assert_eq!(a1, a2); + assert_eq!(hash(a1), hash(a2)); + } + #[test] fn test_partial_ord() { // Test validates that partial ord is defined for AggregateUDF using the name and signature, @@ -1214,4 +1379,10 @@ mod test { assert!(a1 < b1); assert!(!(a1 == b1)); } + + fn hash(value: T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9b2400774a3d6..d522158f7b6b7 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,24 +17,28 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; +use crate::udf_eq::UdfEq; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::config::ConfigOptions; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; +use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// /// A scalar function produces a single row output for each row of input. This /// struct contains the information DataFusion needs to plan and invoke -/// functions you supply such name, type signature, return type, and actual +/// functions you supply such as name, type signature, return type, and actual /// implementation. /// /// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]). @@ -42,11 +46,11 @@ use std::sync::Arc; /// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API /// access (examples in [`advanced_udf.rs`]). /// -/// See [`Self::call`] to invoke a `ScalarUDF` with arguments. +/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments. /// /// # API Note /// -/// This is a separate struct from `ScalarUDFImpl` to maintain backwards +/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards /// compatibility with the older API. /// /// [`create_udf`]: crate::expr_fn::create_udf @@ -59,17 +63,35 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.inner.equals(other.inner.as_ref()) + self.inner.dyn_eq(other.inner.as_any()) } } -// Manual implementation based on `ScalarUDFImpl::equals` impl PartialOrd for ScalarUDF { fn partial_cmp(&self, other: &Self) -> Option { - match self.name().partial_cmp(other.name()) { - Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), - cmp => cmp, + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), other.name() + ); + Some(cmp) } } @@ -77,7 +99,7 @@ impl Eq for ScalarUDF {} impl Hash for ScalarUDF { fn hash(&self, state: &mut H) { - self.inner.hash_value().hash(state) + self.inner.dyn_hash(state) } } @@ -140,7 +162,12 @@ impl ScalarUDF { /// Returns this function's display_name. /// /// See [`ScalarUDFImpl::display_name`] for more details + #[deprecated( + since = "50.0.0", + note = "This method is unused and will be removed in a future release" + )] pub fn display_name(&self, args: &[Expr]) -> Result { + #[expect(deprecated)] self.inner.display_name(args) } @@ -170,7 +197,7 @@ impl ScalarUDF { /// /// # Notes /// - /// If a function implement [`ScalarUDFImpl::return_type_from_args`], + /// If a function implement [`ScalarUDFImpl::return_field_from_args`], /// its [`ScalarUDFImpl::return_type`] should raise an error. /// /// See [`ScalarUDFImpl::return_type`] for more details. @@ -180,9 +207,9 @@ impl ScalarUDF { /// Return the datatype this function returns given the input argument types. /// - /// See [`ScalarUDFImpl::return_type_from_args`] for more details. - pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + /// See [`ScalarUDFImpl::return_field_from_args`] for more details. + pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } /// Do the function rewrite @@ -196,8 +223,9 @@ impl ScalarUDF { self.inner.simplify(args, info) } - #[allow(deprecated)] + #[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")] pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { + #[allow(deprecated)] self.inner.is_nullable(args, schema) } @@ -205,7 +233,23 @@ impl ScalarUDF { /// /// See [`ScalarUDFImpl::invoke_with_args`] for details. pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - self.inner.invoke_with_args(args) + #[cfg(debug_assertions)] + let return_field = Arc::clone(&args.return_field); + let result = self.inner.invoke_with_args(args)?; + // Maybe this could be enabled always? + // This doesn't use debug_assert!, but it's meant to run anywhere except on production. It's same in spirit, thus conditioning on debug_assertions. + #[cfg(debug_assertions)] + { + if &result.data_type() != return_field.data_type() { + return datafusion_common::internal_err!("Function '{}' returned value of type '{:?}' while the following type was promised at planning time and expected: '{:?}'", + self.name(), + result.data_type(), + return_field.data_type() + ); + } + // TODO verify return data is non-null when it was promised to be? + } + Ok(result) } /// Get the circuits of inner implementation @@ -280,6 +324,11 @@ impl ScalarUDF { pub fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + /// Return true if this function is an async function + pub fn as_async(&self) -> Option<&AsyncScalarUDF> { + self.inner().as_any().downcast_ref::() + } } impl From for ScalarUDF @@ -293,14 +342,28 @@ where /// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a /// scalar function. -pub struct ScalarFunctionArgs<'a> { +#[derive(Debug, Clone)] +pub struct ScalarFunctionArgs { /// The evaluated arguments to the function pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, /// The number of rows in record batch being evaluated pub number_rows: usize, - /// The return type of the scalar function returned (from `return_type` or `return_type_from_args`) - /// when creating the physical expression from the logical expression - pub return_type: &'a DataType, + /// The return field of the scalar function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +impl ScalarFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } } /// Information about arguments passed to the function @@ -309,64 +372,18 @@ pub struct ScalarFunctionArgs<'a> { /// such as the type of the arguments, any scalar arguments and if the /// arguments can (ever) be null /// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information +/// See [`ScalarUDFImpl::return_field_from_args`] for more information #[derive(Debug)] -pub struct ReturnTypeArgs<'a> { +pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function - pub arg_types: &'a [DataType], - /// Is argument `i` to the function a scalar (constant) + pub arg_fields: &'a [FieldRef], + /// Is argument `i` to the function a scalar (constant)? /// - /// If argument `i` is not a scalar, it will be None + /// If the argument `i` is not a scalar, it will be None /// /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], - /// Can argument `i` (ever) null? - pub nullables: &'a [bool], -} - -/// Return metadata for this function. -/// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information -#[derive(Debug)] -pub struct ReturnInfo { - return_type: DataType, - nullable: bool, -} - -impl ReturnInfo { - pub fn new(return_type: DataType, nullable: bool) -> Self { - Self { - return_type, - nullable, - } - } - - pub fn new_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: true, - } - } - - pub fn new_non_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: false, - } - } - - pub fn return_type(&self) -> &DataType { - &self.return_type - } - - pub fn nullable(&self) -> bool { - self.nullable - } - - pub fn into_parts(self) -> (DataType, bool) { - (self.return_type, self.nullable) - } } /// Trait for implementing user defined scalar functions. @@ -389,7 +406,7 @@ impl ReturnInfo { /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; /// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; /// /// This struct for a simple UDF that adds one to an int32 -/// #[derive(Debug)] +/// #[derive(Debug, PartialEq, Eq, Hash)] /// struct AddOne { /// signature: Signature, /// } @@ -438,22 +455,36 @@ impl ReturnInfo { /// // Call the function `add_one(col)` /// let expr = add_one.call(vec![col("a")]); /// ``` -pub trait ScalarUDFImpl: Debug + Send + Sync { - // Note: When adding any methods (with default implementations), remember to add them also - // into the AliasedScalarUDFImpl below! - +pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; /// Returns this function's name fn name(&self) -> &str; + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } + /// Returns the user-defined display name of function, given the arguments /// /// This can be used to customize the output column name generated by this /// function. /// /// Defaults to `name(args[0], args[1], ...)` + #[deprecated( + since = "50.0.0", + note = "This method is unused and will be removed in a future release" + )] fn display_name(&self, args: &[Expr]) -> Result { let names: Vec = args.iter().map(ToString::to_string).collect(); // TODO: join with ", " to standardize the formatting of Vec, @@ -471,18 +502,32 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { )) } - /// Returns the function's [`Signature`] for information about what input - /// types are accepted and the function's Volatility. + /// Returns a [`Signature`] describing the argument types for which this + /// function has an implementation, and the function's [`Volatility`]. + /// + /// See [`Signature`] for more details on argument type handling + /// and [`Self::return_type`] for computing the return type. + /// + /// [`Volatility`]: datafusion_expr_common::signature::Volatility fn signature(&self) -> &Signature; - /// What [`DataType`] will be returned by this function, given the types of - /// the arguments. + /// [`DataType`] returned by this function, given the types of the + /// arguments. + /// + /// # Arguments + /// + /// `arg_types` Data types of the arguments. The implementation of + /// `return_type` can assume that some other part of the code has coerced + /// the actual argument types to match [`Self::signature`]. /// /// # Notes /// - /// If you provide an implementation for [`Self::return_type_from_args`], - /// DataFusion will not call `return_type` (this function). In such cases - /// is recommended to return [`DataFusionError::Internal`]. + /// If you provide an implementation for [`Self::return_field_from_args`], + /// DataFusion will not call `return_type` (this function). While it is + /// valid to to put [`unimplemented!()`] or [`unreachable!()`], it is + /// recommended to return [`DataFusionError::Internal`] instead, which + /// reduces the severity of symptoms if bugs occur (an error rather than a + /// panic). /// /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal fn return_type(&self, arg_types: &[DataType]) -> Result; @@ -494,9 +539,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// Most UDFs should implement [`Self::return_type`] and not this - /// function as the output type for most functions only depends on the types - /// of their inputs (e.g. `sqrt(f32)` is always `f32`). + /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient, + /// as the result type is typically a deterministic function of the input types + /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly + /// is generally unnecessary unless the return type depends on runtime values. /// /// This function can be used for more advanced cases such as: /// @@ -504,6 +550,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// 2. return types based on the **values** of the arguments (rather than /// their **types**. /// + /// # Example creating `Field` + /// + /// Note the name of the [`Field`] is ignored, except for structured types such as + /// `DataType::Struct`. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::ReturnFieldArgs; + /// # struct Example{} + /// # impl Example { + /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + /// // report output is only nullable if any one of the arguments are nullable + /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// Ok(field) + /// } + /// # } + /// ``` + /// /// # Output Type based on Values /// /// For example, the following two function calls get the same argument @@ -518,14 +585,20 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function **must** consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let return_type = self.return_type(args.arg_types)?; - Ok(ReturnInfo::new_nullable(return_type)) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let return_type = self.return_type(&data_types)?; + Ok(Arc::new(Field::new(self.name(), return_type, true))) } #[deprecated( since = "45.0.0", - note = "Use `return_type_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_type_from_args`, you might have error" + note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error" )] fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true @@ -543,19 +616,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// to arrays, which will likely be simpler code, but be slower. fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result; - /// Returns any aliases (alternate names) for this function. - /// - /// Aliases can be used to invoke the same function using different names. - /// For example in some databases `now()` and `current_timestamp()` are - /// aliases for the same function. This behavior can be obtained by - /// returning `current_timestamp` as an alias for the `now` function. - /// - /// Note: `aliases` should only include names other than [`Self::name`]. - /// Defaults to `[]` (no aliases) - fn aliases(&self) -> &[String] { - &[] - } - /// Optionally apply per-UDF simplification / rewrite rules. /// /// This can be used to apply function specific simplification rules during @@ -575,6 +635,14 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE /// if the function cannot be simplified, the arguments *MUST* be returned /// unmodified + /// + /// # Notes + /// + /// The returned expression must have the same schema as the original + /// expression, including both the data type and nullability. For example, + /// if the original expression is nullable, the returned expression must + /// also be nullable, otherwise it may lead to schema verification errors + /// later in query planning. fn simplify( &self, args: Vec, @@ -584,13 +652,15 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } /// Returns true if some of this `exprs` subexpressions may not be evaluated - /// and thus any side effects (like divide by zero) may not be encountered - /// Setting this to true prevents certain optimizations such as common subexpression elimination + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// Setting this to true prevents certain optimizations such as common + /// subexpression elimination fn short_circuits(&self) -> bool { false } - /// Computes the output interval for a [`ScalarUDFImpl`], given the input + /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input /// intervals. /// /// # Parameters @@ -606,9 +676,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { Interval::make_unbounded(&DataType::Null) } - /// Updates bounds for child expressions, given a known interval for this - /// function. This is used to propagate constraints down through an expression - /// tree. + /// Updates bounds for child expressions, given a known [`Interval`]s for this + /// function. + /// + /// This function is used to propagate constraints down through an + /// expression tree. /// /// # Parameters /// @@ -657,20 +729,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } } - /// Whether the function preserves lexicographical ordering based on the input ordering + /// Returns true if the function preserves lexicographical ordering based on + /// the input ordering. + /// + /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not. fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { Ok(false) } /// Coerce arguments of a function call to types that the function can evaluate. /// - /// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most - /// UDFs should return one of the other variants of `TypeSignature` which handle common - /// cases + /// This function is only called if [`ScalarUDFImpl::signature`] returns + /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of + /// the other variants of [`TypeSignature`] which handle common cases. /// /// See the [type coercion module](crate::type_coercion) /// documentation for more details on type coercion /// + /// [`TypeSignature`]: crate::TypeSignature + /// /// For example, if your function requires a floating point arguments, but the user calls /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` /// to ensure the argument is converted to `1::double` @@ -685,37 +762,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { not_impl_err!("Function {} does not implement coerce_types", self.name()) } - /// Return true if this scalar UDF is equal to the other. - /// - /// Allows customizing the equality of scalar UDFs. - /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: - /// - /// - reflexive: `a.equals(a)`; - /// - symmetric: `a.equals(b)` implies `b.equals(a)`; - /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. - /// - /// By default, compares [`Self::name`] and [`Self::signature`]. - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - self.name() == other.name() && self.signature() == other.signature() - } - - /// Returns a hash value for this scalar UDF. - /// - /// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`], - /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. - /// - /// By default, hashes [`Self::name`] and [`Self::signature`]. - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.name().hash(hasher); - self.signature().hash(hasher); - hasher.finish() - } - /// Returns the documentation for this Scalar UDF. /// - /// Documentation can be accessed programmatically as well as - /// generating publicly facing documentation. + /// Documentation can be accessed programmatically as well as generating + /// publicly facing documentation. fn documentation(&self) -> Option<&Documentation> { None } @@ -723,9 +773,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// ScalarUDF that adds an alias to the underlying function. It is better to /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedScalarUDFImpl { - inner: Arc, + inner: UdfEq>, aliases: Vec, } @@ -736,10 +786,14 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } + Self { + inner: inner.into(), + aliases, + } } } +#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method impl ScalarUDFImpl for AliasedScalarUDFImpl { fn as_any(&self) -> &dyn Any { self @@ -750,6 +804,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { } fn display_name(&self, args: &[Expr]) -> Result { + #[expect(deprecated)] self.inner.display_name(args) } @@ -765,18 +820,23 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } - fn aliases(&self) -> &[String] { - &self.aliases + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { + #[allow(deprecated)] + self.inner.is_nullable(args, schema) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { self.inner.invoke_with_args(args) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn simplify( &self, args: Vec, @@ -813,138 +873,87 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.coerce_types(arg_types) } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases - } else { - false + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr_common::signature::Volatility; + use std::hash::DefaultHasher; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestScalarUDFImpl { + name: &'static str, + field: &'static str, + signature: Signature, + } + impl ScalarUDFImpl for TestScalarUDFImpl { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + unimplemented!() } } - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.inner.hash_value().hash(hasher); - self.aliases.hash(hasher); - hasher.finish() + // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd + // must be consistent, so they are tested together. + #[test] + fn test_partial_eq_hash_and_partial_ord() { + // A parameterized function + let f = test_func("foo", "a"); + + // Same like `f`, different instance + let f2 = test_func("foo", "a"); + assert_eq!(f, f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Different parameter + let b = test_func("foo", "b"); + assert_ne!(f, b); + assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&b), None); + + // Different name + let o = test_func("other", "a"); + assert_ne!(f, o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), Some(Ordering::Less)); + + // Different name and parameter + assert_ne!(b, o); + assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); } - fn documentation(&self) -> Option<&Documentation> { - self.inner.documentation() + fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF { + ScalarUDF::from(TestScalarUDFImpl { + name, + field: parameter, + signature: Signature::any(1, Volatility::Immutable), + }) } -} -// Scalar UDF doc sections for use in public documentation -pub mod scalar_doc_sections { - use crate::DocSection; - - pub fn doc_sections() -> Vec { - vec![ - DOC_SECTION_MATH, - DOC_SECTION_CONDITIONAL, - DOC_SECTION_STRING, - DOC_SECTION_BINARY_STRING, - DOC_SECTION_REGEX, - DOC_SECTION_DATETIME, - DOC_SECTION_ARRAY, - DOC_SECTION_STRUCT, - DOC_SECTION_MAP, - DOC_SECTION_HASHING, - DOC_SECTION_UNION, - DOC_SECTION_OTHER, - ] - } - - pub const fn doc_sections_const() -> &'static [DocSection] { - &[ - DOC_SECTION_MATH, - DOC_SECTION_CONDITIONAL, - DOC_SECTION_STRING, - DOC_SECTION_BINARY_STRING, - DOC_SECTION_REGEX, - DOC_SECTION_DATETIME, - DOC_SECTION_ARRAY, - DOC_SECTION_STRUCT, - DOC_SECTION_MAP, - DOC_SECTION_HASHING, - DOC_SECTION_UNION, - DOC_SECTION_OTHER, - ] - } - - pub const DOC_SECTION_MATH: DocSection = DocSection { - include: true, - label: "Math Functions", - description: None, - }; - - pub const DOC_SECTION_CONDITIONAL: DocSection = DocSection { - include: true, - label: "Conditional Functions", - description: None, - }; - - pub const DOC_SECTION_STRING: DocSection = DocSection { - include: true, - label: "String Functions", - description: None, - }; - - pub const DOC_SECTION_BINARY_STRING: DocSection = DocSection { - include: true, - label: "Binary String Functions", - description: None, - }; - - pub const DOC_SECTION_REGEX: DocSection = DocSection { - include: true, - label: "Regular Expression Functions", - description: Some( - r#"Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions) -regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) -(minus support for several features including look-around and backreferences). -The following regular expression functions are supported:"#, - ), - }; - - pub const DOC_SECTION_DATETIME: DocSection = DocSection { - include: true, - label: "Time and Date Functions", - description: None, - }; - - pub const DOC_SECTION_ARRAY: DocSection = DocSection { - include: true, - label: "Array Functions", - description: None, - }; - - pub const DOC_SECTION_STRUCT: DocSection = DocSection { - include: true, - label: "Struct Functions", - description: None, - }; - - pub const DOC_SECTION_MAP: DocSection = DocSection { - include: true, - label: "Map Functions", - description: None, - }; - - pub const DOC_SECTION_HASHING: DocSection = DocSection { - include: true, - label: "Hashing Functions", - description: None, - }; - - pub const DOC_SECTION_OTHER: DocSection = DocSection { - include: true, - label: "Other Functions", - description: None, - }; - - pub const DOC_SECTION_UNION: DocSection = DocSection { - include: true, - label: "Union Functions", - description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"), - }; + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } } diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs new file mode 100644 index 0000000000000..6664495267129 --- /dev/null +++ b/datafusion/expr/src/udf_eq.rs @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use std::fmt::Debug; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +/// A wrapper around a pointer to UDF that implements `Eq` and `Hash` delegating to +/// corresponding methods on the UDF trait. +/// +/// If you want to just compare pointers for equality, use [`super::ptr_eq::PtrEq`]. +#[derive(Clone)] +#[allow(private_bounds)] // This is so that UdfEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. +pub struct UdfEq(Ptr); + +impl PartialEq for UdfEq +where + Ptr: UdfPointer, +{ + fn eq(&self, other: &Self) -> bool { + self.0.equals(&other.0) + } +} +impl Eq for UdfEq where Ptr: UdfPointer {} +impl Hash for UdfEq +where + Ptr: UdfPointer, +{ + fn hash(&self, state: &mut H) { + self.0.hash_value().hash(state); + } +} + +impl From for UdfEq +where + Ptr: UdfPointer, +{ + fn from(ptr: Ptr) -> Self { + UdfEq(ptr) + } +} + +impl Debug for UdfEq +where + Ptr: UdfPointer + Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Deref for UdfEq +where + Ptr: UdfPointer, +{ + type Target = Ptr; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +trait UdfPointer: Deref { + fn equals(&self, other: &Self::Target) -> bool; + fn hash_value(&self) -> u64; +} + +impl UdfPointer for Arc { + fn equals(&self, other: &(dyn ScalarUDFImpl + '_)) -> bool { + self.as_ref().dyn_eq(other.as_any()) + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.as_ref().dyn_hash(hasher); + hasher.finish() + } +} + +impl UdfPointer for Arc { + fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool { + self.as_ref().dyn_eq(other.as_any()) + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.as_ref().dyn_hash(hasher); + hasher.finish() + } +} + +impl UdfPointer for Arc { + fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool { + self.as_ref().dyn_eq(other.as_any()) + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.as_ref().dyn_hash(hasher); + hasher.finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ScalarFunctionArgs; + use arrow::datatypes::DataType; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::signature::{Signature, Volatility}; + use std::any::Any; + use std::hash::DefaultHasher; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestScalarUDF { + signature: Signature, + name: &'static str, + } + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + unimplemented!() + } + } + + #[test] + pub fn test_eq_eq_wrapper() { + let signature = Signature::any(1, Volatility::Immutable); + + let a1: Arc = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "a", + }); + let a2: Arc = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "a", + }); + let b: Arc = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "b", + }); + + // Reflexivity + let wrapper = UdfEq(Arc::clone(&a1)); + assert_eq!(wrapper, wrapper); + + // Two wrappers around equal pointer + assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a1))); + assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a1)))); + + // Two wrappers around different pointers but equal in ScalarUDFImpl::equals sense + assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a2))); + assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a2)))); + + // different functions (not equal) + assert_ne!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&b))); + } + + fn hash(value: T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 4da63d7955f58..7ca2f0662d48f 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -19,21 +19,23 @@ use arrow::compute::SortOptions; use std::cmp::Ordering; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::{ any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, FieldRef}; use crate::expr::WindowFunction; +use crate::udf_eq::UdfEq; use crate::{ function::WindowFunctionSimplification, Expr, PartitionEvaluator, Signature, }; use datafusion_common::{not_impl_err, Result}; use datafusion_doc::Documentation; +use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -80,7 +82,7 @@ impl Display for WindowUDF { impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.inner.equals(other.inner.as_ref()) + self.inner.dyn_eq(other.inner.as_any()) } } @@ -88,7 +90,7 @@ impl Eq for WindowUDF {} impl Hash for WindowUDF { fn hash(&self, state: &mut H) { - self.inner.hash_value().hash(state) + self.inner.dyn_hash(state) } } @@ -133,7 +135,7 @@ impl WindowUDF { pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) } /// Returns this function's name @@ -179,7 +181,7 @@ impl WindowUDF { /// Returns the field of the final result of evaluating this window function. /// /// See [`WindowUDFImpl::field`] for more details. - pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { self.inner.field(field_args) } @@ -227,6 +229,10 @@ where /// This trait exposes the full API for implementing user defined window functions and /// can be used to implement any function. /// +/// While the trait depends on [`DynEq`] and [`DynHash`] traits, these should not be +/// implemented directly. Instead, implement [`Eq`] and [`Hash`] and leverage the +/// blanket implementations of [`DynEq`] and [`DynHash`]. +/// /// See [`advanced_udwf.rs`] for a full example with complete implementation and /// [`WindowUDF`] for other available options. /// @@ -236,15 +242,17 @@ where /// ``` /// # use std::any::Any; /// # use std::sync::LazyLock; -/// # use arrow::datatypes::{DataType, Field}; +/// # use arrow::datatypes::{DataType, Field, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation, LimitEffect}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; /// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +/// # use datafusion_physical_expr_common::physical_expr; +/// # use std::sync::Arc; /// -/// #[derive(Debug, Clone)] +/// #[derive(Debug, Clone, PartialEq, Eq, Hash)] /// struct SmoothIt { /// signature: Signature, /// } @@ -279,9 +287,9 @@ where /// ) -> Result> { /// unimplemented!() /// } -/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { -/// if let Some(DataType::Int32) = field_args.get_input_type(0) { -/// Ok(Field::new(field_args.name(), DataType::Int32, false)) +/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { +/// if let Some(DataType::Int32) = field_args.get_input_field(0).map(|f| f.data_type().clone()) { +/// Ok(Field::new(field_args.name(), DataType::Int32, false).into()) /// } else { /// plan_err!("smooth_it only accepts Int32 arguments") /// } @@ -289,6 +297,9 @@ where /// fn documentation(&self) -> Option<&Documentation> { /// Some(get_doc()) /// } +/// fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { +/// LimitEffect::Unknown +/// } /// } /// /// // Create a new WindowUDF from the implementation @@ -303,16 +314,21 @@ where /// .build() /// .unwrap(); /// ``` -pub trait WindowUDFImpl: Debug + Send + Sync { - // Note: When adding any methods (with default implementations), remember to add them also - // into the AliasedWindowUDFImpl below! - +pub trait WindowUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; /// Returns this function's name fn name(&self) -> &str; + /// Returns any aliases (alternate names) for this function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } + /// Returns the function's [`Signature`] for information about what input /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; @@ -328,14 +344,6 @@ pub trait WindowUDFImpl: Debug + Send + Sync { partition_evaluator_args: PartitionEvaluatorArgs, ) -> Result>; - /// Returns any aliases (alternate names) for this function. - /// - /// Note: `aliases` should only include names other than [`Self::name`]. - /// Defaults to `[]` (no aliases) - fn aliases(&self) -> &[String] { - &[] - } - /// Optionally apply per-UDWF simplification / rewrite rules. /// /// This can be used to apply function specific simplification rules during @@ -355,43 +363,23 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// Or, a closure with two arguments: /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] + /// + /// # Notes + /// The returned expression must have the same schema as the original + /// expression, including both the data type and nullability. For example, + /// if the original expression is nullable, the returned expression must + /// also be nullable, otherwise it may lead to schema verification errors + /// later in query planning. fn simplify(&self) -> Option { None } - /// Return true if this window UDF is equal to the other. - /// - /// Allows customizing the equality of window UDFs. - /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: - /// - /// - reflexive: `a.equals(a)`; - /// - symmetric: `a.equals(b)` implies `b.equals(a)`; - /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. - /// - /// By default, compares [`Self::name`] and [`Self::signature`]. - fn equals(&self, other: &dyn WindowUDFImpl) -> bool { - self.name() == other.name() && self.signature() == other.signature() - } - - /// Returns a hash value for this window UDF. - /// - /// Allows customizing the hash code of window UDFs. Similarly to [`Hash`] and [`Eq`], - /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. - /// - /// By default, hashes [`Self::name`] and [`Self::signature`]. - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.name().hash(hasher); - self.signature().hash(hasher); - hasher.finish() - } - - /// The [`Field`] of the final result of evaluating this window function. + /// The [`FieldRef`] of the final result of evaluating this window function. /// /// Call `field_args.name()` to get the fully qualified name for defining - /// the [`Field`]. For a complete example see the implementation in the + /// the [`FieldRef`]. For a complete example see the implementation in the /// [Basic Example](WindowUDFImpl#basic-example) section. - fn field(&self, field_args: WindowUDFFieldArgs) -> Result; + fn field(&self, field_args: WindowUDFFieldArgs) -> Result; /// Allows the window UDF to define a custom result ordering. /// @@ -438,6 +426,23 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// If not causal, returns the effect this function will have on the window + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } +} + +/// the effect this function will have on the limit pushdown +pub enum LimitEffect { + /// Does not affect the limit (i.e. this is causal) + None, + /// Either undeclared, or dynamic (only evaluatable at run time) + Unknown, + /// Grow the limit by N rows + Relative(usize), + /// Limit needs to be at least N rows + Absolute(usize), } pub enum ReversedUDWF { @@ -454,7 +459,7 @@ pub enum ReversedUDWF { impl PartialEq for dyn WindowUDFImpl { fn eq(&self, other: &Self) -> bool { - self.equals(other) + self.dyn_eq(other.as_any()) } } @@ -464,14 +469,16 @@ impl PartialOrd for dyn WindowUDFImpl { Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } /// WindowUDF that adds an alias to the underlying function. It is better to /// implement [`WindowUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedWindowUDFImpl { - inner: Arc, + inner: UdfEq>, aliases: Vec, } @@ -483,10 +490,14 @@ impl AliasedWindowUDFImpl { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } + Self { + inner: inner.into(), + aliases, + } } } +#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method impl WindowUDFImpl for AliasedWindowUDFImpl { fn as_any(&self) -> &dyn Any { self @@ -522,22 +533,7 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.simplify() } - fn equals(&self, other: &dyn WindowUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.inner.hash_value().hash(hasher); - self.aliases.hash(hasher); - hasher.finish() - } - - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { self.inner.field(field_args) } @@ -549,54 +545,34 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.coerce_types(arg_types) } + fn reverse_expr(&self) -> ReversedUDWF { + self.inner.reverse_expr() + } + fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } -} -// Window UDF doc sections for use in public documentation -pub mod window_doc_sections { - use datafusion_doc::DocSection; - - pub fn doc_sections() -> Vec { - vec![ - DOC_SECTION_AGGREGATE, - DOC_SECTION_RANKING, - DOC_SECTION_ANALYTICAL, - ] - } - - pub const DOC_SECTION_AGGREGATE: DocSection = DocSection { - include: true, - label: "Aggregate Functions", - description: Some("All aggregate functions can be used as window functions."), - }; - - pub const DOC_SECTION_RANKING: DocSection = DocSection { - include: true, - label: "Ranking Functions", - description: None, - }; - - pub const DOC_SECTION_ANALYTICAL: DocSection = DocSection { - include: true, - label: "Analytical Functions", - description: None, - }; + fn limit_effect(&self, args: &[Arc]) -> LimitEffect { + self.inner.limit_effect(args) + } } #[cfg(test)] mod test { - use crate::{PartitionEvaluator, WindowUDF, WindowUDFImpl}; - use arrow::datatypes::{DataType, Field}; + use crate::{LimitEffect, PartitionEvaluator, WindowUDF, WindowUDFImpl}; + use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; use datafusion_expr_common::signature::{Signature, Volatility}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::Ordering; + use std::hash::{DefaultHasher, Hash, Hasher}; + use std::sync::Arc; - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct AWindowUDF { signature: Signature, } @@ -630,12 +606,16 @@ mod test { ) -> Result> { unimplemented!() } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!() } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct BWindowUDF { signature: Signature, } @@ -669,9 +649,23 @@ mod test { ) -> Result> { unimplemented!() } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!() } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } + } + + #[test] + fn test_partial_eq() { + let a1 = WindowUDF::from(AWindowUDF::new()); + let a2 = WindowUDF::from(AWindowUDF::new()); + let eq = a1 == a2; + assert!(eq); + assert_eq!(a1, a2); + assert_eq!(hash(a1), hash(a2)); } #[test] @@ -684,4 +678,10 @@ mod test { assert!(a1 < b1); assert!(!(a1 == b1)); } + + fn hash(value: T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 552ce1502d466..b91db4527b3aa 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; -use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction, WindowFunctionParams}; +use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams}; use crate::expr_rewriter::strip_outer_reference; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, @@ -38,7 +38,10 @@ use datafusion_common::{ Result, TableReference, }; +#[cfg(not(feature = "sql"))] +use crate::expr::{ExceptSelectItem, ExcludeSelectItem}; use indexmap::IndexSet; +#[cfg(feature = "sql")] use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem}; pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; @@ -276,7 +279,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { Expr::Unnest(_) | Expr::ScalarVariable(_, _) | Expr::Alias(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::BinaryExpr { .. } | Expr::Like { .. } | Expr::SimilarTo { .. } @@ -579,7 +582,8 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ params: WindowFunctionParams { partition_by, order_by, ..}, .. }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params; let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), @@ -608,7 +612,7 @@ pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator) -> Ve /// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence /// (depth first), with duplicates omitted. -pub fn find_window_exprs(exprs: &[Expr]) -> Vec { +pub fn find_window_exprs<'a>(exprs: impl IntoIterator) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { matches!(nested_expr, Expr::WindowFunction { .. }) }) @@ -689,7 +693,23 @@ where err } -/// Create field meta-data from an expression, for use in a result set schema +/// Create schema fields from an expression list, for use in result set schema construction +/// +/// This function converts a list of expressions into a list of complete schema fields, +/// making comprehensive determinations about each field's properties including: +/// - **Data type**: Resolved based on expression type and input schema context +/// - **Nullability**: Determined by expression-specific nullability rules +/// - **Metadata**: Computed based on expression type (preserving, merging, or generating new metadata) +/// - **Table reference scoping**: Establishing proper qualified field references +/// +/// Each expression is converted to a field by calling [`Expr::to_field`], which performs +/// the complete field resolution process for all field properties. +/// +/// # Returns +/// +/// A `Result` containing a vector of `(Option, Arc)` tuples, +/// where each Field contains complete schema information (type, nullability, metadata) +/// and proper table reference scoping for the corresponding expression. pub fn exprlist_to_fields<'a>( exprs: impl IntoIterator, plan: &LogicalPlan, @@ -784,7 +804,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( indexes.push(idx); } } - Expr::Literal(_) => { + Expr::Literal(_, _) => { indexes.push(usize::MAX); } _ => {} @@ -813,6 +833,8 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::Float16 => true, DataType::Float32 => true, DataType::Float64 => true, + DataType::Decimal32(_, _) => true, + DataType::Decimal64(_, _) => true, DataType::Decimal128(_, _) => true, DataType::Decimal256(_, _) => true, DataType::Timestamp(_, _) => true, @@ -1222,6 +1244,9 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { } /// merge inputs schema into a single schema. +/// +/// This function merges schemas from multiple logical plan inputs using [`DFSchema::merge`]. +/// Refer to that documentation for details on precedence and metadata handling. pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema { if inputs.len() == 1 { inputs[0].schema().as_ref().clone() @@ -1263,9 +1288,11 @@ pub fn collect_subquery_cols( mod tests { use super::*; use crate::{ - col, cube, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::max_udaf, test::function_stub::min_udaf, - test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, + col, cube, + expr::WindowFunction, + expr_vec_fmt, grouping_set, lit, rollup, + test::function_stub::{max_udaf, min_udaf, sum_udaf}, + Cast, ExprFunctionExt, WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; @@ -1279,19 +1306,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1309,25 +1336,25 @@ mod tests { let age_asc = Sort::new(col("age"), true, true); let name_desc = Sort::new(col("name"), false, true); let created_at_desc = Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 8771b25137cf2..f72dc10a6950f 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -24,13 +24,12 @@ //! - An EXCLUDE clause. use crate::{expr::Sort, lit}; -use arrow::datatypes::DataType; use std::fmt::{self, Formatter}; use std::hash::Hash; -use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{plan_err, Result, ScalarValue}; +#[cfg(feature = "sql")] use sqlparser::ast::{self, ValueWithSpan}; -use sqlparser::parser::ParserError::ParserError; /// The frame specification determines which output rows are read by an aggregate /// window function. The ending frame boundary can be omitted if the `BETWEEN` @@ -115,8 +114,9 @@ impl fmt::Debug for WindowFrame { } } +#[cfg(feature = "sql")] impl TryFrom for WindowFrame { - type Error = DataFusionError; + type Error = datafusion_common::error::DataFusionError; fn try_from(value: ast::WindowFrame) -> Result { let start_bound = WindowFrameBound::try_parse(value.start_bound, &value.units)?; @@ -160,7 +160,7 @@ impl WindowFrame { } else { WindowFrameUnits::Range }, - start_bound: WindowFrameBound::Preceding(ScalarValue::Null), + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), end_bound: WindowFrameBound::CurrentRow, causal: strict, } @@ -343,6 +343,7 @@ impl WindowFrameBound { } impl WindowFrameBound { + #[cfg(feature = "sql")] fn try_parse( value: ast::WindowFrameBound, units: &ast::WindowFrameUnits, @@ -351,20 +352,27 @@ impl WindowFrameBound { ast::WindowFrameBound::Preceding(Some(v)) => { Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?) } - ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), + ast::WindowFrameBound::Preceding(None) => { + Self::Preceding(ScalarValue::UInt64(None)) + } ast::WindowFrameBound::Following(Some(v)) => { Self::Following(convert_frame_bound_to_scalar_value(*v, units)?) } - ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), + ast::WindowFrameBound::Following(None) => { + Self::Following(ScalarValue::UInt64(None)) + } ast::WindowFrameBound::CurrentRow => Self::CurrentRow, }) } } +#[cfg(feature = "sql")] fn convert_frame_bound_to_scalar_value( v: ast::Expr, units: &ast::WindowFrameUnits, ) -> Result { + use arrow::datatypes::DataType; + use datafusion_common::exec_err; match units { // For ROWS and GROUPS we are sure that the ScalarValue must be a non-negative integer ... ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v { @@ -381,9 +389,9 @@ fn convert_frame_bound_to_scalar_value( let value = match *value { ast::Expr::Value(ValueWithSpan{value: ast::Value::SingleQuotedString(item), span: _}) => item, e => { - return sql_err!(ParserError(format!( + return exec_err!( "INTERVAL expression cannot be {e:?}" - ))); + ); } }; Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) @@ -404,9 +412,9 @@ fn convert_frame_bound_to_scalar_value( let result = match *value { ast::Expr::Value(ValueWithSpan{value: ast::Value::SingleQuotedString(item), span: _}) => item, e => { - return sql_err!(ParserError(format!( + return exec_err!( "INTERVAL expression cannot be {e:?}" - ))); + ); } }; if let Some(leading_field) = leading_field { @@ -473,6 +481,7 @@ impl fmt::Display for WindowFrameUnits { } } +#[cfg(feature = "sql")] impl From for WindowFrameUnits { fn from(value: ast::WindowFrameUnits) -> Self { match value { @@ -570,9 +579,9 @@ mod tests { #[test] fn test_window_frame_bound_creation() -> Result<()> { // Unbounded - test_bound!(Rows, None, ScalarValue::Null); - test_bound!(Groups, None, ScalarValue::Null); - test_bound!(Range, None, ScalarValue::Null); + test_bound!(Rows, None, ScalarValue::UInt64(None)); + test_bound!(Groups, None, ScalarValue::UInt64(None)); + test_bound!(Range, None, ScalarValue::UInt64(None)); // Number let number = Some(Box::new(ast::Expr::Value( diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index f1d0ead23ab19..cdfb18ee1ddd7 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -28,13 +28,13 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{ - internal_err, + internal_datafusion_err, internal_err, utils::{compare_rows, get_row_at_idx, search_in_slice}, - DataFusionError, Result, ScalarValue, + Result, ScalarValue, }; /// Holds the state of evaluating a window function -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct WindowAggState { /// The range that we calculate the window function pub window_frame_range: Range, @@ -90,7 +90,12 @@ impl WindowAggState { partition_batch_state: &PartitionBatchState, ) -> Result<()> { self.last_calculated_index += out_col.len(); - self.out_col = concat(&[&self.out_col, &out_col])?; + // no need to use concat if the current `out_col` is empty + if self.out_col.is_empty() { + self.out_col = Arc::clone(out_col); + } else { + self.out_col = concat(&[&self.out_col, &out_col])?; + } self.n_row_result_missing = partition_batch_state.record_batch.num_rows() - self.last_calculated_index; self.is_end = partition_batch_state.is_end; @@ -112,7 +117,7 @@ impl WindowAggState { } /// This object stores the window frame state for use in incremental calculations. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum WindowFrameContext { /// ROWS frames are inherently stateless. Rows(Arc), @@ -193,11 +198,7 @@ impl WindowFrameContext { // UNBOUNDED PRECEDING WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { - if idx >= n as usize { - idx - n as usize - } else { - 0 - } + idx.saturating_sub(n as usize) } WindowFrameBound::CurrentRow => idx, // UNBOUNDED FOLLOWING @@ -211,7 +212,7 @@ impl WindowFrameContext { } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return internal_err!("Rows should be Uint") + return internal_err!("Rows should be UInt64") } }; let end = match window_frame.end_bound { @@ -236,7 +237,7 @@ impl WindowFrameContext { } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return internal_err!("Rows should be Uint") + return internal_err!("Rows should be UInt64") } }; Ok(Range { start, end }) @@ -244,7 +245,7 @@ impl WindowFrameContext { } /// State for each unique partition determined according to PARTITION BY column(s) -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub struct PartitionBatchState { /// The record batch belonging to current partition pub record_batch: RecordBatch, @@ -269,6 +270,15 @@ impl PartitionBatchState { } } + pub fn new_with_batch(batch: RecordBatch) -> Self { + Self { + record_batch: batch, + most_recent_row: None, + is_end: false, + n_out_row: 0, + } + } + pub fn extend(&mut self, batch: &RecordBatch) -> Result<()> { self.record_batch = concat_batches(&self.record_batch.schema(), [&self.record_batch, batch])?; @@ -286,7 +296,7 @@ impl PartitionBatchState { /// ranges of data while processing RANGE frames. /// Attribute `sort_options` stores the column ordering specified by the ORDER /// BY clause. This information is used to calculate the range. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct WindowFrameStateRange { sort_options: Vec, } @@ -392,8 +402,8 @@ impl WindowFrameStateRange { .sort_options .first() .ok_or_else(|| { - DataFusionError::Internal( - "Sort options unexpectedly absent in a window frame".to_string(), + internal_datafusion_err!( + "Sort options unexpectedly absent in a window frame" ) })? .descending; @@ -458,7 +468,7 @@ impl WindowFrameStateRange { /// This structure encapsulates all the state information we require as we /// scan groups of data while processing window frames. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct WindowFrameStateGroups { /// A tuple containing group values and the row index where the group ends. /// Example: [[1, 1], [1, 1], [2, 1], [2, 1], ...] would correspond to @@ -602,11 +612,7 @@ impl WindowFrameStateGroups { // Find the group index of the frame boundary: let group_idx = if SEARCH_SIDE { - if self.current_group_idx > delta { - self.current_group_idx - delta - } else { - 0 - } + self.current_group_idx.saturating_sub(delta) } else { self.current_group_idx + delta }; @@ -683,9 +689,9 @@ mod tests { (range_columns, sort_options) } - fn assert_expected( - expected_results: Vec<(Range, usize)>, + fn assert_group_ranges( window_frame: &Arc, + expected_results: Vec<(Range, usize)>, ) -> Result<()> { let mut window_frame_groups = WindowFrameStateGroups::default(); let (range_columns, _) = get_test_data(); @@ -705,6 +711,136 @@ mod tests { Ok(()) } + fn assert_frame_ranges( + window_frame: &Arc, + expected_results: Vec>, + ) -> Result<()> { + let mut window_frame_context = + WindowFrameContext::new(Arc::clone(window_frame), vec![]); + let (range_columns, _) = get_test_data(); + let n_row = range_columns[0].len(); + let mut last_range = Range { start: 0, end: 0 }; + for (idx, expected_range) in expected_results.into_iter().enumerate() { + let range = window_frame_context.calculate_range( + &range_columns, + &last_range, + n_row, + idx, + )?; + assert_eq!(range, expected_range); + last_range = range; + } + Ok(()) + } + + #[test] + fn test_default_window_frame_group_boundaries() -> Result<()> { + let window_frame = Arc::new(WindowFrame::new(None)); + assert_group_ranges( + &window_frame, + vec![ + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + ], + )?; + + assert_frame_ranges( + &window_frame, + vec![ + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + ], + )?; + + Ok(()) + } + + #[test] + fn test_unordered_window_frame_group_boundaries() -> Result<()> { + let window_frame = Arc::new(WindowFrame::new(Some(false))); + assert_group_ranges( + &window_frame, + vec![ + (Range { start: 0, end: 1 }, 0), + (Range { start: 0, end: 2 }, 1), + (Range { start: 0, end: 4 }, 2), + (Range { start: 0, end: 4 }, 2), + (Range { start: 0, end: 5 }, 3), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 9 }, 5), + ], + )?; + + assert_frame_ranges( + &window_frame, + vec![ + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + ], + )?; + + Ok(()) + } + + #[test] + fn test_ordered_window_frame_group_boundaries() -> Result<()> { + let window_frame = Arc::new(WindowFrame::new(Some(true))); + assert_group_ranges( + &window_frame, + vec![ + (Range { start: 0, end: 1 }, 0), + (Range { start: 0, end: 2 }, 1), + (Range { start: 0, end: 4 }, 2), + (Range { start: 0, end: 4 }, 2), + (Range { start: 0, end: 5 }, 3), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 9 }, 5), + ], + )?; + + assert_frame_ranges( + &window_frame, + vec![ + Range { start: 0, end: 1 }, + Range { start: 0, end: 2 }, + Range { start: 0, end: 3 }, + Range { start: 0, end: 4 }, + Range { start: 0, end: 5 }, + Range { start: 0, end: 6 }, + Range { start: 0, end: 7 }, + Range { start: 0, end: 8 }, + Range { start: 0, end: 9 }, + ], + )?; + + Ok(()) + } + #[test] fn test_window_frame_group_boundaries() -> Result<()> { let window_frame = Arc::new(WindowFrame::new_bounds( @@ -712,18 +848,20 @@ mod tests { WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), )); - let expected_results = vec![ - (Range { start: 0, end: 2 }, 0), - (Range { start: 0, end: 4 }, 1), - (Range { start: 1, end: 5 }, 2), - (Range { start: 1, end: 5 }, 2), - (Range { start: 2, end: 8 }, 3), - (Range { start: 4, end: 9 }, 4), - (Range { start: 4, end: 9 }, 4), - (Range { start: 4, end: 9 }, 4), - (Range { start: 5, end: 9 }, 5), - ]; - assert_expected(expected_results, &window_frame) + assert_group_ranges( + &window_frame, + vec![ + (Range { start: 0, end: 2 }, 0), + (Range { start: 0, end: 4 }, 1), + (Range { start: 1, end: 5 }, 2), + (Range { start: 1, end: 5 }, 2), + (Range { start: 2, end: 8 }, 3), + (Range { start: 4, end: 9 }, 4), + (Range { start: 4, end: 9 }, 4), + (Range { start: 4, end: 9 }, 4), + (Range { start: 5, end: 9 }, 5), + ], + ) } #[test] @@ -733,18 +871,20 @@ mod tests { WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), )); - let expected_results = vec![ - (Range:: { start: 1, end: 4 }, 0), - (Range:: { start: 2, end: 5 }, 1), - (Range:: { start: 4, end: 8 }, 2), - (Range:: { start: 4, end: 8 }, 2), - (Range:: { start: 5, end: 9 }, 3), - (Range:: { start: 8, end: 9 }, 4), - (Range:: { start: 8, end: 9 }, 4), - (Range:: { start: 8, end: 9 }, 4), - (Range:: { start: 9, end: 9 }, 5), - ]; - assert_expected(expected_results, &window_frame) + assert_group_ranges( + &window_frame, + vec![ + (Range:: { start: 1, end: 4 }, 0), + (Range:: { start: 2, end: 5 }, 1), + (Range:: { start: 4, end: 8 }, 2), + (Range:: { start: 4, end: 8 }, 2), + (Range:: { start: 5, end: 9 }, 3), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 9, end: 9 }, 5), + ], + ) } #[test] @@ -754,17 +894,19 @@ mod tests { WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), )); - let expected_results = vec![ - (Range:: { start: 0, end: 0 }, 0), - (Range:: { start: 0, end: 1 }, 1), - (Range:: { start: 0, end: 2 }, 2), - (Range:: { start: 0, end: 2 }, 2), - (Range:: { start: 1, end: 4 }, 3), - (Range:: { start: 2, end: 5 }, 4), - (Range:: { start: 2, end: 5 }, 4), - (Range:: { start: 2, end: 5 }, 4), - (Range:: { start: 4, end: 8 }, 5), - ]; - assert_expected(expected_results, &window_frame) + assert_group_ranges( + &window_frame, + vec![ + (Range:: { start: 0, end: 0 }, 0), + (Range:: { start: 0, end: 1 }, 1), + (Range:: { start: 0, end: 2 }, 2), + (Range:: { start: 0, end: 2 }, 2), + (Range:: { start: 1, end: 4 }, 3), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 4, end: 8 }, 5), + ], + ) } } diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 5c80c1b042256..babfe28ad5576 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -40,14 +40,18 @@ crate-type = ["cdylib", "rlib"] [dependencies] abi_stable = "0.11.3" arrow = { workspace = true, features = ["ffi"] } +arrow-schema = { workspace = true } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } +datafusion-common = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-proto = { workspace = true } +datafusion-proto-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } -semver = "1.0.26" +semver = "1.0.27" tokio = { workspace = true } [dev-dependencies] @@ -55,3 +59,4 @@ doc-comment = { workspace = true } [features] integration-tests = [] +tarpaulin_include = [] # Exists only to prevent warnings on stable and still have accurate coverage diff --git a/datafusion/ffi/README.md b/datafusion/ffi/README.md index 48283f4cfdc14..72070984f9315 100644 --- a/datafusion/ffi/README.md +++ b/datafusion/ffi/README.md @@ -17,10 +17,10 @@ under the License. --> -# `datafusion-ffi`: Apache DataFusion Foreign Function Interface +# Apache DataFusion Foreign Function Interface -This crate contains code to allow interoperability of Apache [DataFusion] with -functions from other libraries and/or [DataFusion] versions using a stable +This crate contains code to allow interoperability of [Apache DataFusion] with +functions from other libraries and/or DataFusion versions using a stable interface. One of the limitations of the Rust programming language is that there is no @@ -28,10 +28,10 @@ stable [Rust ABI] (Application Binary Interface). If a library is compiled with one version of the Rust compiler and you attempt to use that library with a program compiled by a different Rust compiler, there is no guarantee that you can access the data structures. In order to share code between libraries loaded -at runtime, you need to use Rust's [FFI](Foreign Function Interface (FFI)). +at runtime, you need to use Rust's [FFI] (Foreign Function Interface (FFI)). -The purpose of this crate is to define interfaces between [DataFusion] libraries -that will remain stable across different versions of [DataFusion]. This allows +The purpose of this crate is to define interfaces between DataFusion libraries +that will remain stable across different versions of DataFusion. This allows users to write libraries that can interface between each other at runtime rather than require compiling all of the code into a single executable. @@ -46,7 +46,7 @@ See [API Docs] for details and examples. Two use cases have been identified for this crate, but they are not intended to be all inclusive. -1. `datafusion-python` which will use the FFI to provide external services such +1. [`datafusion-python`] which will use the FFI to provide external services such as a `TableProvider` without needing to re-export the entire `datafusion-python` code base. With `datafusion-ffi` these packages do not need `datafusion-python` as a dependency at all. @@ -68,8 +68,8 @@ stable interfaces that closely mirror the Rust native approach. To learn more about this approach see the [abi_stable] and [async-ffi] crates. If you have a library in another language that you wish to interface to -[DataFusion] the recommendation is to create a Rust wrapper crate to interface -with your library and then to connect it to [DataFusion] using this crate. +DataFusion the recommendation is to create a Rust wrapper crate to interface +with your library and then to connect it to DataFusion using this crate. Alternatively, you could use [bindgen] to interface directly to the [FFI] provided by this crate, but that is currently not supported. @@ -101,12 +101,12 @@ In this crate we have a variety of structs which closely mimic the behavior of their internal counterparts. To see detailed notes about how to use them, see the example in `FFI_TableProvider`. -[datafusion]: https://datafusion.apache.org +[apache datafusion]: https://datafusion.apache.org/ [api docs]: http://docs.rs/datafusion-ffi/latest [rust abi]: https://doc.rust-lang.org/reference/abi.html [ffi]: https://doc.rust-lang.org/nomicon/ffi.html [abi_stable]: https://crates.io/crates/abi_stable [async-ffi]: https://crates.io/crates/async-ffi [bindgen]: https://crates.io/crates/bindgen -[datafusion-python]: https://datafusion.apache.org/python/ +[`datafusion-python`]: https://datafusion.apache.org/python/ [datafusion-contrib]: https://github.com/datafusion-contrib diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index a18e6df59bf12..7b3751dcae823 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -21,7 +21,8 @@ use abi_stable::StableAbi; use arrow::{ array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + error::ArrowError, + ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -36,7 +37,7 @@ impl From for WrappedSchema { let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { Ok(s) => s, Err(e) => { - error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {e}"); FFI_ArrowSchema::empty() } }; @@ -44,16 +45,19 @@ impl From for WrappedSchema { WrappedSchema(ffi_schema) } } +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_df_schema_error(e: ArrowError) -> Schema { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {e}"); + Schema::empty() +} impl From for SchemaRef { fn from(value: WrappedSchema) -> Self { - let schema = match Schema::try_from(&value.0) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); - Schema::empty() - } - }; + let schema = Schema::try_from(&value.0).unwrap_or_else(catch_df_schema_error); Arc::new(schema) } } @@ -71,7 +75,7 @@ pub struct WrappedArray { } impl TryFrom for ArrayRef { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: WrappedArray) -> Result { let data = unsafe { from_ffi(value.array, &value.schema.0)? }; @@ -79,3 +83,14 @@ impl TryFrom for ArrayRef { Ok(make_array(data)) } } + +impl TryFrom<&ArrayRef> for WrappedArray { + type Error = ArrowError; + + fn try_from(array: &ArrayRef) -> Result { + let (array, schema) = to_ffi(&array.to_data())?; + let schema = WrappedSchema(schema); + + Ok(WrappedArray { array, schema }) + } +} diff --git a/datafusion/ffi/src/catalog_provider.rs b/datafusion/ffi/src/catalog_provider.rs index 0886d4749d723..65dcab34f17d0 100644 --- a/datafusion/ffi/src/catalog_provider.rs +++ b/datafusion/ffi/src/catalog_provider.rs @@ -327,7 +327,7 @@ mod tests { assert!(returned_schema.is_some()); assert_eq!(foreign_catalog.schema_names().len(), 1); - // Retrieve non-existant schema + // Retrieve non-existent schema let returned_schema = foreign_catalog.schema("prior_schema"); assert!(returned_schema.is_none()); diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs index 14a0908c47954..70c957d8c3733 100644 --- a/datafusion/ffi/src/execution_plan.rs +++ b/datafusion/ffi/src/execution_plan.rs @@ -205,7 +205,8 @@ impl DisplayAs for ForeignExecutionPlan { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( f, - "FFI_ExecutionPlan(number_of_children={})", + "FFI_ExecutionPlan: {}, number_of_children={}", + self.name, self.children.len(), ) } @@ -390,7 +391,10 @@ mod tests { ); let buf = display.one_line().to_string(); - assert_eq!(buf.trim(), "FFI_ExecutionPlan(number_of_children=0)"); + assert_eq!( + buf.trim(), + "FFI_ExecutionPlan: empty-exec, number_of_children=0" + ); Ok(()) } diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index 877129fc5bb12..0c2340e8ce7b1 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -34,7 +34,10 @@ pub mod schema_provider; pub mod session_config; pub mod table_provider; pub mod table_source; +pub mod udaf; pub mod udf; +pub mod udtf; +pub mod udwf; pub mod util; pub mod volatility; diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 3592c16b8fab0..48c2698a58c75 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -181,6 +181,7 @@ impl TryFrom for PlanProperties { // TODO Extend FFI to get the registry and codex let default_ctx = SessionContext::new(); + let task_context = default_ctx.task_ctx(); let codex = DefaultPhysicalExtensionCodec {}; let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; @@ -188,12 +189,12 @@ impl TryFrom for PlanProperties { let proto_output_ordering = PhysicalSortExprNodeCollection::decode(df_result!(ffi_orderings)?.as_ref()) .map_err(|e| DataFusionError::External(Box::new(e)))?; - let orderings = Some(parse_physical_sort_exprs( + let sort_exprs = parse_physical_sort_exprs( &proto_output_ordering.physical_sort_expr_nodes, - &default_ctx, + &task_context, &schema, &codex, - )?); + )?; let partitioning_vec = unsafe { df_result!((ffi_props.output_partitioning)(&ffi_props))? }; @@ -202,7 +203,7 @@ impl TryFrom for PlanProperties { .map_err(|e| DataFusionError::External(Box::new(e)))?; let partitioning = parse_protobuf_partitioning( Some(&proto_output_partitioning), - &default_ctx, + &task_context, &schema, &codex, )? @@ -211,11 +212,10 @@ impl TryFrom for PlanProperties { .to_string(), ))?; - let eq_properties = match orderings { - Some(ordering) => { - EquivalenceProperties::new_with_orderings(Arc::new(schema), &[ordering]) - } - None => EquivalenceProperties::new(Arc::new(schema)), + let eq_properties = if sort_exprs.is_empty() { + EquivalenceProperties::new(Arc::new(schema)) + } else { + EquivalenceProperties::new_with_orderings(Arc::new(schema), [sort_exprs]) }; let emission_type: EmissionType = @@ -300,7 +300,7 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::physical_plan::Partitioning; + use datafusion::{physical_expr::PhysicalSortExpr, physical_plan::Partitioning}; use super::*; @@ -310,9 +310,13 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let mut eqp = EquivalenceProperties::new(Arc::clone(&schema)); + let _ = eqp.reorder([PhysicalSortExpr::new_default( + datafusion::physical_plan::expressions::col("a", &schema)?, + )]); let original_props = PlanProperties::new( - EquivalenceProperties::new(schema), - Partitioning::UnknownPartitioning(3), + eqp, + Partitioning::RoundRobinBatch(3), EmissionType::Incremental, Boundedness::Bounded, ); @@ -321,7 +325,7 @@ mod tests { let foreign_props: PlanProperties = local_props_ptr.try_into()?; - assert!(format!("{:?}", foreign_props) == format!("{:?}", original_props)); + assert_eq!(format!("{foreign_props:?}"), format!("{original_props:?}")); Ok(()) } diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 939c4050028cb..1739235d17036 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -32,6 +32,7 @@ use datafusion::{ error::DataFusionError, execution::{RecordBatchStream, SendableRecordBatchStream}, }; +use datafusion_common::{exec_datafusion_err, exec_err}; use futures::{Stream, TryStreamExt}; use tokio::runtime::Handle; @@ -57,6 +58,9 @@ pub struct FFI_RecordBatchStream { /// Return the schema of the record batch pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema, + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + /// Internal data. This is only to be accessed by the provider of the plan. /// The foreign library should never attempt to access this data. pub private_data: *mut c_void, @@ -82,6 +86,7 @@ impl FFI_RecordBatchStream { FFI_RecordBatchStream { poll_next: poll_next_fn_wrapper, schema: schema_fn_wrapper, + release: release_fn_wrapper, private_data, } } @@ -96,6 +101,12 @@ unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> Wrappe (*stream).schema().into() } +unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_RecordBatchStream) { + let private_data = + Box::from_raw(provider.private_data as *mut RecordBatchStreamPrivateData); + drop(private_data); +} + fn record_batch_to_wrapped_array( record_batch: RecordBatch, ) -> RResult { @@ -153,9 +164,8 @@ fn wrapped_array_to_record_batch(array: WrappedArray) -> Result { let struct_array = array .as_any() .downcast_ref::() - .ok_or(DataFusionError::Execution( - "Unexpected array type during record batch collection in FFI_RecordBatchStream" - .to_string(), + .ok_or_else(|| exec_datafusion_err!( + "Unexpected array type during record batch collection in FFI_RecordBatchStream - expected StructArray" ))?; Ok(struct_array.into()) @@ -168,9 +178,7 @@ fn maybe_wrapped_array_to_record_batch( ROption::RSome(RResult::ROk(wrapped_array)) => { Some(wrapped_array_to_record_batch(wrapped_array)) } - ROption::RSome(RResult::RErr(e)) => { - Some(Err(DataFusionError::Execution(e.to_string()))) - } + ROption::RSome(RResult::RErr(e)) => Some(exec_err!("FFI error: {e}")), ROption::RNone => None, } } @@ -190,9 +198,61 @@ impl Stream for FFI_RecordBatchStream { Poll::Ready(maybe_wrapped_array_to_record_batch(array)) } FfiPoll::Pending => Poll::Pending, - FfiPoll::Panicked => Poll::Ready(Some(Err(DataFusionError::Execution( - "Error occurred during poll_next on FFI_RecordBatchStream".to_string(), - )))), + FfiPoll::Panicked => Poll::Ready(Some(exec_err!( + "Panic occurred during poll_next on FFI_RecordBatchStream" + ))), } } } + +impl Drop for FFI_RecordBatchStream { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + common::record_batch, error::Result, execution::SendableRecordBatchStream, + test_util::bounded_stream, + }; + + use super::FFI_RecordBatchStream; + use futures::StreamExt; + + #[tokio::test] + async fn test_round_trip_record_batch_stream() -> Result<()> { + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 3]), + ("b", Float64, vec![Some(4.0), None, Some(5.0)]) + )?; + let original_rbs = bounded_stream(record_batch.clone(), 1); + + let ffi_rbs: FFI_RecordBatchStream = original_rbs.into(); + let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs); + + let schema = ffi_rbs.schema(); + assert_eq!( + schema, + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true) + ])) + ); + + let batch = ffi_rbs.next().await; + assert!(batch.is_some()); + assert!(batch.as_ref().unwrap().is_ok()); + assert_eq!(batch.unwrap().unwrap(), record_batch); + + // There should only be one batch + let no_batch = ffi_rbs.next().await; + assert!(no_batch.is_none()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/schema_provider.rs b/datafusion/ffi/src/schema_provider.rs index 6e5a590e1a09d..b5970d5881d6e 100644 --- a/datafusion/ffi/src/schema_provider.rs +++ b/datafusion/ffi/src/schema_provider.rs @@ -366,7 +366,7 @@ mod tests { assert!(returned_schema.is_some()); assert_eq!(foreign_schema_provider.table_names().len(), 1); - // Retrieve non-existant table + // Retrieve non-existent table let returned_schema = foreign_schema_provider .table("prior_table") .await diff --git a/datafusion/ffi/src/session_config.rs b/datafusion/ffi/src/session_config.rs index aea03cf94e0af..a07b66c601962 100644 --- a/datafusion/ffi/src/session_config.rs +++ b/datafusion/ffi/src/session_config.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::{ - collections::HashMap, - ffi::{c_char, c_void, CString}, -}; - use abi_stable::{ std_types::{RHashMap, RString}, StableAbi, }; use datafusion::{config::ConfigOptions, error::Result}; use datafusion::{error::DataFusionError, prelude::SessionConfig}; +use std::sync::Arc; +use std::{ + collections::HashMap, + ffi::{c_char, c_void, CString}, +}; /// A stable struct for sharing [`SessionConfig`] across FFI boundaries. /// Instead of attempting to expose the entire SessionConfig interface, we @@ -85,11 +85,9 @@ unsafe extern "C" fn release_fn_wrapper(config: &mut FFI_SessionConfig) { unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_SessionConfig { let old_private_data = config.private_data as *mut SessionConfigPrivateData; - let old_config = &(*old_private_data).config; + let old_config = Arc::clone(&(*old_private_data).config); - let private_data = Box::new(SessionConfigPrivateData { - config: old_config.clone(), - }); + let private_data = Box::new(SessionConfigPrivateData { config: old_config }); FFI_SessionConfig { config_options: config_options_fn_wrapper, @@ -100,7 +98,7 @@ unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_Session } struct SessionConfigPrivateData { - pub config: ConfigOptions, + pub config: Arc, } impl From<&SessionConfig> for FFI_SessionConfig { @@ -120,7 +118,7 @@ impl From<&SessionConfig> for FFI_SessionConfig { } let private_data = Box::new(SessionConfigPrivateData { - config: session.options().clone(), + config: Arc::clone(session.options()), }); Self { diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs index a7391a85031e0..890511997a706 100644 --- a/datafusion/ffi/src/table_provider.rs +++ b/datafusion/ffi/src/table_provider.rs @@ -110,8 +110,8 @@ pub struct FFI_TableProvider { /// * `session_config` - session configuration /// * `projections` - if specified, only a subset of the columns are returned /// * `filters_serialized` - filters to apply to the scan, which are a - /// [`LogicalExprList`] protobuf message serialized into bytes to pass - /// across the FFI boundary. + /// [`LogicalExprList`] protobuf message serialized into bytes to pass + /// across the FFI boundary. /// * `limit` - if specified, limit the number of rows returned pub scan: unsafe extern "C" fn( provider: &Self, @@ -259,14 +259,10 @@ unsafe extern "C" fn scan_fn_wrapper( }; let projections: Vec<_> = projections.into_iter().collect(); - let maybe_projections = match projections.is_empty() { - true => None, - false => Some(&projections), - }; let plan = rresult_return!( internal_provider - .scan(&ctx.state(), maybe_projections, &filters, limit.into()) + .scan(&ctx.state(), Some(&projections), &filters, limit.into()) .await ); @@ -600,4 +596,49 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_aggregation() -> Result<()> { + use arrow::datatypes::Field; + use datafusion::arrow::{ + array::Float32Array, datatypes::DataType, record_batch::RecordBatch, + }; + use datafusion::common::assert_batches_eq; + use datafusion::datasource::MemTable; + + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + + let ctx = SessionContext::new(); + + let provider = Arc::new(MemTable::try_new(schema, vec![vec![batch1]])?); + + let ffi_provider = FFI_TableProvider::new(provider, true, None); + + let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + + ctx.register_table("t", Arc::new(foreign_table_provider))?; + + let result = ctx + .sql("SELECT COUNT(*) as cnt FROM t") + .await? + .collect() + .await?; + #[rustfmt::skip] + let expected = [ + "+-----+", + "| cnt |", + "+-----+", + "| 3 |", + "+-----+" + ]; + assert_batches_eq!(expected, &result); + Ok(()) + } } diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index cf05d596308f3..cef4161d8c1fc 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -33,12 +33,13 @@ use arrow::datatypes::Schema; use async_trait::async_trait; use datafusion::{ catalog::{Session, TableProvider}, - error::{DataFusionError, Result}, + error::Result, execution::RecordBatchStream, physical_expr::EquivalenceProperties, physical_plan::{ExecutionPlan, Partitioning}, prelude::Expr, }; +use datafusion_common::exec_err; use futures::Stream; use tokio::{ runtime::Handle, @@ -259,9 +260,9 @@ impl Stream for AsyncTestRecordBatchStream { }); if let Err(e) = this.batch_request.try_send(true) { - return std::task::Poll::Ready(Some(Err(DataFusionError::Execution( - format!("Unable to send batch request, {}", e), - )))); + return std::task::Poll::Ready(Some(exec_err!( + "Failed to send batch request: {e}" + ))); } match this.batch_receiver.blocking_recv() { @@ -269,9 +270,9 @@ impl Stream for AsyncTestRecordBatchStream { Some(batch) => std::task::Poll::Ready(Some(Ok(batch))), None => std::task::Poll::Ready(None), }, - Err(e) => std::task::Poll::Ready(Some(Err(DataFusionError::Execution( - format!("Unable receive record batch: {}", e), - )))), + Err(e) => std::task::Poll::Ready(Some(exec_err!( + "Failed to receive record batch: {e}" + ))), } } } diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index c7a9816431e10..816086c320415 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -27,7 +27,11 @@ use abi_stable::{ }; use catalog::create_catalog_provider; -use crate::catalog_provider::FFI_CatalogProvider; +use crate::{catalog_provider::FFI_CatalogProvider, udtf::FFI_TableFunction}; + +use crate::udaf::FFI_AggregateUDF; + +use crate::udwf::FFI_WindowUDF; use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; @@ -37,7 +41,10 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func}; +use udf_udaf_udwf::{ + create_ffi_abs_func, create_ffi_random_func, create_ffi_rank_func, + create_ffi_stddev_func, create_ffi_sum_func, create_ffi_table_func, +}; mod async_provider; pub mod catalog; @@ -63,6 +70,16 @@ pub struct ForeignLibraryModule { pub create_nullary_udf: extern "C" fn() -> FFI_ScalarUDF, + pub create_table_function: extern "C" fn() -> FFI_TableFunction, + + /// Create an aggregate UDAF using sum + pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, + + /// Create grouping UDAF using stddev + pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF, + + pub create_rank_udwf: extern "C" fn() -> FFI_WindowUDF, + pub version: extern "C" fn() -> u64, } @@ -109,6 +126,10 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_table: construct_table_provider, create_scalar_udf: create_ffi_abs_func, create_nullary_udf: create_ffi_random_func, + create_table_function: create_ffi_table_func, + create_sum_udaf: create_ffi_sum_func, + create_stddev_udaf: create_ffi_stddev_func, + create_rank_udwf: create_ffi_rank_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index b40bec762bd71..55e31ef3ab770 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -15,10 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::udf::FFI_ScalarUDF; +use crate::{ + udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction, + udwf::FFI_WindowUDF, +}; use datafusion::{ + catalog::TableFunctionImpl, functions::math::{abs::AbsFunc, random::RandomFunc}, - logical_expr::ScalarUDF, + functions_aggregate::{stddev::Stddev, sum::Sum}, + functions_table::generate_series::RangeFunc, + functions_window::rank::Rank, + logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}, }; use std::sync::Arc; @@ -34,3 +41,33 @@ pub(crate) extern "C" fn create_ffi_random_func() -> FFI_ScalarUDF { udf.into() } + +pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { + let udtf: Arc = Arc::new(RangeFunc {}); + + FFI_TableFunction::new(udtf, None) +} + +pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Sum::new().into()); + + udaf.into() +} + +pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Stddev::new().into()); + + udaf.into() +} + +pub(crate) extern "C" fn create_ffi_rank_func() -> FFI_WindowUDF { + let udwf: Arc = Arc::new( + Rank::new( + "rank_demo".to_string(), + datafusion::functions_window::rank::RankType::Basic, + ) + .into(), + ); + + udwf.into() +} diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs new file mode 100644 index 0000000000000..80b872159f483 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -0,0 +1,366 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Deref}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::Accumulator, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; + +/// A stable struct for sharing [`Accumulator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`Accumulator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Accumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: + unsafe extern "C" fn(accumulator: &mut Self) -> RResult, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: + unsafe extern "C" fn(accumulator: &mut Self) -> RResult>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + states: RVec, + ) -> RResult<(), RString>, + + pub retract_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + pub supports_retract_batch: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_Accumulator {} +unsafe impl Sync for FFI_Accumulator {} + +pub struct AccumulatorPrivateData { + pub accumulator: Box, +} + +impl FFI_Accumulator { + #[inline] + unsafe fn inner_mut(&mut self) -> &mut Box { + let private_data = self.private_data as *mut AccumulatorPrivateData; + &mut (*private_data).accumulator + } + + #[inline] + unsafe fn inner(&self) -> &dyn Accumulator { + let private_data = self.private_data as *const AccumulatorPrivateData; + (*private_data).accumulator.deref() + } +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accumulator.update_batch(&values_arrays)) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &mut FFI_Accumulator, +) -> RResult, RString> { + let accumulator = accumulator.inner_mut(); + + let scalar_result = rresult_return!(accumulator.evaluate()); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { + accumulator.inner().size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &mut FFI_Accumulator, +) -> RResult>, RString> { + let accumulator = accumulator.inner_mut(); + + let state = rresult_return!(accumulator.state()); + let state = state + .into_iter() + .map(|state_val| { + datafusion_proto::protobuf::ScalarValue::try_from(&state_val) + .map_err(DataFusionError::from) + .map(|v| RVec::from(v.encode_to_vec())) + }) + .collect::>>() + .map(|state_vec| state_vec.into()); + + rresult!(state) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + states: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let states = rresult_return!(states + .into_iter() + .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) + .collect::>>()); + + rresult!(accumulator.merge_batch(&states)) +} + +unsafe extern "C" fn retract_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accumulator.retract_batch(&values_arrays)) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); + drop(private_data); +} + +impl From> for FFI_Accumulator { + fn from(accumulator: Box) -> Self { + let supports_retract_batch = accumulator.supports_retract_batch(); + let private_data = AccumulatorPrivateData { accumulator }; + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + retract_batch: retract_batch_fn_wrapper, + supports_retract_batch, + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_Accumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_Accumulator. +#[derive(Debug)] +pub struct ForeignAccumulator { + accumulator: FFI_Accumulator, +} + +unsafe impl Send for ForeignAccumulator {} +unsafe impl Sync for ForeignAccumulator {} + +impl From for ForeignAccumulator { + fn from(accumulator: FFI_Accumulator) -> Self { + Self { accumulator } + } +} + +impl Accumulator for ForeignAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn evaluate(&mut self) -> Result { + unsafe { + let scalar_bytes = + df_result!((self.accumulator.evaluate)(&mut self.accumulator))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn state(&mut self) -> Result> { + unsafe { + let state_protos = + df_result!((self.accumulator.state)(&mut self.accumulator))?; + + state_protos + .into_iter() + .map(|proto_bytes| { + datafusion_proto::protobuf::ScalarValue::decode(proto_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e))) + .and_then(|proto_value| { + ScalarValue::try_from(&proto_value) + .map_err(DataFusionError::from) + }) + }) + .collect::>>() + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + unsafe { + let states = states + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + states.into() + )) + } + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.retract_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn supports_retract_batch(&self) -> bool { + self.accumulator.supports_retract_batch + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array}; + use datafusion::{ + common::create_array, error::Result, + functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator, + scalar::ScalarValue, + }; + + use super::{FFI_Accumulator, ForeignAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let original_accum = AvgAccumulator::default(); + let original_size = original_accum.size(); + let original_supports_retract = original_accum.supports_retract_batch(); + + let boxed_accum: Box = Box::new(original_accum); + let ffi_accum: FFI_Accumulator = boxed_accum.into(); + let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); + + // Send in an array to average. There are 5 values and it should average to 30.0 + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + foreign_accum.update_batch(&[values])?; + + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + let state = foreign_accum.state()?; + assert_eq!(state.len(), 2); + assert_eq!(state[0], ScalarValue::UInt64(Some(5))); + assert_eq!(state[1], ScalarValue::Float64(Some(150.0))); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = vec![ + make_array(create_array!(UInt64, vec![1]).to_data()), + make_array(create_array!(Float64, vec![0.0]).to_data()), + ]; + + foreign_accum.merge_batch(&second_states)?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(25.0))); + + // If we remove a batch that is equivalent to the state we added + // we should go back to our original value of 30.0 + let values = create_array!(Float64, vec![0.0]); + foreign_accum.retract_batch(&[values])?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + assert_eq!(original_size, foreign_accum.size()); + assert_eq!( + original_supports_retract, + foreign_accum.supports_retract_batch() + ); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs new file mode 100644 index 0000000000000..0302c26a2e6b5 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::arrow_wrappers::WrappedSchema; +use abi_stable::{ + std_types::{RString, RVec}, + StableAbi, +}; +use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; +use arrow_schema::FieldRef; +use datafusion::{ + error::DataFusionError, + logical_expr::function::AccumulatorArgs, + physical_expr::{PhysicalExpr, PhysicalSortExpr}, + prelude::SessionContext, +}; +use datafusion_common::exec_datafusion_err; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, + to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalAggregateExprNode, +}; +use prost::Message; + +/// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries. +/// For an explanation of each field, see the corresponding field +/// defined in [`AccumulatorArgs`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AccumulatorArgs { + return_field: WrappedSchema, + schema: WrappedSchema, + is_reversed: bool, + name: RString, + physical_expr_def: RVec, +} + +impl TryFrom> for FFI_AccumulatorArgs { + type Error = DataFusionError; + + fn try_from(args: AccumulatorArgs) -> Result { + let return_field = + WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); + let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); + + let codec = DefaultPhysicalExtensionCodec {}; + let ordering_req = + serialize_physical_sort_exprs(args.order_bys.to_owned(), &codec)?; + + let expr = serialize_physical_exprs(args.exprs, &codec)?; + + let physical_expr_def = PhysicalAggregateExprNode { + expr, + ordering_req, + distinct: args.is_distinct, + ignore_nulls: args.ignore_nulls, + fun_definition: None, + aggregate_function: None, + human_display: args.name.to_string(), + }; + let physical_expr_def = physical_expr_def.encode_to_vec().into(); + + Ok(Self { + return_field, + schema, + is_reversed: args.is_reversed, + name: args.name.into(), + physical_expr_def, + }) + } +} + +/// This struct mirrors AccumulatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// AccumulatorArgs can then reference. +pub struct ForeignAccumulatorArgs { + pub return_field: FieldRef, + pub schema: Schema, + pub ignore_nulls: bool, + pub order_bys: Vec, + pub is_reversed: bool, + pub name: String, + pub is_distinct: bool, + pub exprs: Vec>, +} + +impl TryFrom for ForeignAccumulatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_AccumulatorArgs) -> Result { + let proto_def = PhysicalAggregateExprNode::decode( + value.physical_expr_def.as_ref(), + ) + .map_err(|e| { + exec_datafusion_err!("Failed to decode PhysicalAggregateExprNode: {e}") + })?; + + let return_field = Arc::new((&value.return_field.0).try_into()?); + let schema = Schema::try_from(&value.schema.0)?; + + let default_ctx = SessionContext::new(); + let task_ctx = default_ctx.task_ctx(); + let codex = DefaultPhysicalExtensionCodec {}; + + let order_bys = parse_physical_sort_exprs( + &proto_def.ordering_req, + &task_ctx, + &schema, + &codex, + )?; + + let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?; + + Ok(Self { + return_field, + schema, + ignore_nulls: proto_def.ignore_nulls, + order_bys, + is_reversed: value.is_reversed, + name: value.name.to_string(), + is_distinct: proto_def.distinct, + exprs, + }) + } +} + +impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { + fn from(value: &'a ForeignAccumulatorArgs) -> Self { + Self { + return_field: Arc::clone(&value.return_field), + schema: &value.schema, + ignore_nulls: value.ignore_nulls, + order_bys: &value.order_bys, + is_reversed: value.is_reversed, + name: value.name.as_str(), + is_distinct: value.is_distinct, + exprs: &value.exprs, + } + } +} + +#[cfg(test)] +mod tests { + use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + error::Result, logical_expr::function::AccumulatorArgs, + physical_expr::PhysicalSortExpr, physical_plan::expressions::col, + }; + + #[test] + fn test_round_trip_accumulator_args() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let orig_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: false, + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let orig_str = format!("{orig_args:?}"); + + let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; + let round_trip_args: AccumulatorArgs = (&foreign_args).into(); + + let round_trip_str = format!("{round_trip_args:?}"); + + // Since AccumulatorArgs doesn't implement Eq, simply compare + // the debug strings. + assert_eq!(orig_str, round_trip_str); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs new file mode 100644 index 0000000000000..58a18c69db7c8 --- /dev/null +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -0,0 +1,513 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Deref, sync::Arc}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + error::ArrowError, + ffi::to_ffi, +}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{EmitTo, GroupsAccumulator}, +}; + +/// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`GroupsAccumulator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_GroupsAccumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn( + accumulator: &mut Self, + emit_to: FFI_EmitTo, + ) -> RResult, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: unsafe extern "C" fn( + accumulator: &mut Self, + emit_to: FFI_EmitTo, + ) -> RResult, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + pub convert_to_state: unsafe extern "C" fn( + accumulator: &Self, + values: RVec, + opt_filter: ROption, + ) + -> RResult, RString>, + + pub supports_convert_to_state: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_GroupsAccumulator {} +unsafe impl Sync for FFI_GroupsAccumulator {} + +pub struct GroupsAccumulatorPrivateData { + pub accumulator: Box, +} + +impl FFI_GroupsAccumulator { + #[inline] + unsafe fn inner_mut(&mut self) -> &mut Box { + let private_data = self.private_data as *mut GroupsAccumulatorPrivateData; + &mut (*private_data).accumulator + } + + #[inline] + unsafe fn inner(&self) -> &dyn GroupsAccumulator { + let private_data = self.private_data as *const GroupsAccumulatorPrivateData; + (*private_data).accumulator.deref() + } +} + +fn process_values(values: RVec) -> Result>> { + values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>() +} + +/// Convert C-typed opt_filter into the internal type. +fn process_opt_filter(opt_filter: ROption) -> Result> { + opt_filter + .into_option() + .map(|filter| { + ArrayRef::try_from(filter) + .map_err(DataFusionError::from) + .map(|arr| BooleanArray::from(arr.into_data())) + }) + .transpose() +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + let values = rresult_return!(process_values(values)); + let group_indices: Vec = group_indices.into_iter().collect(); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + + rresult!(accumulator.update_batch( + &values, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult { + let accumulator = accumulator.inner_mut(); + + let result = rresult_return!(accumulator.evaluate(emit_to.into())); + + rresult!(WrappedArray::try_from(&result)) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { + let accumulator = accumulator.inner(); + accumulator.size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult, RString> { + let accumulator = accumulator.inner_mut(); + + let state = rresult_return!(accumulator.state(emit_to.into())); + rresult!(state + .into_iter() + .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + let values = rresult_return!(process_values(values)); + let group_indices: Vec = group_indices.into_iter().collect(); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + + rresult!(accumulator.merge_batch( + &values, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn convert_to_state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + values: RVec, + opt_filter: ROption, +) -> RResult, RString> { + let accumulator = accumulator.inner(); + let values = rresult_return!(process_values(values)); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + let state = + rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref())); + + rresult!(state + .iter() + .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); +} + +impl From> for FFI_GroupsAccumulator { + fn from(accumulator: Box) -> Self { + let supports_convert_to_state = accumulator.supports_convert_to_state(); + let private_data = GroupsAccumulatorPrivateData { accumulator }; + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + convert_to_state: convert_to_state_fn_wrapper, + supports_convert_to_state, + + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_GroupsAccumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignGroupsAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_GroupsAccumulator. +#[derive(Debug)] +pub struct ForeignGroupsAccumulator { + accumulator: FFI_GroupsAccumulator, +} + +unsafe impl Send for ForeignGroupsAccumulator {} +unsafe impl Sync for ForeignGroupsAccumulator {} + +impl From for ForeignGroupsAccumulator { + fn from(accumulator: FFI_GroupsAccumulator) -> Self { + Self { accumulator } + } +} + +impl GroupsAccumulator for ForeignGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + unsafe { + let return_array = df_result!((self.accumulator.evaluate)( + &mut self.accumulator, + emit_to.into() + ))?; + + return_array.try_into().map_err(DataFusionError::from) + } + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + unsafe { + let returned_arrays = df_result!((self.accumulator.state)( + &mut self.accumulator, + emit_to.into() + ))?; + + returned_arrays + .into_iter() + .map(|wrapped_array| { + wrapped_array.try_into().map_err(DataFusionError::from) + }) + .collect::>>() + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + let returned_array = df_result!((self.accumulator.convert_to_state)( + &self.accumulator, + values, + opt_filter + ))?; + + returned_array + .into_iter() + .map(|arr| arr.try_into().map_err(DataFusionError::from)) + .collect() + } + } + + fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_EmitTo { + All, + First(usize), +} + +impl From for FFI_EmitTo { + fn from(value: EmitTo) -> Self { + match value { + EmitTo::All => Self::All, + EmitTo::First(v) => Self::First(v), + } + } +} + +impl From for EmitTo { + fn from(value: FFI_EmitTo) -> Self { + match value { + FFI_EmitTo::All => Self::All, + FFI_EmitTo::First(v) => Self::First(v), + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array, BooleanArray}; + use datafusion::{ + common::create_array, + error::Result, + logical_expr::{EmitTo, GroupsAccumulator}, + }; + use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; + + use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let boxed_accum: Box = + Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true)); + let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); + + // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. + let values = create_array!(Boolean, vec![true, true, true, false, true, true]); + let opt_filter = + create_array!(Boolean, vec![true, true, true, true, false, false]); + foreign_accum.update_batch( + &[values], + &[0, 0, 1, 1, 2, 2], + Some(opt_filter.as_ref()), + 3, + )?; + + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + let groups_bool = groups_bool.as_any().downcast_ref::().unwrap(); + + assert_eq!( + groups_bool, + create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref() + ); + + let state = foreign_accum.state(EmitTo::All)?; + assert_eq!(state.len(), 1); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = + vec![make_array(create_array!(Boolean, vec![false]).to_data())]; + + let opt_filter = create_array!(Boolean, vec![true]); + foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?; + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + assert_eq!(groups_bool.len(), 1); + assert_eq!( + groups_bool.as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + let values = create_array!(Boolean, vec![false]); + let opt_filter = create_array!(Boolean, vec![true]); + let groups_bool = + foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?; + + assert_eq!( + groups_bool[0].as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + Ok(()) + } + + fn test_emit_to_round_trip(value: EmitTo) -> Result<()> { + let ffi_value: FFI_EmitTo = value.into(); + let round_trip_value: EmitTo = ffi_value.into(); + + assert_eq!(value, round_trip_value); + Ok(()) + } + + /// This test ensures all enum values are properly translated + #[test] + fn test_all_emit_to_round_trip() -> Result<()> { + test_emit_to_round_trip(EmitTo::All)?; + test_emit_to_round_trip(EmitTo::First(10))?; + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs new file mode 100644 index 0000000000000..1ea1798c7c8be --- /dev/null +++ b/datafusion/ffi/src/udaf/mod.rs @@ -0,0 +1,816 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::{ + std_types::{ROption, RResult, RStr, RString, RVec}, + StableAbi, +}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; +use arrow_schema::FieldRef; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::fields_with_aggregate_udf, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, +}; +use datafusion_common::exec_datafusion_err; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; +use std::hash::{Hash, Hasher}; +use std::{ffi::c_void, sync::Arc}; + +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; +use prost::{DecodeError, Message}; + +mod accumulator; +mod accumulator_args; +mod groups_accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub name: RString, + + /// FFI equivalent to the `aliases` of a [`AggregateUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `volatility` of a [`AggregateUDF`] + pub volatility: FFI_Volatility, + + /// Determines the return field of the underlying [`AggregateUDF`] based on the + /// argument fields. + pub return_field: unsafe extern "C" fn( + udaf: &Self, + arg_fields: RVec, + ) -> RResult, + + /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`] + pub is_nullable: bool, + + /// FFI equivalent to [`AggregateUDF::groups_accumulator_supported`] + pub groups_accumulator_supported: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool, + + /// FFI equivalent to [`AggregateUDF::accumulator`] + pub accumulator: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::create_sliding_accumulator`] + pub create_sliding_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::state_fields`] + #[allow(clippy::type_complexity)] + pub state_fields: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_fields: RVec, + return_field: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, + ) -> RResult>, RString>, + + /// FFI equivalent to [`AggregateUDF::create_groups_accumulator`] + pub create_groups_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::with_beneficial_ordering`] + pub with_beneficial_ordering: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, + ) -> RResult, RString>, + + /// FFI equivalent to [`AggregateUDF::order_sensitivity`] + pub order_sensitivity: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, + + /// Performs type coercion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`AggregateUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + /// Used to create a clone on the provider of the udaf. This should + /// only need to be called by the receiver of the udaf. + pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udaf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udaf. + /// A [`ForeignAggregateUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_AggregateUDF {} +unsafe impl Sync for FFI_AggregateUDF {} + +pub struct AggregateUDFPrivateData { + pub udaf: Arc, +} + +impl FFI_AggregateUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const AggregateUDFPrivateData; + &(*private_data).udaf + } +} + +unsafe extern "C" fn return_field_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_fields: RVec, +) -> RResult { + let udaf = udaf.inner(); + + let arg_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&arg_fields)); + + let return_field = udaf + .return_field(&arg_fields) + .and_then(|v| { + FFI_ArrowSchema::try_from(v.as_ref()).map_err(DataFusionError::from) + }) + .map(WrappedSchema); + + rresult!(return_field) +} + +unsafe extern "C" fn accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_sliding_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_sliding_accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_groups_accumulator(accumulator_args.into()) + .map(FFI_GroupsAccumulator::from)) +} + +unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> bool { + let udaf = udaf.inner(); + + ForeignAccumulatorArgs::try_from(args) + .map(|a| udaf.groups_accumulator_supported((&a).into())) + .unwrap_or_else(|e| { + log::warn!("Unable to parse accumulator args. {e}"); + false + }) +} + +unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, +) -> RResult, RString> { + let udaf = udaf.inner().as_ref().clone(); + + let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); + let result = rresult_return!(result + .map(|func| func.with_beneficial_ordering(beneficial_ordering)) + .transpose()) + .flatten() + .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + + RResult::ROk(result.into()) +} + +unsafe extern "C" fn state_fields_fn_wrapper( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_fields: RVec, + return_field: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, +) -> RResult>, RString> { + let udaf = udaf.inner(); + + let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields)); + let return_field = rresult_return!(Field::try_from(&return_field.0)).into(); + + let ordering_fields = &rresult_return!(ordering_fields + .into_iter() + .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref())) + .collect::, DecodeError>>()); + + let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)) + .into_iter() + .map(Arc::new) + .collect::>(); + + let args = StateFieldsArgs { + name: name.as_str(), + input_fields, + return_field, + ordering_fields, + is_distinct, + }; + + let state_fields = rresult_return!(udaf.state_fields(args)); + let state_fields = rresult_return!(state_fields + .iter() + .map(|f| f.as_ref()) + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()) + .into_iter() + .map(|field| field.encode_to_vec().into()) + .collect(); + + RResult::ROk(state_fields) +} + +unsafe extern "C" fn order_sensitivity_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> FFI_AggregateOrderSensitivity { + udaf.inner().order_sensitivity().into() +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult, RString> { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let arg_fields = arg_types + .iter() + .map(|dt| Field::new("f", dt.clone(), true)) + .map(Arc::new) + .collect::>(); + let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf)) + .into_iter() + .map(|f| f.data_type().to_owned()) + .collect::>(); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { + let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { + Arc::clone(udaf.inner()).into() +} + +impl Clone for FFI_AggregateUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_AggregateUDF { + fn from(udaf: Arc) -> Self { + let name = udaf.name().into(); + let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let is_nullable = udaf.is_nullable(); + let volatility = udaf.signature().volatility.into(); + + let private_data = Box::new(AggregateUDFPrivateData { udaf }); + + Self { + name, + is_nullable, + volatility, + aliases, + return_field: return_field_fn_wrapper, + accumulator: accumulator_fn_wrapper, + create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, + create_groups_accumulator: create_groups_accumulator_fn_wrapper, + groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, + with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, + state_fields: state_fields_fn_wrapper, + order_sensitivity: order_sensitivity_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_AggregateUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_AggregateUDF. +#[derive(Debug)] +pub struct ForeignAggregateUDF { + signature: Signature, + aliases: Vec, + udaf: FFI_AggregateUDF, +} + +unsafe impl Send for ForeignAggregateUDF {} +unsafe impl Sync for ForeignAggregateUDF {} + +impl PartialEq for ForeignAggregateUDF { + fn eq(&self, other: &Self) -> bool { + // FFI_AggregateUDF cannot be compared, so identity equality is the best we can do. + std::ptr::eq(self, other) + } +} +impl Eq for ForeignAggregateUDF {} +impl Hash for ForeignAggregateUDF { + fn hash(&self, state: &mut H) { + std::ptr::hash(self, state) + } +} + +impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { + type Error = DataFusionError; + + fn try_from(udaf: &FFI_AggregateUDF) -> Result { + let signature = Signature::user_defined((&udaf.volatility).into()); + let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) + } +} + +impl AggregateUDFImpl for ForeignAggregateUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + self.udaf.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!() + } + + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + let arg_fields = vec_fieldref_to_rvec_wrapped(arg_fields)?; + + let result = unsafe { (self.udaf.return_field)(&self.udaf, arg_fields) }; + + let result = df_result!(result); + + result.and_then(|r| { + Field::try_from(&r.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) + } + + fn is_nullable(&self) -> bool { + self.udaf.is_nullable + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let args = acc_args.try_into()?; + unsafe { + df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { + Box::new(ForeignAccumulator::from(accum)) as Box + }) + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + unsafe { + let name = RStr::from_str(args.name); + let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?; + let return_field = + WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); + let ordering_fields = args + .ordering_fields + .iter() + .map(|f| f.as_ref()) + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()? + .into_iter() + .map(|proto_field| proto_field.encode_to_vec().into()) + .collect(); + + let fields = df_result!((self.udaf.state_fields)( + &self.udaf, + &name, + input_fields, + return_field, + ordering_fields, + args.is_distinct + ))?; + let fields = fields + .into_iter() + .map(|field_bytes| { + datafusion_proto_common::Field::decode(field_bytes.as_ref()) + .map_err(|e| exec_datafusion_err!("{e}")) + }) + .collect::>>()?; + + parse_proto_fields_to_fields(fields.iter()) + .map(|fields| fields.into_iter().map(Arc::new).collect()) + .map_err(|e| exec_datafusion_err!("{e}")) + } + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + let args = match FFI_AccumulatorArgs::try_from(args) { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {e}"); + return false; + } + }; + + unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let args = FFI_AccumulatorArgs::try_from(args)?; + + unsafe { + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( + |accum| { + Box::new(ForeignGroupsAccumulator::from(accum)) + as Box + }, + ) + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let args = args.try_into()?; + unsafe { + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( + |accum| Box::new(ForeignAccumulator::from(accum)) as Box, + ) + } + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + unsafe { + let result = df_result!((self.udaf.with_beneficial_ordering)( + &self.udaf, + beneficial_ordering + ))? + .into_option(); + + let result = result + .map(|func| ForeignAggregateUDF::try_from(&func)) + .transpose()?; + + Ok(result.map(|func| Arc::new(func) as Arc)) + } + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() } + } + + fn simplify(&self) -> Option { + None + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = + df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_AggregateOrderSensitivity { + Insensitive, + HardRequirement, + SoftRequirement, + Beneficial, +} + +impl From for AggregateOrderSensitivity { + fn from(value: FFI_AggregateOrderSensitivity) -> Self { + match value { + FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive, + FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + FFI_AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement, + FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +impl From for FFI_AggregateOrderSensitivity { + fn from(value: AggregateOrderSensitivity) -> Self { + match value { + AggregateOrderSensitivity::Insensitive => Self::Insensitive, + AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + AggregateOrderSensitivity::SoftRequirement => Self::SoftRequirement, + AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::{ + common::create_array, functions_aggregate::sum::Sum, + physical_expr::PhysicalSortExpr, physical_plan::expressions::col, + scalar::ScalarValue, + }; + use std::any::Any; + use std::collections::HashMap; + + use super::*; + + #[derive(Default, Debug, Hash, Eq, PartialEq)] + struct SumWithCopiedMetadata { + inner: Sum, + } + + impl AggregateUDFImpl for SumWithCopiedMetadata { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!() + } + + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + // Copy the input field, so any metadata gets returned + Ok(Arc::clone(&arg_fields[0])) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + self.inner.accumulator(acc_args) + } + } + + fn create_test_foreign_udaf( + original_udaf: impl AggregateUDFImpl + 'static, + ) -> Result { + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + Ok(foreign_udaf.into()) + } + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = Sum::new(); + let original_name = original_udaf.name().to_owned(); + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + // Convert to FFI format + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + // Convert back to native format + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + let foreign_udaf: AggregateUDF = foreign_udaf.into(); + + assert_eq!(original_name, foreign_udaf.name()); + Ok(()) + } + + #[test] + fn test_foreign_udaf_aliases() -> Result<()> { + let foreign_udaf = + create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]); + + let return_field = + foreign_udaf + .return_field(&[Field::new("a", DataType::Float64, true).into()])?; + let return_type = return_field.data_type(); + assert_eq!(return_type, &DataType::Float64); + Ok(()) + } + + #[test] + fn test_foreign_udaf_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: true, + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let mut accumulator = foreign_udaf.accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + #[test] + fn test_round_trip_udaf_metadata() -> Result<()> { + let original_udaf = SumWithCopiedMetadata::default(); + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + // Convert to FFI format + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + // Convert back to native format + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + let foreign_udaf: AggregateUDF = foreign_udaf.into(); + + let metadata: HashMap = + [("a_key".to_string(), "a_value".to_string())] + .into_iter() + .collect(); + let input_field = Arc::new( + Field::new("a", DataType::Float64, false).with_metadata(metadata.clone()), + ); + let return_field = foreign_udaf.return_field(&[input_field])?; + + assert_eq!(&metadata, return_field.metadata()); + Ok(()) + } + + #[test] + fn test_beneficial_ordering() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf( + datafusion::functions_aggregate::first_last::FirstValue::new(), + )?; + + let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap(); + + assert_eq!( + foreign_udaf.order_sensitivity(), + AggregateOrderSensitivity::Beneficial + ); + + let a_field = Arc::new(Field::new("a", DataType::Float64, true)); + let state_fields = foreign_udaf.state_fields(StateFieldsArgs { + name: "a", + input_fields: &[Field::new("f", DataType::Float64, true).into()], + return_field: Field::new("f", DataType::Float64, true).into(), + ordering_fields: &[Arc::clone(&a_field)], + is_distinct: false, + })?; + + assert_eq!(state_fields.len(), 3); + assert_eq!(state_fields[1], a_field); + Ok(()) + } + + #[test] + fn test_sliding_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + // Note: sum distinct is only support Int64 until now + let acc_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: true, + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], + is_reversed: false, + name: "round_trip", + is_distinct: false, + exprs: &[col("a", &schema)?], + }; + + let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) { + let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into(); + let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into(); + + assert_eq!(sensitivity, round_trip_sensitivity); + } + + #[test] + fn test_round_trip_all_order_sensitivities() { + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial); + } +} diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 706b9fabedcb4..5e59cfc5ecb07 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -15,23 +15,28 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; - +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ array::ArrayRef, error::ArrowError, ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, }; +use arrow_schema::FieldRef; +use datafusion::config::ConfigOptions; +use datafusion::logical_expr::ReturnFieldArgs; use datafusion::{ error::DataFusionError, - logical_expr::{ - type_coercion::functions::data_types_with_scalar_udf, ReturnInfo, ReturnTypeArgs, - }, + logical_expr::type_coercion::functions::data_types_with_scalar_udf, }; use datafusion::{ error::Result, @@ -39,19 +44,12 @@ use datafusion::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }, }; -use return_info::FFI_ReturnInfo; use return_type_args::{ - FFI_ReturnTypeArgs, ForeignReturnTypeArgs, ForeignReturnTypeArgsOwned, -}; - -use crate::{ - arrow_wrappers::{WrappedArray, WrappedSchema}, - df_result, rresult, rresult_return, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, - volatility::FFI_Volatility, + FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; +use std::hash::{Hash, Hasher}; +use std::{ffi::c_void, sync::Arc}; -pub mod return_info; pub mod return_type_args; /// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries. @@ -77,25 +75,27 @@ pub struct FFI_ScalarUDF { /// Determines the return info of the underlying [`ScalarUDF`]. Either this /// or return_type may be implemented on a UDF. - pub return_type_from_args: unsafe extern "C" fn( + pub return_field_from_args: unsafe extern "C" fn( udf: &Self, - args: FFI_ReturnTypeArgs, + args: FFI_ReturnFieldArgs, ) - -> RResult, + -> RResult, /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray` /// within an AbiStable wrapper. + #[allow(clippy::type_complexity)] pub invoke_with_args: unsafe extern "C" fn( udf: &Self, args: RVec, + arg_fields: RVec, num_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult, /// See [`ScalarUDFImpl`] for details on short_circuits pub short_circuits: bool, - /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// Performs type coercion. To simply this interface, all UDFs are treated as having /// user defined signatures, which will in turn call coerce_types to be called. This /// call should be transparent to most users as the internal function performs the /// appropriate calls on the underlying [`ScalarUDF`] @@ -140,19 +140,20 @@ unsafe extern "C" fn return_type_fn_wrapper( rresult!(return_type) } -unsafe extern "C" fn return_type_from_args_fn_wrapper( +unsafe extern "C" fn return_field_from_args_fn_wrapper( udf: &FFI_ScalarUDF, - args: FFI_ReturnTypeArgs, -) -> RResult { + args: FFI_ReturnFieldArgs, +) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; - let args: ForeignReturnTypeArgsOwned = rresult_return!((&args).try_into()); - let args_ref: ForeignReturnTypeArgs = (&args).into(); + let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into()); + let args_ref: ForeignReturnFieldArgs = (&args).into(); let return_type = udf - .return_type_from_args((&args_ref).into()) - .and_then(FFI_ReturnInfo::try_from); + .return_field_from_args((&args_ref).into()) + .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from)) + .map(WrappedSchema); rresult!(return_type) } @@ -174,8 +175,9 @@ unsafe extern "C" fn coerce_types_fn_wrapper( unsafe extern "C" fn invoke_with_args_fn_wrapper( udf: &FFI_ScalarUDF, args: RVec, + arg_fields: RVec, number_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; @@ -189,12 +191,25 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( .collect::>(); let args = rresult_return!(args); - let return_type = rresult_return!(DataType::try_from(&return_type.0)); + let return_field = rresult_return!(Field::try_from(&return_field.0)).into(); + + let arg_fields = arg_fields + .into_iter() + .map(|wrapped_field| { + Field::try_from(&wrapped_field.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) + .collect::>>(); + let arg_fields = rresult_return!(arg_fields); let args = ScalarFunctionArgs { args, + arg_fields, number_rows, - return_type: &return_type, + return_field, + // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 + config_options: Arc::new(ConfigOptions::default()), }; let result = rresult_return!(udf @@ -243,7 +258,7 @@ impl From> for FFI_ScalarUDF { short_circuits, invoke_with_args: invoke_with_args_fn_wrapper, return_type: return_type_fn_wrapper, - return_type_from_args: return_type_from_args_fn_wrapper, + return_field_from_args: return_field_from_args_fn_wrapper, coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, @@ -275,6 +290,37 @@ pub struct ForeignScalarUDF { unsafe impl Send for ForeignScalarUDF {} unsafe impl Sync for ForeignScalarUDF {} +impl PartialEq for ForeignScalarUDF { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + aliases, + udf, + signature, + } = self; + name == &other.name + && aliases == &other.aliases + && std::ptr::eq(udf, &other.udf) + && signature == &other.signature + } +} +impl Eq for ForeignScalarUDF {} + +impl Hash for ForeignScalarUDF { + fn hash(&self, state: &mut H) { + let Self { + name, + aliases, + udf, + signature, + } = self; + name.hash(state); + aliases.hash(state); + std::ptr::hash(udf, state); + signature.hash(state); + } +} + impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF { type Error = DataFusionError; @@ -316,21 +362,28 @@ impl ScalarUDFImpl for ForeignScalarUDF { result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let args: FFI_ReturnTypeArgs = args.try_into()?; + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let args: FFI_ReturnFieldArgs = args.try_into()?; - let result = unsafe { (self.udf.return_type_from_args)(&self.udf, args) }; + let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) }; let result = df_result!(result); - result.and_then(|r| r.try_into()) + result.and_then(|r| { + Field::try_from(&r.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) } fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, + arg_fields, number_rows, - return_type, + return_field, + // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 + config_options: _config_options, } = invoke_args; let args = args @@ -347,10 +400,27 @@ impl ScalarUDFImpl for ForeignScalarUDF { .collect::, ArrowError>>()? .into(); - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); + let arg_fields_wrapped = arg_fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, ArrowError>>()?; + + let arg_fields = arg_fields_wrapped + .into_iter() + .map(WrappedSchema) + .collect::>(); + + let return_field = return_field.as_ref().clone(); + let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?); let result = unsafe { - (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type) + (self.udf.invoke_with_args)( + &self.udf, + args, + arg_fields, + number_rows, + return_field, + ) }; let result = df_result!(result)?; @@ -389,7 +459,7 @@ mod tests { let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; - assert!(original_udf.name() == foreign_udf.name()); + assert_eq!(original_udf.name(), foreign_udf.name()); Ok(()) } diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index a0897630e2ea9..c437c9537be6f 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -19,33 +19,30 @@ use abi_stable::{ std_types::{ROption, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow_schema::FieldRef; use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnTypeArgs, + common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; -use crate::{ - arrow_wrappers::WrappedSchema, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, -}; +use crate::arrow_wrappers::WrappedSchema; +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use prost::Message; -/// A stable struct for sharing a [`ReturnTypeArgs`] across FFI boundaries. +/// A stable struct for sharing a [`ReturnFieldArgs`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_ReturnTypeArgs { - arg_types: RVec, +pub struct FFI_ReturnFieldArgs { + arg_fields: RVec, scalar_arguments: RVec>>, - nullables: RVec, } -impl TryFrom> for FFI_ReturnTypeArgs { +impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; - fn try_from(value: ReturnTypeArgs) -> Result { - let arg_types = vec_datatype_to_rvec_wrapped(value.arg_types)?; + fn try_from(value: ReturnFieldArgs) -> Result { + let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -62,35 +59,31 @@ impl TryFrom> for FFI_ReturnTypeArgs { .collect(); let scalar_arguments = scalar_arguments?.into_iter().map(ROption::from).collect(); - let nullables = value.nullables.into(); Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } // TODO(tsaucer) It would be good to find a better way around this, but it // appears a restriction based on the need to have a borrowed ScalarValue -// in the arguments when converted to ReturnTypeArgs -pub struct ForeignReturnTypeArgsOwned { - arg_types: Vec, +// in the arguments when converted to ReturnFieldArgs +pub struct ForeignReturnFieldArgsOwned { + arg_fields: Vec, scalar_arguments: Vec>, - nullables: Vec, } -pub struct ForeignReturnTypeArgs<'a> { - arg_types: &'a [DataType], +pub struct ForeignReturnFieldArgs<'a> { + arg_fields: &'a [FieldRef], scalar_arguments: Vec>, - nullables: &'a [bool], } -impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { +impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { type Error = DataFusionError; - fn try_from(value: &FFI_ReturnTypeArgs) -> Result { - let arg_types = rvec_wrapped_to_vec_datatype(&value.arg_types)?; + fn try_from(value: &FFI_ReturnFieldArgs) -> Result { + let arg_fields = rvec_wrapped_to_vec_fieldref(&value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -107,36 +100,31 @@ impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { .collect(); let scalar_arguments = scalar_arguments?.into_iter().collect(); - let nullables = value.nullables.iter().cloned().collect(); - Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } -impl<'a> From<&'a ForeignReturnTypeArgsOwned> for ForeignReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgsOwned) -> Self { +impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgsOwned) -> Self { Self { - arg_types: &value.arg_types, + arg_fields: &value.arg_fields, scalar_arguments: value .scalar_arguments .iter() .map(|opt| opt.as_ref()) .collect(), - nullables: &value.nullables, } } } -impl<'a> From<&'a ForeignReturnTypeArgs<'a>> for ReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgs) -> Self { - ReturnTypeArgs { - arg_types: value.arg_types, +impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgs) -> Self { + ReturnFieldArgs { + arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, - nullables: value.nullables, } } } diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs new file mode 100644 index 0000000000000..ceedec2599a29 --- /dev/null +++ b/datafusion/ffi/src/udtf.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; + +use datafusion::error::Result; +use datafusion::{ + catalog::{TableFunctionImpl, TableProvider}, + prelude::{Expr, SessionContext}, +}; +use datafusion_proto::{ + logical_plan::{ + from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec, + }, + protobuf::LogicalExprList, +}; +use prost::Message; +use tokio::runtime::Handle; + +use crate::{ + df_result, rresult_return, + table_provider::{FFI_TableProvider, ForeignTableProvider}, +}; + +/// A stable struct for sharing a [`TableFunctionImpl`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_TableFunction { + /// Equivalent to the `call` function of the TableFunctionImpl. + /// The arguments are Expr passed as protobuf encoded bytes. + pub call: unsafe extern "C" fn( + udtf: &Self, + args: RVec, + ) -> RResult, + + /// Used to create a clone on the provider of the udtf. This should + /// only need to be called by the receiver of the udtf. + pub clone: unsafe extern "C" fn(udtf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udtf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udtf. + /// A [`ForeignTableFunction`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_TableFunction {} +unsafe impl Sync for FFI_TableFunction {} + +pub struct TableFunctionPrivateData { + udtf: Arc, + runtime: Option, +} + +impl FFI_TableFunction { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const TableFunctionPrivateData; + unsafe { &(*private_data).udtf } + } + + fn runtime(&self) -> Option { + let private_data = self.private_data as *const TableFunctionPrivateData; + unsafe { (*private_data).runtime.clone() } + } +} + +unsafe extern "C" fn call_fn_wrapper( + udtf: &FFI_TableFunction, + args: RVec, +) -> RResult { + let runtime = udtf.runtime(); + let udtf = udtf.inner(); + + let default_ctx = SessionContext::new(); + let codec = DefaultLogicalExtensionCodec {}; + + let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref())); + + let args = + rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)); + + let table_provider = rresult_return!(udtf.call(&args)); + RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime)) +} + +unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { + let private_data = Box::from_raw(udtf.private_data as *mut TableFunctionPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction { + let runtime = udtf.runtime(); + let udtf = udtf.inner(); + + FFI_TableFunction::new(Arc::clone(udtf), runtime) +} + +impl Clone for FFI_TableFunction { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl FFI_TableFunction { + pub fn new(udtf: Arc, runtime: Option) -> Self { + let private_data = Box::new(TableFunctionPrivateData { udtf, runtime }); + + Self { + call: call_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl From> for FFI_TableFunction { + fn from(udtf: Arc) -> Self { + let private_data = Box::new(TableFunctionPrivateData { + udtf, + runtime: None, + }); + + Self { + call: call_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_TableFunction { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDTF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignTableFunction is to be used by the caller of the UDTF, so it has +/// no knowledge or access to the private data. All interaction with the UDTF +/// must occur through the functions defined in FFI_TableFunction. +#[derive(Debug)] +pub struct ForeignTableFunction(FFI_TableFunction); + +unsafe impl Send for ForeignTableFunction {} +unsafe impl Sync for ForeignTableFunction {} + +impl From for ForeignTableFunction { + fn from(value: FFI_TableFunction) -> Self { + Self(value) + } +} + +impl TableFunctionImpl for ForeignTableFunction { + fn call(&self, args: &[Expr]) -> Result> { + let codec = DefaultLogicalExtensionCodec {}; + let expr_list = LogicalExprList { + expr: serialize_exprs(args, &codec)?, + }; + let filters_serialized = expr_list.encode_to_vec().into(); + + let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) }; + + let table_provider = df_result!(table_provider)?; + let table_provider: ForeignTableProvider = (&table_provider).into(); + + Ok(Arc::new(table_provider)) + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{ + record_batch, ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, + }, + datatypes::{DataType, Field, Schema}, + }; + use datafusion::{ + catalog::MemTable, common::exec_err, prelude::lit, scalar::ScalarValue, + }; + + use super::*; + + #[derive(Debug)] + struct TestUDTF {} + + impl TableFunctionImpl for TestUDTF { + fn call(&self, args: &[Expr]) -> Result> { + let args = args + .iter() + .map(|arg| { + if let Expr::Literal(scalar, _) = arg { + Ok(scalar) + } else { + exec_err!("Expected only literal arguments to table udf") + } + }) + .collect::>>()?; + + if args.len() < 2 { + exec_err!("Expected at least two arguments to table udf")? + } + + let ScalarValue::UInt64(Some(num_rows)) = args[0].to_owned() else { + exec_err!( + "First argument must be the number of elements to create as u64" + )? + }; + let num_rows = num_rows as usize; + + let mut fields = Vec::default(); + let mut arrays1 = Vec::default(); + let mut arrays2 = Vec::default(); + + let split = num_rows / 3; + for (idx, arg) in args[1..].iter().enumerate() { + let (field, array) = match arg { + ScalarValue::Utf8(s) => { + let s_vec = vec![s.to_owned(); num_rows]; + ( + Field::new(format!("field-{idx}"), DataType::Utf8, true), + Arc::new(StringArray::from(s_vec)) as ArrayRef, + ) + } + ScalarValue::UInt64(v) => { + let v_vec = vec![v.to_owned(); num_rows]; + ( + Field::new(format!("field-{idx}"), DataType::UInt64, true), + Arc::new(UInt64Array::from(v_vec)) as ArrayRef, + ) + } + ScalarValue::Float64(v) => { + let v_vec = vec![v.to_owned(); num_rows]; + ( + Field::new(format!("field-{idx}"), DataType::Float64, true), + Arc::new(Float64Array::from(v_vec)) as ArrayRef, + ) + } + _ => exec_err!( + "Test case only supports utf8, u64, and f64. Found {}", + arg.data_type() + )?, + }; + + fields.push(field); + arrays1.push(array.slice(0, split)); + arrays2.push(array.slice(split, num_rows - split)); + } + + let schema = Arc::new(Schema::new(fields)); + let batches = vec![ + RecordBatch::try_new(Arc::clone(&schema), arrays1)?, + RecordBatch::try_new(Arc::clone(&schema), arrays2)?, + ]; + + let table_provider = MemTable::try_new(schema, vec![batches])?; + + Ok(Arc::new(table_provider)) + } + } + + #[tokio::test] + async fn test_round_trip_udtf() -> Result<()> { + let original_udtf = Arc::new(TestUDTF {}) as Arc; + + let local_udtf: FFI_TableFunction = + FFI_TableFunction::new(Arc::clone(&original_udtf), None); + + let foreign_udf: ForeignTableFunction = local_udtf.into(); + + let table = + foreign_udf.call(&vec![lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?; + + let ctx = SessionContext::default(); + let _ = ctx.register_table("test-table", table)?; + + let returned_batches = ctx.table("test-table").await?.collect().await?; + + assert_eq!(returned_batches.len(), 2); + let expected_batch_0 = record_batch!( + ("field-0", Utf8, ["one", "one"]), + ("field-1", Float64, [2.0, 2.0]), + ("field-2", UInt64, [3, 3]) + )?; + assert_eq!(returned_batches[0], expected_batch_0); + + let expected_batch_1 = record_batch!( + ("field-0", Utf8, ["one", "one", "one", "one"]), + ("field-1", Float64, [2.0, 2.0, 2.0, 2.0]), + ("field-2", UInt64, [3, 3, 3, 3]) + )?; + assert_eq!(returned_batches[1], expected_batch_1); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs new file mode 100644 index 0000000000000..9f56e2d4788b7 --- /dev/null +++ b/datafusion/ffi/src/udwf/mod.rs @@ -0,0 +1,453 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::Schema; +use arrow::{ + compute::SortOptions, + datatypes::{DataType, SchemaRef}, +}; +use arrow_schema::{Field, FieldRef}; +use datafusion::logical_expr::LimitEffect; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf, + PartitionEvaluator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{Signature, WindowUDF, WindowUDFImpl}, +}; +use datafusion_common::exec_err; +use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator}; +use partition_evaluator_args::{ + FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs, +}; +use std::hash::{Hash, Hasher}; +use std::{ffi::c_void, sync::Arc}; + +mod partition_evaluator; +mod partition_evaluator_args; +mod range; + +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; + +/// A stable struct for sharing a [`WindowUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_WindowUDF { + /// FFI equivalent to the `name` of a [`WindowUDF`] + pub name: RString, + + /// FFI equivalent to the `aliases` of a [`WindowUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `volatility` of a [`WindowUDF`] + pub volatility: FFI_Volatility, + + pub partition_evaluator: + unsafe extern "C" fn( + udwf: &Self, + args: FFI_PartitionEvaluatorArgs, + ) -> RResult, + + pub field: unsafe extern "C" fn( + udwf: &Self, + input_types: RVec, + display_name: RString, + ) -> RResult, + + /// Performs type coercion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`WindowUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + pub sort_options: ROption, + + /// Used to create a clone on the provider of the udf. This should + /// only need to be called by the receiver of the udf. + pub clone: unsafe extern "C" fn(udf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udf. + /// A [`ForeignWindowUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_WindowUDF {} +unsafe impl Sync for FFI_WindowUDF {} + +pub struct WindowUDFPrivateData { + pub udf: Arc, +} + +impl FFI_WindowUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const WindowUDFPrivateData; + &(*private_data).udf + } +} + +unsafe extern "C" fn partition_evaluator_fn_wrapper( + udwf: &FFI_WindowUDF, + args: FFI_PartitionEvaluatorArgs, +) -> RResult { + let inner = udwf.inner(); + + let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args)); + + let evaluator = rresult_return!(inner.partition_evaluator_factory((&args).into())); + + RResult::ROk(evaluator.into()) +} + +unsafe extern "C" fn field_fn_wrapper( + udwf: &FFI_WindowUDF, + input_fields: RVec, + display_name: RString, +) -> RResult { + let inner = udwf.inner(); + + let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields)); + + let field = rresult_return!(inner.field(WindowUDFFieldArgs::new( + &input_fields, + display_name.as_str() + ))); + + let schema = Arc::new(Schema::new(vec![field])); + + RResult::ROk(WrappedSchema::from(schema)) +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udwf: &FFI_WindowUDF, + arg_types: RVec, +) -> RResult, RString> { + let inner = udwf.inner(); + + let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)) + .into_iter() + .map(|dt| Field::new("f", dt, false)) + .map(Arc::new) + .collect::>(); + + let return_fields = rresult_return!(fields_with_window_udf(&arg_fields, inner)); + let return_types = return_fields + .into_iter() + .map(|f| f.data_type().to_owned()) + .collect::>(); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) { + let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF { + // let private_data = udf.private_data as *const WindowUDFPrivateData; + // let udf_data = &(*private_data); + + // let private_data = Box::new(WindowUDFPrivateData { + // udf: Arc::clone(&udf_data.udf), + // }); + let private_data = Box::new(WindowUDFPrivateData { + udf: Arc::clone(udwf.inner()), + }); + + FFI_WindowUDF { + name: udwf.name.clone(), + aliases: udwf.aliases.clone(), + volatility: udwf.volatility.clone(), + partition_evaluator: partition_evaluator_fn_wrapper, + sort_options: udwf.sort_options.clone(), + coerce_types: coerce_types_fn_wrapper, + field: field_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } +} + +impl Clone for FFI_WindowUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_WindowUDF { + fn from(udf: Arc) -> Self { + let name = udf.name().into(); + let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let volatility = udf.signature().volatility.into(); + let sort_options = udf.sort_options().map(|v| (&v).into()).into(); + + let private_data = Box::new(WindowUDFPrivateData { udf }); + + Self { + name, + aliases, + volatility, + partition_evaluator: partition_evaluator_fn_wrapper, + sort_options, + coerce_types: coerce_types_fn_wrapper, + field: field_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_WindowUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignWindowUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_WindowUDF. +#[derive(Debug)] +pub struct ForeignWindowUDF { + name: String, + aliases: Vec, + udf: FFI_WindowUDF, + signature: Signature, +} + +unsafe impl Send for ForeignWindowUDF {} +unsafe impl Sync for ForeignWindowUDF {} + +impl PartialEq for ForeignWindowUDF { + fn eq(&self, other: &Self) -> bool { + // FFI_WindowUDF cannot be compared, so identity equality is the best we can do. + std::ptr::eq(self, other) + } +} +impl Eq for ForeignWindowUDF {} +impl Hash for ForeignWindowUDF { + fn hash(&self, state: &mut H) { + std::ptr::hash(self, state) + } +} + +impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF { + type Error = DataFusionError; + + fn try_from(udf: &FFI_WindowUDF) -> Result { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + name, + udf: udf.clone(), + aliases, + signature, + }) + } +} + +impl WindowUDFImpl for ForeignWindowUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } + + fn partition_evaluator( + &self, + args: datafusion::logical_expr::function::PartitionEvaluatorArgs, + ) -> Result> { + let evaluator = unsafe { + let args = FFI_PartitionEvaluatorArgs::try_from(args)?; + (self.udf.partition_evaluator)(&self.udf, args) + }; + + df_result!(evaluator).map(|evaluator| { + Box::new(ForeignPartitionEvaluator::from(evaluator)) + as Box + }) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + unsafe { + let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?; + let schema = df_result!((self.udf.field)( + &self.udf, + input_types, + field_args.name().into() + ))?; + let schema: SchemaRef = schema.into(); + + match schema.fields().is_empty() { + true => exec_err!( + "Unable to retrieve field in WindowUDF via FFI - schema has no fields" + ), + false => Ok(schema.field(0).to_owned().into()), + } + } + } + + fn sort_options(&self) -> Option { + let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into(); + options.map(|s| s.into()) + } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } +} + +#[repr(C)] +#[derive(Debug, StableAbi, Clone)] +#[allow(non_camel_case_types)] +pub struct FFI_SortOptions { + pub descending: bool, + pub nulls_first: bool, +} + +impl From<&SortOptions> for FFI_SortOptions { + fn from(value: &SortOptions) -> Self { + Self { + descending: value.descending, + nulls_first: value.nulls_first, + } + } +} + +impl From<&FFI_SortOptions> for SortOptions { + fn from(value: &FFI_SortOptions) -> Self { + Self { + descending: value.descending, + nulls_first: value.nulls_first, + } + } +} + +#[cfg(test)] +#[cfg(feature = "integration-tests")] +mod tests { + use crate::tests::create_record_batch; + use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF}; + use arrow::array::{create_array, ArrayRef}; + use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift}; + use datafusion::logical_expr::expr::Sort; + use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl}; + use datafusion::prelude::SessionContext; + use std::sync::Arc; + + fn create_test_foreign_udwf( + original_udwf: impl WindowUDFImpl + 'static, + ) -> datafusion::common::Result { + let original_udwf = Arc::new(WindowUDF::from(original_udwf)); + + let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + + let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; + Ok(foreign_udwf.into()) + } + + #[test] + fn test_round_trip_udwf() -> datafusion::common::Result<()> { + let original_udwf = lag_udwf(); + let original_name = original_udwf.name().to_owned(); + + // Convert to FFI format + let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + + // Convert back to native format + let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; + let foreign_udwf: WindowUDF = foreign_udwf.into(); + + assert_eq!(original_name, foreign_udwf.name()); + Ok(()) + } + + #[tokio::test] + async fn test_lag_udwf() -> datafusion::common::Result<()> { + let udwf = create_test_foreign_udwf(WindowShift::lag())?; + + let ctx = SessionContext::default(); + let df = ctx.read_batch(create_record_batch(-5, 5))?; + + let df = df.select(vec![ + col("a"), + udwf.call(vec![col("a")]) + .order_by(vec![Sort::new(col("a"), true, true)]) + .build() + .unwrap() + .alias("lag_a"), + ])?; + + df.clone().show().await?; + + let result = df.collect().await?; + let expected = + create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)]) + as ArrayRef; + + assert_eq!(result.len(), 1); + assert_eq!(result[0].column(1), &expected); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udwf/partition_evaluator.rs b/datafusion/ffi/src/udwf/partition_evaluator.rs new file mode 100644 index 0000000000000..14cf23b919aa3 --- /dev/null +++ b/datafusion/ffi/src/udwf/partition_evaluator.rs @@ -0,0 +1,320 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Range}; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{window_state::WindowAggState, PartitionEvaluator}, + scalar::ScalarValue, +}; +use prost::Message; + +use super::range::FFI_Range; + +/// A stable struct for sharing [`PartitionEvaluator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`PartitionEvaluator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PartitionEvaluator { + pub evaluate_all: unsafe extern "C" fn( + evaluator: &mut Self, + values: RVec, + num_rows: usize, + ) -> RResult, + + pub evaluate: unsafe extern "C" fn( + evaluator: &mut Self, + values: RVec, + range: FFI_Range, + ) -> RResult, RString>, + + pub evaluate_all_with_rank: unsafe extern "C" fn( + evaluator: &Self, + num_rows: usize, + ranks_in_partition: RVec, + ) + -> RResult, + + pub get_range: unsafe extern "C" fn( + evaluator: &Self, + idx: usize, + n_rows: usize, + ) -> RResult, + + pub is_causal: bool, + + pub supports_bounded_execution: bool, + pub uses_window_frame: bool, + pub include_rank: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(evaluator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the evaluator. + /// A [`ForeignPartitionEvaluator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_PartitionEvaluator {} +unsafe impl Sync for FFI_PartitionEvaluator {} + +pub struct PartitionEvaluatorPrivateData { + pub evaluator: Box, +} + +impl FFI_PartitionEvaluator { + unsafe fn inner_mut(&mut self) -> &mut Box { + let private_data = self.private_data as *mut PartitionEvaluatorPrivateData; + &mut (*private_data).evaluator + } + + unsafe fn inner(&self) -> &(dyn PartitionEvaluator + 'static) { + let private_data = self.private_data as *mut PartitionEvaluatorPrivateData; + (*private_data).evaluator.as_ref() + } +} + +unsafe extern "C" fn evaluate_all_fn_wrapper( + evaluator: &mut FFI_PartitionEvaluator, + values: RVec, + num_rows: usize, +) -> RResult { + let inner = evaluator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let return_array = inner + .evaluate_all(&values_arrays, num_rows) + .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + + rresult!(return_array) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + evaluator: &mut FFI_PartitionEvaluator, + values: RVec, + range: FFI_Range, +) -> RResult, RString> { + let inner = evaluator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + // let return_array = (inner.evaluate(&values_arrays, &range.into())); + // .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + let scalar_result = rresult_return!(inner.evaluate(&values_arrays, &range.into())); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn evaluate_all_with_rank_fn_wrapper( + evaluator: &FFI_PartitionEvaluator, + num_rows: usize, + ranks_in_partition: RVec, +) -> RResult { + let inner = evaluator.inner(); + + let ranks_in_partition = ranks_in_partition + .into_iter() + .map(Range::from) + .collect::>(); + + let return_array = inner + .evaluate_all_with_rank(num_rows, &ranks_in_partition) + .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + + rresult!(return_array) +} + +unsafe extern "C" fn get_range_fn_wrapper( + evaluator: &FFI_PartitionEvaluator, + idx: usize, + n_rows: usize, +) -> RResult { + let inner = evaluator.inner(); + let range = inner.get_range(idx, n_rows).map(FFI_Range::from); + + rresult!(range) +} + +unsafe extern "C" fn release_fn_wrapper(evaluator: &mut FFI_PartitionEvaluator) { + let private_data = + Box::from_raw(evaluator.private_data as *mut PartitionEvaluatorPrivateData); + drop(private_data); +} + +impl From> for FFI_PartitionEvaluator { + fn from(evaluator: Box) -> Self { + let is_causal = evaluator.is_causal(); + let supports_bounded_execution = evaluator.supports_bounded_execution(); + let include_rank = evaluator.include_rank(); + let uses_window_frame = evaluator.uses_window_frame(); + + let private_data = PartitionEvaluatorPrivateData { evaluator }; + + Self { + evaluate: evaluate_fn_wrapper, + evaluate_all: evaluate_all_fn_wrapper, + evaluate_all_with_rank: evaluate_all_with_rank_fn_wrapper, + get_range: get_range_fn_wrapper, + is_causal, + supports_bounded_execution, + include_rank, + uses_window_frame, + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_PartitionEvaluator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignPartitionEvaluator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_PartitionEvaluator. +#[derive(Debug)] +pub struct ForeignPartitionEvaluator { + evaluator: FFI_PartitionEvaluator, +} + +unsafe impl Send for ForeignPartitionEvaluator {} +unsafe impl Sync for ForeignPartitionEvaluator {} + +impl From for ForeignPartitionEvaluator { + fn from(evaluator: FFI_PartitionEvaluator) -> Self { + Self { evaluator } + } +} + +impl PartitionEvaluator for ForeignPartitionEvaluator { + fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> { + // Exposing `memoize` increases the surface are of the FFI work + // so for now we dot support it. + Ok(()) + } + + fn get_range(&self, idx: usize, n_rows: usize) -> Result> { + let range = unsafe { (self.evaluator.get_range)(&self.evaluator, idx, n_rows) }; + df_result!(range).map(Range::from) + } + + /// Get whether evaluator needs future data for its result (if so returns `false`) or not + fn is_causal(&self) -> bool { + self.evaluator.is_causal + } + + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { + let result = unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + (self.evaluator.evaluate_all)(&mut self.evaluator, values, num_rows) + }; + + let array = df_result!(result)?; + + Ok(array.try_into()?) + } + + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let scalar_bytes = df_result!((self.evaluator.evaluate)( + &mut self.evaluator, + values, + range.to_owned().into() + ))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let result = unsafe { + let ranks_in_partition = ranks_in_partition + .iter() + .map(|rank| FFI_Range::from(rank.to_owned())) + .collect(); + (self.evaluator.evaluate_all_with_rank)( + &self.evaluator, + num_rows, + ranks_in_partition, + ) + }; + + let array = df_result!(result)?; + + Ok(array.try_into()?) + } + + fn supports_bounded_execution(&self) -> bool { + self.evaluator.supports_bounded_execution + } + + fn uses_window_frame(&self) -> bool { + self.evaluator.uses_window_frame + } + + fn include_rank(&self) -> bool { + self.evaluator.include_rank + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/ffi/src/udwf/partition_evaluator_args.rs b/datafusion/ffi/src/udwf/partition_evaluator_args.rs new file mode 100644 index 0000000000000..cd26412564374 --- /dev/null +++ b/datafusion/ffi/src/udwf/partition_evaluator_args.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::HashMap, sync::Arc}; + +use crate::arrow_wrappers::WrappedSchema; +use abi_stable::{std_types::RVec, StableAbi}; +use arrow::{ + datatypes::{DataType, Field, Schema, SchemaRef}, + error::ArrowError, + ffi::FFI_ArrowSchema, +}; +use arrow_schema::FieldRef; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::function::PartitionEvaluatorArgs, + physical_plan::{expressions::Column, PhysicalExpr}, + prelude::SessionContext, +}; +use datafusion_common::exec_datafusion_err; +use datafusion_proto::{ + physical_plan::{ + from_proto::parse_physical_expr, to_proto::serialize_physical_exprs, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalExprNode, +}; +use prost::Message; + +/// A stable struct for sharing [`PartitionEvaluatorArgs`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`PartitionEvaluatorArgs`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PartitionEvaluatorArgs { + input_exprs: RVec>, + input_fields: RVec, + is_reversed: bool, + ignore_nulls: bool, + schema: WrappedSchema, +} + +impl TryFrom> for FFI_PartitionEvaluatorArgs { + type Error = DataFusionError; + fn try_from(args: PartitionEvaluatorArgs) -> Result { + // This is a bit of a hack. Since PartitionEvaluatorArgs does not carry a schema + // around, and instead passes the data types directly we are unable to decode the + // protobuf PhysicalExpr correctly. In evaluating the code the only place these + // appear to be really used are the Column data types. So here we will find all + // of the required columns and create a schema that has empty fields except for + // the ones we require. Ideally we would enhance PartitionEvaluatorArgs to just + // pass along the schema, but that is a larger breaking change. + let required_columns: HashMap = args + .input_exprs() + .iter() + .zip(args.input_fields()) + .filter_map(|(expr, field)| { + expr.as_any() + .downcast_ref::() + .map(|column| (column.index(), (column.name(), field.data_type()))) + }) + .collect(); + + let max_column = required_columns.keys().max(); + let fields: Vec<_> = max_column + .map(|max_column| { + (0..(max_column + 1)) + .map(|idx| match required_columns.get(&idx) { + Some((name, data_type)) => { + Field::new(*name, (*data_type).clone(), true) + } + None => Field::new( + format!("ffi_partition_evaluator_col_{idx}"), + DataType::Null, + true, + ), + }) + .collect() + }) + .unwrap_or_default(); + + let schema = Arc::new(Schema::new(fields)); + + let codec = DefaultPhysicalExtensionCodec {}; + let input_exprs = serialize_physical_exprs(args.input_exprs(), &codec)? + .into_iter() + .map(|expr_node| expr_node.encode_to_vec().into()) + .collect(); + + let input_fields = args + .input_fields() + .iter() + .map(|input_type| FFI_ArrowSchema::try_from(input_type).map(WrappedSchema)) + .collect::, ArrowError>>()? + .into(); + + let schema: WrappedSchema = schema.into(); + + Ok(Self { + input_exprs, + input_fields, + schema, + is_reversed: args.is_reversed(), + ignore_nulls: args.ignore_nulls(), + }) + } +} + +/// This struct mirrors PartitionEvaluatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// PartitionEvaluatorArgs can then reference. +pub struct ForeignPartitionEvaluatorArgs { + input_exprs: Vec>, + input_fields: Vec, + is_reversed: bool, + ignore_nulls: bool, +} + +impl TryFrom for ForeignPartitionEvaluatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result { + let default_ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + + let schema: SchemaRef = value.schema.into(); + + let input_exprs = value + .input_exprs + .into_iter() + .map(|input_expr_bytes| PhysicalExprNode::decode(input_expr_bytes.as_ref())) + .collect::, prost::DecodeError>>() + .map_err(|e| exec_datafusion_err!("Failed to decode PhysicalExprNode: {e}"))? + .iter() + .map(|expr_node| { + parse_physical_expr(expr_node, &default_ctx.task_ctx(), &schema, &codec) + }) + .collect::>>()?; + + let input_fields = input_exprs + .iter() + .map(|expr| expr.return_field(&schema)) + .collect::>>()?; + + Ok(Self { + input_exprs, + input_fields, + is_reversed: value.is_reversed, + ignore_nulls: value.ignore_nulls, + }) + } +} + +impl<'a> From<&'a ForeignPartitionEvaluatorArgs> for PartitionEvaluatorArgs<'a> { + fn from(value: &'a ForeignPartitionEvaluatorArgs) -> Self { + PartitionEvaluatorArgs::new( + &value.input_exprs, + &value.input_fields, + value.is_reversed, + value.ignore_nulls, + ) + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/ffi/src/udf/return_info.rs b/datafusion/ffi/src/udwf/range.rs similarity index 50% rename from datafusion/ffi/src/udf/return_info.rs rename to datafusion/ffi/src/udwf/range.rs index cf76ddd1db762..1ddcc4199fe28 100644 --- a/datafusion/ffi/src/udf/return_info.rs +++ b/datafusion/ffi/src/udwf/range.rs @@ -15,39 +15,50 @@ // specific language governing permissions and limitations // under the License. -use abi_stable::StableAbi; -use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; -use datafusion::{error::DataFusionError, logical_expr::ReturnInfo}; +use std::ops::Range; -use crate::arrow_wrappers::WrappedSchema; +use abi_stable::StableAbi; -/// A stable struct for sharing a [`ReturnInfo`] across FFI boundaries. +/// A stable struct for sharing [`Range`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`Range`]. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_ReturnInfo { - return_type: WrappedSchema, - nullable: bool, +pub struct FFI_Range { + pub start: usize, + pub end: usize, } -impl TryFrom for FFI_ReturnInfo { - type Error = DataFusionError; +impl From> for FFI_Range { + fn from(value: Range) -> Self { + Self { + start: value.start, + end: value.end, + } + } +} - fn try_from(value: ReturnInfo) -> Result { - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(value.return_type())?); - Ok(Self { - return_type, - nullable: value.nullable(), - }) +impl From for Range { + fn from(value: FFI_Range) -> Self { + Self { + start: value.start, + end: value.end, + } } } -impl TryFrom for ReturnInfo { - type Error = DataFusionError; +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_ffi_range() { + let original = Range { start: 10, end: 30 }; - fn try_from(value: FFI_ReturnInfo) -> Result { - let return_type = DataType::try_from(&value.return_type.0)?; + let ffi_range: FFI_Range = original.clone().into(); + let round_trip: Range = ffi_range.into(); - Ok(ReturnInfo::new(return_type, value.nullable)) + assert_eq!(original, round_trip); } } diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 9d5f2aefe324b..151464dc97458 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::arrow_wrappers::WrappedSchema; use abi_stable::std_types::RVec; +use arrow::datatypes::Field; use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; +use arrow_schema::FieldRef; +use std::sync::Arc; -use crate::arrow_wrappers::WrappedSchema; - -/// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a +/// This macro is a helpful conversion utility to convert from an abi_stable::RResult to a /// DataFusion result. #[macro_export] macro_rules! df_result { @@ -28,13 +30,13 @@ macro_rules! df_result { match $x { abi_stable::std_types::RResult::ROk(v) => Ok(v), abi_stable::std_types::RResult::RErr(e) => { - Err(datafusion::error::DataFusionError::Execution(e.to_string())) + datafusion_common::exec_err!("FFI error: {}", e) } } }; } -/// This macro is a helpful conversion utility to conver from a DataFusion Result to an abi_stable::RResult +/// This macro is a helpful conversion utility to convert from a DataFusion Result to an abi_stable::RResult #[macro_export] macro_rules! rresult { ( $x:expr ) => { @@ -47,7 +49,7 @@ macro_rules! rresult { }; } -/// This macro is a helpful conversion utility to conver from a DataFusion Result to an abi_stable::RResult +/// This macro is a helpful conversion utility to convert from a DataFusion Result to an abi_stable::RResult /// and to also call return when it is an error. Since you cannot use `?` on an RResult, this is designed /// to mimic the pattern. #[macro_export] @@ -64,6 +66,31 @@ macro_rules! rresult_return { }; } +/// This is a utility function to convert a slice of [`Field`] to its equivalent +/// FFI friendly counterpart, [`WrappedSchema`] +pub fn vec_fieldref_to_rvec_wrapped( + fields: &[FieldRef], +) -> Result, arrow::error::ArrowError> { + Ok(fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, arrow::error::ArrowError>>()? + .into_iter() + .map(WrappedSchema) + .collect()) +} + +/// This is a utility function to convert an FFI friendly vector of [`WrappedSchema`] +/// to their equivalent [`Field`]. +pub fn rvec_wrapped_to_vec_fieldref( + fields: &RVec, +) -> Result, arrow::error::ArrowError> { + fields + .iter() + .map(|d| Field::try_from(&d.0).map(Arc::new)) + .collect() +} + /// This is a utility function to convert a slice of [`DataType`] to its equivalent /// FFI friendly counterpart, [`WrappedSchema`] pub fn vec_datatype_to_rvec_wrapped( @@ -115,21 +142,21 @@ mod tests { let returned_err_result = df_result!(err_r_result); assert!(returned_err_result.is_err()); assert!( - returned_err_result.unwrap_err().to_string() - == format!("Execution error: {}", ERROR_VALUE) + returned_err_result.unwrap_err().strip_backtrace() + == format!("Execution error: FFI error: {ERROR_VALUE}") ); let ok_result: Result = Ok(VALID_VALUE.to_string()); let err_result: Result = - Err(DataFusionError::Execution(ERROR_VALUE.to_string())); + datafusion_common::exec_err!("{ERROR_VALUE}"); let returned_ok_r_result = wrap_result(ok_result); assert!(returned_ok_r_result == RResult::ROk(VALID_VALUE.into())); let returned_err_r_result = wrap_result(err_result); - assert!( - returned_err_r_result - == RResult::RErr(format!("Execution error: {}", ERROR_VALUE).into()) - ); + assert!(returned_err_r_result.is_err()); + assert!(returned_err_r_result + .unwrap_err() + .starts_with(format!("Execution error: {ERROR_VALUE}").as_str())); } } diff --git a/datafusion/ffi/src/volatility.rs b/datafusion/ffi/src/volatility.rs index 0aaf68a174cfd..f1705da294a39 100644 --- a/datafusion/ffi/src/volatility.rs +++ b/datafusion/ffi/src/volatility.rs @@ -19,7 +19,7 @@ use abi_stable::StableAbi; use datafusion::logical_expr::Volatility; #[repr(C)] -#[derive(Debug, StableAbi)] +#[derive(Debug, StableAbi, Clone)] #[allow(non_camel_case_types)] pub enum FFI_Volatility { Immutable, diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index c6df324e9a17c..eb53e76bfb9b6 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -16,10 +16,9 @@ // under the License. /// Add an additional module here for convenience to scope this to only -/// when the feature integtation-tests is built +/// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { - use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::SessionContext; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; diff --git a/datafusion/ffi/tests/ffi_udaf.rs b/datafusion/ffi/tests/ffi_udaf.rs new file mode 100644 index 0000000000000..ffd99bac62ecc --- /dev/null +++ b/datafusion/ffi/tests/ffi_udaf.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integration-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + use arrow::array::Float64Array; + use datafusion::common::record_batch; + use datafusion::error::{DataFusionError, Result}; + use datafusion::logical_expr::AggregateUDF; + use datafusion::prelude::{col, SessionContext}; + + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udaf::ForeignAggregateUDF; + + #[tokio::test] + async fn test_ffi_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_sum_func = + module + .create_sum_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_sum_func: ForeignAggregateUDF = (&ffi_sum_func).try_into()?; + + let udaf: AggregateUDF = foreign_sum_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("sum_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![1, 2, 4]), + ("sum_b", Float64, vec![1.0, 4.0, 16.0]) + )?; + + assert_eq!(result[0], expected); + + Ok(()) + } + + #[tokio::test] + async fn test_ffi_grouping_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_stddev_func = + module + .create_stddev_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + + let udaf: AggregateUDF = foreign_stddev_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ( + "b", + Float64, + vec![ + 1.0, + 2.0, + 2.0 + 2.0_f64.sqrt(), + 4.0, + 4.0, + 4.0 + 3.0_f64.sqrt(), + 4.0 + 3.0_f64.sqrt() + ] + ) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("stddev_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + let result = result[0].column_by_name("stddev_b").unwrap(); + let result = result + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + assert!(result.first().unwrap().is_nan()); + assert!(result.get(1).unwrap() - 1.0 < 0.00001); + assert!(result.get(2).unwrap() - 1.0 < 0.00001); + + Ok(()) + } +} diff --git a/datafusion/ffi/tests/ffi_udf.rs b/datafusion/ffi/tests/ffi_udf.rs index bbc23552def43..fd6a84bcf5b08 100644 --- a/datafusion/ffi/tests/ffi_udf.rs +++ b/datafusion/ffi/tests/ffi_udf.rs @@ -16,7 +16,7 @@ // under the License. /// Add an additional module here for convenience to scope this to only -/// when the feature integtation-tests is built +/// when the feature integration-tests is built #[cfg(feature = "integration-tests")] mod tests { diff --git a/datafusion/ffi/tests/ffi_udtf.rs b/datafusion/ffi/tests/ffi_udtf.rs new file mode 100644 index 0000000000000..8c1c64a092e13 --- /dev/null +++ b/datafusion/ffi/tests/ffi_udtf.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integration-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + + use std::sync::Arc; + + use arrow::array::{create_array, ArrayRef}; + use datafusion::error::{DataFusionError, Result}; + use datafusion::prelude::SessionContext; + + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udtf::ForeignTableFunction; + + /// This test validates that we can load an external module and use a scalar + /// udf defined in it via the foreign function interface. In this case we are + /// using the abs() function as our scalar UDF. + #[tokio::test] + async fn test_user_defined_table_function() -> Result<()> { + let module = get_module()?; + + let ffi_table_func = module + .create_table_function() + .ok_or(DataFusionError::NotImplemented( + "External table function provider failed to implement create_table_function" + .to_string(), + ))?(); + let foreign_table_func: ForeignTableFunction = ffi_table_func.into(); + + let udtf = Arc::new(foreign_table_func); + + let ctx = SessionContext::default(); + ctx.register_udtf("my_range", udtf); + + let result = ctx + .sql("SELECT * FROM my_range(5)") + .await? + .collect() + .await?; + let expected = create_array!(Int64, [0, 1, 2, 3, 4]) as ArrayRef; + + assert!(result.len() == 1); + assert!(result[0].column(0) == &expected); + + Ok(()) + } +} diff --git a/datafusion/ffi/tests/ffi_udwf.rs b/datafusion/ffi/tests/ffi_udwf.rs new file mode 100644 index 0000000000000..18ffd0c5bcb79 --- /dev/null +++ b/datafusion/ffi/tests/ffi_udwf.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integration-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + use arrow::array::{create_array, ArrayRef}; + use datafusion::error::{DataFusionError, Result}; + use datafusion::logical_expr::expr::Sort; + use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF}; + use datafusion::prelude::SessionContext; + use datafusion_ffi::tests::create_record_batch; + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udwf::ForeignWindowUDF; + + #[tokio::test] + async fn test_rank_udwf() -> Result<()> { + let module = get_module()?; + + let ffi_rank_func = + module + .create_rank_udwf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_scalar_udf" + .to_string(), + ))?(); + let foreign_rank_func: ForeignWindowUDF = (&ffi_rank_func).try_into()?; + + let udwf: WindowUDF = foreign_rank_func.into(); + + let ctx = SessionContext::default(); + let df = ctx.read_batch(create_record_batch(-5, 5))?; + + let df = df.select(vec![ + col("a"), + udwf.call(vec![]) + .order_by(vec![Sort::new(col("a"), true, true)]) + .build() + .unwrap() + .alias("rank_a"), + ])?; + + df.clone().show().await?; + + let result = df.collect().await?; + let expected = create_array!(UInt64, [1, 2, 3, 4, 5]) as ArrayRef; + + assert_eq!(result.len(), 1); + assert_eq!(result[0].column(1), &expected); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate-common/Cargo.toml b/datafusion/functions-aggregate-common/Cargo.toml index cf065ca1cb174..a6e0a1fc2f8bb 100644 --- a/datafusion/functions-aggregate-common/Cargo.toml +++ b/datafusion/functions-aggregate-common/Cargo.toml @@ -19,6 +19,7 @@ name = "datafusion-functions-aggregate-common" description = "Utility functions for implementing aggregate functions for the DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } homepage = { workspace = true } diff --git a/datafusion/functions-aggregate-common/README.md b/datafusion/functions-aggregate-common/README.md new file mode 100644 index 0000000000000..3d52aa722033a --- /dev/null +++ b/datafusion/functions-aggregate-common/README.md @@ -0,0 +1,32 @@ + + +# Apache DataFusion Aggregate Function Common Library + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate contains common functionality for implementation aggregate and window functions. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index a230bb0289091..e0f7af1fb38e3 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -15,20 +15,20 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; /// [`AccumulatorArgs`] contains information about how an aggregate /// function was called, including the types of its arguments and any optional /// ordering expressions. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct AccumulatorArgs<'a> { - /// The return type of the aggregate function. - pub return_type: &'a DataType, + /// The return field of the aggregate function. + pub return_field: FieldRef, /// The schema of the input arguments pub schema: &'a Schema, @@ -50,9 +50,7 @@ pub struct AccumulatorArgs<'a> { /// ```sql /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; /// ``` - /// - /// If no `ORDER BY` is specified, `ordering_req` will be empty. - pub ordering_req: &'a LexOrdering, + pub order_bys: &'a [PhysicalSortExpr], /// Whether the aggregation is running in reverse order pub is_reversed: bool, @@ -71,6 +69,13 @@ pub struct AccumulatorArgs<'a> { pub exprs: &'a [Arc], } +impl AccumulatorArgs<'_> { + /// Returns the return type of the aggregate function. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + /// Factory that returns an accumulator for the given aggregate function. pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; @@ -81,15 +86,22 @@ pub struct StateFieldsArgs<'a> { /// The name of the aggregate function. pub name: &'a str, - /// The input types of the aggregate function. - pub input_types: &'a [DataType], + /// The input fields of the aggregate function. + pub input_fields: &'a [FieldRef], - /// The return type of the aggregate function. - pub return_type: &'a DataType, + /// The return fields of the aggregate function. + pub return_field: FieldRef, /// The ordering fields of the aggregate function. - pub ordering_fields: &'a [Field], + pub ordering_fields: &'a [FieldRef], /// Whether the aggregate function is distinct. pub is_distinct: bool, } + +impl StateFieldsArgs<'_> { + /// The return type of the aggregate function. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index c9cbaa8396fc5..aadce907e7cc3 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -15,5 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod avg_distinct; pub mod count_distinct; pub mod groups_accumulator; +pub mod sum_distinct; diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs new file mode 100644 index 0000000000000..56cdaf6618de5 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod decimal; +mod numeric; + +pub use decimal::DecimalDistinctAvgAccumulator; +pub use numeric::Float64DistinctAvgAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs new file mode 100644 index 0000000000000..9920bf5bf4485 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, ArrowNumericType}, + datatypes::{ + i256, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, + }, +}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr_common::accumulator::Accumulator; +use std::fmt::Debug; +use std::mem::size_of_val; + +use crate::aggregate::sum_distinct::DistinctSumAccumulator; +use crate::utils::DecimalAverager; + +/// Generic implementation of `AVG DISTINCT` for Decimal types. +/// Handles both all Arrow decimal types (32, 64, 128 and 256 bits). +#[derive(Debug)] +pub struct DecimalDistinctAvgAccumulator { + sum_accumulator: DistinctSumAccumulator, + sum_scale: i8, + target_precision: u8, + target_scale: i8, +} + +impl DecimalDistinctAvgAccumulator { + pub fn with_decimal_params( + sum_scale: i8, + target_precision: u8, + target_scale: i8, + ) -> Self { + let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale); + + Self { + sum_accumulator: DistinctSumAccumulator::new(&data_type), + sum_scale, + target_precision, + target_scale, + } + } +} + +impl Accumulator + for DecimalDistinctAvgAccumulator +{ + fn state(&mut self) -> Result> { + self.sum_accumulator.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + if self.sum_accumulator.distinct_count() == 0 { + return ScalarValue::new_primitive::( + None, + &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), + ); + } + + let sum_scalar = self.sum_accumulator.evaluate()?; + + match sum_scalar { + ScalarValue::Decimal32(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + let avg = decimal_averager + .avg(sum, self.sum_accumulator.distinct_count() as i32)?; + Ok(ScalarValue::Decimal32( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + ScalarValue::Decimal64(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + let avg = decimal_averager + .avg(sum, self.sum_accumulator.distinct_count() as i64)?; + Ok(ScalarValue::Decimal64( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + ScalarValue::Decimal128(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + let avg = decimal_averager + .avg(sum, self.sum_accumulator.distinct_count() as i128)?; + Ok(ScalarValue::Decimal128( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + ScalarValue::Decimal256(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + // `distinct_count` returns `u64`, but `avg` expects `i256` + // first convert `u64` to `i128`, then convert `i128` to `i256` to avoid overflow + let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128; + let count: i256 = i256::from_i128(distinct_cnt); + let avg = decimal_averager.avg(sum, count)?; + Ok(ScalarValue::Decimal256( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + + _ => unreachable!("Unsupported decimal type: {:?}", sum_scalar), + } + } + + fn size(&self) -> usize { + let fixed_size = size_of_val(self); + + // Account for the size of the sum_accumulator with its contained values + fixed_size + self.sum_accumulator.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + }; + use std::sync::Arc; + + #[test] + fn test_decimal32_distinct_avg_accumulator() -> Result<()> { + let precision = 5_u8; + let scale = 2_i8; + let array = Decimal32Array::from(vec![ + Some(10_00), + Some(12_50), + Some(17_50), + Some(20_00), + Some(20_00), + Some(30_00), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 9, 6, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = ScalarValue::Decimal32(Some(18000000), 9, 6); + assert_eq!(result, expected_result); + + Ok(()) + } + + #[test] + fn test_decimal64_distinct_avg_accumulator() -> Result<()> { + let precision = 10_u8; + let scale = 4_i8; + let array = Decimal64Array::from(vec![ + Some(100_0000), + Some(125_0000), + Some(175_0000), + Some(200_0000), + Some(200_0000), + Some(300_0000), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 14, 8, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = ScalarValue::Decimal64(Some(180_00000000), 14, 8); + assert_eq!(result, expected_result); + + Ok(()) + } + + #[test] + fn test_decimal128_distinct_avg_accumulator() -> Result<()> { + let precision = 10_u8; + let scale = 4_i8; + let array = Decimal128Array::from(vec![ + Some(100_0000), + Some(125_0000), + Some(175_0000), + Some(200_0000), + Some(200_0000), + Some(300_0000), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 14, 8, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8); + assert_eq!(result, expected_result); + + Ok(()) + } + + #[test] + fn test_decimal256_distinct_avg_accumulator() -> Result<()> { + let precision = 50_u8; + let scale = 2_i8; + + let array = Decimal256Array::from(vec![ + Some(i256::from_i128(10_000)), + Some(i256::from_i128(12_500)), + Some(i256::from_i128(17_500)), + Some(i256::from_i128(20_000)), + Some(i256::from_i128(20_000)), + Some(i256::from_i128(30_000)), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 54, 6, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = + ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6); + assert_eq!(result, expected_result); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs new file mode 100644 index 0000000000000..bb43acc2614f9 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Float64Type}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr_common::accumulator::Accumulator; + +use crate::aggregate::sum_distinct::DistinctSumAccumulator; + +/// Specialized implementation of `AVG DISTINCT` for Float64 values, leveraging +/// the existing DistinctSumAccumulator implementation. +#[derive(Debug)] +pub struct Float64DistinctAvgAccumulator { + // We use the DistinctSumAccumulator to handle the set of distinct values + sum_accumulator: DistinctSumAccumulator, +} + +impl Default for Float64DistinctAvgAccumulator { + fn default() -> Self { + Self { + sum_accumulator: DistinctSumAccumulator::::new( + &DataType::Float64, + ), + } + } +} + +impl Accumulator for Float64DistinctAvgAccumulator { + fn state(&mut self) -> Result> { + self.sum_accumulator.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + // Get the sum from the DistinctSumAccumulator + let sum_result = self.sum_accumulator.evaluate()?; + + // Extract the sum value + if let ScalarValue::Float64(Some(sum)) = sum_result { + // Get the count of distinct values + let count = self.sum_accumulator.distinct_count() as f64; + // Calculate average + let avg = sum / count; + Ok(ScalarValue::Float64(Some(avg))) + } else { + // If sum is None, return None (null) + Ok(ScalarValue::Float64(None)) + } + } + + fn size(&self) -> usize { + self.sum_accumulator.size() + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs index 7d772f7c649dc..25b40382299b4 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs @@ -16,9 +16,11 @@ // under the License. mod bytes; +mod dict; mod native; pub use bytes::BytesDistinctCountAccumulator; pub use bytes::BytesViewDistinctCountAccumulator; +pub use dict::DictionaryCountAccumulator; pub use native::FloatDistinctCountAccumulator; pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs new file mode 100644 index 0000000000000..089d8d5acded1 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::downcast_dictionary_array; +use datafusion_common::{arrow_datafusion_err, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError}; +use datafusion_expr_common::accumulator::Accumulator; + +#[derive(Debug)] +pub struct DictionaryCountAccumulator { + inner: Box, +} + +impl DictionaryCountAccumulator { + pub fn new(inner: Box) -> Self { + Self { inner } + } +} + +impl Accumulator for DictionaryCountAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + let values: Vec<_> = values + .iter() + .map(|dict| { + downcast_dictionary_array! { + dict => { + let buff: BooleanArray = dict.occupancy().into(); + arrow::compute::filter( + dict.values(), + &buff + ).map_err(|e| arrow_datafusion_err!(e)) + }, + _ => internal_err!("DictionaryCountAccumulator only supports dictionary arrays") + } + }) + .collect::, _>>()?; + self.inner.update_batch(values.as_slice()) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + self.inner.evaluate() + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> datafusion_common::Result> { + self.inner.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + self.inner.merge_batch(states) + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index e629e99e1657a..987ba57f7719e 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -636,7 +636,7 @@ mod test { #[test] fn accumulate_fuzz() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..100 { Fixture::new_random(&mut rng).run(); } @@ -661,23 +661,23 @@ mod test { impl Fixture { fn new_random(rng: &mut ThreadRng) -> Self { // Number of input values in a batch - let num_values: usize = rng.gen_range(1..200); + let num_values: usize = rng.random_range(1..200); // number of distinct groups - let num_groups: usize = rng.gen_range(2..1000); + let num_groups: usize = rng.random_range(2..1000); let max_group = num_groups - 1; let group_indices: Vec = (0..num_values) - .map(|_| rng.gen_range(0..max_group)) + .map(|_| rng.random_range(0..max_group)) .collect(); - let values: Vec = (0..num_values).map(|_| rng.gen()).collect(); + let values: Vec = (0..num_values).map(|_| rng.random()).collect(); // 10% chance of false // 10% change of null // 80% chance of true let filter: BooleanArray = (0..num_values) .map(|_| { - let filter_value = rng.gen_range(0.0..1.0); + let filter_value = rng.random_range(0.0..1.0); if filter_value < 0.1 { Some(false) } else if filter_value < 0.2 { @@ -690,14 +690,14 @@ mod test { // random values with random number and location of nulls // random null percentage - let null_pct: f32 = rng.gen_range(0.0..1.0); + let null_pct: f32 = rng.random_range(0.0..1.0); let values_with_nulls: Vec> = (0..num_values) .map(|_| { - let is_null = null_pct < rng.gen_range(0.0..1.0); + let is_null = null_pct < rng.random_range(0.0..1.0); if is_null { None } else { - Some(rng.gen()) + Some(rng.random()) } }) .collect(); diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 6a8946034cbc3..c8c7736bba14f 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -20,7 +20,7 @@ use arrow::array::{ Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, - StringViewArray, + StringViewArray, StructArray, }; use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; @@ -193,6 +193,18 @@ pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result { + let input = input.as_struct(); + // safety: values / offsets came from a valid struct array + // and we checked nulls has the same length as values + unsafe { + Arc::new(StructArray::new_unchecked( + input.fields().clone(), + input.columns().to_vec(), + nulls, + )) + } + } _ => { return not_impl_err!("Applying nulls {:?}", input.data_type()); } diff --git a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct.rs new file mode 100644 index 0000000000000..932bfba0bf0dc --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sum distinct accumulator implementations + +pub mod numeric; + +pub use numeric::DistinctSumAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs new file mode 100644 index 0000000000000..3021783a2a79c --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the accumulator for `SUM DISTINCT` for primitive numeric types + +use std::collections::HashSet; +use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; + +use ahash::RandomState; +use arrow::array::Array; +use arrow::array::ArrayRef; +use arrow::array::ArrowNativeTypeOp; +use arrow::array::ArrowPrimitiveType; +use arrow::array::AsArray; +use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::DataType; + +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr_common::accumulator::Accumulator; + +use crate::utils::Hashable; + +/// Accumulator for computing SUM(DISTINCT expr) +pub struct DistinctSumAccumulator { + values: HashSet, RandomState>, + data_type: DataType, +} + +impl Debug for DistinctSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctSumAccumulator({})", self.data_type) + } +} + +impl DistinctSumAccumulator { + pub fn new(data_type: &DataType) -> Self { + Self { + values: HashSet::default(), + data_type: data_type.clone(), + } + } + + pub fn distinct_count(&self) -> usize { + self.values.len() + } +} + +impl Accumulator for DistinctSumAccumulator { + fn state(&mut self) -> Result> { + // 1. Stores aggregate state in `ScalarValue::List` + // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set + let state_out = { + let distinct_values = self + .values + .iter() + .map(|value| { + ScalarValue::new_primitive::(Some(value.0), &self.data_type) + }) + .collect::>>()?; + + vec![ScalarValue::List(ScalarValue::new_list_nullable( + &distinct_values, + &self.data_type, + ))] + }; + Ok(state_out) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(Hashable(array.value(idx))); + } + } + None => array.values().iter().for_each(|x| { + self.values.insert(Hashable(*x)); + }), + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + for x in states[0].as_list::().iter().flatten() { + self.update_batch(&[x])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc.add_wrapping(distinct_value.0) + } + let v = (!self.values.is_empty()).then_some(acc); + ScalarValue::new_primitive::(v, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + self.values.capacity() * size_of::() + } +} diff --git a/datafusion/functions-aggregate-common/src/lib.rs b/datafusion/functions-aggregate-common/src/lib.rs index da718e7ceefe6..a07ef4d597cf2 100644 --- a/datafusion/functions-aggregate-common/src/lib.rs +++ b/datafusion/functions-aggregate-common/src/lib.rs @@ -26,7 +26,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -34,6 +34,7 @@ pub mod accumulator; pub mod aggregate; pub mod merge_arrays; +pub mod min_max; pub mod order; pub mod stats; pub mod tdigest; diff --git a/datafusion/functions-aggregate-common/src/merge_arrays.rs b/datafusion/functions-aggregate-common/src/merge_arrays.rs index 0cfea662497e1..bdf1490417beb 100644 --- a/datafusion/functions-aggregate-common/src/merge_arrays.rs +++ b/datafusion/functions-aggregate-common/src/merge_arrays.rs @@ -67,6 +67,7 @@ impl<'a> CustomElement<'a> { // - When used inside `BinaryHeap` it is a min-heap. impl Ord for CustomElement<'_> { fn cmp(&self, other: &Self) -> Ordering { + // TODO Ord/PartialOrd is not consistent with PartialEq; PartialOrd contract is violated // Compares according to custom ordering self.ordering(&self.ordering, &other.ordering) // Convert max heap to min heap @@ -86,7 +87,7 @@ impl PartialOrd for CustomElement<'_> { /// This functions merges `values` array (`&[Vec]`) into single array `Vec` /// Merging done according to ordering values stored inside `ordering_values` (`&[Vec>]`) -/// Inner `Vec` in the `ordering_values` can be thought as ordering information for the +/// Inner `Vec` in the `ordering_values` can be thought as ordering information for /// each `ScalarValue` in the `values` array. /// Desired ordering specified by `sort_options` argument (Should have same size with inner `Vec` /// of the `ordering_values` array). @@ -118,17 +119,27 @@ pub fn merge_ordered_arrays( // Defines according to which ordering comparisons should be done. sort_options: &[SortOptions], ) -> datafusion_common::Result<(Vec, Vec>)> { - // Keep track the most recent data of each branch, in binary heap data structure. + // Keep track of the most recent data of each branch, in a binary heap data structure. let mut heap = BinaryHeap::::new(); - if values.len() != ordering_values.len() - || values - .iter() - .zip(ordering_values.iter()) - .any(|(vals, ordering_vals)| vals.len() != ordering_vals.len()) + if values.len() != ordering_values.len() { + return exec_err!( + "Expects values and ordering_values to have same size but got {} and {}", + values.len(), + ordering_values.len() + ); + } + if let Some((idx, (values, ordering_values))) = values + .iter() + .zip(ordering_values.iter()) + .enumerate() + .find(|(_, (vals, ordering_vals))| vals.len() != ordering_vals.len()) { return exec_err!( - "Expects values arguments and/or ordering_values arguments to have same size" + "Expects values elements and ordering_values elements to have same size but got {} and {} at index {}", + values.len(), + ordering_values.len(), + idx ); } let n_branch = values.len(); diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs new file mode 100644 index 0000000000000..7dd60e1c0e1b4 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -0,0 +1,854 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Basic min/max functionality shared across DataFusion aggregate functions + +use arrow::array::{ + ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, + Date64Array, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, + DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, + LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::compute; +use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; +use datafusion_common::{ + downcast_value, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr_common::accumulator::Accumulator; +use std::{cmp::Ordering, mem::size_of_val}; + +// min/max of two non-string scalar values. +macro_rules! typed_min_max { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $($EXTRA_ARGS.clone()),* + ) + }}; +} + +macro_rules! typed_min_max_float { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => match a.total_cmp(b) { + choose_min_max!($OP) => Some(*b), + _ => Some(*a), + }, + }) + }}; +} + +// min/max of two scalar string values. +macro_rules! typed_min_max_string { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }) + }}; +} + +// min/max of two scalar string values with a prefix argument. +macro_rules! typed_min_max_string_arg { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $ARG:expr) => {{ + ScalarValue::$SCALAR( + $ARG, + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }, + ) + }}; +} + +macro_rules! choose_min_max { + (min) => { + std::cmp::Ordering::Greater + }; + (max) => { + std::cmp::Ordering::Less + }; +} + +macro_rules! interval_min_max { + ($OP:tt, $LHS:expr, $RHS:expr) => {{ + match $LHS.partial_cmp(&$RHS) { + Some(choose_min_max!($OP)) => $RHS.clone(), + Some(_) => $LHS.clone(), + None => { + return internal_err!("Comparison error while computing interval min/max") + } + } + }}; +} + +macro_rules! min_max_generic { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + if $VALUE.is_null() { + let mut delta_copy = $DELTA.clone(); + // When the new value won we want to compact it to + // avoid storing the entire input + delta_copy.compact(); + delta_copy + } else if $DELTA.is_null() { + $VALUE.clone() + } else { + match $VALUE.partial_cmp(&$DELTA) { + Some(choose_min_max!($OP)) => { + // When the new value won we want to compact it to + // avoid storing the entire input + let mut delta_copy = $DELTA.clone(); + delta_copy.compact(); + delta_copy + } + _ => $VALUE.clone(), + } + } + }}; +} + +// min/max of two scalar values of the same type +macro_rules! min_max { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + Ok(match ($VALUE, $DELTA) { + (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, + ( + lhs @ ScalarValue::Decimal32(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal32(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal32, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal64(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal64(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal64, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { + typed_min_max!(lhs, rhs, Boolean, $OP) + } + (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { + typed_min_max_float!(lhs, rhs, Float64, $OP) + } + (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { + typed_min_max_float!(lhs, rhs, Float32, $OP) + } + (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { + typed_min_max_float!(lhs, rhs, Float16, $OP) + } + (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { + typed_min_max!(lhs, rhs, UInt64, $OP) + } + (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { + typed_min_max!(lhs, rhs, UInt32, $OP) + } + (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { + typed_min_max!(lhs, rhs, UInt16, $OP) + } + (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { + typed_min_max!(lhs, rhs, UInt8, $OP) + } + (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { + typed_min_max!(lhs, rhs, Int64, $OP) + } + (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { + typed_min_max!(lhs, rhs, Int32, $OP) + } + (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { + typed_min_max!(lhs, rhs, Int16, $OP) + } + (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { + typed_min_max!(lhs, rhs, Int8, $OP) + } + (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8, $OP) + } + (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) + } + (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8View, $OP) + } + (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { + typed_min_max_string!(lhs, rhs, Binary, $OP) + } + (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeBinary, $OP) + } + (ScalarValue::FixedSizeBinary(lsize, lhs), ScalarValue::FixedSizeBinary(rsize, rhs)) => { + if lsize == rsize { + typed_min_max_string_arg!(lhs, rhs, FixedSizeBinary, $OP, *lsize) + } + else { + return internal_err!( + "MIN/MAX is not expected to receive FixedSizeBinary of incompatible sizes {:?}", + (lsize, rsize)) + } + } + (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { + typed_min_max_string!(lhs, rhs, BinaryView, $OP) + } + (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { + typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMillisecond(lhs, l_tz), + ScalarValue::TimestampMillisecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMicrosecond(lhs, l_tz), + ScalarValue::TimestampMicrosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) + } + ( + ScalarValue::TimestampNanosecond(lhs, l_tz), + ScalarValue::TimestampNanosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) + } + ( + ScalarValue::Date32(lhs), + ScalarValue::Date32(rhs), + ) => { + typed_min_max!(lhs, rhs, Date32, $OP) + } + ( + ScalarValue::Date64(lhs), + ScalarValue::Date64(rhs), + ) => { + typed_min_max!(lhs, rhs, Date64, $OP) + } + ( + ScalarValue::Time32Second(lhs), + ScalarValue::Time32Second(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Second, $OP) + } + ( + ScalarValue::Time32Millisecond(lhs), + ScalarValue::Time32Millisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Millisecond, $OP) + } + ( + ScalarValue::Time64Microsecond(lhs), + ScalarValue::Time64Microsecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Microsecond, $OP) + } + ( + ScalarValue::Time64Nanosecond(lhs), + ScalarValue::Time64Nanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) + } + ( + ScalarValue::IntervalYearMonth(lhs), + ScalarValue::IntervalYearMonth(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) + } + ( + ScalarValue::IntervalMonthDayNano(lhs), + ScalarValue::IntervalMonthDayNano(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) + } + ( + ScalarValue::IntervalDayTime(lhs), + ScalarValue::IntervalDayTime(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalDayTime, $OP) + } + ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalMonthDayNano(_), + ) | ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalMonthDayNano(_), + ) => { + interval_min_max!($OP, $VALUE, $DELTA) + } + ( + ScalarValue::DurationSecond(lhs), + ScalarValue::DurationSecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationSecond, $OP) + } + ( + ScalarValue::DurationMillisecond(lhs), + ScalarValue::DurationMillisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMillisecond, $OP) + } + ( + ScalarValue::DurationMicrosecond(lhs), + ScalarValue::DurationMicrosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) + } + ( + ScalarValue::DurationNanosecond(lhs), + ScalarValue::DurationNanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationNanosecond, $OP) + } + + ( + lhs @ ScalarValue::Struct(_), + rhs @ ScalarValue::Struct(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + ( + lhs @ ScalarValue::List(_), + rhs @ ScalarValue::List(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::LargeList(_), + rhs @ ScalarValue::LargeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::FixedSizeList(_), + rhs @ ScalarValue::FixedSizeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + e => { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + e + ) + } + }) + }}; +} + +/// An accumulator to compute the maximum value +#[derive(Debug, Clone)] +pub struct MaxAccumulator { + max: ScalarValue, +} + +impl MaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MaxAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &max_batch(values)?; + let new_max: Result = + min_max!(&self.max, delta, max); + self.max = new_max?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.max) + self.max.size() + } +} + +/// An accumulator to compute the minimum value +#[derive(Debug, Clone)] +pub struct MinAccumulator { + min: ScalarValue, +} + +impl MinAccumulator { + /// new min accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &min_batch(values)?; + let new_min: Result = + min_max!(&self.min, delta, min); + self.min = new_min?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.min) + self.min.size() + } +} + +// Statically-typed version of min/max(array) -> ScalarValue for string types +macro_rules! typed_min_max_batch_string { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_string())); + ScalarValue::$SCALAR(value) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for binary types. +macro_rules! typed_min_max_batch_binary { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_vec())); + ScalarValue::$SCALAR(value) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +macro_rules! typed_min_max_batch { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +// this is a macro to support both operations (min and max). +macro_rules! min_max_batch { + ($VALUES:expr, $OP:ident) => {{ + match $VALUES.data_type() { + DataType::Null => ScalarValue::Null, + DataType::Decimal32(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal32Array, + Decimal32, + $OP, + precision, + scale + ) + } + DataType::Decimal64(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal64Array, + Decimal64, + $OP, + precision, + scale + ) + } + DataType::Decimal128(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal128Array, + Decimal128, + $OP, + precision, + scale + ) + } + DataType::Decimal256(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal256Array, + Decimal256, + $OP, + precision, + scale + ) + } + // all types that have a natural order + DataType::Float64 => { + typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) + } + DataType::Float32 => { + typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) + } + DataType::Float16 => { + typed_min_max_batch!($VALUES, Float16Array, Float16, $OP) + } + DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), + DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), + DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), + DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), + DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), + DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), + DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), + DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_min_max_batch!( + $VALUES, + TimestampSecondArray, + TimestampSecond, + $OP, + tz_opt + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMillisecondArray, + TimestampMillisecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMicrosecondArray, + TimestampMicrosecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampNanosecondArray, + TimestampNanosecond, + $OP, + tz_opt + ), + DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), + DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), + DataType::Time32(TimeUnit::Second) => { + typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) + } + DataType::Time32(TimeUnit::Millisecond) => { + typed_min_max_batch!( + $VALUES, + Time32MillisecondArray, + Time32Millisecond, + $OP + ) + } + DataType::Time64(TimeUnit::Microsecond) => { + typed_min_max_batch!( + $VALUES, + Time64MicrosecondArray, + Time64Microsecond, + $OP + ) + } + DataType::Time64(TimeUnit::Nanosecond) => { + typed_min_max_batch!( + $VALUES, + Time64NanosecondArray, + Time64Nanosecond, + $OP + ) + } + DataType::Interval(IntervalUnit::YearMonth) => { + typed_min_max_batch!( + $VALUES, + IntervalYearMonthArray, + IntervalYearMonth, + $OP + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + typed_min_max_batch!( + $VALUES, + IntervalMonthDayNanoArray, + IntervalMonthDayNano, + $OP + ) + } + DataType::Duration(TimeUnit::Second) => { + typed_min_max_batch!($VALUES, DurationSecondArray, DurationSecond, $OP) + } + DataType::Duration(TimeUnit::Millisecond) => { + typed_min_max_batch!( + $VALUES, + DurationMillisecondArray, + DurationMillisecond, + $OP + ) + } + DataType::Duration(TimeUnit::Microsecond) => { + typed_min_max_batch!( + $VALUES, + DurationMicrosecondArray, + DurationMicrosecond, + $OP + ) + } + DataType::Duration(TimeUnit::Nanosecond) => { + typed_min_max_batch!( + $VALUES, + DurationNanosecondArray, + DurationNanosecond, + $OP + ) + } + other => { + // This should have been handled before + return datafusion_common::internal_err!( + "Min/Max accumulator not implemented for type {}", + other + ); + } + } + }}; +} + +/// dynamically-typed min(array) -> ScalarValue +pub fn min_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, min_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) + } + DataType::Utf8View => { + typed_min_max_batch_string!( + values, + StringViewArray, + Utf8View, + min_string_view + ) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + min_binary + ) + } + DataType::FixedSizeBinary(size) => { + let array = downcast_value!(&values, FixedSizeBinaryArray); + let value = compute::min_fixed_size_binary(array); + let value = value.map(|e| e.to_vec()); + ScalarValue::FixedSizeBinary(*size, value) + } + DataType::BinaryView => { + typed_min_max_batch_binary!( + &values, + BinaryViewArray, + BinaryView, + min_binary_view + ) + } + DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::FixedSizeList(_, _) => { + min_max_batch_generic(values, Ordering::Greater)? + } + DataType::Dictionary(_, _) => { + let values = values.as_any_dictionary().values(); + min_batch(values)? + } + _ => min_max_batch!(values, min), + }) +} + +/// Generic min/max implementation for complex types +fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result { + if array.len() == array.null_count() { + return ScalarValue::try_from(array.data_type()); + } + let mut extreme = ScalarValue::try_from_array(array, 0)?; + for i in 1..array.len() { + let current = ScalarValue::try_from_array(array, i)?; + if current.is_null() { + continue; + } + if extreme.is_null() { + extreme = current; + continue; + } + let cmp = extreme.try_cmp(¤t)?; + if cmp == ordering { + extreme = current; + } + } + + Ok(extreme) +} + +/// dynamically-typed max(array) -> ScalarValue +pub fn max_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, max_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) + } + DataType::Utf8View => { + typed_min_max_batch_string!( + values, + StringViewArray, + Utf8View, + max_string_view + ) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) + } + DataType::BinaryView => { + typed_min_max_batch_binary!( + &values, + BinaryViewArray, + BinaryView, + max_binary_view + ) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + max_binary + ) + } + DataType::FixedSizeBinary(size) => { + let array = downcast_value!(&values, FixedSizeBinaryArray); + let value = compute::max_fixed_size_binary(array); + let value = value.map(|e| e.to_vec()); + ScalarValue::FixedSizeBinary(*size, value) + } + DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, + DataType::Dictionary(_, _) => { + let values = values.as_any_dictionary().values(); + max_batch(values)? + } + _ => min_max_batch!(values, max), + }) +} diff --git a/datafusion/functions-aggregate-common/src/order.rs b/datafusion/functions-aggregate-common/src/order.rs index bfa6e39138f9e..0908396d78341 100644 --- a/datafusion/functions-aggregate-common/src/order.rs +++ b/datafusion/functions-aggregate-common/src/order.rs @@ -22,9 +22,20 @@ pub enum AggregateOrderSensitivity { /// Ordering at the input is not important for the result of the aggregator. Insensitive, /// Indicates that the aggregate expression has a hard requirement on ordering. - /// The aggregator can not produce a correct result unless its ordering + /// The aggregator cannot produce a correct result unless its ordering /// requirement is satisfied. HardRequirement, + /// Indicates that the aggregator is more efficient when the input is ordered + /// but can still produce its result correctly regardless of the input ordering. + /// This is similar to, but stronger than, [`Self::Beneficial`]. + /// + /// Similarly to [`Self::HardRequirement`], when possible DataFusion will insert + /// a `SortExec`, to reorder the input to match the SoftRequirement. However, + /// when such a `SortExec` cannot be inserted, (for example, due to conflicting + /// [`Self::HardRequirement`] with other ordered aggregates in the query), + /// the aggregate function will still execute, without the preferred order, unlike + /// with [`Self::HardRequirement`] + SoftRequirement, /// Indicates that ordering is beneficial for the aggregate expression in terms /// of evaluation efficiency. The aggregator can produce its result efficiently /// when its required ordering is satisfied; however, it can still produce the @@ -38,7 +49,7 @@ impl AggregateOrderSensitivity { } pub fn is_beneficial(&self) -> bool { - self.eq(&AggregateOrderSensitivity::Beneficial) + matches!(self, Self::SoftRequirement | Self::Beneficial) } pub fn hard_requires(&self) -> bool { diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs index 378fc8c42bc66..370a640b046a6 100644 --- a/datafusion/functions-aggregate-common/src/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -45,7 +45,7 @@ macro_rules! cast_scalar_f64 { ($value:expr ) => { match &$value { ScalarValue::Float64(Some(v)) => *v, - v => panic!("invalid type {:?}", v), + v => panic!("invalid type {}", v), } }; } @@ -56,7 +56,7 @@ macro_rules! cast_scalar_u64 { ($value:expr ) => { match &$value { ScalarValue::UInt64(Some(v)) => *v, - v => panic!("invalid type {:?}", v), + v => panic!("invalid type {}", v), } }; } @@ -103,20 +103,6 @@ pub struct Centroid { weight: f64, } -impl PartialOrd for Centroid { - fn partial_cmp(&self, other: &Centroid) -> Option { - Some(self.cmp(other)) - } -} - -impl Eq for Centroid {} - -impl Ord for Centroid { - fn cmp(&self, other: &Centroid) -> Ordering { - self.mean.total_cmp(&other.mean) - } -} - impl Centroid { pub fn new(mean: f64, weight: f64) -> Self { Centroid { mean, weight } @@ -139,6 +125,10 @@ impl Centroid { self.mean = new_sum / new_weight; new_sum } + + pub fn cmp_mean(&self, other: &Self) -> Ordering { + self.mean.total_cmp(&other.mean) + } } impl Default for Centroid { @@ -331,7 +321,7 @@ impl TDigest { result.sum += curr.add(sums_to_merge, weights_to_merge); compressed.push(curr); compressed.shrink_to_fit(); - compressed.sort(); + compressed.sort_by(|a, b| a.cmp_mean(b)); result.centroids = compressed; result @@ -349,7 +339,7 @@ impl TDigest { let mut j = middle; while i < middle && j < last { - match centroids[i].cmp(¢roids[j]) { + match centroids[i].cmp_mean(¢roids[j]) { Ordering::Less => { result.push(centroids[i].clone()); i += 1; @@ -466,7 +456,7 @@ impl TDigest { result.sum += curr.add(sums_to_merge, weights_to_merge); compressed.push(curr.clone()); compressed.shrink_to_fit(); - compressed.sort(); + compressed.sort_by(|a, b| a.cmp_mean(b)); result.count = count; result.min = min; diff --git a/datafusion/functions-aggregate-common/src/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs index 083dac615b5d1..b01f2c8629c9b 100644 --- a/datafusion/functions-aggregate-common/src/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -15,22 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use arrow::array::{ArrayRef, AsArray}; -use arrow::datatypes::ArrowNativeType; -use arrow::{ - array::ArrowNativeTypeOp, - compute::SortOptions, - datatypes::{ - DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, - ToByteSlice, - }, +use arrow::array::{ArrayRef, ArrowNativeTypeOp}; +use arrow::compute::SortOptions; +use arrow::datatypes::{ + ArrowNativeType, DataType, DecimalType, Field, FieldRef, ToByteSlice, }; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, internal_datafusion_err, Result}; use datafusion_expr_common::accumulator::Accumulator; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use std::sync::Arc; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( @@ -43,57 +36,13 @@ pub fn get_accum_scalar_values_as_arrays( .collect() } -/// Adjust array type metadata if needed -/// -/// Since `Decimal128Arrays` created from `Vec` have -/// default precision and scale, this function adjusts the output to -/// match `data_type`, if necessary -#[deprecated(since = "44.0.0", note = "use PrimitiveArray::with_datatype")] -pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result { - let array = match data_type { - DataType::Decimal128(p, s) => Arc::new( - array - .as_primitive::() - .clone() - .with_precision_and_scale(*p, *s)?, - ) as ArrayRef, - DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( - array - .as_primitive::() - .clone() - .with_timezone_opt(tz.clone()), - ), - DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new( - array - .as_primitive::() - .clone() - .with_timezone_opt(tz.clone()), - ), - DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new( - array - .as_primitive::() - .clone() - .with_timezone_opt(tz.clone()), - ), - DataType::Timestamp(TimeUnit::Second, tz) => Arc::new( - array - .as_primitive::() - .clone() - .with_timezone_opt(tz.clone()), - ), - // no adjustment needed for other arrays - _ => array, - }; - Ok(array) -} - -/// Construct corresponding fields for lexicographical ordering requirement expression +/// Construct corresponding fields for the expressions in an ORDER BY clause. pub fn ordering_fields( - ordering_req: &LexOrdering, + order_bys: &[PhysicalSortExpr], // Data type of each expression in the ordering requirement data_types: &[DataType], -) -> Vec { - ordering_req +) -> Vec { + order_bys .iter() .zip(data_types.iter()) .map(|(sort_expr, dtype)| { @@ -104,6 +53,7 @@ pub fn ordering_fields( true, ) }) + .map(Arc::new) .collect() } @@ -162,15 +112,17 @@ impl DecimalAverager { ) -> Result { let sum_mul = T::Native::from_usize(10_usize) .map(|b| b.pow_wrapping(sum_scale as u32)) - .ok_or(DataFusionError::Internal( - "Failed to compute sum_mul in DecimalAverager".to_string(), - ))?; + .ok_or_else(|| { + internal_datafusion_err!("Failed to compute sum_mul in DecimalAverager") + })?; let target_mul = T::Native::from_usize(10_usize) .map(|b| b.pow_wrapping(target_scale as u32)) - .ok_or(DataFusionError::Internal( - "Failed to compute target_mul in DecimalAverager".to_string(), - ))?; + .ok_or_else(|| { + internal_datafusion_err!( + "Failed to compute target_mul in DecimalAverager" + ) + })?; if target_mul >= sum_mul { Ok(Self { diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index ec6e6b633bb81..ffc6f3bb7a10a 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -68,3 +68,7 @@ harness = false [[bench]] name = "array_agg" harness = false + +[[bench]] +harness = false +name = "min_max_bytes" diff --git a/datafusion/functions-aggregate/README.md b/datafusion/functions-aggregate/README.md index 29b313d2a9037..aa50eaeedae03 100644 --- a/datafusion/functions-aggregate/README.md +++ b/datafusion/functions-aggregate/README.md @@ -17,11 +17,16 @@ under the License. --> -# DataFusion Aggregate Function Library +# Apache DataFusion Aggregate Function Library -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. -This crate contains packages of function that can be used to customize the -functionality of DataFusion. +This crate contains implementations of aggregate functions. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-aggregate/benches/array_agg.rs b/datafusion/functions-aggregate/benches/array_agg.rs index fb605e87ed0cc..96444b018465c 100644 --- a/datafusion/functions-aggregate/benches/array_agg.rs +++ b/datafusion/functions-aggregate/benches/array_agg.rs @@ -19,17 +19,23 @@ use std::sync::Arc; use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, AsArray, ListArray, NullBufferBuilder, + PrimitiveArray, }; use arrow::datatypes::{Field, Int64Type}; -use arrow::util::bench_util::create_primitive_array; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::Accumulator; use datafusion_functions_aggregate::array_agg::ArrayAggAccumulator; use arrow::buffer::OffsetBuffer; -use arrow::util::test_util::seedable_rng; -use rand::distributions::{Distribution, Standard}; +use rand::distr::{Distribution, StandardUniform}; +use rand::prelude::StdRng; use rand::Rng; +use rand::SeedableRng; + +/// Returns fixed seedable RNG +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { let list_item_data_type = values.as_list::().values().data_type().clone(); @@ -37,17 +43,35 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { b.iter(|| { #[allow(clippy::unit_arg)] black_box( - ArrayAggAccumulator::try_new(&list_item_data_type) + ArrayAggAccumulator::try_new(&list_item_data_type, false) .unwrap() - .merge_batch(&[values.clone()]) + .merge_batch(std::slice::from_ref(&values)) .unwrap(), ) }) }); } +pub fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray +where + T: ArrowPrimitiveType, + StandardUniform: Distribution, +{ + let mut rng = seedable_rng(); + + (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random()) + } + }) + .collect() +} + /// Create List array with the given item data type, null density, null locations and zero length lists density -/// Creates an random (but fixed-seeded) array of a given size and null density +/// Creates a random (but fixed-seeded) array of a given size and null density pub fn create_list_array( size: usize, null_density: f32, @@ -55,20 +79,20 @@ pub fn create_list_array( ) -> ListArray where T: ArrowPrimitiveType, - Standard: Distribution, + StandardUniform: Distribution, { let mut nulls_builder = NullBufferBuilder::new(size); - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); let offsets = OffsetBuffer::from_lengths((0..size).map(|_| { - let is_null = rng.gen::() < null_density; + let is_null = rng.random::() < null_density; - let mut length = rng.gen_range(1..10); + let mut length = rng.random_range(1..10); if is_null { nulls_builder.append_null(); - if rng.gen::() <= zero_length_lists_probability { + if rng.random::() <= zero_length_lists_probability { length = 0; } } else { diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 8bde7d04c44d9..37c7fad4bd32f 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -15,23 +15,29 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; -use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; +use arrow::util::bench_util::{ + create_boolean_array, create_dict_from_values, create_primitive_array, + create_string_array_with_len, +}; + +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{Accumulator, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::sync::Arc; -fn prepare_accumulator() -> Box { +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +fn prepare_group_accumulator() -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); let accumulator_args = AccumulatorArgs { - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], is_reversed: false, name: "COUNT(f)", is_distinct: false, @@ -44,18 +50,39 @@ fn prepare_accumulator() -> Box { .unwrap() } +fn prepare_accumulator() -> Box { + let schema = Arc::new(Schema::new(vec![Field::new( + "f", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )])); + let accumulator_args = AccumulatorArgs { + return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "COUNT(f)", + is_distinct: true, + exprs: &[col("f", &schema).unwrap()], + }; + let count_fn = Count::new(); + + count_fn.accumulator(accumulator_args).unwrap() +} + fn convert_to_state_bench( c: &mut Criterion, name: &str, values: ArrayRef, opt_filter: Option<&BooleanArray>, ) { - let accumulator = prepare_accumulator(); + let accumulator = prepare_group_accumulator(); c.bench_function(name, |b| { b.iter(|| { black_box( accumulator - .convert_to_state(&[values.clone()], opt_filter) + .convert_to_state(std::slice::from_ref(&values), opt_filter) .unwrap(), ) }) @@ -89,6 +116,22 @@ fn count_benchmark(c: &mut Criterion) { values, Some(&filter), ); + + let arr = create_string_array_with_len::(20, 0.0, 50); + let values = + Arc::new(create_dict_from_values::(200_000, 0.8, &arr)) as ArrayRef; + + let mut accumulator = prepare_accumulator(); + c.bench_function("count low cardinality dict 20% nulls, no filter", |b| { + b.iter(|| { + #[allow(clippy::unit_arg)] + black_box( + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap(), + ) + }) + }); } criterion_group!(benches, count_benchmark); diff --git a/datafusion/functions-aggregate/benches/min_max_bytes.rs b/datafusion/functions-aggregate/benches/min_max_bytes.rs new file mode 100644 index 0000000000000..a438ee5697a2c --- /dev/null +++ b/datafusion/functions-aggregate/benches/min_max_bytes.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// A minimal benchmark of the min_max accumulator for byte-like data types. +// +// The benchmark simulates the insertion of NUM_BATCHES batches into an aggregation, +// where every row belongs to a distinct group. The data generated beforehand to +// ensure that (mostly) the cost of the update_batch method is measured. +// +// The throughput value describes the rows per second that are ingested. + +use std::sync::Arc; + +use arrow::{ + array::{ArrayRef, StringArray}, + datatypes::{DataType, Field, Schema}, +}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use datafusion_expr::{function::AccumulatorArgs, GroupsAccumulator}; +use datafusion_functions_aggregate::min_max; +use datafusion_physical_expr::expressions::col; + +const BATCH_SIZE: usize = 8192; + +fn create_max_bytes_accumulator() -> Box { + let input_schema = + Arc::new(Schema::new(vec![Field::new("value", DataType::Utf8, true)])); + + let max = min_max::max_udaf(); + max.create_groups_accumulator(AccumulatorArgs { + return_field: Arc::new(Field::new("value", DataType::Utf8, true)), + schema: &input_schema, + ignore_nulls: true, + order_bys: &[], + is_reversed: false, + name: "max_utf8", + is_distinct: true, + exprs: &[col("value", &input_schema).unwrap()], + }) + .unwrap() +} + +fn bench_min_max_bytes(c: &mut Criterion) { + let mut group = c.benchmark_group("min_max_bytes"); + + for num_batches in [10, 20, 50, 100, 150, 200, 300, 400, 500] { + let id = BenchmarkId::from_parameter(num_batches); + group.throughput(Throughput::Elements((num_batches * BATCH_SIZE) as u64)); + group.bench_with_input(id, &num_batches, |bencher, num_batches| { + bencher.iter_with_large_drop(|| { + let mut accumulator = create_max_bytes_accumulator(); + let mut group_indices = Vec::with_capacity(BATCH_SIZE); + let strings: ArrayRef = Arc::new(StringArray::from_iter_values( + (0..BATCH_SIZE).map(|i| i.to_string()), + )); + + for batch_idx in 0..*num_batches { + group_indices.clear(); + group_indices + .extend((batch_idx * BATCH_SIZE)..(batch_idx + 1) * BATCH_SIZE); + let total_num_groups = (batch_idx + 1) * BATCH_SIZE; + + accumulator + .update_batch( + &[Arc::clone(&strings)], + &group_indices, + None, + total_num_groups, + ) + .unwrap() + } + }); + }); + } +} + +criterion_group!(benches, bench_min_max_bytes); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index fab53ae94b25d..a1e9894fb86c0 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -15,23 +15,26 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int64Type, Schema}; use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; + use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::sum::Sum; use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; fn prepare_accumulator(data_type: &DataType) -> Box { - let schema = Arc::new(Schema::new(vec![Field::new("f", data_type.clone(), true)])); + let field = Field::new("f", data_type.clone(), true).into(); + let schema = Arc::new(Schema::new(vec![Arc::clone(&field)])); let accumulator_args = AccumulatorArgs { - return_type: data_type, + return_field: field, schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], is_reversed: false, name: "SUM(f)", is_distinct: false, @@ -53,7 +56,7 @@ fn convert_to_state_bench( b.iter(|| { black_box( accumulator - .convert_to_state(&[values.clone()], opt_filter) + .convert_to_state(std::slice::from_ref(&values), opt_filter) .unwrap(), ) }) diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index c97dba1925ca9..9affdb3ee5f68 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -23,13 +23,17 @@ use arrow::array::{ GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ - ArrowPrimitiveType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, + ArrowPrimitiveType, Date32Type, Date64Type, FieldRef, Int16Type, Int32Type, + Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, }; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::ScalarValue; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, + downcast_value, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, + Result, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; @@ -61,9 +65,7 @@ impl TryFrom<&[u8]> for HyperLogLog { type Error = DataFusionError; fn try_from(v: &[u8]) -> Result> { let arr: [u8; 16384] = v.try_into().map_err(|_| { - DataFusionError::Internal( - "Impossibly got invalid binary array from states".into(), - ) + internal_datafusion_err!("Impossibly got invalid binary array from states") })?; Ok(HyperLogLog::::new_with_registers(arr)) } @@ -169,6 +171,9 @@ where } } +#[derive(Debug)] +struct NullHLLAccumulator; + macro_rules! default_accumulator_impl { () => { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -176,8 +181,8 @@ macro_rules! default_accumulator_impl { let binary_array = downcast_value!(states[0], BinaryArray); for v in binary_array.iter() { let v = v.ok_or_else(|| { - DataFusionError::Internal( - "Impossibly got empty binary array from states".into(), + internal_datafusion_err!( + "Impossibly got empty binary array from states" ) })?; let other = v.try_into()?; @@ -264,6 +269,29 @@ where default_accumulator_impl!(); } +impl Accumulator for NullHLLAccumulator { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { + // do nothing, all values are null + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![]) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::UInt64(Some(0))) + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + impl Debug for ApproxDistinct { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("ApproxDistinct") @@ -293,6 +321,7 @@ impl Default for ApproxDistinct { ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct ApproxDistinct { signature: Signature, } @@ -322,12 +351,13 @@ impl AggregateUDFImpl for ApproxDistinct { Ok(DataType::UInt64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "hll_registers"), DataType::Binary, false, - )]) + ) + .into()]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -345,11 +375,38 @@ impl AggregateUDFImpl for ApproxDistinct { DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Date32 => Box::new(NumericHLLAccumulator::::new()), + DataType::Date64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Time32(TimeUnit::Second) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Time32(TimeUnit::Millisecond) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Time64(TimeUnit::Microsecond) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Time64(TimeUnit::Nanosecond) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Timestamp(TimeUnit::Second, _) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + Box::new(NumericHLLAccumulator::::new()) + } DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), DataType::Utf8View => Box::new(StringViewHLLAccumulator::::new()), DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), + DataType::Null => Box::new(NullHLLAccumulator), other => { return not_impl_err!( "Support for 'approx_distinct' for data type {other} is not implemented" diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 787e08bae2867..976f4d2c94801 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -17,11 +17,11 @@ //! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution +use arrow::datatypes::DataType::{Float64, UInt64}; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; use std::fmt::Debug; - -use arrow::datatypes::DataType::{Float64, UInt64}; -use arrow::datatypes::{DataType, Field}; +use std::sync::Arc; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -45,7 +45,7 @@ make_udaf_expr_and_func!( /// APPROX_MEDIAN aggregate expression #[user_doc( doc_section(label = "Approximate Functions"), - description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`.", + description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`.", syntax_example = "approx_median(expression)", sql_example = r#"```sql > SELECT approx_median(column_name) FROM table_name; @@ -57,6 +57,7 @@ make_udaf_expr_and_func!( ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct ApproxMedian { signature: Signature, } @@ -91,7 +92,7 @@ impl AggregateUDFImpl for ApproxMedian { self } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), @@ -103,7 +104,10 @@ impl AggregateUDFImpl for ApproxMedian { Field::new_list_field(Float64, true), false, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn name(&self) -> &str { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 1fad5f73703c7..0deb09184b3f4 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::array::{Array, RecordBatch}; use arrow::compute::{filter, is_not_null}; +use arrow::datatypes::FieldRef; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -29,11 +30,11 @@ use arrow::{ }, datatypes::{DataType, Field, Schema}, }; - use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, Result, ScalarValue, }; +use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; @@ -51,30 +52,63 @@ create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); /// Computes the approximate percentile continuous of a set of numbers pub fn approx_percentile_cont( - expression: Expr, + order_by: Sort, percentile: Expr, centroids: Option, ) -> Expr { + let expr = order_by.expr.clone(); + let args = if let Some(centroids) = centroids { - vec![expression, percentile, centroids] + vec![expr, percentile, centroids] } else { - vec![expression, percentile] + vec![expr, percentile] }; - approx_percentile_cont_udaf().call(args) + + Expr::AggregateFunction(AggregateFunction::new_udf( + approx_percentile_cont_udaf(), + args, + false, + None, + vec![order_by], + None, + )) } #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont(expression, percentile, centroids)", + syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ +> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++-----------------------------------------------------------------------+ +| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | ++-----------------------------------------------------------------------+ +| 65.0 | ++-----------------------------------------------------------------------+ +``` +An alternate syntax is also supported: +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + > SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; -+-------------------------------------------------+ -| approx_percentile_cont(column_name, 0.75, 100) | -+-------------------------------------------------+ -| 65.0 | -+-------------------------------------------------+ -```"#, ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` +"#, standard_argument(name = "expression",), argument( name = "percentile", @@ -85,6 +119,7 @@ pub fn approx_percentile_cont( description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] +#[derive(PartialEq, Eq, Hash)] pub struct ApproxPercentileCont { signature: Signature, } @@ -130,6 +165,19 @@ impl ApproxPercentileCont { args: AccumulatorArgs, ) -> Result { let percentile = validate_input_percentile_expr(&args.exprs[1])?; + + let is_descending = args + .order_bys + .first() + .map(|sort_expr| sort_expr.options.descending) + .unwrap_or(false); + + let percentile = if is_descending { + 1.0 - percentile + } else { + percentile + }; + let tdigest_max_size = if args.exprs.len() == 3 { Some(validate_input_max_size_expr(&args.exprs[2])?) } else { @@ -232,7 +280,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "max_size"), @@ -264,7 +312,10 @@ impl AggregateUDFImpl for ApproxPercentileCont { Field::new_list_field(DataType::Float64, true), false, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn name(&self) -> &str { @@ -286,12 +337,20 @@ impl AggregateUDFImpl for ApproxPercentileCont { } if arg_types.len() == 3 && !arg_types[2].is_integer() { return plan_err!( - "approx_percentile_cont requires integer max_size input types" + "approx_percentile_cont requires integer centroids input types" ); } Ok(arg_types[0].clone()) } + fn supports_null_handling_clause(&self) -> bool { + false + } + + fn is_ordered_set_aggregate(&self) -> bool { + true + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -325,14 +384,23 @@ impl ApproxPercentileAccumulator { } } - // public for approx_percentile_cont_with_weight - pub fn merge_digests(&mut self, digests: &[TDigest]) { + // pub(crate) for approx_percentile_cont_with_weight + pub(crate) fn max_size(&self) -> usize { + self.digest.max_size() + } + + // pub(crate) for approx_percentile_cont_with_weight + pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) { let digests = digests.iter().chain(std::iter::once(&self.digest)); self.digest = TDigest::merge_digests(digests) } - // public for approx_percentile_cont_with_weight - pub fn convert_to_float(values: &ArrayRef) -> Result> { + // pub(crate) for approx_percentile_cont_with_weight + pub(crate) fn convert_to_float(values: &ArrayRef) -> Result> { + debug_assert!( + values.null_count() == 0, + "convert_to_float assumes nulls have already been filtered out" + ); match values.data_type() { DataType::Float64 => { let array = downcast_value!(values, Float64Array); @@ -429,7 +497,7 @@ impl Accumulator for ApproxPercentileAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // Remove any nulls before computing the percentile let mut values = Arc::clone(&values[0]); - if values.nulls().is_some() { + if values.null_count() > 0 { values = filter(&values, &is_not_null(&values)?)?; } let sorted_values = &arrow::compute::sort(&values, None)?; @@ -457,7 +525,7 @@ impl Accumulator for ApproxPercentileAccumulator { DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), DataType::Float32 => ScalarValue::Float32(Some(q as f32)), DataType::Float64 => ScalarValue::Float64(Some(q)), - v => unreachable!("unexpected return type {:?}", v), + v => unreachable!("unexpected return type {}", v), }) } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 16dac2c1b8f04..89ff546039e56 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,49 +17,85 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Field}, -}; - +use arrow::compute::{and, filter, is_not_null}; +use arrow::datatypes::FieldRef; +use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, -}; -use datafusion_functions_aggregate_common::tdigest::{ - Centroid, TDigest, DEFAULT_MAX_SIZE, + Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, }; +use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest}; use datafusion_macros::user_doc; use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; -make_udaf_expr_and_func!( +create_func!( ApproxPercentileContWithWeight, - approx_percentile_cont_with_weight, - expression weight percentile, - "Computes the approximate percentile continuous with weight of a set of numbers", approx_percentile_cont_with_weight_udaf ); +/// Computes the approximate percentile continuous with weight of a set of numbers +pub fn approx_percentile_cont_with_weight( + order_by: Sort, + weight: Expr, + percentile: Expr, + centroids: Option, +) -> Expr { + let expr = order_by.expr.clone(); + + let args = if let Some(centroids) = centroids { + vec![expr, weight, percentile, centroids] + } else { + vec![expr, weight, percentile] + }; + + Expr::AggregateFunction(AggregateFunction::new_udf( + approx_percentile_cont_with_weight_udaf(), + args, + false, + None, + vec![order_by], + None, + )) +} + /// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont_with_weight(expression, weight, percentile)", + syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++---------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) | ++---------------------------------------------------------------------------------------------+ +| 78.5 | ++---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ +``` +An alternative syntax is also supported: + +```sql > SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; -+----------------------------------------------------------------------+ ++--------------------------------------------------+ | approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | -+----------------------------------------------------------------------+ -| 78.5 | -+----------------------------------------------------------------------+ ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( @@ -69,8 +105,13 @@ make_udaf_expr_and_func!( argument( name = "percentile", description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)." + ), + argument( + name = "centroids", + description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] +#[derive(PartialEq, Eq, Hash)] pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, @@ -93,21 +134,26 @@ impl Default for ApproxPercentileContWithWeight { impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with weight and float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + ])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + int.clone(), + ])); + } + } Self { - signature: Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Immutable, - ), + signature: Signature::one_of(variants, Immutable), approx_percentile_cont: ApproxPercentileCont::new(), } } @@ -140,6 +186,11 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { if arg_types[2] != DataType::Float64 { return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); } + if arg_types.len() == 4 && !arg_types[3].is_integer() { + return plan_err!( + "approx_percentile_cont_with_weight requires integer centroids input types" + ); + } Ok(arg_types[0].clone()) } @@ -150,17 +201,25 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { ); } - if acc_args.exprs.len() != 3 { + if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 { return plan_err!( - "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]" ); } let sub_args = AccumulatorArgs { - exprs: &[ - Arc::clone(&acc_args.exprs[0]), - Arc::clone(&acc_args.exprs[2]), - ], + exprs: if acc_args.exprs.len() == 4 { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + Arc::clone(&acc_args.exprs[3]), // centroids + ] + } else { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + ] + }, ..acc_args }; let approx_percentile_cont_accumulator = @@ -174,10 +233,18 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.approx_percentile_cont.state_fields(args) } + fn supports_null_handling_clause(&self) -> bool { + false + } + + fn is_ordered_set_aggregate(&self) -> bool { + true + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -202,19 +269,41 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let means = &values[0]; - let weights = &values[1]; + let mut means = Arc::clone(&values[0]); + let mut weights = Arc::clone(&values[1]); + // If nulls are present in either array, need to filter those rows out in both arrays + match (means.null_count() > 0, weights.null_count() > 0) { + // Both have nulls + (true, true) => { + let predicate = and(&is_not_null(&means)?, &is_not_null(&weights)?)?; + means = filter(&means, &predicate)?; + weights = filter(&weights, &predicate)?; + } + // Only one has nulls + (false, true) => { + let predicate = &is_not_null(&weights)?; + means = filter(&means, predicate)?; + weights = filter(&weights, predicate)?; + } + (true, false) => { + let predicate = &is_not_null(&means)?; + means = filter(&means, predicate)?; + weights = filter(&weights, predicate)?; + } + // No nulls + (false, false) => {} + } debug_assert_eq!( means.len(), weights.len(), "invalid number of values in means and weights" ); - let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?; - let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?; + let means_f64 = ApproxPercentileAccumulator::convert_to_float(&means)?; + let weights_f64 = ApproxPercentileAccumulator::convert_to_float(&weights)?; let mut digests: Vec = vec![]; for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) { digests.push(TDigest::new_with_centroid( - DEFAULT_MAX_SIZE, + self.approx_percentile_cont_accumulator.max_size(), Centroid::new(*mean, *weight), )) } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 573624ce4d491..4d8676f24a289 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -17,26 +17,32 @@ //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] -use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray}; -use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, Fields}; +use std::cmp::Ordering; +use std::collections::{HashSet, VecDeque}; +use std::mem::{size_of, size_of_val, take}; +use std::sync::Arc; + +use arrow::array::{ + new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, +}; +use arrow::compute::{filter, SortOptions}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; -use datafusion_common::{exec_err, ScalarValue}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::utils::{ + compare_rows, get_row_at_idx, take_function_args, SingleRowListArrayBuilder, +}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, Signature, Volatility}; -use datafusion_expr::{AggregateUDFImpl, Documentation}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; +use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; use datafusion_functions_aggregate_common::utils::ordering_fields; use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use std::cmp::Ordering; -use std::collections::{HashSet, VecDeque}; -use std::mem::{size_of, size_of_val}; -use std::sync::Arc; make_udaf_expr_and_func!( ArrayAgg, @@ -69,16 +75,18 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp "#, standard_argument(name = "expression",) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] /// ARRAY_AGG aggregate expression pub struct ArrayAgg { signature: Signature, + is_input_pre_ordered: bool, } impl Default for ArrayAgg { fn default() -> Self { Self { signature: Signature::any(1, Volatility::Immutable), + is_input_pre_ordered: false, } } } @@ -92,10 +100,6 @@ impl AggregateUDFImpl for ArrayAgg { "array_agg" } - fn aliases(&self) -> &[String] { - &[] - } - fn signature(&self) -> &Signature { &self.signature } @@ -107,39 +111,60 @@ impl AggregateUDFImpl for ArrayAgg { )))) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, - )]); + ) + .into()]); } let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, - )]; + ) + .into()]; if args.ordering_fields.is_empty() { return Ok(fields); } let orderings = args.ordering_fields.to_vec(); - fields.push(Field::new_list( - format_state_name(args.name, "array_agg_orderings"), - Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), - false, - )); + fields.push( + Field::new_list( + format_state_name(args.name, "array_agg_orderings"), + Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), + false, + ) + .into(), + ); Ok(fields) } + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::SoftRequirement + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Ok(Some(Arc::new(Self { + signature: self.signature.clone(), + is_input_pre_ordered: beneficial_ordering, + }))) + } + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + let ignore_nulls = + acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?; if acc_args.is_distinct { // Limitation similar to Postgres. The aggregation function can only mix @@ -156,28 +181,30 @@ impl AggregateUDFImpl for ArrayAgg { // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid // ARRAY_AGG(DISTINCT col ORDER BY other_col) <- Invalid // ARRAY_AGG(DISTINCT col ORDER BY concat(col, '')) <- Invalid - if acc_args.ordering_req.len() > 1 { - return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"); - } - let mut sort_option: Option = None; - if let Some(order) = acc_args.ordering_req.first() { - if !order.expr.eq(&acc_args.exprs[0]) { - return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"); + let sort_option = match acc_args.order_bys { + [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options), + [] => None, + _ => { + return exec_err!( + "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list" + ); } - sort_option = Some(order.options) - } + }; return Ok(Box::new(DistinctArrayAggAccumulator::try_new( &data_type, sort_option, + ignore_nulls, )?)); } - if acc_args.ordering_req.is_empty() { - return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?)); - } + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return Ok(Box::new(ArrayAggAccumulator::try_new( + &data_type, + ignore_nulls, + )?)); + }; - let ordering_dtypes = acc_args - .ordering_req + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; @@ -185,8 +212,10 @@ impl AggregateUDFImpl for ArrayAgg { OrderSensitiveArrayAggAccumulator::try_new( &data_type, &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, + self.is_input_pre_ordered, acc_args.is_reversed, + ignore_nulls, ) .map(|acc| Box::new(acc) as _) } @@ -204,18 +233,20 @@ impl AggregateUDFImpl for ArrayAgg { pub struct ArrayAggAccumulator { values: Vec, datatype: DataType, + ignore_nulls: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), + ignore_nulls, }) } - /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list) + /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list) /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option { let offsets = list_array.value_offsets(); @@ -239,7 +270,7 @@ impl ArrayAggAccumulator { return Some(list_array.values().slice(0, 0)); } - // According to the Arrow spec, null values can point to non empty lists + // According to the Arrow spec, null values can point to non-empty lists // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value // Unwrapping is safe as we just checked if there is a null value @@ -247,7 +278,7 @@ impl ArrayAggAccumulator { let mut valid_slices_iter = nulls.valid_slices(); - // This is safe as we validated that that are at least 1 valid value in the array + // This is safe as we validated that there is at least 1 valid value in the array let (start, end) = valid_slices_iter.next().unwrap(); let start_offset = offsets[start]; @@ -257,7 +288,7 @@ impl ArrayAggAccumulator { let mut end_offset_of_last_valid_value = offsets[end]; for (start, end) in valid_slices_iter { - // If there is a null value that point to a non empty list than the start offset of the valid value + // If there is a null value that point to a non-empty list than the start offset of the valid value // will be different that the end offset of the last valid value if offsets[start] != end_offset_of_last_valid_value { return None; @@ -288,10 +319,23 @@ impl Accumulator for ArrayAggAccumulator { return internal_err!("expects single batch"); } - let val = Arc::clone(&values[0]); - if val.len() > 0 { - self.values.push(val); + let val = &values[0]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; + + let val = match nulls { + Some(nulls) if nulls.null_count() >= val.len() => return Ok(()), + Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?, + None => Arc::clone(val), + }; + + if !val.is_empty() { + self.values.push(val) } + Ok(()) } @@ -310,7 +354,7 @@ impl Accumulator for ArrayAggAccumulator { match Self::get_optional_values_to_merge_as_is(list_arr) { Some(values) => { // Make sure we don't insert empty lists - if values.len() > 0 { + if !values.is_empty() { self.values.push(values); } } @@ -348,7 +392,18 @@ impl Accumulator for ArrayAggAccumulator { + self .values .iter() - .map(|arr| arr.get_array_memory_size()) + // Each ArrayRef might be just a reference to a bigger array, and many + // ArrayRefs here might be referencing exactly the same array, so if we + // were to call `arr.get_array_memory_size()`, we would be double-counting + // the same underlying data many times. + // + // Instead, we do an approximation by estimating how much memory each + // ArrayRef would occupy if its underlying data was fully owned by this + // accumulator. + // + // Note that this is just an estimation, but the reality is that this + // accumulator might not own any data. + .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default()) .sum::() + self.datatype.size() - size_of_val(&self.datatype) @@ -360,17 +415,20 @@ struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, sort_options: Option, + ignore_nulls: bool, } impl DistinctArrayAggAccumulator { pub fn try_new( datatype: &DataType, sort_options: Option, + ignore_nulls: bool, ) -> Result { Ok(Self { values: HashSet::new(), datatype: datatype.clone(), sort_options, + ignore_nulls, }) } } @@ -385,11 +443,21 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - let array = &values[0]; + let val = &values[0]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; - for i in 0..array.len() { - let scalar = ScalarValue::try_from_array(&array, i)?; - self.values.insert(scalar); + let nulls = nulls.as_ref(); + if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { + for i in 0..val.len() { + if nulls.is_none_or(|nulls| nulls.is_valid(i)) { + self.values + .insert(ScalarValue::try_from_array(val, i)?.compacted()); + } + } } Ok(()) @@ -418,6 +486,7 @@ impl Accumulator for DistinctArrayAggAccumulator { } if let Some(opts) = self.sort_options { + let mut delayed_cmp_err = Ok(()); values.sort_by(|a, b| { if a.is_null() { return match opts.nulls_first { @@ -432,10 +501,15 @@ impl Accumulator for DistinctArrayAggAccumulator { }; } match opts.descending { - true => b.partial_cmp(a).unwrap_or(Ordering::Equal), - false => a.partial_cmp(b).unwrap_or(Ordering::Equal), + true => b.try_cmp(a), + false => a.try_cmp(b), } + .unwrap_or_else(|err| { + delayed_cmp_err = Err(err); + Ordering::Equal + }) }); + delayed_cmp_err?; }; let arr = ScalarValue::new_list(&values, &self.datatype, true); @@ -469,8 +543,12 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator { datatypes: Vec, /// Stores the ordering requirement of the `Accumulator`. ordering_req: LexOrdering, + /// Whether the input is known to be pre-ordered + is_input_pre_ordered: bool, /// Whether the aggregation is running in reverse. reverse: bool, + /// Whether the aggregation should ignore null values. + ignore_nulls: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -480,7 +558,9 @@ impl OrderSensitiveArrayAggAccumulator { datatype: &DataType, ordering_dtypes: &[DataType], ordering_req: LexOrdering, + is_input_pre_ordered: bool, reverse: bool, + ignore_nulls: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -489,9 +569,58 @@ impl OrderSensitiveArrayAggAccumulator { ordering_values: vec![], datatypes, ordering_req, + is_input_pre_ordered, reverse, + ignore_nulls, }) } + + fn sort(&mut self) { + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + let mut values = take(&mut self.values) + .into_iter() + .zip(take(&mut self.ordering_values)) + .collect::>(); + let mut delayed_cmp_err = Ok(()); + values.sort_by(|(_, left_ordering), (_, right_ordering)| { + compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else( + |err| { + delayed_cmp_err = Err(err); + Ordering::Equal + }, + ) + }); + (self.values, self.ordering_values) = values.into_iter().unzip(); + } + + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + + let column_wise_ordering_values = if self.ordering_values.is_empty() { + fields + .iter() + .map(|f| new_empty_array(f.data_type())) + .collect::>() + } else { + (0..fields.len()) + .map(|i| { + let column_values = self.ordering_values.iter().map(|x| x[i].clone()); + ScalarValue::iter_to_array(column_values) + }) + .collect::>()? + }; + + let ordering_array = StructArray::try_new( + Fields::from(fields), + column_wise_ordering_values, + None, + )?; + Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) + } } impl Accumulator for OrderSensitiveArrayAggAccumulator { @@ -500,11 +629,28 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { return Ok(()); } - let n_row = values[0].len(); - for index in 0..n_row { - let row = get_row_at_idx(values, index)?; - self.values.push(row[0].clone()); - self.ordering_values.push(row[1..].to_vec()); + let val = &values[0]; + let ord = &values[1..]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; + + let nulls = nulls.as_ref(); + if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { + for i in 0..val.len() { + if nulls.is_none_or(|nulls| nulls.is_valid(i)) { + self.values + .push(ScalarValue::try_from_array(val, i)?.compacted()); + self.ordering_values.push( + get_row_at_idx(ord, i)? + .into_iter() + .map(|v| v.compacted()) + .collect(), + ) + } + } } Ok(()) @@ -521,9 +667,8 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { // inside `ARRAY_AGG` list, we will receive an `Array` that stores values // received from its ordering requirement expression. (This information // is necessary for during merging). - let [array_agg_values, agg_orderings, ..] = &states else { - return exec_err!("State should have two elements"); - }; + let [array_agg_values, agg_orderings] = + take_function_args("OrderSensitiveArrayAggAccumulator::merge_batch", states)?; let Some(agg_orderings) = agg_orderings.as_list_opt::() else { return exec_err!("Expects to receive a list array"); }; @@ -534,18 +679,24 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { let mut partition_ordering_values = vec![]; // Existing values should be merged also. - partition_values.push(self.values.clone().into()); - partition_ordering_values.push(self.ordering_values.clone().into()); + if !self.is_input_pre_ordered { + self.sort(); + } + partition_values.push(take(&mut self.values).into()); + partition_ordering_values.push(take(&mut self.ordering_values).into()); // Convert array to Scalars to sort them easily. Convert back to array at evaluation. let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - for v in array_agg_res.into_iter() { - partition_values.push(v.into()); + for maybe_v in array_agg_res.into_iter() { + if let Some(v) = maybe_v { + partition_values.push(v.into()); + } else { + partition_values.push(vec![].into()); + } } let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - - for partition_ordering_rows in orderings.into_iter() { + for partition_ordering_rows in orderings.into_iter().flatten() { // Extract value from struct to ordering_rows for each group/partition let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { if let ScalarValue::Struct(s) = ordering_row { @@ -584,6 +735,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn state(&mut self) -> Result> { + if !self.is_input_pre_ordered { + self.sort(); + } + let mut result = vec![self.evaluate()?]; result.push(self.evaluate_orderings()?); @@ -591,6 +746,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn evaluate(&mut self) -> Result { + if !self.is_input_pre_ordered { + self.sort(); + } + if self.values.is_empty() { return Ok(ScalarValue::new_null_list( self.datatypes[0].clone(), @@ -635,41 +794,15 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } } -impl OrderSensitiveArrayAggAccumulator { - fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); - let num_columns = fields.len(); - let struct_field = Fields::from(fields.clone()); - - let mut column_wise_ordering_values = vec![]; - for i in 0..num_columns { - let column_values = self - .ordering_values - .iter() - .map(|x| x[i].clone()) - .collect::>(); - let array = if column_values.is_empty() { - new_empty_array(fields[i].data_type()) - } else { - ScalarValue::iter_to_array(column_values.into_iter())? - }; - column_wise_ordering_values.push(array); - } - - let ordering_array = - StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) - } -} - #[cfg(test)] mod tests { use super::*; + use arrow::array::{ListBuilder, StringBuilder}; use arrow::datatypes::{FieldRef, Schema}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::internal_err; use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; #[test] @@ -931,10 +1064,59 @@ mod tests { Ok(()) } + #[test] + fn does_not_over_account_memory() -> Result<()> { + let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?; + + acc1.update_batch(&[data(["a", "c", "b"])])?; + acc2.update_batch(&[data(["b", "c", "a"])])?; + acc1 = merge(acc1, acc2)?; + + assert_eq!(acc1.size(), 266); + + Ok(()) + } + #[test] + fn does_not_over_account_memory_distinct() -> Result<()> { + let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string() + .distinct() + .build_two()?; + + acc1.update_batch(&[string_list_data([ + vec!["a", "b", "c"], + vec!["d", "e", "f"], + ])])?; + acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?; + acc1 = merge(acc1, acc2)?; + + // without compaction, the size is 16660 + assert_eq!(acc1.size(), 1660); + + Ok(()) + } + + #[test] + fn does_not_over_account_memory_ordered() -> Result<()> { + let mut acc = ArrayAggAccumulatorBuilder::string() + .order_by_col("col", SortOptions::new(false, false)) + .build()?; + + acc.update_batch(&[string_list_data([ + vec!["a", "b", "c"], + vec!["c", "d", "e"], + vec!["b", "c", "d"], + ])])?; + + // without compaction, the size is 17112 + assert_eq!(acc.size(), 2184); + + Ok(()) + } + struct ArrayAggAccumulatorBuilder { - data_type: DataType, + return_field: FieldRef, distinct: bool, - ordering: LexOrdering, + order_bys: Vec, schema: Schema, } @@ -945,15 +1127,13 @@ mod tests { fn new(data_type: DataType) -> Self { Self { - data_type: data_type.clone(), - distinct: Default::default(), - ordering: Default::default(), + return_field: Field::new("f", data_type.clone(), true).into(), + distinct: false, + order_bys: vec![], schema: Schema { fields: Fields::from(vec![Field::new( "col", - DataType::List(FieldRef::new(Field::new( - "item", data_type, true, - ))), + DataType::new_list(data_type, true), true, )]), metadata: Default::default(), @@ -967,22 +1147,23 @@ mod tests { } fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self { - self.ordering.extend([PhysicalSortExpr::new( + let new_order = PhysicalSortExpr::new( Arc::new( Column::new_with_schema(col, &self.schema) .expect("column not available in schema"), ), sort_options, - )]); + ); + self.order_bys.push(new_order); self } fn build(&self) -> Result> { ArrayAgg::default().accumulator(AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: false, - ordering_req: &self.ordering, + order_bys: &self.order_bys, is_reversed: false, name: "", is_distinct: self.distinct, @@ -1007,10 +1188,19 @@ mod tests { fn print_nulls(sort: Vec>) -> Vec { sort.into_iter() - .map(|v| v.unwrap_or("NULL".to_string())) + .map(|v| v.unwrap_or_else(|| "NULL".to_string())) .collect() } + fn string_list_data<'a>(data: impl IntoIterator>) -> ArrayRef { + let mut builder = ListBuilder::new(StringBuilder::new()); + for string_list in data.into_iter() { + builder.append_value(string_list.iter().map(Some).collect::>()); + } + + Arc::new(builder.finish()) + } + fn data(list: [T; N]) -> ArrayRef where ScalarValue: From, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 141771b0412f2..d007163e7c08f 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -24,8 +24,11 @@ use arrow::array::{ use arrow::compute::sum; use arrow::datatypes::{ - i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, - Float64Type, UInt64Type, + i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, + UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, @@ -35,10 +38,13 @@ use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_typ use datafusion_expr::utils::format_state_name; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator, ReversedUDAF, Signature, }; +use datafusion_functions_aggregate_common::aggregate::avg_distinct::{ + DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator, +}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ filtered_null_mask, set_nulls, @@ -60,6 +66,17 @@ make_udaf_expr_and_func!( avg_udaf ); +pub fn avg_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + avg_udaf(), + vec![expr], + true, + None, + vec![], + None, + )) +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the average of numeric values in the specified column.", @@ -74,7 +91,7 @@ make_udaf_expr_and_func!( ```"#, standard_argument(name = "expression",) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Avg { signature: Signature, aliases: Vec, @@ -113,66 +130,176 @@ impl AggregateUDFImpl for Avg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - if acc_args.is_distinct { - return exec_err!("avg(DISTINCT) aggregations are not available"); - } + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; use DataType::*; - let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; // instantiate specialized accumulator based for the type - match (&data_type, acc_args.return_type) { - (Float64, Float64) => Ok(Box::::default()), - ( - Decimal128(sum_precision, sum_scale), - Decimal128(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator:: { - sum: None, - count: 0, - sum_scale: *sum_scale, - sum_precision: *sum_precision, - target_precision: *target_precision, - target_scale: *target_scale, - })), + if acc_args.is_distinct { + match (&data_type, acc_args.return_type()) { + // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation + (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())), + + ( + Decimal32(_, scale), + Decimal32(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + ( + Decimal64(_, scale), + Decimal64(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + ( + Decimal128(_, scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), - ( - Decimal256(sum_precision, sum_scale), - Decimal256(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator:: { - sum: None, - count: 0, - sum_scale: *sum_scale, - sum_precision: *sum_precision, - target_precision: *target_precision, - target_scale: *target_scale, - })), - _ => exec_err!( - "AvgAccumulator for ({} --> {})", - &data_type, - acc_args.return_type - ), + ( + Decimal256(_, scale), + Decimal256(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + + (dt, return_type) => exec_err!( + "AVG(DISTINCT) for ({} --> {}) not supported", + dt, + return_type + ), + } + } else { + match (&data_type, acc_args.return_type()) { + (Float64, Float64) => Ok(Box::::default()), + ( + Decimal32(sum_precision, sum_scale), + Decimal32(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + ( + Decimal64(sum_precision, sum_scale), + Decimal64(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + ( + Decimal128(sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + + ( + Decimal256(sum_precision, sum_scale), + Decimal256(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + + (Duration(time_unit), Duration(result_unit)) => { + Ok(Box::new(DurationAvgAccumulator { + sum: None, + count: 0, + time_unit: *time_unit, + result_unit: *result_unit, + })) + } + + (dt, return_type) => { + exec_err!("AvgAccumulator for ({} --> {})", dt, return_type) + } + } } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { - Ok(vec![ - Field::new( - format_state_name(args.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(args.name, "sum"), - args.input_types[0].clone(), - true, - ), - ]) + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + // Decimal accumulator actually uses a different precision during accumulation, + // see DecimalDistinctAvgAccumulator::with_decimal_params + let dt = match args.input_fields[0].data_type() { + DataType::Decimal32(_, scale) => { + DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale) + } + DataType::Decimal64(_, scale) => { + DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale) + } + DataType::Decimal128(_, scale) => { + DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale) + } + DataType::Decimal256(_, scale) => { + DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale) + } + _ => args.return_type().clone(), + }; + // Similar to datafusion_functions_aggregate::sum::Sum::state_fields + // since the accumulator uses DistinctSumAccumulator internally. + Ok(vec![Field::new_list( + format_state_name(args.name, "avg distinct"), + Field::new_list_field(dt, true), + false, + ) + .into()]) + } else { + Ok(vec![ + Field::new( + format_state_name(args.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + format_state_name(args.name, "sum"), + args.input_fields[0].data_type().clone(), + true, + ), + ] + .into_iter() + .map(Arc::new) + .collect()) + } } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { matches!( - args.return_type, - DataType::Float64 | DataType::Decimal128(_, _) - ) + args.return_field.data_type(), + DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Duration(_) + ) && !args.is_distinct } fn create_groups_accumulator( @@ -183,14 +310,52 @@ impl AggregateUDFImpl for Avg { let data_type = args.exprs[0].data_type(args.schema)?; // instantiate specialized accumulator based for the type - match (&data_type, args.return_type) { + match (&data_type, args.return_field.data_type()) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), |sum: f64, count: u64| Ok(sum / count as f64), ))) } + ( + Decimal32(_sum_precision, sum_scale), + Decimal32(target_precision, target_scale), + ) => { + let decimal_averager = DecimalAverager::::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = + move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32); + + Ok(Box::new(AvgGroupsAccumulator::::new( + &data_type, + args.return_field.data_type(), + avg_fn, + ))) + } + ( + Decimal64(_sum_precision, sum_scale), + Decimal64(target_precision, target_scale), + ) => { + let decimal_averager = DecimalAverager::::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = + move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64); + + Ok(Box::new(AvgGroupsAccumulator::::new( + &data_type, + args.return_field.data_type(), + avg_fn, + ))) + } ( Decimal128(_sum_precision, sum_scale), Decimal128(target_precision, target_scale), @@ -206,7 +371,7 @@ impl AggregateUDFImpl for Avg { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), avg_fn, ))) } @@ -227,15 +392,54 @@ impl AggregateUDFImpl for Avg { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), avg_fn, ))) } + (Duration(time_unit), Duration(_result_unit)) => { + let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); + + match time_unit { + TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::< + DurationSecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMillisecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMicrosecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationNanosecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + } + } + _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", &data_type, - args.return_type + args.return_field.data_type() ), } } @@ -335,7 +539,7 @@ impl Accumulator for DecimalAvgAccumu self.count += (values.len() - values.null_count()) as u64; if let Some(x) = sum(values) { - let v = self.sum.get_or_insert(T::Native::default()); + let v = self.sum.get_or_insert_with(T::Native::default); self.sum = Some(v.add_wrapping(x)); } Ok(()) @@ -380,7 +584,7 @@ impl Accumulator for DecimalAvgAccumu // sums are summed if let Some(x) = sum(states[1].as_primitive::()) { - let v = self.sum.get_or_insert(T::Native::default()); + let v = self.sum.get_or_insert_with(T::Native::default); self.sum = Some(v.add_wrapping(x)); } Ok(()) @@ -399,6 +603,105 @@ impl Accumulator for DecimalAvgAccumu } } +/// An accumulator to compute the average for duration values +#[derive(Debug)] +struct DurationAvgAccumulator { + sum: Option, + count: u64, + time_unit: TimeUnit, + result_unit: TimeUnit, +} + +impl Accumulator for DurationAvgAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - array.null_count()) as u64; + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(array.as_primitive::()), + TimeUnit::Millisecond => sum(array.as_primitive::()), + TimeUnit::Microsecond => sum(array.as_primitive::()), + TimeUnit::Nanosecond => sum(array.as_primitive::()), + }; + + if let Some(x) = sum_value { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let avg = self.sum.map(|sum| sum / self.count as i64); + + match self.result_unit { + TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)), + TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)), + TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)), + TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)), + } + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + let duration_value = match self.time_unit { + TimeUnit::Second => ScalarValue::DurationSecond(self.sum), + TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum), + TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum), + TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum), + }; + + Ok(vec![ScalarValue::from(self.count), duration_value]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(states[1].as_primitive::()), + TimeUnit::Millisecond => { + sum(states[1].as_primitive::()) + } + TimeUnit::Microsecond => { + sum(states[1].as_primitive::()) + } + TimeUnit::Nanosecond => { + sum(states[1].as_primitive::()) + } + }; + + if let Some(x) = sum_value { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count -= (array.len() - array.null_count()) as u64; + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(array.as_primitive::()), + TimeUnit::Millisecond => sum(array.as_primitive::()), + TimeUnit::Microsecond => sum(array.as_primitive::()), + TimeUnit::Nanosecond => sum(array.as_primitive::()), + }; + + if let Some(x) = sum_value { + self.sum = Some(self.sum.unwrap() - x); + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + /// An accumulator to compute the average of `[PrimitiveArray]`. /// Stores values as native types, and does overflow checking /// @@ -436,7 +739,7 @@ where { pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { debug!( - "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}", + "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}", std::any::type_name::() ); diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 50ab50abc9e2a..e63044c753173 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -20,13 +20,14 @@ use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use std::hash::Hash; use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; use arrow::datatypes::{ - ArrowNativeType, ArrowNumericType, DataType, Field, Int16Type, Int32Type, Int64Type, - Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type, + Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::cast::as_list_array; @@ -39,7 +40,7 @@ use datafusion_expr::{ Signature, Volatility, }; -use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_doc::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign}; use std::sync::LazyLock; @@ -87,7 +88,7 @@ macro_rules! accumulator_helper { /// `is_distinct` is boolean value indicating whether the operation is distinct or not. macro_rules! downcast_bitwise_accumulator { ($args:ident, $opr:expr, $is_distinct: expr) => { - match $args.return_type { + match $args.return_field.data_type() { DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct), DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct), DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct), @@ -101,7 +102,7 @@ macro_rules! downcast_bitwise_accumulator { "{} not supported for {}: {}", stringify!($opr), $args.name, - $args.return_type + $args.return_field.data_type() ) } } @@ -196,7 +197,7 @@ make_bitwise_udaf_expr_and_func!( ); /// The different types of bitwise operations that can be performed. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] enum BitwiseOperationType { And, Or, @@ -205,12 +206,12 @@ enum BitwiseOperationType { impl Display for BitwiseOperationType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } /// [BitwiseOperation] struct encapsulates information about a bitwise operation. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct BitwiseOperation { signature: Signature, /// `operation` indicates the type of bitwise operation to be performed. @@ -263,7 +264,7 @@ impl AggregateUDFImpl for BitwiseOperation { downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if self.operation == BitwiseOperationType::Xor && args.is_distinct { Ok(vec![Field::new_list( format_state_name( @@ -271,15 +272,17 @@ impl AggregateUDFImpl for BitwiseOperation { format!("{} distinct", self.name()).as_str(), ), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type.clone(), true), + Field::new_list_field(args.return_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, self.name()), - args.return_type.clone(), + args.return_field.data_type().clone(), true, - )]) + ) + .into()]) } } @@ -291,7 +294,7 @@ impl AggregateUDFImpl for BitwiseOperation { &self, args: AccumulatorArgs, ) -> Result> { - let data_type = args.return_type; + let data_type = args.return_field.data_type(); let operation = &self.operation; downcast_integer! { data_type => (group_accumulator_helper, data_type, operation), @@ -379,7 +382,7 @@ where { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::()) { - let v = self.value.get_or_insert(T::Native::usize_as(0)); + let v = self.value.get_or_insert_with(|| T::Native::usize_as(0)); *v = *v | x; } Ok(()) @@ -424,7 +427,7 @@ where { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::()) { - let v = self.value.get_or_insert(T::Native::usize_as(0)); + let v = self.value.get_or_insert_with(|| T::Native::usize_as(0)); *v = *v ^ x; } Ok(()) @@ -476,7 +479,7 @@ impl Default for DistinctBitXorAccumulator { impl Accumulator for DistinctBitXorAccumulator where - T::Native: std::ops::BitXor + std::hash::Hash + Eq, + T::Native: std::ops::BitXor + Hash + Eq, { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 1b33a7900c000..ff389bb419e2e 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -24,8 +24,8 @@ use arrow::array::ArrayRef; use arrow::array::BooleanArray; use arrow::compute::bool_and as compute_bool_and; use arrow::compute::bool_or as compute_bool_or; -use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::internal_err; use datafusion_common::{downcast_value, not_impl_err}; @@ -106,7 +106,7 @@ make_udaf_expr_and_func!( standard_argument(name = "expression", prefix = "The") )] /// BOOL_AND aggregate expression -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct BoolAnd { signature: Signature, } @@ -150,12 +150,13 @@ impl AggregateUDFImpl for BoolAnd { Ok(Box::::default()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, self.name()), DataType::Boolean, true, - )]) + ) + .into()]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -166,22 +167,18 @@ impl AggregateUDFImpl for BoolAnd { &self, args: AccumulatorArgs, ) -> Result> { - match args.return_type { + match args.return_field.data_type() { DataType::Boolean => { Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y, true))) } _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.return_type + args.return_field.data_type() ), } } - fn aliases(&self) -> &[String] { - &[] - } - fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Insensitive } @@ -244,7 +241,7 @@ impl Accumulator for BoolAndAccumulator { standard_argument(name = "expression", prefix = "The") )] /// BOOL_OR aggregate expression -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct BoolOr { signature: Signature, } @@ -288,12 +285,13 @@ impl AggregateUDFImpl for BoolOr { Ok(Box::::default()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, self.name()), DataType::Boolean, true, - )]) + ) + .into()]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -304,7 +302,7 @@ impl AggregateUDFImpl for BoolOr { &self, args: AccumulatorArgs, ) -> Result> { - match args.return_type { + match args.return_field.data_type() { DataType::Boolean => Ok(Box::new(BooleanGroupsAccumulator::new( |x, y| x || y, false, @@ -312,15 +310,11 @@ impl AggregateUDFImpl for BoolOr { _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.return_type + args.return_field.data_type() ), } } - fn aliases(&self) -> &[String] { - &[] - } - fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Insensitive } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index ac57256ce882f..20f23662cadec 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -26,8 +26,8 @@ use arrow::array::{ downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder, UInt64Array, }; -use arrow::compute::{and, filter, is_not_null, kernels::cast}; -use arrow::datatypes::{Float64Type, UInt64Type}; +use arrow::compute::{and, filter, is_not_null}; +use arrow::datatypes::{FieldRef, Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, @@ -38,10 +38,9 @@ use log::debug; use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; -use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, - type_coercion::aggregates::NUMERICS, utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; @@ -71,7 +70,7 @@ make_udaf_expr_and_func!( standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Correlation { signature: Signature, } @@ -83,10 +82,13 @@ impl Default for Correlation { } impl Correlation { - /// Create a new COVAR_POP aggregate function + /// Create a new CORR aggregate function pub fn new() -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } @@ -105,11 +107,7 @@ impl AggregateUDFImpl for Correlation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Correlation requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -117,7 +115,7 @@ impl AggregateUDFImpl for Correlation { Ok(Box::new(CorrelationAccumulator::try_new()?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -130,7 +128,10 @@ impl AggregateUDFImpl for Correlation { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn documentation(&self) -> Option<&Documentation> { @@ -193,6 +194,11 @@ impl Accumulator for CorrelationAccumulator { } fn evaluate(&mut self) -> Result { + let n = self.covar.get_count(); + if n < 2 { + return Ok(ScalarValue::Float64(None)); + } + let covar = self.covar.evaluate()?; let stddev1 = self.stddev1.evaluate()?; let stddev2 = self.stddev2.evaluate()?; @@ -201,7 +207,7 @@ impl Accumulator for CorrelationAccumulator { if let ScalarValue::Float64(Some(s1)) = stddev1 { if let ScalarValue::Float64(Some(s2)) = stddev2 { if s1 == 0_f64 || s2 == 0_f64 { - return Ok(ScalarValue::Float64(Some(0_f64))); + return Ok(ScalarValue::Float64(None)); } else { return Ok(ScalarValue::Float64(Some(c / s1 / s2))); } @@ -372,10 +378,8 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { self.sum_xx.resize(total_num_groups, 0.0); self.sum_yy.resize(total_num_groups, 0.0); - let array_x = &cast(&values[0], &DataType::Float64)?; - let array_x = downcast_array::(array_x); - let array_y = &cast(&values[1], &DataType::Float64)?; - let array_y = downcast_array::(array_y); + let array_x = downcast_array::(&values[0]); + let array_y = downcast_array::(&values[1]); accumulate_multiple( group_indices, @@ -460,11 +464,8 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { // the `denominator` state is 0. In these cases, the final aggregation // result should be `Null` (according to PostgreSQL's behavior). // - // TODO: Old datafusion implementation returns 0.0 for these invalid cases. - // Update this to match PostgreSQL's behavior. for i in 0..n { if self.count[i] < 2 { - // TODO: Evaluate as `Null` (see notes above) values.push(0.0); nulls.append_null(); continue; @@ -485,7 +486,6 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { ((sum_xx - sum_x * mean_x) * (sum_yy - sum_y * mean_y)).sqrt(); if denominator == 0.0 { - // TODO: Evaluate as `Null` (see notes above) values.push(0.0); nulls.append_null(); } else { diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2d995b4a41793..c0d2ba199a131 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,53 +16,50 @@ // under the License. use ahash::RandomState; -use datafusion_common::stats::Precision; -use datafusion_expr::expr::WindowFunction; -use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; -use datafusion_macros::user_doc; -use datafusion_physical_expr::expressions; -use std::collections::HashSet; -use std::fmt::Debug; -use std::mem::{size_of, size_of_val}; -use std::ops::BitAnd; -use std::sync::Arc; - use arrow::{ - array::{ArrayRef, AsArray}, + array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray}, + buffer::BooleanBuffer, compute, datatypes::{ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, - Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }, }; - -use arrow::{ - array::{Array, BooleanArray, Int64Array, PrimitiveArray}, - buffer::BooleanBuffer, -}; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, Result, ScalarValue, -}; -use datafusion_expr::function::StateFieldsArgs; -use datafusion_expr::{ - function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, - Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility, + downcast_value, internal_err, not_impl_err, stats::Precision, + utils::expr::COUNT_STAR_EXPANSION, HashMap, Result, ScalarValue, }; use datafusion_expr::{ - Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, + expr::WindowFunction, + function::{AccumulatorArgs, StateFieldsArgs}, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator, + ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility, + WindowFunctionDefinition, }; -use datafusion_functions_aggregate_common::aggregate::count_distinct::{ - BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, - PrimitiveDistinctCountAccumulator, +use datafusion_functions_aggregate_common::aggregate::{ + count_distinct::BytesDistinctCountAccumulator, + count_distinct::BytesViewDistinctCountAccumulator, + count_distinct::DictionaryCountAccumulator, + count_distinct::FloatDistinctCountAccumulator, + count_distinct::PrimitiveDistinctCountAccumulator, + groups_accumulator::accumulate::accumulate_indices, }; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_macros::user_doc; +use datafusion_physical_expr::expressions; use datafusion_physical_expr_common::binary_map::OutputType; +use std::{ + collections::HashSet, + fmt::Debug, + mem::{size_of, size_of_val}, + ops::BitAnd, + sync::Arc, +}; -use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; make_udaf_expr_and_func!( Count, count, @@ -77,7 +74,7 @@ pub fn count_distinct(expr: Expr) -> Expr { vec![expr], true, None, - None, + vec![], None, )) } @@ -100,7 +97,7 @@ pub fn count_distinct(expr: Expr) -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all() -> Expr { - count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") + count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)") } /// Creates window aggregation to count all rows. @@ -123,9 +120,9 @@ pub fn count_all() -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all_window() -> Expr { - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) } @@ -150,6 +147,7 @@ pub fn count_all_window() -> Expr { ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct Count { signature: Signature, } @@ -179,6 +177,107 @@ impl Count { } } } +fn get_count_accumulator(data_type: &DataType) -> Box { + match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Utf8View => { + Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( + OutputType::BinaryView, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + } +} impl AggregateUDFImpl for Count { fn as_any(&self) -> &dyn std::any::Any { @@ -201,20 +300,27 @@ impl AggregateUDFImpl for Count { false } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { + let dtype: DataType = match &args.input_fields[0].data_type() { + DataType::Dictionary(_, values_type) => (**values_type).clone(), + &dtype => dtype.clone(), + }; + Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(dtype, true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, "count"), DataType::Int64, false, - )]) + ) + .into()]) } } @@ -228,121 +334,16 @@ impl AggregateUDFImpl for Count { } let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - DataType::Int8 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt8 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - DataType::Date32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Date64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Time32(TimeUnit::Millisecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Time32(TimeUnit::Second) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Time64(TimeUnit::Microsecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Time64(TimeUnit::Nanosecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Second, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - - DataType::Float16 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Float32 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Float64 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Utf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - DataType::Utf8View => { - Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) - } - DataType::LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + Ok(match data_type { + DataType::Dictionary(_, values_type) => { + let inner = get_count_accumulator(values_type); + Box::new(DictionaryCountAccumulator::new(inner)) } - DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( - OutputType::BinaryView, - )), - DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: data_type.clone(), - }), + _ => get_count_accumulator(data_type), }) } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { // groups accumulator only supports `COUNT(c1)`, not // `COUNT(c1, c2)`, etc @@ -407,6 +408,98 @@ impl AggregateUDFImpl for Count { // the same as new values are seen. SetMonotonicity::Increasing } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + if args.is_distinct { + let acc = + SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?; + Ok(Box::new(acc)) + } else { + let acc = CountAccumulator::new(); + Ok(Box::new(acc)) + } + } +} + +// DistinctCountAccumulator does not support retract_batch and sliding window +// this is a specialized accumulator for distinct count that supports retract_batch +// and sliding window. +#[derive(Debug)] +pub struct SlidingDistinctCountAccumulator { + counts: HashMap, + data_type: DataType, +} + +impl SlidingDistinctCountAccumulator { + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + counts: HashMap::default(), + data_type: data_type.clone(), + }) + } +} + +impl Accumulator for SlidingDistinctCountAccumulator { + fn state(&mut self) -> Result> { + let keys = self.counts.keys().cloned().collect::>(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + keys.as_slice(), + &self.data_type, + ))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + if let Some(cnt) = self.counts.get_mut(&v) { + *cnt -= 1; + if *cnt == 0 { + self.counts.remove(&v); + } + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let list_arr = states[0].as_list::(); + for inner in list_arr.iter().flatten() { + for j in 0..inner.len() { + let v = ScalarValue::try_from_array(&*inner, j)?; + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.counts.len() as i64))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + size_of_val(self) + } } #[derive(Debug)] @@ -708,8 +801,8 @@ impl Accumulator for DistinctCountAccumulator { } (0..arr.len()).try_for_each(|index| { - if !arr.is_null(index) { - let scalar = ScalarValue::try_from_array(arr, index)?; + let scalar = ScalarValue::try_from_array(arr, index)?; + if !scalar.is_null() { self.values.insert(scalar); } Ok(()) @@ -754,8 +847,28 @@ impl Accumulator for DistinctCountAccumulator { #[cfg(test)] mod tests { + use super::*; - use arrow::array::NullArray; + use arrow::{ + array::{DictionaryArray, Int32Array, NullArray, StringArray}, + datatypes::{DataType, Field, Int32Type, Schema}, + }; + use datafusion_expr::function::AccumulatorArgs; + use datafusion_physical_expr::expressions::Column; + use std::sync::Arc; + /// Helper function to create a dictionary array with non-null keys but some null values + /// Returns a dictionary array where: + /// - keys are [0, 1, 2, 0, 1] (all non-null) + /// - values are ["a", null, "c"] + /// - so the keys reference: "a", null, "c", "a", null + fn create_dictionary_with_null_values() -> Result> { + let values = StringArray::from(vec![Some("a"), None, Some("c")]); + let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null + Ok(DictionaryArray::::try_new( + keys, + Arc::new(values), + )?) + } #[test] fn count_accumulator_nulls() -> Result<()> { @@ -764,4 +877,167 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn test_nested_dictionary() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "dict_col", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )), + ), + true, + )])); + + // Using Count UDAF's accumulator + let count = Count::new(); + let expr = Arc::new(Column::new("dict_col", 0)); + let args = AccumulatorArgs { + schema: &schema, + exprs: &[expr], + is_distinct: true, + name: "count", + ignore_nulls: false, + is_reversed: false, + return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), + order_bys: &[], + }; + + let inner_dict = + DictionaryArray::::from_iter(["a", "b", "c", "d", "a", "b"]); + + let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]); + let dict_of_dict = + DictionaryArray::::try_new(keys, Arc::new(inner_dict))?; + + let mut acc = count.accumulator(args)?; + acc.update_batch(&[Arc::new(dict_of_dict)])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4))); + + Ok(()) + } + + #[test] + fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> { + let dict_array = create_dictionary_with_null_values()?; + + // The expected behavior is that count_distinct should count only non-null values + // which in this case are "a" and "c" (appearing as 0 and 2 in keys) + let mut accumulator = DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: dict_array.data_type().clone(), + }; + + accumulator.update_batch(&[Arc::new(dict_array)])?; + + // Should have 2 distinct non-null values ("a" and "c") + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2))); + Ok(()) + } + + #[test] + fn count_accumulator_dictionary_with_null_values() -> Result<()> { + let dict_array = create_dictionary_with_null_values()?; + + // The expected behavior is that count should only count non-null values + let mut accumulator = CountAccumulator::new(); + + accumulator.update_batch(&[Arc::new(dict_array)])?; + + // 5 elements in the array, of which 2 reference null values (the two 1s in the keys) + // So we should count 3 non-null values + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> { + // Create a dictionary array that only contains null values + let dict_values = StringArray::from(vec![None, Some("abc")]); + let dict_indices = Int32Array::from(vec![0; 5]); + let dict_array = + DictionaryArray::::try_new(dict_indices, Arc::new(dict_values))?; + + let mut accumulator = DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: dict_array.data_type().clone(), + }; + + accumulator.update_batch(&[Arc::new(dict_array)])?; + + // All referenced values are null so count(distinct) should be 0 + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_basic() -> Result<()> { + // Basic update_batch + evaluate functionality + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // Create an Int32Array: [1, 2, 2, 3, null] + let values: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(2), + Some(3), + None, + ])); + acc.update_batch(&[values])?; + // Expect distinct values {1,2,3} → count = 3 + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_retract() -> Result<()> { + // Test that retract_batch properly decrements counts + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?; + // Initial batch: ["a", "b", "a"] + let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")])) + as ArrayRef; + acc.update_batch(&[arr1])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"} + + // Retract batch: ["a", null, "b"] + let arr2 = + Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef; + acc.retract_batch(&[arr2])?; + // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"} + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_merge_states() -> Result<()> { + // Test merging multiple accumulator states with merge_batch + let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // acc1 sees [1, 2] + acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?; + // acc2 sees [2, 3] + acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?; + // Extract their states as Vec + let state_sv1 = acc1.state()?; + let state_sv2 = acc2.state()?; + // Convert ScalarValue states into Vec, propagating errors + // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray + let state_arr1: Vec = state_sv1 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + let state_arr2: Vec = state_sv2 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + // Merge both states into a fresh accumulator + let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + merged.merge_batch(&state_arr1)?; + merged.merge_batch(&state_arr2)?; + // Expect distinct {1,2,3} → count = 3 + assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index d4ae27533c6db..f74fddd603319 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -17,15 +17,12 @@ //! [`CovarianceSample`]: covariance sample aggregations. -use std::fmt::Debug; -use std::mem::size_of_val; - +use arrow::datatypes::FieldRef; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, compute::kernels::cast, datatypes::{DataType, Field}, }; - use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, @@ -38,6 +35,9 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate_common::stats::StatsType; use datafusion_macros::user_doc; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::Arc; make_udaf_expr_and_func!( CovarianceSample, @@ -70,6 +70,7 @@ make_udaf_expr_and_func!( standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] +#[derive(PartialEq, Eq, Hash)] pub struct CovarianceSample { signature: Signature, aliases: Vec, @@ -120,7 +121,7 @@ impl AggregateUDFImpl for CovarianceSample { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -131,7 +132,10 @@ impl AggregateUDFImpl for CovarianceSample { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { @@ -162,6 +166,7 @@ impl AggregateUDFImpl for CovarianceSample { standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] +#[derive(PartialEq, Eq, Hash)] pub struct CovariancePopulation { signature: Signature, } @@ -210,7 +215,7 @@ impl AggregateUDFImpl for CovariancePopulation { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -221,7 +226,10 @@ impl AggregateUDFImpl for CovariancePopulation { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 28e6a8723dfd4..28755427c7325 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt::Debug; +use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -29,12 +30,12 @@ use arrow::array::{ use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions}; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::cast::as_boolean_array; use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx}; @@ -45,26 +46,33 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt, - GroupsAccumulator, Signature, SortExpr, Volatility, + GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::LexOrdering; create_func!(FirstValue, first_value_udaf); +create_func!(LastValue, last_value_udaf); /// Returns the first value in a group of values. -pub fn first_value(expression: Expr, order_by: Option>) -> Expr { - if let Some(order_by) = order_by { - first_value_udaf() - .call(vec![expression]) - .order_by(order_by) - .build() - // guaranteed to be `Expr::AggregateFunction` - .unwrap() - } else { - first_value_udaf().call(vec![expression]) - } +pub fn first_value(expression: Expr, order_by: Vec) -> Expr { + first_value_udaf() + .call(vec![expression]) + .order_by(order_by) + .build() + // guaranteed to be `Expr::AggregateFunction` + .unwrap() +} + +/// Returns the last value in a group of values. +pub fn last_value(expression: Expr, order_by: Vec) -> Expr { + last_value_udaf() + .call(vec![expression]) + .order_by(order_by) + .build() + // guaranteed to be `Expr::AggregateFunction` + .unwrap() } #[user_doc( @@ -81,9 +89,10 @@ pub fn first_value(expression: Expr, order_by: Option>) -> Expr { ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct FirstValue { signature: Signature, - requirement_satisfied: bool, + is_input_pre_ordered: bool, } impl Debug for FirstValue { @@ -106,14 +115,9 @@ impl FirstValue { pub fn new() -> Self { Self { signature: Signature::any(1, Volatility::Immutable), - requirement_satisfied: false, + is_input_pre_ordered: false, } } - - fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } } impl AggregateUDFImpl for FirstValue { @@ -134,86 +138,92 @@ impl AggregateUDFImpl for FirstValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialFirstValueAccumulator::try_new( + acc_args.return_field.data_type(), + acc_args.ignore_nulls, + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - - // When requirement is empty, or it is signalled by outside caller that - // the ordering requirement is/will be satisfied. - let requirement_satisfied = - acc_args.ordering_req.is_empty() || self.requirement_satisfied; - - FirstValueAccumulator::try_new( - acc_args.return_type, + Ok(Box::new(FirstValueAccumulator::try_new( + acc_args.return_field.data_type(), &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, + self.is_input_pre_ordered, acc_args.ignore_nulls, - ) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + )?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( format_state_name(args.name, "first_value"), - args.return_type.clone(), + args.return_type().clone(), true, - )]; - fields.extend(args.ordering_fields.to_vec()); - fields.push(Field::new("is_set", DataType::Boolean, true)); + ) + .into()]; + fields.extend(args.ordering_fields.iter().cloned()); + fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; - matches!( - args.return_type, - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + !args.order_bys.is_empty() + && matches!( + args.return_field.data_type(), + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - fn create_accumulator( + fn create_accumulator( args: AccumulatorArgs, - ) -> Result> - where - T: ArrowPrimitiveType + Send, - { - let ordering_dtypes = args - .ordering_req + ) -> Result> { + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(args.schema)) .collect::>>()?; - Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( - args.ordering_req.clone(), + FirstPrimitiveGroupsAccumulator::::try_new( + ordering, args.ignore_nulls, - args.return_type, + args.return_field.data_type(), &ordering_dtypes, - )?)) + true, + ) + .map(|acc| Box::new(acc) as _) } - match args.return_type { + match args.return_field.data_type() { DataType::Int8 => create_accumulator::(args), DataType::Int16 => create_accumulator::(args), DataType::Int32 => create_accumulator::(args), @@ -226,6 +236,8 @@ impl AggregateUDFImpl for FirstValue { DataType::Float32 => create_accumulator::(args), DataType::Float64 => create_accumulator::(args), + DataType::Decimal32(_, _) => create_accumulator::(args), + DataType::Decimal64(_, _) => create_accumulator::(args), DataType::Decimal128(_, _) => create_accumulator::(args), DataType::Decimal256(_, _) => create_accumulator::(args), @@ -259,31 +271,28 @@ impl AggregateUDFImpl for FirstValue { } _ => internal_err!( - "GroupsAccumulator not supported for first({})", - args.return_type + "GroupsAccumulator not supported for first_value({})", + args.return_field.data_type() ), } } - fn aliases(&self) -> &[String] { - &[] - } - fn with_beneficial_ordering( self: Arc, beneficial_ordering: bool, ) -> Result>> { - Ok(Some(Arc::new( - FirstValue::new().with_requirement_satisfied(beneficial_ordering), - ))) + Ok(Some(Arc::new(Self { + signature: self.signature.clone(), + is_input_pre_ordered: beneficial_ordering, + }))) } fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Beneficial } - fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Reversed(last_value_udaf()) + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Reversed(last_value_udaf()) } fn documentation(&self) -> Option<&Documentation> { @@ -291,6 +300,7 @@ impl AggregateUDFImpl for FirstValue { } } +// TODO: rename to PrimitiveGroupsAccumulator struct FirstPrimitiveGroupsAccumulator where T: ArrowPrimitiveType + Send, @@ -316,16 +326,18 @@ where // buffer for `get_filtered_min_of_each_group` // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val // only valid if filter_min_of_each_group_buf.1[group_idx] == true + // TODO: rename to extreme_of_each_group_buf min_of_each_group_buf: (Vec, BooleanBufferBuilder), // =========== option ============ // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // true: take first element in an aggregation group according to the requested ordering. + // false: take last element in an aggregation group according to the requested ordering. + pick_first_in_group: bool, // derived from `ordering_req`. sort_options: Vec, - // Stores whether incoming data already satisfies the ordering requirement. - input_requirement_satisfied: bool, // Ignore null values. ignore_nulls: bool, /// The output type @@ -342,21 +354,19 @@ where ignore_nulls: bool, data_type: &DataType, ordering_dtypes: &[DataType], + pick_first_in_group: bool, ) -> Result { - let requirement_satisfied = ordering_req.is_empty(); - let default_orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; + .collect::>()?; - let sort_options = get_sort_options(ordering_req.as_ref()); + let sort_options = get_sort_options(&ordering_req); Ok(Self { null_builder: BooleanBufferBuilder::new(0), ordering_req, sort_options, - input_requirement_satisfied: requirement_satisfied, ignore_nulls, default_orderings, data_type: data_type.clone(), @@ -365,21 +375,10 @@ where is_sets: BooleanBufferBuilder::new(0), size_of_orderings: 0, min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)), + pick_first_in_group, }) } - fn need_update(&self, group_idx: usize) -> bool { - if !self.is_sets.get_bit(group_idx) { - return true; - } - - if self.ignore_nulls && !self.null_builder.get_bit(group_idx) { - return true; - } - - !self.input_requirement_satisfied - } - fn should_update_state( &self, group_idx: usize, @@ -391,8 +390,13 @@ where assert!(new_ordering_values.len() == self.ordering_req.len()); let current_ordering = &self.orderings[group_idx]; - compare_rows(current_ordering, new_ordering_values, &self.sort_options) - .map(|x| x.is_gt()) + compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| { + if self.pick_first_in_group { + x.is_gt() + } else { + x.is_lt() + } + }) } fn take_orderings(&mut self, emit_to: EmitTo) -> Vec> { @@ -501,10 +505,10 @@ where .map(ScalarValue::size_of_vec) .sum::() } - /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the /// minimum value in `orderings` for each group, using lexicographical comparison. /// Values are filtered using `opt_filter` and `is_set_arr` if provided. + /// TODO: rename to get_filtered_extreme_of_each_group fn get_filtered_min_of_each_group( &mut self, orderings: &[ArrayRef], @@ -540,31 +544,30 @@ where let group_idx = *group_idx; let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val)); - let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val)); if !passed_filter || !is_set { continue; } - if !self.need_update(group_idx) { - continue; - } - if self.ignore_nulls && vals.is_null(idx_in_val) { continue; } let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx); - if is_valid - && comparator - .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val) - .is_gt() - { - self.min_of_each_group_buf.0[group_idx] = idx_in_val; - } else if !is_valid { + + if !is_valid { self.min_of_each_group_buf.1.set_bit(group_idx, true); self.min_of_each_group_buf.0[group_idx] = idx_in_val; + } else { + let ordering = comparator + .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val); + + if (ordering.is_gt() && self.pick_first_in_group) + || (ordering.is_lt() && !self.pick_first_in_group) + { + self.min_of_each_group_buf.0[group_idx] = idx_in_val; + } } } @@ -683,7 +686,7 @@ where let (is_set_arr, val_and_order_cols) = match values.split_last() { Some(result) => result, - None => return internal_err!("Empty row in FISRT_VALUE"), + None => return internal_err!("Empty row in FIRST_VALUE"), }; let is_set_arr = as_boolean_array(is_set_arr)?; @@ -716,7 +719,7 @@ where fn size(&self) -> usize { self.vals.capacity() * size_of::() - + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes + + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes + self.is_sets.capacity() / 8 + self.size_of_orderings + self.min_of_each_group_buf.0.capacity() * size_of::() @@ -745,19 +748,101 @@ where } } } + +/// This accumulator is used when there is no ordering specified for the +/// `FIRST_VALUE` aggregation. It simply returns the first value it sees +/// according to the pre-existing ordering of the input data, and provides +/// a fast path for this case without needing to maintain any ordering state. +#[derive(Debug)] +pub struct TrivialFirstValueAccumulator { + first: ScalarValue, + // Whether we have seen the first value yet. + is_set: bool, + // Ignore null values. + ignore_nulls: bool, +} + +impl TrivialFirstValueAccumulator { + /// Creates a new `TrivialFirstValueAccumulator` for the given `data_type`. + pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result { + ScalarValue::try_from(data_type).map(|first| Self { + first, + is_set: false, + ignore_nulls, + }) + } +} + +impl Accumulator for TrivialFirstValueAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if !self.is_set { + // Get first entry according to the pre-existing ordering (0th index): + let value = &values[0]; + let mut first_idx = None; + if self.ignore_nulls { + // If ignoring nulls, find the first non-null value. + for i in 0..value.len() { + if !value.is_null(i) { + first_idx = Some(i); + break; + } + } + } else if !value.is_empty() { + // If not ignoring nulls, return the first value if it exists. + first_idx = Some(0); + } + if let Some(first_idx) = first_idx { + let mut row = get_row_at_idx(values, first_idx)?; + self.first = row.swap_remove(0); + self.first.compact(); + self.is_set = true; + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // FIRST_VALUE(first1, first2, first3, ...) + // Second index contains is_set flag. + if !self.is_set { + let flags = states[1].as_boolean(); + let filtered_states = + filter_states_according_to_is_set(&states[0..1], flags)?; + if let Some(first) = filtered_states.first() { + if !first.is_empty() { + self.first = ScalarValue::try_from_array(first, 0)?; + self.is_set = true; + } + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(self.first.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.first) + self.first.size() + } +} + #[derive(Debug)] pub struct FirstValueAccumulator { first: ScalarValue, - // At the beginning, `is_set` is false, which means `first` is not seen yet. - // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. + // Whether we have seen the first value yet. is_set: bool, - // Stores ordering values, of the aggregator requirement corresponding to first value - // of the aggregator. These values are used during merging of multiple partitions. + // Stores values of the ordering columns corresponding to the first value. + // These values are used during merging of multiple partitions. orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, // Stores whether incoming data already satisfies the ordering requirement. - requirement_satisfied: bool, + is_input_pre_ordered: bool, // Ignore null values. ignore_nulls: bool, } @@ -768,32 +853,31 @@ impl FirstValueAccumulator { data_type: &DataType, ordering_dtypes: &[DataType], ordering_req: LexOrdering, + is_input_pre_ordered: bool, ignore_nulls: bool, ) -> Result { let orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); + .collect::>()?; ScalarValue::try_from(data_type).map(|first| Self { first, is_set: false, orderings, ordering_req, - requirement_satisfied, + is_input_pre_ordered, ignore_nulls, }) } - pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } - // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.first = row[0].clone(); - self.orderings = row[1..].to_vec(); + fn update_with_new_row(&mut self, mut row: Vec) { + // Ensure any Array based scalars hold have a single value to reduce memory pressure + for s in row.iter_mut() { + s.compact(); + } + self.first = row.remove(0); + self.orderings = row; self.is_set = true; } @@ -801,7 +885,7 @@ impl FirstValueAccumulator { let [value, ordering_values @ ..] = values else { return internal_err!("Empty row in FIRST_VALUE"); }; - if self.requirement_satisfied { + if self.is_input_pre_ordered { // Get first entry according to the pre-existing ordering (0th index): if self.ignore_nulls { // If ignoring nulls, find the first non-null value. @@ -844,29 +928,23 @@ impl Accumulator for FirstValueAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.first.clone()]; result.extend(self.orderings.iter().cloned()); - result.push(ScalarValue::Boolean(Some(self.is_set))); + result.push(ScalarValue::from(self.is_set)); Ok(result) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !self.is_set { - if let Some(first_idx) = self.get_first_idx(values)? { - let row = get_row_at_idx(values, first_idx)?; - self.update_with_new_row(&row); - } - } else if !self.requirement_satisfied { - if let Some(first_idx) = self.get_first_idx(values)? { - let row = get_row_at_idx(values, first_idx)?; - let orderings = &row[1..]; - if compare_rows( - &self.orderings, - orderings, - &get_sort_options(self.ordering_req.as_ref()), - )? - .is_gt() - { - self.update_with_new_row(&row); - } + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + if !self.is_set + || (!self.is_input_pre_ordered + && compare_rows( + &self.orderings, + &row[1..], + &get_sort_options(&self.ordering_req), + )? + .is_gt()) + { + self.update_with_new_row(row); } } Ok(()) @@ -880,19 +958,17 @@ impl Accumulator for FirstValueAccumulator { let filtered_states = filter_states_according_to_is_set(&states[0..is_set_idx], flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_columns = convert_to_sort_cols( - &filtered_states[1..is_set_idx], - self.ordering_req.as_ref(), - ); + let sort_columns = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); let comparator = LexicographicalComparator::try_new(&sort_columns)?; let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b)); if let Some(first_idx) = min { - let first_row = get_row_at_idx(&filtered_states, first_idx)?; + let mut first_row = get_row_at_idx(&filtered_states, first_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let first_ordering = &first_row[1..is_set_idx]; - let sort_options = get_sort_options(self.ordering_req.as_ref()); + let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() @@ -900,7 +976,9 @@ impl Accumulator for FirstValueAccumulator { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&first_row[0..is_set_idx]); + assert!(is_set_idx <= first_row.len()); + first_row.resize(is_set_idx, ScalarValue::Null); + self.update_with_new_row(first_row); } } Ok(()) @@ -918,13 +996,6 @@ impl Accumulator for FirstValueAccumulator { } } -make_udaf_expr_and_func!( - LastValue, - last_value, - "Returns the last value in a group of values.", - last_value_udaf -); - #[user_doc( doc_section(label = "General Functions"), description = "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", @@ -939,9 +1010,10 @@ make_udaf_expr_and_func!( ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct LastValue { signature: Signature, - requirement_satisfied: bool, + is_input_pre_ordered: bool, } impl Debug for LastValue { @@ -964,14 +1036,9 @@ impl LastValue { pub fn new() -> Self { Self { signature: Signature::any(1, Volatility::Immutable), - requirement_satisfied: false, + is_input_pre_ordered: false, } } - - fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } } impl AggregateUDFImpl for LastValue { @@ -992,66 +1059,249 @@ impl AggregateUDFImpl for LastValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialLastValueAccumulator::try_new( + acc_args.return_field.data_type(), + acc_args.ignore_nulls, + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - - let requirement_satisfied = - acc_args.ordering_req.is_empty() || self.requirement_satisfied; - - LastValueAccumulator::try_new( - acc_args.return_type, + Ok(Box::new(LastValueAccumulator::try_new( + acc_args.return_field.data_type(), &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, + self.is_input_pre_ordered, acc_args.ignore_nulls, - ) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + )?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let StateFieldsArgs { - name, - input_types, - return_type: _, - ordering_fields, - is_distinct: _, - } = args; + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( - format_state_name(name, "last_value"), - input_types[0].clone(), + format_state_name(args.name, "last_value"), + args.return_field.data_type().clone(), true, - )]; - fields.extend(ordering_fields.to_vec()); - fields.push(Field::new("is_set", DataType::Boolean, true)); + ) + .into()]; + fields.extend(args.ordering_fields.iter().cloned()); + fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } - fn aliases(&self) -> &[String] { - &[] - } - fn with_beneficial_ordering( self: Arc, beneficial_ordering: bool, ) -> Result>> { - Ok(Some(Arc::new( - LastValue::new().with_requirement_satisfied(beneficial_ordering), - ))) + Ok(Some(Arc::new(Self { + signature: self.signature.clone(), + is_input_pre_ordered: beneficial_ordering, + }))) } fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Beneficial } - fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Reversed(first_value_udaf()) + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Reversed(first_value_udaf()) } fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + use DataType::*; + !args.order_bys.is_empty() + && matches!( + args.return_field.data_type(), + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + fn create_accumulator( + args: AccumulatorArgs, + ) -> Result> + where + T: ArrowPrimitiveType + Send, + { + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering + .iter() + .map(|e| e.expr.data_type(args.schema)) + .collect::>>()?; + + Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( + ordering, + args.ignore_nulls, + args.return_field.data_type(), + &ordering_dtypes, + false, + )?)) + } + + match args.return_field.data_type() { + DataType::Int8 => create_accumulator::(args), + DataType::Int16 => create_accumulator::(args), + DataType::Int32 => create_accumulator::(args), + DataType::Int64 => create_accumulator::(args), + DataType::UInt8 => create_accumulator::(args), + DataType::UInt16 => create_accumulator::(args), + DataType::UInt32 => create_accumulator::(args), + DataType::UInt64 => create_accumulator::(args), + DataType::Float16 => create_accumulator::(args), + DataType::Float32 => create_accumulator::(args), + DataType::Float64 => create_accumulator::(args), + + DataType::Decimal32(_, _) => create_accumulator::(args), + DataType::Decimal64(_, _) => create_accumulator::(args), + DataType::Decimal128(_, _) => create_accumulator::(args), + DataType::Decimal256(_, _) => create_accumulator::(args), + + DataType::Timestamp(TimeUnit::Second, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + create_accumulator::(args) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + create_accumulator::(args) + } + + DataType::Date32 => create_accumulator::(args), + DataType::Date64 => create_accumulator::(args), + DataType::Time32(TimeUnit::Second) => { + create_accumulator::(args) + } + DataType::Time32(TimeUnit::Millisecond) => { + create_accumulator::(args) + } + + DataType::Time64(TimeUnit::Microsecond) => { + create_accumulator::(args) + } + DataType::Time64(TimeUnit::Nanosecond) => { + create_accumulator::(args) + } + + _ => { + internal_err!( + "GroupsAccumulator not supported for last_value({})", + args.return_field.data_type() + ) + } + } + } +} + +/// This accumulator is used when there is no ordering specified for the +/// `LAST_VALUE` aggregation. It simply updates the last value it sees +/// according to the pre-existing ordering of the input data, and provides +/// a fast path for this case without needing to maintain any ordering state. +#[derive(Debug)] +pub struct TrivialLastValueAccumulator { + last: ScalarValue, + // The `is_set` flag keeps track of whether the last value is finalized. + // This information is used to discriminate genuine NULLs and NULLS that + // occur due to empty partitions. + is_set: bool, + // Ignore null values. + ignore_nulls: bool, +} + +impl TrivialLastValueAccumulator { + /// Creates a new `TrivialLastValueAccumulator` for the given `data_type`. + pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result { + ScalarValue::try_from(data_type).map(|last| Self { + last, + is_set: false, + ignore_nulls, + }) + } +} + +impl Accumulator for TrivialLastValueAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Get last entry according to the pre-existing ordering (0th index): + let value = &values[0]; + let mut last_idx = None; + if self.ignore_nulls { + // If ignoring nulls, find the last non-null value. + for i in (0..value.len()).rev() { + if !value.is_null(i) { + last_idx = Some(i); + break; + } + } + } else if !value.is_empty() { + // If not ignoring nulls, return the last value if it exists. + last_idx = Some(value.len() - 1); + } + if let Some(last_idx) = last_idx { + let mut row = get_row_at_idx(values, last_idx)?; + self.last = row.swap_remove(0); + self.last.compact(); + self.is_set = true; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // LAST_VALUE(last1, last2, last3, ...) + // Second index contains is_set flag. + let flags = states[1].as_boolean(); + let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?; + if let Some(last) = filtered_states.last() { + if !last.is_empty() { + self.last = ScalarValue::try_from_array(last, 0)?; + self.is_set = true; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(self.last.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.last) + self.last.size() + } } #[derive(Debug)] @@ -1061,11 +1311,13 @@ struct LastValueAccumulator { // This information is used to discriminate genuine NULLs and NULLS that // occur due to empty partitions. is_set: bool, + // Stores values of the ordering columns corresponding to the first value. + // These values are used during merging of multiple partitions. orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, // Stores whether incoming data already satisfies the ordering requirement. - requirement_satisfied: bool, + is_input_pre_ordered: bool, // Ignore null values. ignore_nulls: bool, } @@ -1076,27 +1328,31 @@ impl LastValueAccumulator { data_type: &DataType, ordering_dtypes: &[DataType], ordering_req: LexOrdering, + is_input_pre_ordered: bool, ignore_nulls: bool, ) -> Result { let orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); + .collect::>()?; ScalarValue::try_from(data_type).map(|last| Self { last, is_set: false, orderings, ordering_req, - requirement_satisfied, + is_input_pre_ordered, ignore_nulls, }) } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.last = row[0].clone(); - self.orderings = row[1..].to_vec(); + fn update_with_new_row(&mut self, mut row: Vec) { + // Ensure any Array based scalars hold have a single value to reduce memory pressure + for s in row.iter_mut() { + s.compact(); + } + self.last = row.remove(0); + self.orderings = row; self.is_set = true; } @@ -1104,7 +1360,7 @@ impl LastValueAccumulator { let [value, ordering_values @ ..] = values else { return internal_err!("Empty row in LAST_VALUE"); }; - if self.requirement_satisfied { + if self.is_input_pre_ordered { // Get last entry according to the order of data: if self.ignore_nulls { // If ignoring nulls, find the last non-null value. @@ -1118,6 +1374,7 @@ impl LastValueAccumulator { return Ok((!value.is_empty()).then_some(value.len() - 1)); } } + let sort_columns = ordering_values .iter() .zip(self.ordering_req.iter()) @@ -1138,42 +1395,33 @@ impl LastValueAccumulator { Ok(max_ind) } - - fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } } impl Accumulator for LastValueAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.last.clone()]; result.extend(self.orderings.clone()); - result.push(ScalarValue::Boolean(Some(self.is_set))); + result.push(ScalarValue::from(self.is_set)); Ok(result) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !self.is_set || self.requirement_satisfied { - if let Some(last_idx) = self.get_last_idx(values)? { - let row = get_row_at_idx(values, last_idx)?; - self.update_with_new_row(&row); - } - } else if let Some(last_idx) = self.get_last_idx(values)? { + if let Some(last_idx) = self.get_last_idx(values)? { let row = get_row_at_idx(values, last_idx)?; let orderings = &row[1..]; // Update when there is a more recent entry - if compare_rows( - &self.orderings, - orderings, - &get_sort_options(self.ordering_req.as_ref()), - )? - .is_lt() + if !self.is_set + || self.is_input_pre_ordered + || compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() { - self.update_with_new_row(&row); + self.update_with_new_row(row); } } - Ok(()) } @@ -1185,29 +1433,29 @@ impl Accumulator for LastValueAccumulator { let filtered_states = filter_states_according_to_is_set(&states[0..is_set_idx], flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_columns = convert_to_sort_cols( - &filtered_states[1..is_set_idx], - self.ordering_req.as_ref(), - ); + let sort_columns = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); let comparator = LexicographicalComparator::try_new(&sort_columns)?; let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b)); if let Some(last_idx) = max { - let last_row = get_row_at_idx(&filtered_states, last_idx)?; + let mut last_row = get_row_at_idx(&filtered_states, last_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let last_ordering = &last_row[1..is_set_idx]; - let sort_options = get_sort_options(self.ordering_req.as_ref()); + let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set - || self.requirement_satisfied + || self.is_input_pre_ordered || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt() { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&last_row[0..is_set_idx]); + assert!(is_set_idx <= last_row.len()); + last_row.resize(is_set_idx, ScalarValue::Null); + self.update_with_new_row(last_row); } } Ok(()) @@ -1234,7 +1482,7 @@ fn filter_states_according_to_is_set( states .iter() .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e))) - .collect::>>() + .collect() } /// Combines array refs and their corresponding orderings to construct `SortColumn`s. @@ -1245,30 +1493,28 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec>() + .collect() } #[cfg(test)] mod tests { - use arrow::{array::Int64Array, compute::SortOptions, datatypes::Schema}; + use std::iter::repeat_with; + + use arrow::{ + array::{Int64Array, ListArray}, + compute::SortOptions, + datatypes::Schema, + }; use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; use super::*; #[test] fn test_first_last_value_value() -> Result<()> { - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; @@ -1305,22 +1551,14 @@ mod tests { .collect::>(); // FirstValueAccumulator - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = first_accumulator.state()?; - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = first_accumulator.state()?; @@ -1335,34 +1573,22 @@ mod tests { ])?); } - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.merge_batch(&states)?; let merged_state = first_accumulator.state()?; assert_eq!(merged_state.len(), state1.len()); // LastValueAccumulator - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = last_accumulator.state()?; - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = last_accumulator.state()?; @@ -1377,12 +1603,8 @@ mod tests { ])?); } - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.merge_batch(&states)?; let merged_state = last_accumulator.state()?; @@ -1392,7 +1614,7 @@ mod tests { } #[test] - fn test_frist_group_acc() -> Result<()> { + fn test_first_group_acc() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), @@ -1401,16 +1623,17 @@ mod tests { Field::new("e", DataType::Boolean, true), ])); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + let sort_keys = [PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]); + }]; let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( - sort_key, + sort_keys.into(), true, &DataType::Int64, &[DataType::Int64], + true, )?; let mut val_with_orderings = { @@ -1485,7 +1708,7 @@ mod tests { } #[test] - fn test_frist_group_acc_size_of_ordering() -> Result<()> { + fn test_group_acc_size_of_ordering() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), @@ -1494,16 +1717,17 @@ mod tests { Field::new("e", DataType::Boolean, true), ])); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + let sort_keys = [PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]); + }]; let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( - sort_key, + sort_keys.into(), true, &DataType::Int64, &[DataType::Int64], + true, )?; let val_with_orderings = { @@ -1563,4 +1787,131 @@ mod tests { Ok(()) } + + #[test] + fn test_last_group_acc() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])); + + let sort_keys = [PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: SortOptions::default(), + }]; + + let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + sort_keys.into(), + true, + &DataType::Int64, + &[DataType::Int64], + false, + )?; + + let mut val_with_orderings = { + let mut val_with_orderings = Vec::::new(); + + let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)])); + let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6])); + + val_with_orderings.push(vals); + val_with_orderings.push(orderings); + + val_with_orderings + }; + + group_acc.update_batch( + &val_with_orderings, + &[0, 1, 2, 1], + Some(&BooleanArray::from(vec![true, true, false, true])), + 3, + )?; + + let state = group_acc.state(EmitTo::All)?; + + let expected_state: Vec> = vec![ + Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])), + Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])), + Arc::new(BooleanArray::from(vec![true, true, false])), + ]; + assert_eq!(state, expected_state); + + group_acc.merge_batch( + &state, + &[0, 1, 2], + Some(&BooleanArray::from(vec![true, false, false])), + 3, + )?; + + val_with_orderings.clear(); + val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6]))); + val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6]))); + + group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?; + + let binding = group_acc.evaluate(EmitTo::All)?; + let eval_result = binding.as_any().downcast_ref::().unwrap(); + + let expect: PrimitiveArray = + Int64Array::from(vec![Some(1), Some(66), Some(6), None]); + + assert_eq!(eval_result, &expect); + + Ok(()) + } + + #[test] + fn test_first_list_acc_size() -> Result<()> { + fn size_after_batch(values: &[ArrayRef]) -> Result { + let mut first_accumulator = TrivialFirstValueAccumulator::try_new( + &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))), + false, + )?; + + first_accumulator.update_batch(values)?; + + Ok(first_accumulator.size()) + } + + let batch1 = ListArray::from_iter_primitive::( + repeat_with(|| Some(vec![Some(1)])).take(10000), + ); + let batch2 = + ListArray::from_iter_primitive::([Some(vec![Some(1)])]); + + let size1 = size_after_batch(&[Arc::new(batch1)])?; + let size2 = size_after_batch(&[Arc::new(batch2)])?; + assert_eq!(size1, size2); + + Ok(()) + } + + #[test] + fn test_last_list_acc_size() -> Result<()> { + fn size_after_batch(values: &[ArrayRef]) -> Result { + let mut last_accumulator = TrivialLastValueAccumulator::try_new( + &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))), + false, + )?; + + last_accumulator.update_batch(values)?; + + Ok(last_accumulator.size()) + } + + let batch1 = ListArray::from_iter_primitive::( + repeat_with(|| Some(vec![Some(1)])).take(10000), + ); + let batch2 = + ListArray::from_iter_primitive::([Some(vec![Some(1)])]); + + let size1 = size_after_batch(&[Arc::new(batch1)])?; + let size2 = size_after_batch(&[Arc::new(batch2)])?; + assert_eq!(size1, size2); + + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 445774ff11e7d..4d1da1dad5949 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -20,8 +20,8 @@ use std::any::Any; use std::fmt; -use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; @@ -60,6 +60,7 @@ make_udaf_expr_and_func!( description = "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function." ) )] +#[derive(PartialEq, Eq, Hash)] pub struct Grouping { signature: Signature, } @@ -105,12 +106,13 @@ impl AggregateUDFImpl for Grouping { Ok(DataType::Int32) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "grouping"), DataType::Int32, true, - )]) + ) + .into()]) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 7944280291eb4..4f282301ce5bd 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -105,6 +105,7 @@ pub mod expr_fn { pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; pub use super::array_agg::array_agg; pub use super::average::avg; + pub use super::average::avg_distinct; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; @@ -134,6 +135,7 @@ pub mod expr_fn { pub use super::stddev::stddev; pub use super::stddev::stddev_pop; pub use super::sum::sum; + pub use super::sum::sum_distinct; pub use super::variance::var_pop; pub use super::variance::var_sample; } @@ -220,8 +222,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index b464dde6ccab5..6c6bf72838899 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#[macro_export] macro_rules! make_udaf_expr { ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function @@ -27,13 +28,14 @@ macro_rules! make_udaf_expr { vec![$($arg),*], false, None, - None, + vec![], None, )) } }; } +#[macro_export] macro_rules! make_udaf_expr_and_func { ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); @@ -50,7 +52,7 @@ macro_rules! make_udaf_expr_and_func { args, false, None, - None, + vec![], None, )) } @@ -59,6 +61,7 @@ macro_rules! make_udaf_expr_and_func { }; } +#[macro_export] macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index ba6b63260e068..a65759594eac2 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -35,7 +35,9 @@ use arrow::{ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; -use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; +use arrow::datatypes::{ + ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef, +}; use datafusion_common::{ internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue, @@ -81,6 +83,7 @@ make_udaf_expr_and_func!( /// If using the distinct variation, the memory usage will be similarly high if the /// cardinality is high as it stores all distinct values in memory before computing the /// result, but if cardinality is low then memory usage will also be lower. +#[derive(PartialEq, Eq, Hash)] pub struct Median { signature: Signature, } @@ -125,9 +128,9 @@ impl AggregateUDFImpl for Median { Ok(arg_types[0].clone()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new_list_field(args.input_types[0].clone(), true); + let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { @@ -138,7 +141,8 @@ impl AggregateUDFImpl for Median { format_state_name(args.name, state_name), DataType::List(Arc::new(field)), true, - )]) + ) + .into()]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -164,6 +168,8 @@ impl AggregateUDFImpl for Median { DataType::Float16 => helper!(Float16Type, dt), DataType::Float32 => helper!(Float32Type, dt), DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal32(_, _) => helper!(Decimal32Type, dt), + DataType::Decimal64(_, _) => helper!(Decimal64Type, dt), DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), _ => Err(DataFusionError::NotImplemented(format!( @@ -203,6 +209,8 @@ impl AggregateUDFImpl for Median { DataType::Float16 => helper!(Float16Type, dt), DataType::Float32 => helper!(Float32Type, dt), DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal32(_, _) => helper!(Decimal32Type, dt), + DataType::Decimal64(_, _) => helper!(Decimal64Type, dt), DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), _ => Err(DataFusionError::NotImplemented(format!( @@ -213,10 +221,6 @@ impl AggregateUDFImpl for Median { } } - fn aliases(&self) -> &[String] { - &[] - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index ea4cad5488031..1a46afefffb3b 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -19,29 +19,17 @@ //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function mod min_max_bytes; +mod min_max_struct; -use arrow::array::{ - ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, - DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, - LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; -use arrow::compute; +use arrow::array::ArrayRef; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, DurationMicrosecondType, - DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, + Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::stats::Precision; -use datafusion_common::{ - downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result, -}; +use datafusion_common::{exec_err, internal_err, ColumnStatistics, Result}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_physical_expr::expressions; use std::cmp::Ordering; @@ -55,6 +43,7 @@ use arrow::datatypes::{ }; use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; +use crate::min_max::min_max_struct::MinMaxStructAccumulator; use datafusion_common::ScalarValue; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, @@ -102,7 +91,7 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result> { standard_argument(name = "expression",) )] // MAX aggregate UDF -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Max { signature: Signature, } @@ -231,17 +220,15 @@ impl AggregateUDFImpl for Max { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?)) - } - - fn aliases(&self) -> &[String] { - &[] + Ok(Box::new(MaxAccumulator::try_new( + acc_args.return_field.data_type(), + )?)) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -252,6 +239,8 @@ impl AggregateUDFImpl for Max { | Float16 | Float32 | Float64 + | Decimal32(_, _) + | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _) | Date32 @@ -266,6 +255,7 @@ impl AggregateUDFImpl for Max { | LargeBinary | BinaryView | Duration(_) + | Struct(_) ) } @@ -275,7 +265,7 @@ impl AggregateUDFImpl for Max { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.return_type; + let data_type = args.return_field.data_type(); match data_type { Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), @@ -332,6 +322,12 @@ impl AggregateUDFImpl for Max { Duration(Nanosecond) => { primitive_max_accumulator!(data_type, i64, DurationNanosecondType) } + Decimal32(_, _) => { + primitive_max_accumulator!(data_type, i32, Decimal32Type) + } + Decimal64(_, _) => { + primitive_max_accumulator!(data_type, i64, Decimal64Type) + } Decimal128(_, _) => { primitive_max_accumulator!(data_type, i128, Decimal128Type) } @@ -341,7 +337,9 @@ impl AggregateUDFImpl for Max { Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) } - + Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_max( + data_type.clone(), + ))), // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), } @@ -351,7 +349,9 @@ impl AggregateUDFImpl for Max { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?)) + Ok(Box::new(SlidingMaxAccumulator::try_new( + args.return_field.data_type(), + )?)) } fn is_descending(&self) -> Option { @@ -383,597 +383,6 @@ impl AggregateUDFImpl for Max { } } -// Statically-typed version of min/max(array) -> ScalarValue for string types -macro_rules! typed_min_max_batch_string { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - let value = value.and_then(|e| Some(e.to_string())); - ScalarValue::$SCALAR(value) - }}; -} -// Statically-typed version of min/max(array) -> ScalarValue for binary types. -macro_rules! typed_min_max_batch_binary { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - let value = value.and_then(|e| Some(e.to_vec())); - ScalarValue::$SCALAR(value) - }}; -} - -// Statically-typed version of min/max(array) -> ScalarValue for non-string types. -macro_rules! typed_min_max_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) - }}; -} - -// Statically-typed version of min/max(array) -> ScalarValue for non-string types. -// this is a macro to support both operations (min and max). -macro_rules! min_max_batch { - ($VALUES:expr, $OP:ident) => {{ - match $VALUES.data_type() { - DataType::Null => ScalarValue::Null, - DataType::Decimal128(precision, scale) => { - typed_min_max_batch!( - $VALUES, - Decimal128Array, - Decimal128, - $OP, - precision, - scale - ) - } - DataType::Decimal256(precision, scale) => { - typed_min_max_batch!( - $VALUES, - Decimal256Array, - Decimal256, - $OP, - precision, - scale - ) - } - // all types that have a natural order - DataType::Float64 => { - typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) - } - DataType::Float32 => { - typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) - } - DataType::Float16 => { - typed_min_max_batch!($VALUES, Float16Array, Float16, $OP) - } - DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), - DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), - DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), - DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), - DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), - DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), - DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), - DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_min_max_batch!( - $VALUES, - TimestampSecondArray, - TimestampSecond, - $OP, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampMillisecondArray, - TimestampMillisecond, - $OP, - tz_opt - ), - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampMicrosecondArray, - TimestampMicrosecond, - $OP, - tz_opt - ), - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampNanosecondArray, - TimestampNanosecond, - $OP, - tz_opt - ), - DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), - DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), - DataType::Time32(TimeUnit::Second) => { - typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) - } - DataType::Time32(TimeUnit::Millisecond) => { - typed_min_max_batch!( - $VALUES, - Time32MillisecondArray, - Time32Millisecond, - $OP - ) - } - DataType::Time64(TimeUnit::Microsecond) => { - typed_min_max_batch!( - $VALUES, - Time64MicrosecondArray, - Time64Microsecond, - $OP - ) - } - DataType::Time64(TimeUnit::Nanosecond) => { - typed_min_max_batch!( - $VALUES, - Time64NanosecondArray, - Time64Nanosecond, - $OP - ) - } - DataType::Interval(IntervalUnit::YearMonth) => { - typed_min_max_batch!( - $VALUES, - IntervalYearMonthArray, - IntervalYearMonth, - $OP - ) - } - DataType::Interval(IntervalUnit::DayTime) => { - typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - typed_min_max_batch!( - $VALUES, - IntervalMonthDayNanoArray, - IntervalMonthDayNano, - $OP - ) - } - DataType::Duration(TimeUnit::Second) => { - typed_min_max_batch!($VALUES, DurationSecondArray, DurationSecond, $OP) - } - DataType::Duration(TimeUnit::Millisecond) => { - typed_min_max_batch!( - $VALUES, - DurationMillisecondArray, - DurationMillisecond, - $OP - ) - } - DataType::Duration(TimeUnit::Microsecond) => { - typed_min_max_batch!( - $VALUES, - DurationMicrosecondArray, - DurationMicrosecond, - $OP - ) - } - DataType::Duration(TimeUnit::Nanosecond) => { - typed_min_max_batch!( - $VALUES, - DurationNanosecondArray, - DurationNanosecond, - $OP - ) - } - other => { - // This should have been handled before - return internal_err!( - "Min/Max accumulator not implemented for type {:?}", - other - ); - } - } - }}; -} - -/// dynamically-typed min(array) -> ScalarValue -fn min_batch(values: &ArrayRef) -> Result { - Ok(match values.data_type() { - DataType::Utf8 => { - typed_min_max_batch_string!(values, StringArray, Utf8, min_string) - } - DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) - } - DataType::Utf8View => { - typed_min_max_batch_string!( - values, - StringViewArray, - Utf8View, - min_string_view - ) - } - DataType::Boolean => { - typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) - } - DataType::Binary => { - typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) - } - DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - min_binary - ) - } - DataType::BinaryView => { - typed_min_max_batch_binary!( - &values, - BinaryViewArray, - BinaryView, - min_binary_view - ) - } - _ => min_max_batch!(values, min), - }) -} - -/// dynamically-typed max(array) -> ScalarValue -pub fn max_batch(values: &ArrayRef) -> Result { - Ok(match values.data_type() { - DataType::Utf8 => { - typed_min_max_batch_string!(values, StringArray, Utf8, max_string) - } - DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) - } - DataType::Utf8View => { - typed_min_max_batch_string!( - values, - StringViewArray, - Utf8View, - max_string_view - ) - } - DataType::Boolean => { - typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) - } - DataType::Binary => { - typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) - } - DataType::BinaryView => { - typed_min_max_batch_binary!( - &values, - BinaryViewArray, - BinaryView, - max_binary_view - ) - } - DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - max_binary - ) - } - _ => min_max_batch!(values, max), - }) -} - -// min/max of two non-string scalar values. -macro_rules! typed_min_max { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ - ScalarValue::$SCALAR( - match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(*a), - (None, Some(b)) => Some(*b), - (Some(a), Some(b)) => Some((*a).$OP(*b)), - }, - $($EXTRA_ARGS.clone()),* - ) - }}; -} -macro_rules! typed_min_max_float { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR(match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(*a), - (None, Some(b)) => Some(*b), - (Some(a), Some(b)) => match a.total_cmp(b) { - choose_min_max!($OP) => Some(*b), - _ => Some(*a), - }, - }) - }}; -} - -// min/max of two scalar string values. -macro_rules! typed_min_max_string { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR(match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), - (Some(a), Some(b)) => Some((a).$OP(b).clone()), - }) - }}; -} - -macro_rules! choose_min_max { - (min) => { - std::cmp::Ordering::Greater - }; - (max) => { - std::cmp::Ordering::Less - }; -} - -macro_rules! interval_min_max { - ($OP:tt, $LHS:expr, $RHS:expr) => {{ - match $LHS.partial_cmp(&$RHS) { - Some(choose_min_max!($OP)) => $RHS.clone(), - Some(_) => $LHS.clone(), - None => { - return internal_err!("Comparison error while computing interval min/max") - } - } - }}; -} - -// min/max of two scalar values of the same type -macro_rules! min_max { - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - Ok(match ($VALUE, $DELTA) { - (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, - ( - lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - ( - lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { - typed_min_max!(lhs, rhs, Boolean, $OP) - } - (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - typed_min_max_float!(lhs, rhs, Float64, $OP) - } - (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - typed_min_max_float!(lhs, rhs, Float32, $OP) - } - (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { - typed_min_max_float!(lhs, rhs, Float16, $OP) - } - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - typed_min_max!(lhs, rhs, UInt64, $OP) - } - (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { - typed_min_max!(lhs, rhs, UInt32, $OP) - } - (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { - typed_min_max!(lhs, rhs, UInt16, $OP) - } - (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { - typed_min_max!(lhs, rhs, UInt8, $OP) - } - (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - typed_min_max!(lhs, rhs, Int64, $OP) - } - (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { - typed_min_max!(lhs, rhs, Int32, $OP) - } - (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { - typed_min_max!(lhs, rhs, Int16, $OP) - } - (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { - typed_min_max!(lhs, rhs, Int8, $OP) - } - (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8, $OP) - } - (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) - } - (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8View, $OP) - } - (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { - typed_min_max_string!(lhs, rhs, Binary, $OP) - } - (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeBinary, $OP) - } - (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { - typed_min_max_string!(lhs, rhs, BinaryView, $OP) - } - (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { - typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) - } - ( - ScalarValue::TimestampMillisecond(lhs, l_tz), - ScalarValue::TimestampMillisecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) - } - ( - ScalarValue::TimestampMicrosecond(lhs, l_tz), - ScalarValue::TimestampMicrosecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) - } - ( - ScalarValue::TimestampNanosecond(lhs, l_tz), - ScalarValue::TimestampNanosecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) - } - ( - ScalarValue::Date32(lhs), - ScalarValue::Date32(rhs), - ) => { - typed_min_max!(lhs, rhs, Date32, $OP) - } - ( - ScalarValue::Date64(lhs), - ScalarValue::Date64(rhs), - ) => { - typed_min_max!(lhs, rhs, Date64, $OP) - } - ( - ScalarValue::Time32Second(lhs), - ScalarValue::Time32Second(rhs), - ) => { - typed_min_max!(lhs, rhs, Time32Second, $OP) - } - ( - ScalarValue::Time32Millisecond(lhs), - ScalarValue::Time32Millisecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time32Millisecond, $OP) - } - ( - ScalarValue::Time64Microsecond(lhs), - ScalarValue::Time64Microsecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time64Microsecond, $OP) - } - ( - ScalarValue::Time64Nanosecond(lhs), - ScalarValue::Time64Nanosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) - } - ( - ScalarValue::IntervalYearMonth(lhs), - ScalarValue::IntervalYearMonth(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) - } - ( - ScalarValue::IntervalMonthDayNano(lhs), - ScalarValue::IntervalMonthDayNano(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) - } - ( - ScalarValue::IntervalDayTime(lhs), - ScalarValue::IntervalDayTime(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalDayTime, $OP) - } - ( - ScalarValue::IntervalYearMonth(_), - ScalarValue::IntervalMonthDayNano(_), - ) | ( - ScalarValue::IntervalYearMonth(_), - ScalarValue::IntervalDayTime(_), - ) | ( - ScalarValue::IntervalMonthDayNano(_), - ScalarValue::IntervalDayTime(_), - ) | ( - ScalarValue::IntervalMonthDayNano(_), - ScalarValue::IntervalYearMonth(_), - ) | ( - ScalarValue::IntervalDayTime(_), - ScalarValue::IntervalYearMonth(_), - ) | ( - ScalarValue::IntervalDayTime(_), - ScalarValue::IntervalMonthDayNano(_), - ) => { - interval_min_max!($OP, $VALUE, $DELTA) - } - ( - ScalarValue::DurationSecond(lhs), - ScalarValue::DurationSecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationSecond, $OP) - } - ( - ScalarValue::DurationMillisecond(lhs), - ScalarValue::DurationMillisecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationMillisecond, $OP) - } - ( - ScalarValue::DurationMicrosecond(lhs), - ScalarValue::DurationMicrosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) - } - ( - ScalarValue::DurationNanosecond(lhs), - ScalarValue::DurationNanosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationNanosecond, $OP) - } - e => { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - e - ) - } - }) - }}; -} - -/// An accumulator to compute the maximum value -#[derive(Debug)] -pub struct MaxAccumulator { - max: ScalarValue, -} - -impl MaxAccumulator { - /// new max accumulator - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - max: ScalarValue::try_from(datatype)?, - }) - } -} - -impl Accumulator for MaxAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &max_batch(values)?; - let new_max: Result = - min_max!(&self.max, delta, max); - self.max = new_max?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - fn evaluate(&mut self) -> Result { - Ok(self.max.clone()) - } - - fn size(&self) -> usize { - size_of_val(self) - size_of_val(&self.max) + self.max.size() - } -} - #[derive(Debug)] pub struct SlidingMaxAccumulator { max: ScalarValue, @@ -1047,7 +456,7 @@ impl Accumulator for SlidingMaxAccumulator { ```"#, standard_argument(name = "expression",) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Min { signature: Signature, } @@ -1098,17 +507,15 @@ impl AggregateUDFImpl for Min { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?)) - } - - fn aliases(&self) -> &[String] { - &[] + Ok(Box::new(MinAccumulator::try_new( + acc_args.return_field.data_type(), + )?)) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -1119,6 +526,8 @@ impl AggregateUDFImpl for Min { | Float16 | Float32 | Float64 + | Decimal32(_, _) + | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _) | Date32 @@ -1133,6 +542,7 @@ impl AggregateUDFImpl for Min { | LargeBinary | BinaryView | Duration(_) + | Struct(_) ) } @@ -1142,7 +552,7 @@ impl AggregateUDFImpl for Min { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.return_type; + let data_type = args.return_field.data_type(); match data_type { Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), @@ -1199,6 +609,12 @@ impl AggregateUDFImpl for Min { Duration(Nanosecond) => { primitive_min_accumulator!(data_type, i64, DurationNanosecondType) } + Decimal32(_, _) => { + primitive_min_accumulator!(data_type, i32, Decimal32Type) + } + Decimal64(_, _) => { + primitive_min_accumulator!(data_type, i64, Decimal64Type) + } Decimal128(_, _) => { primitive_min_accumulator!(data_type, i128, Decimal128Type) } @@ -1208,7 +624,9 @@ impl AggregateUDFImpl for Min { Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) } - + Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_min( + data_type.clone(), + ))), // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } @@ -1218,7 +636,9 @@ impl AggregateUDFImpl for Min { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?)) + Ok(Box::new(SlidingMinAccumulator::try_new( + args.return_field.data_type(), + )?)) } fn is_descending(&self) -> Option { @@ -1251,48 +671,6 @@ impl AggregateUDFImpl for Min { } } -/// An accumulator to compute the minimum value -#[derive(Debug)] -pub struct MinAccumulator { - min: ScalarValue, -} - -impl MinAccumulator { - /// new min accumulator - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - min: ScalarValue::try_from(datatype)?, - }) - } -} - -impl Accumulator for MinAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &min_batch(values)?; - let new_min: Result = - min_max!(&self.min, delta, min); - self.min = new_min?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - Ok(self.min.clone()) - } - - fn size(&self) -> usize { - size_of_val(self) - size_of_val(&self.min) + self.min.size() - } -} - #[derive(Debug)] pub struct SlidingMinAccumulator { min: ScalarValue, @@ -1624,11 +1002,23 @@ make_udaf_expr_and_func!( min_udaf ); +// Re-export accumulators from the common module for backwards compatibility +pub use datafusion_functions_aggregate_common::min_max::{ + MaxAccumulator, MinAccumulator, +}; + #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::{ - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + use arrow::{ + array::{ + DictionaryArray, Float32Array, Int32Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, StringArray, + }, + datatypes::{ + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, + }, }; use std::sync::Arc; @@ -1768,10 +1158,10 @@ mod tests { use rand::Rng; fn get_random_vec_i32(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut input = Vec::with_capacity(len); for _i in 0..len { - input.push(rng.gen_range(0..100)); + input.push(rng.random_range(0..100)); } input } @@ -1854,9 +1244,31 @@ mod tests { #[test] fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> { let data_type = - DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); let result = get_min_max_result_type(&[data_type])?; - assert_eq!(result, vec![DataType::Int32]); + assert_eq!(result, vec![DataType::Utf8]); + Ok(()) + } + + #[test] + fn test_min_max_dictionary() -> Result<()> { + let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]); + let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]); + let dict_array = + DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); + let dict_array_ref = Arc::new(dict_array) as ArrayRef; + let rt_type = + get_min_max_result_type(&[dict_array_ref.data_type().clone()])?[0].clone(); + + let mut min_acc = MinAccumulator::try_new(&rt_type)?; + min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let min_result = min_acc.evaluate()?; + assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string()))); + + let mut max_acc = MaxAccumulator::try_new(&rt_type)?; + max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let max_result = max_acc.evaluate()?; + assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); Ok(()) } } diff --git a/datafusion/functions-aggregate/src/min_max/min_max_struct.rs b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs new file mode 100644 index 0000000000000..8038f2f01d90c --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs @@ -0,0 +1,544 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp::Ordering, sync::Arc}; + +use arrow::{ + array::{ + Array, ArrayData, ArrayRef, AsArray, BooleanArray, MutableArrayData, StructArray, + }, + datatypes::DataType, +}; +use datafusion_common::{ + internal_err, + scalar::{copy_array_data, partial_cmp_struct}, + Result, +}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; + +/// Accumulator for MIN/MAX operations on Struct data types. +/// +/// This accumulator tracks the minimum or maximum struct value encountered +/// during aggregation, depending on the `is_min` flag. +/// +/// The comparison is done based on the struct fields in order. +pub(crate) struct MinMaxStructAccumulator { + /// Inner data storage. + inner: MinMaxStructState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxStructAccumulator { + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxStructState::new(data_type), + is_min: true, + } + } + + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxStructState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxStructAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + assert_eq!(array.data_type(), &self.inner.data_type); + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + + fn struct_min(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Less)) + } + + fn struct_max(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Greater)) + } + + if self.is_min { + self.inner.update_batch( + array.as_struct(), + group_indices, + total_num_groups, + struct_min, + ) + } else { + self.inner.update_batch( + array.as_struct(), + group_indices, + total_num_groups, + struct_max, + ) + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (_, min_maxes) = self.inner.emit_to(emit_to); + let fields = match &self.inner.data_type { + DataType::Struct(fields) => fields, + _ => return internal_err!("Data type is not a struct"), + }; + let null_array = StructArray::new_null(fields.clone(), 1); + let min_maxes_data: Vec = min_maxes + .iter() + .map(|v| match v { + Some(v) => v.to_data(), + None => null_array.to_data(), + }) + .collect(); + let min_maxes_refs: Vec<&ArrayData> = min_maxes_data.iter().collect(); + let mut copy = MutableArrayData::new(min_maxes_refs, true, min_maxes_data.len()); + + for (i, item) in min_maxes_data.iter().enumerate() { + copy.extend(i, 0, item.len()); + } + let result = copy.freeze(); + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(Arc::new(StructArray::from(result))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let output = apply_filter_as_nulls(&values[0], opt_filter)?; + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +#[derive(Debug)] +struct MinMaxStructState { + /// The minimum/maximum value for each group + min_max: Vec>, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone)] +enum MinMaxLocation { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(StructArray), +} + +/// Implement the MinMaxStructState with a comparison function +/// for comparing structs +impl MinMaxStructState { + /// Create a new MinMaxStructState + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &StructArray) { + let new_val = StructArray::from(copy_array_data(&new_val.to_data())); + match self.min_max[group_index].as_mut() { + None => { + self.total_data_bytes += new_val.get_array_memory_size(); + self.min_max[group_index] = Some(new_val); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.get_array_memory_size(); + self.total_data_bytes += new_val.get_array_memory_size(); + *existing_val = new_val; + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch( + &mut self, + array: &StructArray, + group_indices: &[usize], + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&StructArray, &StructArray) -> bool + Send + Sync, + { + self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owned values in `self.min_maxes` at most once + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + // Figure out the new min value for each group + for (index, group_index) in (0..array.len()).zip(group_indices.iter()) { + let group_index = *group_index; + if array.is_null(index) { + continue; + } + let new_val = array.slice(index, 1); + + let existing_val = match &locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(existing_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + existing_val + } + }; + + // Compare the new value to the existing value, replacing if necessary + if cmp(&new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with any new min/max values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_data_capacity: usize = first_min_maxes + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= first_data_capacity; + (first_data_capacity, first_min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + self.min_max.len() * size_of::>() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray, StructArray}; + use arrow::datatypes::{DataType, Field, Fields, Int32Type}; + use std::sync::Arc; + + fn create_test_struct_array( + int_values: Vec>, + str_values: Vec>, + ) -> StructArray { + let int_array = Int32Array::from(int_values); + let str_array = StringArray::from(str_values); + + let fields = vec![ + Field::new("int_field", DataType::Int32, true), + Field::new("str_field", DataType::Utf8, true), + ]; + + StructArray::new( + Fields::from(fields), + vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ], + None, + ) + } + + fn create_nested_struct_array( + int_values: Vec>, + str_values: Vec>, + ) -> StructArray { + let inner_struct = create_test_struct_array(int_values, str_values); + + let fields = vec![Field::new("inner", inner_struct.data_type().clone(), true)]; + + StructArray::new( + Fields::from(fields), + vec![Arc::new(inner_struct) as ArrayRef], + None, + ) + } + + #[test] + fn test_min_max_simple_struct() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_nested_struct() { + let array = create_nested_struct_array( + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let inner = min_result.column(0).as_struct(); + let int_array = inner.column(0).as_primitive::(); + let str_array = inner.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let inner = max_result.column(0).as_struct(); + let int_array = inner.column(0).as_primitive::(); + let str_array = inner.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_with_nulls() { + let array = create_test_struct_array( + vec![Some(1), None, Some(3)], + vec![Some("a"), None, Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_multiple_groups() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3), Some(4)], + vec![Some("a"), Some("b"), Some("c"), Some("d")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 1, 0, 1]; + + min_accumulator + .update_batch(&values, &group_indices, None, 2) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 2) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 2); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + assert_eq!(int_array.value(1), 2); + assert_eq!(str_array.value(1), "b"); + + assert_eq!(max_result.len(), 2); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + assert_eq!(int_array.value(1), 4); + assert_eq!(str_array.value(1), "d"); + } + + #[test] + fn test_min_max_with_filter() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3), Some(4)], + vec![Some("a"), Some("b"), Some("c"), Some("d")], + ); + + // Create a filter that only keeps even numbers + let filter = BooleanArray::from(vec![false, true, false, true]); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, Some(&filter), 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, Some(&filter), 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 2); + assert_eq!(str_array.value(0), "b"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 4); + assert_eq!(str_array.value(0), "d"); + } +} diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index d84bd02a6bafe..b9dc498ee7469 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -24,7 +24,7 @@ use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; @@ -86,10 +86,10 @@ pub fn nth_value( description = "The position (nth) of the value to retrieve, based on the ordering." ) )] -/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi +/// Expression for a `NTH_VALUE(..., ... ORDER BY ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, /// and then their results are merged. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NthValueAgg { signature: Signature, } @@ -148,27 +148,28 @@ impl AggregateUDFImpl for NthValueAgg { } }; - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialNthValueAccumulator::try_new( + n, + acc_args.return_field.data_type(), + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; - NthValueAccumulator::try_new( - n, - &data_type, - &ordering_dtypes, - acc_args.ordering_req.clone(), - ) - .map(|acc| Box::new(acc) as _) + NthValueAccumulator::try_new(n, &data_type, &ordering_dtypes, ordering) + .map(|acc| Box::new(acc) as _) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), false, )]; let orderings = args.ordering_fields.to_vec(); @@ -179,11 +180,7 @@ impl AggregateUDFImpl for NthValueAgg { false, )); } - Ok(fields) - } - - fn aliases(&self) -> &[String] { - &[] + Ok(fields.into_iter().map(Arc::new).collect()) } fn reverse_expr(&self) -> ReversedUDAF { @@ -195,6 +192,126 @@ impl AggregateUDFImpl for NthValueAgg { } } +#[derive(Debug)] +pub struct TrivialNthValueAccumulator { + /// The `N` value. + n: i64, + /// Stores entries in the `NTH_VALUE` result. + values: VecDeque, + /// Data types of the value. + datatype: DataType, +} + +impl TrivialNthValueAccumulator { + /// Create a new order-insensitive NTH_VALUE accumulator based on the given + /// item data type. + pub fn try_new(n: i64, datatype: &DataType) -> Result { + if n == 0 { + // n cannot be 0 + return internal_err!("Nth value indices are 1 based. 0 is invalid index"); + } + Ok(Self { + n, + values: VecDeque::new(), + datatype: datatype.clone(), + }) + } + + /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete + /// None represents all of the new `values` need to be added to the state. + fn append_new_data( + &mut self, + values: &[ArrayRef], + fetch: Option, + ) -> Result<()> { + let n_row = values[0].len(); + let n_to_add = if let Some(fetch) = fetch { + std::cmp::min(fetch, n_row) + } else { + n_row + }; + for index in 0..n_to_add { + let mut row = get_row_at_idx(values, index)?; + self.values.push_back(row.swap_remove(0)); + // At index 1, we have n index argument, which is constant. + } + Ok(()) + } +} + +impl Accumulator for TrivialNthValueAccumulator { + /// Updates its state with the `values`. Assumes data in the `values` satisfies the required + /// ordering for the accumulator (across consecutive batches, not just batch-wise). + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if !values.is_empty() { + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + if from_start { + // direction is from start + let n_remaining = n_required.saturating_sub(self.values.len()); + self.append_new_data(values, Some(n_remaining))?; + } else { + // direction is from end + self.append_new_data(values, None)?; + let start_offset = self.values.len().saturating_sub(n_required); + if start_offset > 0 { + self.values.drain(0..start_offset); + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if !states.is_empty() { + // First entry in the state is the aggregation result. + let n_required = self.n.unsigned_abs() as usize; + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for v in array_agg_res.into_iter().flatten() { + self.values.extend(v); + if self.values.len() > n_required { + // There is enough data collected, can stop merging: + break; + } + } + } + Ok(()) + } + + fn state(&mut self) -> Result> { + let mut values_cloned = self.values.clone(); + let values_slice = values_cloned.make_contiguous(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatype, + ))]) + } + + fn evaluate(&mut self) -> Result { + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + let nth_value_idx = if from_start { + // index is from start + let forward_idx = n_required - 1; + (forward_idx < self.values.len()).then_some(forward_idx) + } else { + // index is from end + self.values.len().checked_sub(n_required) + }; + if let Some(idx) = nth_value_idx { + Ok(self.values[idx].clone()) + } else { + ScalarValue::try_from(self.datatype.clone()) + } + } + + fn size(&self) -> usize { + size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values) + - size_of_val(&self.values) + + size_of::() + } +} + #[derive(Debug)] pub struct NthValueAccumulator { /// The `N` value. @@ -236,6 +353,64 @@ impl NthValueAccumulator { ordering_req, }) } + + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + + let mut column_wise_ordering_values = vec![]; + let num_columns = fields.len(); + for i in 0..num_columns { + let column_values = self + .ordering_values + .iter() + .map(|x| x[i].clone()) + .collect::>(); + let array = if column_values.is_empty() { + new_empty_array(fields[i].data_type()) + } else { + ScalarValue::iter_to_array(column_values.into_iter())? + }; + column_wise_ordering_values.push(array); + } + + let struct_field = Fields::from(fields); + let ordering_array = + StructArray::try_new(struct_field, column_wise_ordering_values, None)?; + + Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) + } + + fn evaluate_values(&self) -> ScalarValue { + let mut values_cloned = self.values.clone(); + let values_slice = values_cloned.make_contiguous(); + ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatypes[0], + )) + } + + /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete + /// None represents all of the new `values` need to be added to the state. + fn append_new_data( + &mut self, + values: &[ArrayRef], + fetch: Option, + ) -> Result<()> { + let n_row = values[0].len(); + let n_to_add = if let Some(fetch) = fetch { + std::cmp::min(fetch, n_row) + } else { + n_row + }; + for index in 0..n_to_add { + let row = get_row_at_idx(values, index)?; + self.values.push_back(row[0].clone()); + // At index 1, we have n index argument. + // Ordering values cover starting from 2nd index to end + self.ordering_values.push_back(row[2..].to_vec()); + } + Ok(()) + } } impl Accumulator for NthValueAccumulator { @@ -269,91 +444,60 @@ impl Accumulator for NthValueAccumulator { if states.is_empty() { return Ok(()); } - // First entry in the state is the aggregation result. - let array_agg_values = &states[0]; - let n_required = self.n.unsigned_abs() as usize; - if self.ordering_req.is_empty() { - let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - for v in array_agg_res.into_iter() { - self.values.extend(v); - if self.values.len() > n_required { - // There is enough data collected can stop merging - break; - } - } - } else if let Some(agg_orderings) = states[1].as_list_opt::() { - // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside NTH_VALUE list. - // For each `StructArray` inside NTH_VALUE list, we will receive an `Array` that stores - // values received from its ordering requirement expression. (This information is necessary for during merging). - - // Stores NTH_VALUE results coming from each partition - let mut partition_values: Vec> = vec![]; - // Stores ordering requirement expression results coming from each partition - let mut partition_ordering_values: Vec>> = vec![]; - - // Existing values should be merged also. - partition_values.push(self.values.clone()); - - partition_ordering_values.push(self.ordering_values.clone()); - - let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - - for v in array_agg_res.into_iter() { - partition_values.push(v.into()); - } - - let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - - let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { - // Extract value from struct to ordering_rows for each group/partition - partition_ordering_rows.into_iter().map(|ordering_row| { - if let ScalarValue::Struct(s) = ordering_row { - let mut ordering_columns_per_row = vec![]; - - for column in s.columns() { - let sv = ScalarValue::try_from_array(column, 0)?; - ordering_columns_per_row.push(sv); - } - - Ok(ordering_columns_per_row) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", - ordering_row.data_type() - ) - } - }).collect::>>() - }).collect::>>()?; - for ordering_values in ordering_values.into_iter() { - partition_ordering_values.push(ordering_values.into()); - } - - let sort_options = self - .ordering_req - .iter() - .map(|sort_expr| sort_expr.options) - .collect::>(); - let (new_values, new_orderings) = merge_ordered_arrays( - &mut partition_values, - &mut partition_ordering_values, - &sort_options, - )?; - self.values = new_values.into(); - self.ordering_values = new_orderings.into(); - } else { + // Second entry stores values received for ordering requirement columns + // for each aggregation value inside NTH_VALUE list. For each `StructArray` + // inside this list, we will receive an `Array` that stores values received + // from its ordering requirement expression. This information is necessary + // during merging. + let Some(agg_orderings) = states[1].as_list_opt::() else { return exec_err!("Expects to receive a list array"); + }; + + // Stores NTH_VALUE results coming from each partition + let mut partition_values = vec![self.values.clone()]; + // First entry in the state is the aggregation result. + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for v in array_agg_res.into_iter().flatten() { + partition_values.push(v.into()); + } + // Stores ordering requirement expression results coming from each partition: + let mut partition_ordering_values = vec![self.ordering_values.clone()]; + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + // Extract value from struct to ordering_rows for each group/partition: + for partition_ordering_rows in orderings.into_iter().flatten() { + let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| { + let ScalarValue::Struct(s_array) = ordering_row else { + return exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", + ordering_row.data_type() + ); + }; + s_array + .columns() + .iter() + .map(|column| ScalarValue::try_from_array(column, 0)) + .collect() + }).collect::>>()?; + partition_ordering_values.push(ordering_values); } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + let (new_values, new_orderings) = merge_ordered_arrays( + &mut partition_values, + &mut partition_ordering_values, + &sort_options, + )?; + self.values = new_values.into(); + self.ordering_values = new_orderings.into(); Ok(()) } fn state(&mut self) -> Result> { - let mut result = vec![self.evaluate_values()]; - if !self.ordering_req.is_empty() { - result.push(self.evaluate_orderings()?); - } - Ok(result) + Ok(vec![self.evaluate_values(), self.evaluate_orderings()?]) } fn evaluate(&mut self) -> Result { @@ -396,63 +540,3 @@ impl Accumulator for NthValueAccumulator { total } } - -impl NthValueAccumulator { - fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); - let struct_field = Fields::from(fields.clone()); - - let mut column_wise_ordering_values = vec![]; - let num_columns = fields.len(); - for i in 0..num_columns { - let column_values = self - .ordering_values - .iter() - .map(|x| x[i].clone()) - .collect::>(); - let array = if column_values.is_empty() { - new_empty_array(fields[i].data_type()) - } else { - ScalarValue::iter_to_array(column_values.into_iter())? - }; - column_wise_ordering_values.push(array); - } - - let ordering_array = - StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - - Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) - } - - fn evaluate_values(&self) -> ScalarValue { - let mut values_cloned = self.values.clone(); - let values_slice = values_cloned.make_contiguous(); - ScalarValue::List(ScalarValue::new_list_nullable( - values_slice, - &self.datatypes[0], - )) - } - - /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete - /// None represents all of the new `values` need to be added to the state. - fn append_new_data( - &mut self, - values: &[ArrayRef], - fetch: Option, - ) -> Result<()> { - let n_row = values[0].len(); - let n_to_add = if let Some(fetch) = fetch { - std::cmp::min(fetch, n_row) - } else { - n_row - }; - for index in 0..n_to_add { - let row = get_row_at_idx(values, index)?; - self.values.push_back(row[0].clone()); - // At index 1, we have n index argument. - // Ordering values cover starting from 2nd index to end - self.ordering_values.push_back(row[2..].to_vec()); - } - Ok(()) - } -} diff --git a/datafusion/functions-aggregate/src/planner.rs b/datafusion/functions-aggregate/src/planner.rs index c8cb841189954..f0e37f6b1dbe4 100644 --- a/datafusion/functions-aggregate/src/planner.rs +++ b/datafusion/functions-aggregate/src/planner.rs @@ -100,7 +100,7 @@ impl ExprPlanner for AggregateFunctionPlanner { let new_expr = Expr::AggregateFunction(AggregateFunction::new_udf( func, - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], distinct, filter, order_by, diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 82575d15e50b8..44ce0bd48ead6 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -18,6 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use arrow::array::Float64Array; +use arrow::datatypes::FieldRef; use arrow::{ array::{ArrayRef, UInt64Array}, compute::cast, @@ -28,7 +29,7 @@ use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result, ScalarValue, }; -use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; +use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; @@ -37,8 +38,9 @@ use datafusion_expr::{ }; use std::any::Any; use std::fmt::Debug; +use std::hash::Hash; use std::mem::size_of_val; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; macro_rules! make_regr_udaf_expr_and_func { ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { @@ -57,6 +59,7 @@ make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); +#[derive(PartialEq, Eq, Hash)] pub struct Regr { signature: Signature, regr_type: RegrType, @@ -142,6 +145,29 @@ static DOCUMENTATION: LazyLock> = LazyLock::new Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.", "regr_slope(expression_y, expression_x)") + .with_sql_example( + r#"```sql +create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80); +select * from weekly_performance; ++-----+--------------+ +| day | user_signups | ++-----+--------------+ +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | +| 4 | 75 | +| 5 | 80 | ++-----+--------------+ + +SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance; ++--------+ +| slope | ++--------+ +| 5.0 | ++--------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build() @@ -155,6 +181,30 @@ static DOCUMENTATION: LazyLock> = LazyLock::new this function returns b.", "regr_intercept(expression_y, expression_x)") + .with_sql_example( + r#"```sql +create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80); +select * from weekly_performance; ++------+---------------------+ +| week | productivity_score | +| ---- | ------------------- | +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | +| 4 | 75 | +| 5 | 80 | ++------+---------------------+ + +SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance; ++----------+ +|intercept| +|intercept | ++----------+ +| 55 | ++----------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build() @@ -167,6 +217,29 @@ static DOCUMENTATION: LazyLock> = LazyLock::new "Counts the number of non-null paired data points.", "regr_count(expression_y, expression_x)", ) + .with_sql_example( + r#"```sql +create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL); +select * from daily_metrics; ++-----+---------------+ +| day | user_signups | +| --- | ------------- | +| 1 | 100 | +| 2 | 120 | +| 3 | NULL | +| 4 | 110 | +| 5 | NULL | ++-----+---------------+ + +SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics; ++-------------+ +| valid_pairs | ++-------------+ +| 3 | ++-------------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build(), @@ -179,6 +252,29 @@ static DOCUMENTATION: LazyLock> = LazyLock::new "Computes the square of the correlation coefficient between the independent and dependent variables.", "regr_r2(expression_y, expression_x)") + .with_sql_example( + r#"```sql +create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80); +select * from weekly_performance; ++-----+--------------+ +| day | user_signups | ++-----+--------------+ +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | +| 4 | 75 | +| 5 | 80 | ++-----+--------------+ + +SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance; ++---------+ +|r_squared| ++---------+ +| 1.0 | ++---------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build() @@ -191,6 +287,29 @@ static DOCUMENTATION: LazyLock> = LazyLock::new "Computes the average of the independent variable (input) expression_x for the non-null paired data points.", "regr_avgx(expression_y, expression_x)") + .with_sql_example( + r#"```sql +create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250); +select * from daily_sales; ++-----+-------------+ +| day | total_sales | +| --- | ----------- | +| 1 | 100 | +| 2 | 150 | +| 3 | 200 | +| 4 | NULL | +| 5 | 250 | ++-----+-------------+ + +SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales; ++----------+ +| avg_day | ++----------+ +| 2.75 | ++----------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build() @@ -203,6 +322,30 @@ static DOCUMENTATION: LazyLock> = LazyLock::new "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.", "regr_avgy(expression_y, expression_x)") + .with_sql_example( + r#"```sql +create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36); +select * from daily_temperature; ++-----+-------------+ +| day | temperature | +| --- | ----------- | +| 1 | 30 | +| 2 | 32 | +| 3 | NULL | +| 4 | 35 | +| 5 | 36 | ++-----+-------------+ + +-- temperature as Dependent Variable(Y), day as Independent Variable(X) +SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature; ++-----------------+ +| avg_temperature | ++-----------------+ +| 33.25 | ++-----------------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build() @@ -215,6 +358,29 @@ static DOCUMENTATION: LazyLock> = LazyLock::new "Computes the sum of squares of the independent variable.", "regr_sxx(expression_y, expression_x)", ) + .with_sql_example( + r#"```sql +create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95); +select * from study_hours; ++------------+-------+------------+ +| student_id | hours | test_score | ++------------+-------+------------+ +| 1 | 2 | 55 | +| 2 | 4 | 65 | +| 3 | 6 | 75 | +| 4 | 8 | 85 | +| 5 | 10 | 95 | ++------------+-------+------------+ + +SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours; ++------+ +| sxx | ++------+ +| 40.0 | ++------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build(), @@ -227,6 +393,27 @@ static DOCUMENTATION: LazyLock> = LazyLock::new "Computes the sum of squares of the dependent variable.", "regr_syy(expression_y, expression_x)", ) + .with_sql_example( + r#"```sql +create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70); +select * from employee_productivity; ++------+--------------------+ +| week | productivity_score | ++------+--------------------+ +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | ++------+--------------------+ + +SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity; ++---------------+ +| sum_squares_y | ++---------------+ +| 50.0 | ++---------------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build(), @@ -239,6 +426,27 @@ static DOCUMENTATION: LazyLock> = LazyLock::new "Computes the sum of products of paired data points.", "regr_sxy(expression_y, expression_x)", ) + .with_sql_example( + r#"```sql +create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70); +select * from employee_productivity; ++------+--------------------+ +| week | productivity_score | ++------+--------------------+ +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | ++------+--------------------+ + +SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity; ++------------------------+ +| sum_product_deviations | ++------------------------+ +| 10.0 | ++------------------------+ +``` +"# + ) .with_standard_argument("expression_y", Some("Dependent variable")) .with_standard_argument("expression_x", Some("Independent variable")) .build(), @@ -278,7 +486,7 @@ impl AggregateUDFImpl for Regr { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -310,7 +518,10 @@ impl AggregateUDFImpl for Regr { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index adf86a128cfb1..312d5f11b4771 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -19,12 +19,13 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::hash::Hash; use std::mem::align_of_val; use std::sync::Arc; use arrow::array::Float64Array; +use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; - use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -61,6 +62,7 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +#[derive(PartialEq, Eq, Hash)] pub struct Stddev { signature: Signature, alias: Vec, @@ -109,7 +111,7 @@ impl AggregateUDFImpl for Stddev { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -122,7 +124,10 @@ impl AggregateUDFImpl for Stddev { true, ), Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -175,6 +180,7 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV_POP population aggregate expression +#[derive(PartialEq, Eq, Hash)] pub struct StddevPop { signature: Signature, } @@ -217,7 +223,7 @@ impl AggregateUDFImpl for StddevPop { &self.signature } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -230,7 +236,10 @@ impl AggregateUDFImpl for StddevPop { true, ), Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -387,7 +396,6 @@ mod tests { use datafusion_expr::AggregateUDF; use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; #[test] @@ -436,10 +444,10 @@ mod tests { schema: &Schema, ) -> Result { let args1 = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], name: "a", is_distinct: false, is_reversed: false, @@ -447,10 +455,10 @@ mod tests { }; let args2 = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], name: "a", is_distinct: false, is_reversed: false, diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 64314ef6df687..a091ed34da70c 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -17,19 +17,28 @@ //! [`StringAgg`] accumulator for the `string_agg` function +use std::any::Any; +use std::hash::Hash; +use std::mem::size_of_val; + +use crate::array_agg::ArrayAgg; + use arrow::array::ArrayRef; -use arrow::datatypes::DataType; -use datafusion_common::cast::as_generic_string_array; -use datafusion_common::Result; -use datafusion_common::{not_impl_err, ScalarValue}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{ + as_generic_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{ + internal_datafusion_err, internal_err, not_impl_err, Result, ScalarValue, +}; use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, }; +use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; -use std::any::Any; -use std::mem::size_of_val; make_udaf_expr_and_func!( StringAgg, @@ -41,15 +50,31 @@ make_udaf_expr_and_func!( #[user_doc( doc_section(label = "General Functions"), - description = "Concatenates the values of string expressions and places separator values between them.", - syntax_example = "string_agg(expression, delimiter)", + description = "Concatenates the values of string expressions and places separator values between them. \ +If ordering is required, strings are concatenated in the specified order. \ +This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression.", + syntax_example = "string_agg([DISTINCT] expression, delimiter [ORDER BY expression])", sql_example = r#"```sql > SELECT string_agg(name, ', ') AS names_list FROM employee; +--------------------------+ | names_list | +--------------------------+ -| Alice, Bob, Charlie | +| Alice, Bob, Bob, Charlie | ++--------------------------+ +> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Charlie, Bob, Bob, Alice | ++--------------------------+ +> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Charlie, Bob, Alice | +--------------------------+ ```"#, argument( @@ -62,9 +87,10 @@ make_udaf_expr_and_func!( ) )] /// STRING_AGG aggregate expression -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct StringAgg { signature: Signature, + array_agg: ArrayAgg, } impl StringAgg { @@ -76,9 +102,19 @@ impl StringAgg { TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]), ], Volatility::Immutable, ), + array_agg: Default::default(), } } } @@ -89,6 +125,8 @@ impl Default for StringAgg { } } +/// If there is no `distinct` and `order by` required by the `string_agg` call, a +/// more efficient accumulator `SimpleStringAggAccumulator` will be used. impl AggregateUDFImpl for StringAgg { fn as_any(&self) -> &dyn Any { self @@ -106,20 +144,73 @@ impl AggregateUDFImpl for StringAgg { Ok(DataType::LargeUtf8) } + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + // See comments in `impl AggregateUDFImpl ...` for more detail + let no_order_no_distinct = + (args.ordering_fields.is_empty()) && (!args.is_distinct); + if no_order_no_distinct { + // Case `SimpleStringAggAccumulator` + Ok(vec![Field::new( + format_state_name(args.name, "string_agg"), + DataType::LargeUtf8, + true, + ) + .into()]) + } else { + // Case `StringAggAccumulator` + self.array_agg.state_fields(args) + } + } + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { - return match lit.value().try_as_str() { - Some(Some(delimiter)) => { - Ok(Box::new(StringAggAccumulator::new(delimiter))) - } - Some(None) => Ok(Box::new(StringAggAccumulator::new(""))), - None => { - not_impl_err!("StringAgg not supported for delimiter {}", lit.value()) - } - }; + let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() else { + return not_impl_err!( + "The second argument of the string_agg function must be a string literal" + ); + }; + + let delimiter = if lit.value().is_null() { + // If the second argument (the delimiter that joins strings) is NULL, join + // on an empty string. (e.g. [a, b, c] => "abc"). + "" + } else if let Some(lit_string) = lit.value().try_as_str() { + lit_string.unwrap_or("") + } else { + return not_impl_err!( + "StringAgg not supported for delimiter \"{}\"", + lit.value() + ); + }; + + // See comments in `impl AggregateUDFImpl ...` for more detail + let no_order_no_distinct = + acc_args.order_bys.is_empty() && (!acc_args.is_distinct); + + if no_order_no_distinct { + // simple case (more efficient) + Ok(Box::new(SimpleStringAggAccumulator::new(delimiter))) + } else { + // general case + let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { + return_field: Field::new( + "f", + DataType::new_list(acc_args.return_field.data_type().clone(), true), + true, + ) + .into(), + exprs: &filter_index(acc_args.exprs, 1), + ..acc_args + })?; + + Ok(Box::new(StringAggAccumulator::new( + array_agg_acc, + delimiter, + ))) } + } - not_impl_err!("expect literal") + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf()) } fn documentation(&self) -> Option<&Documentation> { @@ -127,16 +218,17 @@ impl AggregateUDFImpl for StringAgg { } } +/// StringAgg accumulator for the general case (with order or distinct specified) #[derive(Debug)] pub(crate) struct StringAggAccumulator { - values: Option, + array_agg_acc: Box, delimiter: String, } impl StringAggAccumulator { - pub fn new(delimiter: &str) -> Self { + pub fn new(array_agg_acc: Box, delimiter: &str) -> Self { Self { - values: None, + array_agg_acc, delimiter: delimiter.to_string(), } } @@ -144,37 +236,414 @@ impl StringAggAccumulator { impl Accumulator for StringAggAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_array: Vec<_> = as_generic_string_array::(&values[0])? - .iter() - .filter_map(|v| v.as_ref().map(ToString::to_string)) - .collect(); - if !string_array.is_empty() { - let s = string_array.join(self.delimiter.as_str()); - let v = self.values.get_or_insert("".to_string()); - if !v.is_empty() { - v.push_str(self.delimiter.as_str()); + self.array_agg_acc.update_batch(&filter_index(values, 1)) + } + + fn evaluate(&mut self) -> Result { + let scalar = self.array_agg_acc.evaluate()?; + + let ScalarValue::List(list) = scalar else { + return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type()); + }; + + let string_arr: Vec<_> = match list.value_type() { + DataType::LargeUtf8 => as_generic_string_array::(list.values())? + .iter() + .flatten() + .collect(), + DataType::Utf8 => as_generic_string_array::(list.values())? + .iter() + .flatten() + .collect(), + DataType::Utf8View => as_string_view_array(list.values())? + .iter() + .flatten() + .collect(), + _ => { + return internal_err!( + "Expected elements to of type Utf8 or LargeUtf8, but got {}", + list.value_type() + ) } - v.push_str(s.as_str()); + }; + + if string_arr.is_empty() { + return Ok(ScalarValue::LargeUtf8(None)); } - Ok(()) + + Ok(ScalarValue::LargeUtf8(Some( + string_arr.join(&self.delimiter), + ))) } - fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.update_batch(values)?; - Ok(()) + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.array_agg_acc) + + self.array_agg_acc.size() + + self.delimiter.capacity() } fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) + self.array_agg_acc.state() + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.array_agg_acc.merge_batch(values) + } +} + +fn filter_index(values: &[T], index: usize) -> Vec { + values + .iter() + .enumerate() + .filter(|(i, _)| *i != index) + .map(|(_, v)| v) + .cloned() + .collect::>() +} + +/// StringAgg accumulator for the simple case (no order or distinct specified) +/// This accumulator is more efficient than `StringAggAccumulator` +/// because it accumulates the string directly, +/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`. +#[derive(Debug)] +pub(crate) struct SimpleStringAggAccumulator { + delimiter: String, + /// Updated during `update_batch()`. e.g. "foo,bar" + accumulated_string: String, + has_value: bool, +} + +impl SimpleStringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + delimiter: delimiter.to_string(), + accumulated_string: "".to_string(), + has_value: false, + } + } + + #[inline] + fn append_strings<'a, I>(&mut self, iter: I) + where + I: Iterator>, + { + for value in iter.flatten() { + if self.has_value { + self.accumulated_string.push_str(&self.delimiter); + } + + self.accumulated_string.push_str(value); + self.has_value = true; + } + } +} + +impl Accumulator for SimpleStringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_arr = values.first().ok_or_else(|| { + internal_datafusion_err!( + "Planner should ensure its first arg is Utf8/Utf8View" + ) + })?; + + match string_arr.data_type() { + DataType::Utf8 => { + let array = as_string_array(string_arr)?; + self.append_strings(array.iter()); + } + DataType::LargeUtf8 => { + let array = as_generic_string_array::(string_arr)?; + self.append_strings(array.iter()); + } + DataType::Utf8View => { + let array = as_string_view_array(string_arr)?; + self.append_strings(array.iter()); + } + other => { + return internal_err!( + "Planner should ensure string_agg first argument is Utf8-like, found {other}" + ); + } + } + + Ok(()) } fn evaluate(&mut self) -> Result { - Ok(ScalarValue::LargeUtf8(self.values.clone())) + let result = if self.has_value { + ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) + } else { + ScalarValue::LargeUtf8(None) + }; + + self.has_value = false; + Ok(result) } fn size(&self) -> usize { - size_of_val(self) - + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) - + self.delimiter.capacity() + size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity() + } + + fn state(&mut self) -> Result> { + let result = if self.has_value { + ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) + } else { + ScalarValue::LargeUtf8(None) + }; + self.has_value = false; + + Ok(vec![result]) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::LargeStringArray; + use arrow::compute::SortOptions; + use arrow::datatypes::{Fields, Schema}; + use datafusion_common::internal_err; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use std::sync::Arc; + + #[test] + fn no_duplicates_no_distinct() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "a,b,c,d,e,f"); + + Ok(()) + } + + #[test] + fn no_duplicates_distinct() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc2.update_batch(&[data(["d", "e", "f"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str_sorted(acc1.evaluate()?, ","); + + assert_eq!(result, "a,b,c,d,e,f"); + + Ok(()) + } + + #[test] + fn duplicates_no_distinct() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",").build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "a,b,c,a,b,c"); + + Ok(()) + } + + #[test] + fn duplicates_distinct() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .build_two()?; + + acc1.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc2.update_batch(&[data(["a", "b", "c"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str_sorted(acc1.evaluate()?, ","); + + assert_eq!(result, "a,b,c"); + + Ok(()) + } + + #[test] + fn no_duplicates_distinct_sort_asc() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .order_by_col("col", SortOptions::new(false, false)) + .build_two()?; + + acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?; + acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "a,b,c,d,e,f"); + + Ok(()) + } + + #[test] + fn no_duplicates_distinct_sort_desc() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .order_by_col("col", SortOptions::new(true, false)) + .build_two()?; + + acc1.update_batch(&[data(["e", "b", "d"]), data([","])])?; + acc2.update_batch(&[data(["f", "a", "c"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "f,e,d,c,b,a"); + + Ok(()) + } + + #[test] + fn duplicates_distinct_sort_asc() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .order_by_col("col", SortOptions::new(false, false)) + .build_two()?; + + acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?; + acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "a,b,c"); + + Ok(()) + } + + #[test] + fn duplicates_distinct_sort_desc() -> Result<()> { + let (mut acc1, mut acc2) = StringAggAccumulatorBuilder::new(",") + .distinct() + .order_by_col("col", SortOptions::new(true, false)) + .build_two()?; + + acc1.update_batch(&[data(["a", "c", "b"]), data([","])])?; + acc2.update_batch(&[data(["b", "c", "a"]), data([","])])?; + acc1 = merge(acc1, acc2)?; + + let result = some_str(acc1.evaluate()?); + + assert_eq!(result, "c,b,a"); + + Ok(()) + } + + struct StringAggAccumulatorBuilder { + sep: String, + distinct: bool, + order_bys: Vec, + schema: Schema, + } + + impl StringAggAccumulatorBuilder { + fn new(sep: &str) -> Self { + Self { + sep: sep.to_string(), + distinct: Default::default(), + order_bys: vec![], + schema: Schema { + fields: Fields::from(vec![Field::new( + "col", + DataType::LargeUtf8, + true, + )]), + metadata: Default::default(), + }, + } + } + fn distinct(mut self) -> Self { + self.distinct = true; + self + } + + fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self { + self.order_bys.extend([PhysicalSortExpr::new( + Arc::new( + Column::new_with_schema(col, &self.schema) + .expect("column not available in schema"), + ), + sort_options, + )]); + self + } + + fn build(&self) -> Result> { + StringAgg::new().accumulator(AccumulatorArgs { + return_field: Field::new("f", DataType::LargeUtf8, true).into(), + schema: &self.schema, + ignore_nulls: false, + order_bys: &self.order_bys, + is_reversed: false, + name: "", + is_distinct: self.distinct, + exprs: &[ + Arc::new(Column::new("col", 0)), + Arc::new(Literal::new(ScalarValue::Utf8(Some(self.sep.to_string())))), + ], + }) + } + + fn build_two(&self) -> Result<(Box, Box)> { + Ok((self.build()?, self.build()?)) + } + } + + fn some_str(value: ScalarValue) -> String { + str(value) + .expect("ScalarValue was not a String") + .expect("ScalarValue was None") + } + + fn some_str_sorted(value: ScalarValue, sep: &str) -> String { + let value = some_str(value); + let mut parts: Vec<&str> = value.split(sep).collect(); + parts.sort(); + parts.join(sep) + } + + fn str(value: ScalarValue) -> Result> { + match value { + ScalarValue::LargeUtf8(v) => Ok(v), + _ => internal_err!( + "Expected ScalarValue::LargeUtf8, got {}", + value.data_type() + ), + } + } + + fn data(list: [&str; N]) -> ArrayRef { + Arc::new(LargeStringArray::from(list.to_vec())) + } + + fn merge( + mut acc1: Box, + mut acc2: Box, + ) -> Result> { + let intermediate_state = acc2.state().and_then(|e| { + e.iter() + .map(|v| v.to_array()) + .collect::>>() + })?; + acc1.merge_batch(&intermediate_state)?; + Ok(acc1) } } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 76a1315c2d889..958553d78ca51 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -18,23 +18,24 @@ //! Defines `SUM` and `SUM DISTINCT` aggregate accumulators use ahash::RandomState; +use arrow::datatypes::DECIMAL32_MAX_PRECISION; +use arrow::datatypes::DECIMAL64_MAX_PRECISION; use datafusion_expr::utils::AggregateOrderSensitivity; +use datafusion_expr::Expr; use std::any::Any; -use std::collections::HashSet; -use std::mem::{size_of, size_of_val}; +use std::mem::size_of_val; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::array::{ArrowNumericType, AsArray}; -use arrow::datatypes::ArrowNativeType; -use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::{ArrowNativeType, FieldRef}; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, - DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float64Type, + Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{ - exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, + exec_err, not_impl_err, utils::take_function_args, HashMap, Result, ScalarValue, }; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; @@ -44,7 +45,7 @@ use datafusion_expr::{ SetMonotonicity, Signature, Volatility, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use datafusion_functions_aggregate_common::utils::Hashable; +use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator; use datafusion_macros::user_doc; make_udaf_expr_and_func!( @@ -55,6 +56,17 @@ make_udaf_expr_and_func!( sum_udaf ); +pub fn sum_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + sum_udaf(), + vec![expr], + true, + None, + vec![], + None, + )) +} + /// Sum only supports a subset of numeric types, instead relying on type coercion /// /// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive) @@ -63,17 +75,33 @@ make_udaf_expr_and_func!( /// `helper` is a macro accepting (ArrowPrimitiveType, DataType) macro_rules! downcast_sum { ($args:ident, $helper:ident) => { - match $args.return_type { - DataType::UInt64 => $helper!(UInt64Type, $args.return_type), - DataType::Int64 => $helper!(Int64Type, $args.return_type), - DataType::Float64 => $helper!(Float64Type, $args.return_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type), + match $args.return_field.data_type().clone() { + DataType::UInt64 => { + $helper!(UInt64Type, $args.return_field.data_type().clone()) + } + DataType::Int64 => { + $helper!(Int64Type, $args.return_field.data_type().clone()) + } + DataType::Float64 => { + $helper!(Float64Type, $args.return_field.data_type().clone()) + } + DataType::Decimal32(_, _) => { + $helper!(Decimal32Type, $args.return_field.data_type().clone()) + } + DataType::Decimal64(_, _) => { + $helper!(Decimal64Type, $args.return_field.data_type().clone()) + } + DataType::Decimal128(_, _) => { + $helper!(Decimal128Type, $args.return_field.data_type().clone()) + } + DataType::Decimal256(_, _) => { + $helper!(Decimal256Type, $args.return_field.data_type().clone()) + } _ => { not_impl_err!( "Sum not supported for {}: {}", $args.name, - $args.return_type + $args.return_field.data_type() ) } } @@ -94,7 +122,7 @@ macro_rules! downcast_sum { ```"#, standard_argument(name = "expression",) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Sum { signature: Signature, } @@ -137,13 +165,14 @@ impl AggregateUDFImpl for Sum { DataType::Dictionary(_, v) => coerced_type(v), // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { - Ok(data_type.clone()) - } + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => Ok(data_type.clone()), dt if dt.is_signed_integer() => Ok(DataType::Int64), dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), dt if dt.is_floating() => Ok(DataType::Float64), - _ => exec_err!("Sum not supported for {}", data_type), + _ => exec_err!("Sum not supported for {data_type}"), } } @@ -155,6 +184,18 @@ impl AggregateUDFImpl for Sum { DataType::Int64 => Ok(DataType::Int64), DataType::UInt64 => Ok(DataType::UInt64), DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal32(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal32(new_precision, *scale)) + } + DataType::Decimal64(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal64(new_precision, *scale)) + } DataType::Decimal128(precision, scale) => { // in the spark, the result type is DECIMAL(min(38,precision+10), s) // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 @@ -177,7 +218,7 @@ impl AggregateUDFImpl for Sum { if args.is_distinct { macro_rules! helper { ($t:ty, $dt:expr) => { - Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?)) + Ok(Box::new(DistinctSumAccumulator::<$t>::new(&$dt))) }; } downcast_sum!(args, helper) @@ -191,27 +232,25 @@ impl AggregateUDFImpl for Sum { } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { Ok(vec![Field::new_list( format_state_name(args.name, "sum distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type.clone(), true), + Field::new_list_field(args.return_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, "sum"), - args.return_type.clone(), + args.return_type().clone(), true, - )]) + ) + .into()]) } } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { !args.is_distinct } @@ -235,12 +274,23 @@ impl AggregateUDFImpl for Sum { &self, args: AccumulatorArgs, ) -> Result> { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) - }; + if args.is_distinct { + // distinct path: use our sliding‐window distinct‐sum + macro_rules! helper_distinct { + ($t:ty, $dt:expr) => { + Ok(Box::new(SlidingDistinctSumAccumulator::try_new(&$dt)?)) + }; + } + downcast_sum!(args, helper_distinct) + } else { + // non‐distinct path: existing sliding sum + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(args, helper) } - downcast_sum!(args, helper) } fn reverse_expr(&self) -> ReversedUDAF { @@ -297,7 +347,7 @@ impl Accumulator for SumAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = values[0].as_primitive::(); if let Some(x) = arrow::compute::sum(values) { - let v = self.sum.get_or_insert(T::Native::usize_as(0)); + let v = self.sum.get_or_insert_with(|| T::Native::usize_as(0)); *v = v.add_wrapping(x); } Ok(()) @@ -389,83 +439,106 @@ impl Accumulator for SlidingSumAccumulator { } } -struct DistinctSumAccumulator { - values: HashSet, RandomState>, +/// A sliding‐window accumulator for `SUM(DISTINCT)` over Int64 columns. +/// Maintains a running sum so that `evaluate()` is O(1). +#[derive(Debug)] +pub struct SlidingDistinctSumAccumulator { + /// Map each distinct value → its current count in the window + counts: HashMap, + /// Running sum of all distinct keys currently in the window + sum: i64, + /// Data type (must be Int64) data_type: DataType, } -impl std::fmt::Debug for DistinctSumAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DistinctSumAccumulator({})", self.data_type) - } -} - -impl DistinctSumAccumulator { +impl SlidingDistinctSumAccumulator { + /// Create a new accumulator; only `DataType::Int64` is supported. pub fn try_new(data_type: &DataType) -> Result { + // TODO support other numeric types + if *data_type != DataType::Int64 { + return exec_err!("SlidingDistinctSumAccumulator only supports Int64"); + } Ok(Self { - values: HashSet::default(), + counts: HashMap::default(), + sum: 0, data_type: data_type.clone(), }) } } -impl Accumulator for DistinctSumAccumulator { - fn state(&mut self) -> Result> { - // 1. Stores aggregate state in `ScalarValue::List` - // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set - let state_out = { - let distinct_values = self - .values - .iter() - .map(|value| { - ScalarValue::new_primitive::(Some(value.0), &self.data_type) - }) - .collect::>>()?; - - vec![ScalarValue::List(ScalarValue::new_list_nullable( - &distinct_values, - &self.data_type, - ))] - }; - Ok(state_out) - } - +impl Accumulator for SlidingDistinctSumAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let array = values[0].as_primitive::(); - match array.nulls().filter(|x| x.null_count() > 0) { - Some(n) => { - for idx in n.valid_indices() { - self.values.insert(Hashable(array.value(idx))); - } + let arr = values[0].as_primitive::(); + for &v in arr.values() { + let cnt = self.counts.entry(v).or_insert(0); + if *cnt == 0 { + // first occurrence in window + self.sum = self.sum.wrapping_add(v); } - None => array.values().iter().for_each(|x| { - self.values.insert(Hashable(*x)); - }), + *cnt += 1; } Ok(()) } + fn evaluate(&mut self) -> Result { + // O(1) wrap of running sum + Ok(ScalarValue::Int64(Some(self.sum))) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + // Serialize distinct keys for cross-partition merge if needed + let keys = self + .counts + .keys() + .cloned() + .map(Some) + .map(ScalarValue::Int64) + .collect::>(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + &keys, + &self.data_type, + ))]) + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - for x in states[0].as_list::().iter().flatten() { - self.update_batch(&[x])? + // Merge distinct keys from other partitions + let list_arr = states[0].as_list::(); + for maybe_inner in list_arr.iter().flatten() { + for idx in 0..maybe_inner.len() { + if let ScalarValue::Int64(Some(v)) = + ScalarValue::try_from_array(&*maybe_inner, idx)? + { + let cnt = self.counts.entry(v).or_insert(0); + if *cnt == 0 { + self.sum = self.sum.wrapping_add(v); + } + *cnt += 1; + } + } } Ok(()) } - fn evaluate(&mut self) -> Result { - let mut acc = T::Native::usize_as(0); - for distinct_value in self.values.iter() { - acc = acc.add_wrapping(distinct_value.0) + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = values[0].as_primitive::(); + for &v in arr.values() { + if let Some(cnt) = self.counts.get_mut(&v) { + *cnt -= 1; + if *cnt == 0 { + // last copy leaving window + self.sum = self.sum.wrapping_sub(v); + self.counts.remove(&v); + } + } } - let v = (!self.values.is_empty()).then_some(acc); - ScalarValue::new_primitive::(v, &self.data_type) + Ok(()) } - fn size(&self) -> usize { - size_of_val(self) + self.values.capacity() * size_of::() + fn supports_retract_batch(&self) -> bool { + true } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 53e3e0cc56cd2..846c145cb11e7 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,15 +18,13 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. +use arrow::datatypes::FieldRef; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, compute::kernels::cast, datatypes::{DataType, Field}, }; -use std::mem::{size_of, size_of_val}; -use std::{fmt::Debug, sync::Arc}; - use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, @@ -38,6 +36,8 @@ use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; use datafusion_macros::user_doc; +use std::mem::{size_of, size_of_val}; +use std::{fmt::Debug, sync::Arc}; make_udaf_expr_and_func!( VarianceSample, @@ -61,6 +61,7 @@ make_udaf_expr_and_func!( syntax_example = "var(expression)", standard_argument(name = "expression", prefix = "Numeric") )] +#[derive(PartialEq, Eq, Hash)] pub struct VarianceSample { signature: Signature, aliases: Vec, @@ -107,13 +108,16 @@ impl AggregateUDFImpl for VarianceSample { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -150,6 +154,7 @@ impl AggregateUDFImpl for VarianceSample { syntax_example = "var_pop(expression)", standard_argument(name = "expression", prefix = "Numeric") )] +#[derive(PartialEq, Eq, Hash)] pub struct VariancePopulation { signature: Signature, aliases: Vec, @@ -200,13 +205,16 @@ impl AggregateUDFImpl for VariancePopulation { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 9a7b1f460ef54..9c0b7a16f9a9b 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -37,15 +37,21 @@ workspace = true [lib] name = "datafusion_functions_nested" +[features] +default = ["sql"] +sql = ["datafusion-expr/sql"] + [dependencies] arrow = { workspace = true } arrow-ord = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-execution = { workspace = true } -datafusion-expr = { workspace = true } +datafusion-expr = { workspace = true, default-features = false } +datafusion-expr-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true, features = ["use_std"] } diff --git a/datafusion/functions-nested/README.md b/datafusion/functions-nested/README.md index 8a5047c838ab0..6ab456edb1925 100644 --- a/datafusion/functions-nested/README.md +++ b/datafusion/functions-nested/README.md @@ -17,11 +17,18 @@ under the License. --> -# DataFusion Nested Type Function Library +# Apache DataFusion Nested Type Function Library -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate contains functions for working with arrays, maps and structs, such as `array_append` that work with -`ListArray`, `LargeListArray` and `FixedListArray` types from the `arrow` crate. +`ListArray`, `LargeListArray` and `FixedListArray` types from the [`arrow`] crate. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`arrow`]: https://crates.io/crates/arrow +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 2774b24b902a7..ca12dde1f5c39 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -21,22 +21,22 @@ use arrow::array::{Int32Array, ListArray, StringArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::prelude::ThreadRng; -use rand::Rng; -use std::collections::HashSet; -use std::sync::Arc; - +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs}; use datafusion_functions_nested::map::map_udf; use datafusion_functions_nested::planner::NestedFunctionPlanner; +use rand::prelude::ThreadRng; +use rand::Rng; +use std::collections::HashSet; +use std::sync::Arc; fn keys(rng: &mut ThreadRng) -> Vec { let mut keys = HashSet::with_capacity(1000); while keys.len() < 1000 { - keys.insert(rng.gen_range(0..10000).to_string()); + keys.insert(rng.random_range(0..10000).to_string()); } keys.into_iter().collect() @@ -46,20 +46,23 @@ fn values(rng: &mut ThreadRng) -> Vec { let mut values = HashSet::with_capacity(1000); while values.len() < 1000 { - values.insert(rng.gen_range(0..10000)); + values.insert(rng.random_range(0..10000)); } values.into_iter().collect() } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_map_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let keys = keys(&mut rng); let values = values(&mut rng); let mut buffer = Vec::new(); for i in 0..1000 { - buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } let planner = NestedFunctionPlanner {}; @@ -74,7 +77,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("map_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let field = Arc::new(Field::new_list_field(DataType::Utf8, true)); let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); let key_list = ListArray::new( @@ -94,17 +97,25 @@ fn criterion_benchmark(c: &mut Criterion) { let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); - let return_type = &map_udf() + let return_type = map_udf() .return_type(&[DataType::Utf8, DataType::Int32]) .expect("should get return type"); + let arg_fields = vec![ + Field::new("a", keys.data_type(), true).into(), + Field::new("a", values.data_type(), true).into(), + ]; + let return_field = Field::new("f", return_type, true).into(); + let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { black_box( map_udf() .invoke_with_args(ScalarFunctionArgs { args: vec![keys.clone(), values.clone()], + arg_fields: arg_fields.clone(), number_rows: 1, - return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 48ee341566b90..f34fea0c4ba07 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -17,20 +17,18 @@ //! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions. -use arrow::array::{ - Array, ArrayRef, BooleanArray, Datum, GenericListArray, OffsetSizeTrait, Scalar, -}; +use arrow::array::{Array, ArrayRef, BooleanArray, Datum, Scalar}; use arrow::buffer::BooleanBuffer; use arrow::datatypes::DataType; use arrow::row::{RowConverter, Rows, SortField}; -use datafusion_common::cast::as_generic_list_array; +use datafusion_common::cast::{as_fixed_size_list_array, as_generic_list_array}; use datafusion_common::utils::string_utils::string_array_to_vec; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::expr::{InList, ScalarFunction}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + in_list, ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use datafusion_physical_expr_common::datum::compare_with_eq; @@ -83,7 +81,7 @@ make_udf_expr_and_func!(ArrayHasAny, description = "Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayHas { signature: Signature, aliases: Vec, @@ -133,40 +131,45 @@ impl ScalarUDFImpl for ArrayHas { // if the haystack is a constant list, we can use an inlist expression which is more // efficient because the haystack is not varying per-row - if let Expr::Literal(ScalarValue::List(array)) = haystack { - // TODO: support LargeList - // (not supported by `convert_array_to_scalar_vec`) - // (FixedSizeList not supported either, but seems to have worked fine when attempting to - // build a reproducer) - - assert_eq!(array.len(), 1); // guarantee of ScalarValue - if let Ok(scalar_values) = - ScalarValue::convert_array_to_scalar_vec(array.as_ref()) - { - assert_eq!(scalar_values.len(), 1); - let list = scalar_values - .into_iter() - .flatten() - .map(Expr::Literal) - .collect(); - - return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { - expr: Box::new(std::mem::take(needle)), - list, - negated: false, - }))); + match haystack { + Expr::Literal( + // FixedSizeList gets coerced to List + scalar @ ScalarValue::List(_) | scalar @ ScalarValue::LargeList(_), + _, + ) => { + let array = scalar.to_array().unwrap(); // guarantee of ScalarValue + if let Ok(scalar_values) = + ScalarValue::convert_array_to_scalar_vec(&array) + { + assert_eq!(scalar_values.len(), 1); + let list = scalar_values + .into_iter() + // If the vec is a singular null, `list` will be empty due to this flatten(). + // It would be more clear if we handled the None separately, but this is more performant. + .flatten() + .flatten() + .map(|v| Expr::Literal(v.clone(), None)) + .collect(); + + return Ok(ExprSimplifyResult::Simplified(in_list( + std::mem::take(needle), + list, + false, + ))); + } } - } else if let Expr::ScalarFunction(ScalarFunction { func, args }) = haystack { - // make_array has a static set of arguments, so we can pull the arguments out from it - if func == &make_array_udf() { - return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { - expr: Box::new(std::mem::take(needle)), - list: std::mem::take(args), - negated: false, - }))); + Expr::ScalarFunction(ScalarFunction { func, args }) + if func == &make_array_udf() => + { + // make_array has a static set of arguments, so we can pull the arguments out from it + return Ok(ExprSimplifyResult::Simplified(in_list( + std::mem::take(needle), + std::mem::take(args), + false, + ))); } - } - + _ => {} + }; Ok(ExprSimplifyResult::Original(args)) } @@ -218,34 +221,98 @@ fn array_has_inner_for_scalar( haystack: &ArrayRef, needle: &dyn Datum, ) -> Result { - match haystack.data_type() { - DataType::List(_) => array_has_dispatch_for_scalar::(haystack, needle), - DataType::LargeList(_) => array_has_dispatch_for_scalar::(haystack, needle), - _ => exec_err!( - "array_has does not support type '{:?}'.", - haystack.data_type() - ), - } + let haystack = haystack.as_ref().try_into()?; + array_has_dispatch_for_scalar(haystack, needle) } fn array_has_inner_for_array(haystack: &ArrayRef, needle: &ArrayRef) -> Result { - match haystack.data_type() { - DataType::List(_) => array_has_dispatch_for_array::(haystack, needle), - DataType::LargeList(_) => array_has_dispatch_for_array::(haystack, needle), - _ => exec_err!( - "array_has does not support type '{:?}'.", - haystack.data_type() - ), + let haystack = haystack.as_ref().try_into()?; + array_has_dispatch_for_array(haystack, needle) +} + +enum ArrayWrapper<'a> { + FixedSizeList(&'a arrow::array::FixedSizeListArray), + List(&'a arrow::array::GenericListArray), + LargeList(&'a arrow::array::GenericListArray), +} + +impl<'a> TryFrom<&'a dyn Array> for ArrayWrapper<'a> { + type Error = DataFusionError; + + fn try_from( + value: &'a dyn Array, + ) -> std::result::Result, Self::Error> { + match value.data_type() { + DataType::List(_) => { + Ok(ArrayWrapper::List(as_generic_list_array::(value)?)) + } + DataType::LargeList(_) => Ok(ArrayWrapper::LargeList( + as_generic_list_array::(value)?, + )), + DataType::FixedSizeList(_, _) => Ok(ArrayWrapper::FixedSizeList( + as_fixed_size_list_array(value)?, + )), + _ => exec_err!("array_has does not support type '{:?}'.", value.data_type()), + } } } -fn array_has_dispatch_for_array( - haystack: &ArrayRef, +impl<'a> ArrayWrapper<'a> { + fn len(&self) -> usize { + match self { + ArrayWrapper::FixedSizeList(arr) => arr.len(), + ArrayWrapper::List(arr) => arr.len(), + ArrayWrapper::LargeList(arr) => arr.len(), + } + } + + fn iter(&self) -> Box> + 'a> { + match self { + ArrayWrapper::FixedSizeList(arr) => Box::new(arr.iter()), + ArrayWrapper::List(arr) => Box::new(arr.iter()), + ArrayWrapper::LargeList(arr) => Box::new(arr.iter()), + } + } + + fn values(&self) -> &ArrayRef { + match self { + ArrayWrapper::FixedSizeList(arr) => arr.values(), + ArrayWrapper::List(arr) => arr.values(), + ArrayWrapper::LargeList(arr) => arr.values(), + } + } + + fn value_type(&self) -> DataType { + match self { + ArrayWrapper::FixedSizeList(arr) => arr.value_type(), + ArrayWrapper::List(arr) => arr.value_type(), + ArrayWrapper::LargeList(arr) => arr.value_type(), + } + } + + fn offsets(&self) -> Box + 'a> { + match self { + ArrayWrapper::FixedSizeList(arr) => { + let offsets = (0..=arr.len()) + .step_by(arr.value_length() as usize) + .collect::>(); + Box::new(offsets.into_iter()) + } + ArrayWrapper::List(arr) => { + Box::new(arr.offsets().iter().map(|o| (*o) as usize)) + } + ArrayWrapper::LargeList(arr) => { + Box::new(arr.offsets().iter().map(|o| (*o) as usize)) + } + } + } +} + +fn array_has_dispatch_for_array( + haystack: ArrayWrapper<'_>, needle: &ArrayRef, ) -> Result { - let haystack = as_generic_list_array::(haystack)?; let mut boolean_builder = BooleanArray::builder(haystack.len()); - for (i, arr) in haystack.iter().enumerate() { if arr.is_none() || needle.is_null(i) { boolean_builder.append_null(); @@ -261,17 +328,15 @@ fn array_has_dispatch_for_array( Ok(Arc::new(boolean_builder.finish())) } -fn array_has_dispatch_for_scalar( - haystack: &ArrayRef, +fn array_has_dispatch_for_scalar( + haystack: ArrayWrapper<'_>, needle: &dyn Datum, ) -> Result { - let haystack = as_generic_list_array::(haystack)?; let values = haystack.values(); let is_nested = values.data_type().is_nested(); - let offsets = haystack.value_offsets(); // If first argument is empty list (second argument is non-null), return false // i.e. array_has([], non-null element) -> false - if values.len() == 0 { + if haystack.len() == 0 { return Ok(Arc::new(BooleanArray::new( BooleanBuffer::new_unset(haystack.len()), None, @@ -279,51 +344,128 @@ fn array_has_dispatch_for_scalar( } let eq_array = compare_with_eq(values, needle, is_nested)?; let mut final_contained = vec![None; haystack.len()]; - for (i, offset) in offsets.windows(2).enumerate() { - let start = offset[0].to_usize().unwrap(); - let end = offset[1].to_usize().unwrap(); + + // Check validity buffer to distinguish between null and empty arrays + let validity = match &haystack { + ArrayWrapper::FixedSizeList(arr) => arr.nulls(), + ArrayWrapper::List(arr) => arr.nulls(), + ArrayWrapper::LargeList(arr) => arr.nulls(), + }; + + for (i, (start, end)) in haystack.offsets().tuple_windows().enumerate() { let length = end - start; - // For non-nested list, length is 0 for null + + // Check if the array at this position is null + if let Some(validity_buffer) = validity { + if !validity_buffer.is_valid(i) { + final_contained[i] = None; // null array -> null result + continue; + } + } + + // For non-null arrays: length is 0 for empty arrays if length == 0 { - continue; + final_contained[i] = Some(false); // empty array -> false + } else { + let sliced_array = eq_array.slice(start, length); + final_contained[i] = Some(sliced_array.true_count() > 0); } - let sliced_array = eq_array.slice(start, length); - final_contained[i] = Some(sliced_array.true_count() > 0); } Ok(Arc::new(BooleanArray::from(final_contained))) } fn array_has_all_inner(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::List(_) => { - array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::All) - } - DataType::LargeList(_) => { - array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::All) + array_has_all_and_any_inner(args, ComparisonType::All) +} + +// General row comparison for array_has_all and array_has_any +fn general_array_has_for_all_and_any<'a>( + haystack: &ArrayWrapper<'a>, + needle: &ArrayWrapper<'a>, + comparison_type: ComparisonType, +) -> Result { + let mut boolean_builder = BooleanArray::builder(haystack.len()); + let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?; + + for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + boolean_builder.append_value(general_array_has_all_and_any_kernel( + arr_values, + sub_arr_values, + comparison_type, + )); + } else { + boolean_builder.append_null(); } - _ => exec_err!( - "array_has does not support type '{:?}'.", - args[0].data_type() - ), } + + Ok(Arc::new(boolean_builder.finish())) } -fn array_has_any_inner(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::List(_) => { - array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::Any) +// String comparison for array_has_all and array_has_any +fn array_has_all_and_any_string_internal<'a>( + haystack: &ArrayWrapper<'a>, + needle: &ArrayWrapper<'a>, + comparison_type: ComparisonType, +) -> Result { + let mut boolean_builder = BooleanArray::builder(haystack.len()); + for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { + match (arr, sub_arr) { + (Some(arr), Some(sub_arr)) => { + let haystack_array = string_array_to_vec(&arr); + let needle_array = string_array_to_vec(&sub_arr); + boolean_builder.append_value(array_has_string_kernel( + haystack_array, + needle_array, + comparison_type, + )); + } + (_, _) => { + boolean_builder.append_null(); + } } - DataType::LargeList(_) => { - array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + + Ok(Arc::new(boolean_builder.finish())) +} + +fn array_has_all_and_any_dispatch<'a>( + haystack: &ArrayWrapper<'a>, + needle: &ArrayWrapper<'a>, + comparison_type: ComparisonType, +) -> Result { + if needle.values().is_empty() { + let buffer = match comparison_type { + ComparisonType::All => BooleanBuffer::new_set(haystack.len()), + ComparisonType::Any => BooleanBuffer::new_unset(haystack.len()), + }; + Ok(Arc::new(BooleanArray::from(buffer))) + } else { + match needle.value_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + array_has_all_and_any_string_internal(haystack, needle, comparison_type) + } + _ => general_array_has_for_all_and_any(haystack, needle, comparison_type), } - _ => exec_err!( - "array_has does not support type '{:?}'.", - args[0].data_type() - ), } } +fn array_has_all_and_any_inner( + args: &[ArrayRef], + comparison_type: ComparisonType, +) -> Result { + let haystack: ArrayWrapper = args[0].as_ref().try_into()?; + let needle: ArrayWrapper = args[1].as_ref().try_into()?; + array_has_all_and_any_dispatch(&haystack, &needle, comparison_type) +} + +fn array_has_any_inner(args: &[ArrayRef]) -> Result { + array_has_all_and_any_inner(args, ComparisonType::Any) +} + #[user_doc( doc_section(label = "Array Functions"), description = "Returns true if all elements of sub-array exist in array.", @@ -345,7 +487,7 @@ fn array_has_any_inner(args: &[ArrayRef]) -> Result { description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayHasAll { signature: Signature, aliases: Vec, @@ -360,7 +502,7 @@ impl Default for ArrayHasAll { impl ArrayHasAll { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays(2, None, Volatility::Immutable), aliases: vec![String::from("list_has_all")], } } @@ -419,7 +561,7 @@ impl ScalarUDFImpl for ArrayHasAll { description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayHasAny { signature: Signature, aliases: Vec, @@ -434,7 +576,7 @@ impl Default for ArrayHasAny { impl ArrayHasAny { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays(2, None, Volatility::Immutable), aliases: vec![String::from("list_has_any"), String::from("arrays_overlap")], } } @@ -481,55 +623,6 @@ enum ComparisonType { Any, } -fn array_has_all_and_any_dispatch( - haystack: &ArrayRef, - needle: &ArrayRef, - comparison_type: ComparisonType, -) -> Result { - let haystack = as_generic_list_array::(haystack)?; - let needle = as_generic_list_array::(needle)?; - if needle.values().len() == 0 { - let buffer = match comparison_type { - ComparisonType::All => BooleanBuffer::new_set(haystack.len()), - ComparisonType::Any => BooleanBuffer::new_unset(haystack.len()), - }; - return Ok(Arc::new(BooleanArray::from(buffer))); - } - match needle.data_type() { - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { - array_has_all_and_any_string_internal::(haystack, needle, comparison_type) - } - _ => general_array_has_for_all_and_any::(haystack, needle, comparison_type), - } -} - -// String comparison for array_has_all and array_has_any -fn array_has_all_and_any_string_internal( - array: &GenericListArray, - needle: &GenericListArray, - comparison_type: ComparisonType, -) -> Result { - let mut boolean_builder = BooleanArray::builder(array.len()); - for (arr, sub_arr) in array.iter().zip(needle.iter()) { - match (arr, sub_arr) { - (Some(arr), Some(sub_arr)) => { - let haystack_array = string_array_to_vec(&arr); - let needle_array = string_array_to_vec(&sub_arr); - boolean_builder.append_value(array_has_string_kernel( - haystack_array, - needle_array, - comparison_type, - )); - } - (_, _) => { - boolean_builder.append_null(); - } - } - } - - Ok(Arc::new(boolean_builder.finish())) -} - fn array_has_string_kernel( haystack: Vec>, needle: Vec>, @@ -547,32 +640,6 @@ fn array_has_string_kernel( } } -// General row comparison for array_has_all and array_has_any -fn general_array_has_for_all_and_any( - haystack: &GenericListArray, - needle: &GenericListArray, - comparison_type: ComparisonType, -) -> Result { - let mut boolean_builder = BooleanArray::builder(haystack.len()); - let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?; - - for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; - boolean_builder.append_value(general_array_has_all_and_any_kernel( - arr_values, - sub_arr_values, - comparison_type, - )); - } else { - boolean_builder.append_null(); - } - } - - Ok(Arc::new(boolean_builder.finish())) -} - fn general_array_has_all_and_any_kernel( haystack_rows: Rows, needle_rows: Rows, @@ -594,11 +661,20 @@ fn general_array_has_all_and_any_kernel( #[cfg(test)] mod tests { - use arrow::array::create_array; - use datafusion_common::utils::SingleRowListArrayBuilder; + use std::sync::Arc; + + use arrow::{ + array::{create_array, Array, ArrayRef, AsArray, Int32Array, ListArray}, + buffer::OffsetBuffer, + datatypes::{DataType, Field}, + }; + use datafusion_common::{ + config::ConfigOptions, utils::SingleRowListArrayBuilder, DataFusionError, + ScalarValue, + }; use datafusion_expr::{ - col, execution_props::ExecutionProps, lit, simplify::ExprSimplifyResult, Expr, - ScalarUDFImpl, + col, execution_props::ExecutionProps, lit, simplify::ExprSimplifyResult, + ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl, }; use crate::expr_fn::make_array; @@ -673,4 +749,44 @@ mod tests { assert_eq!(args, vec![col("c1"), col("c2")],); } + + #[test] + fn test_array_has_list_empty_child() -> Result<(), DataFusionError> { + let haystack_field = Arc::new(Field::new_list( + "haystack", + Field::new_list("", Field::new("", DataType::Int32, true), true), + true, + )); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new_list( + "return", + Field::new("", DataType::Boolean, true), + true, + )); + + let haystack = ListArray::new( + Field::new_list_field(DataType::Int32, true).into(), + OffsetBuffer::new(vec![0, 0].into()), + Arc::new(Int32Array::from(Vec::::new())) as ArrayRef, + Some(vec![false].into()), + ); + + let haystack = ColumnarValue::Array(Arc::new(haystack)); + let needle = ColumnarValue::Scalar(ScalarValue::Int32(Some(1))); + + let result = ArrayHas::new().invoke_with_args(ScalarFunctionArgs { + args: vec![haystack, needle], + arg_fields: vec![haystack_field, needle_field], + number_rows: 1, + return_field, + config_options: Arc::new(ConfigOptions::default()), + })?; + + let output = result.into_array(1)?; + let output = output.as_boolean(); + assert_eq!(output.len(), 1); + assert!(output.is_null(0)); + + Ok(()) + } } diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index f2f23841586ce..6db0011cd0784 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -23,12 +23,12 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, LargeList, List, Map, UInt64}, + DataType::{LargeList, List, Map, Null, UInt64}, }; use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; -use datafusion_common::utils::take_function_args; +use datafusion_common::exec_err; +use datafusion_common::utils::{take_function_args, ListCoercion}; use datafusion_common::Result; -use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -52,13 +52,12 @@ impl Cardinality { vec![ TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), ], Volatility::Immutable, ), - aliases: vec![], } } } @@ -80,10 +79,9 @@ impl Cardinality { description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Cardinality { signature: Signature, - aliases: Vec, } impl Default for Cardinality { @@ -103,13 +101,8 @@ impl ScalarUDFImpl for Cardinality { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) | Map(_, _) => UInt64, - _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList/Map."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt64) } fn invoke_with_args( @@ -119,10 +112,6 @@ impl ScalarUDFImpl for Cardinality { make_scalar_function(cardinality_inner)(&args.args) } - fn aliases(&self) -> &[String] { - &self.aliases - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -131,21 +120,22 @@ impl ScalarUDFImpl for Cardinality { /// Cardinality SQL function pub fn cardinality_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("cardinality", args)?; - match &array.data_type() { + match array.data_type() { + Null => Ok(Arc::new(UInt64Array::from_value(0, array.len()))), List(_) => { - let list_array = as_list_array(&array)?; + let list_array = as_list_array(array)?; generic_list_cardinality::(list_array) } LargeList(_) => { - let list_array = as_large_list_array(&array)?; + let list_array = as_large_list_array(array)?; generic_list_cardinality::(list_array) } Map(_, _) => { - let map_array = as_map_array(&array)?; + let map_array = as_map_array(array)?; generic_map_cardinality(map_array) } - other => { - exec_err!("cardinality does not support type '{:?}'", other) + arg_type => { + exec_err!("cardinality does not support type {arg_type}") } } } diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index f4b9208e5c83a..9a12db525f954 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -17,29 +17,32 @@ //! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. +use std::any::Any; use std::sync::Arc; -use std::{any::Any, cmp::Ordering}; +use crate::make_array::make_array_inner; +use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; use arrow::array::{ - Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder, - OffsetSizeTrait, + Array, ArrayData, ArrayRef, Capacities, GenericListArray, MutableArrayData, + NullBufferBuilder, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::ListCoercion; +use datafusion_common::utils::{ + base_type, coerced_type_with_base_type_only, ListCoercion, +}; use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, - exec_err, not_impl_err, plan_err, + exec_err, plan_err, utils::{list_ndims, take_function_args}, }; +use datafusion_expr::binary::type_union_resolution; use datafusion_expr::{ - ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; - -use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; +use itertools::Itertools; make_udf_expr_and_func!( ArrayAppend, @@ -67,7 +70,7 @@ make_udf_expr_and_func!( ), argument(name = "element", description = "Element to append to the array.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayAppend { signature: Signature, aliases: Vec, @@ -106,7 +109,12 @@ impl ScalarUDFImpl for ArrayAppend { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + let [array_type, element_type] = take_function_args(self.name(), arg_types)?; + if array_type.is_null() { + Ok(DataType::new_list(element_type.clone(), true)) + } else { + Ok(array_type.clone()) + } } fn invoke_with_args( @@ -151,7 +159,7 @@ make_udf_expr_and_func!( ), argument(name = "element", description = "Element to prepend to the array.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayPrepend { signature: Signature, aliases: Vec, @@ -166,18 +174,7 @@ impl Default for ArrayPrepend { impl ArrayPrepend { pub fn new() -> Self { Self { - signature: Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::Array { - arguments: vec![ - ArrayFunctionArgument::Element, - ArrayFunctionArgument::Array, - ], - array_coercion: Some(ListCoercion::FixedSizedListToList), - }, - ), - volatility: Volatility::Immutable, - }, + signature: Signature::element_and_array(Volatility::Immutable), aliases: vec![ String::from("list_prepend"), String::from("array_push_front"), @@ -201,7 +198,12 @@ impl ScalarUDFImpl for ArrayPrepend { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[1].clone()) + let [element_type, array_type] = take_function_args(self.name(), arg_types)?; + if array_type.is_null() { + Ok(DataType::new_list(element_type.clone(), true)) + } else { + Ok(array_type.clone()) + } } fn invoke_with_args( @@ -248,7 +250,7 @@ make_udf_expr_and_func!( description = "Subsequent array column or literal array to concatenate." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayConcat { signature: Signature, aliases: Vec, @@ -263,7 +265,7 @@ impl Default for ArrayConcat { impl ArrayConcat { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![ String::from("array_cat"), String::from("list_concat"), @@ -287,39 +289,41 @@ impl ScalarUDFImpl for ArrayConcat { } fn return_type(&self, arg_types: &[DataType]) -> Result { - let mut expr_type = DataType::Null; let mut max_dims = 0; + let mut large_list = false; + let mut element_types = Vec::with_capacity(arg_types.len()); for arg_type in arg_types { - let DataType::List(field) = arg_type else { - return plan_err!( - "The array_concat function can only accept list as the args." - ); - }; - if !field.data_type().equals_datatype(&DataType::Null) { - let dims = list_ndims(arg_type); - expr_type = match max_dims.cmp(&dims) { - Ordering::Greater => expr_type, - Ordering::Equal => { - if expr_type == DataType::Null { - arg_type.clone() - } else if !expr_type.equals_datatype(arg_type) { - return plan_err!( - "It is not possible to concatenate arrays of different types. Expected: {}, got: {}", expr_type, arg_type - ); - } else { - expr_type - } - } - - Ordering::Less => { - max_dims = dims; - arg_type.clone() - } - }; + match arg_type { + DataType::Null | DataType::List(_) | DataType::FixedSizeList(..) => (), + DataType::LargeList(_) => large_list = true, + arg_type => { + return plan_err!("{} does not support type {arg_type}", self.name()) + } } + + max_dims = max_dims.max(list_ndims(arg_type)); + element_types.push(base_type(arg_type)) } - Ok(expr_type) + if max_dims == 0 { + Ok(DataType::Null) + } else if let Some(mut return_type) = type_union_resolution(&element_types) { + for _ in 1..max_dims { + return_type = DataType::new_list(return_type, true) + } + + if large_list { + Ok(DataType::new_large_list(return_type, true)) + } else { + Ok(DataType::new_list(return_type, true)) + } + } else { + plan_err!( + "Failed to unify argument types of {}: [{}]", + self.name(), + arg_types.iter().join(", ") + ) + } } fn invoke_with_args( @@ -333,6 +337,16 @@ impl ScalarUDFImpl for ArrayConcat { &self.aliases } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let base_type = base_type(&self.return_type(arg_types)?); + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + coerced_type_with_base_type_only(arg_type, &base_type, coercion) + }); + + Ok(arg_types.collect()) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -341,24 +355,38 @@ impl ScalarUDFImpl for ArrayConcat { /// Array_concat/Array_cat SQL function pub(crate) fn array_concat_inner(args: &[ArrayRef]) -> Result { if args.is_empty() { - return exec_err!("array_concat expects at least one arguments"); + return exec_err!("array_concat expects at least one argument"); } - let mut new_args = vec![]; + let mut all_null = true; + let mut large_list = false; for arg in args { - let ndim = list_ndims(arg.data_type()); - let base_type = datafusion_common::utils::base_type(arg.data_type()); - if ndim == 0 { - return not_impl_err!("Array is not type '{base_type:?}'."); + match arg.data_type() { + DataType::Null => continue, + DataType::LargeList(_) => large_list = true, + _ => (), } - if !base_type.eq(&DataType::Null) { - new_args.push(Arc::clone(arg)); + if arg.null_count() < arg.len() { + all_null = false; } } - match &args[0].data_type() { - DataType::LargeList(_) => concat_internal::(new_args.as_slice()), - _ => concat_internal::(new_args.as_slice()), + if all_null { + // Return a null array with the same type as the first non-null-type argument + let return_type = args + .iter() + .map(|arg| arg.data_type()) + .find_or_first(|d| !d.is_null()) + .unwrap(); // Safe because args is non-empty + + Ok(arrow::array::make_array(ArrayData::new_null( + return_type, + args[0].len(), + ))) + } else if large_list { + concat_internal::(args) + } else { + concat_internal::(args) } } @@ -427,21 +455,23 @@ fn concat_internal(args: &[ArrayRef]) -> Result { /// Array_append SQL function pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result { - let [array, _] = take_function_args("array_append", args)?; - + let [array, values] = take_function_args("array_append", args)?; match array.data_type() { + DataType::Null => make_array_inner(&[Arc::clone(values)]), + DataType::List(_) => general_append_and_prepend::(args, true), DataType::LargeList(_) => general_append_and_prepend::(args, true), - _ => general_append_and_prepend::(args, true), + arg_type => exec_err!("array_append does not support type {arg_type}"), } } /// Array_prepend SQL function pub(crate) fn array_prepend_inner(args: &[ArrayRef]) -> Result { - let [_, array] = take_function_args("array_prepend", args)?; - + let [values, array] = take_function_args("array_prepend", args)?; match array.data_type() { + DataType::Null => make_array_inner(&[Arc::clone(values)]), + DataType::List(_) => general_append_and_prepend::(args, false), DataType::LargeList(_) => general_append_and_prepend::(args, false), - _ => general_append_and_prepend::(args, false), + arg_type => exec_err!("array_prepend does not support type {arg_type}"), } } diff --git a/datafusion/functions-nested/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs index a7d0336414131..b0fc5bee5494d 100644 --- a/datafusion/functions-nested/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -17,24 +17,26 @@ //! [`ScalarUDFImpl`] definitions for array_dims and array_ndims functions. -use arrow::array::{ - Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, -}; +use arrow::array::{Array, ArrayRef, ListArray, UInt64Array}; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, LargeList, List, UInt64}, - Field, UInt64Type, + DataType::{FixedSizeList, LargeList, List, Null, UInt64}, + UInt64Type, }; use std::any::Any; -use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; +use datafusion_common::utils::list_ndims; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use itertools::Itertools; use std::sync::Arc; make_udf_expr_and_func!( @@ -62,7 +64,7 @@ make_udf_expr_and_func!( description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayDims { signature: Signature, aliases: Vec, @@ -77,7 +79,7 @@ impl Default for ArrayDims { impl ArrayDims { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec!["list_dims".to_string()], } } @@ -95,15 +97,8 @@ impl ScalarUDFImpl for ArrayDims { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => { - List(Arc::new(Field::new_list_field(UInt64, true))) - } - _ => { - return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::new_list(UInt64, true)) } fn invoke_with_args( @@ -148,7 +143,7 @@ make_udf_expr_and_func!( ), argument(name = "element", description = "Array element.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayNdims { signature: Signature, aliases: Vec, @@ -156,7 +151,7 @@ pub(super) struct ArrayNdims { impl ArrayNdims { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec![String::from("list_ndims")], } } @@ -174,13 +169,8 @@ impl ScalarUDFImpl for ArrayNdims { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt64) } fn invoke_with_args( @@ -202,61 +192,42 @@ impl ScalarUDFImpl for ArrayNdims { /// Array_dims SQL function pub fn array_dims_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("array_dims", args)?; - - let data = match array.data_type() { - List(_) => { - let array = as_list_array(&array)?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - LargeList(_) => { - let array = as_large_list_array(&array)?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - array_type => { - return exec_err!("array_dims does not support type '{array_type:?}'"); + let data: Vec<_> = match array.data_type() { + List(_) => as_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + LargeList(_) => as_large_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + FixedSizeList(..) => as_fixed_size_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + arg_type => { + return exec_err!("array_dims does not support type {arg_type}"); } }; let result = ListArray::from_iter_primitive::(data); - - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(result)) } /// Array_ndims SQL function pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { - let [array_dim] = take_function_args("array_ndims", args)?; + let [array] = take_function_args("array_ndims", args)?; - fn general_list_ndims( - array: &GenericListArray, - ) -> Result { - let mut data = Vec::new(); - let ndims = datafusion_common::utils::list_ndims(array.data_type()); - - for arr in array.iter() { - if arr.is_some() { - data.push(Some(ndims)) - } else { - data.push(None) - } - } - - Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + fn general_list_ndims(array: &ArrayRef) -> Result { + let ndims = list_ndims(array.data_type()); + let data = vec![ndims; array.len()]; + let result = UInt64Array::new(data.into(), array.nulls().cloned()); + Ok(Arc::new(result)) } - match array_dim.data_type() { - List(_) => { - let array = as_list_array(&array_dim)?; - general_list_ndims::(array) - } - LargeList(_) => { - let array = as_large_list_array(&array_dim)?; - general_list_ndims::(array) - } - array_type => exec_err!("array_ndims does not support type {array_type:?}"), + + match array.data_type() { + Null => Ok(Arc::new(UInt64Array::new_null(array.len()))), + List(_) | LargeList(_) | FixedSizeList(..) => general_list_ndims(array), + arg_type => exec_err!("array_ndims does not support type {arg_type}"), } } diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index cfc7fccdd70c4..e2e38fbd0d836 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -23,21 +23,20 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, Float64, LargeList, List}, + DataType::{FixedSizeList, LargeList, List, Null}, }; use datafusion_common::cast::{ as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, as_int64_array, }; -use datafusion_common::utils::coerced_fixed_size_list_to_list; -use datafusion_common::{ - exec_err, internal_datafusion_err, utils::take_function_args, Result, -}; +use datafusion_common::utils::{coerced_type_with_base_type_only, ListCoercion}; +use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_functions::{downcast_arg, downcast_named_arg}; +use datafusion_functions::downcast_arg; use datafusion_macros::user_doc; +use itertools::Itertools; use std::any::Any; use std::sync::Arc; @@ -70,7 +69,7 @@ make_udf_expr_and_func!( description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayDistance { signature: Signature, aliases: Vec, @@ -104,24 +103,26 @@ impl ScalarUDFImpl for ArrayDistance { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64), - _ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), - } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { let [_, _] = take_function_args(self.name(), arg_types)?; - let mut result = Vec::new(); - for arg_type in arg_types { - match arg_type { - List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)), - _ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) } - } + }); - Ok(result) + arg_types.try_collect() } fn invoke_with_args( @@ -142,12 +143,11 @@ impl ScalarUDFImpl for ArrayDistance { pub fn array_distance_inner(args: &[ArrayRef]) -> Result { let [array1, array2] = take_function_args("array_distance", args)?; - - match (&array1.data_type(), &array2.data_type()) { + match (array1.data_type(), array2.data_type()) { (List(_), List(_)) => general_array_distance::(args), (LargeList(_), LargeList(_)) => general_array_distance::(args), - (array_type1, array_type2) => { - exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'") + (arg_type1, arg_type2) => { + exec_err!("array_distance does not support types {arg_type1} and {arg_type2}") } } } @@ -243,7 +243,7 @@ fn compute_array_distance( /// Converts an array of any numeric type to a Float64Array. fn convert_to_f64_array(array: &ArrayRef) -> Result { match array.data_type() { - Float64 => Ok(as_float64_array(array)?.clone()), + DataType::Float64 => Ok(as_float64_array(array)?.clone()), DataType::Float32 => { let array = as_float32_array(array)?; let converted: Float64Array = diff --git a/datafusion/functions-nested/src/empty.rs b/datafusion/functions-nested/src/empty.rs index dcefd583e9377..27a90ab0442bc 100644 --- a/datafusion/functions-nested/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -18,13 +18,14 @@ //! [`ScalarUDFImpl`] definitions for array_empty function. use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow::buffer::BooleanBuffer; use arrow::datatypes::{ DataType, DataType::{Boolean, FixedSizeList, LargeList, List}, }; use datafusion_common::cast::as_generic_list_array; -use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -57,7 +58,7 @@ make_udf_expr_and_func!( description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayEmpty { signature: Signature, aliases: Vec, @@ -71,7 +72,7 @@ impl Default for ArrayEmpty { impl ArrayEmpty { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec!["array_empty".to_string(), "list_empty".to_string()], } } @@ -89,13 +90,8 @@ impl ScalarUDFImpl for ArrayEmpty { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, - _ => { - return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Boolean) } fn invoke_with_args( @@ -117,21 +113,25 @@ impl ScalarUDFImpl for ArrayEmpty { /// Array_empty SQL function pub fn array_empty_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("array_empty", args)?; - - let array_type = array.data_type(); - match array_type { + match array.data_type() { List(_) => general_array_empty::(array), LargeList(_) => general_array_empty::(array), - _ => exec_err!("array_empty does not support type '{array_type:?}'."), + FixedSizeList(_, size) => { + let values = if *size == 0 { + BooleanBuffer::new_set(array.len()) + } else { + BooleanBuffer::new_unset(array.len()) + }; + Ok(Arc::new(BooleanArray::new(values, array.nulls().cloned()))) + } + arg_type => exec_err!("array_empty does not support type {arg_type}"), } } fn general_array_empty(array: &ArrayRef) -> Result { - let array = as_generic_list_array::(array)?; - - let builder = array + let result = as_generic_list_array::(array)? .iter() .map(|arr| arr.map(|arr| arr.is_empty())) .collect::(); - Ok(Arc::new(builder)) + Ok(Arc::new(result)) } diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index 2385f6d12d43e..d6982ab5a2ab0 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -22,7 +22,7 @@ use arrow::array::{cast::AsArray, Array, ArrayRef, GenericListArray, OffsetSizeT use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, FieldRef}; use arrow::row::{RowConverter, SortField}; -use datafusion_common::utils::take_function_args; +use datafusion_common::utils::{take_function_args, ListCoercion}; use datafusion_common::{internal_err, HashSet, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -66,7 +66,7 @@ make_udf_expr_and_func!( description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayExcept { signature: Signature, aliases: Vec, @@ -81,7 +81,11 @@ impl Default for ArrayExcept { impl ArrayExcept { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ), aliases: vec!["list_except".to_string()], } } diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 321dda55ce097..a46c9c75094c6 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -19,12 +19,12 @@ use arrow::array::{ Array, ArrayRef, ArrowNativeTypeOp, Capacities, GenericListArray, Int64Array, - MutableArrayData, NullBufferBuilder, OffsetSizeTrait, + MutableArrayData, NullArray, NullBufferBuilder, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; use arrow::datatypes::{ - DataType::{FixedSizeList, LargeList, List}, + DataType::{FixedSizeList, LargeList, List, Null}, Field, }; use datafusion_common::cast::as_int64_array; @@ -32,8 +32,8 @@ use datafusion_common::cast::as_large_list_array; use datafusion_common::cast::as_list_array; use datafusion_common::utils::ListCoercion; use datafusion_common::{ - exec_err, internal_datafusion_err, plan_err, utils::take_function_args, - DataFusionError, Result, + exec_datafusion_err, exec_err, internal_datafusion_err, plan_err, + utils::take_function_args, Result, }; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, Expr, TypeSignature, @@ -103,7 +103,7 @@ make_udf_expr_and_func!( description = "Index to extract the element from the array." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayElement { signature: Signature, aliases: Vec, @@ -163,13 +163,9 @@ impl ScalarUDFImpl for ArrayElement { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) - | LargeList(field) - | FixedSizeList(field, _) => Ok(field.data_type().clone()), - DataType::Null => Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true)))), - _ => plan_err!( - "ArrayElement can only accept List, LargeList or FixedSizeList as the first argument" - ), + Null => Ok(Null), + List(field) | LargeList(field) => Ok(field.data_type().clone()), + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), } } @@ -200,6 +196,7 @@ fn array_element_inner(args: &[ArrayRef]) -> Result { let [array, indexes] = take_function_args("array_element", args)?; match &array.data_type() { + Null => Ok(Arc::new(NullArray::new(array.len()))), List(_) => { let array = as_list_array(&array)?; let indexes = as_int64_array(&indexes)?; @@ -210,10 +207,9 @@ fn array_element_inner(args: &[ArrayRef]) -> Result { let indexes = as_int64_array(&indexes)?; general_array_element::(array, indexes) } - _ => exec_err!( - "array_element does not support type: {:?}", - array.data_type() - ), + arg_type => { + exec_err!("array_element does not support type {arg_type}") + } } } @@ -225,6 +221,10 @@ where i64: TryInto, { let values = array.values(); + if values.data_type().is_null() { + return Ok(Arc::new(NullArray::new(array.len()))); + } + let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -237,10 +237,7 @@ where i64: TryInto, { let index: O = index.try_into().map_err(|_| { - DataFusionError::Execution(format!( - "array_element got invalid index: {}", - index - )) + exec_datafusion_err!("array_element got invalid index: {index}") })?; // 0 ~ len - 1 let adjusted_zero_index = if index < O::usize_as(0) { @@ -321,7 +318,7 @@ pub fn array_slice(array: Expr, begin: Expr, end: Expr, stride: Option) -> description = "Stride of the array slice. The default is 1." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArraySlice { signature: Signature, aliases: Vec, @@ -338,7 +335,7 @@ impl ArraySlice { ArrayFunctionArgument::Index, ArrayFunctionArgument::Index, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ @@ -347,7 +344,7 @@ impl ArraySlice { ArrayFunctionArgument::Index, ArrayFunctionArgument::Index, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), ], Volatility::Immutable, @@ -452,7 +449,7 @@ fn array_slice_inner(args: &[ArrayRef]) -> Result { let array = as_large_list_array(&args[0])?; general_array_slice::(array, from_array, to_array, stride) } - _ => exec_err!("array_slice does not support type: {:?}", array_data_type), + _ => exec_err!("array_slice does not support type: {}", array_data_type), } } @@ -664,7 +661,7 @@ where description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayPopFront { signature: Signature, aliases: Vec, @@ -673,15 +670,7 @@ pub(super) struct ArrayPopFront { impl ArrayPopFront { pub fn new() -> Self { Self { - signature: Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array], - array_coercion: Some(ListCoercion::FixedSizedListToList), - }, - ), - volatility: Volatility::Immutable, - }, + signature: Signature::array(Volatility::Immutable), aliases: vec![String::from("list_pop_front")], } } @@ -731,10 +720,7 @@ fn array_pop_front_inner(args: &[ArrayRef]) -> Result { let array = as_large_list_array(&args[0])?; general_pop_front_list::(array) } - _ => exec_err!( - "array_pop_front does not support type: {:?}", - array_data_type - ), + _ => exec_err!("array_pop_front does not support type: {}", array_data_type), } } @@ -771,7 +757,7 @@ where description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayPopBack { signature: Signature, aliases: Vec, @@ -780,15 +766,7 @@ pub(super) struct ArrayPopBack { impl ArrayPopBack { pub fn new() -> Self { Self { - signature: Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array], - array_coercion: Some(ListCoercion::FixedSizedListToList), - }, - ), - volatility: Volatility::Immutable, - }, + signature: Signature::array(Volatility::Immutable), aliases: vec![String::from("list_pop_back")], } } @@ -840,7 +818,7 @@ fn array_pop_back_inner(args: &[ArrayRef]) -> Result { general_pop_back_list::(array) } _ => exec_err!( - "array_pop_back does not support type: {:?}", + "array_pop_back does not support type: {}", array.data_type() ), } @@ -879,7 +857,7 @@ where description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayAnyValue { signature: Signature, aliases: Vec, @@ -943,7 +921,7 @@ fn array_any_value_inner(args: &[ArrayRef]) -> Result { let array = as_large_list_array(&array)?; general_array_any_value::(array) } - data_type => exec_err!("array_any_value does not support type: {:?}", data_type), + data_type => exec_err!("array_any_value does not support type: {data_type}"), } } diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index f288035948dcb..1b74af643c0c0 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -18,19 +18,16 @@ //! [`ScalarUDFImpl`] definitions for flatten function. use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, Null}, }; -use datafusion_common::cast::{ - as_generic_list_array, as_large_list_array, as_list_array, -}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -61,7 +58,7 @@ make_udf_expr_and_func!( description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Flatten { signature: Signature, aliases: Vec, @@ -76,13 +73,7 @@ impl Default for Flatten { impl Flatten { pub fn new() -> Self { Self { - signature: Signature { - // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::RecursiveArray, - ), - volatility: Volatility::Immutable, - }, + signature: Signature::array(Volatility::Immutable), aliases: vec![], } } @@ -102,25 +93,23 @@ impl ScalarUDFImpl for Flatten { } fn return_type(&self, arg_types: &[DataType]) -> Result { - fn get_base_type(data_type: &DataType) -> Result { - match data_type { - List(field) | FixedSizeList(field, _) - if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => - { - get_base_type(field.data_type()) - } - LargeList(field) if matches!(field.data_type(), LargeList(_)) => { - get_base_type(field.data_type()) + let data_type = match &arg_types[0] { + List(field) => match field.data_type() { + List(field) | FixedSizeList(field, _) => List(Arc::clone(field)), + _ => arg_types[0].clone(), + }, + LargeList(field) => match field.data_type() { + List(field) | LargeList(field) | FixedSizeList(field, _) => { + LargeList(Arc::clone(field)) } - Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(Arc::clone(field))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } + _ => arg_types[0].clone(), + }, + Null => Null, + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + )?, + }; - let data_type = get_base_type(&arg_types[0])?; Ok(data_type) } @@ -146,14 +135,64 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result { match array.data_type() { List(_) => { - let list_arr = as_list_array(&array)?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) + let (_field, offsets, values, nulls) = + as_list_array(&array)?.clone().into_parts(); + let values = cast_fsl_to_list(values)?; + + match values.data_type() { + List(_) => { + let (inner_field, inner_offsets, inner_values, _) = + as_list_array(&values)?.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + exec_err!("flatten does not support type '{:?}'", array.data_type())? + } + _ => Ok(Arc::clone(array) as ArrayRef), + } } LargeList(_) => { - let list_arr = as_large_list_array(&array)?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) + let (_field, offsets, values, nulls) = + as_large_list_array(&array)?.clone().into_parts(); + let values = cast_fsl_to_list(values)?; + + match values.data_type() { + List(_) => { + let (inner_field, inner_offsets, inner_values, _) = + as_list_array(&values)?.clone().into_parts(); + let offsets = get_large_offsets_for_flatten(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + let (inner_field, inner_offsets, inner_values, nulls) = + as_large_list_array(&values)?.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + _ => Ok(Arc::clone(array) as ArrayRef), + } } Null => Ok(Arc::clone(array)), _ => { @@ -162,37 +201,6 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result { } } -fn flatten_internal( - list_arr: GenericListArray, - indexes: Option>, -) -> Result> { - let (field, offsets, values, _) = list_arr.clone().into_parts(); - let data_type = field.data_type(); - - match data_type { - // Recursively get the base offsets for flattened array - List(_) | LargeList(_) => { - let sub_list = as_generic_list_array::(&values)?; - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - flatten_internal::(sub_list.clone(), Some(offsets)) - } else { - flatten_internal::(sub_list.clone(), Some(offsets)) - } - } - // Reach the base level, create a new list array - _ => { - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - let list_arr = GenericListArray::::new(field, offsets, values, None); - Ok(list_arr) - } else { - Ok(list_arr) - } - } - } -} - // Create new offsets that are equivalent to `flatten` the array. fn get_offsets_for_flatten( offsets: OffsetBuffer, @@ -205,3 +213,25 @@ fn get_offsets_for_flatten( .collect(); OffsetBuffer::new(offsets.into()) } + +// Create new large offsets that are equivalent to `flatten` the array. +fn get_large_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer

, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes + .iter() + .map(|i| buffer[i.to_usize().unwrap()].to_i64().unwrap()) + .collect(); + OffsetBuffer::new(offsets.into()) +} + +fn cast_fsl_to_list(array: ArrayRef) -> Result { + match array.data_type() { + FixedSizeList(field, _) => { + Ok(arrow::compute::cast(&array, &List(Arc::clone(field)))?) + } + _ => Ok(array), + } +} diff --git a/datafusion/functions-nested/src/length.rs b/datafusion/functions-nested/src/length.rs index 3c3a42da0d692..060a978185e51 100644 --- a/datafusion/functions-nested/src/length.rs +++ b/datafusion/functions-nested/src/length.rs @@ -19,18 +19,22 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, Int64Array, LargeListArray, ListArray, OffsetSizeTrait, UInt64Array, + Array, ArrayRef, FixedSizeListArray, Int64Array, LargeListArray, ListArray, + OffsetSizeTrait, UInt64Array, }; use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, UInt64}, }; -use datafusion_common::cast::{as_generic_list_array, as_int64_array}; -use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_generic_list_array, as_int64_array, +}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; -use datafusion_functions::{downcast_arg, downcast_named_arg}; +use datafusion_functions::downcast_arg; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -61,7 +65,7 @@ make_udf_expr_and_func!( ), argument(name = "dimension", description = "Array dimension.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayLength { signature: Signature, aliases: Vec, @@ -76,7 +80,22 @@ impl Default for ArrayLength { impl ArrayLength { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ], + array_coercion: None, + }), + ], + Volatility::Immutable, + ), aliases: vec![String::from("list_length")], } } @@ -94,13 +113,8 @@ impl ScalarUDFImpl for ArrayLength { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_length function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt64) } fn invoke_with_args( @@ -119,6 +133,23 @@ impl ScalarUDFImpl for ArrayLength { } } +macro_rules! array_length_impl { + ($array:expr, $dimension:expr) => {{ + let array = $array; + let dimension = match $dimension { + Some(d) => as_int64_array(d)?.clone(), + None => Int64Array::from_value(1, array.len()), + }; + let result = array + .iter() + .zip(dimension.iter()) + .map(|(arr, dim)| compute_array_length(arr, dim)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) + }}; +} + /// Array_length SQL function pub fn array_length_inner(args: &[ArrayRef]) -> Result { if args.len() != 1 && args.len() != 2 { @@ -128,26 +159,18 @@ pub fn array_length_inner(args: &[ArrayRef]) -> Result { match &args[0].data_type() { List(_) => general_array_length::(args), LargeList(_) => general_array_length::(args), - array_type => exec_err!("array_length does not support type '{array_type:?}'"), + FixedSizeList(_, _) => fixed_size_array_length(args), + array_type => exec_err!("array_length does not support type '{array_type}'"), } } +fn fixed_size_array_length(array: &[ArrayRef]) -> Result { + array_length_impl!(as_fixed_size_list_array(&array[0])?, array.get(1)) +} + /// Dispatch array length computation based on the offset type. fn general_array_length(array: &[ArrayRef]) -> Result { - let list_array = as_generic_list_array::(&array[0])?; - let dimension = if array.len() == 2 { - as_int64_array(&array[1])?.clone() - } else { - Int64Array::from_value(1, list_array.len()) - }; - - let result = list_array - .iter() - .zip(dimension.iter()) - .map(|(arr, dim)| compute_array_length(arr, dim)) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) + array_length_impl!(as_generic_list_array::(&array[0])?, array.get(1)) } /// Returns the length of a concrete array dimension @@ -185,6 +208,10 @@ fn compute_array_length( value = downcast_arg!(value, LargeListArray).value(0); current_dimension += 1; } + FixedSizeList(_, _) => { + value = downcast_arg!(value, FixedSizeListArray).value(0); + current_dimension += 1; + } _ => return Ok(None), } } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index c9a61d98cd446..0a549fb294c6e 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -50,10 +50,11 @@ pub mod flatten; pub mod length; pub mod make_array; pub mod map; +pub mod map_entries; pub mod map_extract; pub mod map_keys; pub mod map_values; -pub mod max; +pub mod min_max; pub mod planner; pub mod position; pub mod range; @@ -95,9 +96,12 @@ pub mod expr_fn { pub use super::flatten::flatten; pub use super::length::array_length; pub use super::make_array::make_array; + pub use super::map_entries::map_entries; pub use super::map_extract::map_extract; pub use super::map_keys::map_keys; pub use super::map_values::map_values; + pub use super::min_max::array_max; + pub use super::min_max::array_min; pub use super::position::array_position; pub use super::position::array_positions; pub use super::range::gen_series; @@ -146,7 +150,8 @@ pub fn all_default_nested_functions() -> Vec> { length::array_length_udf(), distance::array_distance_udf(), flatten::flatten_udf(), - max::array_max_udf(), + min_max::array_max_udf(), + min_max::array_min_udf(), sort::array_sort_udf(), repeat::array_repeat_udf(), resize::array_resize_udf(), @@ -163,6 +168,7 @@ pub fn all_default_nested_functions() -> Vec> { replace::array_replace_all_udf(), replace::array_replace_udf(), map::map_udf(), + map_entries::map_entries_udf(), map_extract::map_extract_udf(), map_keys::map_keys_udf(), map_values::map_values_udf(), @@ -201,8 +207,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 4daaafc5a8888..97d64c70cd364 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -28,10 +28,7 @@ use arrow::array::{ }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; -use arrow::datatypes::{ - DataType::{List, Null}, - Field, -}; +use arrow::datatypes::{DataType::Null, Field}; use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{plan_err, Result}; use datafusion_expr::binary::{ @@ -42,6 +39,7 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use itertools::Itertools as _; make_udf_expr_and_func!( MakeArray, @@ -67,7 +65,7 @@ make_udf_expr_and_func!( description = "Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct MakeArray { signature: Signature, aliases: Vec, @@ -105,16 +103,14 @@ impl ScalarUDFImpl for MakeArray { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types.len() { - 0 => Ok(empty_array_type()), - _ => { - // At this point, all the type in array should be coerced to the same one - Ok(List(Arc::new(Field::new_list_field( - arg_types[0].to_owned(), - true, - )))) - } - } + let element_type = if arg_types.is_empty() { + Null + } else { + // At this point, all the type in array should be coerced to the same one. + arg_types[0].to_owned() + }; + + Ok(DataType::new_list(element_type, true)) } fn invoke_with_args( @@ -129,26 +125,17 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let mut errors = vec![]; - match try_type_union_resolution_with_struct(arg_types) { - Ok(r) => return Ok(r), - Err(e) => { - errors.push(e); - } + if let Ok(unified) = try_type_union_resolution_with_struct(arg_types) { + return Ok(unified); } - if let Some(new_type) = type_union_resolution(arg_types) { - if new_type.is_null() { - Ok(vec![DataType::Int64; arg_types.len()]) - } else { - Ok(vec![new_type; arg_types.len()]) - } + if let Some(unified) = type_union_resolution(arg_types) { + Ok(vec![unified; arg_types.len()]) } else { plan_err!( - "Fail to find the valid type between {:?} for {}, errors are {:?}", - arg_types, + "Failed to unify argument types of {}: [{}]", self.name(), - errors + arg_types.iter().join(", ") ) } } @@ -158,35 +145,25 @@ impl ScalarUDFImpl for MakeArray { } } -// Empty array is a special case that is useful for many other array functions -pub(super) fn empty_array_type() -> DataType { - List(Arc::new(Field::new_list_field(DataType::Int64, true))) -} - /// `make_array_inner` is the implementation of the `make_array` function. /// Constructs an array using the input `data` as `ArrayRef`. /// Returns a reference-counted `Array` instance result. pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { - let mut data_type = Null; - for arg in arrays { - let arg_data_type = arg.data_type(); - if !arg_data_type.equals_datatype(&Null) { - data_type = arg_data_type.clone(); - break; - } - } + let data_type = arrays.iter().find_map(|arg| { + let arg_type = arg.data_type(); + (!arg_type.is_null()).then_some(arg_type) + }); - match data_type { + let data_type = data_type.unwrap_or(&Null); + if data_type.is_null() { // Either an empty array or all nulls: - Null => { - let length = arrays.iter().map(|a| a.len()).sum(); - // By default Int64 - let array = new_null_array(&DataType::Int64, length); - Ok(Arc::new( - SingleRowListArrayBuilder::new(array).build_list_array(), - )) - } - _ => array_array::(arrays, data_type), + let length = arrays.iter().map(|a| a.len()).sum(); + let array = new_null_array(&Null, length); + Ok(Arc::new( + SingleRowListArrayBuilder::new(array).build_list_array(), + )) + } else { + array_array::(arrays, data_type.clone()) } } diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index 828f2e244112b..03cfdb52c6de7 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -221,7 +221,7 @@ SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); For `make_map`: The list of values to be mapped to the corresponding keys." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct MapFunc { signature: Signature, } diff --git a/datafusion/functions-nested/src/map_entries.rs b/datafusion/functions-nested/src/map_entries.rs new file mode 100644 index 0000000000000..7d9d103206dbc --- /dev/null +++ b/datafusion/functions-nested/src/map_entries.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for map_entries function. + +use crate::utils::{get_map_entry_field, make_scalar_function}; +use arrow::array::{Array, ArrayRef, ListArray}; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use datafusion_macros::user_doc; +use std::any::Any; +use std::sync::Arc; + +make_udf_expr_and_func!( + MapEntriesFunc, + map_entries, + map, + "Return a list of all entries in the map.", + map_entries_udf +); + +#[user_doc( + doc_section(label = "Map Functions"), + description = "Returns a list of all entries in the map.", + syntax_example = "map_entries(map)", + sql_example = r#"```sql +SELECT map_entries(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[{'key': a, 'value': 1}, {'key': b, 'value': NULL}, {'key': c, 'value': 3}] + +SELECT map_entries(map([100, 5], [42, 43])); +---- +[{'key': 100, 'value': 42}, {'key': 5, 'value': 43}] +```"#, + argument( + name = "map", + description = "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MapEntriesFunc { + signature: Signature, +} + +impl Default for MapEntriesFunc { + fn default() -> Self { + Self::new() + } +} + +impl MapEntriesFunc { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for MapEntriesFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_entries" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [map_type] = take_function_args(self.name(), arg_types)?; + let map_fields = get_map_entry_field(map_type)?; + Ok(DataType::List(Arc::new(Field::new_list_field( + DataType::Struct(Fields::from(vec![ + Field::new( + "key", + map_fields.first().unwrap().data_type().clone(), + false, + ), + Field::new( + "value", + map_fields.get(1).unwrap().data_type().clone(), + true, + ), + ])), + false, + )))) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + make_scalar_function(map_entries_inner)(&args.args) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn map_entries_inner(args: &[ArrayRef]) -> Result { + let [map_arg] = take_function_args("map_entries", args)?; + + let map_array = match map_arg.data_type() { + DataType::Map(_, _) => as_map_array(&map_arg)?, + _ => return exec_err!("Argument for map_entries should be a map"), + }; + + Ok(Arc::new(ListArray::new( + Arc::new(Field::new_list_field( + DataType::Struct(Fields::from(vec![ + Field::new("key", map_array.key_type().clone(), false), + Field::new("value", map_array.value_type().clone(), true), + ])), + false, + )), + map_array.offsets().clone(), + Arc::new(map_array.entries().clone()), + map_array.nulls().cloned(), + ))) +} diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 55ab8447c54f1..4aab5d7a60d18 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -68,7 +68,7 @@ SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); description = "Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct MapExtract { signature: Signature, aliases: Vec, diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs index 0f15c06d86d15..2fc44670d74a2 100644 --- a/datafusion/functions-nested/src/map_keys.rs +++ b/datafusion/functions-nested/src/map_keys.rs @@ -56,7 +56,7 @@ SELECT map_keys(map([100, 5], [42, 43])); description = "Map expression. Can be a constant, column, or function, and any combination of map operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct MapKeysFunc { signature: Signature, } @@ -94,9 +94,10 @@ impl ScalarUDFImpl for MapKeysFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { let [map_type] = take_function_args(self.name(), arg_types)?; let map_fields = get_map_entry_field(map_type)?; + // internal array nullability is true to be in sync with DuckDB Ok(DataType::List(Arc::new(Field::new_list_field( map_fields.first().unwrap().data_type().clone(), - false, + true, )))) } @@ -121,7 +122,8 @@ fn map_keys_inner(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(ListArray::new( - Arc::new(Field::new_list_field(map_array.key_type().clone(), false)), + // internal array nullability is true to be in sync with DuckDB + Arc::new(Field::new_list_field(map_array.key_type().clone(), true)), map_array.offsets().clone(), Arc::clone(map_array.keys()), map_array.nulls().cloned(), diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index f82e4bfa1a897..6ae8a278063da 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -19,15 +19,16 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::take_function_args; -use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_common::{cast::as_map_array, exec_err, internal_err, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; +use std::ops::Deref; use std::sync::Arc; make_udf_expr_and_func!( @@ -56,7 +57,7 @@ SELECT map_values(map([100, 5], [42, 43])); description = "Map expression. Can be a constant, column, or function, and any combination of map operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(crate) struct MapValuesFunc { signature: Signature, } @@ -91,13 +92,23 @@ impl ScalarUDFImpl for MapValuesFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let [map_type] = take_function_args(self.name(), arg_types)?; - let map_fields = get_map_entry_field(map_type)?; - Ok(DataType::List(Arc::new(Field::new_list_field( - map_fields.last().unwrap().data_type().clone(), - true, - )))) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + let [map_type] = take_function_args(self.name(), args.arg_fields)?; + + Ok(Field::new( + self.name(), + DataType::List(get_map_values_field_as_list_field(map_type.data_type())?), + // Nullable if the map is nullable + args.arg_fields.iter().any(|x| x.is_nullable()), + ) + .into()) } fn invoke_with_args( @@ -121,9 +132,139 @@ fn map_values_inner(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(ListArray::new( - Arc::new(Field::new_list_field(map_array.value_type().clone(), true)), + get_map_values_field_as_list_field(map_arg.data_type())?, map_array.offsets().clone(), Arc::clone(map_array.values()), map_array.nulls().cloned(), ))) } + +fn get_map_values_field_as_list_field(map_type: &DataType) -> Result { + let map_fields = get_map_entry_field(map_type)?; + + let values_field = map_fields + .last() + .unwrap() + .deref() + .clone() + .with_name(Field::LIST_FIELD_DEFAULT_NAME); + + Ok(Arc::new(values_field)) +} + +#[cfg(test)] +mod tests { + use crate::map_values::MapValuesFunc; + use arrow::datatypes::{DataType, Field, FieldRef}; + use datafusion_common::ScalarValue; + use datafusion_expr::ScalarUDFImpl; + use std::sync::Arc; + + #[test] + fn return_type_field() { + fn get_map_field( + is_map_nullable: bool, + is_keys_nullable: bool, + is_values_nullable: bool, + ) -> FieldRef { + Field::new_map( + "something", + "entries", + Arc::new(Field::new("keys", DataType::Utf8, is_keys_nullable)), + Arc::new(Field::new( + "values", + DataType::LargeUtf8, + is_values_nullable, + )), + false, + is_map_nullable, + ) + .into() + } + + fn get_list_field( + name: &str, + is_list_nullable: bool, + list_item_type: DataType, + is_list_items_nullable: bool, + ) -> FieldRef { + Field::new_list( + name, + Arc::new(Field::new_list_field( + list_item_type, + is_list_items_nullable, + )), + is_list_nullable, + ) + .into() + } + + fn get_return_field(field: FieldRef) -> FieldRef { + let func = MapValuesFunc::new(); + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[field], + scalar_arguments: &[None::<&ScalarValue>], + }; + + func.return_field_from_args(args).unwrap() + } + + // Test cases: + // + // | Input Map || Expected Output | + // | ------------------------------------------------------ || ----------------------------------------------------- | + // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable | + // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- | + // | false | false | false || false | false | + // | false | false | true || false | true | + // | false | true | false || false | false | + // | false | true | true || false | true | + // | true | false | false || true | false | + // | true | false | true || true | true | + // | true | true | false || true | false | + // | true | true | true || true | true | + // + // --------------- + // We added the key nullability to show that it does not affect the nullability of the list or the list items. + + assert_eq!( + get_return_field(get_map_field(false, false, false)), + get_list_field("map_values", false, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(false, false, true)), + get_list_field("map_values", false, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(false, true, false)), + get_list_field("map_values", false, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(false, true, true)), + get_list_field("map_values", false, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(true, false, false)), + get_list_field("map_values", true, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(true, false, true)), + get_list_field("map_values", true, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(true, true, false)), + get_list_field("map_values", true, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(true, true, true)), + get_list_field("map_values", true, DataType::LargeUtf8, true) + ); + } +} diff --git a/datafusion/functions-nested/src/max.rs b/datafusion/functions-nested/src/max.rs deleted file mode 100644 index 32957edc62b5c..0000000000000 --- a/datafusion/functions-nested/src/max.rs +++ /dev/null @@ -1,138 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ScalarUDFImpl`] definitions for array_max function. -use crate::utils::make_scalar_function; -use arrow::array::ArrayRef; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::List; -use datafusion_common::cast::as_list_array; -use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, ScalarValue}; -use datafusion_doc::Documentation; -use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, -}; -use datafusion_functions_aggregate::min_max; -use datafusion_macros::user_doc; -use itertools::Itertools; -use std::any::Any; - -make_udf_expr_and_func!( - ArrayMax, - array_max, - array, - "returns the maximum value in the array.", - array_max_udf -); - -#[user_doc( - doc_section(label = "Array Functions"), - description = "Returns the maximum value in the array.", - syntax_example = "array_max(array)", - sql_example = r#"```sql -> select array_max([3,1,4,2]); -+-----------------------------------------+ -| array_max(List([3,1,4,2])) | -+-----------------------------------------+ -| 4 | -+-----------------------------------------+ -```"#, - argument( - name = "array", - description = "Array expression. Can be a constant, column, or function, and any combination of array operators." - ) -)] -#[derive(Debug)] -pub struct ArrayMax { - signature: Signature, - aliases: Vec, -} - -impl Default for ArrayMax { - fn default() -> Self { - Self::new() - } -} - -impl ArrayMax { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec!["list_max".to_string()], - } - } -} - -impl ScalarUDFImpl for ArrayMax { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "array_max" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { - match &arg_types[0] { - List(field) => Ok(field.data_type().clone()), - _ => exec_err!("Not reachable, data_type should be List"), - } - } - - fn invoke_with_args( - &self, - args: ScalarFunctionArgs, - ) -> datafusion_common::Result { - make_scalar_function(array_max_inner)(&args.args) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn documentation(&self) -> Option<&Documentation> { - self.doc() - } -} - -/// array_max SQL function -/// -/// There is one argument for array_max as the array. -/// `array_max(array)` -/// -/// For example: -/// > array_max(\[1, 3, 2]) -> 3 -pub fn array_max_inner(args: &[ArrayRef]) -> datafusion_common::Result { - let [arg1] = take_function_args("array_max", args)?; - - match arg1.data_type() { - List(_) => { - let input_list_array = as_list_array(&arg1)?; - let result_vec = input_list_array - .iter() - .flat_map(|arr| min_max::max_batch(&arr.unwrap())) - .collect_vec(); - ScalarValue::iter_to_array(result_vec) - } - _ => exec_err!("array_max does not support type: {:?}", arg1.data_type()), - } -} diff --git a/datafusion/functions-nested/src/min_max.rs b/datafusion/functions-nested/src/min_max.rs new file mode 100644 index 0000000000000..117cfbeaa2b2c --- /dev/null +++ b/datafusion/functions-nested/src/min_max.rs @@ -0,0 +1,224 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_max function. +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{LargeList, List}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::utils::take_function_args; +use datafusion_common::Result; +use datafusion_common::{exec_err, plan_err, ScalarValue}; +use datafusion_doc::Documentation; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch}; +use datafusion_macros::user_doc; +use itertools::Itertools; +use std::any::Any; + +make_udf_expr_and_func!( + ArrayMax, + array_max, + array, + "returns the maximum value in the array.", + array_max_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the maximum value in the array.", + syntax_example = "array_max(array)", + sql_example = r#"```sql +> select array_max([3,1,4,2]); ++-----------------------------------------+ +| array_max(List([3,1,4,2])) | ++-----------------------------------------+ +| 4 | ++-----------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayMax { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayMax { + fn default() -> Self { + Self::new() + } +} + +impl ArrayMax { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec!["list_max".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayMax { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_max" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [array] = take_function_args(self.name(), arg_types)?; + match array { + List(field) | LargeList(field) => Ok(field.data_type().clone()), + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_max_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// array_max SQL function +/// +/// There is one argument for array_max as the array. +/// `array_max(array)` +/// +/// For example: +/// > array_max(\[1, 3, 2]) -> 3 +pub fn array_max_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_max", args)?; + match array.data_type() { + List(_) => array_min_max_helper(as_list_array(array)?, max_batch), + LargeList(_) => array_min_max_helper(as_large_list_array(array)?, max_batch), + arg_type => exec_err!("array_max does not support type: {arg_type}"), + } +} + +make_udf_expr_and_func!( + ArrayMin, + array_min, + array, + "returns the minimum value in the array", + array_min_udf +); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the minimum value in the array.", + syntax_example = "array_min(array)", + sql_example = r#"```sql +> select array_min([3,1,4,2]); ++-----------------------------------------+ +| array_min(List([3,1,4,2])) | ++-----------------------------------------+ +| 1 | ++-----------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +struct ArrayMin { + signature: Signature, +} + +impl Default for ArrayMin { + fn default() -> Self { + Self::new() + } +} + +impl ArrayMin { + fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ArrayMin { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_min" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [array] = take_function_args(self.name(), arg_types)?; + match array { + List(field) | LargeList(field) => Ok(field.data_type().clone()), + arg_type => plan_err!("{} does not support type {}", self.name(), arg_type), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_min_inner)(&args.args) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +pub fn array_min_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_min", args)?; + match array.data_type() { + List(_) => array_min_max_helper(as_list_array(array)?, min_batch), + LargeList(_) => array_min_max_helper(as_large_list_array(array)?, min_batch), + arg_type => exec_err!("array_min does not support type: {arg_type}"), + } +} + +fn array_min_max_helper( + array: &GenericListArray, + agg_fn: fn(&ArrayRef) -> Result, +) -> Result { + let null_value = ScalarValue::try_from(array.value_type())?; + let result_vec: Vec = array + .iter() + .map(|arr| arr.as_ref().map_or_else(|| Ok(null_value.clone()), agg_fn)) + .try_collect()?; + ScalarValue::iter_to_array(result_vec) +} diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 369eaecb1905f..f4fa8630a8d37 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -22,11 +22,15 @@ use datafusion_common::ExprSchema; use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; +#[cfg(feature = "sql")] +use datafusion_expr::sqlparser::ast::BinaryOperator; use datafusion_expr::AggregateUDF; use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, - sqlparser, Expr, ExprSchemable, GetFieldAccess, + Expr, ExprSchemable, GetFieldAccess, }; +#[cfg(not(feature = "sql"))] +use datafusion_expr_common::operator::Operator as BinaryOperator; use datafusion_functions::core::get_field as get_field_inner; use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -51,7 +55,7 @@ impl ExprPlanner for NestedFunctionPlanner { ) -> Result> { let RawBinaryExpr { op, left, right } = expr; - if op == sqlparser::ast::BinaryOperator::StringConcat { + if op == BinaryOperator::StringConcat { let left_type = left.get_type(schema)?; let right_type = right.get_type(schema)?; let left_list_ndims = list_ndims(&left_type); @@ -75,18 +79,14 @@ impl ExprPlanner for NestedFunctionPlanner { } else if left_list_ndims < right_list_ndims { return Ok(PlannerResult::Planned(array_prepend(left, right))); } - } else if matches!( - op, - sqlparser::ast::BinaryOperator::AtArrow - | sqlparser::ast::BinaryOperator::ArrowAt - ) { + } else if matches!(op, BinaryOperator::AtArrow | BinaryOperator::ArrowAt) { let left_type = left.get_type(schema)?; let right_type = right.get_type(schema)?; let left_list_ndims = list_ndims(&left_type); let right_list_ndims = list_ndims(&right_type); // if both are list if left_list_ndims > 0 && right_list_ndims > 0 { - if op == sqlparser::ast::BinaryOperator::AtArrow { + if op == BinaryOperator::AtArrow { // array1 @> array2 -> array_has_all(array1, array2) return Ok(PlannerResult::Planned(array_has_all(left, right))); } else { @@ -108,7 +108,7 @@ impl ExprPlanner for NestedFunctionPlanner { } fn plan_make_map(&self, args: Vec) -> Result>> { - if args.len() % 2 != 0 { + if !args.len().is_multiple_of(2) { return plan_err!("make_map requires an even number of arguments"); } @@ -123,7 +123,7 @@ impl ExprPlanner for NestedFunctionPlanner { } fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - if expr.op == sqlparser::ast::BinaryOperator::Eq { + if expr.op == BinaryOperator::Eq { Ok(PlannerResult::Planned(Expr::ScalarFunction( ScalarFunction::new_udf( array_has_udf(), diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index b186b65407c32..dae946def8f53 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -52,7 +52,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns the position of the first occurrence of the specified element in the array.", + description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found.", syntax_example = "array_position(array, element)\narray_position(array, element, index)", sql_example = r#"```sql > select array_position([1, 2, 2, 3, 1, 4], 2); @@ -76,9 +76,12 @@ make_udf_expr_and_func!( name = "element", description = "Element to search for position in the array." ), - argument(name = "index", description = "Index at which to start searching.") + argument( + name = "index", + description = "Index at which to start searching (1-indexed)." + ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayPosition { signature: Signature, aliases: Vec, @@ -144,7 +147,7 @@ pub fn array_position_inner(args: &[ArrayRef]) -> Result { match &args[0].data_type() { List(_) => general_position_dispatch::(args), LargeList(_) => general_position_dispatch::(args), - array_type => exec_err!("array_position does not support type '{array_type:?}'."), + array_type => exec_err!("array_position does not support type '{array_type}'."), } } fn general_position_dispatch(args: &[ArrayRef]) -> Result { @@ -170,7 +173,7 @@ fn general_position_dispatch(args: &[ArrayRef]) -> Result= arr.len() { + if from < 0 || from as usize > arr.len() { return internal_err!("start_from index out of bounds"); } } else { @@ -242,7 +245,7 @@ make_udf_expr_and_func!( description = "Element to search for position in the array." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayPositions { signature: Signature, aliases: Vec, @@ -305,7 +308,7 @@ pub fn array_positions_inner(args: &[ArrayRef]) -> Result { general_positions::(arr, element) } array_type => { - exec_err!("array_positions does not support type '{array_type:?}'.") + exec_err!("array_positions does not support type '{array_type}'.") } } } diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index 637a78d158ab2..619b0e84c19a7 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -88,7 +88,7 @@ make_udf_expr_and_func!( description = "Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Range { signature: Signature, aliases: Vec, @@ -218,7 +218,7 @@ make_udf_expr_and_func!( description = "Increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct GenSeries { signature: Signature, aliases: Vec, diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 7f5baa18e7693..d330606cdd894 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -26,9 +26,11 @@ use arrow::array::{ use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::as_int64_array; +use datafusion_common::utils::ListCoercion; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -63,7 +65,7 @@ make_udf_expr_and_func!( description = "Element to be removed from the array." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayRemove { signature: Signature, aliases: Vec, @@ -147,7 +149,7 @@ make_udf_expr_and_func!( ), argument(name = "max", description = "Number of first occurrences to remove.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayRemoveN { signature: Signature, aliases: Vec, @@ -156,7 +158,17 @@ pub(super) struct ArrayRemoveN { impl ArrayRemoveN { pub fn new() -> Self { Self { - signature: Signature::any(3, Volatility::Immutable), + signature: Signature::new( + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Index, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }), + Volatility::Immutable, + ), aliases: vec!["list_remove_n".to_string()], } } @@ -224,7 +236,7 @@ make_udf_expr_and_func!( description = "Element to be removed from the array." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayRemoveAll { signature: Signature, aliases: Vec, @@ -311,7 +323,7 @@ fn array_remove_internal( general_remove::(list_array, element_array, arr_n) } array_type => { - exec_err!("array_remove_all does not support type '{array_type:?}'.") + exec_err!("array_remove_all does not support type '{array_type}'.") } } } diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index 26d67ad3113ff..ed66b9e396762 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -74,7 +74,7 @@ make_udf_expr_and_func!( description = "Value of how many times to repeat the element." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayRepeat { signature: Signature, aliases: Vec, diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 3dbe672c5b028..59f851a776a18 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -78,7 +78,7 @@ make_udf_expr_and_func!(ArrayReplaceAll, argument(name = "from", description = "Initial element."), argument(name = "to", description = "Final element.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayReplace { signature: Signature, aliases: Vec, @@ -164,7 +164,7 @@ impl ScalarUDFImpl for ArrayReplace { argument(name = "to", description = "Final element."), argument(name = "max", description = "Number of first occurrences to replace.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayReplaceN { signature: Signature, aliases: Vec, @@ -244,7 +244,7 @@ impl ScalarUDFImpl for ArrayReplaceN { argument(name = "from", description = "Initial element."), argument(name = "to", description = "Final element.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayReplaceAll { signature: Signature, aliases: Vec, @@ -430,7 +430,7 @@ pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result { general_replace::(list_array, from, to, arr_n) } DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => exec_err!("array_replace does not support type '{array_type:?}'."), + array_type => exec_err!("array_replace does not support type '{array_type}'."), } } @@ -450,7 +450,7 @@ pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { } DataType::Null => Ok(new_null_array(array.data_type(), 1)), array_type => { - exec_err!("array_replace_n does not support type '{array_type:?}'.") + exec_err!("array_replace_n does not support type '{array_type}'.") } } } @@ -471,7 +471,7 @@ pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result { } DataType::Null => Ok(new_null_array(array.data_type(), 1)), array_type => { - exec_err!("array_replace_all does not support type '{array_type:?}'.") + exec_err!("array_replace_all does not support type '{array_type}'.") } } } diff --git a/datafusion/functions-nested/src/resize.rs b/datafusion/functions-nested/src/resize.rs index 145d7e80043b8..09f67a75fd56a 100644 --- a/datafusion/functions-nested/src/resize.rs +++ b/datafusion/functions-nested/src/resize.rs @@ -26,7 +26,7 @@ use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; use arrow::datatypes::{ArrowNativeType, Field}; use arrow::datatypes::{ - DataType::{FixedSizeList, LargeList, List}, + DataType::{LargeList, List}, FieldRef, }; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; @@ -70,7 +70,7 @@ make_udf_expr_and_func!( description = "Defines new elements' value or empty if value is not set." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayResize { signature: Signature, aliases: Vec, @@ -125,7 +125,7 @@ impl ScalarUDFImpl for ArrayResize { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::clone(field))), + List(field) => Ok(List(Arc::clone(field))), LargeList(field) => Ok(LargeList(Arc::clone(field))), DataType::Null => { Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true)))) @@ -191,7 +191,7 @@ pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result { let array = as_large_list_array(&arg[0])?; general_list_resize::(array, new_len, field, new_element) } - array_type => exec_err!("array_resize does not support type '{array_type:?}'."), + array_type => exec_err!("array_resize does not support type '{array_type}'."), } } diff --git a/datafusion/functions-nested/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs index 140cd19aeff9c..8440d890d2528 100644 --- a/datafusion/functions-nested/src/reverse.rs +++ b/datafusion/functions-nested/src/reverse.rs @@ -19,12 +19,15 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, + Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData, + OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::DataType::{LargeList, List, Null}; +use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -58,7 +61,7 @@ make_udf_expr_and_func!( description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayReverse { signature: Signature, aliases: Vec, @@ -73,7 +76,7 @@ impl Default for ArrayReverse { impl ArrayReverse { pub fn new() -> Self { Self { - signature: Signature::any(1, Volatility::Immutable), + signature: Signature::array(Volatility::Immutable), aliases: vec!["list_reverse".to_string()], } } @@ -125,8 +128,12 @@ pub fn array_reverse_inner(arg: &[ArrayRef]) -> Result { let array = as_large_list_array(input_array)?; general_array_reverse::(array, field) } + FixedSizeList(field, _) => { + let array = as_fixed_size_list_array(input_array)?; + fixed_size_array_reverse(array, field) + } Null => Ok(Arc::clone(input_array)), - array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), + array_type => exec_err!("array_reverse does not support type '{array_type}'."), } } @@ -175,3 +182,40 @@ fn general_array_reverse>( Some(nulls.into()), )?)) } + +fn fixed_size_array_reverse( + array: &FixedSizeListArray, + field: &FieldRef, +) -> Result { + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut nulls = vec![]; + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + let value_length = array.value_length() as usize; + + for row_index in 0..array.len() { + // skip the null value + if array.is_null(row_index) { + nulls.push(false); + mutable.extend(0, 0, value_length); + continue; + } else { + nulls.push(true); + } + let start = row_index * value_length; + let end = start + value_length; + for idx in (start..end).rev() { + mutable.extend(0, idx, idx + 1); + } + } + + let data = mutable.freeze(); + Ok(Arc::new(FixedSizeListArray::try_new( + Arc::clone(field), + array.value_length(), + arrow::array::make_array(data), + Some(nulls.into()), + )?)) +} diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index a67945b1f1e1e..555767f8f070b 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -17,16 +17,21 @@ //! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions. -use crate::make_array::{empty_array_type, make_array_inner}; use crate::utils::make_scalar_function; -use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::array::{ + new_null_array, Array, ArrayRef, GenericListArray, LargeListArray, ListArray, + OffsetSizeTrait, +}; use arrow::buffer::OffsetBuffer; use arrow::compute; -use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; +use arrow::datatypes::DataType::{LargeList, List, Null}; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{ + exec_err, internal_err, plan_err, utils::take_function_args, Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -89,7 +94,7 @@ make_udf_expr_and_func!( description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayUnion { signature: Signature, aliases: Vec, @@ -104,7 +109,11 @@ impl Default for ArrayUnion { impl ArrayUnion { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ), aliases: vec![String::from("list_union")], } } @@ -124,8 +133,10 @@ impl ScalarUDFImpl for ArrayUnion { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (&arg_types[0], &arg_types[1]) { - (&Null, dt) => Ok(dt.clone()), + let [array1, array2] = take_function_args(self.name(), arg_types)?; + match (array1, array2) { + (Null, Null) => Ok(DataType::new_list(Null, true)), + (Null, dt) => Ok(dt.clone()), (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } @@ -174,7 +185,7 @@ impl ScalarUDFImpl for ArrayUnion { description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayIntersect { signature: Signature, aliases: Vec, @@ -183,7 +194,11 @@ pub(super) struct ArrayIntersect { impl ArrayIntersect { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ), aliases: vec![String::from("list_intersect")], } } @@ -203,10 +218,12 @@ impl ScalarUDFImpl for ArrayIntersect { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (arg_types[0].clone(), arg_types[1].clone()) { - (Null, Null) | (Null, _) => Ok(Null), - (_, Null) => Ok(empty_array_type()), - (dt, _) => Ok(dt), + let [array1, array2] = take_function_args(self.name(), arg_types)?; + match (array1, array2) { + (Null, Null) => Ok(DataType::new_list(Null, true)), + (Null, dt) => Ok(dt.clone()), + (dt, Null) => Ok(dt.clone()), + (dt, _) => Ok(dt.clone()), } } @@ -243,7 +260,7 @@ impl ScalarUDFImpl for ArrayIntersect { description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct ArrayDistinct { signature: Signature, aliases: Vec, @@ -273,16 +290,11 @@ impl ScalarUDFImpl for ArrayDistinct { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new( - Field::new_list_field(field.data_type().clone(), true), - ))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field( - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), + List(field) => Ok(DataType::new_list(field.data_type().clone(), true)), + LargeList(field) => { + Ok(DataType::new_large_list(field.data_type().clone(), true)) + } + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), } } @@ -305,24 +317,18 @@ impl ScalarUDFImpl for ArrayDistinct { /// array_distinct SQL function /// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] fn array_distinct_inner(args: &[ArrayRef]) -> Result { - let [input_array] = take_function_args("array_distinct", args)?; - - // handle null - if input_array.data_type() == &Null { - return Ok(Arc::clone(input_array)); - } - - // handle for list & largelist - match input_array.data_type() { + let [array] = take_function_args("array_distinct", args)?; + match array.data_type() { + Null => Ok(Arc::clone(array)), List(field) => { - let array = as_list_array(&input_array)?; + let array = as_list_array(&array)?; general_array_distinct(array, field) } LargeList(field) => { - let array = as_large_list_array(&input_array)?; + let array = as_large_list_array(&array)?; general_array_distinct(array, field) } - array_type => exec_err!("array_distinct does not support type '{array_type:?}'"), + arg_type => exec_err!("array_distinct does not support type {arg_type}"), } } @@ -347,80 +353,76 @@ fn generic_set_lists( field: Arc, set_op: SetOp, ) -> Result { - if matches!(l.value_type(), Null) { + if l.is_empty() || l.value_type().is_null() { let field = Arc::new(Field::new_list_field(r.value_type(), true)); return general_array_distinct::(r, &field); - } else if matches!(r.value_type(), Null) { + } else if r.is_empty() || r.value_type().is_null() { let field = Arc::new(Field::new_list_field(l.value_type(), true)); return general_array_distinct::(l, &field); } - // Handle empty array at rhs case - // array_union(arr, []) -> arr; - // array_intersect(arr, []) -> []; - if r.value_length(0).is_zero() { - if set_op == SetOp::Union { - return Ok(Arc::new(l.clone()) as ArrayRef); - } else { - return Ok(Arc::new(r.clone()) as ArrayRef); - } - } - if l.value_type() != r.value_type() { return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'"); } - let dt = l.value_type(); - let mut offsets = vec![OffsetSize::usize_as(0)]; let mut new_arrays = vec![]; - - let converter = RowConverter::new(vec![SortField::new(dt)])?; + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; for (first_arr, second_arr) in l.iter().zip(r.iter()) { - if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { - let l_values = converter.convert_columns(&[first_arr])?; - let r_values = converter.convert_columns(&[second_arr])?; - - let l_iter = l_values.iter().sorted().dedup(); - let values_set: HashSet<_> = l_iter.clone().collect(); - let mut rows = if set_op == SetOp::Union { - l_iter.collect::>() - } else { - vec![] - }; - for r_val in r_values.iter().sorted().dedup() { - match set_op { - SetOp::Union => { - if !values_set.contains(&r_val) { - rows.push(r_val); - } + let l_values = if let Some(first_arr) = first_arr { + converter.convert_columns(&[first_arr])? + } else { + converter.convert_columns(&[])? + }; + + let r_values = if let Some(second_arr) = second_arr { + converter.convert_columns(&[second_arr])? + } else { + converter.convert_columns(&[])? + }; + + let l_iter = l_values.iter().sorted().dedup(); + let values_set: HashSet<_> = l_iter.clone().collect(); + let mut rows = if set_op == SetOp::Union { + l_iter.collect() + } else { + vec![] + }; + + for r_val in r_values.iter().sorted().dedup() { + match set_op { + SetOp::Union => { + if !values_set.contains(&r_val) { + rows.push(r_val); } - SetOp::Intersect => { - if values_set.contains(&r_val) { - rows.push(r_val); - } + } + SetOp::Intersect => { + if values_set.contains(&r_val) { + rows.push(r_val); } } } - - let last_offset = match offsets.last().copied() { - Some(offset) => offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("{set_op}: failed to get array from rows"); - } - }; - new_arrays.push(array); } + + let last_offset = match offsets.last() { + Some(offset) => *offset, + None => return internal_err!("offsets should not be empty"), + }; + + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => Arc::clone(array), + None => { + return internal_err!("{set_op}: failed to get array from rows"); + } + }; + + new_arrays.push(array); } let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect(); let values = compute::concat(&new_arrays_ref)?; let arr = GenericListArray::::try_new(field, offsets, values, None)?; Ok(Arc::new(arr)) @@ -431,38 +433,59 @@ fn general_set_op( array2: &ArrayRef, set_op: SetOp, ) -> Result { + fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result { + let field = Arc::new(Field::new_list_field(data_type.clone(), true)); + let values = new_null_array(data_type, len); + if large { + Ok(Arc::new(LargeListArray::try_new( + field, + OffsetBuffer::new_zeroed(len), + values, + None, + )?)) + } else { + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new_zeroed(len), + values, + None, + )?)) + } + } + match (array1.data_type(), array2.data_type()) { + (Null, Null) => Ok(Arc::new(ListArray::new_null( + Arc::new(Field::new_list_field(Null, true)), + array1.len(), + ))), (Null, List(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&Null)); + return empty_array(field.data_type(), array1.len(), false); } let array = as_list_array(&array2)?; general_array_distinct::(array, field) } - (List(field), Null) => { if set_op == SetOp::Intersect { - return make_array_inner(&[]); + return empty_array(field.data_type(), array1.len(), false); } let array = as_list_array(&array1)?; general_array_distinct::(array, field) } (Null, LargeList(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&Null)); + return empty_array(field.data_type(), array1.len(), true); } let array = as_large_list_array(&array2)?; general_array_distinct::(array, field) } (LargeList(field), Null) => { if set_op == SetOp::Intersect { - return make_array_inner(&[]); + return empty_array(field.data_type(), array1.len(), true); } let array = as_large_list_array(&array1)?; general_array_distinct::(array, field) } - (Null, Null) => Ok(new_empty_array(&Null)), - (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index 1db245fe52fed..4a7aa31c755b7 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -18,13 +18,16 @@ //! [`ScalarUDFImpl`] definitions for array_sort function. use crate::utils::make_scalar_function; -use arrow::array::{new_null_array, Array, ArrayRef, ListArray, NullBufferBuilder}; +use arrow::array::{ + new_null_array, Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait, +}; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::DataType::{FixedSizeList, LargeList, List}; -use arrow::datatypes::{DataType, Field}; +use arrow::compute::SortColumn; +use arrow::datatypes::{DataType, FieldRef}; use arrow::{compute, compute::SortOptions}; -use datafusion_common::cast::{as_list_array, as_string_array}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{exec_err, plan_err, Result}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -73,7 +76,7 @@ make_udf_expr_and_func!( description = "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`)." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArraySort { signature: Signature, aliases: Vec, @@ -92,14 +95,14 @@ impl ArraySort { vec![ TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ ArrayFunctionArgument::Array, ArrayFunctionArgument::String, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ @@ -107,7 +110,7 @@ impl ArraySort { ArrayFunctionArgument::String, ArrayFunctionArgument::String, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), ], Volatility::Immutable, @@ -132,17 +135,16 @@ impl ScalarUDFImpl for ArraySort { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new( - Field::new_list_field(field.data_type().clone(), true), - ))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field( - field.data_type().clone(), - true, - )))), DataType::Null => Ok(DataType::Null), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), + DataType::List(field) => { + Ok(DataType::new_list(field.data_type().clone(), true)) + } + DataType::LargeList(field) => { + Ok(DataType::new_large_list(field.data_type().clone(), true)) + } + arg_type => { + plan_err!("{} does not support type {arg_type}", self.name()) + } } } @@ -168,11 +170,15 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { return exec_err!("array_sort expects one to three arguments"); } + if args[0].is_empty() || args[0].data_type().is_null() { + return Ok(Arc::clone(&args[0])); + } + if args[1..].iter().any(|array| array.is_null(0)) { return Ok(new_null_array(args[0].data_type(), args[0].len())); } - let sort_option = match args.len() { + let sort_options = match args.len() { 1 => None, 2 => { let sort = as_string_array(&args[1])?.value(0); @@ -189,14 +195,36 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { nulls_first: order_nulls_first(nulls_first)?, }) } - _ => return exec_err!("array_sort expects 1 to 3 arguments"), + // We guard at the top + _ => unreachable!(), }; - let list_array = as_list_array(&args[0])?; - let row_count = list_array.len(); - if row_count == 0 { - return Ok(Arc::clone(&args[0])); + match args[0].data_type() { + DataType::List(field) | DataType::LargeList(field) + if field.data_type().is_null() => + { + Ok(Arc::clone(&args[0])) + } + DataType::List(field) => { + let array = as_list_array(&args[0])?; + array_sort_generic(array, field, sort_options) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&args[0])?; + array_sort_generic(array, field, sort_options) + } + // Signature should prevent this arm ever occurring + _ => exec_err!("array_sort expects list for first argument"), } +} + +/// Array_sort SQL function +pub fn array_sort_generic( + list_array: &GenericListArray, + field: &FieldRef, + sort_options: Option, +) -> Result { + let row_count = list_array.len(); let mut array_lengths = vec![]; let mut arrays = vec![]; @@ -207,17 +235,30 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { valid.append_null(); } else { let arr_ref = list_array.value(i); - let arr_ref = arr_ref.as_ref(); - let sorted_array = compute::sort(arr_ref, sort_option)?; + // arrow sort kernel does not support Structs, so use + // lexsort_to_indices instead: + // https://github.com/apache/arrow-rs/issues/6911#issuecomment-2562928843 + let sorted_array = match arr_ref.data_type() { + DataType::Struct(_) => { + let sort_columns: Vec = vec![SortColumn { + values: Arc::clone(&arr_ref), + options: sort_options, + }]; + let indices = compute::lexsort_to_indices(&sort_columns, None)?; + compute::take(arr_ref.as_ref(), &indices, None)? + } + _ => { + let arr_ref = arr_ref.as_ref(); + compute::sort(arr_ref, sort_options)? + } + }; array_lengths.push(sorted_array.len()); arrays.push(sorted_array); valid.append_non_null(); } } - // Assume all arrays have the same data type - let data_type = list_array.value_type(); let buffer = valid.finish(); let elements = arrays @@ -226,10 +267,10 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { .collect::>(); let list_arr = if elements.is_empty() { - ListArray::new_null(Arc::new(Field::new_list_field(data_type, true)), row_count) + GenericListArray::::new_null(Arc::clone(field), row_count) } else { - ListArray::new( - Arc::new(Field::new_list_field(data_type, true)), + GenericListArray::::new( + Arc::clone(field), OffsetBuffer::from_lengths(array_lengths), Arc::new(compute::concat(elements.as_slice())?), buffer, diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index d60d1a6e4de02..3373f7a9838e1 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -25,9 +25,8 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{ - internal_datafusion_err, not_impl_err, plan_err, DataFusionError, Result, -}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use std::any::Any; @@ -41,14 +40,17 @@ use arrow::compute::cast; use arrow::datatypes::DataType::{ Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View, }; -use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; use datafusion_common::exec_err; use datafusion_common::types::logical_string; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, - TypeSignatureClass, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, Coercion, ColumnarValue, + Documentation, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, + Volatility, }; -use datafusion_functions::{downcast_arg, downcast_named_arg}; +use datafusion_functions::downcast_arg; use datafusion_macros::user_doc; use std::sync::Arc; @@ -146,7 +148,7 @@ make_udf_expr_and_func!( description = "Optional. String to replace null values in the array. If not provided, nulls will be handled by default behavior." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayToString { signature: Signature, aliases: Vec, @@ -161,7 +163,26 @@ impl Default for ArrayToString { impl ArrayToString { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::one_of( + vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::String, + ArrayFunctionArgument::String, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }), + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::String, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }), + ], + Volatility::Immutable, + ), aliases: vec![ String::from("list_to_string"), String::from("array_join"), @@ -184,13 +205,8 @@ impl ScalarUDFImpl for ArrayToString { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, - _ => { - return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) } fn invoke_with_args( @@ -242,7 +258,7 @@ make_udf_expr_and_func!( description = "Substring values to be replaced with `NULL`." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(super) struct StringToArray { signature: Signature, aliases: Vec, @@ -284,16 +300,10 @@ impl ScalarUDFImpl for StringToArray { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - Utf8 | Utf8View | LargeUtf8 => { - List(Arc::new(Field::new_list_field(arg_types[0].clone(), true))) - } - _ => { - return plan_err!( - "The string_to_array function can only accept Utf8, Utf8View or LargeUtf8." - ); - } - }) + Ok(List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) } fn invoke_with_args( @@ -370,6 +380,20 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { Ok(arg) } + FixedSizeList(..) => { + let list_array = as_fixed_size_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + + Ok(arg) + } LargeList(..) => { let list_array = as_large_list_array(&arr)?; for i in 0..list_array.len() { @@ -451,9 +475,8 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { Ok(StringArray::from(res)) } - let arr_type = arr.data_type(); - let string_arr = match arr_type { - List(_) | FixedSizeList(_, _) => { + let string_arr = match arr.data_type() { + List(_) => { let list_array = as_list_array(&arr)?; generate_string_array::( list_array, @@ -471,29 +494,8 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { with_null_string, )? } - _ => { - let mut arg = String::from(""); - let mut res: Vec> = Vec::new(); - // delimiter length is 1 - assert_eq!(delimiters.len(), 1); - let delimiter = delimiters[0].unwrap(); - let s = compute_array_to_string( - &mut arg, - Arc::clone(arr), - delimiter.to_string(), - null_string, - with_null_string, - )? - .clone(); - - if !s.is_empty() { - let s = s.strip_suffix(delimiter).unwrap().to_string(); - res.push(Some(s)); - } else { - res.push(Some(s)); - } - StringArray::from(res) - } + // Signature guards against this arm + _ => return exec_err!("array_to_string expects list as first argument"), }; Ok(Arc::new(string_arr)) diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 74b21a3ceb479..464301b6ffcf0 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -22,17 +22,16 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, - UInt32Array, + Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, }; use arrow::buffer::OffsetBuffer; -use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, Result, ScalarValue, +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, }; +use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_functions::{downcast_arg, downcast_named_arg}; +use itertools::Itertools as _; pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); @@ -41,7 +40,10 @@ pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { || arg.data_type().equals_datatype(&DataType::Null) }) { let types = args.iter().map(|arg| arg.data_type()).collect::>(); - return plan_err!("{name} received incompatible types: '{types:?}'."); + return plan_err!( + "{name} received incompatible types: {}", + types.iter().join(", ") + ); } Ok(()) @@ -234,8 +236,16 @@ pub(crate) fn compute_array_dims( loop { match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); + DataType::List(_) => { + value = as_list_array(&value)?.value(0); + res.push(Some(value.len() as u64)); + } + DataType::LargeList(_) => { + value = as_large_list_array(&value)?.value(0); + res.push(Some(value.len() as u64)); + } + DataType::FixedSizeList(..) => { + value = as_fixed_size_list_array(&value)?.value(0); res.push(Some(value.len() as u64)); } _ => return Ok(Some(res)), @@ -254,13 +264,14 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { } } } - _ => internal_err!("Expected a Map type, got {:?}", data_type), + _ => internal_err!("Expected a Map type, got {data_type}"), } } #[cfg(test)] mod tests { use super::*; + use arrow::array::ListArray; use arrow::datatypes::Int64Type; use datafusion_common::utils::SingleRowListArrayBuilder; diff --git a/datafusion/functions-table/README.md b/datafusion/functions-table/README.md index c4e7a5aff9993..89f589a9584c5 100644 --- a/datafusion/functions-table/README.md +++ b/datafusion/functions-table/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Table Function Library +# Apache DataFusion Table Function Library -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate contains table functions that can be used in DataFusion queries. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index 5bb56f28bc8d3..d00f3d734d76a 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -15,8 +15,12 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Int64Array; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::array::timezone::Tz; +use arrow::array::types::TimestampNanosecondType; +use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray}; +use arrow::datatypes::{ + DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit, +}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; use datafusion_catalog::Session; @@ -27,103 +31,428 @@ use datafusion_expr::{Expr, TableType}; use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; use datafusion_physical_plan::ExecutionPlan; use parking_lot::RwLock; +use std::any::Any; use std::fmt; +use std::str::FromStr; use std::sync::Arc; +/// Empty generator that produces no rows - used when series arguments contain null values +#[derive(Debug, Clone)] +pub struct Empty { + name: &'static str, +} + +impl Empty { + pub fn name(&self) -> &'static str { + self.name + } +} + +impl LazyBatchGenerator for Empty { + fn as_any(&self) -> &dyn Any { + self + } + + fn generate_next_batch(&mut self) -> Result> { + Ok(None) + } +} + +impl fmt::Display for Empty { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}: empty", self.name) + } +} + +/// Trait for values that can be generated in a series +pub trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static { + type StepType: fmt::Debug + Clone + Send + Sync; + type ValueType: fmt::Debug + Clone + Send + Sync; + + /// Check if we've reached the end of the series + fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool; + + /// Advance to the next value in the series + fn advance(&mut self, step: &Self::StepType) -> Result<()>; + + /// Create an Arrow array from a vector of values + fn create_array(&self, values: Vec) -> Result; + + /// Convert self to ValueType for array creation + fn to_value_type(&self) -> Self::ValueType; + + /// Display the value for debugging + fn display_value(&self) -> String; +} + +impl SeriesValue for i64 { + type StepType = i64; + type ValueType = i64; + + fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool { + reach_end_int64(*self, end, *step, include_end) + } + + fn advance(&mut self, step: &Self::StepType) -> Result<()> { + *self += step; + Ok(()) + } + + fn create_array(&self, values: Vec) -> Result { + Ok(Arc::new(Int64Array::from(values))) + } + + fn to_value_type(&self) -> Self::ValueType { + *self + } + + fn display_value(&self) -> String { + self.to_string() + } +} + +#[derive(Debug, Clone)] +pub struct TimestampValue { + value: i64, + parsed_tz: Option, + tz_str: Option>, +} + +impl TimestampValue { + pub fn value(&self) -> i64 { + self.value + } + + pub fn tz_str(&self) -> Option<&Arc> { + self.tz_str.as_ref() + } +} + +impl SeriesValue for TimestampValue { + type StepType = IntervalMonthDayNano; + type ValueType = i64; + + fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool { + let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0; + + if include_end { + if step_negative { + self.value < end.value + } else { + self.value > end.value + } + } else if step_negative { + self.value <= end.value + } else { + self.value >= end.value + } + } + + fn advance(&mut self, step: &Self::StepType) -> Result<()> { + let tz = self + .parsed_tz + .unwrap_or_else(|| Tz::from_str("+00:00").unwrap()); + let Some(next_ts) = + TimestampNanosecondType::add_month_day_nano(self.value, *step, tz) + else { + return plan_err!( + "Failed to add interval {:?} to timestamp {}", + step, + self.value + ); + }; + self.value = next_ts; + Ok(()) + } + + fn create_array(&self, values: Vec) -> Result { + let array = TimestampNanosecondArray::from(values); + + // Use timezone from self (now we have access to tz through &self) + let array = match self.tz_str.as_ref() { + Some(tz_str) => array.with_timezone(Arc::clone(tz_str)), + None => array, + }; + + Ok(Arc::new(array)) + } + + fn to_value_type(&self) -> Self::ValueType { + self.value + } + + fn display_value(&self) -> String { + self.value.to_string() + } +} + /// Indicates the arguments used for generating a series. #[derive(Debug, Clone)] -enum GenSeriesArgs { +pub enum GenSeriesArgs { /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated. - ContainsNull { + ContainsNull { name: &'static str }, + /// Int64Args holds the start, end, and step values for generating integer series when all arguments are not null. + Int64Args { + start: i64, + end: i64, + step: i64, + /// Indicates whether the end value should be included in the series. include_end: bool, name: &'static str, }, - /// AllNotNullArgs holds the start, end, and step values for generating the series when all arguments are not null. - AllNotNullArgs { + /// TimestampArgs holds the start, end, and step values for generating timestamp series when all arguments are not null. + TimestampArgs { start: i64, end: i64, - step: i64, + step: IntervalMonthDayNano, + tz: Option>, + /// Indicates whether the end value should be included in the series. + include_end: bool, + name: &'static str, + }, + /// DateArgs holds the start, end, and step values for generating date series when all arguments are not null. + /// Internally, dates are converted to timestamps and use the timestamp logic. + DateArgs { + start: i64, + end: i64, + step: IntervalMonthDayNano, /// Indicates whether the end value should be included in the series. include_end: bool, name: &'static str, }, } -/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step +/// Table that generates a series of integers/timestamps from `start`(inclusive) to `end`, incrementing by step #[derive(Debug, Clone)] -struct GenerateSeriesTable { +pub struct GenerateSeriesTable { schema: SchemaRef, args: GenSeriesArgs, } -/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step +impl GenerateSeriesTable { + pub fn new(schema: SchemaRef, args: GenSeriesArgs) -> Self { + Self { schema, args } + } + + pub fn as_generator( + &self, + batch_size: usize, + ) -> Result>> { + let generator: Arc> = match &self.args { + GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })), + GenSeriesArgs::Int64Args { + start, + end, + step, + include_end, + name, + } => Arc::new(RwLock::new(GenericSeriesState { + schema: self.schema(), + start: *start, + end: *end, + step: *step, + current: *start, + batch_size, + include_end: *include_end, + name, + })), + GenSeriesArgs::TimestampArgs { + start, + end, + step, + tz, + include_end, + name, + } => { + let parsed_tz = tz + .as_ref() + .map(|s| Tz::from_str(s.as_ref())) + .transpose() + .map_err(|e| { + datafusion_common::internal_datafusion_err!( + "Failed to parse timezone: {e}" + ) + })? + .unwrap_or_else(|| Tz::from_str("+00:00").unwrap()); + Arc::new(RwLock::new(GenericSeriesState { + schema: self.schema(), + start: TimestampValue { + value: *start, + parsed_tz: Some(parsed_tz), + tz_str: tz.clone(), + }, + end: TimestampValue { + value: *end, + parsed_tz: Some(parsed_tz), + tz_str: tz.clone(), + }, + step: *step, + current: TimestampValue { + value: *start, + parsed_tz: Some(parsed_tz), + tz_str: tz.clone(), + }, + batch_size, + include_end: *include_end, + name, + })) + } + GenSeriesArgs::DateArgs { + start, + end, + step, + include_end, + name, + } => Arc::new(RwLock::new(GenericSeriesState { + schema: self.schema(), + start: TimestampValue { + value: *start, + parsed_tz: None, + tz_str: None, + }, + end: TimestampValue { + value: *end, + parsed_tz: None, + tz_str: None, + }, + step: *step, + current: TimestampValue { + value: *start, + parsed_tz: None, + tz_str: None, + }, + batch_size, + include_end: *include_end, + name, + })), + }; + + Ok(generator) + } +} + #[derive(Debug, Clone)] -struct GenerateSeriesState { +pub struct GenericSeriesState { schema: SchemaRef, - start: i64, // Kept for display - end: i64, - step: i64, + start: T, + end: T, + step: T::StepType, batch_size: usize, - - /// Tracks current position when generating table - current: i64, - /// Indicates whether the end value should be included in the series. + current: T, include_end: bool, name: &'static str, } -impl GenerateSeriesState { - fn reach_end(&self, val: i64) -> bool { - if self.step > 0 { - if self.include_end { - return val > self.end; - } else { - return val >= self.end; - } +impl GenericSeriesState { + pub fn name(&self) -> &'static str { + self.name + } + + pub fn batch_size(&self) -> usize { + self.batch_size + } + + pub fn include_end(&self) -> bool { + self.include_end + } + + pub fn start(&self) -> &T { + &self.start + } + + pub fn end(&self) -> &T { + &self.end + } + + pub fn step(&self) -> &T::StepType { + &self.step + } + + pub fn current(&self) -> &T { + &self.current + } +} + +impl LazyBatchGenerator for GenericSeriesState { + fn as_any(&self) -> &dyn Any { + self + } + + fn generate_next_batch(&mut self) -> Result> { + let mut buf = Vec::with_capacity(self.batch_size); + + while buf.len() < self.batch_size + && !self + .current + .should_stop(self.end.clone(), &self.step, self.include_end) + { + buf.push(self.current.to_value_type()); + self.current.advance(&self.step)?; } - if self.include_end { - val < self.end - } else { - val <= self.end + if buf.is_empty() { + return Ok(None); } + + let array = self.current.create_array(buf)?; + let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?; + Ok(Some(batch)) } } -/// Detail to display for 'Explain' plan -impl fmt::Display for GenerateSeriesState { +impl fmt::Display for GenericSeriesState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "{}: start={}, end={}, batch_size={}", - self.name, self.start, self.end, self.batch_size + self.name, + self.start.display_value(), + self.end.display_value(), + self.batch_size ) } } -impl LazyBatchGenerator for GenerateSeriesState { - fn generate_next_batch(&mut self) -> Result> { - let mut buf = Vec::with_capacity(self.batch_size); - while buf.len() < self.batch_size && !self.reach_end(self.current) { - buf.push(self.current); - self.current += self.step; +fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool { + if step > 0 { + if include_end { + val > end + } else { + val >= end } - let array = Int64Array::from(buf); + } else if include_end { + val < end + } else { + val <= end + } +} - if array.is_empty() { - return Ok(None); - } +fn validate_interval_step( + step: IntervalMonthDayNano, + start: i64, + end: i64, +) -> Result<()> { + if step.months == 0 && step.days == 0 && step.nanoseconds == 0 { + return plan_err!("Step interval cannot be zero"); + } - let batch = - RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; + let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0; + let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0; - Ok(Some(batch)) + if start > end && step_is_positive { + return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series"); } + + if start < end && step_is_negative { + return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series"); + } + + Ok(()) } #[async_trait] impl TableProvider for GenerateSeriesTable { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -138,46 +467,19 @@ impl TableProvider for GenerateSeriesTable { async fn scan( &self, state: &dyn Session, - _projection: Option<&Vec>, + projection: Option<&Vec>, _filters: &[Expr], _limit: Option, ) -> Result> { let batch_size = state.config_options().execution.batch_size; - - let state = match self.args { - // if args have null, then return 0 row - GenSeriesArgs::ContainsNull { include_end, name } => GenerateSeriesState { - schema: self.schema(), - start: 0, - end: 0, - step: 1, - current: 1, - batch_size, - include_end, - name, - }, - GenSeriesArgs::AllNotNullArgs { - start, - end, - step, - include_end, - name, - } => GenerateSeriesState { - schema: self.schema(), - start, - end, - step, - current: start, - batch_size, - include_end, - name, - }, + let schema = match projection { + Some(projection) => Arc::new(self.schema.project(projection)?), + None => self.schema(), }; - Ok(Arc::new(LazyMemoryExec::try_new( - self.schema(), - vec![Arc::new(RwLock::new(state))], - )?)) + let generator = self.as_generator(batch_size)?; + + Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?)) } } @@ -193,12 +495,44 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { return plan_err!("{} function requires 1 to 3 arguments", self.name); } + // Determine the data type from the first argument + match &exprs[0] { + Expr::Literal( + // Default to int64 for null + ScalarValue::Null | ScalarValue::Int64(_), + _, + ) => self.call_int64(exprs), + Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => { + self.call_timestamp(exprs) + } + Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => { + self.call_date(exprs) + } + Expr::Literal(scalar, _) => { + plan_err!( + "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}", + scalar.data_type() + ) + } + _ => plan_err!("Arguments must be literals"), + } + } +} + +impl GenerateSeriesFuncImpl { + fn call_int64(&self, exprs: &[Expr]) -> Result> { let mut normalize_args = Vec::new(); - for expr in exprs { + for (expr_index, expr) in exprs.iter().enumerate() { match expr { - Expr::Literal(ScalarValue::Null) => {} - Expr::Literal(ScalarValue::Int64(Some(n))) => normalize_args.push(*n), - _ => return plan_err!("First argument must be an integer literal"), + Expr::Literal(ScalarValue::Null, _) => {} + Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n), + other => { + return plan_err!( + "Argument #{} must be an INTEGER or NULL, got {:?}", + expr_index + 1, + other + ) + } }; } @@ -212,10 +546,7 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { // contain null return Ok(Arc::new(GenerateSeriesTable { schema, - args: GenSeriesArgs::ContainsNull { - include_end: self.include_end, - name: self.name, - }, + args: GenSeriesArgs::ContainsNull { name: self.name }, })); } @@ -229,20 +560,20 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { }; if start > end && step > 0 { - return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series"); + return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series"); } if start < end && step < 0 { - return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series"); + return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series"); } if step == 0 { - return plan_err!("step cannot be zero"); + return plan_err!("Step cannot be zero"); } Ok(Arc::new(GenerateSeriesTable { schema, - args: GenSeriesArgs::AllNotNullArgs { + args: GenSeriesArgs::Int64Args { start, end, step, @@ -251,6 +582,174 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { }, })) } + + fn call_timestamp(&self, exprs: &[Expr]) -> Result> { + if exprs.len() != 3 { + return plan_err!( + "{} function with timestamps requires exactly 3 arguments", + self.name + ); + } + + // Parse start timestamp + let (start_ts, tz) = match &exprs[0] { + Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => { + (*ts, tz.clone()) + } + other => { + return plan_err!( + "First argument must be a timestamp or NULL, got {:?}", + other + ) + } + }; + + // Parse end timestamp + let end_ts = match &exprs[1] { + Expr::Literal(ScalarValue::Null, _) => None, + Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts, + other => { + return plan_err!( + "Second argument must be a timestamp or NULL, got {:?}", + other + ) + } + }; + + // Parse step interval + let step_interval = match &exprs[2] { + Expr::Literal(ScalarValue::Null, _) => None, + Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval, + other => { + return plan_err!( + "Third argument must be an interval or NULL, got {:?}", + other + ) + } + }; + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + false, + )])); + + // Check if any argument is null + let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval) + else { + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull { name: self.name }, + })); + }; + + // Validate step interval + validate_interval_step(step, start, end)?; + + Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::TimestampArgs { + start, + end, + step, + tz, + include_end: self.include_end, + name: self.name, + }, + })) + } + + fn call_date(&self, exprs: &[Expr]) -> Result> { + if exprs.len() != 3 { + return plan_err!( + "{} function with dates requires exactly 3 arguments", + self.name + ); + } + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )])); + + // Parse start date + let start_date = match &exprs[0] { + Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date, + Expr::Literal(ScalarValue::Date32(None), _) + | Expr::Literal(ScalarValue::Null, _) => { + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull { name: self.name }, + })); + } + other => { + return plan_err!( + "First argument must be a date or NULL, got {:?}", + other + ) + } + }; + + // Parse end date + let end_date = match &exprs[1] { + Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date, + Expr::Literal(ScalarValue::Date32(None), _) + | Expr::Literal(ScalarValue::Null, _) => { + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull { name: self.name }, + })); + } + other => { + return plan_err!( + "Second argument must be a date or NULL, got {:?}", + other + ) + } + }; + + // Parse step interval + let step_interval = match &exprs[2] { + Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => { + *interval + } + Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _) + | Expr::Literal(ScalarValue::Null, _) => { + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull { name: self.name }, + })); + } + other => { + return plan_err!( + "Third argument must be an interval or NULL, got {:?}", + other + ) + } + }; + + // Convert Date32 (days since epoch) to timestamp nanoseconds (nanoseconds since epoch) + // Date32 is days since 1970-01-01, so multiply by nanoseconds per day + const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000; + + let start_ts = start_date as i64 * NANOS_PER_DAY; + let end_ts = end_date as i64 * NANOS_PER_DAY; + + // Validate step interval + validate_interval_step(step_interval, start_ts, end_ts)?; + + Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::DateArgs { + start: start_ts, + end: end_ts, + step: step_interval, + include_end: self.include_end, + name: self.name, + }, + })) + } } #[derive(Debug)] diff --git a/datafusion/functions-table/src/lib.rs b/datafusion/functions-table/src/lib.rs index 36fcdc7ede56c..b339a8f4a52f3 100644 --- a/datafusion/functions-table/src/lib.rs +++ b/datafusion/functions-table/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/functions-window-common/README.md b/datafusion/functions-window-common/README.md index de12d25f97319..f2e45880724e0 100644 --- a/datafusion/functions-window-common/README.md +++ b/datafusion/functions-window-common/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Window Function Common Library +# Apache DataFusion Window Function Common Library -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. -This crate contains common functions for implementing user-defined window functions. +This crate contains common functions for implementing window functions. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs index 1d99fe7acf152..774cd5182b30b 100644 --- a/datafusion/functions-window-common/src/expr.rs +++ b/datafusion/functions-window-common/src/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -25,9 +25,9 @@ pub struct ExpressionArgs<'a> { /// The expressions passed as arguments to the user-defined window /// function. input_exprs: &'a [Arc], - /// The corresponding data types of expressions passed as arguments + /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], } impl<'a> ExpressionArgs<'a> { @@ -36,17 +36,17 @@ impl<'a> ExpressionArgs<'a> { /// # Arguments /// /// * `input_exprs` - The expressions passed as arguments - /// to the user-defined window function. + /// to the user-defined window function. /// * `input_types` - The data types corresponding to the - /// arguments to the user-defined window function. + /// arguments to the user-defined window function. /// pub fn new( input_exprs: &'a [Arc], - input_types: &'a [DataType], + input_fields: &'a [FieldRef], ) -> Self { Self { input_exprs, - input_types, + input_fields, } } @@ -56,9 +56,9 @@ impl<'a> ExpressionArgs<'a> { self.input_exprs } - /// Returns the [`DataType`]s corresponding to the input expressions + /// Returns the [`FieldRef`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_types(&self) -> &'a [DataType] { - self.input_types + pub fn input_fields(&self) -> &'a [FieldRef] { + self.input_fields } } diff --git a/datafusion/functions-window-common/src/field.rs b/datafusion/functions-window-common/src/field.rs index 8011b7b0f05f0..8d22efa3bcf44 100644 --- a/datafusion/functions-window-common/src/field.rs +++ b/datafusion/functions-window-common/src/field.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; /// Metadata for defining the result field from evaluating a /// user-defined window function. pub struct WindowUDFFieldArgs<'a> { - /// The data types corresponding to the arguments to the + /// The fields corresponding to the arguments to the /// user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], /// The display name of the user-defined window function. display_name: &'a str, } @@ -32,22 +32,22 @@ impl<'a> WindowUDFFieldArgs<'a> { /// /// # Arguments /// - /// * `input_types` - The data types corresponding to the - /// arguments to the user-defined window function. + /// * `input_fields` - The fields corresponding to the + /// arguments to the user-defined window function. /// * `function_name` - The qualified schema name of the - /// user-defined window function expression. + /// user-defined window function expression. /// - pub fn new(input_types: &'a [DataType], display_name: &'a str) -> Self { + pub fn new(input_fields: &'a [FieldRef], display_name: &'a str) -> Self { WindowUDFFieldArgs { - input_types, + input_fields, display_name, } } - /// Returns the data type of input expressions passed as arguments + /// Returns the field of input expressions passed as arguments /// to the user-defined window function. - pub fn input_types(&self) -> &[DataType] { - self.input_types + pub fn input_fields(&self) -> &[FieldRef] { + self.input_fields } /// Returns the name for the field of the final result of evaluating @@ -56,9 +56,9 @@ impl<'a> WindowUDFFieldArgs<'a> { self.display_name } - /// Returns `Some(DataType)` of input expression at index, otherwise + /// Returns `Some(Field)` of input expression at index, otherwise /// returns `None` if the index is out of bounds. - pub fn get_input_type(&self, index: usize) -> Option { - self.input_types.get(index).cloned() + pub fn get_input_field(&self, index: usize) -> Option { + self.input_fields.get(index).cloned() } } diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 7f668a20a76a6..76341239f6a5a 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs index 64786d2fe7c70..61125e596130b 100644 --- a/datafusion/functions-window-common/src/partition.rs +++ b/datafusion/functions-window-common/src/partition.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -26,9 +26,9 @@ pub struct PartitionEvaluatorArgs<'a> { /// The expressions passed as arguments to the user-defined window /// function. input_exprs: &'a [Arc], - /// The corresponding data types of expressions passed as arguments + /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], /// Set to `true` if the user-defined window function is reversed. is_reversed: bool, /// Set to `true` if `IGNORE NULLS` is specified. @@ -41,23 +41,23 @@ impl<'a> PartitionEvaluatorArgs<'a> { /// # Arguments /// /// * `input_exprs` - The expressions passed as arguments - /// to the user-defined window function. + /// to the user-defined window function. /// * `input_types` - The data types corresponding to the - /// arguments to the user-defined window function. + /// arguments to the user-defined window function. /// * `is_reversed` - Set to `true` if and only if the user-defined - /// window function is reversible and is reversed. + /// window function is reversible and is reversed. /// * `ignore_nulls` - Set to `true` when `IGNORE NULLS` is - /// specified. + /// specified. /// pub fn new( input_exprs: &'a [Arc], - input_types: &'a [DataType], + input_fields: &'a [FieldRef], is_reversed: bool, ignore_nulls: bool, ) -> Self { Self { input_exprs, - input_types, + input_fields, is_reversed, ignore_nulls, } @@ -69,10 +69,10 @@ impl<'a> PartitionEvaluatorArgs<'a> { self.input_exprs } - /// Returns the [`DataType`]s corresponding to the input expressions + /// Returns the [`FieldRef`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_types(&self) -> &'a [DataType] { - self.input_types + pub fn input_fields(&self) -> &'a [FieldRef] { + self.input_fields } /// Returns `true` when the user-defined window function is diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index e0c17c579b196..23ee608a82675 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -38,6 +38,7 @@ workspace = true name = "datafusion_functions_window" [dependencies] +arrow = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-expr = { workspace = true } @@ -47,6 +48,3 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.15" - -[dev-dependencies] -arrow = { workspace = true } diff --git a/datafusion/functions-window/README.md b/datafusion/functions-window/README.md index 18590983ca473..f2bb9f53f5307 100644 --- a/datafusion/functions-window/README.md +++ b/datafusion/functions-window/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Window Function Library +# Apache DataFusion Window Function Library -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. -This crate contains user-defined window functions. +This crate contains window function definitions. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs index d777f7932b0e6..372086b12d5ee 100644 --- a/datafusion/functions-window/src/cume_dist.rs +++ b/datafusion/functions-window/src/cume_dist.rs @@ -17,16 +17,18 @@ //! `cume_dist` window function implementation +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::{ArrayRef, Float64Array}; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::Result; use datafusion_expr::{ - Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, + Documentation, LimitEffect, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_macros::user_doc; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use field::WindowUDFFieldArgs; use std::any::Any; use std::fmt::Debug; @@ -43,10 +45,26 @@ define_udwf_and_expr!( /// CumeDist calculates the cume_dist in the window function with order by #[user_doc( doc_section(label = "Ranking Functions"), - description = "Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows).", - syntax_example = "cume_dist()" + description = "Relative rank of the current row: (number of rows preceding or peer with the current row) / (total rows).", + syntax_example = "cume_dist()", + sql_example = r#" +```sql +-- Example usage of the cume_dist window function: +SELECT salary, + cume_dist() OVER (ORDER BY salary) AS cume_dist +FROM employees; + ++--------+-----------+ +| salary | cume_dist | ++--------+-----------+ +| 30000 | 0.33 | +| 50000 | 0.67 | +| 70000 | 1.00 | ++--------+-----------+ +``` +"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct CumeDist { signature: Signature, } @@ -86,13 +104,17 @@ impl WindowUDFImpl for CumeDist { Ok(Box::::default()) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, false)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false).into()) } fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } #[derive(Debug, Default)] @@ -113,7 +135,7 @@ impl PartitionEvaluator for CumeDistEvaluator { let len = range.end - range.start; *acc += len as u64; let value: f64 = (*acc as f64) / scalar; - let result = iter::repeat(value).take(len); + let result = iter::repeat_n(value, len); Some(result) }) .flatten(), diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 5df20cf5b9808..3910a0be574d8 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -18,22 +18,25 @@ //! `lead` and `lag` window function implementations use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL; use datafusion_expr::{ - Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, - Volatility, WindowUDFImpl, + Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature, + TypeSignature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::expr::ExpressionArgs; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr::expressions; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::min; use std::collections::VecDeque; +use std::hash::Hash; use std::ops::{Neg, Range}; use std::sync::{Arc, LazyLock}; @@ -92,8 +95,8 @@ pub fn lead( lead_udwf().call(vec![arg, shift_offset_lit, default_lit]) } -#[derive(Debug)] -enum WindowShiftKind { +#[derive(Debug, PartialEq, Eq, Hash)] +pub enum WindowShiftKind { Lag, Lead, } @@ -118,7 +121,7 @@ impl WindowShiftKind { } /// window shift expression -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct WindowShift { signature: Signature, kind: WindowShiftKind, @@ -146,6 +149,10 @@ impl WindowShift { pub fn lead() -> Self { Self::new(WindowShiftKind::Lead) } + + pub fn kind(&self) -> &WindowShiftKind { + &self.kind + } } static LAG_DOCUMENTATION: LazyLock = LazyLock::new(|| { @@ -157,6 +164,24 @@ static LAG_DOCUMENTATION: LazyLock = LazyLock::new(|| { the value of expression should be retrieved. Defaults to 1.") .with_argument("default", "The default value if the offset is \ not within the partition. Must be of the same type as expression.") + .with_sql_example(r#" +```sql +-- Example usage of the lag window function: +SELECT employee_id, + salary, + lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary +FROM employees; + ++-------------+--------+-------------+ +| employee_id | salary | prev_salary | ++-------------+--------+-------------+ +| 1 | 30000 | 0 | +| 2 | 50000 | 30000 | +| 3 | 70000 | 50000 | +| 4 | 60000 | 70000 | ++-------------+--------+-------------+ +``` +"#) .build() }); @@ -175,6 +200,27 @@ static LEAD_DOCUMENTATION: LazyLock = LazyLock::new(|| { forward the value of expression should be retrieved. Defaults to 1.") .with_argument("default", "The default value if the offset is \ not within the partition. Must be of the same type as expression.") + .with_sql_example(r#" +```sql +-- Example usage of lead window function: +SELECT + employee_id, + department, + salary, + lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary +FROM employees; + ++-------------+-------------+--------+--------------+ +| employee_id | department | salary | next_salary | ++-------------+-------------+--------+--------------+ +| 1 | Sales | 30000 | 50000 | +| 2 | Sales | 50000 | 70000 | +| 3 | Sales | 70000 | 0 | +| 4 | Engineering | 40000 | 60000 | +| 5 | Engineering | 60000 | 0 | ++-------------+-------------+--------+--------------+ +``` +"#) .build() }); @@ -201,7 +247,7 @@ impl WindowUDFImpl for WindowShift { /// /// For more details see: fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { - parse_expr(expr_args.input_exprs(), expr_args.input_types()) + parse_expr(expr_args.input_exprs(), expr_args.input_fields()) .into_iter() .collect::>() } @@ -224,7 +270,7 @@ impl WindowUDFImpl for WindowShift { })?; let default_value = parse_default_value( partition_evaluator_args.input_exprs(), - partition_evaluator_args.input_types(), + partition_evaluator_args.input_fields(), )?; Ok(Box::new(WindowShiftEvaluator { @@ -235,10 +281,14 @@ impl WindowUDFImpl for WindowShift { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - let return_type = parse_expr_type(field_args.input_types())?; + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_field = parse_expr_field(field_args.input_fields())?; - Ok(Field::new(field_args.name(), return_type, true)) + Ok(return_field + .as_ref() + .clone() + .with_name(field_args.name()) + .into()) } fn reverse_expr(&self) -> ReversedUDWF { @@ -254,6 +304,26 @@ impl WindowUDFImpl for WindowShift { WindowShiftKind::Lead => Some(get_lead_doc()), } } + + fn limit_effect(&self, args: &[Arc]) -> LimitEffect { + if self.kind == WindowShiftKind::Lag { + return LimitEffect::None; + } + match args { + [_, expr, ..] => { + let Some(lit) = expr.as_any().downcast_ref::() + else { + return LimitEffect::Unknown; + }; + let ScalarValue::Int64(Some(amount)) = lit.value() else { + return LimitEffect::Unknown; // we should only get int64 from the parser + }; + LimitEffect::Relative((*amount).max(0) as usize) + } + [_] => LimitEffect::Relative(1), // default value + _ => LimitEffect::Unknown, // invalid arguments + } + } } /// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to @@ -270,58 +340,63 @@ impl WindowUDFImpl for WindowShift { /// For more details see: fn parse_expr( input_exprs: &[Arc], - input_types: &[DataType], + input_fields: &[FieldRef], ) -> Result> { assert!(!input_exprs.is_empty()); - assert!(!input_types.is_empty()); + assert!(!input_fields.is_empty()); let expr = Arc::clone(input_exprs.first().unwrap()); - let expr_type = input_types.first().unwrap(); + let expr_field = input_fields.first().unwrap(); // Handles the most common case where NULL is unexpected - if !expr_type.is_null() { + if !expr_field.data_type().is_null() { return Ok(expr); } let default_value = get_scalar_value_from_args(input_exprs, 2)?; default_value.map_or(Ok(expr), |value| { - ScalarValue::try_from(&value.data_type()).map(|v| { - Arc::new(datafusion_physical_expr::expressions::Literal::new(v)) - as Arc - }) + ScalarValue::try_from(&value.data_type()) + .map(|v| Arc::new(expressions::Literal::new(v)) as Arc) }) } -/// Returns the data type of the default value(if provided) when the +static NULL_FIELD: LazyLock = + LazyLock::new(|| Field::new("value", DataType::Null, true).into()); + +/// Returns the field of the default value(if provided) when the /// expression is `NULL`. /// -/// Otherwise, returns the expression type unchanged. -fn parse_expr_type(input_types: &[DataType]) -> Result { - assert!(!input_types.is_empty()); - let expr_type = input_types.first().unwrap_or(&DataType::Null); +/// Otherwise, returns the expression field unchanged. +fn parse_expr_field(input_fields: &[FieldRef]) -> Result { + assert!(!input_fields.is_empty()); + let expr_field = input_fields.first().unwrap_or(&NULL_FIELD); // Handles the most common case where NULL is unexpected - if !expr_type.is_null() { - return Ok(expr_type.clone()); + if !expr_field.data_type().is_null() { + return Ok(expr_field.as_ref().clone().with_nullable(true).into()); } - let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); - Ok(default_value_type.clone()) + let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD); + Ok(default_value_field + .as_ref() + .clone() + .with_nullable(true) + .into()) } /// Handles type coercion and null value refinement for default value /// argument depending on the data type of the input expression. fn parse_default_value( input_exprs: &[Arc], - input_types: &[DataType], + input_types: &[FieldRef], ) -> Result { - let expr_type = parse_expr_type(input_types)?; + let expr_field = parse_expr_field(input_types)?; let unparsed = get_scalar_value_from_args(input_exprs, 2)?; unparsed .filter(|v| !v.data_type().is_null()) - .map(|v| v.cast_to(&expr_type)) - .unwrap_or(ScalarValue::try_from(expr_type)) + .map(|v| v.cast_to(expr_field.data_type())) + .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type())) } #[derive(Debug)] @@ -666,7 +741,12 @@ mod tests { test_i32_result( WindowShift::lead(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), [ Some(-2), Some(3), @@ -688,7 +768,12 @@ mod tests { test_i32_result( WindowShift::lag(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), [ None, Some(1), @@ -713,12 +798,15 @@ mod tests { as Arc; let input_exprs = &[expr, shift_offset, default_value]; - let input_types: &[DataType] = - &[DataType::Int32, DataType::Int32, DataType::Int32]; + let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32] + .into_iter() + .map(|d| Field::new("f", d, true)) + .map(Arc::new) + .collect::>(); test_i32_result( WindowShift::lag(), - PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), + PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false), [ Some(100), Some(1), diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 10e09542d7c5d..139ace4bf7097 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs index 0a86ba6255330..890ced90a9a21 100644 --- a/datafusion/functions-window/src/macros.rs +++ b/datafusion/functions-window/src/macros.rs @@ -29,17 +29,18 @@ /// # Parameters /// /// * `$UDWF`: The struct which defines the [`Signature`](datafusion_expr::Signature) -/// of the user-defined window function. +/// of the user-defined window function. /// * `$OUT_FN_NAME`: The basename to generate a unique function name like -/// `$OUT_FN_NAME_udwf`. +/// `$OUT_FN_NAME_udwf`. /// * `$DOC`: Doc comments for UDWF. /// * (optional) `$CTOR`: Pass a custom constructor. When omitted it -/// automatically resolves to `$UDWF::default()`. +/// automatically resolves to `$UDWF::default()`. /// /// # Example /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # @@ -56,7 +57,7 @@ /// # /// # assert_eq!(simple_udwf().name(), "simple_user_defined_window_function"); /// # -/// # #[derive(Debug)] +/// # #[derive(Debug, PartialEq, Eq, Hash)] /// # struct SimpleUDWF { /// # signature: Signature, /// # } @@ -85,8 +86,8 @@ /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false).into()) /// # } /// # } /// # @@ -122,13 +123,13 @@ macro_rules! get_or_init_udwf { /// # Parameters /// /// * `$UDWF`: The struct which defines the [`Signature`] of the -/// user-defined window function. +/// user-defined window function. /// * `$OUT_FN_NAME`: The basename to generate a unique function name like -/// `$OUT_FN_NAME_udwf`. +/// `$OUT_FN_NAME_udwf`. /// * `$DOC`: Doc comments for UDWF. /// * (optional) `[$($PARAM:ident),+]`: An array of 1 or more parameters -/// for the generated function. The type of parameters is [`Expr`]. -/// When omitted this creates a function with zero parameters. +/// for the generated function. The type of parameters is [`Expr`]. +/// When omitted this creates a function with zero parameters. /// /// [`Signature`]: datafusion_expr::Signature /// [`Expr`]: datafusion_expr::Expr @@ -138,6 +139,7 @@ macro_rules! get_or_init_udwf { /// 1. With Zero Parameters /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; @@ -169,7 +171,7 @@ macro_rules! get_or_init_udwf { /// # "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" /// # ); /// # -/// # #[derive(Debug)] +/// # #[derive(Debug, PartialEq, Eq, Hash)] /// # struct RowNumber { /// # signature: Signature, /// # } @@ -196,8 +198,8 @@ macro_rules! get_or_init_udwf { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) /// # } /// # } /// ``` @@ -205,6 +207,7 @@ macro_rules! get_or_init_udwf { /// 2. With Multiple Parameters /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -247,7 +250,7 @@ macro_rules! get_or_init_udwf { /// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" /// # ); /// # -/// # #[derive(Debug)] +/// # #[derive(Debug, PartialEq, Eq, Hash)] /// # struct Lead { /// # signature: Signature, /// # } @@ -283,12 +286,12 @@ macro_rules! get_or_init_udwf { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` @@ -332,15 +335,15 @@ macro_rules! create_udwf_expr { /// # Arguments /// /// * `$UDWF`: The struct which defines the [`Signature`] of the -/// user-defined window function. +/// user-defined window function. /// * `$OUT_FN_NAME`: The basename to generate a unique function name like -/// `$OUT_FN_NAME_udwf`. +/// `$OUT_FN_NAME_udwf`. /// * (optional) `[$($PARAM:ident),+]`: An array of 1 or more parameters -/// for the generated function. The type of parameters is [`Expr`]. -/// When omitted this creates a function with zero parameters. +/// for the generated function. The type of parameters is [`Expr`]. +/// When omitted this creates a function with zero parameters. /// * `$DOC`: Doc comments for UDWF. /// * (optional) `$CTOR`: Pass a custom constructor. When omitted it -/// automatically resolves to `$UDWF::default()`. +/// automatically resolves to `$UDWF::default()`. /// /// [`Signature`]: datafusion_expr::Signature /// [`Expr`]: datafusion_expr::Expr @@ -352,6 +355,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # @@ -375,7 +379,7 @@ macro_rules! create_udwf_expr { /// # /// # assert_eq!(simple_udwf().name(), "simple_user_defined_window_function"); /// # -/// # #[derive(Debug)] +/// # #[derive(Debug, PartialEq, Eq, Hash)] /// # struct SimpleUDWF { /// # signature: Signature, /// # } @@ -404,8 +408,8 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false).into()) /// # } /// # } /// # @@ -415,6 +419,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; @@ -441,7 +446,7 @@ macro_rules! create_udwf_expr { /// # "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" /// # ); /// # -/// # #[derive(Debug)] +/// # #[derive(Debug, PartialEq, Eq, Hash)] /// # struct RowNumber { /// # signature: Signature, /// # } @@ -468,8 +473,8 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) /// # } /// # } /// ``` @@ -479,6 +484,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -518,7 +524,7 @@ macro_rules! create_udwf_expr { /// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" /// # ); /// # -/// # #[derive(Debug)] +/// # #[derive(Debug, PartialEq, Eq, Hash)] /// # struct Lead { /// # signature: Signature, /// # } @@ -554,12 +560,12 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` @@ -567,6 +573,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -607,7 +614,7 @@ macro_rules! create_udwf_expr { /// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" /// # ); /// # -/// # #[derive(Debug)] +/// # #[derive(Debug, PartialEq, Eq, Hash)] /// # struct Lead { /// # signature: Signature, /// # } @@ -643,12 +650,12 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index 1c781bd8e5f3f..329d8aa5ab178 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -19,24 +19,26 @@ use crate::utils::{get_scalar_value_from_args, get_signed_integer}; -use std::any::Any; -use std::cmp::Ordering; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::LazyLock; - +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; -use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +use datafusion_doc::window_doc_sections::DOC_SECTION_ANALYTICAL; use datafusion_expr::window_state::WindowAggState; use datafusion_expr::{ - Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, - Volatility, WindowUDFImpl, + Documentation, LimitEffect, Literal, PartitionEvaluator, ReversedUDWF, Signature, + TypeSignature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::Hash; +use std::ops::Range; +use std::sync::{Arc, LazyLock}; get_or_init_udwf!( First, @@ -76,7 +78,7 @@ pub fn nth_value(arg: datafusion_expr::Expr, n: i64) -> datafusion_expr::Expr { } /// Tag to differentiate special use cases of the NTH_VALUE built-in window function. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum NthValueKind { First, Last, @@ -93,7 +95,7 @@ impl NthValueKind { } } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NthValue { signature: Signature, kind: NthValueKind, @@ -125,6 +127,10 @@ impl NthValue { pub fn nth() -> Self { Self::new(NthValueKind::Nth) } + + pub fn kind(&self) -> &NthValueKind { + &self.kind + } } static FIRST_VALUE_DOCUMENTATION: LazyLock = LazyLock::new(|| { @@ -135,6 +141,28 @@ static FIRST_VALUE_DOCUMENTATION: LazyLock = LazyLock::new(|| { "first_value(expression)", ) .with_argument("expression", "Expression to operate on") + .with_sql_example( + r#" +```sql +-- Example usage of the first_value window function: +SELECT department, + employee_id, + salary, + first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary +FROM employees; + ++-------------+-------------+--------+------------+ +| department | employee_id | salary | top_salary | ++-------------+-------------+--------+------------+ +| Sales | 1 | 70000 | 70000 | +| Sales | 2 | 50000 | 70000 | +| Sales | 3 | 30000 | 70000 | +| Engineering | 4 | 90000 | 90000 | +| Engineering | 5 | 80000 | 90000 | ++-------------+-------------+--------+------------+ +``` +"#, + ) .build() }); @@ -150,6 +178,25 @@ static LAST_VALUE_DOCUMENTATION: LazyLock = LazyLock::new(|| { "last_value(expression)", ) .with_argument("expression", "Expression to operate on") + .with_sql_example(r#"```sql +-- SQL example of last_value: +SELECT department, + employee_id, + salary, + last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary +FROM employees; + ++-------------+-------------+--------+---------------------+ +| department | employee_id | salary | running_last_salary | ++-------------+-------------+--------+---------------------+ +| Sales | 1 | 30000 | 30000 | +| Sales | 2 | 50000 | 50000 | +| Sales | 3 | 70000 | 70000 | +| Engineering | 4 | 40000 | 40000 | +| Engineering | 5 | 60000 | 60000 | ++-------------+-------------+--------+---------------------+ +``` +"#) .build() }); @@ -160,16 +207,49 @@ fn get_last_value_doc() -> &'static Documentation { static NTH_VALUE_DOCUMENTATION: LazyLock = LazyLock::new(|| { Documentation::builder( DOC_SECTION_ANALYTICAL, - "Returns value evaluated at the row that is the nth row of the window \ - frame (counting from 1); null if no such row.", + "Returns the value evaluated at the nth row of the window frame \ + (counting from 1). Returns NULL if no such row exists.", "nth_value(expression, n)", ) .with_argument( "expression", - "The name the column of which nth \ - value to retrieve", + "The column from which to retrieve the nth value.", + ) + .with_argument( + "n", + "Integer. Specifies the row number (starting from 1) in the window frame.", + ) + .with_sql_example( + r#" +```sql +-- Sample employees table: +CREATE TABLE employees (id INT, salary INT); +INSERT INTO employees (id, salary) VALUES +(1, 30000), +(2, 40000), +(3, 50000), +(4, 60000), +(5, 70000); + +-- Example usage of nth_value: +SELECT nth_value(salary, 2) OVER ( + ORDER BY salary + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +) AS nth_value +FROM employees; + ++-----------+ +| nth_value | ++-----------+ +| 40000 | +| 40000 | +| 40000 | +| 40000 | +| 40000 | ++-----------+ +``` +"#, ) - .with_argument("n", "Integer. Specifies the n in nth") .build() }); @@ -236,11 +316,15 @@ impl WindowUDFImpl for NthValue { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - let nullable = true; - let return_type = field_args.input_types().first().unwrap_or(&DataType::Null); + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = field_args + .input_fields() + .first() + .map(|f| f.data_type()) + .cloned() + .unwrap_or(DataType::Null); - Ok(Field::new(field_args.name(), return_type.clone(), nullable)) + Ok(Field::new(field_args.name(), return_type, true).into()) } fn reverse_expr(&self) -> ReversedUDWF { @@ -258,6 +342,10 @@ impl WindowUDFImpl for NthValue { NthValueKind::Nth => Some(get_nth_value_doc()), } } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::None // NthValue is causal + } } #[derive(Debug, Clone)] @@ -478,7 +566,12 @@ mod tests { let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( NthValue::first(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), Int32Array::from(vec![1; 8]).iter().collect::(), ) } @@ -488,7 +581,12 @@ mod tests { let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( NthValue::last(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), Int32Array::from(vec![ Some(1), Some(-2), @@ -512,7 +610,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[DataType::Int32], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), @@ -531,7 +629,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[DataType::Int32], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs index 180f7ab02c03b..d188db3bbf59e 100644 --- a/datafusion/functions-window/src/ntile.rs +++ b/datafusion/functions-window/src/ntile.rs @@ -17,23 +17,25 @@ //! `ntile` window function implementation -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - use crate::utils::{ get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, }; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; use datafusion_common::arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_datafusion_err, exec_err, Result}; use datafusion_expr::{ - Documentation, Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, + Documentation, Expr, LimitEffect, PartitionEvaluator, Signature, Volatility, + WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_macros::user_doc; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; get_or_init_udwf!( Ntile, @@ -52,9 +54,31 @@ pub fn ntile(arg: Expr) -> Expr { argument( name = "expression", description = "An integer describing the number groups the partition should be split into" - ) + ), + sql_example = r#" +```sql +-- Example usage of the ntile window function: +SELECT employee_id, + salary, + ntile(4) OVER (ORDER BY salary DESC) AS quartile +FROM employees; + ++-------------+--------+----------+ +| employee_id | salary | quartile | ++-------------+--------+----------+ +| 1 | 90000 | 1 | +| 2 | 85000 | 1 | +| 3 | 80000 | 2 | +| 4 | 70000 | 2 | +| 5 | 60000 | 3 | +| 6 | 50000 | 3 | +| 7 | 40000 | 4 | +| 8 | 30000 | 4 | ++-------------+--------+----------+ +``` +"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Ntile { signature: Signature, } @@ -107,9 +131,7 @@ impl WindowUDFImpl for Ntile { let scalar_n = get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? .ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires a positive integer".to_string(), - ) + exec_datafusion_err!("NTILE requires a positive integer") })?; if scalar_n.is_null() { @@ -127,15 +149,19 @@ impl WindowUDFImpl for Ntile { Ok(Box::new(NtileEvaluator { n: n as u64 })) } } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let nullable = false; - Ok(Field::new(field_args.name(), DataType::UInt64, nullable)) + Ok(Field::new(field_args.name(), DataType::UInt64, nullable).into()) } fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } #[derive(Debug)] diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 1ddd8b27c4205..84836ad569ff8 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -23,7 +23,7 @@ use datafusion_expr::{ expr_rewriter::NamePreserver, planner::{ExprPlanner, PlannerResult, RawWindowExpr}, utils::COUNT_STAR_EXPANSION, - Expr, ExprFunctionExt, + Expr, }; #[derive(Debug)] @@ -40,23 +40,30 @@ impl ExprPlanner for WindowFunctionPlanner { partition_by, order_by, window_frame, + filter, null_treatment, + distinct, } = raw_expr; - let origin_expr = Expr::WindowFunction(WindowFunction { + let origin_expr = Expr::from(WindowFunction { fun: func_def, params: WindowFunctionParams { args, partition_by, order_by, window_frame, + filter, null_treatment, + distinct, }, }); let saved_name = NamePreserver::new_for_projection().save(&origin_expr); - let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = origin_expr else { + unreachable!("") + }; + let WindowFunction { fun, params: WindowFunctionParams { @@ -65,18 +72,19 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, + filter, }, - }) = origin_expr - else { - unreachable!("") - }; + } = *window_fun; let raw_expr = RawWindowExpr { func_def: fun, args, partition_by, order_by, window_frame, + filter, null_treatment, + distinct, }; // TODO: remove the next line after `Expr::Wildcard` is removed @@ -92,19 +100,23 @@ impl ExprPlanner for WindowFunctionPlanner { partition_by, order_by, window_frame, + filter, null_treatment, + distinct, } = raw_expr; - let new_expr = Expr::WindowFunction(WindowFunction::new( - func_def, - vec![Expr::Literal(COUNT_STAR_EXPANSION)], - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?; - + let new_expr = Expr::from(WindowFunction { + fun: func_def, + params: WindowFunctionParams { + args: vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + }, + }); let new_expr = saved_name.restore(new_expr); return Ok(PlannerResult::Planned(new_expr)); diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs index bd2edc5722eb6..6d891e76671d7 100644 --- a/datafusion/functions-window/src/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -18,13 +18,8 @@ //! Implementation of `rank`, `dense_rank`, and `percent_rank` window functions, //! which can be evaluated at runtime during query execution. -use std::any::Any; -use std::fmt::Debug; -use std::iter; -use std::ops::Range; -use std::sync::{Arc, LazyLock}; - use crate::define_udwf_and_expr; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::{Float64Array, UInt64Array}; use datafusion_common::arrow::compute::SortOptions; @@ -32,13 +27,20 @@ use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::utils::get_row_at_idx; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_doc::window_doc_sections::DOC_SECTION_RANKING; use datafusion_expr::{ - Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, + Documentation, LimitEffect, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::hash::Hash; +use std::iter; +use std::ops::Range; +use std::sync::{Arc, LazyLock}; define_udwf_and_expr!( Rank, @@ -62,7 +64,7 @@ define_udwf_and_expr!( ); /// Rank calculates the rank in the window function with order by -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Rank { name: String, signature: Signature, @@ -95,7 +97,7 @@ impl Rank { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum RankType { Basic, Dense, @@ -110,6 +112,26 @@ static RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { skips ranks for identical values.", "rank()") + .with_sql_example(r#" +```sql +-- Example usage of the rank window function: +SELECT department, + salary, + rank() OVER (PARTITION BY department ORDER BY salary DESC) AS rank +FROM employees; + ++-------------+--------+------+ +| department | salary | rank | ++-------------+--------+------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------+ +``` +"#) .build() }); @@ -121,6 +143,25 @@ static DENSE_RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { Documentation::builder(DOC_SECTION_RANKING, "Returns the rank of the current row without gaps. This function ranks \ rows in a dense manner, meaning consecutive ranks are assigned even for identical \ values.", "dense_rank()") + .with_sql_example(r#" +```sql +-- Example usage of the dense_rank window function: +SELECT department, + salary, + dense_rank() OVER (PARTITION BY department ORDER BY salary DESC) AS dense_rank +FROM employees; + ++-------------+--------+------------+ +| department | salary | dense_rank | ++-------------+--------+------------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 3 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------------+ +```"#) .build() }); @@ -131,6 +172,21 @@ fn get_dense_rank_doc() -> &'static Documentation { static PERCENT_RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { Documentation::builder(DOC_SECTION_RANKING, "Returns the percentage rank of the current row within its partition. \ The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`.", "percent_rank()") + .with_sql_example(r#"```sql + -- Example usage of the percent_rank window function: +SELECT employee_id, + salary, + percent_rank() OVER (ORDER BY salary) AS percent_rank +FROM employees; + ++-------------+--------+---------------+ +| employee_id | salary | percent_rank | ++-------------+--------+---------------+ +| 1 | 30000 | 0.00 | +| 2 | 50000 | 0.50 | +| 3 | 70000 | 1.00 | ++-------------+--------+---------------+ +```"#) .build() }); @@ -161,14 +217,14 @@ impl WindowUDFImpl for Rank { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let return_type = match self.rank_type { RankType::Basic | RankType::Dense => DataType::UInt64, RankType::Percent => DataType::Float64, }; let nullable = false; - Ok(Field::new(field_args.name(), return_type, nullable)) + Ok(Field::new(field_args.name(), return_type, nullable).into()) } fn sort_options(&self) -> Option { @@ -185,6 +241,14 @@ impl WindowUDFImpl for Rank { RankType::Percent => Some(get_percent_rank_doc()), } } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + match self.rank_type { + RankType::Basic => LimitEffect::None, + RankType::Dense => LimitEffect::None, + RankType::Percent => LimitEffect::Unknown, + } + } } /// State for the RANK(rank) built-in window function. @@ -261,7 +325,7 @@ impl PartitionEvaluator for RankEvaluator { .iter() .scan(1_u64, |acc, range| { let len = range.end - range.start; - let result = iter::repeat(*acc).take(len); + let result = iter::repeat_n(*acc, len); *acc += len as u64; Some(result) }) @@ -274,7 +338,7 @@ impl PartitionEvaluator for RankEvaluator { .zip(1u64..) .flat_map(|(range, rank)| { let len = range.end - range.start; - iter::repeat(rank).take(len) + iter::repeat_n(rank, len) }), )), @@ -287,7 +351,7 @@ impl PartitionEvaluator for RankEvaluator { .scan(0_u64, |acc, range| { let len = range.end - range.start; let value = (*acc as f64) / (denominator - 1.0).max(1.0); - let result = iter::repeat(value).take(len); + let result = iter::repeat_n(value, len); *acc += len as u64; Some(result) }) diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index 8f462528dbedc..d7d298cecead8 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -17,6 +17,7 @@ //! `row_number` window function implementation +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::UInt64Array; use datafusion_common::arrow::compute::SortOptions; @@ -24,15 +25,17 @@ use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, + Documentation, LimitEffect, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_macros::user_doc; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use field::WindowUDFFieldArgs; use std::any::Any; use std::fmt::Debug; use std::ops::Range; +use std::sync::Arc; define_udwf_and_expr!( RowNumber, @@ -44,9 +47,29 @@ define_udwf_and_expr!( #[user_doc( doc_section(label = "Ranking Functions"), description = "Number of the current row within its partition, counting from 1.", - syntax_example = "row_number()" + syntax_example = "row_number()", + sql_example = r#" +```sql +-- Example usage of the row_number window function: +SELECT department, + salary, + row_number() OVER (PARTITION BY department ORDER BY salary DESC) AS row_num +FROM employees; + ++-------------+--------+---------+ +| department | salary | row_num | ++-------------+--------+---------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 3 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+---------+ +``` +"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RowNumber { signature: Signature, } @@ -86,8 +109,8 @@ impl WindowUDFImpl for RowNumber { Ok(Box::::default()) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) } fn sort_options(&self) -> Option { @@ -100,6 +123,10 @@ impl WindowUDFImpl for RowNumber { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::None + } } /// State for the `row_number` built-in window function. @@ -119,7 +146,7 @@ impl PartitionEvaluator for NumRowsEvaluator { _values: &[ArrayRef], num_rows: usize, ) -> Result { - Ok(std::sync::Arc::new(UInt64Array::from_iter_values( + Ok(Arc::new(UInt64Array::from_iter_values( 1..(num_rows as u64) + 1, ))) } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 729770b8a65c6..90331fbccaf06 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -80,16 +80,23 @@ log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } rand = { workspace = true } regex = { workspace = true, optional = true } -sha2 = { version = "^0.10.1", optional = true } +sha2 = { version = "^0.10.9", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "1.16", features = ["v4"], optional = true } +uuid = { version = "1.18", features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } +ctor = { workspace = true } +env_logger = { workspace = true } rand = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "sync"] } +[[bench]] +harness = false +name = "ascii" +required-features = ["string_expressions"] + [[bench]] harness = false name = "concat" diff --git a/datafusion/functions/README.md b/datafusion/functions/README.md index a610d135c0f68..dee1330422727 100644 --- a/datafusion/functions/README.md +++ b/datafusion/functions/README.md @@ -17,11 +17,17 @@ under the License. --> -# DataFusion Function Library +# Apache DataFusion Function Library -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate contains packages of function that can be used to customize the functionality of DataFusion. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs new file mode 100644 index 0000000000000..55471817d2778 --- /dev/null +++ b/datafusion/functions/benches/ascii.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; +mod helper; + +use arrow::datatypes::{DataType, Field}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::ScalarFunctionArgs; +use helper::gen_string_array; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let ascii = datafusion_functions::string::ascii(); + + // All benches are single batch run with 8192 rows + const N_ROWS: usize = 8192; + const STR_LEN: usize = 16; + const UTF8_DENSITY_OF_ALL_ASCII: f32 = 0.0; + const NORMAL_UTF8_DENSITY: f32 = 0.8; + + for null_density in [0.0, 0.5] { + // StringArray ASCII only + let args_string_ascii = gen_string_array( + N_ROWS, + STR_LEN, + null_density, + UTF8_DENSITY_OF_ALL_ASCII, + false, + ); + + let arg_fields = + vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function( + format!("ascii/string_ascii_only (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringArray UTF8 + let args_string_utf8 = + gen_string_array(N_ROWS, STR_LEN, null_density, NORMAL_UTF8_DENSITY, false); + let arg_fields = + vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_utf8 (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringViewArray ASCII only + let args_string_view_ascii = gen_string_array( + N_ROWS, + STR_LEN, + null_density, + UTF8_DENSITY_OF_ALL_ASCII, + true, + ); + let arg_fields = + vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_view_ascii_only (null_density={null_density})") + .as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringViewArray UTF8 + let args_string_view_utf8 = + gen_string_array(N_ROWS, STR_LEN, null_density, NORMAL_UTF8_DENSITY, true); + let arg_fields = + vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_view_utf8 (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index bbcfed021064a..edb61c013e242 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -17,10 +17,12 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; +use std::sync::Arc; mod helper; @@ -28,20 +30,30 @@ fn criterion_benchmark(c: &mut Criterion) { // All benches are single batch run with 8192 rows let character_length = datafusion_functions::unicode::character_length(); - let return_type = DataType::Utf8; + let return_field = Arc::new(Field::new("f", DataType::Utf8, true)); + let config_options = Arc::new(ConfigOptions::default()); let n_rows = 8192; for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields = args_string_ascii + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringArray_ascii_str_len_{}", str_len), + &format!("character_length_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), })) }) }, @@ -49,14 +61,23 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); + let arg_fields = args_string_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringArray_utf8_str_len_{}", str_len), + &format!("character_length_StringArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), })) }) }, @@ -64,14 +85,23 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields = args_string_view_ascii + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringViewArray_ascii_str_len_{}", str_len), + &format!("character_length_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), })) }) }, @@ -79,14 +109,23 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields = args_string_view_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringViewArray_utf8_str_len_{}", str_len), + &format!("character_length_StringViewArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 4750fb4666532..ec3f188f90844 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -17,41 +17,57 @@ extern crate criterion; -use arrow::{array::PrimitiveArray, datatypes::Int64Type, util::test_util::seedable_rng}; +use arrow::{array::PrimitiveArray, datatypes::Int64Type}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::chr; -use rand::Rng; +use rand::{Rng, SeedableRng}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::config::ConfigOptions; +use rand::rngs::StdRng; use std::sync::Arc; +/// Returns fixed seedable RNG +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + fn criterion_benchmark(c: &mut Criterion) { let cot_fn = chr(); let size = 1024; let input: PrimitiveArray = { let null_density = 0.2; - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..size) .map(|_| { - if rng.gen::() < null_density { + if rng.random::() < null_density { None } else { - Some(rng.gen_range::(1i64..10_000)) + Some(rng.random_range::(1i64..10_000)) } }) .collect() }; let input = Arc::new(input); let args = vec![ColumnarValue::Array(input)]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + c.bench_function("chr", |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 45ca076e754f6..15f9ffbd78025 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -16,9 +16,10 @@ // under the License. use arrow::array::ArrayRef; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::concat; @@ -37,6 +38,15 @@ fn create_args(size: usize, str_len: usize) -> Vec { fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args(size, 32); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { b.iter(|| { @@ -45,8 +55,10 @@ fn criterion_benchmark(c: &mut Criterion) { concat() .invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index b2a9ca0b9f470..937d092cc0282 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -25,7 +25,8 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::cot; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::config::ConfigOptions; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { @@ -33,14 +34,25 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("cot f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("cot f32 array: {size}"), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) @@ -48,14 +60,25 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("cot f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Float64, true)); + + c.bench_function(&format!("cot f64 array: {size}"), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 7ea5fdcb2be2e..ea8705984f386 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -20,18 +20,19 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; -use rand::rngs::ThreadRng; -use rand::Rng; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::date_bin; +use rand::rngs::ThreadRng; +use rand::Rng; fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { let mut seconds = vec![]; for _ in 0..1000 { - seconds.push(rng.gen_range(0..1_000_000)); + seconds.push(rng.random_range(0..1_000_000)); } TimestampSecondArray::from(seconds) @@ -39,7 +40,7 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_bin_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let timestamps_array = Arc::new(timestamps(&mut rng)) as ArrayRef; let batch_len = timestamps_array.len(); let interval = ColumnarValue::Scalar(ScalarValue::new_interval_dt(0, 1_000_000)); @@ -48,13 +49,22 @@ fn criterion_benchmark(c: &mut Criterion) { let return_type = udf .return_type(&[interval.data_type(), timestamps.data_type()]) .unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); + + let arg_fields = vec![ + Field::new("a", interval.data_type(), true).into(), + Field::new("b", timestamps.data_type(), true).into(), + ]; + let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![interval.clone(), timestamps.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index e7e96fb7a9fa7..70d372429b2d0 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -20,18 +20,19 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; -use rand::rngs::ThreadRng; -use rand::Rng; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::date_trunc; +use rand::rngs::ThreadRng; +use rand::Rng; fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { let mut seconds = vec![]; for _ in 0..1000 { - seconds.push(rng.gen_range(0..1_000_000)); + seconds.push(rng.random_range(0..1_000_000)); } TimestampSecondArray::from(seconds) @@ -39,7 +40,7 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_trunc_minute_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let timestamps_array = Arc::new(timestamps(&mut rng)) as ArrayRef; let batch_len = timestamps_array.len(); let precision = @@ -47,15 +48,28 @@ fn criterion_benchmark(c: &mut Criterion) { let timestamps = ColumnarValue::Array(timestamps_array); let udf = date_trunc(); let args = vec![precision, timestamps]; - let return_type = &udf + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + let return_type = udf .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) .unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); + let config_options = Arc::new(ConfigOptions::default()); + b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index cf8f8d2fd62c7..dc2529cd9fd76 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -17,15 +17,19 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::array::Array; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::encoding; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let decode = encoding::decode(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let str_array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); c.bench_function(&format!("base64_decode/{size}"), |b| { @@ -33,19 +37,31 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_fields: vec![ + Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ], number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) @@ -54,22 +70,36 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); + let arg_fields = vec![ + Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_fields, number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; + let return_field = Field::new("f", DataType::Utf8, true).into(); let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 9307525482c2b..df7d7cc09dd23 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -18,14 +18,15 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; use std::sync::Arc; @@ -51,7 +52,7 @@ fn gen_args_array( let mut output_set_vec: Vec> = Vec::with_capacity(n_rows); let mut output_element_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_element_vec.push(None); output_set_vec.push(None); @@ -60,7 +61,7 @@ fn gen_args_array( let mut generated_string = String::with_capacity(str_len_chars); for i in 0..num_elements { for _ in 0..str_len_chars { - let idx = rng_ref.gen_range(0..corpus_char_count); + let idx = rng_ref.random_range(0..corpus_char_count); let char = utf8.chars().nth(idx).unwrap(); generated_string.push(char); } @@ -112,7 +113,7 @@ fn random_element_in_set(string: &str) -> String { } let mut rng = StdRng::seed_from_u64(44); - let random_index = rng.gen_range(0..elements.len()); + let random_index = rng.random_range(0..elements.len()); elements[random_index].to_string() } @@ -153,23 +154,37 @@ fn criterion_benchmark(c: &mut Criterion) { group.measurement_time(Duration::from_secs(10)); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false); - group.bench_function(format!("string_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Field::new("f", DataType::Int32, true).into(); + group.bench_function(format!("string_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), })) }) }); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true); - group.bench_function(format!("string_view_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + group.bench_function(format!("string_view_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), })) }) }); @@ -179,23 +194,39 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("find_in_set_scalar"); let args = gen_args_scalar(n_rows, str_len, 0.1, false); - group.bench_function(format!("string_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + group.bench_function(format!("string_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), })) }) }); let args = gen_args_scalar(n_rows, str_len, 0.1, true); - group.bench_function(format!("string_view_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + let config_options = Arc::new(ConfigOptions::default()); + + group.bench_function(format!("string_view_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index f8c855c82ad4a..913ed523543e0 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -17,11 +17,13 @@ extern crate criterion; +use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, Int64Array}, datatypes::DataType, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::gcd; @@ -29,9 +31,9 @@ use rand::Rng; use std::sync::Arc; fn generate_i64_array(n_rows: usize) -> ArrayRef { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = (0..n_rows) - .map(|_| rng.gen_range(0..1000)) + .map(|_| rng.random_range(0..1000)) .collect::>(); Arc::new(Int64Array::from(values)) as ArrayRef } @@ -41,14 +43,20 @@ fn criterion_benchmark(c: &mut Criterion) { let array_a = ColumnarValue::Array(generate_i64_array(n_rows)); let array_b = ColumnarValue::Array(generate_i64_array(n_rows)); let udf = gcd(); + let config_options = Arc::new(ConfigOptions::default()); c.bench_function("gcd both array", |b| { b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), array_b.clone()], + arg_fields: vec![ + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", array_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::clone(&config_options), }) .expect("date_bin should work on valid values"), ) @@ -63,8 +71,13 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), scalar_b.clone()], + arg_fields: vec![ + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", scalar_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::clone(&config_options), }) .expect("date_bin should work on valid values"), ) @@ -79,8 +92,13 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![scalar_a.clone(), scalar_b.clone()], + arg_fields: vec![ + Field::new("a", scalar_a.data_type(), true).into(), + Field::new("b", scalar_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::clone(&config_options), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/helper.rs b/datafusion/functions/benches/helper.rs index 0dbb4b0027d42..a2b110ae4d63b 100644 --- a/datafusion/functions/benches/helper.rs +++ b/datafusion/functions/benches/helper.rs @@ -17,7 +17,7 @@ use arrow::array::{StringArray, StringViewArray}; use datafusion_expr::ColumnarValue; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::{rngs::StdRng, Rng, SeedableRng}; use std::sync::Arc; @@ -39,14 +39,14 @@ pub fn gen_string_array( let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_string_vec.push(None); } else if rand_num < null_density + utf8_density { // Generate random UTF8 string let mut generated_string = String::with_capacity(str_len_chars); for _ in 0..str_len_chars { - let char = corpus[rng_ref.gen_range(0..corpus.len())]; + let char = corpus[rng_ref.random_range(0..corpus.len())]; generated_string.push(char); } output_string_vec.push(Some(generated_string)); diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 97c76831b33c8..7562e990ca16c 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::OffsetSizeTrait; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::sync::Arc; @@ -49,14 +50,25 @@ fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); for size in [1024, 4096] { let args = create_args::(size, 8, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + c.bench_function( - format!("initcap string view shorter than 12 [size={}]", size).as_str(), + format!("initcap string view shorter than 12 [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8View, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), })) }) }, @@ -64,25 +76,29 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 16, true); c.bench_function( - format!("initcap string view longer than 12 [size={}]", size).as_str(), + format!("initcap string view longer than 12 [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8View, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), })) }) }, ); let args = create_args::(size, 16, false); - c.bench_function(format!("initcap string [size={}]", size).as_str(), |b| { + c.bench_function(format!("initcap string [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 42004cc24f69d..f59c7af939ab2 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -17,12 +17,13 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::isnan; use std::sync::Arc; @@ -32,14 +33,25 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("isnan f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("isnan f32 array: {size}"), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) @@ -47,14 +59,23 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("isnan f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function(&format!("isnan f64 array: {size}"), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 9e5f6a84804bc..9752a9364b9f3 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -17,12 +17,13 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::iszero; use std::sync::Arc; @@ -33,14 +34,26 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("iszero f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("iszero f32 array: {size}"), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ) @@ -49,14 +62,25 @@ fn criterion_benchmark(c: &mut Criterion) { let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("iszero f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); + + c.bench_function(&format!("iszero f64 array: {size}"), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 534e5739225d7..83d437c6caa63 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -44,7 +45,7 @@ fn create_args2(size: usize) -> Vec { let mut items = Vec::with_capacity(size); items.push("农历新年".to_string()); for i in 1..size { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } let array = Arc::new(StringArray::from(items)) as ArrayRef; vec![ColumnarValue::Array(array)] @@ -58,11 +59,11 @@ fn create_args3(size: usize) -> Vec { let mut items = Vec::with_capacity(size); let half = size / 2; for i in 0..half { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } items.push("Ⱦ".to_string()); for i in half + 1..size { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } let array = Arc::new(StringArray::from(items)) as ArrayRef; vec![ColumnarValue::Array(array)] @@ -122,44 +123,73 @@ fn create_args5( fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let args = create_args1(size, 32); - c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("lower_all_values_are_ascii: {size}"), |b| { b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }); let args = create_args2(size); - c.bench_function( - &format!("lower_the_first_value_is_nonascii: {}", size), - |b| { - b.iter(|| { - let args_cloned = args.clone(); - black_box(lower.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: size, - return_type: &DataType::Utf8, - })) - }) - }, - ); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("lower_the_first_value_is_nonascii: {size}"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); let args = create_args3(size); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function( - &format!("lower_the_middle_value_is_nonascii: {}", size), + &format!("lower_the_middle_value_is_nonascii: {size}"), |b| { b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }, @@ -176,29 +206,39 @@ fn criterion_benchmark(c: &mut Criterion) { for &str_len in &str_lens { for &size in &sizes { let args = create_args4(size, str_len, *null_density, mixed); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function( - &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", - size, str_len, null_density, mixed), + &format!("lower_all_values_are_ascii_string_views: size: {size}, str_len: {str_len}, null_density: {null_density}, mixed: {mixed}"), |b| b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }), ); let args = create_args4(size, str_len, *null_density, mixed); c.bench_function( - &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", - size, str_len, null_density, mixed), + &format!("lower_all_values_are_ascii_string_views: size: {size}, str_len: {str_len}, null_density: {null_density}, mixed: {mixed}"), |b| b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }), ); @@ -211,8 +251,10 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 457fb499f5a10..2712223506b9e 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -18,21 +18,18 @@ extern crate criterion; use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{ black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion, SamplingMode, }; +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; use datafusion_functions::string; -use rand::{distributions::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; +use rand::{distr::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; use std::{fmt, sync::Arc}; -pub fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - #[derive(Clone, Copy)] pub enum StringArrayType { Utf8View, @@ -58,14 +55,14 @@ pub fn create_string_array_and_characters( remaining_len: usize, string_array_type: StringArrayType, ) -> (ArrayRef, ScalarValue) { - let rng = &mut seedable_rng(); + let rng = &mut StdRng::seed_from_u64(42); // Create `size` rows: // - 10% rows will be `None` // - Other 90% will be strings with same `remaining_len` lengths // We will build the string array on it later. let string_iter = (0..size).map(|_| { - if rng.gen::() < 0.1 { + if rng.random::() < 0.1 { None } else { let mut value = trimmed.as_bytes().to_vec(); @@ -136,6 +133,13 @@ fn run_with_string_type( string_type: StringArrayType, ) { let args = create_args(size, characters, trimmed, remaining_len, string_type); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + group.bench_function( format!( "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", @@ -145,8 +149,10 @@ fn run_with_string_type( let args_cloned = args.clone(); black_box(ltrim.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 8dd7a7a59773c..f0494a9d3b4e4 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -20,19 +20,19 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int32Array}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::rngs::ThreadRng; -use rand::Rng; - +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::make_date; +use rand::rngs::ThreadRng; +use rand::Rng; fn years(rng: &mut ThreadRng) -> Int32Array { let mut years = vec![]; for _ in 0..1000 { - years.push(rng.gen_range(1900..2050)); + years.push(rng.random_range(1900..2050)); } Int32Array::from(years) @@ -41,7 +41,7 @@ fn years(rng: &mut ThreadRng) -> Int32Array { fn months(rng: &mut ThreadRng) -> Int32Array { let mut months = vec![]; for _ in 0..1000 { - months.push(rng.gen_range(1..13)); + months.push(rng.random_range(1..13)); } Int32Array::from(months) @@ -50,27 +50,36 @@ fn months(rng: &mut ThreadRng) -> Int32Array { fn days(rng: &mut ThreadRng) -> Int32Array { let mut days = vec![]; for _ in 0..1000 { - days.push(rng.gen_range(1..29)); + days.push(rng.random_range(1..29)); } Int32Array::from(days) } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_col_col_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let years_array = Arc::new(years(&mut rng)) as ArrayRef; let batch_len = years_array.len(); let years = ColumnarValue::Array(years_array); let months = ColumnarValue::Array(Arc::new(months(&mut rng)) as ArrayRef); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); + let arg_fields = vec![ + Field::new("a", years.data_type(), true).into(), + Field::new("a", months.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); + let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![years.clone(), months.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("make_date should work on valid values"), ) @@ -78,20 +87,29 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("make_date_scalar_col_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let months_arr = Arc::new(months(&mut rng)) as ArrayRef; let batch_len = months_arr.len(); let months = ColumnarValue::Array(months_arr); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", months.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); + let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), months.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("make_date should work on valid values"), ) @@ -99,20 +117,29 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("make_date_scalar_scalar_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); let day_arr = Arc::new(days(&mut rng)); let batch_len = day_arr.len(); let days = ColumnarValue::Array(day_arr); + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", month.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); + let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("make_date should work on valid values"), ) @@ -123,14 +150,23 @@ fn criterion_benchmark(c: &mut Criterion) { let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); let day = ColumnarValue::Scalar(ScalarValue::Int32(Some(26))); + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", month.data_type(), true).into(), + Field::new("a", day.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); + let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), day.clone()], + arg_fields: arg_fields.clone(), number_rows: 1, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 9096c976bf31d..93ec687c4d0e4 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -17,9 +17,10 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::core::nullif; @@ -33,14 +34,25 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), ColumnarValue::Array(array), ]; - c.bench_function(&format!("nullif scalar array: {}", size), |b| { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("nullif scalar array: {size}"), |b| { b.iter(|| { black_box( nullif .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f78a53fbee191..125559269a4f6 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -16,14 +16,16 @@ // under the License. use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; -use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{DataType, Field, Int64Type}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode::{lpad, rpad}; -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; use rand::Rng; use std::sync::Arc; @@ -52,13 +54,13 @@ where dist: Uniform::new_inclusive::(0, len as i64), }; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..size) .map(|_| { - if rng.gen::() < null_density { + if rng.random::() < null_density { None } else { - Some(rng.sample(&dist)) + Some(rng.sample(dist.dist.unwrap())) } }) .collect() @@ -95,21 +97,43 @@ fn create_args( } } +fn invoke_pad_with_args( + args: Vec, + number_rows: usize, + left_pad: bool, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + let scalar_args = ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + + if left_pad { + lpad().invoke_with_args(scalar_args) + } else { + rpad().invoke_with_args(scalar_args) + } +} + fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 2048] { let mut group = c.benchmark_group("lpad function"); let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -118,13 +142,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::LargeUtf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -133,13 +151,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -152,13 +164,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -167,13 +173,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::LargeUtf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -183,13 +183,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 78ebf23e02e07..ac92aed586bae 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -17,13 +17,17 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::math::random::RandomFunc; +use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let random_func = RandomFunc::new(); + let return_field = Field::new("f", DataType::Float64, true).into(); + let config_options = Arc::new(ConfigOptions::default()); // Benchmark to evaluate 1M rows in batch size 8192 let iterations = 1_000_000 / 8192; // Calculate how many iterations are needed to reach approximately 1M rows @@ -34,8 +38,10 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 8192, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ); @@ -43,6 +49,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let return_field = Field::new("f", DataType::Float64, true).into(); // Benchmark to evaluate 1M rows in batch size 128 let iterations_128 = 1_000_000 / 128; // Calculate how many iterations are needed to reach approximately 1M rows with batch size 128 c.bench_function("random_1M_rows_batch_128", |b| { @@ -52,8 +59,10 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 128, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ); diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 1f99cc3a5f0bc..c18241f799e36 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -23,12 +23,13 @@ use arrow::compute::cast; use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_functions::regex::regexpcount::regexp_count_func; +use datafusion_functions::regex::regexpinstr::regexp_instr_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; +use rand::prelude::IndexedRandom; use rand::rngs::ThreadRng; -use rand::seq::SliceRandom; use rand::Rng; use std::iter; use std::sync::Arc; @@ -65,7 +66,16 @@ fn regex(rng: &mut ThreadRng) -> StringArray { fn start(rng: &mut ThreadRng) -> Int64Array { let mut data: Vec = vec![]; for _ in 0..1000 { - data.push(rng.gen_range(1..5)); + data.push(rng.random_range(1..5)); + } + + Int64Array::from(data) +} + +fn n(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.random_range(1..5)); } Int64Array::from(data) @@ -86,9 +96,18 @@ fn flags(rng: &mut ThreadRng) -> StringArray { sb.finish() } +fn subexp(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.random_range(1..5)); + } + + Int64Array::from(data) +} + fn criterion_benchmark(c: &mut Criterion) { c.bench_function("regexp_count_1000 string", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let start = Arc::new(start(&mut rng)) as ArrayRef; @@ -108,7 +127,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_count_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let start = Arc::new(start(&mut rng)) as ArrayRef; @@ -127,8 +146,54 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("regexp_instr_1000 string", |b| { + let mut rng = rand::rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; + let n = Arc::new(n(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + let subexp = Arc::new(subexp(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_instr_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&n), + Arc::clone(&flags), + Arc::clone(&subexp), + ]) + .expect("regexp_instr should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_instr_1000 utf8view", |b| { + let mut rng = rand::rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let n = Arc::new(n(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_instr_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&n), + Arc::clone(&flags), + ]) + .expect("regexp_instr should work on utf8view"), + ) + }) + }); + c.bench_function("regexp_like_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; @@ -142,7 +207,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_like_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); @@ -156,7 +221,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_match_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; @@ -174,7 +239,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_match_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); @@ -192,21 +257,21 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_replace_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; let replacement = - Arc::new(StringArray::from_iter_values(iter::repeat("XX").take(1000))) + Arc::new(StringArray::from_iter_values(iter::repeat_n("XX", 1000))) as ArrayRef; b.iter(|| { black_box( - regexp_replace::( + regexp_replace::( data.as_string::(), regex.as_string::(), replacement.as_string::(), - Some(&flags), + Some(flags.as_string::()), ) .expect("regexp_replace should work on valid values"), ) @@ -214,22 +279,21 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_replace_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); - // flags are not allowed to be utf8view according to the function - let flags = Arc::new(flags(&mut rng)) as ArrayRef; - let replacement = Arc::new(StringViewArray::from_iter_values( - iter::repeat("XX").take(1000), - )); + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + let replacement = Arc::new(StringViewArray::from_iter_values(iter::repeat_n( + "XX", 1000, + ))); b.iter(|| { black_box( - regexp_replace::( + regexp_replace::( data.as_string_view(), regex.as_string_view(), - &replacement, - Some(&flags), + &*replacement, + Some(flags.as_string_view()), ) .expect("regexp_replace should work on valid values"), ) diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 5cc6a177d9d9a..991a5a467c0e3 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -18,11 +18,13 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -56,66 +58,64 @@ fn create_args( } } +fn invoke_repeat_with_args( + args: Vec, + repeat_times: i64, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + string::repeat().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: repeat_times as usize, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let repeat = string::repeat(); for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 32, repeat_times, true); group.bench_function( - format!( - "repeat_string_view [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string_view [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_large_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_large_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -124,61 +124,40 @@ fn criterion_benchmark(c: &mut Criterion) { // REPEAT 30 TIMES let repeat_times = 30; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 32, repeat_times, true); group.bench_function( - format!( - "repeat_string_view [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string_view [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_large_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_large_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -187,25 +166,18 @@ fn criterion_benchmark(c: &mut Criterion) { // REPEAT overflow let repeat_times = 1073741824; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 2, repeat_times, false); group.bench_function( - format!( - "repeat_string overflow [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string overflow [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index d61f8fb805175..acac674a6de06 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -18,14 +18,17 @@ extern crate criterion; mod helper; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; +use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { // All benches are single batch run with 8192 rows let reverse = datafusion_functions::unicode::reverse(); + let config_options = Arc::new(ConfigOptions::default()); const N_ROWS: usize = 8192; const NULL_DENSITY: f32 = 0.1; @@ -41,13 +44,19 @@ fn criterion_benchmark(c: &mut Criterion) { false, ); c.bench_function( - &format!("reverse_StringArray_ascii_str_len_{}", str_len), + &format!("reverse_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: vec![Field::new( + "a", + args_string_ascii[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }, @@ -58,15 +67,18 @@ fn criterion_benchmark(c: &mut Criterion) { gen_string_array(N_ROWS, str_len, NULL_DENSITY, NORMAL_UTF8_DENSITY, false); c.bench_function( &format!( - "reverse_StringArray_utf8_density_{}_str_len_{}", - NORMAL_UTF8_DENSITY, str_len + "reverse_StringArray_utf8_density_{NORMAL_UTF8_DENSITY}_str_len_{str_len}" ), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: vec![ + Field::new("a", args_string_utf8[0].data_type(), true).into(), + ], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }, @@ -81,13 +93,19 @@ fn criterion_benchmark(c: &mut Criterion) { true, ); c.bench_function( - &format!("reverse_StringViewArray_ascii_str_len_{}", str_len), + &format!("reverse_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: vec![Field::new( + "a", + args_string_view_ascii[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }, @@ -98,15 +116,20 @@ fn criterion_benchmark(c: &mut Criterion) { gen_string_array(N_ROWS, str_len, NULL_DENSITY, NORMAL_UTF8_DENSITY, true); c.bench_function( &format!( - "reverse_StringViewArray_utf8_density_{}_str_len_{}", - NORMAL_UTF8_DENSITY, str_len + "reverse_StringViewArray_utf8_density_{NORMAL_UTF8_DENSITY}_str_len_{str_len}" ), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: vec![Field::new( + "a", + args_string_view_utf8[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 01939fad5f34e..d56f3930d2678 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -19,10 +19,11 @@ extern crate criterion; use arrow::datatypes::DataType; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::signum; use std::sync::Arc; @@ -33,14 +34,26 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("signum f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Field::new("f", DataType::Float32, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("signum f32 array: {size}"), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Float32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ) @@ -50,14 +63,25 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("signum f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Field::new("f", DataType::Float64, true).into(); + + c.bench_function(&format!("signum f64 array: {size}"), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index df57c229e0ad8..fc31abb23d849 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -18,10 +18,11 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; use std::str::Chars; @@ -46,7 +47,7 @@ fn gen_string_array( let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); let mut output_sub_string_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_sub_string_vec.push(None); output_string_vec.push(None); @@ -54,7 +55,7 @@ fn gen_string_array( // Generate random UTF8 string let mut generated_string = String::with_capacity(str_len_chars); for _ in 0..str_len_chars { - let idx = rng_ref.gen_range(0..corpus_char_count); + let idx = rng_ref.random_range(0..corpus_char_count); let char = utf8.chars().nth(idx).unwrap(); generated_string.push(char); } @@ -94,8 +95,8 @@ fn random_substring(chars: Chars) -> String { // get the substring of a random length from the input string by byte unit let mut rng = StdRng::seed_from_u64(44); let count = chars.clone().count(); - let start = rng.gen_range(0..count - 1); - let end = rng.gen_range(start + 1..count); + let start = rng.random_range(0..count - 1); + let end = rng.random_range(start + 1..count); chars .enumerate() .filter(|(i, _)| *i >= start && *i < end) @@ -111,14 +112,21 @@ fn criterion_benchmark(c: &mut Criterion) { for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields = + vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + c.bench_function( - &format!("strpos_StringArray_ascii_str_len_{}", str_len), + &format!("strpos_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), })) }) }, @@ -126,29 +134,36 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); - c.bench_function( - &format!("strpos_StringArray_utf8_str_len_{}", str_len), - |b| { - b.iter(|| { - black_box(strpos.invoke_with_args(ScalarFunctionArgs { - args: args_string_utf8.clone(), - number_rows: n_rows, - return_type: &DataType::Int32, - })) - }) - }, - ); + let arg_fields = + vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); + c.bench_function(&format!("strpos_StringArray_utf8_str_len_{str_len}"), |b| { + b.iter(|| { + black_box(strpos.invoke_with_args(ScalarFunctionArgs { + args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields = + vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( - &format!("strpos_StringViewArray_ascii_str_len_{}", str_len), + &format!("strpos_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), })) }) }, @@ -156,14 +171,19 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields = + vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( - &format!("strpos_StringViewArray_utf8_str_len_{}", str_len), + &format!("strpos_StringViewArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 80ab70ef71b06..f14f10894649f 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -18,11 +18,13 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::sync::Arc; @@ -96,8 +98,27 @@ fn create_args_with_count( } } +fn invoke_substr_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + unicode::substr().invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let substr = unicode::substr(); for size in [1024, 4096] { // string_len = 12, substring_len=6 (see `create_args_without_count`) let len = 12; @@ -107,44 +128,19 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_without_count::(size, len, true, true); group.bench_function( - format!("substr_string_view [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_without_count::(size, len, false, false); - group.bench_function( - format!("substr_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, - ); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); let args = create_args_without_count::(size, len, true, false); group.bench_function( - format!("substr_large_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -158,53 +154,20 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_with_count::(size, len, count, true); group.bench_function( - format!( - "substr_string_view [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_large_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -218,53 +181,20 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_with_count::(size, len, count, true); group.bench_function( - format!( - "substr_string_view [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_large_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index b1c1c3c34a95b..2cc381e4545ee 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -20,14 +20,14 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::distributions::{Alphanumeric, Uniform}; -use rand::prelude::Distribution; -use rand::Rng; - +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode::substr_index; +use rand::distr::{Alphanumeric, Uniform}; +use rand::prelude::Distribution; +use rand::Rng; struct Filter { dist: Dist, @@ -54,21 +54,21 @@ fn data() -> (StringArray, StringArray, Int64Array) { dist: Uniform::new(-4, 5), test: |x: &i64| x != &0, }; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut strings: Vec = vec![]; let mut delimiters: Vec = vec![]; let mut counts: Vec = vec![]; for _ in 0..1000 { - let length = rng.gen_range(20..50); + let length = rng.random_range(20..50); let text: String = (&mut rng) .sample_iter(&Alphanumeric) .take(length) .map(char::from) .collect(); - let char = rng.gen_range(0..text.len()); + let char = rng.random_range(0..text.len()); let delimiter = &text.chars().nth(char).unwrap(); - let count = rng.sample(&dist); + let count = rng.sample(dist.dist.unwrap()); strings.push(text); delimiters.push(delimiter.to_string()); @@ -91,13 +91,24 @@ fn criterion_benchmark(c: &mut Criterion) { let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); let args = vec![strings, delimiters, counts]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + b.iter(|| { black_box( substr_index() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 6f20a20dc219f..9599b8677216c 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -20,30 +20,30 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Date32Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use chrono::prelude::*; use chrono::TimeDelta; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::rngs::ThreadRng; -use rand::seq::SliceRandom; -use rand::Rng; - +use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::TimestampNanosecond; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::to_char; +use rand::prelude::IndexedRandom; +use rand::rngs::ThreadRng; +use rand::Rng; -fn random_date_in_range( +fn pick_date_in_range( rng: &mut ThreadRng, start_date: NaiveDate, end_date: NaiveDate, ) -> NaiveDate { let days_in_range = (end_date - start_date).num_days(); - let random_days: i64 = rng.gen_range(0..days_in_range); + let random_days: i64 = rng.random_range(0..days_in_range); start_date + TimeDelta::try_days(random_days).unwrap() } -fn data(rng: &mut ThreadRng) -> Date32Array { +fn generate_date32_array(rng: &mut ThreadRng) -> Date32Array { let mut data: Vec = vec![]; let unix_days_from_ce = NaiveDate::from_ymd_opt(1970, 1, 1) .unwrap() @@ -56,7 +56,7 @@ fn data(rng: &mut ThreadRng) -> Date32Array { .expect("Date should parse"); for _ in 0..1000 { data.push( - random_date_in_range(rng, start_date, end_date).num_days_from_ce() + pick_date_in_range(rng, start_date, end_date).num_days_from_ce() - unix_days_from_ce, ); } @@ -64,65 +64,205 @@ fn data(rng: &mut ThreadRng) -> Date32Array { Date32Array::from(data) } -fn patterns(rng: &mut ThreadRng) -> StringArray { - let samples = [ - "%Y:%m:%d".to_string(), - "%d-%m-%Y".to_string(), - "%d%m%Y".to_string(), - "%Y%m%d".to_string(), - "%Y...%m...%d".to_string(), - ]; - let mut data: Vec = vec![]; +const DATE_PATTERNS: [&str; 5] = + ["%Y:%m:%d", "%d-%m-%Y", "%d%m%Y", "%Y%m%d", "%Y...%m...%d"]; + +const DATETIME_PATTERNS: [&str; 8] = [ + "%Y:%m:%d %H:%M%S", + "%Y:%m:%d %_H:%M%S", + "%Y:%m:%d %k:%M%S", + "%d-%m-%Y %I%P-%M-%S %f", + "%d%m%Y %H", + "%Y%m%d %M-%S %.3f", + "%Y...%m...%d %T%3f", + "%c", +]; + +fn pick_date_pattern(rng: &mut ThreadRng) -> String { + (*DATE_PATTERNS + .choose(rng) + .expect("Empty list of date patterns")) + .to_string() +} + +fn pick_date_time_pattern(rng: &mut ThreadRng) -> String { + (*DATETIME_PATTERNS + .choose(rng) + .expect("Empty list of date time patterns")) + .to_string() +} + +fn pick_date_and_date_time_mixed_pattern(rng: &mut ThreadRng) -> String { + match rng.random_bool(0.5) { + true => pick_date_pattern(rng), + false => pick_date_time_pattern(rng), + } +} + +fn generate_pattern_array( + rng: &mut ThreadRng, + pick_fn: impl Fn(&mut ThreadRng) -> String, +) -> StringArray { + let mut data = Vec::with_capacity(1000); + for _ in 0..1000 { - data.push(samples.choose(rng).unwrap().to_string()); + data.push(pick_fn(rng)); } StringArray::from(data) } +fn generate_date_pattern_array(rng: &mut ThreadRng) -> StringArray { + generate_pattern_array(rng, pick_date_pattern) +} + +fn generate_datetime_pattern_array(rng: &mut ThreadRng) -> StringArray { + generate_pattern_array(rng, pick_date_time_pattern) +} + +fn generate_mixed_pattern_array(rng: &mut ThreadRng) -> StringArray { + generate_pattern_array(rng, pick_date_and_date_time_mixed_pattern) +} + fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("to_char_array_array_1000", |b| { - let mut rng = rand::thread_rng(); - let data_arr = data(&mut rng); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("to_char_array_date_only_patterns_1000", |b| { + let mut rng = rand::rng(); + let data_arr = generate_date32_array(&mut rng); + let batch_len = data_arr.len(); + let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); + let patterns = ColumnarValue::Array(Arc::new(generate_date_pattern_array( + &mut rng, + )) as ArrayRef); + + b.iter(|| { + black_box( + to_char() + .invoke_with_args(ScalarFunctionArgs { + args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("to_char should work on valid values"), + ) + }) + }); + + c.bench_function("to_char_array_datetime_patterns_1000", |b| { + let mut rng = rand::rng(); + let data_arr = generate_date32_array(&mut rng); + let batch_len = data_arr.len(); + let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); + let patterns = ColumnarValue::Array(Arc::new(generate_datetime_pattern_array( + &mut rng, + )) as ArrayRef); + + b.iter(|| { + black_box( + to_char() + .invoke_with_args(ScalarFunctionArgs { + args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("to_char should work on valid values"), + ) + }) + }); + + c.bench_function("to_char_array_mixed_patterns_1000", |b| { + let mut rng = rand::rng(); + let data_arr = generate_date32_array(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); - let patterns = ColumnarValue::Array(Arc::new(patterns(&mut rng)) as ArrayRef); + let patterns = ColumnarValue::Array(Arc::new(generate_mixed_pattern_array( + &mut rng, + )) as ArrayRef); b.iter(|| { black_box( to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .expect("to_char should work on valid values"), ) }) }); - c.bench_function("to_char_array_scalar_1000", |b| { - let mut rng = rand::thread_rng(); - let data_arr = data(&mut rng); + c.bench_function("to_char_scalar_date_only_pattern_1000", |b| { + let mut rng = rand::rng(); + let data_arr = generate_date32_array(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%Y-%m-%d".to_string()))); + ColumnarValue::Scalar(ScalarValue::Utf8(Some(pick_date_pattern(&mut rng)))); b.iter(|| { black_box( to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .expect("to_char should work on valid values"), ) }) }); - c.bench_function("to_char_scalar_scalar_1000", |b| { + c.bench_function("to_char_scalar_datetime_pattern_1000", |b| { + let mut rng = rand::rng(); + let data_arr = generate_date32_array(&mut rng); + let batch_len = data_arr.len(); + let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); + let patterns = ColumnarValue::Scalar(ScalarValue::Utf8(Some( + pick_date_time_pattern(&mut rng), + ))); + + b.iter(|| { + black_box( + to_char() + .invoke_with_args(ScalarFunctionArgs { + args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("to_char should work on valid values"), + ) + }) + }); + + c.bench_function("to_char_scalar_1000", |b| { + let mut rng = rand::rng(); let timestamp = "2026-07-08T09:10:11" .parse::() .unwrap() @@ -132,17 +272,21 @@ fn criterion_benchmark(c: &mut Criterion) { .timestamp_nanos_opt() .unwrap(); let data = ColumnarValue::Scalar(TimestampNanosecond(Some(timestamp), None)); - let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some( - "%d-%m-%Y %H:%M:%S".to_string(), - ))); + let pattern = + ColumnarValue::Scalar(ScalarValue::Utf8(Some(pick_date_pattern(&mut rng)))); b.iter(|| { black_box( to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), pattern.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", pattern.data_type(), true).into(), + ], number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a45d936c0a52d..cad9addab10ec 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -17,9 +17,10 @@ extern crate criterion; -use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::datatypes::{DataType, Field, Int32Type, Int64Type}; use arrow::util::bench_util::create_primitive_array; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -30,14 +31,18 @@ fn criterion_benchmark(c: &mut Criterion) { let i32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = i32_array.len(); let i32_args = vec![ColumnarValue::Array(i32_array)]; - c.bench_function(&format!("to_hex i32 array: {}", size), |b| { + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("to_hex i32 array: {size}"), |b| { b.iter(|| { let args_cloned = i32_args.clone(); black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int32, false).into()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) @@ -46,14 +51,16 @@ fn criterion_benchmark(c: &mut Criterion) { let i64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = i64_array.len(); let i64_args = vec![ColumnarValue::Array(i64_array)]; - c.bench_function(&format!("to_hex i64 array: {}", size), |b| { + c.bench_function(&format!("to_hex i64 array: {size}"), |b| { b.iter(|| { let args_cloned = i64_args.clone(); black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int64, false).into()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index aec56697691fc..7e15d896f83e3 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -22,9 +22,9 @@ use std::sync::Arc; use arrow::array::builder::StringBuilder; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::compute::cast; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, TimeUnit}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; - +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::to_timestamp; @@ -109,7 +109,12 @@ fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { ) } fn criterion_benchmark(c: &mut Criterion) { - let return_type = &DataType::Timestamp(TimeUnit::Nanosecond, None); + let return_field = + Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(); + let arg_field = Field::new("a", DataType::Utf8, false).into(); + let arg_fields = vec![arg_field]; + let config_options = Arc::new(ConfigOptions::default()); + c.bench_function("to_timestamp_no_formats_utf8", |b| { let arr_data = data(); let batch_len = arr_data.len(); @@ -120,8 +125,10 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("to_timestamp should work on valid values"), ) @@ -138,8 +145,10 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("to_timestamp should work on valid values"), ) @@ -156,8 +165,10 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("to_timestamp should work on valid values"), ) @@ -174,13 +185,23 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Array(Arc::new(format2) as ArrayRef), ColumnarValue::Array(Arc::new(format3) as ArrayRef), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("to_timestamp should work on valid values"), ) @@ -205,13 +226,23 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef ), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("to_timestamp should work on valid values"), ) @@ -237,13 +268,23 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef ), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 7fc93921d2e7b..160eac913d2b6 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::trunc; use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { @@ -33,14 +34,20 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("trunc f32 array: {}", size), |b| { + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field = Field::new("f", DataType::Float32, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("trunc f32 array: {size}"), |b| { b.iter(|| { black_box( trunc .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float32, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ) @@ -48,14 +55,18 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("trunc f64 array: {}", size), |b| { + let arg_fields = vec![Field::new("a", DataType::Float64, true).into()]; + let return_field = Field::new("f", DataType::Float64, true).into(); + c.bench_function(&format!("trunc f64 array: {size}"), |b| { b.iter(|| { black_box( trunc .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index f0bee89c7d376..700f70b4b4f36 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -17,9 +17,10 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -35,6 +36,8 @@ fn create_args(size: usize, str_len: usize) -> Vec { fn criterion_benchmark(c: &mut Criterion) { let upper = string::upper(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let args = create_args(size, 32); c.bench_function("upper_all_values_are_ascii", |b| { @@ -42,8 +45,10 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(upper.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 7b8d156fec219..f9345a97eb53c 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -17,19 +17,25 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; use datafusion_expr::ScalarFunctionArgs; use datafusion_functions::string; +use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let uuid = string::uuid(); + let config_options = Arc::new(ConfigOptions::default()); + c.bench_function("uuid", |b| { b.iter(|| { black_box(uuid.invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 1024, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), })) }) }); diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 2686dbf8be3cc..94a41ba4bb251 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,7 +17,7 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, @@ -29,7 +29,7 @@ use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -80,7 +80,7 @@ use datafusion_macros::user_doc; description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]" ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrowCastFunc { signature: Signature, } @@ -113,11 +113,11 @@ impl ScalarUDFImpl for ArrowCastFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let nullable = args.nullables.iter().any(|&nullable| nullable); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?; @@ -131,7 +131,7 @@ impl ScalarUDFImpl for ArrowCastFunc { ) }, |casted_type| match casted_type.parse::() { - Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()), Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), Err(e) => Err(arrow_datafusion_err!(e)), }, @@ -177,7 +177,7 @@ impl ScalarUDFImpl for ArrowCastFunc { fn data_type_from_args(args: &[Expr]) -> Result { let [_, type_arg] = take_function_args("arrow_cast", args)?; - let Expr::Literal(ScalarValue::Utf8(Some(val))) = type_arg else { + let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else { return exec_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", type_arg diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index 2509ed246ac7c..f178890f93704 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -40,7 +40,7 @@ use std::any::Any; description = "Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrowTypeOfFunc { signature: Signature, } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index ba20c23828eb8..3fba539dd04b4 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{new_null_array, BooleanArray}; -use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, is_not_null, is_null}; -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, internal_err, Result}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{exec_err, internal_err, plan_err, Result}; use datafusion_expr::binary::try_type_union_resolution; +use datafusion_expr::conditional_expressions::CaseBuilder; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -46,7 +45,7 @@ use std::any::Any; description = "Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct CoalesceFunc { signature: Signature, } @@ -79,76 +78,52 @@ impl ScalarUDFImpl for CoalesceFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // If any the arguments in coalesce is non-null, the result is non-null - let nullable = args.nullables.iter().all(|&nullable| nullable); + let nullable = args.arg_fields.iter().all(|f| f.is_nullable()); let return_type = args - .arg_types + .arg_fields .iter() + .map(|f| f.data_type()) .find_or_first(|d| !d.is_null()) .unwrap() .clone(); - Ok(ReturnInfo::new(return_type, nullable)) + Ok(Field::new(self.name(), return_type, nullable).into()) } - /// coalesce evaluates to the first value which is not NULL - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = args.args; - // do not accept 0 arguments. + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { if args.is_empty() { - return exec_err!( - "coalesce was called with {} arguments. It requires at least 1.", - args.len() - ); + return plan_err!("coalesce must have at least one argument"); } - - let return_type = args[0].data_type(); - let mut return_array = args.iter().filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }); - - if let Some(size) = return_array.next() { - // start with nulls as default output - let mut current_value = new_null_array(&return_type, size); - let mut remainder = BooleanArray::from(vec![true; size]); - - for arg in args { - match arg { - ColumnarValue::Array(ref array) => { - let to_apply = and(&remainder, &is_not_null(array.as_ref())?)?; - current_value = zip(&to_apply, array, ¤t_value)?; - remainder = and(&remainder, &is_null(array)?)?; - } - ColumnarValue::Scalar(value) => { - if value.is_null() { - continue; - } else { - let last_value = value.to_scalar()?; - current_value = zip(&remainder, &last_value, ¤t_value)?; - break; - } - } - } - if remainder.iter().all(|x| x == Some(false)) { - break; - } - } - Ok(ColumnarValue::Array(current_value)) - } else { - let result = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Scalar(s) if !s.is_null() => Some(x.clone()), - _ => None, - }) - .next() - .unwrap_or_else(|| args[0].clone()); - Ok(result) + if args.len() == 1 { + return Ok(ExprSimplifyResult::Simplified( + args.into_iter().next().unwrap(), + )); } + + let n = args.len(); + let (init, last_elem) = args.split_at(n - 1); + let whens = init + .iter() + .map(|x| x.clone().is_not_null()) + .collect::>(); + let cases = init.to_vec(); + Ok(ExprSimplifyResult::Simplified( + CaseBuilder::new(None, whens, cases, Some(Box::new(last_elem[0].clone()))) + .end()?, + )) + } + + /// coalesce evaluates to the first value which is not NULL + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("coalesce should have been simplified to case") } fn short_circuits(&self) -> bool { diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3ac26b98359bb..d18bd6e31f72e 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -20,7 +20,7 @@ use arrow::array::{ Scalar, }; use arrow::compute::SortOptions; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ @@ -28,7 +28,7 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -75,7 +75,7 @@ use std::sync::Arc; description = "The field name in the map or struct to retrieve data for. Must evaluate to a string." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct GetFieldFunc { signature: Signature, } @@ -108,8 +108,8 @@ impl ScalarUDFImpl for GetFieldFunc { let [base, field_name] = take_function_args(self.name(), args)?; let name = match field_name { - Expr::Literal(name) => name, - other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), + Expr::Literal(name, _) => name.to_string(), + other => other.schema_name().to_string(), }; Ok(format!("{base}[{name}]")) @@ -118,8 +118,8 @@ impl ScalarUDFImpl for GetFieldFunc { fn schema_name(&self, args: &[Expr]) -> Result { let [base, field_name] = take_function_args(self.name(), args)?; let name = match field_name { - Expr::Literal(name) => name, - other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), + Expr::Literal(name, _) => name.to_string(), + other => other.schema_name().to_string(), }; Ok(format!("{}[{}]", base.schema_name(), name)) @@ -130,14 +130,14 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_type(&self, _: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert_eq!(args.scalar_arguments.len(), 2); - match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { + match (&args.arg_fields[0].data_type(), args.scalar_arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { @@ -146,7 +146,8 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(ReturnInfo::new_nullable(value_field.data_type().clone())) + + Ok(value_field.as_ref().clone().with_nullable(true).into()) }, _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -158,10 +159,20 @@ impl ScalarUDFImpl for GetFieldFunc { |field_name| { fields.iter().find(|f| f.name() == field_name) .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) - .map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) + .map(|f| { + let mut child_field = f.as_ref().clone(); + + // If the parent is nullable, then getting the child must be nullable, + // so potentially override the return value + + if args.arg_fields[0].is_nullable() { + child_field = child_field.with_nullable(true); + } + Arc::new(child_field) + }) }) }, - (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), + (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true).into()), (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } @@ -245,7 +256,7 @@ impl ScalarUDFImpl for GetFieldFunc { (DataType::Map(_, _), other) => { let data_type = other.data_type(); if data_type.is_nested() { - exec_err!("unsupported type {:?} for map access", data_type) + exec_err!("unsupported type {} for map access", data_type) } else { process_map_array(array, other.to_array()?) } @@ -264,7 +275,7 @@ impl ScalarUDFImpl for GetFieldFunc { (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), (dt, name) => exec_err!( "get_field is only possible on maps with utf8 indexes or struct \ - with utf8 indexes. Received {dt:?} with {name:?} index" + with utf8 indexes. Received {dt} with {name:?} index" ), } } diff --git a/datafusion/functions/src/core/greatest.rs b/datafusion/functions/src/core/greatest.rs index 2d7ad2be3986f..6afc5b25512f4 100644 --- a/datafusion/functions/src/core/greatest.rs +++ b/datafusion/functions/src/core/greatest.rs @@ -53,7 +53,7 @@ const SORT_OPTIONS: SortOptions = SortOptions { description = "Expressions to compare and return the greatest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct GreatestFunc { signature: Signature, } diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index 662dac3e699fb..31cdf54441117 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -53,7 +53,7 @@ const SORT_OPTIONS: SortOptions = SortOptions { description = "Expressions to compare and return the smallest value. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct LeastFunc { signature: Signature, } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index c6329b1ee0afd..db080cd628478 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -36,6 +36,7 @@ pub mod overlay; pub mod planner; pub mod r#struct; pub mod union_extract; +pub mod union_tag; pub mod version; // create UDFs @@ -52,6 +53,7 @@ make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); make_udf_function!(least::LeastFunc, least); make_udf_function!(union_extract::UnionExtractFun, union_extract); +make_udf_function!(union_tag::UnionTagFunc, union_tag); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { @@ -101,6 +103,10 @@ pub mod expr_fn { least, "Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL", args, + ),( + union_tag, + "Returns the name of the currently selected field in the union", + arg1 )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -136,6 +142,7 @@ pub fn functions() -> Vec> { greatest(), least(), union_extract(), + union_tag(), version(), r#struct(), ] diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index bba884d96483d..1da5148474f8c 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -16,10 +16,10 @@ // under the License. use arrow::array::StructArray; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -58,7 +58,7 @@ a struct type of fields `field_a` and `field_b`: description = "Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NamedStructFunc { signature: Signature, } @@ -91,10 +91,12 @@ impl ScalarUDFImpl for NamedStructFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("named_struct: return_type called instead of return_type_from_args") + internal_err!( + "named_struct: return_type called instead of return_field_from_args" + ) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // do not accept 0 arguments. if args.scalar_arguments.is_empty() { return exec_err!( @@ -102,7 +104,7 @@ impl ScalarUDFImpl for NamedStructFunc { ); } - if args.scalar_arguments.len() % 2 != 0 { + if !args.scalar_arguments.len().is_multiple_of(2) { return exec_err!( "named_struct requires an even number of arguments, got {} instead", args.scalar_arguments.len() @@ -126,7 +128,13 @@ impl ScalarUDFImpl for NamedStructFunc { ) ) .collect::>>()?; - let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); + let types = args + .arg_fields + .iter() + .skip(1) + .step_by(2) + .map(|f| f.data_type()) + .collect::>(); let return_fields = names .into_iter() @@ -134,13 +142,16 @@ impl ScalarUDFImpl for NamedStructFunc { .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; - Ok(ReturnInfo::new_nullable(DataType::Struct(Fields::from( - return_fields, - )))) + Ok(Field::new( + self.name(), + DataType::Struct(Fields::from(return_fields)), + true, + ) + .into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_type() else { return internal_err!("incorrect named_struct return type"); }; diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index ee29714da16b6..be2dd0d2ca160 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -53,7 +53,7 @@ This can be used to perform the inverse operation of [`coalesce`](#coalesce).", description = "Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NullIfFunc { signature: Signature, } diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 82d367072a256..c8b34c4b17800 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -55,7 +55,7 @@ use std::sync::Arc; description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NVLFunc { signature: Signature, aliases: Vec, diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index d20b01e29fba8..82aa8d2a4cd54 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -59,7 +59,7 @@ use std::sync::Arc; description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NVL2Func { signature: Signature, } @@ -113,7 +113,7 @@ impl ScalarUDFImpl for NVL2Func { if let Some(coerced_type) = coerced_type { Ok(coerced_type) } else { - internal_err!("Coercion from {acc:?} to {x:?} failed.") + internal_err!("Coercion from {acc} to {x} failed.") } })?; Ok(vec![new_type; arg_types.len()]) diff --git a/datafusion/functions/src/core/overlay.rs b/datafusion/functions/src/core/overlay.rs index 0ea5359e9621d..165bc571afe09 100644 --- a/datafusion/functions/src/core/overlay.rs +++ b/datafusion/functions/src/core/overlay.rs @@ -53,7 +53,7 @@ use datafusion_macros::user_doc; description = "The count of characters to be replaced from start position of str. If not specified, will use substr length instead." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct OverlayFunc { signature: Signature, } diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 8792bf1bd1b98..32c7af80e397f 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -64,7 +64,7 @@ select struct(a as field_a, b) from t; description = "Expression to include in the output struct. Can be a constant, column, or function, any combination of arithmetic or string operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct StructFunc { signature: Signature, aliases: Vec, @@ -117,7 +117,7 @@ impl ScalarUDFImpl for StructFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_type() else { return internal_err!("incorrect struct return type"); }; diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 420eeed42cc3b..a71e2e87388d5 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -16,14 +16,14 @@ // under the License. use arrow::array::Array; -use arrow::datatypes::{DataType, FieldRef, UnionFields}; +use arrow::datatypes::{DataType, Field, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; use datafusion_common::utils::take_function_args; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; use datafusion_doc::Documentation; -use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; +use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -49,7 +49,7 @@ use datafusion_macros::user_doc; description = "String expression to operate on. Must be a constant." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct UnionExtractFun { signature: Signature, } @@ -82,35 +82,35 @@ impl ScalarUDFImpl for UnionExtractFun { } fn return_type(&self, _: &[DataType]) -> Result { - // should be using return_type_from_args and not calling the default implementation + // should be using return_field_from_args and not calling the default implementation internal_err!("union_extract should return type from args") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 2 { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 2 { return exec_err!( "union_extract expects 2 arguments, got {} instead", - args.arg_types.len() + args.arg_fields.len() ); } - let DataType::Union(fields, _) = &args.arg_types[0] else { + let DataType::Union(fields, _) = &args.arg_fields[0].data_type() else { return exec_err!( "union_extract first argument must be a union, got {} instead", - args.arg_types[0] + args.arg_fields[0].data_type() ); }; let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { return exec_err!( "union_extract second argument must be a non-null string literal, got {} instead", - args.arg_types[1] + args.arg_fields[1].data_type() ); }; let field = find_field(fields, field_name)?.1; - Ok(ReturnInfo::new_nullable(field.data_type().clone())) + Ok(Field::new(self.name(), field.data_type().clone(), true).into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -169,10 +169,11 @@ fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldR #[cfg(test)] mod tests { - use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; use super::UnionExtractFun; @@ -189,47 +190,70 @@ mod tests { ], ); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - None, - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((3, Box::new(ScalarValue::Int32(Some(42))))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((1, Box::new(ScalarValue::new_utf8("42")))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs new file mode 100644 index 0000000000000..aeadb8292ba1e --- /dev/null +++ b/datafusion/functions/src/core/union_tag.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, AsArray, DictionaryArray, Int8Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use datafusion_doc::Documentation; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Union Functions"), + description = "Returns the name of the currently selected field in the union", + syntax_example = "union_tag(union_expression)", + sql_example = r#"```sql +❯ select union_column, union_tag(union_column) from table_with_union; ++--------------+-------------------------+ +| union_column | union_tag(union_column) | ++--------------+-------------------------+ +| {a=1} | a | +| {b=3.0} | b | +| {a=4} | a | +| {b=} | b | +| {a=} | a | ++--------------+-------------------------+ +```"#, + standard_argument(name = "union", prefix = "Union") +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct UnionTagFunc { + signature: Signature, +} + +impl Default for UnionTagFunc { + fn default() -> Self { + Self::new() + } +} + +impl UnionTagFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for UnionTagFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "union_tag" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + )) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [union_] = take_function_args("union_tag", args.args)?; + + match union_ { + ColumnarValue::Array(array) + if matches!(array.data_type(), DataType::Union(_, _)) => + { + let union_array = array.as_union(); + + let keys = Int8Array::try_new(union_array.type_ids().clone(), None)?; + + let fields = match union_array.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!(), + }; + + // Union fields type IDs only constraints are being unique and in the 0..128 range: + // They may not start at 0, be sequential, or even contiguous. + // Therefore, we allocate a values vector with a length equal to the highest type ID plus one, + // ensuring that each field's name can be placed at the index corresponding to its type ID. + let values_len = fields + .iter() + .map(|(type_id, _)| type_id + 1) + .max() + .unwrap_or_default() as usize; + + let mut values = vec![""; values_len]; + + for (type_id, field) in fields.iter() { + values[type_id as usize] = field.name().as_str() + } + + let values = Arc::new(StringArray::from(values)); + + // SAFETY: union type_ids are validated to not be smaller than zero. + // values len is the union biggest type id plus one. + // keys is built from the union type_ids, which contains only valid type ids + // therefore, `keys[i] >= values.len() || keys[i] < 0` never occurs + let dict = unsafe { DictionaryArray::new_unchecked(keys, values) }; + + Ok(ColumnarValue::Array(Arc::new(dict))) + } + ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => match value { + Some((value_type_id, _)) => fields + .iter() + .find(|(type_id, _)| value_type_id == *type_id) + .map(|(_, field)| { + ColumnarValue::Scalar(ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(field.name().as_str().into()), + )) + }) + .ok_or_else(|| { + exec_datafusion_err!( + "union_tag: union scalar with unknown type_id {value_type_id}" + ) + }), + None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + args.return_field.data_type(), + )?)), + }, + v => exec_err!("union_tag only support unions, got {:?}", v.data_type()), + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod tests { + use super::UnionTagFunc; + use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests + #[test] + fn union_scalar() { + let fields = [(0, Arc::new(Field::new("a", DataType::UInt32, false)))] + .into_iter() + .collect(); + + let scalar = ScalarValue::Union( + Some((0, Box::new(ScalarValue::UInt32(Some(0))))), + fields, + UnionMode::Dense, + ); + + let return_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + + let result = UnionTagFunc::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(scalar)], + number_rows: 1, + return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + assert_scalar( + result, + ScalarValue::Dictionary(Box::new(DataType::Int8), Box::new("a".into())), + ); + } + + #[test] + fn union_scalar_empty() { + let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); + + let return_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + + let result = UnionTagFunc::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(scalar)], + number_rows: 1, + return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + assert_scalar( + result, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Utf8(None)), + ), + ); + } + + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), + } + } +} diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 34038022f2dc7..ef3c5aafa4801 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -39,7 +39,7 @@ use std::any::Any; +--------------------------------------------+ ```"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct VersionFunc { signature: Signature, } @@ -97,7 +97,10 @@ impl ScalarUDFImpl for VersionFunc { #[cfg(test)] mod test { use super::*; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; use datafusion_expr::ScalarUDF; + use std::sync::Arc; #[tokio::test] async fn test_version_udf() { @@ -105,8 +108,10 @@ mod test { let version = version_udf .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 0, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }) .unwrap(); diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index eaa688c1c3359..5bf83943a92da 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -21,7 +21,7 @@ use arrow::array::{ Array, ArrayRef, BinaryArray, BinaryArrayType, BinaryViewArray, GenericBinaryArray, OffsetSizeTrait, }; -use arrow::array::{AsArray, GenericStringArray, StringArray, StringViewArray}; +use arrow::array::{AsArray, GenericStringArray, StringViewArray}; use arrow::datatypes::DataType; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; @@ -169,18 +169,18 @@ pub fn md5(args: &[ColumnarValue]) -> Result { let [data] = take_function_args("md5", args)?; let value = digest_process(data, DigestAlgorithm::Md5)?; - // md5 requires special handling because of its unique utf8 return type + // md5 requires special handling because of its unique utf8view return type Ok(match value { ColumnarValue::Array(array) => { let binary_array = as_binary_array(&array)?; - let string_array: StringArray = binary_array + let string_array: StringViewArray = binary_array .iter() .map(|opt| opt.map(hex_encode::<_>)) .collect(); ColumnarValue::Array(Arc::new(string_array)) } ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { - ColumnarValue::Scalar(ScalarValue::Utf8(opt.map(hex_encode::<_>))) + ColumnarValue::Scalar(ScalarValue::Utf8View(opt.map(hex_encode::<_>))) } _ => return exec_err!("Impossibly got invalid results from digest"), }) diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index 2840006169be4..a4999f72f8d56 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -56,7 +56,7 @@ use std::any::Any; - blake3" ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct DigestFunc { signature: Signature, } diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index c1540450029cf..88859fdee34a7 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -45,7 +45,7 @@ use std::any::Any; ```"#, standard_argument(name = "expression", prefix = "String") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Md5Func { signature: Signature, } @@ -92,12 +92,12 @@ impl ScalarUDFImpl for Md5Func { fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; Ok(match &arg_types[0] { - LargeUtf8 | LargeBinary => Utf8, - Utf8View | Utf8 | Binary | BinaryView => Utf8, + LargeUtf8 | LargeBinary => Utf8View, + Utf8View | Utf8 | Binary | BinaryView => Utf8View, Null => Null, Dictionary(_, t) => match **t { - LargeUtf8 | LargeBinary => Utf8, - Utf8 | Binary | BinaryView => Utf8, + LargeUtf8 | LargeBinary => Utf8View, + Utf8 | Binary | BinaryView => Utf8View, Null => Null, _ => { return plan_err!( diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index a64a3ef803197..69b79cce72c4e 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -44,7 +44,7 @@ use std::any::Any; ```"#, standard_argument(name = "expression", prefix = "String") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SHA224Func { signature: Signature, } diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index 94f3ea3b49fa6..9a948ba50c9e1 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -44,7 +44,7 @@ use std::any::Any; ```"#, standard_argument(name = "expression", prefix = "String") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SHA256Func { signature: Signature, } diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index 023730469c7bd..9e363cf883d29 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -44,7 +44,7 @@ use std::any::Any; ```"#, standard_argument(name = "expression", prefix = "String") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SHA384Func { signature: Signature, } diff --git a/datafusion/functions/src/crypto/sha512.rs b/datafusion/functions/src/crypto/sha512.rs index f48737e5751f0..a185698ca46ff 100644 --- a/datafusion/functions/src/crypto/sha512.rs +++ b/datafusion/functions/src/crypto/sha512.rs @@ -44,7 +44,7 @@ use std::any::Any; ```"#, standard_argument(name = "expression", prefix = "String") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SHA512Func { signature: Signature, } diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index fd9f37d8052c7..65f9c9323925f 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -29,7 +29,8 @@ use chrono::{DateTime, TimeZone, Utc}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::{ - exec_err, unwrap_or_internal_err, DataFusionError, Result, ScalarType, ScalarValue, + exec_datafusion_err, exec_err, unwrap_or_internal_err, DataFusionError, Result, + ScalarType, ScalarValue, }; use datafusion_expr::ColumnarValue; @@ -83,9 +84,9 @@ pub(crate) fn string_to_datetime_formatted( format: &str, ) -> Result, DataFusionError> { let err = |err_ctx: &str| { - DataFusionError::Execution(format!( + exec_datafusion_err!( "Error parsing timestamp from '{s}' using format '{format}': {err_ctx}" - )) + ) }; let mut parsed = Parsed::new(); @@ -149,9 +150,7 @@ pub(crate) fn string_to_timestamp_nanos_formatted( .naive_utc() .and_utc() .timestamp_nanos_opt() - .ok_or_else(|| { - DataFusionError::Execution(ERR_NANOSECONDS_NOT_SUPPORTED.to_string()) - }) + .ok_or_else(|| exec_datafusion_err!("{ERR_NANOSECONDS_NOT_SUPPORTED}")) } /// Accepts a string with a `chrono` format and converts it to a @@ -412,8 +411,8 @@ where }?; let r = op(x, v); - if r.is_ok() { - val = Some(Ok(op2(r.unwrap()))); + if let Ok(inner) = r { + val = Some(Ok(op2(inner))); break; } else { val = Some(r); diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 9998e7d3758e0..0ba3afd19bedb 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -17,9 +17,10 @@ use std::any::Any; +use arrow::array::timezone::Tz; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Date32; -use chrono::{Datelike, NaiveDate}; +use chrono::{Datelike, NaiveDate, TimeZone}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; @@ -31,13 +32,13 @@ use datafusion_macros::user_doc; #[user_doc( doc_section(label = "Time and Date Functions"), description = r#" -Returns the current UTC date. +Returns the current date in the session time zone. The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. "#, syntax_example = "current_date()" )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct CurrentDateFunc { signature: Signature, aliases: Vec, @@ -100,14 +101,22 @@ impl ScalarUDFImpl for CurrentDateFunc { info: &dyn SimplifyInfo, ) -> Result { let now_ts = info.execution_props().query_execution_start_time; - let days = Some( - now_ts.num_days_from_ce() - - NaiveDate::from_ymd_opt(1970, 1, 1) - .unwrap() - .num_days_from_ce(), - ); + + // Get timezone from config and convert to local time + let days = info + .execution_props() + .config_options() + .and_then(|config| config.execution.time_zone.parse::().ok()) + .map_or_else( + || datetime_to_days(&now_ts), + |tz| { + let local_now = tz.from_utc_datetime(&now_ts.naive_utc()); + datetime_to_days(&local_now) + }, + ); Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::Date32(days), + ScalarValue::Date32(Some(days)), + None, ))) } @@ -115,3 +124,11 @@ impl ScalarUDFImpl for CurrentDateFunc { self.doc() } } + +/// Converts a DateTime to the number of days since Unix epoch (1970-01-01) +fn datetime_to_days(dt: &T) -> i32 { + dt.num_days_from_ce() + - NaiveDate::from_ymd_opt(1970, 1, 1) + .unwrap() + .num_days_from_ce() +} diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index c416d0240b13c..79d5bfc1783c1 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -36,7 +36,7 @@ The `current_time()` return value is determined at query time and will return th "#, syntax_example = "current_time()" )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct CurrentTimeFunc { signature: Signature, } @@ -96,6 +96,7 @@ impl ScalarUDFImpl for CurrentTimeFunc { let nano = now_ts.timestamp_nanos_opt().map(|ts| ts % 86400000000000); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Time64Nanosecond(nano), + None, ))) } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 5ffae46dde48f..74e286de0f584 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -95,7 +95,7 @@ FROM VALUES ('2023-01-01T18:18:18Z'), ('2023-01-03T19:00:03Z') t(time); "# ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct DateBinFunc { signature: Signature, } @@ -505,85 +505,90 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use chrono::TimeDelta; + use datafusion_common::config::ConfigOptions; + + fn invoke_date_bin_with_args( + args: Vec, + number_rows: usize, + return_field: &FieldRef, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Arc::clone(return_field), + config_options: Arc::new(ConfigOptions::default()), + }; + DateBinFunc::new().invoke_with_args(args) + } #[test] fn test_date_bin() { - let mut args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )); + + let mut args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Array(timestamps), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Array(timestamps), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert!(res.is_ok()); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // stride supports month-day-nano - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months: 0, - days: 0, - nanoseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 1, + }, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // @@ -591,33 +596,25 @@ mod tests { // // invalid number of arguments - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - )))], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + )))]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expected two or three arguments" ); // stride: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects stride argument to be an INTERVAL but got Interval(YearMonth)" @@ -625,113 +622,83 @@ mod tests { // stride: invalid value - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 0, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 0, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; - let res = DateBinFunc::new().invoke_with_args(args); + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride must be non-zero" ); // stride: overflow of day-time interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime::MAX, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime::MAX, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: overflow of month-day-nano interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: month intervals - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN stride does not support combination of month, day and nanosecond intervals" ); // origin: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(Microsecond, None)" ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // unsupported array type for stride @@ -745,16 +712,12 @@ mod tests { }) .collect::(), ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(intervals), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Array(intervals), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the stride argument, not arrays" @@ -763,21 +726,15 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Array(timestamps), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Array(timestamps), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the origin argument, not arrays" @@ -893,22 +850,22 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), - ColumnarValue::Array(Arc::new(input)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(string_to_timestamp_nanos(origin).unwrap()), - tz_opt.clone(), - )), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp( - TimeUnit::Nanosecond, + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos(origin).unwrap()), tz_opt.clone(), - ), - }; - let result = DateBinFunc::new().invoke_with_args(args).unwrap(); + )), + ]; + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + )); + let result = + invoke_date_bin_with_args(args, batch_len, return_field).unwrap(); + if let ColumnarValue::Array(result) = result { assert_eq!( result.data_type(), diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index bfd06b39d2067..aa23a5028dd81 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use datafusion_common::types::{logical_date, NativeType}; use datafusion_common::{ @@ -42,7 +42,7 @@ use datafusion_common::{ Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -68,9 +68,10 @@ use datafusion_macros::user_doc; - millisecond - microsecond - nanosecond - - dow (day of the week) + - dow (day of the week where Sunday is 0) - doy (day of the year) - epoch (seconds since Unix epoch) + - isodow (day of the week where Monday is 0) "# ), argument( @@ -78,7 +79,7 @@ use datafusion_macros::user_doc; description = "Time expression to operate on. Can be a constant, column, or function." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct DatePartFunc { signature: Signature, aliases: Vec, @@ -142,10 +143,10 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; field @@ -155,12 +156,13 @@ impl ScalarUDFImpl for DatePartFunc { .filter(|s| !s.is_empty()) .map(|part| { if is_epoch(part) { - ReturnInfo::new_nullable(DataType::Float64) + Field::new(self.name(), DataType::Float64, true) } else { - ReturnInfo::new_nullable(DataType::Int32) + Field::new(self.name(), DataType::Int32, true) } }) }) + .map(Arc::new) .map_or_else( || exec_err!("{} requires non-empty constant string", self.name()), Ok, @@ -216,6 +218,7 @@ impl ScalarUDFImpl for DatePartFunc { "qtr" | "quarter" => date_part(array.as_ref(), DatePart::Quarter)?, "doy" => date_part(array.as_ref(), DatePart::DayOfYear)?, "dow" => date_part(array.as_ref(), DatePart::DayOfWeekSunday0)?, + "isodow" => date_part(array.as_ref(), DatePart::DayOfWeekMonday0)?, "epoch" => epoch(array.as_ref())?, _ => return exec_err!("Date part '{part}' not supported"), } @@ -231,6 +234,7 @@ impl ScalarUDFImpl for DatePartFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index ed3eb228bf034..405aabfde9917 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -28,11 +28,13 @@ use arrow::array::types::{ ArrowTimestampType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow::array::{Array, PrimitiveArray}; +use arrow::array::{Array, ArrayRef, Int64Array, PrimitiveArray}; use arrow::datatypes::DataType::{self, Null, Timestamp, Utf8, Utf8View}; use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second}; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_datafusion_err, exec_err, plan_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -60,6 +62,8 @@ use chrono::{ - hour / HOUR - minute / MINUTE - second / SECOND + - millisecond / MILLISECOND + - microsecond / MICROSECOND "# ), argument( @@ -67,7 +71,7 @@ use chrono::{ description = "Time expression to operate on. Can be a constant, column, or function." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct DateTruncFunc { signature: Signature, aliases: Vec, @@ -185,6 +189,26 @@ impl ScalarUDFImpl for DateTruncFunc { ) -> Result { let parsed_tz = parse_tz(tz_opt)?; let array = as_primitive_array::(array)?; + + // fast path for fine granularities + if matches!( + granularity.as_str(), + // For modern timezones, it's correct to truncate "minute" in this way. + // Both datafusion and arrow are ignoring historical timezone's non-minute granularity + // bias (e.g., Asia/Kathmandu before 1919 is UTC+05:41:16). + "second" | "minute" | "millisecond" | "microsecond" + ) || + // In UTC, "hour" and "day" have uniform durations and can be truncated with simple arithmetic + (parsed_tz.is_none() && matches!(granularity.as_str(), "hour" | "day")) + { + let result = general_date_trunc_array_fine_granularity( + T::UNIT, + array, + granularity.as_str(), + )?; + return Ok(ColumnarValue::Array(result)); + } + let array: PrimitiveArray = array .try_unary(|x| { general_date_trunc(T::UNIT, x, parsed_tz, granularity.as_str()) @@ -405,16 +429,13 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result to clear the various fields because need to clear per timezone, // and NaiveDateTime (ISO 8601) has no concept of timezones let value = as_datetime_with_timezone::(value, tz) - .ok_or(DataFusionError::Execution(format!( - "Timestamp {value} out of range" - )))?; + .ok_or(exec_datafusion_err!("Timestamp {value} out of range"))?; _date_trunc_coarse_with_tz(granularity, Some(value)) } None => { // Use chrono NaiveDateTime to clear the various fields, if we don't have a timezone. - let value = timestamp_ns_to_datetime(value).ok_or_else(|| { - DataFusionError::Execution(format!("Timestamp {value} out of range")) - })?; + let value = timestamp_ns_to_datetime(value) + .ok_or_else(|| exec_datafusion_err!("Timestamp {value} out of range"))?; _date_trunc_coarse_without_tz(granularity, Some(value)) } }?; @@ -423,6 +444,55 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option) -> Result( + tu: TimeUnit, + array: &PrimitiveArray, + granularity: &str, +) -> Result { + let unit = match (tu, granularity) { + (Second, "minute") => Some(Int64Array::new_scalar(60)), + (Second, "hour") => Some(Int64Array::new_scalar(3600)), + (Second, "day") => Some(Int64Array::new_scalar(86400)), + + (Millisecond, "second") => Some(Int64Array::new_scalar(1_000)), + (Millisecond, "minute") => Some(Int64Array::new_scalar(60_000)), + (Millisecond, "hour") => Some(Int64Array::new_scalar(3_600_000)), + (Millisecond, "day") => Some(Int64Array::new_scalar(86_400_000)), + + (Microsecond, "millisecond") => Some(Int64Array::new_scalar(1_000)), + (Microsecond, "second") => Some(Int64Array::new_scalar(1_000_000)), + (Microsecond, "minute") => Some(Int64Array::new_scalar(60_000_000)), + (Microsecond, "hour") => Some(Int64Array::new_scalar(3_600_000_000)), + (Microsecond, "day") => Some(Int64Array::new_scalar(86_400_000_000)), + + (Nanosecond, "microsecond") => Some(Int64Array::new_scalar(1_000)), + (Nanosecond, "millisecond") => Some(Int64Array::new_scalar(1_000_000)), + (Nanosecond, "second") => Some(Int64Array::new_scalar(1_000_000_000)), + (Nanosecond, "minute") => Some(Int64Array::new_scalar(60_000_000_000)), + (Nanosecond, "hour") => Some(Int64Array::new_scalar(3_600_000_000_000)), + (Nanosecond, "day") => Some(Int64Array::new_scalar(86_400_000_000_000)), + _ => None, + }; + + if let Some(unit) = unit { + let original_type = array.data_type(); + let array = arrow::compute::cast(array, &DataType::Int64)?; + let array = arrow::compute::kernels::numeric::div(&array, &unit)?; + let array = arrow::compute::kernels::numeric::mul(&array, &unit)?; + let array = arrow::compute::cast(&array, original_type)?; + Ok(array) + } else { + // truncate to the same or smaller unit + Ok(Arc::new(array.clone())) + } +} + // truncates a single value with the given timeunit to the specified granularity fn general_date_trunc( tu: TimeUnit, @@ -470,9 +540,8 @@ fn general_date_trunc( fn parse_tz(tz: &Option>) -> Result> { tz.as_ref() .map(|tz| { - Tz::from_str(tz).map_err(|op| { - DataFusionError::Execution(format!("failed on timezone {tz}: {:?}", op)) - }) + Tz::from_str(tz) + .map_err(|op| exec_datafusion_err!("failed on timezone {tz}: {op:?}")) }) .transpose() } @@ -487,7 +556,8 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; + use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -726,13 +796,24 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", input.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -874,6 +955,21 @@ mod tests { "2018-11-04T02:00:00-02", ], ), + ( + vec![ + "2024-10-26T23:30:00Z", + "2024-10-27T00:30:00Z", + "2024-10-27T01:30:00Z", + "2024-10-27T02:30:00Z", + ], + Some("Asia/Kathmandu".into()), // UTC+5:45 + vec![ + "2024-10-27T05:00:00+05:45", + "2024-10-27T06:00:00+05:45", + "2024-10-27T07:00:00+05:45", + "2024-10-27T08:00:00+05:45", + ], + ), ]; cases.iter().for_each(|(original, tz_opt, expected)| { @@ -888,13 +984,24 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", input.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index ed8181452dbd9..5d6adfb6f119a 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -18,20 +18,19 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.", + description = "Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.", syntax_example = "from_unixtime(expression[, timezone])", sql_example = r#"```sql > select from_unixtime(1599572549, 'America/New_York'); @@ -47,7 +46,7 @@ use datafusion_macros::user_doc; description = "Optional timezone to use when converting the integer to a timestamp. If not provided, the default timezone is UTC." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct FromUnixtimeFunc { signature: Signature, } @@ -82,12 +81,12 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert!(matches!(args.scalar_arguments.len(), 1 | 2)); if args.scalar_arguments.len() == 1 { - Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) + Ok(Field::new(self.name(), Timestamp(Second, None), true).into()) } else { args.scalar_arguments[1] .and_then(|sv| { @@ -95,12 +94,14 @@ impl ScalarUDFImpl for FromUnixtimeFunc { .flatten() .filter(|s| !s.is_empty()) .map(|tz| { - ReturnInfo::new_nullable(Timestamp( - Second, - Some(Arc::from(tz.to_string())), - )) + Field::new( + self.name(), + Timestamp(Second, Some(Arc::from(tz.to_string()))), + true, + ) }) }) + .map(Arc::new) .map_or_else( || { exec_err!( @@ -114,7 +115,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("call return_type_from_args instead") + internal_err!("call return_field_from_args instead") } fn invoke_with_args( @@ -132,7 +133,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { if args[0].data_type() != Int64 { return exec_err!( - "Unsupported data type {:?} for function from_unixtime", + "Unsupported data type {} for function from_unixtime", args[0].data_type() ); } @@ -144,7 +145,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { .cast_to(&Timestamp(Second, Some(Arc::from(tz.to_string()))), None), _ => { exec_err!( - "Unsupported data type {:?} for function from_unixtime", + "Unsupported data type {} for function from_unixtime", args[1].data_type() ) } @@ -161,8 +162,9 @@ impl ScalarUDFImpl for FromUnixtimeFunc { #[cfg(test)] mod test { use crate::datetime::from_unixtime::FromUnixtimeFunc; - use arrow::datatypes::DataType; use arrow::datatypes::TimeUnit::Second; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::Int64; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -170,10 +172,13 @@ mod test { #[test] fn test_without_timezone() { + let arg_field = Arc::new(Field::new("a", DataType::Int64, true)); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Timestamp(Second, None), + return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -187,6 +192,10 @@ mod test { #[test] fn test_with_timezone() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Utf8, true).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(Int64(Some(1729900800))), @@ -194,11 +203,15 @@ mod test { "America/New_York".to_string(), ))), ], + arg_fields, number_rows: 2, - return_type: &DataType::Timestamp( - Second, - Some(Arc::from("America/New_York")), - ), + return_field: Field::new( + "f", + DataType::Timestamp(Second, Some(Arc::from("America/New_York"))), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 929fa601f1076..0fe5d156a8383 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -51,7 +51,7 @@ use datafusion_macros::user_doc; +-----------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) "#, argument( name = "year", @@ -66,7 +66,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo description = "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct MakeDateFunc { signature: Signature, } @@ -122,6 +122,13 @@ impl ScalarUDFImpl for MakeDateFunc { let [years, months, days] = take_function_args(self.name(), args)?; + if matches!(years, ColumnarValue::Scalar(ScalarValue::Null)) + || matches!(months, ColumnarValue::Scalar(ScalarValue::Null)) + || matches!(days, ColumnarValue::Scalar(ScalarValue::Null)) + { + return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + } + let years = years.cast_to(&Int32, None)?; let months = months.cast_to(&Int32, None)?; let days = days.cast_to(&Int32, None)?; @@ -223,25 +230,41 @@ fn make_date_inner( mod tests { use crate::datetime::make_date::MakeDateFunc; use arrow::array::{Array, Date32Array, Int32Array, Int64Array, UInt32Array}; - use arrow::datatypes::DataType; - use datafusion_common::ScalarValue; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; + fn invoke_make_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Date32, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + MakeDateFunc::new().invoke_with_args(args) + } + #[test] fn test_make_date() { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -249,18 +272,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -268,18 +288,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -291,18 +308,15 @@ mod tests { let months = Arc::new((1..5).map(Some).collect::()); let days = Arc::new((11..15).map(Some).collect::()); let batch_len = years.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Array(years), ColumnarValue::Array(months), ColumnarValue::Array(days), ], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + batch_len, + ) + .unwrap(); if let ColumnarValue::Array(array) = res { assert_eq!(array.len(), 4); @@ -321,63 +335,70 @@ mod tests { // // invalid number of arguments - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + let res = invoke_make_date_with_args( + vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: make_date function requires 3 arguments, got 1" ); // invalid type - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Casting from Interval(YearMonth) to Int32 not supported" ); // overflow of month - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 18446744073709551615 to type Int32" ); // overflow of day - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 4294967295 to type Int32" ); } + + #[test] + fn test_make_date_null_param() { + let res = invoke_make_date_with_args( + vec![ + ColumnarValue::Scalar(ScalarValue::Null), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), + ], + 1, + ) + .expect("that make_date parsed values without error"); + + assert!(matches!(res, ColumnarValue::Scalar(ScalarValue::Null))); + } } diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index dee40215c9ea5..5729b1edae958 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -30,6 +30,7 @@ pub mod date_trunc; pub mod from_unixtime; pub mod make_date; pub mod now; +pub mod planner; pub mod to_char; pub mod to_date; pub mod to_local_time; diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b26dc52cee4d6..65dadb42a89e1 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -37,7 +37,7 @@ The `now()` return value is determined at query time and will return the same ti "#, syntax_example = "now()" )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NowFunc { signature: Signature, aliases: Vec, @@ -77,15 +77,17 @@ impl ScalarUDFImpl for NowFunc { &self.signature } - fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { - Ok(ReturnInfo::new_non_nullable(Timestamp( - Nanosecond, - Some("+00:00".into()), - ))) + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new( + self.name(), + Timestamp(Nanosecond, Some("+00:00".into())), + false, + ) + .into()) } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } fn invoke_with_args( @@ -106,6 +108,7 @@ impl ScalarUDFImpl for NowFunc { .timestamp_nanos_opt(); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), + None, ))) } diff --git a/datafusion/functions/src/datetime/planner.rs b/datafusion/functions/src/datetime/planner.rs new file mode 100644 index 0000000000000..f4b64c3711e2c --- /dev/null +++ b/datafusion/functions/src/datetime/planner.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL planning extensions like [`DatetimeFunctionPlanner`] +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::{ExprPlanner, PlannerResult}; +use datafusion_expr::Expr; + +#[derive(Default, Debug)] +pub struct DatetimeFunctionPlanner; + +impl ExprPlanner for DatetimeFunctionPlanner { + fn plan_extract( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::datetime::date_part(), args), + ))) + } +} diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 8b2e5ad874717..7d9b2bc241e1a 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use arrow::array::cast::AsArray; use arrow::array::{new_null_array, Array, ArrayRef, StringArray}; +use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{ Date32, Date64, Duration, Time32, Time64, Timestamp, Utf8, @@ -27,7 +28,6 @@ use arrow::datatypes::DataType::{ use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::error::ArrowError; use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; - use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -48,7 +48,7 @@ use datafusion_macros::user_doc; +----------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) "#, argument( name = "expression", @@ -63,7 +63,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo description = "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToCharFunc { signature: Signature, aliases: Vec, @@ -145,14 +145,14 @@ impl ScalarUDFImpl for ToCharFunc { match format { ColumnarValue::Scalar(ScalarValue::Utf8(None)) | ColumnarValue::Scalar(ScalarValue::Null) => { - _to_char_scalar(date_time.clone(), None) + to_char_scalar(date_time.clone(), None) } // constant format ColumnarValue::Scalar(ScalarValue::Utf8(Some(format))) => { // invoke to_char_scalar with the known string, without converting to array - _to_char_scalar(date_time.clone(), Some(format)) + to_char_scalar(date_time.clone(), Some(format)) } - ColumnarValue::Array(_) => _to_char_array(&args), + ColumnarValue::Array(_) => to_char_array(&args), _ => { exec_err!( "Format for `to_char` must be non-null Utf8, received {:?}", @@ -165,12 +165,13 @@ impl ScalarUDFImpl for ToCharFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } -fn _build_format_options<'a>( +fn build_format_options<'a>( data_type: &DataType, format: Option<&'a str>, ) -> Result, Result> { @@ -178,7 +179,9 @@ fn _build_format_options<'a>( return Ok(FormatOptions::new()); }; let format_options = match data_type { - Date32 => FormatOptions::new().with_date_format(Some(format)), + Date32 => FormatOptions::new() + .with_date_format(Some(format)) + .with_datetime_format(Some(format)), Date64 => FormatOptions::new().with_datetime_format(Some(format)), Time32(_) => FormatOptions::new().with_time_format(Some(format)), Time64(_) => FormatOptions::new().with_time_format(Some(format)), @@ -202,7 +205,7 @@ fn _build_format_options<'a>( } /// Special version when arg\[1] is a scalar -fn _to_char_scalar( +fn to_char_scalar( expression: ColumnarValue, format: Option<&str>, ) -> Result { @@ -210,17 +213,17 @@ fn _to_char_scalar( // of the implementation in arrow-rs we need to convert it to an array let data_type = &expression.data_type(); let is_scalar_expression = matches!(&expression, ColumnarValue::Scalar(_)); - let array = expression.into_array(1)?; + let array = expression.clone().into_array(1)?; if format.is_none() { - if is_scalar_expression { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + return if is_scalar_expression { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } else { - return Ok(ColumnarValue::Array(new_null_array(&Utf8, array.len()))); - } + Ok(ColumnarValue::Array(new_null_array(&Utf8, array.len()))) + }; } - let format_options = match _build_format_options(data_type, format) { + let format_options = match build_format_options(data_type, format) { Ok(value) => value, Err(value) => return value, }; @@ -247,11 +250,17 @@ fn _to_char_scalar( )) } } else { + // if the data type was a Date32, formatting could have failed because the format string + // contained datetime specifiers, so we'll retry by casting the date array as a timestamp array + if data_type == &Date32 { + return to_char_scalar(expression.clone().cast_to(&Date64, None)?, format); + } + exec_err!("{}", formatted.unwrap_err()) } } -fn _to_char_array(args: &[ColumnarValue]) -> Result { +fn to_char_array(args: &[ColumnarValue]) -> Result { let arrays = ColumnarValue::values_to_arrays(args)?; let mut results: Vec> = vec![]; let format_array = arrays[1].as_string::(); @@ -267,7 +276,7 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { results.push(None); continue; } - let format_options = match _build_format_options(data_type, format) { + let format_options = match build_format_options(data_type, format) { Ok(value) => value, Err(value) => return value, }; @@ -277,7 +286,25 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { let result = formatter.value(idx).try_to_string(); match result { Ok(value) => results.push(Some(value)), - Err(e) => return exec_err!("{}", e), + Err(e) => { + // if the data type was a Date32, formatting could have failed because the format string + // contained datetime specifiers, so we'll treat this specific date element as a timestamp + if data_type == &Date32 { + let failed_date_value = arrays[0].slice(idx, 1); + + match retry_date_as_timestamp(failed_date_value, &format_options) { + Ok(value) => { + results.push(Some(value)); + continue; + } + Err(e) => { + return exec_err!("{}", e); + } + } + } + + return exec_err!("{}", e); + } } } @@ -294,6 +321,19 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { } } +fn retry_date_as_timestamp( + array_ref: ArrayRef, + format_options: &FormatOptions, +) -> Result { + let target_data_type = Date64; + + let date_value = cast(&array_ref, &target_data_type)?; + let formatter = ArrayFormatter::try_new(date_value.as_ref(), format_options)?; + let result = formatter.value(0).try_to_string()?; + + Ok(result) +} + #[cfg(test)] mod tests { use crate::datetime::to_char::ToCharFunc; @@ -303,12 +343,52 @@ mod tests { TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::{NaiveDateTime, Timelike}; + use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; + #[test] + fn test_array_array() { + let array_array_data = vec![( + Arc::new(Date32Array::from(vec![18506, 18507])) as ArrayRef, + StringArray::from(vec!["%Y::%m::%d", "%Y::%m::%d %S::%M::%H %f"]), + StringArray::from(vec!["2020::09::01", "2020::09::02 00::00::00 000000000"]), + )]; + + for (value, format, expected) in array_array_data { + let batch_len = value.len(); + let value_data_type = value.data_type().clone(); + let format_data_type = format.data_type().clone(); + + let args = datafusion_expr::ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(value), + ColumnarValue::Array(Arc::new(format) as ArrayRef), + ], + arg_fields: vec![ + Field::new("a", value_data_type, true).into(), + Field::new("b", format_data_type, true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&Arc::new(ConfigOptions::default())), + }; + let result = ToCharFunc::new() + .invoke_with_args(args) + .expect("that to_char parsed values without error"); + + if let ColumnarValue::Array(result) = result { + assert_eq!(result.len(), 2); + assert_eq!(&expected as &dyn Array, result.as_ref()); + } else { + panic!("Expected an array value") + } + } + } + #[test] fn test_to_char() { let date = "2020-01-02T03:04:05" @@ -328,6 +408,11 @@ mod tests { ScalarValue::Utf8(Some("%Y::%m::%d".to_string())), "2020::09::01".to_string(), ), + ( + ScalarValue::Date32(Some(18506)), + ScalarValue::Utf8(Some("%Y::%m::%d %S::%M::%H %f".to_string())), + "2020::09::01 00::00::00 000000000".to_string(), + ), ( ScalarValue::Date64(Some(date.and_utc().timestamp_millis())), ScalarValue::Utf8(Some("%Y::%m::%d".to_string())), @@ -385,10 +470,16 @@ mod tests { ]; for (value, format, expected) in scalar_data { + let arg_fields = vec![ + Field::new("a", value.data_type(), false).into(), + Field::new("a", format.data_type(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -407,6 +498,11 @@ mod tests { StringArray::from(vec!["%Y::%m::%d".to_string()]), "2020::09::01".to_string(), ), + ( + ScalarValue::Date32(Some(18506)), + StringArray::from(vec!["%Y::%m::%d %S::%M::%H %f".to_string()]), + "2020::09::01 00::00::00 000000000".to_string(), + ), ( ScalarValue::Date64(Some(date.and_utc().timestamp_millis())), StringArray::from(vec!["%Y::%m::%d".to_string()]), @@ -465,13 +561,19 @@ mod tests { for (value, format, expected) in scalar_array_data { let batch_len = format.len(); + let arg_fields = vec![ + Field::new("a", value.data_type(), false).into(), + Field::new("a", format.data_type().to_owned(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -490,6 +592,14 @@ mod tests { ScalarValue::Utf8(Some("%Y::%m::%d".to_string())), StringArray::from(vec!["2020::09::01", "2020::09::02"]), ), + ( + Arc::new(Date32Array::from(vec![18506, 18507])) as ArrayRef, + ScalarValue::Utf8(Some("%Y::%m::%d %S::%M::%H %f".to_string())), + StringArray::from(vec![ + "2020::09::01 00::00::00 000000000", + "2020::09::02 00::00::00 000000000", + ]), + ), ( Arc::new(Date64Array::from(vec![ date.and_utc().timestamp_millis(), @@ -506,6 +616,25 @@ mod tests { StringArray::from(vec!["%Y::%m::%d", "%d::%m::%Y"]), StringArray::from(vec!["2020::09::01", "02::09::2020"]), ), + ( + Arc::new(Date32Array::from(vec![18506, 18507])) as ArrayRef, + StringArray::from(vec![ + "%Y::%m::%d %S::%M::%H %f", + "%Y::%m::%d %S::%M::%H %f", + ]), + StringArray::from(vec![ + "2020::09::01 00::00::00 000000000", + "2020::09::02 00::00::00 000000000", + ]), + ), + ( + Arc::new(Date32Array::from(vec![18506, 18507])) as ArrayRef, + StringArray::from(vec!["%Y::%m::%d", "%Y::%m::%d %S::%M::%H %f"]), + StringArray::from(vec![ + "2020::09::01", + "2020::09::02 00::00::00 000000000", + ]), + ), ( Arc::new(Date64Array::from(vec![ date.and_utc().timestamp_millis(), @@ -596,13 +725,19 @@ mod tests { for (value, format, expected) in array_scalar_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false).into(), + Field::new("a", format.data_type(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -618,13 +753,19 @@ mod tests { for (value, format, expected) in array_array_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false).into(), + Field::new("a", format.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -643,10 +784,13 @@ mod tests { // // invalid number of arguments + let arg_field = Field::new("a", DataType::Int32, true).into(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -655,13 +799,19 @@ mod tests { ); // invalid type + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("a", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 91740b2c31c11..3840c8d8bbb94 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -39,7 +39,7 @@ Returns the corresponding date. Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`.", syntax_example = "to_date('2017-05-31', '%Y-%m-%d')", sql_example = r#"```sql -> select to_date('2023-01-31'); +> select to_date('2023-01-31'); +-------------------------------+ | to_date(Utf8("2023-01-31")) | +-------------------------------+ @@ -53,7 +53,7 @@ Note: `to_date` returns Date32, which represents its values as the number of day +---------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) "#, standard_argument(name = "expression", prefix = "String"), argument( @@ -63,7 +63,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo an error will be returned." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToDateFunc { signature: Signature, } @@ -150,7 +150,7 @@ impl ScalarUDFImpl for ToDateFunc { } Utf8View | LargeUtf8 | Utf8 => self.to_date(&args), other => { - exec_err!("Unsupported data type {:?} for function to_date", other) + exec_err!("Unsupported data type {} for function to_date", other) } } } @@ -162,14 +162,33 @@ impl ScalarUDFImpl for ToDateFunc { #[cfg(test)] mod tests { + use super::ToDateFunc; use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; - use datafusion_common::ScalarValue; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; - use super::ToDateFunc; + fn invoke_to_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Date32, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + ToDateFunc::new().invoke_with_args(args) + } #[test] fn test_to_date_without_format() { @@ -208,12 +227,8 @@ mod tests { } fn test_scalar(sv: ScalarValue, tc: &TestCase) { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(sv)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(sv)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -234,12 +249,10 @@ mod tests { { let date_array = A::from(vec![tc.date_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(date_array))], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Array(Arc::new(date_array))], + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -328,15 +341,13 @@ mod tests { fn test_scalar(sv: ScalarValue, tc: &TestCase) { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -358,15 +369,13 @@ mod tests { let format_array = A::from(vec![tc.format_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Array(Arc::new(date_array)), ColumnarValue::Array(Arc::new(format_array)), ], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -398,16 +407,14 @@ mod tests { let format1_scalar = ScalarValue::Utf8(Some("%Y-%m-%d".into())); let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(formatted_date_scalar), ColumnarValue::Scalar(format1_scalar), ColumnarValue::Scalar(format2_scalar), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -431,19 +438,17 @@ mod tests { for date_str in test_cases { let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(formatted_date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Scalar(formatted_date_scalar)], + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { let expected = Date32Type::parse_formatted("2020-09-08", "%Y-%m-%d"); assert_eq!(date_val, expected, "to_date created wrong value"); } - _ => panic!("Conversion of {} failed", date_str), + _ => panic!("Conversion of {date_str} failed"), } } } @@ -453,23 +458,18 @@ mod tests { let date_str = "20241231"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { let expected = Date32Type::parse_formatted("2024-12-31", "%Y-%m-%d"); assert_eq!( date_val, expected, - "to_date created wrong value for {}", - date_str + "to_date created wrong value for {date_str}" ); } - _ => panic!("Conversion of {} failed", date_str), + _ => panic!("Conversion of {date_str} failed"), } } @@ -478,18 +478,11 @@ mod tests { let date_str = "202412311"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); if let Ok(ColumnarValue::Scalar(ScalarValue::Date32(_))) = to_date_result { - panic!( - "Conversion of {} succeeded, but should have failed, ", - date_str - ); + panic!("Conversion of {date_str} succeeded, but should have failed. "); } } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 8dbef90cdc3f3..a2a54398a33bf 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -31,7 +31,8 @@ use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{ - exec_err, plan_err, utils::take_function_args, DataFusionError, Result, ScalarValue, + exec_err, internal_datafusion_err, plan_err, utils::take_function_args, Result, + ScalarValue, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -96,7 +97,7 @@ FROM ( description = "Time expression to operate on. Can be a constant, column, or function." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToLocalTimeFunc { signature: Signature, } @@ -326,15 +327,15 @@ fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { // This should not fail under normal circumstances as the // maximum possible offset is 26 hours (93,600 seconds) TimeDelta::try_seconds(offset_seconds) - .ok_or(DataFusionError::Internal("Offset seconds should be less than i64::MAX / 1_000 or greater than -i64::MAX / 1_000".to_string()))?, + .ok_or_else(|| internal_datafusion_err!("Offset seconds should be less than i64::MAX / 1_000 or greater than -i64::MAX / 1_000"))?, ); // convert the naive datetime back to i64 match T::UNIT { - Nanosecond => adjusted_date_time.timestamp_nanos_opt().ok_or( - DataFusionError::Internal( - "Failed to convert DateTime to timestamp in nanosecond. This error may occur if the date is out of range. The supported date ranges are between 1677-09-21T00:12:43.145224192 and 2262-04-11T23:47:16.854775807".to_string(), - ), + Nanosecond => adjusted_date_time.timestamp_nanos_opt().ok_or_else(|| + internal_datafusion_err!( + "Failed to convert DateTime to timestamp in nanosecond. This error may occur if the date is out of range. The supported date ranges are between 1677-09-21T00:12:43.145224192 and 2262-04-11T23:47:16.854775807" + ) ), Microsecond => Ok(adjusted_date_time.timestamp_micros()), Millisecond => Ok(adjusted_date_time.timestamp_millis()), @@ -372,7 +373,7 @@ impl ScalarUDFImpl for ToLocalTimeFunc { ) -> Result { let [time_value] = take_function_args(self.name(), args.args)?; - self.to_local_time(&[time_value.clone()]) + self.to_local_time(std::slice::from_ref(&time_value)) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { @@ -385,6 +386,7 @@ impl ScalarUDFImpl for ToLocalTimeFunc { let first_arg = arg_types[0].clone(); match &first_arg { + DataType::Null => Ok(vec![Timestamp(Nanosecond, None)]), Timestamp(Nanosecond, timezone) => { Ok(vec![Timestamp(Nanosecond, timezone.clone())]) } @@ -407,10 +409,11 @@ impl ScalarUDFImpl for ToLocalTimeFunc { mod tests { use std::sync::Arc; - use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::array::{types::TimestampNanosecondType, Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::NaiveDateTime; + use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -538,11 +541,14 @@ mod tests { } fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let arg_field = Field::new("a", input.data_type(), true).into(); let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(input)], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &expected.data_type(), + return_field: Field::new("f", expected.data_type(), true).into(), + config_options: Arc::new(ConfigOptions::default()), }) .unwrap(); match res { @@ -602,10 +608,18 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); let batch_size = input.len(); + let arg_field = Field::new("a", input.data_type().clone(), true).into(); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(input))], + arg_fields: vec![arg_field], number_rows: batch_size, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 52c86733f3327..dcd52aa07be38 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -19,12 +19,14 @@ use std::any::Any; use std::sync::Arc; use crate::datetime::common::*; +use arrow::array::Float64Array; use arrow::datatypes::DataType::*; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{ ArrowTimestampType, DataType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; +use datafusion_common::format::DEFAULT_CAST_OPTIONS; use datafusion_common::{exec_err, Result, ScalarType, ScalarValue}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -53,7 +55,7 @@ Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for in | 2023-05-17T03:59:00.123456789 | +--------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) "#, argument( name = "expression", @@ -64,7 +66,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampFunc { signature: Signature, } @@ -87,7 +89,7 @@ pub struct ToTimestampFunc { | 2023-05-17T03:59:00 | +----------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) "#, argument( name = "expression", @@ -98,7 +100,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampSecondsFunc { signature: Signature, } @@ -121,7 +123,7 @@ pub struct ToTimestampSecondsFunc { | 2023-05-17T03:59:00.123 | +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) "#, argument( name = "expression", @@ -132,7 +134,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampMillisFunc { signature: Signature, } @@ -155,7 +157,7 @@ pub struct ToTimestampMillisFunc { | 2023-05-17T03:59:00.123456 | +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) "#, argument( name = "expression", @@ -166,7 +168,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampMicrosFunc { signature: Signature, } @@ -189,7 +191,7 @@ pub struct ToTimestampMicrosFunc { | 2023-05-17T03:59:00.123456789 | +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) "#, argument( name = "expression", @@ -200,7 +202,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampNanosFunc { signature: Signature, } @@ -319,9 +321,22 @@ impl ScalarUDFImpl for ToTimestampFunc { Int32 | Int64 => args[0] .cast_to(&Timestamp(Second, None), None)? .cast_to(&Timestamp(Nanosecond, None), None), - Null | Float64 | Timestamp(_, None) => { + Null | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } + Float64 => { + let rescaled = arrow::compute::kernels::numeric::mul( + &args[0].to_array(1)?, + &arrow::array::Scalar::new(Float64Array::from(vec![ + 1_000_000_000f64, + ])), + )?; + Ok(ColumnarValue::Array(arrow::compute::cast_with_options( + &rescaled, + &Timestamp(Nanosecond, None), + &DEFAULT_CAST_OPTIONS, + )?)) + } Timestamp(_, Some(tz)) => { args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) } @@ -353,10 +368,7 @@ impl ScalarUDFImpl for ToTimestampFunc { } } other => { - exec_err!( - "Unsupported data type {:?} for function to_timestamp", - other - ) + exec_err!("Unsupported data type {other} for function to_timestamp") } } } @@ -409,7 +421,7 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } other => { exec_err!( - "Unsupported data type {:?} for function to_timestamp_seconds", + "Unsupported data type {} for function to_timestamp_seconds", other ) } @@ -467,7 +479,7 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { ), other => { exec_err!( - "Unsupported data type {:?} for function to_timestamp_millis", + "Unsupported data type {} for function to_timestamp_millis", other ) } @@ -525,7 +537,7 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { ), other => { exec_err!( - "Unsupported data type {:?} for function to_timestamp_micros", + "Unsupported data type {} for function to_timestamp_micros", other ) } @@ -582,7 +594,7 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } other => { exec_err!( - "Unsupported data type {:?} for function to_timestamp_nanos", + "Unsupported data type {} for function to_timestamp_nanos", other ) } @@ -639,8 +651,9 @@ mod tests { TimestampNanosecondArray, TimestampSecondArray, }; use arrow::array::{ArrayRef, Int64Array, StringBuilder}; - use arrow::datatypes::TimeUnit; + use arrow::datatypes::{Field, TimeUnit}; use chrono::Utc; + use datafusion_common::config::ConfigOptions; use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; use datafusion_expr::ScalarFunctionImplementation; @@ -788,7 +801,7 @@ mod tests { } #[test] - fn to_timestamp_with_unparseable_data() -> Result<()> { + fn to_timestamp_with_unparsable_data() -> Result<()> { let mut date_string_builder = StringBuilder::with_capacity(2, 1024); date_string_builder.append_null(); @@ -940,7 +953,7 @@ mod tests { let expected = format!("Execution error: Error parsing timestamp from '{s}' using format '{f}': {ctx}"); let actual = string_to_datetime_formatted(&Utc, s, f) .unwrap_err() - .to_string(); + .strip_backtrace(); assert_eq!(actual, expected) } } @@ -968,7 +981,7 @@ mod tests { let expected = format!("Execution error: Error parsing timestamp from '{s}' using format '{f}': {ctx}"); let actual = string_to_datetime_formatted(&Utc, s, f) .unwrap_err() - .to_string(); + .strip_backtrace(); assert_eq!(actual, expected) } } @@ -1012,11 +1025,14 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); + let arg_field = Field::new("arg", array.data_type().clone(), true).into(); assert!(matches!(rt, Timestamp(_, Some(_)))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &rt, + return_field: Field::new("f", rt, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let res = udf .invoke_with_args(args) @@ -1060,10 +1076,13 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); + let arg_field = Field::new("arg", array.data_type().clone(), true).into(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_fields: vec![arg_field], number_rows: 5, - return_type: &rt, + return_field: Field::new("f", rt, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 653ec10851695..42651cd537162 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -54,7 +54,7 @@ use std::any::Any; description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToUnixtimeFunc { signature: Signature, } @@ -118,7 +118,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc { .invoke_with_args(args)? .cast_to(&DataType::Int64, None), other => { - exec_err!("Unsupported data type {:?} for function to_unixtime", other) + exec_err!("Unsupported data type {} for function to_unixtime", other) } } } diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 51e8c6968866a..5baa91936320d 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -30,7 +30,7 @@ use datafusion_common::{ not_impl_err, plan_err, utils::take_function_args, }; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation}; use std::sync::Arc; @@ -54,7 +54,7 @@ use std::any::Any; ), related_udf(name = "decode") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct EncodeFunc { signature: Signature, } @@ -147,7 +147,7 @@ impl ScalarUDFImpl for EncodeFunc { argument(name = "format", description = "Same arguments as [encode](#encode)"), related_udf(name = "encode") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct DecodeFunc { signature: Signature, } @@ -309,18 +309,15 @@ fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { // only write input / 2 bytes to buf let out_len = input.len() / 2; let buf = &mut buf[..out_len]; - hex::decode_to_slice(input, buf).map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from hex: {}", e)) - })?; + hex::decode_to_slice(input, buf) + .map_err(|e| internal_datafusion_err!("Failed to decode from hex: {e}"))?; Ok(out_len) } fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { general_purpose::STANDARD_NO_PAD .decode_slice(input, buf) - .map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) - }) + .map_err(|e| internal_datafusion_err!("Failed to decode from base64: {e}")) } macro_rules! encode_to_array { @@ -418,17 +415,13 @@ impl Encoding { general_purpose::STANDARD_NO_PAD .decode(value) .map_err(|e| { - DataFusionError::Internal(format!( - "Failed to decode value using base64: {}", - e - )) + internal_datafusion_err!( + "Failed to decode value using base64: {e}" + ) })? } Self::Hex => hex::decode(value).map_err(|e| { - DataFusionError::Internal(format!( - "Failed to decode value using hex: {}", - e - )) + internal_datafusion_err!("Failed to decode value using hex: {e}") })?, }; @@ -446,17 +439,13 @@ impl Encoding { general_purpose::STANDARD_NO_PAD .decode(value) .map_err(|e| { - DataFusionError::Internal(format!( - "Failed to decode value using base64: {}", - e - )) + internal_datafusion_err!( + "Failed to decode value using base64: {e}" + ) })? } Self::Hex => hex::decode(value).map_err(|e| { - DataFusionError::Internal(format!( - "Failed to decode value using hex: {}", - e - )) + internal_datafusion_err!("Failed to decode value using hex: {e}") })?, }; diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index b65c4c5432427..7eb32b7ed795b 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -191,6 +191,13 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) } +#[cfg(test)] +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::try_init(); +} + #[cfg(test)] mod tests { use crate::all_default_functions; @@ -209,8 +216,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index d2849c3abba0d..228d704e29cb5 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -40,11 +40,12 @@ /// Exported functions accept: /// - `Vec` argument (single argument followed by a comma) /// - Variable number of `Expr` arguments (zero or more arguments, must be without commas) +#[macro_export] macro_rules! export_functions { ($(($FUNC:ident, $DOC:expr, $($arg:tt)*)),*) => { $( // switch to single-function cases below - export_functions!(single $FUNC, $DOC, $($arg)*); + $crate::export_functions!(single $FUNC, $DOC, $($arg)*); )* }; @@ -69,8 +70,10 @@ macro_rules! export_functions { /// named `$NAME` which returns that singleton. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. +#[macro_export] macro_rules! make_udf_function { ($UDF:ty, $NAME:ident) => { + #[allow(rustdoc::redundant_explicit_links)] #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] pub fn $NAME() -> std::sync::Arc { // Singleton instance of the function @@ -119,7 +122,7 @@ macro_rules! make_stub_package { macro_rules! downcast_named_arg { ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - internal_datafusion_err!( + datafusion_common::internal_datafusion_err!( "could not cast {} to {}", $NAME, std::any::type_name::<$ARRAY_TYPE>() @@ -136,7 +139,7 @@ macro_rules! downcast_named_arg { #[macro_export] macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ - downcast_named_arg!($ARG, "", $ARRAY_TYPE) + $crate::downcast_named_arg!($ARG, "", $ARRAY_TYPE) }}; } @@ -152,7 +155,7 @@ macro_rules! downcast_arg { /// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_unary_udf { ($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { - make_udf_function!($NAME::$UDF, $NAME); + $crate::make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; @@ -168,7 +171,7 @@ macro_rules! make_math_unary_udf { Signature, Volatility, }; - #[derive(Debug)] + #[derive(Debug, PartialEq, Eq, Hash)] pub struct $UDF { signature: Signature, } @@ -266,7 +269,7 @@ macro_rules! make_math_unary_udf { /// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_binary_udf { ($UDF:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { - make_udf_function!($NAME::$UDF, $NAME); + $crate::make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; @@ -282,7 +285,7 @@ macro_rules! make_math_binary_udf { Signature, Volatility, }; - #[derive(Debug)] + #[derive(Debug, PartialEq, Eq, Hash)] pub struct $UDF { signature: Signature, } diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 0c686a59016ac..040f13c014493 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -21,14 +21,12 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, + ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, }; use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::{ - internal_datafusion_err, not_impl_err, utils::take_function_args, Result, -}; +use datafusion_common::{not_impl_err, utils::take_function_args, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -100,6 +98,8 @@ fn create_abs_function(input_data_type: &DataType) -> Result | DataType::UInt64 => Ok(|input: &ArrayRef| Ok(Arc::clone(input))), // Decimal types + DataType::Decimal32(_, _) => Ok(make_decimal_abs_function!(Decimal32Array)), + DataType::Decimal64(_, _) => Ok(make_decimal_abs_function!(Decimal64Array)), DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)), @@ -110,9 +110,17 @@ fn create_abs_function(input_data_type: &DataType) -> Result doc_section(label = "Math Functions"), description = "Returns the absolute value of a number.", syntax_example = "abs(numeric_expression)", + sql_example = r#"```sql +> SELECT abs(-5); ++----------+ +| abs(-5) | ++----------+ +| 5 | ++----------+ +```"#, standard_argument(name = "numeric_expression", prefix = "Numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct AbsFunc { signature: Signature, } @@ -156,6 +164,12 @@ impl ScalarUDFImpl for AbsFunc { DataType::UInt16 => Ok(DataType::UInt16), DataType::UInt32 => Ok(DataType::UInt32), DataType::UInt64 => Ok(DataType::UInt64), + DataType::Decimal32(precision, scale) => { + Ok(DataType::Decimal32(precision, scale)) + } + DataType::Decimal64(precision, scale) => { + Ok(DataType::Decimal64(precision, scale)) + } DataType::Decimal128(precision, scale) => { Ok(DataType::Decimal128(precision, scale)) } diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index 4e56212ddbee8..43f2012d073dd 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -32,9 +32,17 @@ use datafusion_macros::user_doc; doc_section(label = "Math Functions"), description = "Returns the cotangent of a number.", syntax_example = r#"cot(numeric_expression)"#, + sql_example = r#"```sql +> SELECT cot(1); ++---------+ +| cot(1) | ++---------+ +| 0.64209 | ++---------+ +```"#, standard_argument(name = "numeric_expression", prefix = "Numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct CotFunc { signature: Signature, } diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index c2ac21b78f212..79f6da94dd0e1 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -26,9 +26,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; -use datafusion_common::{ - arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, -}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -39,9 +37,17 @@ use datafusion_macros::user_doc; doc_section(label = "Math Functions"), description = "Factorial. Returns 1 if value is less than 2.", syntax_example = "factorial(numeric_expression)", + sql_example = r#"```sql +> SELECT factorial(5); ++---------------+ +| factorial(5) | ++---------------+ +| 120 | ++---------------+ +```"#, standard_argument(name = "numeric_expression", prefix = "Numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct FactorialFunc { signature: Signature, } diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 7fe253b4afbc0..0b85e7b54a782 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -34,10 +34,18 @@ use datafusion_macros::user_doc; doc_section(label = "Math Functions"), description = "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.", syntax_example = "gcd(expression_x, expression_y)", + sql_example = r#"```sql +> SELECT gcd(48, 18); ++------------+ +| gcd(48,18) | ++------------+ +| 6 | ++------------+ +```"#, standard_argument(name = "expression_x", prefix = "First numeric"), standard_argument(name = "expression_y", prefix = "Second numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct GcdFunc { signature: Signature, } diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index bc12dfb7898e8..68cd3aca28fdc 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -36,9 +36,17 @@ use crate::utils::make_scalar_function; doc_section(label = "Math Functions"), description = "Returns true if a given number is +0.0 or -0.0 otherwise returns false.", syntax_example = "iszero(numeric_expression)", + sql_example = r#"```sql +> SELECT iszero(0); ++------------+ +| iszero(0) | ++------------+ +| true | ++------------+ +```"#, standard_argument(name = "numeric_expression", prefix = "Numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct IsZeroFunc { signature: Signature, } diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index fc6bf9461f283..bfb20dfd5ce41 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -23,9 +23,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::error::ArrowError; -use datafusion_common::{ - arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result, -}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -39,10 +37,18 @@ use crate::utils::make_scalar_function; doc_section(label = "Math Functions"), description = "Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.", syntax_example = "lcm(expression_x, expression_y)", + sql_example = r#"```sql +> SELECT lcm(4, 5); ++----------+ +| lcm(4,5) | ++----------+ +| 20 | ++----------+ +```"#, standard_argument(name = "expression_x", prefix = "First numeric"), standard_argument(name = "expression_y", prefix = "Second numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct LcmFunc { signature: Signature, } diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index fd135f4c5ec02..ff1fd0cd4b37a 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -22,8 +22,14 @@ use std::sync::Arc; use super::power::PowerFunc; -use arrow::array::{ArrayRef, AsArray}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use crate::utils::{calculate_binary_math, decimal128_to_i128}; +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, + Int64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; +use arrow::error::ArrowError; +use arrow_buffer::i256; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; @@ -42,10 +48,18 @@ use datafusion_macros::user_doc; description = "Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.", syntax_example = r#"log(base, numeric_expression) log(numeric_expression)"#, + sql_example = r#"```sql +> SELECT log(10); ++---------+ +| log(10) | ++---------+ +| 1.0 | ++---------+ +```"#, standard_argument(name = "base", prefix = "Base numeric"), standard_argument(name = "numeric_expression", prefix = "Numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct LogFunc { signature: Signature, } @@ -58,14 +72,37 @@ impl Default for LogFunc { impl LogFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ - Exact(vec![Float32]), - Exact(vec![Float64]), - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), + Numeric(1), + Numeric(2), + Exact(vec![DataType::Float32, DataType::Float32]), + Exact(vec![DataType::Float64, DataType::Float64]), + Exact(vec![ + DataType::Int64, + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Float32, + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Float64, + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Int64, + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Float32, + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Float64, + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + ]), ], Volatility::Immutable, ), @@ -73,6 +110,41 @@ impl LogFunc { } } +/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base +/// Returns error if base is invalid +fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { + if !base.is_finite() || base.trunc() != base { + return Err(ArrowError::ComputeError(format!( + "Log cannot use non-integer base: {base}" + ))); + } + if (base as u32) < 2 { + return Err(ArrowError::ComputeError(format!( + "Log base must be greater than 1: {base}" + ))); + } + + let unscaled_value = decimal128_to_i128(value, scale)?; + if unscaled_value > 0 { + let log_value: u32 = unscaled_value.ilog(base as i128); + Ok(log_value as f64) + } else { + // Reflect f64::log behaviour + Ok(f64::NAN) + } +} + +/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base +/// Returns error if base is invalid or if value is out of bounds of Decimal128 +fn log_decimal256(value: i256, scale: i8, base: f64) -> Result { + match value.to_i128() { + Some(value) => log_decimal128(value, scale, base), + None => Err(ArrowError::NotYetImplemented(format!( + "Log of Decimal256 larger than Decimal128 is not yet supported: {value}" + ))), + } +} + impl ScalarUDFImpl for LogFunc { fn as_any(&self) -> &dyn Any { self @@ -86,7 +158,8 @@ impl ScalarUDFImpl for LogFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[0] { + // Check last argument (value) + match &arg_types.last().ok_or(plan_datafusion_err!("No args"))? { DataType::Float32 => Ok(DataType::Float32), _ => Ok(DataType::Float64), } @@ -121,55 +194,68 @@ impl ScalarUDFImpl for LogFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = ColumnarValue::values_to_arrays(&args.args)?; - let mut base = ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))); - - let mut x = &args[0]; - if args.len() == 2 { - x = &args[1]; - base = ColumnarValue::Array(Arc::clone(&args[0])); - } - // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(x.as_primitive::().unary::<_, Float64Type>( - |value: f64| f64::log(value, base as f64), - )) - } - ColumnarValue::Array(base) => { - let x = x.as_primitive::(); - let base = base.as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float64Type>( - x, - base, - f64::log, - )?; - Arc::new(result) as _ - } - _ => { - return exec_err!("log function requires a scalar or array for base") - } - }, - - DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new( - x.as_primitive::() - .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), + let (base, value) = if args.len() == 2 { + // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) + (ColumnarValue::Array(Arc::clone(&args[0])), &args[1]) + } else { + // log(num) - assume base is 10 + let ret_type = if args[0].data_type().is_null() { + &DataType::Float64 + } else { + args[0].data_type() + }; + ( + ColumnarValue::Array( + ScalarValue::new_ten(ret_type)?.to_array_of_size(args[0].len())?, ), - ColumnarValue::Array(base) => { - let x = x.as_primitive::(); - let base = base.as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float32Type>( - x, - base, - f32::log, - )?; - Arc::new(result) as _ - } - _ => { - return exec_err!("log function requires a scalar or array for base") - } - }, + &args[0], + ) + }; + + // All log functors have format 'log(value, base)' + // Therefore, for `calculate_binary_math` the first type means a type of main array + // The second type is the type of the base array (even if derived from main) + let arr: ArrayRef = match value.data_type() { + DataType::Float32 => calculate_binary_math::< + Float32Type, + Float32Type, + Float32Type, + _, + >(value, &base, |x, b| Ok(f32::log(x, b)))?, + DataType::Float64 => calculate_binary_math::< + Float64Type, + Float64Type, + Float64Type, + _, + >(value, &base, |x, b| Ok(f64::log(x, b)))?, + DataType::Int32 => { + calculate_binary_math::( + value, + &base, + |x, b| Ok(f64::log(x as f64, b)), + )? + } + DataType::Int64 => { + calculate_binary_math::( + value, + &base, + |x, b| Ok(f64::log(x as f64, b)), + )? + } + DataType::Decimal128(_precision, scale) => { + calculate_binary_math::( + value, + &base, + |x, b| log_decimal128(x, *scale, b), + )? + } + DataType::Decimal256(_precision, scale) => { + calculate_binary_math::( + value, + &base, + |x, b| log_decimal256(x, *scale, b), + )? + } other => { return exec_err!("Unsupported data type {other:?} for function log") } @@ -210,7 +296,9 @@ impl ScalarUDFImpl for LogFunc { }; match number { - Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_one(&number_datatype)? => + { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( &info.get_data_type(&base)?, )?))) @@ -254,37 +342,54 @@ mod tests { use super::*; - use arrow::array::{Float32Array, Float64Array, Int64Array}; + use arrow::array::{ + Date32Array, Decimal128Array, Decimal256Array, Float32Array, Float64Array, + }; use arrow::compute::SortOptions; + use arrow::datatypes::{Field, DECIMAL256_MAX_PRECISION}; use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_common::config::ConfigOptions; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; #[test] - #[should_panic] fn test_log_invalid_base_type() { + let arg_fields = vec![ + Field::new("b", DataType::Date32, false).into(), + Field::new("n", DataType::Float64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ + ColumnarValue::Array(Arc::new(Date32Array::from(vec![5, 10, 15, 20]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num - ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], + arg_fields, number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; - let _ = LogFunc::new().invoke_with_args(args); + let result = LogFunc::new().invoke_with_args(args); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string().lines().next().unwrap(), + "Arrow error: Cast error: Casting from Date32 to Float64 not supported" + ); } #[test] fn test_log_invalid_value() { + let arg_field = Field::new("a", DataType::Date32, false).into(); let args = ScalarFunctionArgs { args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num + ColumnarValue::Array(Arc::new(Date32Array::from(vec![10]))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new().invoke_with_args(args); @@ -293,12 +398,15 @@ mod tests { #[test] fn test_log_scalar_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new() .invoke_with_args(args) @@ -320,12 +428,15 @@ mod tests { #[test] fn test_log_scalar_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new() .invoke_with_args(args) @@ -347,13 +458,19 @@ mod tests { #[test] fn test_log_scalar_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("a", DataType::Float32, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ - ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // base ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], + arg_fields, number_rows: 1, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new() .invoke_with_args(args) @@ -375,13 +492,19 @@ mod tests { #[test] fn test_log_scalar_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ - ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // base ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], + arg_fields, number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new() .invoke_with_args(args) @@ -403,14 +526,17 @@ mod tests { #[test] fn test_log_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new() .invoke_with_args(args) @@ -435,14 +561,17 @@ mod tests { #[test] fn test_log_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new() .invoke_with_args(args) @@ -467,17 +596,23 @@ mod tests { #[test] fn test_log_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 2.0, 2.0, 3.0, 5.0, + 2.0, 2.0, 3.0, 5.0, 5.0, ]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 8.0, 4.0, 81.0, 625.0, + 8.0, 4.0, 81.0, 625.0, -123.0, ]))), // num ], - number_rows: 4, - return_type: &DataType::Float64, + arg_fields, + number_rows: 5, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new() .invoke_with_args(args) @@ -488,11 +623,12 @@ mod tests { let floats = as_float64_array(&arr) .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 4); + assert_eq!(floats.len(), 5); assert!((floats.value(0) - 3.0).abs() < 1e-10); assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 4.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); + assert!(floats.value(4).is_nan()); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -502,6 +638,10 @@ mod tests { #[test] fn test_log_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("a", DataType::Float32, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ @@ -511,8 +651,10 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_fields, number_rows: 4, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = LogFunc::new() .invoke_with_args(args) @@ -599,7 +741,7 @@ mod tests { // Test log(num) for order in orders.iter().cloned() { - let result = log.output_ordering(&[order.clone()]).unwrap(); + let result = log.output_ordering(std::slice::from_ref(&order)).unwrap(); assert_eq!(result, order.sort_properties); } @@ -682,4 +824,288 @@ mod tests { SortProperties::Unordered ); } + + #[test] + fn test_log_scalar_decimal128_unary() { + let arg_field = Field::new("a", DataType::Decimal128(38, 0), false).into(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(10), 38, 0)), // num + ], + arg_fields: vec![arg_field], + number_rows: 1, + return_field: Field::new("f", DataType::Decimal128(38, 0), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Decimal128Array"); + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_decimal128() { + let arg_fields = vec![ + Field::new("b", DataType::Float64, false).into(), + Field::new("x", DataType::Decimal128(38, 0), false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // base + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 6.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_decimal128_unary() { + let arg_field = Field::new("a", DataType::Decimal128(38, 0), false).into(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new( + Decimal128Array::from(vec![10, 100, 1000, 10000, 12600, -123]) + .with_precision_and_scale(38, 0) + .unwrap(), + )), // num + ], + arg_fields: vec![arg_field], + number_rows: 6, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 6); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding + assert!(floats.value(5).is_nan()); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_decimal128_base_decimal() { + // Base stays 2 despite scaling + for base in [ + ScalarValue::Decimal128(Some(i128::from(2)), 38, 0), + ScalarValue::Decimal128(Some(i128::from(2000)), 38, 3), + ] { + let arg_fields = vec![ + Field::new("b", DataType::Decimal128(38, 0), false).into(), + Field::new("x", DataType::Decimal128(38, 0), false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(base), // base + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 6.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + } + + #[test] + fn test_log_decimal128_value_scale() { + // Value stays 1000 despite scaling + for value in [ + ScalarValue::Decimal128(Some(i128::from(1000)), 38, 0), + ScalarValue::Decimal128(Some(i128::from(10000)), 38, 1), + ScalarValue::Decimal128(Some(i128::from(1000000)), 38, 3), + ] { + let arg_fields = vec![ + Field::new("b", DataType::Decimal128(38, 0), false).into(), + Field::new("x", DataType::Decimal128(38, 0), false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(value), // base + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 3.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + } + + #[test] + fn test_log_decimal256_unary() { + let arg_field = Field::new( + "a", + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + false, + ) + .into(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new( + Decimal256Array::from(vec![ + Some(i256::from(10)), + Some(i256::from(100)), + Some(i256::from(1000)), + Some(i256::from(10000)), + Some(i256::from(12600)), + // Slightly lower than i128 max - can calculate + Some(i256::from_i128(i128::MAX) - i256::from(1000)), + // Give NaN for incorrect inputs, as in f64::log + Some(i256::from(-123)), + ]) + .with_precision_and_scale(DECIMAL256_MAX_PRECISION, 0) + .unwrap(), + )), // num + ], + arg_fields: vec![arg_field], + number_rows: 7, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 7); + eprintln!("floats {:?}", &floats); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding for float log + assert!((floats.value(5) - 38.0).abs() < 1e-10); + assert!(floats.value(6).is_nan()); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_decimal128_wrong_base() { + let arg_fields = vec![ + Field::new("b", DataType::Float64, false).into(), + Field::new("x", DataType::Decimal128(38, 0), false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(-2.0))), // base + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new().invoke_with_args(args); + assert!(result.is_err()); + assert_eq!( + "Arrow error: Compute error: Log base must be greater than 1: -2", + result.unwrap_err().to_string().lines().next().unwrap() + ); + } + + #[test] + fn test_log_decimal256_error() { + let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Decimal256Array::from(vec![ + // Slightly larger than i128 + Some(i256::from_i128(i128::MAX) + i256::from(1000)), + ]))), // num + ], + arg_fields: vec![arg_field], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new().invoke_with_args(args); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string().lines().next().unwrap(), + "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727" + ); + } } diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index baa3147f6258d..5b8252137be11 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -18,8 +18,8 @@ use std::sync::LazyLock; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_doc::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::Documentation; @@ -45,6 +45,16 @@ static DOCUMENTATION_ACOS: LazyLock = LazyLock::new(|| { "acos(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT acos(1); ++----------+ +| acos(1) | ++----------+ +| 0.0 | ++----------+ +```"#, + ) .build() }); @@ -69,15 +79,24 @@ pub fn acosh_order(input: &[ExprProperties]) -> Result { } } -static DOCUMENTATION_ACOSH: LazyLock = LazyLock::new(|| { - Documentation::builder( +static DOCUMENTATION_ACOSH: LazyLock = + LazyLock::new(|| { + Documentation::builder( DOC_SECTION_MATH, "Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number.", "acosh(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example(r#"```sql +> SELECT acosh(2); ++------------+ +| acosh(2) | ++------------+ +| 1.31696 | ++------------+ +```"#) .build() -}); + }); pub fn get_acosh_doc() -> &'static Documentation { &DOCUMENTATION_ACOSH @@ -105,6 +124,16 @@ static DOCUMENTATION_ASIN: LazyLock = LazyLock::new(|| { "asin(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT asin(0.5); ++------------+ +| asin(0.5) | ++------------+ +| 0.5235988 | ++------------+ +```"#, + ) .build() }); @@ -124,6 +153,16 @@ static DOCUMENTATION_ASINH: LazyLock = LazyLock::new(|| { "asinh(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#" ```sql +> SELECT asinh(1); ++------------+ +| asinh(1) | ++------------+ +| 0.8813736 | ++------------+ +```"#, + ) .build() }); @@ -143,6 +182,16 @@ static DOCUMENTATION_ATAN: LazyLock = LazyLock::new(|| { "atan(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql + > SELECT atan(1); ++-----------+ +| atan(1) | ++-----------+ +| 0.7853982 | ++-----------+ +```"#, + ) .build() }); @@ -165,15 +214,24 @@ pub fn atanh_order(input: &[ExprProperties]) -> Result { } } -static DOCUMENTATION_ATANH: LazyLock = LazyLock::new(|| { - Documentation::builder( +static DOCUMENTATION_ATANH: LazyLock = + LazyLock::new(|| { + Documentation::builder( DOC_SECTION_MATH, "Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number.", "atanh(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example(r#"```sql + > SELECT atanh(0.5); ++-------------+ +| atanh(0.5) | ++-------------+ +| 0.5493061 | ++-------------+ +```"#) .build() -}); + }); pub fn get_atanh_doc() -> &'static Documentation { &DOCUMENTATION_ATANH @@ -185,8 +243,9 @@ pub fn atan2_order(_input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } -static DOCUMENTATION_ATANH2: LazyLock = LazyLock::new(|| { - Documentation::builder( +static DOCUMENTATION_ATANH2: LazyLock = + LazyLock::new(|| { + Documentation::builder( DOC_SECTION_MATH, "Returns the arc tangent or inverse tangent of `expression_y / expression_x`.", "atan2(expression_y, expression_x)", @@ -201,8 +260,16 @@ Can be a constant, column, or function, and any combination of arithmetic operat r#"Second numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators."#, ) + .with_sql_example(r#"```sql +> SELECT atan2(1, 1); ++------------+ +| atan2(1,1) | ++------------+ +| 0.7853982 | ++------------+ +```"#) .build() -}); + }); pub fn get_atan2_doc() -> &'static Documentation { &DOCUMENTATION_ATANH2 @@ -220,6 +287,16 @@ static DOCUMENTATION_CBRT: LazyLock = LazyLock::new(|| { "cbrt(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT cbrt(27); ++-----------+ +| cbrt(27) | ++-----------+ +| 3.0 | ++-----------+ +```"#, + ) .build() }); @@ -239,6 +316,16 @@ static DOCUMENTATION_CEIL: LazyLock = LazyLock::new(|| { "ceil(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql + > SELECT ceil(3.14); ++------------+ +| ceil(3.14) | ++------------+ +| 4.0 | ++------------+ +```"#, + ) .build() }); @@ -260,6 +347,16 @@ static DOCUMENTATION_COS: LazyLock = LazyLock::new(|| { "cos(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT cos(0); ++--------+ +| cos(0) | ++--------+ +| 1.0 | ++--------+ +```"#, + ) .build() }); @@ -290,6 +387,16 @@ static DOCUMENTATION_COSH: LazyLock = LazyLock::new(|| { "cosh(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT cosh(1); ++-----------+ +| cosh(1) | ++-----------+ +| 1.5430806 | ++-----------+ +```"#, + ) .build() }); @@ -309,6 +416,16 @@ static DOCUMENTATION_DEGREES: LazyLock = LazyLock::new(|| { "degrees(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql + > SELECT degrees(pi()); ++------------+ +| degrees(0) | ++------------+ +| 180.0 | ++------------+ +```"#, + ) .build() }); @@ -328,6 +445,16 @@ static DOCUMENTATION_EXP: LazyLock = LazyLock::new(|| { "exp(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT exp(1); ++---------+ +| exp(1) | ++---------+ +| 2.71828 | ++---------+ +```"#, + ) .build() }); @@ -347,6 +474,16 @@ static DOCUMENTATION_FLOOR: LazyLock = LazyLock::new(|| { "floor(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT floor(3.14); ++-------------+ +| floor(3.14) | ++-------------+ +| 3.0 | ++-------------+ +```"#, + ) .build() }); @@ -375,6 +512,16 @@ static DOCUMENTATION_LN: LazyLock = LazyLock::new(|| { "ln(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT ln(2.71828); ++-------------+ +| ln(2.71828) | ++-------------+ +| 1.0 | ++-------------+ +```"#, + ) .build() }); @@ -403,6 +550,16 @@ static DOCUMENTATION_LOG2: LazyLock = LazyLock::new(|| { "log2(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT log2(8); ++-----------+ +| log2(8) | ++-----------+ +| 3.0 | ++-----------+ +```"#, + ) .build() }); @@ -431,6 +588,16 @@ static DOCUMENTATION_LOG10: LazyLock = LazyLock::new(|| { "log10(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT log10(100); ++-------------+ +| log10(100) | ++-------------+ +| 2.0 | ++-------------+ +```"#, + ) .build() }); @@ -443,18 +610,28 @@ pub fn radians_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } -static DOCUMENTATION_RADIONS: LazyLock = LazyLock::new(|| { +static DOCUMENTATION_RADIANS: LazyLock = LazyLock::new(|| { Documentation::builder( DOC_SECTION_MATH, "Converts degrees to radians.", "radians(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT radians(180); ++----------------+ +| radians(180) | ++----------------+ +| 3.14159265359 | ++----------------+ +```"#, + ) .build() }); pub fn get_radians_doc() -> &'static Documentation { - &DOCUMENTATION_RADIONS + &DOCUMENTATION_RADIANS } /// Non-decreasing on \[0, π\] and then non-increasing on \[π, 2π\]. @@ -471,6 +648,16 @@ static DOCUMENTATION_SIN: LazyLock = LazyLock::new(|| { "sin(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT sin(0); ++----------+ +| sin(0) | ++----------+ +| 0.0 | ++----------+ +```"#, + ) .build() }); @@ -490,6 +677,16 @@ static DOCUMENTATION_SINH: LazyLock = LazyLock::new(|| { "sinh(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT sinh(1); ++-----------+ +| sinh(1) | ++-----------+ +| 1.1752012 | ++-----------+ +```"#, + ) .build() }); @@ -539,6 +736,16 @@ static DOCUMENTATION_TAN: LazyLock = LazyLock::new(|| { "tan(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql +> SELECT tan(pi()/4); ++--------------+ +| tan(PI()/4) | ++--------------+ +| 1.0 | ++--------------+ +```"#, + ) .build() }); @@ -558,6 +765,16 @@ static DOCUMENTATION_TANH: LazyLock = LazyLock::new(|| { "tanh(numeric_expression)", ) .with_standard_argument("numeric_expression", Some("Numeric")) + .with_sql_example( + r#"```sql + > SELECT tanh(20); + +----------+ + | tanh(20) | + +----------+ + | 1.0 | + +----------+ + ```"#, + ) .build() }); diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 34a5c2a1c16bb..759b0f5fd50ac 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -31,9 +31,17 @@ use std::sync::Arc; doc_section(label = "Math Functions"), description = "Returns true if a given number is +NaN or -NaN otherwise returns false.", syntax_example = "isnan(numeric_expression)", + sql_example = r#"```sql +> SELECT isnan(1); ++----------+ +| isnan(1) | ++----------+ +| false | ++----------+ +```"#, standard_argument(name = "numeric_expression", prefix = "Numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct IsNanFunc { signature: Signature, } diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index 9effb82896ee0..f0835b4d48a0c 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -36,6 +36,14 @@ use datafusion_macros::user_doc; description = r#"Returns the first argument if it's not _NaN_. Returns the second argument otherwise."#, syntax_example = "nanvl(expression_x, expression_y)", + sql_example = r#"```sql +> SELECT nanvl(0, 5); ++------------+ +| nanvl(0,5) | ++------------+ +| 0 | ++------------+ +```"#, argument( name = "expression_x", description = "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators." @@ -45,7 +53,7 @@ Returns the second argument otherwise."#, description = "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NanvlFunc { signature: Signature, } diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index 5339a9b14a283..71a8e21a52f26 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -32,7 +32,7 @@ use datafusion_macros::user_doc; description = "Returns an approximate value of π.", syntax_example = "pi()" )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct PiFunc { signature: Signature, } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 028ec2fef7937..ad2e795d086e9 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -24,8 +24,8 @@ use super::log::LogFunc; use arrow::array::{ArrayRef, AsArray, Int64Array}; use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ - arrow_datafusion_err, exec_datafusion_err, exec_err, internal_datafusion_err, - plan_datafusion_err, DataFusionError, Result, ScalarValue, + arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, + DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; @@ -39,10 +39,18 @@ use datafusion_macros::user_doc; doc_section(label = "Math Functions"), description = "Returns a base expression raised to the power of an exponent.", syntax_example = "power(base, exponent)", + sql_example = r#"```sql +> SELECT power(2, 3); ++-------------+ +| power(2,3) | ++-------------+ +| 8 | ++-------------+ +```"#, standard_argument(name = "base", prefix = "Numeric"), standard_argument(name = "exponent", prefix = "Exponent numeric") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct PowerFunc { signature: Signature, aliases: Vec, @@ -156,12 +164,15 @@ impl ScalarUDFImpl for PowerFunc { let exponent_type = info.get_data_type(&exponent)?; match exponent { - Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_zero(&exponent_type)? => + { Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::new_one(&info.get_data_type(&base)?)?, + None, ))) } - Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { + Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) } Expr::ScalarFunction(ScalarFunction { func, mut args }) @@ -186,13 +197,18 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { + use super::*; use arrow::array::Float64Array; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int64_array}; - - use super::*; + use datafusion_common::config::ConfigOptions; #[test] fn test_power_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("a", DataType::Float64, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -202,8 +218,10 @@ mod tests { 3.0, 2.0, 4.0, 4.0, ]))), // exponent ], + arg_fields, number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = PowerFunc::new() .invoke_with_args(args) @@ -227,13 +245,19 @@ mod tests { #[test] fn test_power_i64() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Int64, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ], + arg_fields, number_rows: 4, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 607f9fb09f2ae..d63e76a06d011 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use arrow::array::Float64Array; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -32,9 +32,17 @@ use datafusion_macros::user_doc; doc_section(label = "Math Functions"), description = r#"Returns a random float value in the range [0, 1). The random seed is unique to each row."#, - syntax_example = "random()" + syntax_example = "random()", + sql_example = r#"```sql +> SELECT random(); ++------------------+ +| random() | ++------------------+ +| 0.7389238902938 | ++------------------+ +```"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RandomFunc { signature: Signature, } @@ -74,9 +82,9 @@ impl ScalarUDFImpl for RandomFunc { if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } - let mut rng = thread_rng(); + let mut rng = rng(); let mut values = vec![0.0; args.number_rows]; - // Equivalent to set each element with rng.gen_range(0.0..1.0), but more efficient + // Equivalent to set each element with rng.random_range(0.0..1.0), but more efficient rng.fill(&mut values[..]); let array = Float64Array::from(values); diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index fc87b7e63a62f..837f0be432403 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -41,9 +41,17 @@ use datafusion_macros::user_doc; argument( name = "decimal_places", description = "Optional. The number of decimal places to round to. Defaults to 0." - ) + ), + sql_example = r#"```sql +> SELECT round(3.14159); ++--------------+ +| round(3.14159)| ++--------------+ +| 3.0 | ++--------------+ +```"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RoundFunc { signature: Signature, } @@ -305,6 +313,6 @@ mod test { let result = round(&args); assert!(result.is_err()); - assert!(matches!(result, Err(DataFusionError::Execution { .. }))); + assert!(matches!(result, Err(DataFusionError::Execution(_)))); } } diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index ba5422afa7686..bbe6178f39b79 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -38,9 +38,17 @@ use crate::utils::make_scalar_function; Negative numbers return `-1`. Zero and positive numbers return `1`."#, syntax_example = "signum(numeric_expression)", - standard_argument(name = "numeric_expression", prefix = "Numeric") + standard_argument(name = "numeric_expression", prefix = "Numeric"), + sql_example = r#"```sql +> SELECT signum(-42); ++-------------+ +| signum(-42) | ++-------------+ +| -1 | ++-------------+ +```"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SignumFunc { signature: Signature, } @@ -138,8 +146,9 @@ mod test { use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use crate::math::signum::SignumFunc; @@ -157,10 +166,13 @@ mod test { f32::INFINITY, f32::NEG_INFINITY, ])); + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, number_rows: array.len(), - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = SignumFunc::new() .invoke_with_args(args) @@ -201,10 +213,13 @@ mod test { f64::INFINITY, f64::NEG_INFINITY, ])); + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, number_rows: array.len(), - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 2ac291204a0bc..9d1b4336f6389 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -45,9 +45,18 @@ use datafusion_macros::user_doc; `decimal_places` is a positive integer, truncates digits to the right of the decimal point. If `decimal_places` is a negative integer, replaces digits to the left of the decimal point with `0`."# - ) + ), + sql_example = r#" + ```sql + > SELECT trunc(42.738); + +----------------+ + | trunc(42.738) | + +----------------+ + | 42 | + +----------------+ + ```"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct TruncFunc { signature: Signature, } diff --git a/datafusion/functions/src/planner.rs b/datafusion/functions/src/planner.rs index 93edec7ece307..7228cdc07e727 100644 --- a/datafusion/functions/src/planner.rs +++ b/datafusion/functions/src/planner.rs @@ -24,9 +24,14 @@ use datafusion_expr::{ Expr, }; +#[deprecated( + since = "0.50.0", + note = "Use UnicodeFunctionPlanner and DateTimeFunctionPlanner instead" +)] #[derive(Default, Debug)] pub struct UserDefinedFunctionPlanner; +#[expect(deprecated)] impl ExprPlanner for UserDefinedFunctionPlanner { #[cfg(feature = "datetime_expressions")] fn plan_extract(&self, args: Vec) -> Result>> { diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 13fbc049af582..da4e23f91de7d 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,15 +17,20 @@ //! "regex" DataFusion functions +use arrow::error::ArrowError; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::sync::Arc; - pub mod regexpcount; +pub mod regexpinstr; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; // create UDFs make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); +make_udf_function!(regexpinstr::RegexpInstrFunc, regexp_instr); make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); @@ -60,7 +65,35 @@ pub mod expr_fn { super::regexp_match().call(args) } - /// Returns true if a has at least one match in a string, false otherwise. + /// Returns index of regular expression matches in a string. + pub fn regexp_instr( + values: Expr, + regex: Expr, + start: Option, + n: Option, + endoption: Option, + flags: Option, + subexpr: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + if let Some(n) = n { + args.push(n); + }; + if let Some(endoption) = endoption { + args.push(endoption); + }; + if let Some(flags) = flags { + args.push(flags); + }; + if let Some(subexpr) = subexpr { + args.push(subexpr); + }; + super::regexp_instr().call(args) + } + /// Returns true if a regex has at least one match in a string, false otherwise. pub fn regexp_like(values: Expr, regex: Expr, flags: Option) -> Expr { let mut args = vec![values, regex]; if let Some(flags) = flags { @@ -89,7 +122,45 @@ pub fn functions() -> Vec> { vec![ regexp_count(), regexp_match(), + regexp_instr(), regexp_like(), regexp_replace(), ] } + +pub fn compile_and_cache_regex<'strings, 'cache>( + regex: &'strings str, + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result<&'cache Regex, ArrowError> +where + 'strings: 'cache, +{ + let result = match regex_cache.entry((regex, flags)) { + Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), + Entry::Vacant(vacant_entry) => { + let compiled = compile_regex(regex, flags)?; + vacant_entry.insert(compiled) + } + }; + Ok(result) +} + +pub fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count()/regexp_instr() does not support the global flag" + .to_string(), + )); + } + format!("(?{flags}){regex}") + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) + }) +} diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8cb1a4ff3d606..8bad506217aa5 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::regex::{compile_and_cache_regex, compile_regex}; use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType}; use arrow::datatypes::{DataType, Int64Type}; use arrow::datatypes::{ @@ -29,7 +30,6 @@ use datafusion_expr::{ use datafusion_macros::user_doc; use itertools::izip; use regex::Regex; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Arc; @@ -61,7 +61,7 @@ use std::sync::Arc; - **U**: swap the meaning of x* and x*?"# ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RegexpCountFunc { signature: Signature, } @@ -550,45 +550,6 @@ where } } -fn compile_and_cache_regex<'strings, 'cache>( - regex: &'strings str, - flags: Option<&'strings str>, - regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, -) -> Result<&'cache Regex, ArrowError> -where - 'strings: 'cache, -{ - let result = match regex_cache.entry((regex, flags)) { - Entry::Occupied(occupied_entry) => occupied_entry.into_mut(), - Entry::Vacant(vacant_entry) => { - let compiled = compile_regex(regex, flags)?; - vacant_entry.insert(compiled) - } - }; - Ok(result) -} - -fn compile_regex(regex: &str, flags: Option<&str>) -> Result { - let pattern = match flags { - None | Some("") => regex.to_string(), - Some(flags) => { - if flags.contains("g") { - return Err(ArrowError::ComputeError( - "regexp_count() does not support global flag".to_string(), - )); - } - format!("(?{}){}", flags, regex) - } - }; - - Regex::new(&pattern).map_err(|_| { - ArrowError::ComputeError(format!( - "Regular expression did not compile: {}", - pattern - )) - }) -} - fn count_matches( value: Option<&str>, pattern: &Regex, @@ -619,6 +580,8 @@ fn count_matches( mod tests { use super::*; use arrow::array::{GenericStringArray, StringViewArray}; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; use datafusion_expr::ScalarFunctionArgs; #[test] @@ -647,6 +610,27 @@ mod tests { test_case_regexp_count_cache_check::>(); } + fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result { + let args_values = args + .iter() + .map(|sv| ColumnarValue::Scalar(sv.clone())) + .collect(); + + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true).into()) + .collect::>(); + + RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_values, + arg_fields, + number_rows: args.len(), + return_field: Field::new("f", Int64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + } + fn test_case_sensitive_regexp_count_scalar() { let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; let regex = "abc"; @@ -657,11 +641,7 @@ mod tests { let v_sv = ScalarValue::Utf8(Some(v.to_string())); let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -672,11 +652,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -687,11 +663,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -713,15 +685,7 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -732,15 +696,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -751,15 +707,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -783,16 +731,13 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -804,16 +749,13 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -825,16 +767,13 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -903,20 +842,16 @@ mod tests { values.iter().enumerate().for_each(|(pos, &v)| { // utf8 let v_sv = ScalarValue::Utf8(Some(v.to_string())); - let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| s.to_string())); + let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| (*s).to_string())); let start_sv = ScalarValue::Int64(Some(start)); - let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); + let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| (*f).to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -926,18 +861,16 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); - let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); - let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let regex_sv = + ScalarValue::LargeUtf8(regex.get(pos).map(|s| (*s).to_string())); + let flags_sv = + ScalarValue::LargeUtf8(flags.get(pos).map(|f| (*f).to_string())); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -947,18 +880,16 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); - let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); - let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let regex_sv = + ScalarValue::Utf8View(regex.get(pos).map(|s| (*s).to_string())); + let flags_sv = + ScalarValue::Utf8View(flags.get(pos).map(|f| (*f).to_string())); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs new file mode 100644 index 0000000000000..851c182a90dd0 --- /dev/null +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -0,0 +1,824 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, Datum, Int64Array, PrimitiveArray, StringArrayType, +}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, +}; +use datafusion_macros::user_doc; +use itertools::izip; +use regex::Regex; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::regex::compile_and_cache_regex; + +#[user_doc( + doc_section(label = "Regular Expression Functions"), + description = "Returns the position in a string where the specified occurrence of a POSIX regular expression is located.", + syntax_example = "regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]])", + sql_example = r#"```sql +> SELECT regexp_instr('ABCDEF', 'C(.)(..)'); ++---------------------------------------------------------------+ +| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) | ++---------------------------------------------------------------+ +| 3 | ++---------------------------------------------------------------+ +```"#, + standard_argument(name = "str", prefix = "String"), + standard_argument(name = "regexp", prefix = "Regular"), + argument( + name = "start", + description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1" + ), + argument( + name = "N", + description = "- **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function." + ), + argument( + name = "flags", + description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"# + ), + argument( + name = "subexpr", + description = "Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct RegexpInstrFunc { + signature: Signature, +} + +impl Default for RegexpInstrFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpInstrFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Uniform(2, vec![Utf8View, LargeUtf8, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Int64, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64, Int64, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64, Int64, Utf8, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpInstrFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_instr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + let args = &args.args; + + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.to_array(inferred_length)) + .collect::>>()?; + + let result = regexp_instr_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +pub fn regexp_instr_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=6).contains(&args_len) { + return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), + other => { + return internal_err!( + "Unsupported data type {other:?} for function regexp_instr" + ); + } + } + + regexp_instr( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + if args_len > 4 { Some(&args[4]) } else { None }, + if args_len > 5 { Some(&args[5]) } else { None }, + ) + .map_err(|e| e.into()) +} + +/// `arrow-rs` style implementation of `regexp_instr` function. +/// This function `regexp_instr` is responsible for returning the index of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `nth_array` (optional): The array of start nth for the search. +/// - `endoption_array` (optional): The array of endoption positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// - `subexpr_array` (optional): The array of subexpr positions for the search. +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. +pub fn regexp_instr( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + nth_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, + subexpr_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, _) = regex_array.get(); + let start_array = start_array.map(|start| { + let (start, _) = start.get(); + start + }); + let nth_array = nth_array.map(|nth| { + let (nth, _) = nth.get(); + nth + }); + let flags_array = flags_array.map(|flags| { + let (flags, _) = flags.get(); + flags + }); + let subexpr_array = subexpr_array.map(|subexpr| { + let (subexpr, _) = subexpr.get(); + subexpr + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + start_array.map(|start| start.as_primitive::()), + nth_array.map(|nth| nth.as_primitive::()), + None, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + start_array.map(|start| start.as_primitive::()), + nth_array.map(|nth| nth.as_primitive::()), + Some(flags_array.as_string::()), + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + ), + (LargeUtf8, LargeUtf8, None) => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + start_array.map(|start| start.as_primitive::()), + nth_array.map(|nth| nth.as_primitive::()), + None, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_instr_inner( + values.as_string::(), + regex_array.as_string::(), + start_array.map(|start| start.as_primitive::()), + nth_array.map(|nth| nth.as_primitive::()), + Some(flags_array.as_string::()), + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + ), + (Utf8View, Utf8View, None) => regexp_instr_inner( + values.as_string_view(), + regex_array.as_string_view(), + start_array.map(|start| start.as_primitive::()), + nth_array.map(|nth| nth.as_primitive::()), + None, + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_instr_inner( + values.as_string_view(), + regex_array.as_string_view(), + start_array.map(|start| start.as_primitive::()), + nth_array.map(|nth| nth.as_primitive::()), + Some(flags_array.as_string_view()), + subexpr_array.map(|subexpr| subexpr.as_primitive::()), + ), + _ => Err(ArrowError::ComputeError( + "regexp_instr() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), + } +} + +#[allow(clippy::too_many_arguments)] +pub fn regexp_instr_inner<'a, S>( + values: S, + regex_array: S, + start_array: Option<&Int64Array>, + nth_array: Option<&Int64Array>, + flags_array: Option, + subexp_array: Option<&Int64Array>, +) -> Result +where + S: StringArrayType<'a>, +{ + let len = values.len(); + + let default_start_array = PrimitiveArray::::from(vec![1; len]); + let start_array = start_array.unwrap_or(&default_start_array); + let start_input: Vec = (0..start_array.len()) + .map(|i| start_array.value(i)) // handle nulls as 0 + .collect(); + + let default_nth_array = PrimitiveArray::::from(vec![1; len]); + let nth_array = nth_array.unwrap_or(&default_nth_array); + let nth_input: Vec = (0..nth_array.len()) + .map(|i| nth_array.value(i)) // handle nulls as 0 + .collect(); + + let flags_input = match flags_array { + Some(flags) => flags.iter().collect(), + None => vec![None; len], + }; + + let default_subexp_array = PrimitiveArray::::from(vec![0; len]); + let subexp_array = subexp_array.unwrap_or(&default_subexp_array); + let subexp_input: Vec = (0..subexp_array.len()) + .map(|i| subexp_array.value(i)) // handle nulls as 0 + .collect(); + + let mut regex_cache = HashMap::new(); + + let result: Result>, ArrowError> = izip!( + values.iter(), + regex_array.iter(), + start_input.iter(), + nth_input.iter(), + flags_input.iter(), + subexp_input.iter() + ) + .map(|(value, regex, start, nth, flags, subexp)| match regex { + None => Ok(None), + Some("") => Ok(Some(0)), + Some(regex) => get_index( + value, + regex, + *start, + *nth, + *subexp, + *flags, + &mut regex_cache, + ), + }) + .collect(); + Ok(Arc::new(Int64Array::from(result?))) +} + +fn handle_subexp( + pattern: &Regex, + search_slice: &str, + subexpr: i64, + value: &str, + byte_start_offset: usize, +) -> Result, ArrowError> { + if let Some(captures) = pattern.captures(search_slice) { + if let Some(matched) = captures.get(subexpr as usize) { + // Convert byte offset relative to search_slice back to 1-based character offset + // relative to the original `value` string. + let start_char_offset = + value[..byte_start_offset + matched.start()].chars().count() as i64 + 1; + return Ok(Some(start_char_offset)); + } + } + Ok(Some(0)) // Return 0 if the subexpression was not found +} + +fn get_nth_match( + pattern: &Regex, + search_slice: &str, + n: i64, + byte_start_offset: usize, + value: &str, +) -> Result, ArrowError> { + if let Some(mat) = pattern.find_iter(search_slice).nth((n - 1) as usize) { + // Convert byte offset relative to search_slice back to 1-based character offset + // relative to the original `value` string. + let match_start_byte_offset = byte_start_offset + mat.start(); + let match_start_char_offset = + value[..match_start_byte_offset].chars().count() as i64 + 1; + Ok(Some(match_start_char_offset)) + } else { + Ok(Some(0)) // Return 0 if the N-th match was not found + } +} +fn get_index<'strings, 'cache>( + value: Option<&str>, + pattern: &'strings str, + start: i64, + n: i64, + subexpr: i64, + flags: Option<&'strings str>, + regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>, +) -> Result, ArrowError> +where + 'strings: 'cache, +{ + let value = match value { + None => return Ok(None), + Some("") => return Ok(Some(0)), + Some(value) => value, + }; + let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?; + // println!("get_index: value = {}, pattern = {}, start = {}, n = {}, subexpr = {}, flags = {:?}", value, pattern, start, n, subexpr, flags); + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_instr() requires start to be 1-based".to_string(), + )); + } + + if n < 1 { + return Err(ArrowError::ComputeError( + "N must be 1 or greater".to_string(), + )); + } + + // --- Simplified byte_start_offset calculation --- + let total_chars = value.chars().count() as i64; + let byte_start_offset: usize = if start > total_chars { + // If start is beyond the total characters, it means we start searching + // after the string effectively. No matches possible. + return Ok(Some(0)); + } else { + // Get the byte offset for the (start - 1)-th character (0-based) + value + .char_indices() + .nth((start - 1) as usize) + .map(|(idx, _)| idx) + .unwrap_or(0) // Should not happen if start is valid and <= total_chars + }; + // --- End simplified calculation --- + + let search_slice = &value[byte_start_offset..]; + + // Handle subexpression capturing first, as it takes precedence + if subexpr > 0 { + return handle_subexp(pattern, search_slice, subexpr, value, byte_start_offset); + } + + // Use nth to get the N-th match (n is 1-based, nth is 0-based) + get_nth_match(pattern, search_slice, n, byte_start_offset, value) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use arrow::array::{GenericStringArray, StringViewArray}; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::ScalarFunctionArgs; + #[test] + fn test_regexp_instr() { + test_case_sensitive_regexp_instr_nulls(); + test_case_sensitive_regexp_instr_scalar(); + test_case_sensitive_regexp_instr_scalar_start(); + test_case_sensitive_regexp_instr_scalar_nth(); + test_case_sensitive_regexp_instr_scalar_subexp(); + + test_case_sensitive_regexp_instr_array::>(); + test_case_sensitive_regexp_instr_array::>(); + test_case_sensitive_regexp_instr_array::(); + + test_case_sensitive_regexp_instr_array_start::>(); + test_case_sensitive_regexp_instr_array_start::>(); + test_case_sensitive_regexp_instr_array_start::(); + + test_case_sensitive_regexp_instr_array_nth::>(); + test_case_sensitive_regexp_instr_array_nth::>(); + test_case_sensitive_regexp_instr_array_nth::(); + } + + fn regexp_instr_with_scalar_values(args: &[ScalarValue]) -> Result { + let args_values: Vec = args + .iter() + .map(|sv| ColumnarValue::Scalar(sv.clone())) + .collect(); + + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, a)| { + Arc::new(Field::new(format!("arg_{idx}"), a.data_type(), true)) + }) + .collect::>(); + + RegexpInstrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_values, + arg_fields, + number_rows: args.len(), + return_field: Arc::new(Field::new("f", Int64, true)), + config_options: Arc::new(ConfigOptions::default()), + }) + } + + fn test_case_sensitive_regexp_instr_nulls() { + let v = ""; + let r = ""; + let expected = 0; + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]); + // let res_exp = re.unwrap(); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, Some(expected), "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + } + fn test_case_sensitive_regexp_instr_scalar() { + let values = [ + "hello world", + "abcdefg", + "xyz123xyz", + "no match here", + "abc", + "ДатаФусион数据融合📊🔥", + ]; + let regex = ["o", "d", "123", "z", "gg", "📊"]; + + let expected: Vec = vec![5, 4, 4, 0, 0, 15]; + + izip!(values.iter(), regex.iter()) + .enumerate() + .for_each(|(pos, (&v, &r))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let expected = expected.get(pos).cloned(); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + // let res_exp = re.unwrap(); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let re = regexp_instr_with_scalar_values(&[v_sv, regex_sv]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_scalar_start() { + let values = ["abcabcabc", "abcabcabc", ""]; + let regex = ["abc", "abc", "gg"]; + let start = [4, 5, 5]; + let expected: Vec = vec![4, 7, 0]; + + izip!(values.iter(), regex.iter(), start.iter()) + .enumerate() + .for_each(|(pos, (&v, &r, &s))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let expected = expected.get(pos).cloned(); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let re = + regexp_instr_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_scalar_nth() { + let values = ["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]; + let regex = ["abc", "abc", "abc", "abc"]; + let start = [1, 1, 1, 1]; + let nth = [1, 2, 3, 4]; + let expected: Vec = vec![1, 4, 7, 0]; + + izip!(values.iter(), regex.iter(), start.iter(), nth.iter()) + .enumerate() + .for_each(|(pos, (&v, &r, &s, &n))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let expected = expected.get(pos).cloned(); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_scalar_subexp() { + let values = ["12 abc def ghi 34"]; + let regex = ["(abc) (def) (ghi)"]; + let start = [1]; + let nth = [1]; + let flags = ["i"]; + let subexps = [2]; + let expected: Vec = vec![8]; + + izip!( + values.iter(), + regex.iter(), + start.iter(), + nth.iter(), + flags.iter(), + subexps.iter() + ) + .enumerate() + .for_each(|(pos, (&v, &r, &s, &n, &flag, &subexp))| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let flags_sv = ScalarValue::Utf8(Some(flag.to_string())); + let subexp_sv = ScalarValue::Int64(Some(subexp)); + let expected = expected.get(pos).cloned(); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + flags_sv, + subexp_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let flags_sv = ScalarValue::LargeUtf8(Some(flag.to_string())); + let subexp_sv = ScalarValue::Int64(Some(subexp)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + flags_sv, + subexp_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(r.to_string())); + let start_sv = ScalarValue::Int64(Some(s)); + let nth_sv = ScalarValue::Int64(Some(n)); + let flags_sv = ScalarValue::Utf8View(Some(flag.to_string())); + let subexp_sv = ScalarValue::Int64(Some(subexp)); + let re = regexp_instr_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + nth_sv.clone(), + flags_sv, + subexp_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_instr scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_instr_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec![ + "hello world", + "abcdefg", + "xyz123xyz", + "no match here", + "", + ]); + let regex = A::from(vec!["o", "d", "123", "z", "gg"]); + + let expected = Int64Array::from(vec![5, 4, 4, 0, 0]); + let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_instr_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["abcabcabc", "abcabcabc", ""]); + let regex = A::from(vec!["abc", "abc", "gg"]); + let start = Int64Array::from(vec![4, 5, 5]); + let expected = Int64Array::from(vec![4, 7, 0]); + + let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_instr_array_nth() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]); + let regex = A::from(vec!["abc", "abc", "abc", "abc"]); + let start = Int64Array::from(vec![1, 1, 1, 1]); + let nth = Int64Array::from(vec![1, 2, 3, 4]); + let expected = Int64Array::from(vec![1, 4, 7, 0]); + + let re = regexp_instr_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(nth), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 2080bb9fe818f..d75eb9141c056 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -27,11 +27,14 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, - TypeSignatureClass, Volatility, + binary_expr, cast, Coercion, ColumnarValue, Documentation, Expr, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr_common::operator::Operator; +use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use std::any::Any; use std::sync::Arc; @@ -67,7 +70,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo - **U**: swap the meaning of x* and x*?"# ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RegexpLikeFunc { signature: Signature, } @@ -153,11 +156,76 @@ impl ScalarUDFImpl for RegexpLikeFunc { } } + fn simplify( + &self, + mut args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + // Try to simplify regexp_like usage to one of the builtin operators since those have + // optimized code paths for the case where the regular expression pattern is a scalar. + // Additionally, the expression simplification optimization pass will attempt to further + // simplify regular expression patterns used in operator expressions. + let Some(op) = derive_operator(&args) else { + return Ok(ExprSimplifyResult::Original(args)); + }; + + let string_type = info.get_data_type(&args[0])?; + let regexp_type = info.get_data_type(&args[1])?; + let binary_type_coercer = BinaryTypeCoercer::new(&string_type, &op, ®exp_type); + let Ok((coerced_string_type, coerced_regexp_type)) = + binary_type_coercer.get_input_types() + else { + return Ok(ExprSimplifyResult::Original(args)); + }; + + // regexp_like(str, regexp [, flags]) + let regexp = args.swap_remove(1); + let string = args.swap_remove(0); + + Ok(ExprSimplifyResult::Simplified(binary_expr( + if string_type != coerced_string_type { + cast(string, coerced_string_type) + } else { + string + }, + op, + if regexp_type != coerced_regexp_type { + cast(regexp, coerced_regexp_type) + } else { + regexp + }, + ))) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } +fn derive_operator(args: &[Expr]) -> Option { + match args.len() { + // regexp_like(str, regexp, flags) + 3 => { + match &args[2] { + Expr::Literal(ScalarValue::Utf8(Some(flags)), _) => { + match flags.as_str() { + "i" => Some(Operator::RegexIMatch), + "" => Some(Operator::RegexMatch), + // Any flags besides 'i' have no operator equivalent + _ => None, + } + } + // `flags` is not a literal, so we can't derive the correct operator statically + _ => None, + } + } + // regexp_like(str, regexp) + 2 => Some(Operator::RegexMatch), + // Should never happen, but just in case + _ => None, + } +} + /// Tests a string using a regular expression returning true if at /// least one match, false otherwise. /// diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 1119e66398d1d..ba52822a02f8c 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -66,7 +66,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo - **U**: swap the meaning of x* and x*?"# ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RegexpMatchFunc { signature: Signature, } diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 3a83564ff11fe..ca3d19822e137 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -24,7 +24,9 @@ use arrow::array::{new_null_array, ArrayIter, AsArray}; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::array::{ArrayAccessor, StringViewArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_string_view_array; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::ScalarValue; @@ -82,7 +84,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo - **U**: swap the meaning of x* and x*?"# ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RegexpReplaceFunc { signature: Signature, } @@ -95,13 +97,12 @@ impl Default for RegexpReplaceFunc { impl RegexpReplaceFunc { pub fn new() -> Self { use DataType::*; + use TypeSignature::*; Self { signature: Signature::one_of( vec![ - TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8, Utf8, Utf8]), + Uniform(3, vec![Utf8View, LargeUtf8, Utf8]), + Uniform(4, vec![Utf8View, LargeUtf8, Utf8]), ], Volatility::Immutable, ), @@ -238,15 +239,14 @@ fn regex_replace_posix_groups(replacement: &str) -> String { /// # Ok(()) /// # } /// ``` -pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>( - string_array: V, - pattern_array: B, - replacement_array: B, - flags: Option<&ArrayRef>, +pub fn regexp_replace<'a, T: OffsetSizeTrait, U>( + string_array: U, + pattern_array: U, + replacement_array: U, + flags_array: Option, ) -> Result where - V: ArrayAccessor, - B: ArrayAccessor, + U: ArrayAccessor, { // Default implementation for regexp_replace, assumes all args are arrays // and args is a sequence of 3 or 4 elements. @@ -260,7 +260,7 @@ where let pattern_array_iter = ArrayIter::new(pattern_array); let replacement_array_iter = ArrayIter::new(replacement_array); - match flags { + match flags_array { None => { let result_iter = string_array_iter .zip(pattern_array_iter) @@ -307,13 +307,13 @@ where } } } - Some(flags) => { - let flags_array = as_generic_string_array::(flags)?; + Some(flags_array) => { + let flags_array_iter = ArrayIter::new(flags_array); let result_iter = string_array_iter .zip(pattern_array_iter) .zip(replacement_array_iter) - .zip(flags_array.iter()) + .zip(flags_array_iter) .map(|(((string, pattern), replacement), flags)| { match (string, pattern, replacement, flags) { (Some(string), Some(pattern), Some(replacement), Some(flags)) => { @@ -398,12 +398,37 @@ fn _regexp_replace_early_abort( /// Note: If the array is empty or the first argument is null, /// then calls the given early abort function. macro_rules! fetch_string_arg { - ($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident, $ARRAY_SIZE:expr) => {{ - let array = as_generic_string_array::<$T>($ARG)?; - if array.len() == 0 || array.is_null(0) { - return $EARLY_ABORT(array, $ARRAY_SIZE); - } else { - array.value(0) + ($ARG:expr, $NAME:expr, $EARLY_ABORT:ident, $ARRAY_SIZE:expr) => {{ + let string_array_type = ($ARG).data_type(); + match string_array_type { + DataType::Utf8 => { + let array = as_string_array($ARG)?; + if array.len() == 0 || array.is_null(0) { + return $EARLY_ABORT(array, $ARRAY_SIZE); + } else { + array.value(0) + } + } + DataType::LargeUtf8 => { + let array = as_large_string_array($ARG)?; + if array.len() == 0 || array.is_null(0) { + return $EARLY_ABORT(array, $ARRAY_SIZE); + } else { + array.value(0) + } + } + DataType::Utf8View => { + let array = as_string_view_array($ARG)?; + if array.len() == 0 || array.is_null(0) { + return $EARLY_ABORT(array, $ARRAY_SIZE); + } else { + array.value(0) + } + } + _ => unreachable!( + "Invalid data type for regexp_replace: {}", + string_array_type + ), } }}; } @@ -417,23 +442,17 @@ fn _regexp_replace_static_pattern_replace( args: &[ArrayRef], ) -> Result { let array_size = args[0].len(); - let pattern = fetch_string_arg!( - &args[1], - "pattern", - i32, - _regexp_replace_early_abort, - array_size - ); + let pattern = + fetch_string_arg!(&args[1], "pattern", _regexp_replace_early_abort, array_size); let replacement = fetch_string_arg!( &args[2], "replacement", - i32, _regexp_replace_early_abort, array_size ); let flags = match args.len() { 3 => None, - 4 => Some(fetch_string_arg!(&args[3], "flags", i32, _regexp_replace_early_abort, array_size)), + 4 => Some(fetch_string_arg!(&args[3], "flags", _regexp_replace_early_abort, array_size)), other => { return exec_err!( "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." @@ -590,38 +609,61 @@ pub fn specialize_regexp_replace( .map(|arg| arg.to_array(inferred_length)) .collect::>>()?; - match args[0].data_type() { - DataType::Utf8View => { - let string_array = args[0].as_string_view(); + match ( + args[0].data_type(), + args[1].data_type(), + args[2].data_type(), + args.get(3).map(|a| a.data_type()), + ) { + ( + DataType::Utf8, + DataType::Utf8, + DataType::Utf8, + Some(DataType::Utf8) | None, + ) => { + let string_array = args[0].as_string::(); let pattern_array = args[1].as_string::(); let replacement_array = args[2].as_string::(); - regexp_replace::( + let flags_array = args.get(3).map(|a| a.as_string::()); + regexp_replace::( string_array, pattern_array, replacement_array, - args.get(3), + flags_array, ) } - DataType::Utf8 => { - let string_array = args[0].as_string::(); - let pattern_array = args[1].as_string::(); - let replacement_array = args[2].as_string::(); - regexp_replace::( + ( + DataType::Utf8View, + DataType::Utf8View, + DataType::Utf8View, + Some(DataType::Utf8View) | None, + ) => { + let string_array = args[0].as_string_view(); + let pattern_array = args[1].as_string_view(); + let replacement_array = args[2].as_string_view(); + let flags_array = args.get(3).map(|a| a.as_string_view()); + regexp_replace::( string_array, pattern_array, replacement_array, - args.get(3), + flags_array, ) } - DataType::LargeUtf8 => { + ( + DataType::LargeUtf8, + DataType::LargeUtf8, + DataType::LargeUtf8, + Some(DataType::LargeUtf8) | None, + ) => { let string_array = args[0].as_string::(); let pattern_array = args[1].as_string::(); let replacement_array = args[2].as_string::(); - regexp_replace::( + let flags_array = args.get(3).map(|a| a.as_string::()); + regexp_replace::( string_array, pattern_array, replacement_array, - args.get(3), + flags_array, ) } other => { @@ -650,8 +692,8 @@ mod tests { vec!["afooc", "acd", "afoocd1234567890123", "123456789012afooc"]; let values = <$T>::from(values); - let patterns = StringArray::from(patterns); - let replacements = StringArray::from(replacement); + let patterns = <$T>::from(patterns); + let replacements = <$T>::from(replacement); let expected = <$T>::from(expected); let re = _regexp_replace_static_pattern_replace::<$O>(&[ diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 006492a0e07a1..bdf30833127a2 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -16,7 +16,7 @@ // under the License. use crate::utils::make_scalar_function; -use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; +use arrow::array::{ArrayRef, AsArray, Int32Array, StringArrayType}; use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::types::logical_string; @@ -30,7 +30,7 @@ use std::sync::Arc; #[user_doc( doc_section(label = "String Functions"), - description = "Returns the Unicode character code of the first character in a string.", + description = "Returns the first Unicode scalar value of a string.", syntax_example = "ascii(str)", sql_example = r#"```sql > select ascii('abc'); @@ -49,7 +49,7 @@ use std::sync::Arc; standard_argument(name = "str", prefix = "String"), related_udf(name = "chr") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct AsciiFunc { signature: Signature, } @@ -87,9 +87,7 @@ impl ScalarUDFImpl for AsciiFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - use DataType::*; - - Ok(Int32) + Ok(DataType::Int32) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -103,19 +101,22 @@ impl ScalarUDFImpl for AsciiFunc { fn calculate_ascii<'a, V>(array: V) -> Result where - V: ArrayAccessor, + V: StringArrayType<'a, Item = &'a str>, { - let iter = ArrayIter::new(array); - let result = iter - .map(|string| { - string.map(|s| { - let mut chars = s.chars(); - chars.next().map_or(0, |v| v as i32) - }) + let values: Vec<_> = (0..array.len()) + .map(|i| { + if array.is_null(i) { + 0 + } else { + let s = array.value(i); + s.chars().next().map_or(0, |c| c as i32) + } }) - .collect::(); + .collect(); + + let array = Int32Array::new(values.into(), array.nulls().cloned()); - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(array)) } /// Returns the numeric code of the first character of the argument. @@ -182,6 +183,9 @@ mod tests { test_ascii!(Some(String::from("x")), Ok(Some(120))); test_ascii!(Some(String::from("a")), Ok(Some(97))); test_ascii!(Some(String::from("")), Ok(Some(0))); + test_ascii!(Some(String::from("🚀")), Ok(Some(128640))); + test_ascii!(Some(String::from("\n")), Ok(Some(10))); + test_ascii!(Some(String::from("\t")), Ok(Some(9))); test_ascii!(None, Ok(None)); Ok(()) } diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index f8740aa4178b4..1578331e57f89 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -45,7 +45,7 @@ use datafusion_macros::user_doc; related_udf(name = "length"), related_udf(name = "octet_length") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct BitLengthFunc { signature: Signature, } diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 2f1711c9962ad..a7fbdb3c69213 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -65,7 +65,7 @@ fn btrim(args: &[ArrayRef]) -> Result { related_udf(name = "ltrim"), related_udf(name = "rtrim") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct BTrimFunc { signature: Signature, aliases: Vec, diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index a811de7fccf06..4d2beafbae53a 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; -/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. +/// Returns the character with the given code. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { let integer_array = as_int64_array(&args[0])?; @@ -47,20 +47,14 @@ pub fn chr(args: &[ArrayRef]) -> Result { for integer in integer_array { match integer { Some(integer) => { - if integer == 0 { - return exec_err!("null character not permitted."); - } else { - match core::char::from_u32(integer as u32) { - Some(c) => { - builder.append_value(c.encode_utf8(&mut buf)); - } - None => { - return exec_err!( - "requested character too large for encoding." - ); - } + if let Ok(u) = u32::try_from(integer) { + if let Some(c) = core::char::from_u32(u) { + builder.append_value(c.encode_utf8(&mut buf)); + continue; } } + + return exec_err!("invalid Unicode scalar value: {integer}"); } None => { builder.append_null(); @@ -75,7 +69,7 @@ pub fn chr(args: &[ArrayRef]) -> Result { #[user_doc( doc_section(label = "String Functions"), - description = "Returns the character with the specified ASCII or Unicode code value.", + description = "Returns a string containing the character with the specified Unicode scalar value.", syntax_example = "chr(expression)", sql_example = r#"```sql > select chr(128640); @@ -88,7 +82,7 @@ pub fn chr(args: &[ArrayRef]) -> Result { standard_argument(name = "expression", prefix = "String"), related_udf(name = "ascii") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ChrFunc { signature: Signature, } @@ -132,3 +126,116 @@ impl ScalarUDFImpl for ChrFunc { self.doc() } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Int64Array, StringArray}; + use datafusion_common::assert_contains; + + #[test] + fn test_chr_normal() { + let input = Arc::new(Int64Array::from(vec![ + Some(0), // null + Some(65), // A + Some(66), // B + Some(67), // C + Some(128640), // 🚀 + Some(8364), // € + Some(945), // α + None, // NULL + Some(32), // space + Some(10), // newline + Some(9), // tab + Some(0x10FFFF), // 0x10FFFF, the largest Unicode code point + ])); + let result = chr(&[input]).unwrap(); + let string_array = result.as_any().downcast_ref::().unwrap(); + let expected = [ + "\u{0000}", + "A", + "B", + "C", + "🚀", + "€", + "α", + "", + " ", + "\n", + "\t", + "\u{10ffff}", + ]; + + assert_eq!(string_array.len(), expected.len()); + for (i, e) in expected.iter().enumerate() { + assert_eq!(string_array.value(i), *e); + } + } + + #[test] + fn test_chr_error() { + // invalid Unicode code points (too large) + let input = Arc::new(Int64Array::from(vec![i64::MAX])); + let result = chr(&[input]); + assert!(result.is_err()); + assert_contains!( + result.err().unwrap().to_string(), + "invalid Unicode scalar value: 9223372036854775807" + ); + + // invalid Unicode code points (too large) case 2 + let input = Arc::new(Int64Array::from(vec![0x10FFFF + 1])); + let result = chr(&[input]); + assert!(result.is_err()); + assert_contains!( + result.err().unwrap().to_string(), + "invalid Unicode scalar value: 1114112" + ); + + // invalid Unicode code points (surrogate code point) + // link: + let input = Arc::new(Int64Array::from(vec![0xD800 + 1])); + let result = chr(&[input]); + assert!(result.is_err()); + assert_contains!( + result.err().unwrap().to_string(), + "invalid Unicode scalar value: 55297" + ); + + // negative input + let input = Arc::new(Int64Array::from(vec![i64::MIN + 2i64])); // will be 2 if cast to u32 + let result = chr(&[input]); + assert!(result.is_err()); + assert_contains!( + result.err().unwrap().to_string(), + "invalid Unicode scalar value: -9223372036854775806" + ); + + // negative input case 2 + let input = Arc::new(Int64Array::from(vec![-1])); + let result = chr(&[input]); + assert!(result.is_err()); + assert_contains!( + result.err().unwrap().to_string(), + "invalid Unicode scalar value: -1" + ); + + // one error with valid values after + let input = Arc::new(Int64Array::from(vec![65, -1, 66])); // A, -1, B + let result = chr(&[input]); + assert!(result.is_err()); + assert_contains!( + result.err().unwrap().to_string(), + "invalid Unicode scalar value: -1" + ); + } + + #[test] + fn test_chr_empty() { + // empty input array + let input = Arc::new(Int64Array::from(Vec::::new())); + let result = chr(&[input]).unwrap(); + let string_array = result.as_any().downcast_ref::().unwrap(); + assert_eq!(string_array.len(), 0); + } +} diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index c47d08d579e4b..a93e70e714e8b 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -52,7 +52,7 @@ use datafusion_macros::user_doc; ), related_udf(name = "concat_ws") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ConcatFunc { signature: Signature, } @@ -140,7 +140,7 @@ impl ScalarUDFImpl for ConcatFunc { Some(Some(v)) => result.push_str(v), Some(None) => {} // null literal None => plan_err!( - "Concat function does not support scalar type {:?}", + "Concat function does not support scalar type {}", scalar )?, } @@ -295,7 +295,7 @@ pub fn simplify_concat(args: Vec) -> Result { let data_types: Vec<_> = args .iter() .filter_map(|expr| match expr { - Expr::Literal(l) => Some(l.data_type()), + Expr::Literal(l, _) => Some(l.data_type()), _ => None, }) .collect(); @@ -304,25 +304,25 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { - Expr::Literal(ScalarValue::Utf8(None)) => {} - Expr::Literal(ScalarValue::LargeUtf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => {} + Expr::Literal(ScalarValue::LargeUtf8(None), _) => { } - Expr::Literal(ScalarValue::Utf8View(None)) => { } + Expr::Literal(ScalarValue::Utf8View(None), _) => { } // filter out `null` args // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal(ScalarValue::Utf8(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(x) => { + Expr::Literal(x, _) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." ) @@ -376,6 +376,8 @@ mod tests { use crate::utils::test::test_function; use arrow::array::{Array, LargeStringArray, StringViewArray}; use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; use DataType::*; #[test] @@ -468,11 +470,23 @@ mod tests { None, Some("b"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8View, true), + Field::new("a", Utf8View, true), + ] + .into_iter() + .map(Arc::new) + .collect::>(); let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index c2bad206db152..cdd30ac8755ab 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -59,7 +59,7 @@ use datafusion_macros::user_doc; ), related_udf(name = "concat") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ConcatWsFunc { signature: Signature, } @@ -312,6 +312,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { // when the delimiter is an empty string, @@ -336,8 +337,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None), _) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), _) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { @@ -347,7 +348,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result return internal_err!("The scalar {s} should be casted to string type during the type coercion."), + Expr::Literal(s, _) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` and reset it to None. // Then pushing this arg to the `new_args`. @@ -374,10 +375,11 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Utf8(None), + None, ))), } } - Expr::Literal(d) => internal_err!( + Expr::Literal(d, _) => internal_err!( "The scalar {d} should be casted to string type during the type coercion." ), _ => { @@ -394,7 +396,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } @@ -403,10 +405,11 @@ fn is_null(expr: &Expr) -> bool { mod tests { use std::sync::Arc; + use crate::string::concat_ws::ConcatWsFunc; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::datatypes::DataType::Utf8; - - use crate::string::concat_ws::ConcatWsFunc; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -481,10 +484,17 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -511,10 +521,17 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 05a3edf61c5ac..7e50676933c8d 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -46,7 +46,7 @@ use std::sync::Arc; standard_argument(name = "str", prefix = "String"), argument(name = "search_str", description = "The string to search for in str.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ContainsFunc { signature: Signature, } @@ -140,7 +140,7 @@ fn contains(args: &[ArrayRef]) -> Result { } } else { exec_err!( - "Unsupported data type {:?}, {:?} for function `contains`.", + "Unsupported data type {}, {:?} for function `contains`.", args[0].data_type(), args[1].data_type() ) @@ -150,10 +150,12 @@ fn contains(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { use super::ContainsFunc; + use crate::expr_fn::contains; use arrow::array::{BooleanArray, StringArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; #[test] @@ -164,11 +166,17 @@ mod test { Some("yyy?()"), ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("a", DataType::Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![array, scalar], + arg_fields, number_rows: 2, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let actual = udf.invoke_with_args(args).unwrap(); @@ -181,4 +189,19 @@ mod test { *expect.into_array(2).unwrap() ); } + + #[test] + fn test_contains_api() { + let expr = contains( + Expr::Literal( + ScalarValue::Utf8(Some("the quick brown fox".to_string())), + None, + ), + Expr::Literal(ScalarValue::Utf8(Some("row".to_string())), None), + ); + assert_eq!( + expr.to_string(), + "contains(Utf8(\"the quick brown fox\"), Utf8(\"row\"))" + ); + } } diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index eafc310236ee3..6090d9c84d4cd 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -52,7 +52,7 @@ use datafusion_macros::user_doc; standard_argument(name = "str", prefix = "String"), argument(name = "substr", description = "Substring to test for.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct EndsWithFunc { signature: Signature, } diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index a1a486c7d3cf4..2f7894df903d6 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -57,7 +57,7 @@ use datafusion_macros::user_doc; description = "String expression to compute Levenshtein distance with str1." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct LevenshteinFunc { signature: Signature, } diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 226275b139991..ee56a6a549857 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -44,7 +44,7 @@ use datafusion_macros::user_doc; related_udf(name = "initcap"), related_udf(name = "upper") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct LowerFunc { signature: Signature, } @@ -98,15 +98,21 @@ impl ScalarUDFImpl for LowerFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; use std::sync::Arc; fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); + let arg_fields = vec![Field::new("a", input.data_type().clone(), true).into()]; let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - return_type: &DataType::Utf8, + arg_fields, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 65849202efc66..dc6d30d38188c 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -71,7 +71,7 @@ fn ltrim(args: &[ArrayRef]) -> Result { related_udf(name = "btrim"), related_udf(name = "rtrim") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct LtrimFunc { signature: Signature, } diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 4c59e2644456e..b4a026db9f894 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -140,7 +140,8 @@ pub mod expr_fn { "returns uuid v4 as a string value", ), ( contains, - "Return true if search_string is found within string.", + "Return true if `search_string` is found within `string`.", + string search_string )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 17ea2726b071e..aa8257ef8fc53 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -45,7 +45,7 @@ use datafusion_macros::user_doc; related_udf(name = "bit_length"), related_udf(name = "length") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct OctetLengthFunc { signature: Signature, } diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 2d36cb8356a00..3f6128b6516b9 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -51,7 +51,7 @@ use datafusion_macros::user_doc; description = "Number of times to repeat the input string." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RepeatFunc { signature: Signature, } diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index de70215c49c77..f127b452b2d34 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -52,7 +52,7 @@ use datafusion_macros::user_doc; ), standard_argument(name = "replacement", prefix = "Replacement substring") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ReplaceFunc { signature: Signature, } @@ -145,7 +145,7 @@ impl ScalarUDFImpl for ReplaceFunc { } } else { exec_err!( - "Unsupported data type {:?}, {:?}, {:?} for function replace.", + "Unsupported data type {}, {:?}, {:?} for function replace.", data_types[0], data_types[1], data_types[2] diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index bb33274978daf..be0595f65542a 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -71,7 +71,7 @@ fn rtrim(args: &[ArrayRef]) -> Result { related_udf(name = "btrim"), related_udf(name = "ltrim") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RtrimFunc { signature: Signature, } diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 724d9c278cca5..8462dd5149cbf 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -47,7 +47,7 @@ use std::sync::Arc; argument(name = "delimiter", description = "String or character to split on."), argument(name = "pos", description = "Position of the part to return.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SplitPartFunc { signature: Signature, } diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 71df83352f96c..c4159cba86f34 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -74,7 +74,7 @@ fn starts_with(args: &[ArrayRef]) -> Result { standard_argument(name = "str", prefix = "String"), argument(name = "substr", description = "Substring to test for.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct StartsWithFunc { signature: Signature, } @@ -130,7 +130,7 @@ impl ScalarUDFImpl for StartsWithFunc { args: Vec, info: &dyn SimplifyInfo, ) -> Result { - if let Expr::Literal(scalar_value) = &args[1] { + if let Expr::Literal(scalar_value, _) = &args[1] { // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping // Example: starts_with(col, 'ja%') -> col LIKE 'ja\%%' // 1. 'ja%' (input pattern) @@ -141,8 +141,8 @@ impl ScalarUDFImpl for StartsWithFunc { | ScalarValue::LargeUtf8(Some(pattern)) | ScalarValue::Utf8View(Some(pattern)) => { let escaped_pattern = pattern.replace("%", "\\%"); - let like_pattern = format!("{}%", escaped_pattern); - Expr::Literal(ScalarValue::Utf8(Some(like_pattern))) + let like_pattern = format!("{escaped_pattern}%"); + Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None) } _ => return Ok(ExprSimplifyResult::Original(args)), }; diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index a3a1acfcf1f05..26be0066c2df3 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -19,25 +19,29 @@ use std::any::Any; use std::fmt::Write; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringBuilder, OffsetSizeTrait}; +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, GenericStringBuilder}; +use arrow::datatypes::DataType::{ + Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, Utf8, +}; use arrow::datatypes::{ - ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, + ArrowNativeType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; - -use crate::utils::make_scalar_function; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr_common::signature::TypeSignature::Exact; use datafusion_macros::user_doc; /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' pub fn to_hex(args: &[ArrayRef]) -> Result where - T::Native: OffsetSizeTrait, + T::Native: std::fmt::LowerHex, { let integer_array = as_primitive_array::(&args[0])?; @@ -83,7 +87,7 @@ where ```"#, standard_argument(name = "int", prefix = "Integer") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ToHexFunc { signature: Signature, } @@ -96,9 +100,20 @@ impl Default for ToHexFunc { impl ToHexFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + signature: Signature::one_of( + vec![ + Exact(vec![Int8]), + Exact(vec![Int16]), + Exact(vec![Int32]), + Exact(vec![Int64]), + Exact(vec![UInt8]), + Exact(vec![UInt16]), + Exact(vec![UInt32]), + Exact(vec![UInt64]), + ], + Volatility::Immutable, + ), } } } @@ -117,10 +132,8 @@ impl ScalarUDFImpl for ToHexFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - Int8 | Int16 | Int32 | Int64 => Utf8, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Utf8, _ => { return plan_err!("The to_hex function can only accept integers."); } @@ -129,12 +142,14 @@ impl ScalarUDFImpl for ToHexFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { match args.args[0].data_type() { - DataType::Int32 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::Int64 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } + Int64 => make_scalar_function(to_hex::, vec![])(&args.args), + UInt64 => make_scalar_function(to_hex::, vec![])(&args.args), + Int32 => make_scalar_function(to_hex::, vec![])(&args.args), + UInt32 => make_scalar_function(to_hex::, vec![])(&args.args), + Int16 => make_scalar_function(to_hex::, vec![])(&args.args), + UInt16 => make_scalar_function(to_hex::, vec![])(&args.args), + Int8 => make_scalar_function(to_hex::, vec![])(&args.args), + UInt8 => make_scalar_function(to_hex::, vec![])(&args.args), other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } @@ -146,48 +161,92 @@ impl ScalarUDFImpl for ToHexFunc { #[cfg(test)] mod tests { - use arrow::array::{Int32Array, StringArray}; - + use arrow::array::{ + Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, + }; use datafusion_common::cast::as_string_array; use super::*; - #[test] - // Test to_hex function for zero - fn to_hex_zero() -> Result<()> { - let array = vec![0].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("0")]); - assert_eq!(&expected, hex_value); - - Ok(()) + macro_rules! test_to_hex_type { + // Default test with standard input/output + ($name:ident, $arrow_type:ty, $array_type:ty) => { + test_to_hex_type!( + $name, + $arrow_type, + $array_type, + vec![Some(100), Some(0), None], + vec![Some("64"), Some("0"), None] + ); + }; + + // Custom test with custom input/output (eg: positive number) + ($name:ident, $arrow_type:ty, $array_type:ty, $input:expr, $expected:expr) => { + #[test] + fn $name() -> Result<()> { + let input = $input; + let expected = $expected; + + let array = <$array_type>::from(input); + let array_ref = Arc::new(array); + let hex_result = to_hex::<$arrow_type>(&[array_ref])?; + let hex_array = as_string_array(&hex_result)?; + let expected_array = StringArray::from(expected); + + assert_eq!(&expected_array, hex_array); + Ok(()) + } + }; } - #[test] - // Test to_hex function for positive number - fn to_hex_positive_number() -> Result<()> { - let array = vec![100].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("64")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } + test_to_hex_type!( + to_hex_int8, + Int8Type, + Int8Array, + vec![Some(100), Some(0), None, Some(-1)], + vec![Some("64"), Some("0"), None, Some("ffffffffffffffff")] + ); + test_to_hex_type!( + to_hex_int16, + Int16Type, + Int16Array, + vec![Some(100), Some(0), None, Some(-1)], + vec![Some("64"), Some("0"), None, Some("ffffffffffffffff")] + ); + test_to_hex_type!( + to_hex_int32, + Int32Type, + Int32Array, + vec![Some(100), Some(0), None, Some(-1)], + vec![Some("64"), Some("0"), None, Some("ffffffffffffffff")] + ); + test_to_hex_type!( + to_hex_int64, + Int64Type, + Int64Array, + vec![Some(100), Some(0), None, Some(-1)], + vec![Some("64"), Some("0"), None, Some("ffffffffffffffff")] + ); - #[test] - // Test to_hex function for negative number - fn to_hex_negative_number() -> Result<()> { - let array = vec![-1].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("ffffffffffffffff")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } + test_to_hex_type!(to_hex_uint8, UInt8Type, UInt8Array); + test_to_hex_type!(to_hex_uint16, UInt16Type, UInt16Array); + test_to_hex_type!(to_hex_uint32, UInt32Type, UInt32Array); + test_to_hex_type!(to_hex_uint64, UInt64Type, UInt64Array); + + test_to_hex_type!( + to_hex_large_signed, + Int64Type, + Int64Array, + vec![Some(i64::MAX), Some(i64::MIN)], + vec![Some("7fffffffffffffff"), Some("8000000000000000")] + ); + + test_to_hex_type!( + to_hex_large_unsigned, + UInt64Type, + UInt64Array, + vec![Some(u64::MAX), Some(u64::MIN)], + vec![Some("ffffffffffffffff"), Some("0")] + ); } diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 2fec7305d1839..8bb2ec1d511cd 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -43,7 +43,7 @@ use std::any::Any; related_udf(name = "initcap"), related_udf(name = "lower") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct UpperFunc { signature: Signature, } @@ -97,15 +97,21 @@ impl ScalarUDFImpl for UpperFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; use std::sync::Arc; fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); + let arg_field = Field::new("a", input.data_type().clone(), true).into(); let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - return_type: &DataType::Utf8, + arg_fields: vec![arg_field], + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index d1f43d5480660..a5ad6db5354f3 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -42,7 +42,7 @@ use datafusion_macros::user_doc; +--------------------------------------+ ```"# )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct UuidFunc { signature: Signature, } @@ -86,7 +86,7 @@ impl ScalarUDFImpl for UuidFunc { } // Generate random u128 values - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut randoms = vec![0u128; args.number_rows]; rng.fill(&mut randoms[..]); diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs index 6299b353d57a9..108c20e136670 100644 --- a/datafusion/functions/src/strings.rs +++ b/datafusion/functions/src/strings.rs @@ -18,47 +18,12 @@ use std::mem::size_of; use arrow::array::{ - make_view, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ByteView, - GenericStringArray, LargeStringArray, NullBufferBuilder, OffsetSizeTrait, - StringArray, StringViewArray, StringViewBuilder, + make_view, Array, ArrayAccessor, ArrayDataBuilder, ByteView, LargeStringArray, + NullBufferBuilder, StringArray, StringViewArray, StringViewBuilder, }; use arrow::buffer::{MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; -/// Abstracts iteration over different types of string arrays. -#[deprecated(since = "45.0.0", note = "Use arrow::array::StringArrayType instead")] -pub trait StringArrayType<'a>: ArrayAccessor + Sized { - /// Return an [`ArrayIter`] over the values of the array. - /// - /// This iterator iterates returns `Option<&str>` for each item in the array. - fn iter(&self) -> ArrayIter; - - /// Check if the array is ASCII only. - fn is_ascii(&self) -> bool; -} - -#[allow(deprecated)] -impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { - fn iter(&self) -> ArrayIter { - GenericStringArray::::iter(self) - } - - fn is_ascii(&self) -> bool { - GenericStringArray::::is_ascii(self) - } -} - -#[allow(deprecated)] -impl<'a> StringArrayType<'a> for &'a StringViewArray { - fn iter(&self) -> ArrayIter { - StringViewArray::iter(self) - } - - fn is_ascii(&self) -> bool { - StringViewArray::is_ascii(self) - } -} - /// Optimized version of the StringBuilder in Arrow that: /// 1. Precalculating the expected length of the result, avoiding reallocations. /// 2. Avoids creating / incrementally creating a `NullBufferBuilder` diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index c2db253dc7419..85fe0956a951b 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -17,7 +17,7 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveBuilder, + Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, StringArrayType, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; @@ -45,7 +45,7 @@ use std::sync::Arc; related_udf(name = "bit_length"), related_udf(name = "octet_length") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct CharacterLengthFunc { signature: Signature, aliases: Vec, @@ -131,46 +131,45 @@ where T::Native: OffsetSizeTrait, V: StringArrayType<'a>, { - let mut builder = PrimitiveBuilder::::with_capacity(array.len()); - // String characters are variable length encoded in UTF-8, counting the // number of chars requires expensive decoding, however checking if the // string is ASCII only is relatively cheap. // If strings are ASCII only, count bytes instead. let is_array_ascii_only = array.is_ascii(); - if array.null_count() == 0 { + let nulls = array.nulls().cloned(); + let array = { if is_array_ascii_only { - for i in 0..array.len() { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.len())); - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + // Safety: we are iterating with array.len() so the index is always valid + let value = unsafe { array.value_unchecked(i) }; + T::Native::usize_as(value.len()) + }) + .collect(); + PrimitiveArray::::new(values.into(), nulls) } else { - for i in 0..array.len() { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.chars().count())); - } - } - } else if is_array_ascii_only { - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null(); - } else { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.len())); - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + // Safety: we are iterating with array.len() so the index is always valid + if array.is_null(i) { + T::default_value() + } else { + let value = unsafe { array.value_unchecked(i) }; + if value.is_empty() { + T::default_value() + } else if value.is_ascii() { + T::Native::usize_as(value.len()) + } else { + T::Native::usize_as(value.chars().count()) + } + } + }) + .collect(); + PrimitiveArray::::new(values.into(), nulls) } - } else { - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null(); - } else { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.chars().count())); - } - } - } + }; - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(array)) } #[cfg(test)] diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index c4a9f067e9f4f..fa68e539600b0 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -53,7 +53,7 @@ use datafusion_macros::user_doc; description = "A string list is a string composed of substrings separated by , characters." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct FindInSetFunc { signature: Signature, } @@ -348,7 +348,8 @@ mod tests { use crate::unicode::find_in_set::FindInSetFunc; use crate::utils::test::test_function; use arrow::array::{Array, Int32Array, StringArray}; - use arrow::datatypes::DataType::Int32; + use arrow::datatypes::{DataType::Int32, Field}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; @@ -471,10 +472,19 @@ mod tests { }) .unwrap_or(1); let return_type = fis.return_type(&type_array)?; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, a)| { + Field::new(format!("arg_{idx}"), a.data_type(), true).into() + }) + .collect::>(); let result = fis.invoke_with_args(ScalarFunctionArgs { args, + arg_fields, number_rows: cardinality, - return_type: &return_type, + return_field: Field::new("f", return_type, true).into(), + config_options: Arc::new(ConfigOptions::default()), }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/initcap.rs b/datafusion/functions/src/unicode/initcap.rs index c9b0cb77b0969..62862fbe78980 100644 --- a/datafusion/functions/src/unicode/initcap.rs +++ b/datafusion/functions/src/unicode/initcap.rs @@ -50,7 +50,7 @@ use datafusion_macros::user_doc; related_udf(name = "lower"), related_udf(name = "upper") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct InitcapFunc { signature: Signature, } @@ -131,10 +131,11 @@ fn initcap(args: &[ArrayRef]) -> Result { string_array.value_data().len(), ); + let mut container = String::new(); string_array.iter().for_each(|str| match str { Some(s) => { - let initcap_str = initcap_string(s); - builder.append_value(initcap_str); + initcap_string(s, &mut container); + builder.append_value(&container); } None => builder.append_null(), }); @@ -147,10 +148,11 @@ fn initcap_utf8view(args: &[ArrayRef]) -> Result { let mut builder = StringViewBuilder::with_capacity(string_view_array.len()); + let mut container = String::new(); string_view_array.iter().for_each(|str| match str { Some(s) => { - let initcap_str = initcap_string(s); - builder.append_value(initcap_str); + initcap_string(s, &mut container); + builder.append_value(&container); } None => builder.append_null(), }); @@ -158,31 +160,29 @@ fn initcap_utf8view(args: &[ArrayRef]) -> Result { Ok(Arc::new(builder.finish()) as ArrayRef) } -fn initcap_string(input: &str) -> String { - let mut result = String::with_capacity(input.len()); +fn initcap_string(input: &str, container: &mut String) { + container.clear(); let mut prev_is_alphanumeric = false; if input.is_ascii() { for c in input.chars() { if prev_is_alphanumeric { - result.push(c.to_ascii_lowercase()); + container.push(c.to_ascii_lowercase()); } else { - result.push(c.to_ascii_uppercase()); + container.push(c.to_ascii_uppercase()); }; prev_is_alphanumeric = c.is_ascii_alphanumeric(); } } else { for c in input.chars() { if prev_is_alphanumeric { - result.extend(c.to_lowercase()); + container.extend(c.to_lowercase()); } else { - result.extend(c.to_uppercase()); + container.extend(c.to_uppercase()); } prev_is_alphanumeric = c.is_alphanumeric(); } } - - result } #[cfg(test)] diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index f99f0de67ebb2..fceb2a131a2b0 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -53,7 +53,7 @@ use datafusion_macros::user_doc; argument(name = "n", description = "Number of characters to return."), related_udf(name = "right") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct LeftFunc { signature: Signature, } diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index ea57dbd2bed51..621dbd4970f26 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -56,7 +56,7 @@ use datafusion_macros::user_doc; ), related_udf(name = "rpad") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct LPadFunc { signature: Signature, } @@ -204,11 +204,15 @@ where V2: StringArrayType<'a>, T: OffsetSizeTrait, { - let array = if fill_array.is_none() { + let array = if let Some(fill_array) = fill_array { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - for (string, length) in string_array.iter().zip(length_array.iter()) { - if let (Some(string), Some(length)) = (string, length) { + for ((string, length), fill) in string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + { + if let (Some(string), Some(length), Some(fill)) = (string, length, fill) { if length > i32::MAX as i64 { return exec_err!("lpad requested length {length} too large"); } @@ -220,10 +224,17 @@ where } let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + if length < graphemes.len() { builder.append_value(graphemes[..length].concat()); + } else if fill_chars.is_empty() { + builder.append_value(string); } else { - builder.write_str(" ".repeat(length - graphemes.len()).as_str())?; + for l in 0..length - graphemes.len() { + let c = *fill_chars.get(l % fill_chars.len()).unwrap(); + builder.write_char(c)?; + } builder.write_str(string)?; builder.append_value(""); } @@ -236,12 +247,8 @@ where } else { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - for ((string, length), fill) in string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.unwrap().iter()) - { - if let (Some(string), Some(length), Some(fill)) = (string, length, fill) { + for (string, length) in string_array.iter().zip(length_array.iter()) { + if let (Some(string), Some(length)) = (string, length) { if length > i32::MAX as i64 { return exec_err!("lpad requested length {length} too large"); } @@ -253,17 +260,10 @@ where } let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - if length < graphemes.len() { builder.append_value(graphemes[..length].concat()); - } else if fill_chars.is_empty() { - builder.append_value(string); } else { - for l in 0..length - graphemes.len() { - let c = *fill_chars.get(l % fill_chars.len()).unwrap(); - builder.write_char(c)?; - } + builder.write_str(" ".repeat(length - graphemes.len()).as_str())?; builder.write_str(string)?; builder.append_value(""); } diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 3c5cde3789ea2..4a0dd21d749af 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -26,6 +26,7 @@ pub mod find_in_set; pub mod initcap; pub mod left; pub mod lpad; +pub mod planner; pub mod reverse; pub mod right; pub mod rpad; diff --git a/datafusion/functions/src/unicode/planner.rs b/datafusion/functions/src/unicode/planner.rs new file mode 100644 index 0000000000000..e4f29be3d13dc --- /dev/null +++ b/datafusion/functions/src/unicode/planner.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL planning extensions like [`UnicodeFunctionPlanner`] + +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::{ExprPlanner, PlannerResult}; +use datafusion_expr::Expr; + +#[derive(Default, Debug)] +pub struct UnicodeFunctionPlanner; + +impl ExprPlanner for UnicodeFunctionPlanner { + fn plan_position( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::strpos(), args), + ))) + } + + fn plan_substring( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::substr(), args), + ))) + } +} diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 311e9e81a8be9..500e762ec250b 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -44,7 +44,7 @@ use DataType::{LargeUtf8, Utf8, Utf8View}; ```"#, standard_argument(name = "str", prefix = "String") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ReverseFunc { signature: Signature, } diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 1ceaf69983311..c492f606e9c5b 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -53,7 +53,7 @@ use datafusion_macros::user_doc; argument(name = "n", description = "Number of characters to return."), related_udf(name = "left") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RightFunc { signature: Signature, } diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index c68c4d329c74d..6ec78b07980b8 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -55,7 +55,7 @@ use DataType::{LargeUtf8, Utf8, Utf8View}; ), related_udf(name = "lpad") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct RPadFunc { signature: Signature, } diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index b3bc73a295852..4f238b2644bdf 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -22,7 +22,9 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType, }; -use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, +}; use datafusion_common::types::logical_string; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ @@ -47,7 +49,7 @@ use datafusion_macros::user_doc; standard_argument(name = "str", prefix = "String"), argument(name = "substr", description = "Substring expression to search for.") )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct StrposFunc { signature: Signature, aliases: Vec, @@ -88,16 +90,23 @@ impl ScalarUDFImpl for StrposFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be used instead") + internal_err!("return_field_from_args should be used instead") } - fn return_type_from_args( + fn return_field_from_args( &self, - args: datafusion_expr::ReturnTypeArgs, - ) -> Result { - utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| { - datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x)) - }) + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map( + |data_type| { + Field::new( + self.name(), + data_type, + args.arg_fields.iter().any(|x| x.is_nullable()), + ) + .into() + }, + ) } fn invoke_with_args( @@ -228,7 +237,7 @@ mod tests { use arrow::array::{Array, Int32Array, Int64Array}; use arrow::datatypes::DataType::{Int32, Int64}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -321,15 +330,15 @@ mod tests { fn nullable_return_type() { fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool { let strpos = StrposFunc::new(); - let args = datafusion_expr::ReturnTypeArgs { - arg_types: &[DataType::Utf8, DataType::Utf8], - nullables: &[string_array_nullable, substring_nullable], + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[ + Field::new("f1", DataType::Utf8, string_array_nullable).into(), + Field::new("f2", DataType::Utf8, substring_nullable).into(), + ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], }; - let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts(); - - nullable + strpos.return_field_from_args(args).unwrap().is_nullable() } assert!(!get_nullable(false, false)); diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 4dcbea4807f44..0b35f664532d4 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -56,7 +56,7 @@ use datafusion_macros::user_doc; description = "Number of characters to extract. If not specified, returns the rest of the string after the start position." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SubstrFunc { signature: Signature, aliases: Vec, diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 9a18b5d23c5ee..a7ee7388f9013 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -62,7 +62,7 @@ If count is negative, everything to the right of the final delimiter (counting f description = "The number of times to search for the delimiter. Can be either a positive or negative number." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SubstrIndexFunc { signature: Signature, aliases: Vec, diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 8b4894643a7a3..911b8d311996e 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -52,7 +52,7 @@ use datafusion_macros::user_doc; description = "Translation characters. Translation characters replace only characters at the same position in the **chars** string." ) )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct TranslateFunc { signature: Signature, } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 47f3121ba2ce9..932d61e8007cd 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; +use arrow::compute::try_binary; use arrow::datatypes::DataType; - -use datafusion_common::{Result, ScalarValue}; +use arrow::error::ArrowError; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; +use std::sync::Arc; /// Creates a function to identify the optimal return type of a string function given /// the type of its first argument. @@ -75,7 +77,7 @@ get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); /// Creates a scalar function implementation for the given function. /// * `inner` - the function to be executed /// * `hints` - hints to be used when expanding scalars to arrays -pub(super) fn make_scalar_function( +pub fn make_scalar_function( inner: F, hints: Vec, ) -> impl Fn(&[ColumnarValue]) -> Result @@ -120,6 +122,76 @@ where } } +/// Computes a binary math function for input arrays using a specified function. +/// Generic types: +/// - `L`: Left array primitive type +/// - `R`: Right array primitive type +/// - `O`: Output array primitive type +/// - `F`: Functor computing `fun(l: L, r: R) -> Result` +pub fn calculate_binary_math( + left: &dyn Array, + right: &ColumnarValue, + fun: F, +) -> Result>> +where + R: ArrowPrimitiveType, + L: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(L::Native, R::Native) -> Result, + R::Native: TryFrom, +{ + Ok(match right { + ColumnarValue::Scalar(scalar) => { + let right_value: R::Native = + R::Native::try_from(scalar.clone()).map_err(|_| { + DataFusionError::NotImplemented(format!( + "Cannot convert scalar value {} to {}", + &scalar, + R::DATA_TYPE + )) + })?; + let left_array = left.as_primitive::(); + // Bind right value + let result = + left_array.try_unary::<_, O, _>(|lvalue| fun(lvalue, right_value))?; + Arc::new(result) as _ + } + ColumnarValue::Array(right) => { + let right_casted = arrow::compute::cast(&right, &R::DATA_TYPE)?; + let right_array = right_casted.as_primitive::(); + + // Types are compatible even they are decimals with different scale or precision + let result = if PrimitiveArray::::is_compatible(&L::DATA_TYPE) { + let left_array = left.as_primitive::(); + try_binary::<_, _, _, O>(left_array, right_array, &fun)? + } else { + let left_casted = arrow::compute::cast(left, &L::DATA_TYPE)?; + let left_array = left_casted.as_primitive::(); + try_binary::<_, _, _, O>(left_array, right_array, &fun)? + }; + Arc::new(result) as _ + } + }) +} + +/// Converts Decimal128 components (value and scale) to an unscaled i128 +pub fn decimal128_to_i128(value: i128, scale: i8) -> Result { + if scale < 0 { + Err(ArrowError::ComputeError( + "Negative scale is not supported".into(), + )) + } else if scale == 0 { + Ok(value) + } else { + match i128::from(10).checked_pow(scale as u32) { + Some(divisor) => Ok(value / divisor), + None => Err(ArrowError::ComputeError(format!( + "Cannot get a power of {scale}" + ))), + } + } +} + #[cfg(test)] pub mod test { /// $FUNC ScalarUDFImpl to test @@ -128,19 +200,20 @@ pub mod test { /// $EXPECTED_TYPE is the expected value type /// $EXPECTED_DATA_TYPE is the expected result type /// $ARRAY_TYPE is the column type after function applied + /// $CONFIG_OPTIONS config options to pass to function macro_rules! test_function { - ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { - let expected: Result> = $EXPECTED; - let func = $FUNC; - - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); - let cardinality = $ARGS - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }) - .unwrap_or(1); + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => { + let expected: Result> = $EXPECTED; + let func = $FUNC; + + let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let cardinality = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); let scalar_arguments = $ARGS.iter().map(|arg| match arg { ColumnarValue::Scalar(scalar) => Some(scalar.clone()), @@ -153,51 +226,83 @@ pub mod test { ColumnarValue::Array(a) => a.null_count() > 0, }).collect::>(); - let return_info = func.return_type_from_args(datafusion_expr::ReturnTypeArgs { - arg_types: &type_array, - scalar_arguments: &scalar_arguments_refs, - nullables: &nullables - }); + let field_array = data_array.into_iter().zip(nullables).enumerate() + .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable)) + .map(std::sync::Arc::new) + .collect::>(); + + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &field_array, + scalar_arguments: &scalar_arguments_refs, + }); + let arg_fields = $ARGS.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); - match expected { - Ok(expected) => { - assert_eq!(return_info.is_ok(), true); - let (return_type, _nullable) = return_info.unwrap().into_parts(); - assert_eq!(return_type, $EXPECTED_DATA_TYPE); + match expected { + Ok(expected) => { + assert_eq!(return_field.is_ok(), true); + let return_field = return_field.unwrap(); + let return_type = return_field.data_type(); + assert_eq!(return_type, &$EXPECTED_DATA_TYPE); - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ + args: $ARGS, + arg_fields, + number_rows: cardinality, + return_field, + config_options: $CONFIG_OPTIONS + }); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); - assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE); + assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE); - // value is correct - match expected { - Some(v) => assert_eq!(result.value(0), v), - None => assert!(result.is_null(0)), - }; - } - Err(expected_error) => { - if return_info.is_err() { - match return_info { - Ok(_) => assert!(false, "expected error"), - Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } - } - } - else { - let (return_type, _nullable) = return_info.unwrap().into_parts(); - - // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}) { - Ok(_) => assert!(false, "expected error"), - Err(error) => { - assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); - } + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + if let Ok(return_field) = return_field { + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs { + args: $ARGS, + arg_fields, + number_rows: cardinality, + return_field, + config_options: $CONFIG_OPTIONS, + }) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error + .strip_backtrace() + .starts_with(&error.strip_backtrace())); } } + } else if let Err(error) = return_field { + datafusion_common::assert_contains!( + expected_error.strip_backtrace(), + error.strip_backtrace() + ); } - }; + } + }; + }; + + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + test_function!( + $FUNC, + $ARGS, + $EXPECTED, + $EXPECTED_TYPE, + $EXPECTED_DATA_TYPE, + $ARRAY_TYPE, + std::sync::Arc::new(datafusion_common::config::ConfigOptions::default()) + ) }; } @@ -218,4 +323,31 @@ pub mod test { let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap(); assert_eq!(v, DataType::Int64); } + + #[test] + fn test_decimal128_to_i128() { + let cases = [ + (123, 0, Some(123)), + (1230, 1, Some(123)), + (123000, 3, Some(123)), + (1, 0, Some(1)), + (123, -3, None), + (123, i8::MAX, None), + (i128::MAX, 0, Some(i128::MAX)), + (i128::MAX, 3, Some(i128::MAX / 1000)), + ]; + + for (value, scale, expected) in cases { + match decimal128_to_i128(value, scale) { + Ok(actual) => { + assert_eq!( + actual, + expected.expect("Got value but expected none"), + "{value} and {scale} vs {expected:?}" + ); + } + Err(_) => assert!(expected.is_none()), + } + } + } } diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml index c6532aa046810..fe979720bc566 100644 --- a/datafusion/macros/Cargo.toml +++ b/datafusion/macros/Cargo.toml @@ -19,6 +19,7 @@ name = "datafusion-macros" description = "Procedural macros for DataFusion query engine" keywords = ["datafusion", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } homepage = { workspace = true } @@ -40,6 +41,6 @@ path = "src/user_doc.rs" proc-macro = true [dependencies] -datafusion-expr = { workspace = true } -quote = "1.0.40" -syn = { version = "2.0.100", features = ["full"] } +datafusion-doc = { workspace = true } +quote = "1.0.41" +syn = { version = "2.0.106", features = ["full"] } diff --git a/datafusion/macros/README.md b/datafusion/macros/README.md new file mode 100644 index 0000000000000..c45bba1423fc2 --- /dev/null +++ b/datafusion/macros/README.md @@ -0,0 +1,30 @@ + + +# Apache DataFusion Macros + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate contains common macros used in DataFusion + +Most projects should use the [`datafusion`] crate directly. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs index c6510c1564232..71ce381ec4318 100644 --- a/datafusion/macros/src/user_doc.rs +++ b/datafusion/macros/src/user_doc.rs @@ -19,10 +19,10 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] extern crate proc_macro; -use datafusion_expr::scalar_doc_sections::doc_sections_const; +use datafusion_doc::scalar_doc_sections::doc_sections_const; use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, DeriveInput, LitStr}; @@ -206,7 +206,7 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { }; let doc_section_description = doc_section_desc .map(|desc| quote! { Some(#desc)}) - .unwrap_or(quote! { None }); + .unwrap_or_else(|| quote! { None }); let sql_example = sql_example.map(|ex| { quote! { diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 61d101aab3f8e..f10510e0973c3 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,16 +45,18 @@ arrow = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } regex = { workspace = true } -regex-syntax = "0.8.0" +regex-syntax = "0.8.6" [dev-dependencies] async-trait = { workspace = true } +criterion = { workspace = true } ctor = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } @@ -62,3 +64,7 @@ datafusion-functions-window-common = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } + +[[bench]] +name = "projection_unnecessary" +harness = false diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index 61bc1cd70145b..a95ec4828b35e 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -17,6 +17,18 @@ under the License. --> -Please see [Query Optimizer] in the Library User Guide +# Apache DataFusion Optimizer +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate contains the DataFusion logical optimizer. +Please see [Query Optimizer] in the Library User Guide for more information. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion [query optimizer]: https://datafusion.apache.org/library-user-guide/query-optimizer.html diff --git a/datafusion/optimizer/benches/projection_unnecessary.rs b/datafusion/optimizer/benches/projection_unnecessary.rs new file mode 100644 index 0000000000000..c9f248fe49b5a --- /dev/null +++ b/datafusion/optimizer/benches/projection_unnecessary.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ToDFSchema; +use datafusion_common::{Column, TableReference}; +use datafusion_expr::{logical_plan::LogicalPlan, projection_schema, Expr}; +use datafusion_optimizer::optimize_projections::is_projection_unnecessary; +use std::sync::Arc; + +fn is_projection_unnecessary_old( + input: &LogicalPlan, + proj_exprs: &[Expr], +) -> datafusion_common::Result { + // First check if all expressions are trivial (cheaper operation than `projection_schema`) + if !proj_exprs + .iter() + .all(|expr| matches!(expr, Expr::Column(_) | Expr::Literal(_, _))) + { + return Ok(false); + } + let proj_schema = projection_schema(input, proj_exprs)?; + Ok(&proj_schema == input.schema()) +} + +fn create_plan_with_many_exprs(num_exprs: usize) -> (LogicalPlan, Vec) { + // Create schema with many fields + let fields = (0..num_exprs) + .map(|i| Field::new(format!("col{i}"), DataType::Int32, false)) + .collect::>(); + let schema = Schema::new(fields); + + // Create table scan + let table_scan = LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(schema.clone().to_dfschema().unwrap()), + }); + + // Create projection expressions (just column references) + let exprs = (0..num_exprs) + .map(|i| Expr::Column(Column::new(None::, format!("col{i}")))) + .collect(); + + (table_scan, exprs) +} + +fn benchmark_is_projection_unnecessary(c: &mut Criterion) { + let (plan, exprs) = create_plan_with_many_exprs(1000); + + let mut group = c.benchmark_group("projection_unnecessary_comparison"); + + group.bench_function("is_projection_unnecessary_new", |b| { + b.iter(|| black_box(is_projection_unnecessary(&plan, &exprs).unwrap())) + }); + + group.bench_function("is_projection_unnecessary_old", |b| { + b.iter(|| black_box(is_projection_unnecessary_old(&plan, &exprs).unwrap())) + }); + + group.finish(); +} + +criterion_group!(benches, benchmark_is_projection_unnecessary); +criterion_main!(benches); diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 2517e3c3a4006..272692f983683 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -38,14 +38,6 @@ pub mod function_rewrite; pub mod resolve_grouping_function; pub mod type_coercion; -pub mod subquery { - #[deprecated( - since = "44.0.0", - note = "please use `datafusion_expr::check_subquery_expr` instead" - )] - pub use datafusion_expr::check_subquery_expr; -} - /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index f8a8185636090..fa7ff1b8b19d6 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -189,19 +189,19 @@ fn grouping_function_on_id( // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 if !is_grouping_set { - return Ok(Expr::Literal(ScalarValue::from(0i32))); + return Ok(Expr::Literal(ScalarValue::from(0i32), None)); } let group_by_expr_count = group_by_expr.len(); let literal = |value: usize| { if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8)) + Expr::Literal(ScalarValue::from(value as u8), None) } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16)) + Expr::Literal(ScalarValue::from(value as u16), None) } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32)) + Expr::Literal(ScalarValue::from(value as u32), None) } else { - Expr::Literal(ScalarValue::from(value as u64)) + Expr::Literal(ScalarValue::from(value as u64), None) } }; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index d47f7ea6ce68c..3d5dee3a72559 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion_expr::binary::BinaryTypeCoercer; -use itertools::izip; +use itertools::{izip, Itertools as _}; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -29,8 +29,9 @@ use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, - DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, + plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, + TableReference, }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, @@ -41,7 +42,7 @@ use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; use datafusion_expr::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, }; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, @@ -50,9 +51,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_u use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan, - Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, - WindowFrameUnits, + AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, + ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -253,7 +253,7 @@ impl<'a> TypeCoercionRewriter<'a> { if dt.is_integer() || dt.is_null() { expr.cast_to(&DataType::Int64, schema) } else { - plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}") + plan_err!("Expected {expr_name} to be an integer or null, but got {dt}") } } @@ -352,9 +352,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { .data; let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); - let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( - "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" - ), + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( + plan_datafusion_err!( + "expr type {expr_type} can't cast to {subquery_type} in InSubquery" + ), )?; let new_subquery = Subquery { subquery: Arc::new(new_plan), @@ -440,23 +441,23 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" - )) + ) })?; let high_type = high.get_type(self.schema)?; let high_coerced_type = comparison_coercion(&expr_type, &high_type) .ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" - )) + ) })?; let coercion_type = comparison_coercion(&low_coerced_type, &high_coerced_type) .ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" - )) + ) })?; Ok(Transformed::yes(Expr::Between(Between::new( Box::new(expr.cast_to(&coercion_type, self.schema)?), @@ -479,7 +480,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { get_coerce_type_for_list(&expr_data_type, &list_data_types); match result_type { None => plan_err!( - "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" + "Can not find compatible types to compare {expr_data_type} with [{}]", list_data_types.iter().join(", ") ), Some(coerced_type) => { // find the coerced type @@ -539,17 +540,20 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), ))) } - Expr::WindowFunction(WindowFunction { - fun, - params: - expr::WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + expr::WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + }, + } = *window_fun; let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -564,21 +568,26 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { _ => args, }; - Ok(Transformed::yes( - Expr::WindowFunction(WindowFunction::new(fun, args)) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?, - )) + let new_expr = Expr::from(WindowFunction { + fun, + params: expr::WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + }, + }); + Ok(Transformed::yes(new_expr)) } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Alias(_) | Expr::Column(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::SimilarTo(_) | Expr::IsNotNull(_) | Expr::IsNull(_) @@ -678,7 +687,7 @@ fn coerce_scalar_range_aware( // If type coercion fails, check if the largest type in family works: if let Some(largest_type) = get_widest_type_in_family(target_type) { coerce_scalar(largest_type, value).map_or_else( - |_| exec_err!("Cannot cast {value:?} to {target_type:?}"), + |_| exec_err!("Cannot cast {value:?} to {target_type}"), |_| ScalarValue::try_from(target_type), ) } else { @@ -718,6 +727,9 @@ fn coerce_frame_bound( fn extract_window_frame_target_type(col_type: &DataType) -> Result { if col_type.is_numeric() || is_utf8_or_utf8view_or_large_utf8(col_type) + || matches!(col_type, DataType::List(_)) + || matches!(col_type, DataType::LargeList(_)) + || matches!(col_type, DataType::FixedSizeList(_, _)) || matches!(col_type, DataType::Null) || matches!(col_type, DataType::Boolean) { @@ -727,7 +739,7 @@ fn extract_window_frame_target_type(col_type: &DataType) -> Result { } else if let DataType::Dictionary(_, value_type) = col_type { extract_window_frame_target_type(value_type) } else { - return internal_err!("Cannot run range queries on datatype: {col_type:?}"); + internal_err!("Cannot run range queries on datatype: {col_type}") } } @@ -808,12 +820,15 @@ fn coerce_arguments_for_signature_with_aggregate_udf( return Ok(expressions); } - let current_types = expressions + let current_fields = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_types = data_types_with_aggregate_udf(¤t_types, func)?; + let new_types = fields_with_aggregate_udf(¤t_fields, func)? + .into_iter() + .map(|f| f.data_type().clone()) + .collect::>(); expressions .into_iter() @@ -883,8 +898,9 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { get_coerce_type_for_case_expression(&when_types, Some(case_type)); coerced_type.ok_or_else(|| { plan_datafusion_err!( - "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ - to common types in CASE WHEN expression" + "Failed to coerce case ({case_type}) and when ({}) \ + to common types in CASE WHEN expression", + when_types.iter().join(", ") ) }) }) @@ -892,10 +908,19 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { let then_else_coerce_type = get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else( || { - plan_datafusion_err!( - "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \ - to common types in CASE WHEN expression" - ) + if let Some(else_type) = else_type { + plan_datafusion_err!( + "Failed to coerce then ({}) and else ({else_type}) \ + to common types in CASE WHEN expression", + then_types.iter().join(", ") + ) + } else { + plan_datafusion_err!( + "Failed to coerce then ({}) and else (None) \ + to common types in CASE WHEN expression", + then_types.iter().join(", ") + ) + } }, )?; @@ -937,6 +962,43 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { /// /// This method presumes that the wildcard expansion is unneeded, or has already /// been applied. +/// +/// ## Schema and Field Handling in Union Coercion +/// +/// **Processing order**: The function starts with the base schema (first input) and then +/// processes remaining inputs sequentially, with later inputs taking precedence in merging. +/// +/// **Schema-level metadata merging**: Later schemas take precedence for duplicate keys. +/// +/// **Field-level metadata merging**: Later fields take precedence for duplicate metadata keys. +/// +/// **Type coercion precedence**: The coerced type is determined by iteratively applying +/// `comparison_coercion()` between the accumulated type and each new input's type. The +/// result depends on type coercion rules, not input order. +/// +/// **Nullability merging**: Nullability is accumulated using logical OR (`||`). +/// Once any input field is nullable, the result field becomes nullable permanently. +/// Later inputs can make a field nullable but cannot make it non-nullable. +/// +/// **Field precedence**: Field names come from the first (base) schema, but the field properties +/// (nullability and field-level metadata) have later schemas taking precedence. +/// +/// **Example**: +/// ```sql +/// SELECT a, b FROM table1 -- a: Int32, metadata {"source": "t1"}, nullable=false +/// UNION +/// SELECT a, b FROM table2 -- a: Int64, metadata {"source": "t2"}, nullable=true +/// UNION +/// SELECT a, b FROM table3 -- a: Int32, metadata {"encoding": "utf8"}, nullable=false +/// -- Result: +/// -- a: Int64 (from type coercion), nullable=true (from table2), +/// -- metadata: {"source": "t2", "encoding": "utf8"} (later inputs take precedence) +/// ``` +/// +/// **Precedence Summary**: +/// - **Datatypes**: Determined by `comparison_coercion()` rules, not input order +/// - **Nullability**: Later inputs can add nullability but cannot remove it (logical OR) +/// - **Metadata**: Later inputs take precedence for same keys (HashMap::extend semantics) pub fn coerce_union_schema(inputs: &[Arc]) -> Result { coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema()) } @@ -1055,12 +1117,13 @@ mod test { use arrow::datatypes::DataType::Utf8; use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit}; + use insta::assert_snapshot; use crate::analyzer::type_coercion::{ coerce_case_expression, TypeCoercion, TypeCoercionRewriter, }; use crate::analyzer::Analyzer; - use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; + use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; @@ -1096,13 +1159,80 @@ mod test { })) } + macro_rules! assert_analyzed_plan_eq { + ( + $plan: expr, + @ $expected: literal $(,)? + ) => {{ + let options = ConfigOptions::default(); + let rule = Arc::new(TypeCoercion::new()); + assert_analyzed_plan_with_config_eq_snapshot!( + options, + rule, + $plan, + @ $expected, + ) + }}; + } + + macro_rules! coerce_on_output_if_viewtype { + ( + $is_viewtype: expr, + $plan: expr, + @ $expected: literal $(,)? + ) => {{ + let mut options = ConfigOptions::default(); + // coerce on output + if $is_viewtype {options.optimizer.expand_views_at_output = true;} + let rule = Arc::new(TypeCoercion::new()); + + assert_analyzed_plan_with_config_eq_snapshot!( + options, + rule, + $plan, + @ $expected, + ) + }}; + } + + fn assert_type_coercion_error( + plan: LogicalPlan, + expected_substr: &str, + ) -> Result<()> { + let options = ConfigOptions::default(); + let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]); + + match analyzer.execute_and_check(plan, &options, |_, _| {}) { + Ok(succeeded_plan) => { + panic!( + "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}" + ); + } + Err(e) => { + let msg = e.to_string(); + assert!( + msg.contains(expected_substr), + "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`" + ); + } + } + + Ok(()) + } + #[test] fn simple_case() -> Result<()> { let expr = col("a").lt(lit(2_u32)); let empty = empty_with_type(DataType::Float64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a < CAST(UInt32(2) AS Float64) + EmptyRelation: rows=0 + " + ) } #[test] @@ -1137,28 +1267,15 @@ mod test { Arc::new(analyzed_union), )?); - let expected = "Projection: a\n Union\n Projection: CAST(datafusion.test.foo.a AS Int64) AS a\n EmptyRelation\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), top_level_plan, expected) - } - - fn coerce_on_output_if_viewtype(plan: LogicalPlan, expected: &str) -> Result<()> { - let mut options = ConfigOptions::default(); - options.optimizer.expand_views_at_output = true; - - assert_analyzed_plan_with_config_eq( - options, - Arc::new(TypeCoercion::new()), - plan.clone(), - expected, - ) - } - - fn do_not_coerce_on_output(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_analyzed_plan_with_config_eq( - ConfigOptions::default(), - Arc::new(TypeCoercion::new()), - plan.clone(), - expected, + assert_analyzed_plan_eq!( + top_level_plan, + @r" + Projection: a + Union + Projection: CAST(datafusion.test.foo.a AS Int64) AS a + EmptyRelation: rows=0 + EmptyRelation: rows=0 + " ) } @@ -1172,12 +1289,26 @@ mod test { vec![expr.clone()], Arc::clone(&empty), )?); + // Plan A: no coerce - let if_not_coerced = "Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation: rows=0 + " + )?; + // Plan A: coerce requested: Utf8View => LargeUtf8 - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + EmptyRelation: rows=0 + " + )?; // Plan B // scenario: outermost bool projection @@ -1187,12 +1318,33 @@ mod test { Arc::clone(&empty), )?); // Plan B: no coerce - let if_not_coerced = - "Projection: a < CAST(Utf8(\"foo\") AS Utf8View)\n EmptyRelation"; - do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + bool_plan.clone(), + @r#" + Projection: a < CAST(Utf8("foo") AS Utf8View) + EmptyRelation: rows=0 + "# + )?; + + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation: rows=0 + " + )?; + // Plan B: coerce requested: no coercion applied - let if_coerced = if_not_coerced; - coerce_on_output_if_viewtype(bool_plan, if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + EmptyRelation: rows=0 + " + )?; // Plan C // scenario: with a non-projection root logical plan node @@ -1202,13 +1354,29 @@ mod test { input: Arc::new(plan), fetch: None, }); + // Plan C: no coerce - let if_not_coerced = - "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + sort_plan.clone(), + @r" + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation: rows=0 + " + )?; + // Plan C: coerce requested: Utf8View => LargeUtf8 - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + sort_plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation: rows=0 + " + )?; // Plan D // scenario: two layers of projections with view types @@ -1217,11 +1385,27 @@ mod test { Arc::new(sort_plan), )?); // Plan D: no coerce - let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation: rows=0 + " + )?; // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation: rows=0 + " + )?; Ok(()) } @@ -1236,12 +1420,26 @@ mod test { vec![expr.clone()], Arc::clone(&empty), )?); + // Plan A: no coerce - let if_not_coerced = "Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation: rows=0 + " + )?; + // Plan A: coerce requested: BinaryView => LargeBinary - let if_coerced = "Projection: CAST(a AS LargeBinary)\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + EmptyRelation: rows=0 + " + )?; // Plan B // scenario: outermost bool projection @@ -1250,13 +1448,26 @@ mod test { vec![bool_expr], Arc::clone(&empty), )?); + // Plan B: no coerce - let if_not_coerced = - "Projection: a < CAST(Binary(\"8,1,8,1\") AS BinaryView)\n EmptyRelation"; - do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + bool_plan.clone(), + @r#" + Projection: a < CAST(Binary("8,1,8,1") AS BinaryView) + EmptyRelation: rows=0 + "# + )?; + // Plan B: coerce requested: no coercion applied - let if_coerced = if_not_coerced; - coerce_on_output_if_viewtype(bool_plan, if_coerced)?; + coerce_on_output_if_viewtype!( + true, + bool_plan.clone(), + @r#" + Projection: a < CAST(Binary("8,1,8,1") AS BinaryView) + EmptyRelation: rows=0 + "# + )?; // Plan C // scenario: with a non-projection root logical plan node @@ -1266,13 +1477,28 @@ mod test { input: Arc::new(plan), fetch: None, }); + // Plan C: no coerce - let if_not_coerced = - "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + sort_plan.clone(), + @r" + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation: rows=0 + " + )?; // Plan C: coerce requested: BinaryView => LargeBinary - let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + sort_plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation: rows=0 + " + )?; // Plan D // scenario: two layers of projections with view types @@ -1280,12 +1506,30 @@ mod test { vec![col("a")], Arc::new(sort_plan), )?); + // Plan D: no coerce - let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation: rows=0 + " + )?; + // Plan B: coerce requested: BinaryView => LargeBinary only on outermost - let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation: rows=0 + " + )?; Ok(()) } @@ -1299,12 +1543,17 @@ mod test { vec![expr.clone().or(expr)], empty, )?); - let expected = "Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64) + EmptyRelation: rows=0 + " + ) } - #[derive(Debug, Clone)] + #[derive(Debug, PartialEq, Eq, Hash)] struct TestScalarUDF { signature: Signature, } @@ -1340,9 +1589,14 @@ mod test { }) .call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); - let expected = - "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: TestScalarUDF(CAST(Int32(123) AS Float32)) + EmptyRelation: rows=0 + " + ) } #[test] @@ -1372,9 +1626,14 @@ mod test { vec![scalar_function_expr], empty, )?); - let expected = - "Projection: TestScalarUDF(CAST(Int64(10) AS Float32))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: TestScalarUDF(CAST(Int64(10) AS Float32)) + EmptyRelation: rows=0 + " + ) } #[test] @@ -1393,12 +1652,18 @@ mod test { vec![lit(10i64)], false, None, - None, + vec![], None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); - let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: MY_AVG(CAST(Int64(10) AS Float64)) + EmptyRelation: rows=0 + " + ) } #[test] @@ -1413,8 +1678,8 @@ mod test { return_type, accumulator, vec![ - Field::new("count", DataType::UInt64, true), - Field::new("avg", DataType::Float64, true), + Field::new("count", DataType::UInt64, true).into(), + Field::new("avg", DataType::Float64, true).into(), ], )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -1422,13 +1687,13 @@ mod test { vec![lit("10")], false, None, - None, + vec![], None, )); let err = Projection::try_new(vec![udaf], empty).err().unwrap(); assert!( - err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from [Utf8] to the signature Uniform(1, [Float64]) failed") + err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed") ); Ok(()) } @@ -1441,12 +1706,18 @@ mod test { vec![lit(12f64)], false, None, - None, + vec![], None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: avg(Float64(12))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: avg(Float64(12)) + EmptyRelation: rows=0 + " + )?; let empty = empty_with_type(DataType::Int32); let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -1454,13 +1725,18 @@ mod test { vec![cast(col("a"), DataType::Float64)], false, None, - None, + vec![], None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: avg(CAST(a AS Float64)) + EmptyRelation: rows=0 + " + ) } #[test] @@ -1471,14 +1747,14 @@ mod test { vec![lit("1")], false, None, - None, + vec![], None, )); let err = Projection::try_new(vec![agg_expr], empty) .err() .unwrap() .strip_backtrace(); - assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed")); + assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed")); Ok(()) } @@ -1489,10 +1765,14 @@ mod test { + lit(ScalarValue::new_interval_dt(123, 456)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"IntervalDayTime { days: 123, milliseconds: 456 }\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }") + EmptyRelation: rows=0 + "# + ) } #[test] @@ -1501,8 +1781,12 @@ mod test { let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) + EmptyRelation: rows=0 + ")?; // a in (1,4,8), a is decimal let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); @@ -1514,8 +1798,12 @@ mod test { )?), })); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) + EmptyRelation: rows=0 + ") } #[test] @@ -1528,10 +1816,14 @@ mod test { ); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); - let expected = - "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) AND CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\")\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r#" + Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") + EmptyRelation: rows=0 + "# + ) } #[test] @@ -1544,11 +1836,15 @@ mod test { ); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); + // TODO: we should cast col(a). - let expected = - "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AND CAST(Utf8(\"2002-12-08\") AS Date32)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + assert_analyzed_plan_eq!( + plan, + @r#" + Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32) + EmptyRelation: rows=0 + "# + ) } #[test] @@ -1556,10 +1852,14 @@ mod test { let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64)); let empty = empty(); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); - let expected = - "Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2) + EmptyRelation: rows=0 + " + ) } #[test] @@ -1569,37 +1869,60 @@ mod test { let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); - let expected = "Projection: a IS TRUE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS TRUE + EmptyRelation: rows=0 + " + )?; let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, ""); - let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); + assert_type_coercion_error( + plan, + "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean" + )?; // is not true let expr = col("a").is_not_true(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT TRUE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT TRUE + EmptyRelation: rows=0 + " + )?; // is false let expr = col("a").is_false(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS FALSE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS FALSE + EmptyRelation: rows=0 + " + )?; // is not false let expr = col("a").is_not_false(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT FALSE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT FALSE + EmptyRelation: rows=0 + " + ) } #[test] @@ -1610,27 +1933,38 @@ mod test { let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a LIKE Utf8("abc") + EmptyRelation: rows=0 + "# + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a LIKE CAST(NULL AS Utf8) + EmptyRelation: rows=0 + " + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains( - "There isn't a common type to coerce Int64 and Utf8 in LIKE expression" - )); + assert_type_coercion_error( + plan, + "There isn't a common type to coerce Int64 and Utf8 in LIKE expression", + )?; // ilike let expr = Box::new(col("a")); @@ -1638,27 +1972,39 @@ mod test { let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a ILIKE Utf8("abc") + EmptyRelation: rows=0 + "# + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a ILIKE CAST(NULL AS Utf8) + EmptyRelation: rows=0 + " + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains( - "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression" - )); + assert_type_coercion_error( + plan, + "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression", + )?; + Ok(()) } @@ -1669,23 +2015,34 @@ mod test { let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); - let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS UNKNOWN + EmptyRelation: rows=0 + " + )?; let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); + assert_type_coercion_error( + plan, + "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean" + )?; // is not unknown let expr = col("a").is_not_unknown(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT UNKNOWN + EmptyRelation: rows=0 + " + ) } #[test] @@ -1694,21 +2051,19 @@ mod test { let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; // concat-type signature - { - let expr = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::variadic(vec![Utf8], Volatility::Immutable), - }) - .call(args.to_vec()); - let plan = LogicalPlan::Projection(Projection::try_new( - vec![expr], - Arc::clone(&empty), - )?); - let expected = - "Projection: TestScalarUDF(a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - } - - Ok(()) + let expr = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + }) + .call(args.to_vec()); + let plan = + LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?); + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8)) + EmptyRelation: rows=0 + "# + ) } #[test] @@ -1758,10 +2113,14 @@ mod test { .eq(cast(lit("1998-03-18"), DataType::Date32)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(Nanosecond, None)) + EmptyRelation: rows=0 + "# + ) } fn cast_if_not_same_type( @@ -1882,12 +2241,9 @@ mod test { else_expr: Some(Box::new(col("string"))), }; let err = coerce_case_expression(case, &schema).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: \ - Failed to coerce case (Interval(MonthDayNano)) and \ - when ([Float32, Binary, Utf8]) to common types in \ - CASE WHEN expression" + @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression" ); let case = Case { @@ -1900,12 +2256,9 @@ mod test { else_expr: Some(Box::new(col("timestamp"))), }; let err = coerce_case_expression(case, &schema).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: \ - Failed to coerce then ([Date32, Float32, Binary]) and \ - else (Some(Timestamp(Nanosecond, None))) to common types \ - in CASE WHEN expression" + @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(Nanosecond, None)) to common types in CASE WHEN expression" ); Ok(()) @@ -2103,17 +2456,19 @@ mod test { let map_type_entries = DataType::Map(Arc::new(fields), false); let fields = Field::new("key_value", DataType::Struct(struct_fields), false); - let may_type_cutsom = DataType::Map(Arc::new(fields), false); + let may_type_custom = DataType::Map(Arc::new(fields), false); - let expr = col("a").eq(cast(col("a"), may_type_cutsom)); + let expr = col("a").eq(cast(col("a"), may_type_custom)); let empty = empty_with_type(map_type_entries); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a = CAST(CAST(a AS Map(Field { name: \"key_value\", data_type: Struct([Field { name: \"key\", data_type: Utf8, \ - nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), \ - nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: \"entries\", data_type: Struct([Field { name: \"key\", data_type: Utf8, nullable: false, \ - dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))\n \ - EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a = CAST(CAST(a AS Map(Field { name: "key_value", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) + EmptyRelation: rows=0 + "# + ) } #[test] @@ -2129,9 +2484,14 @@ mod test { )); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(Nanosecond, None)) + EmptyRelation: rows=0 + "# + ) } #[test] @@ -2149,10 +2509,14 @@ mod test { )); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) - CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) + EmptyRelation: rows=0 + "# + ) } #[test] @@ -2171,14 +2535,17 @@ mod test { )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?); // add cast for subquery - let expected = "\ - Filter: a IN ()\ - \n Subquery:\ - \n Projection: CAST(a AS Int64)\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r" + Filter: a IN () + Subquery: + Projection: CAST(a AS Int64) + EmptyRelation: rows=0 + EmptyRelation: rows=0 + " + ) } #[test] @@ -2196,14 +2563,17 @@ mod test { false, )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?); + // add cast for subquery - let expected = "\ - Filter: CAST(a AS Int64) IN ()\ - \n Subquery:\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(a AS Int64) IN () + Subquery: + EmptyRelation: rows=0 + EmptyRelation: rows=0 + " + ) } #[test] @@ -2221,13 +2591,17 @@ mod test { false, )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?); + // add cast for subquery - let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN ()\ - \n Subquery:\ - \n Projection: CAST(a AS Decimal128(13, 8))\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(a AS Decimal128(13, 8)) IN () + Subquery: + Projection: CAST(a AS Decimal128(13, 8)) + EmptyRelation: rows=0 + EmptyRelation: rows=0 + " + ) } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 69b5fbb9f8c0f..ec1f8f991a8ee 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -316,6 +316,19 @@ impl CommonSubexprEliminate { } => { let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); let new_aggr_expr = original_exprs_list.pop().unwrap(); + let saved_names = if let Some(aggr_expr) = aggr_expr { + let name_preserver = NamePreserver::new_for_projection(); + aggr_expr + .iter() + .map(|expr| Some(name_preserver.save(expr))) + .collect::>() + } else { + new_aggr_expr + .clone() + .into_iter() + .map(|_| None) + .collect::>() + }; let mut agg_exprs = common_exprs .into_iter() @@ -326,10 +339,19 @@ impl CommonSubexprEliminate { for expr in &new_group_expr { extract_expressions(expr, &mut proj_exprs) } - for (expr_rewritten, expr_orig) in - rewritten_aggr_expr.into_iter().zip(new_aggr_expr) + for ((expr_rewritten, expr_orig), saved_name) in + rewritten_aggr_expr + .into_iter() + .zip(new_aggr_expr) + .zip(saved_names) { if expr_rewritten == expr_orig { + let expr_rewritten = if let Some(saved_name) = saved_name + { + saved_name.restore(expr_rewritten) + } else { + expr_rewritten + }; if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { @@ -803,23 +825,39 @@ mod test { use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::OptimizerContext; use crate::test::*; - use crate::Optimizer; use datafusion_expr::test::function_stub::{avg, sum}; - fn assert_optimized_plan_eq( - expected: &str, - plan: LogicalPlan, - config: Option<&dyn OptimizerConfig>, - ) { - let optimizer = - Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]); - let default_config = OptimizerContext::new(); - let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap(); - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(expected, formatted_plan); + macro_rules! assert_optimized_plan_equal { + ( + $config:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(CommonSubexprEliminate::new())]; + assert_optimized_plan_eq_snapshot!( + $config, + rules, + $plan, + @ $expected, + ) + }}; + + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(CommonSubexprEliminate::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -844,13 +882,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\ - \n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]] + Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -864,13 +903,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -886,12 +926,12 @@ mod test { Signature::exact(vec![DataType::UInt32], Volatility::Stable), return_type.clone(), Arc::clone(&accumulator), - vec![Field::new("value", DataType::UInt32, true)], + vec![Field::new("value", DataType::UInt32, true).into()], ))), vec![inner], false, None, - None, + vec![], None, )) }; @@ -917,11 +957,14 @@ mod test { )? .build()?; - let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c) + Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]] + TableScan: test + " + )?; // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -936,11 +979,14 @@ mod test { )? .build()?; - let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\ - \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a) + Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]] + TableScan: test + " + )?; // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -953,11 +999,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -970,11 +1019,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -991,14 +1043,15 @@ mod test { )? .build()?; - let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\ - \n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a) + Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1018,14 +1071,15 @@ mod test { )? .build()?; - let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\ - \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\ - \n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\ - \n TableScan: table.test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a) + Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]] + Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a + TableScan: table.test + " + ) } #[test] @@ -1039,13 +1093,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS first, __common_expr_1 AS second\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS first, __common_expr_1 AS second + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1056,13 +1111,14 @@ mod test { .project(vec![lit(1) + col("a"), col("a") + lit(1)])? .build()?; - let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1) + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1074,12 +1130,14 @@ mod test { .project(vec![lit(1) + col("a")])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n Projection: Int32(1) + test.a, test.a\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: Int32(1) + test.a + Projection: Int32(1) + test.a, test.a + TableScan: test + " + ) } #[test] @@ -1193,14 +1251,15 @@ mod test { .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))? .build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 - Int32(10) > __common_expr_1\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 - Int32(10) > __common_expr_1 + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1226,7 +1285,7 @@ mod test { fn test_alias_collision() -> Result<()> { let table_scan = test_table_scan()?; - let config = &OptimizerContext::new(); + let config = OptimizerContext::new(); let common_expr_1 = config.alias_generator().next(CSE_PREFIX); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![ @@ -1241,14 +1300,18 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4\ - \n Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c\ - \n Projection: test.a + test.b AS __common_expr_1, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, Some(config)); - - let config = &OptimizerContext::new(); + assert_optimized_plan_equal!( + config, + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4 + Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c + Projection: test.a + test.b AS __common_expr_1, test.c + TableScan: test + " + )?; + + let config = OptimizerContext::new(); let _common_expr_1 = config.alias_generator().next(CSE_PREFIX); let common_expr_2 = config.alias_generator().next(CSE_PREFIX); let plan = LogicalPlanBuilder::from(table_scan) @@ -1264,12 +1327,16 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4\ - \n Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c\ - \n Projection: test.a + test.b AS __common_expr_2, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, Some(config)); + assert_optimized_plan_equal!( + config, + plan, + @ r" + Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4 + Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c + Projection: test.a + test.b AS __common_expr_2, test.c + TableScan: test + " + )?; Ok(()) } @@ -1308,13 +1375,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\ - \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5 + Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1331,13 +1399,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2 + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1360,13 +1429,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\ - \n Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4 + Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1382,14 +1452,15 @@ mod test { .project(vec![col("c1"), col("c2")])? .build()?; - let expected = "Projection: c1, c2\ - \n Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: c1, c2 + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1405,14 +1476,15 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ - \n Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c\ - \n Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c + Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1422,13 +1494,15 @@ mod test { let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 * __common_expr_1 = Int32(30) + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1438,13 +1512,15 @@ mod test { let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1454,13 +1530,15 @@ mod test { let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1470,13 +1548,15 @@ mod test { let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1486,13 +1566,15 @@ mod test { let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1502,13 +1584,15 @@ mod test { let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a"))); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 AND __common_expr_1\ - \n Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 AND __common_expr_1 + Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1518,13 +1602,15 @@ mod test { let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a"))); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 AND __common_expr_1\ - \n Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 AND __common_expr_1 + Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1535,11 +1621,15 @@ mod test { .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\ - \n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 - __common_expr_1 = Int32(30) + Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1)) let table_scan = test_table_scan()?; @@ -1548,11 +1638,16 @@ mod test { + col("a")) .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\ - \n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30) + Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // c2 / (c1 + c3) <=> c2 / (c3 + c1) let table_scan = test_table_scan()?; @@ -1560,16 +1655,20 @@ mod test { * (col("b") / (col("c") + col("a")))) .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ - \n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 * __common_expr_1 = Int32(30) + Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; Ok(()) } - #[derive(Debug)] + #[derive(Debug, PartialEq, Eq, Hash)] pub struct TestUdf { signature: Signature, } @@ -1612,10 +1711,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\ - \n Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a + Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // is_null(a == b) <=> is_null(b == a) let table_scan = test_table_scan()?; @@ -1624,10 +1727,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\ - \n Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL + Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // a + b between 0 and 10 <=> b + a between 0 and 10 let table_scan = test_table_scan()?; @@ -1636,10 +1743,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\ - \n Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10) + Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // c between a + b and 10 <=> c between b + a and 10 let table_scan = test_table_scan()?; @@ -1648,10 +1759,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\ - \n Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10) + Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // function call with argument <=> function call with argument let udf = ScalarUDF::from(TestUdf::new()); @@ -1661,11 +1776,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\ - \n Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a) + Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } /// returns a "random" function that is marked volatile (aka each invocation @@ -1677,7 +1795,7 @@ mod test { ScalarUDF::new_from_impl(RandomStub::new()) } - #[derive(Debug)] + #[derive(Debug, PartialEq, Eq, Hash)] struct RandomStub { signature: Signature, } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 71ff863b51a18..63236787743a4 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -71,6 +71,9 @@ pub struct PullUpCorrelatedExpr { pub collected_count_expr_map: HashMap, /// pull up having expr, which must be evaluated after the Join pub pull_up_having_expr: Option, + /// whether we have converted a scalar aggregation into a group aggregation. When unnesting + /// lateral joins, we need to produce a left outer join in such cases. + pub pulled_up_scalar_agg: bool, } impl Default for PullUpCorrelatedExpr { @@ -91,6 +94,7 @@ impl PullUpCorrelatedExpr { need_handle_count_bug: false, collected_count_expr_map: HashMap::new(), pull_up_having_expr: None, + pulled_up_scalar_agg: false, } } @@ -313,6 +317,11 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { missing_exprs.push(un_matched_row); } } + if aggregate.group_expr.is_empty() { + // TODO: how do we handle the case where we have pulled multiple aggregations? For example, + // a group agg with a scalar agg as child. + self.pulled_up_scalar_agg = true; + } let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone()) .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())? .build()?; @@ -485,9 +494,12 @@ fn agg_exprs_evaluation_result_on_empty_batch( let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { if func.name() == "count" { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + Transformed::yes(Expr::Literal( + ScalarValue::Int64(Some(0)), + None, + )) } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null, None)) } } _ => Transformed::no(expr), @@ -501,10 +513,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; - if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) { - expr_result_map_for_count_bug - .insert(e.schema_name().to_string(), result_expr); - } + expr_result_map_for_count_bug.insert(e.schema_name().to_string(), result_expr); } Ok(()) } @@ -581,10 +590,10 @@ fn filter_exprs_evaluation_result_on_empty_batch( let result_expr = simplifier.simplify(result_expr)?; match &result_expr { // evaluate to false or null on empty batch, no need to pull up - Expr::Literal(ScalarValue::Null) - | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, + Expr::Literal(ScalarValue::Null, _) + | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None, // evaluate to true on empty batch, need to pull up the expr - Expr::Literal(ScalarValue::Boolean(Some(true))) => { + Expr::Literal(ScalarValue::Boolean(Some(true)), _) => { for (name, exprs) in input_expr_result_map_for_count_bug { expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); } @@ -599,7 +608,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( Box::new(result_expr.clone()), Box::new(input_expr.clone()), )], - else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))), }); let expr_key = new_expr.schema_name().to_string(); expr_result_map_for_count_bug.insert(expr_key, new_expr); diff --git a/datafusion/optimizer/src/decorrelate_lateral_join.rs b/datafusion/optimizer/src/decorrelate_lateral_join.rs new file mode 100644 index 0000000000000..7d2072ad1ce99 --- /dev/null +++ b/datafusion/optimizer/src/decorrelate_lateral_join.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins. + +use std::collections::BTreeSet; + +use crate::decorrelate::PullUpCorrelatedExpr; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_expr::{lit, Join}; + +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::Result; +use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::utils::conjunction; +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; + +/// Optimizer rule for rewriting lateral joins to joins +#[derive(Default, Debug)] +pub struct DecorrelateLateralJoin {} + +impl DecorrelateLateralJoin { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } +} + +impl OptimizerRule for DecorrelateLateralJoin { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + // Find cross joins with outer column references on the right side (i.e., the apply operator). + let LogicalPlan::Join(join) = plan else { + return Ok(Transformed::no(plan)); + }; + + rewrite_internal(join) + } + + fn name(&self) -> &str { + "decorrelate_lateral_join" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +// Build the decorrelated join based on the original lateral join query. For now, we only support cross/inner +// lateral joins. +fn rewrite_internal(join: Join) -> Result> { + if join.join_type != JoinType::Inner { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + match join.right.apply_with_subqueries(|p| { + // TODO: support outer joins + if p.contains_outer_reference() { + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + })? { + TreeNodeRecursion::Stop => {} + TreeNodeRecursion::Continue => { + // The left side contains outer references, we need to decorrelate it. + return Ok(Transformed::new( + LogicalPlan::Join(join), + false, + TreeNodeRecursion::Jump, + )); + } + TreeNodeRecursion::Jump => { + unreachable!("") + } + } + + let LogicalPlan::Subquery(subquery) = join.right.as_ref() else { + return Ok(Transformed::no(LogicalPlan::Join(join))); + }; + + if join.join_type != JoinType::Inner { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + let subquery_plan = subquery.subquery.as_ref(); + let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); + let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?; + if !pull_up.can_pull_up { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + let mut all_correlated_cols = BTreeSet::new(); + pull_up + .correlated_subquery_cols_map + .values() + .for_each(|cols| all_correlated_cols.extend(cols.clone())); + let join_filter_opt = conjunction(pull_up.join_filters); + let join_filter = match join_filter_opt { + Some(join_filter) => join_filter, + None => lit(true), + }; + // -- inner join but the right side always has one row, we need to rewrite it to a left join + // SELECT * FROM t0, LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); + // -- inner join but the right side number of rows is related to the filter (join) condition, so keep inner join. + // SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); + let new_plan = LogicalPlanBuilder::from(join.left) + .join_on( + rewritten_subquery, + if pull_up.pulled_up_scalar_agg { + JoinType::Left + } else { + JoinType::Inner + }, + Some(join_filter), + )? + .build()?; + // TODO: handle count(*) bug + Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump)) +} diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index c18c48251daa2..c8be689fc5a42 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -31,7 +31,7 @@ use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::utils::{conjunction, split_conjunction_owned}; +use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned}; use datafusion_expr::{ exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, @@ -342,7 +342,7 @@ fn build_join( replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some) })?; - let join_filter = match (join_filter_opt, in_predicate_opt) { + let join_filter = match (join_filter_opt, in_predicate_opt.clone()) { ( Some(join_filter), Some(Expr::BinaryExpr(BinaryExpr { @@ -371,6 +371,51 @@ fn build_join( (None, None) => lit(true), _ => return Ok(None), }; + + if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) { + let right_schema = sub_query_alias.schema(); + + // Gather all columns needed for the join filter + predicates + let mut needed = std::collections::HashSet::new(); + expr_to_columns(&join_filter, &mut needed)?; + if let Some(ref in_pred) = in_predicate_opt { + expr_to_columns(in_pred, &mut needed)?; + } + + // Keep only columns that actually belong to the RIGHT child, and sort by their + // position in the right schema for deterministic order. + let mut right_cols_idx_and_col: Vec<(usize, Column)> = needed + .into_iter() + .filter_map(|c| right_schema.index_of_column(&c).ok().map(|idx| (idx, c))) + .collect(); + + right_cols_idx_and_col.sort_by_key(|(idx, _)| *idx); + + let right_proj_exprs: Vec = right_cols_idx_and_col + .into_iter() + .map(|(_, c)| Expr::Column(c)) + .collect(); + + let right_projected = if !right_proj_exprs.is_empty() { + LogicalPlanBuilder::from(sub_query_alias.clone()) + .project(right_proj_exprs)? + .build()? + } else { + // Degenerate case: no right columns referenced by the predicate(s) + sub_query_alias.clone() + }; + let new_plan = LogicalPlanBuilder::from(left.clone()) + .join_on(right_projected, join_type, Some(join_filter))? + .build()?; + + debug!( + "predicate subquery optimized:\n{}", + new_plan.display_indent() + ); + + return Ok(Some(new_plan)); + } + // join our sub query into the main plan let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(sub_query_alias, join_type, Some(join_filter))? @@ -427,17 +472,23 @@ mod tests { use super::*; use crate::test::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::builder::table_source; use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(DecorrelatePredicateSubquery::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } fn test_subquery_with_name(name: &str) -> Result> { @@ -461,17 +512,21 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_1.c [c:UInt32]\ - \n TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ - \n Projection: sq_2.c [c:UInt32]\ - \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq_1.c [c:UInt32] + TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [c:UInt32] + Projection: sq_2.c [c:UInt32] + TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for IN subquery with additional AND filter @@ -489,15 +544,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for nested IN subqueries @@ -515,18 +573,21 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ - \n Projection: sq.a [a:UInt32]\ - \n LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_nested.c [c:UInt32]\ - \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [a:UInt32] + Projection: sq.a [a:UInt32] + LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq_nested.c [c:UInt32] + TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test multiple correlated subqueries @@ -551,23 +612,21 @@ mod tests { .build()?; debug!("plan to optimize:\n{}", plan.display_indent()); - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) + assert_optimized_plan_equal!( + plan, + @r###" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + "### + ) } /// Test recursive correlated subqueries @@ -601,23 +660,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64] + Projection: lineitem.l_orderkey [l_orderkey:Int64] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated IN subquery filter with additional subquery filters @@ -639,20 +696,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery with no columns in schema @@ -673,19 +728,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for IN subquery with both columns in schema @@ -703,20 +756,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery not equal @@ -737,19 +788,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery less than @@ -770,19 +819,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery filter with subquery disjunction @@ -804,20 +851,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\ - \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64] + Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN without projection @@ -861,19 +905,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN expressions @@ -894,19 +936,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery multiple projected columns @@ -959,20 +999,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery filter @@ -990,19 +1028,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single IN subquery filter @@ -1014,19 +1050,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single NOT IN subquery filter @@ -1038,19 +1072,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1061,19 +1093,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1087,19 +1117,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1116,19 +1144,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32]\ - \n Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32] + Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1150,20 +1176,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32]\ - \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32] + Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32] + Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1186,20 +1210,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ - \n Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ - \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32] + Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32] + Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1228,24 +1250,22 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32] + Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32] + TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32] + Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32] + TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1263,20 +1283,18 @@ mod tests { .build()?; // Subquery and outer query refer to the same table. - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: test.c [c:UInt32]\ - \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: test.c [c:UInt32] + Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for multiple exists subqueries in the same filter expression @@ -1297,17 +1315,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test recursive correlated subqueries @@ -1340,17 +1362,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64] + Projection: lineitem.l_orderkey [l_orderkey:Int64] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated exists subquery filter with additional subquery filters @@ -1372,15 +1398,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1398,14 +1427,17 @@ mod tests { .build()?; // Other rule will pushdown `customer.c_custkey = 1`, - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for exists subquery with both columns in schema @@ -1423,14 +1455,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery not equal @@ -1451,14 +1487,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery less than @@ -1479,14 +1518,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with subquery disjunction @@ -1508,14 +1550,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\ - \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64] + Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists without projection @@ -1535,13 +1580,16 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists expressions @@ -1562,14 +1610,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with additional filters @@ -1589,15 +1640,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with disjunctions @@ -1615,16 +1669,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\ - \n LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated EXISTS subquery filter @@ -1642,14 +1699,17 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = "Projection: test.c [c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c [c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single exists subquery filter @@ -1661,13 +1721,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single NOT exists subquery filter @@ -1679,13 +1743,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1712,19 +1780,22 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq1.c, sq1.a [c:UInt32, a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32]\ - \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq1.c, sq1.a [c:UInt32, a:UInt32] + TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32] + Projection: sq2.c, sq2.a [c:UInt32, a:UInt32] + TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1743,14 +1814,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32]\ - \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32] + Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1768,15 +1842,18 @@ mod tests { .build()?; // Subquery and outer query refer to the same table. - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: test.c [c:UInt32]\ - \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: test.c [c:UInt32] + Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1796,15 +1873,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Distinct: [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Distinct: [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1824,15 +1904,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32]\ - \n Distinct: [sq.b + sq.c:UInt32, a:UInt32]\ - \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32] + Distinct: [sq.b + sq.c:UInt32, a:UInt32] + Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1852,15 +1935,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32] + Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32] + Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1884,13 +1970,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [arr:Int32;N]\ - \n Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]\ - \n TableScan: sq [arr:List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [arr:Int32;N] + Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N] + TableScan: sq [arr:List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] + "# + ) } #[test] @@ -1915,14 +2005,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32;N]\ - \n Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]\ - \n TableScan: sq [a:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [a:UInt32;N] + Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N] + TableScan: sq [a:List(Field { name: "item", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] + "# + ) } #[test] @@ -1946,13 +2039,16 @@ mod tests { .project(vec![col("\"TEST_A\".\"B\"")])? .build()?; - let expected = "Projection: TEST_A.B [B:UInt32]\ - \n LeftSemi Join: Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32]\ - \n TableScan: TEST_A [A:UInt32, B:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32]\ - \n Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32]\ - \n TableScan: TEST_B [A:UInt32, B:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: TEST_A.B [B:UInt32] + LeftSemi Join: Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32] + TableScan: TEST_A [A:UInt32, B:UInt32] + SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32] + Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32] + TableScan: TEST_B [A:UInt32, B:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index d35572e6d34a3..ae1d7df46d52e 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::join_key_set::JoinKeySet; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::Result; +use datafusion_common::{NullEquality, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, @@ -89,6 +89,7 @@ impl OptimizerRule for EliminateCrossJoin { let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; let mut all_filters: Vec = vec![]; + let mut null_equality = NullEquality::NullEqualsNothing; let parent_predicate = if let LogicalPlan::Filter(filter) = plan { // if input isn't a join that can potentially be rewritten @@ -113,6 +114,12 @@ impl OptimizerRule for EliminateCrossJoin { let Filter { input, predicate, .. } = filter; + + // Extract null_equality setting from the input join + if let LogicalPlan::Join(join) = input.as_ref() { + null_equality = join.null_equality; + } + flatten_join_inputs( Arc::unwrap_or_clone(input), &mut possible_join_keys, @@ -122,26 +129,30 @@ impl OptimizerRule for EliminateCrossJoin { extract_possible_join_keys(&predicate, &mut possible_join_keys); Some(predicate) - } else if matches!( - plan, - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - ) { - if !can_flatten_join_inputs(&plan) { - return Ok(Transformed::no(plan)); - } - flatten_join_inputs( - plan, - &mut possible_join_keys, - &mut all_inputs, - &mut all_filters, - )?; - None } else { - // recursively try to rewrite children - return rewrite_children(self, plan, config); + match plan { + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + null_equality: original_null_equality, + .. + }) => { + if !can_flatten_join_inputs(&plan) { + return Ok(Transformed::no(plan)); + } + flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + &mut all_filters, + )?; + null_equality = original_null_equality; + None + } + _ => { + // recursively try to rewrite children + return rewrite_children(self, plan, config); + } + } }; // Join keys are handled locally: @@ -153,6 +164,7 @@ impl OptimizerRule for EliminateCrossJoin { &mut all_inputs, &possible_join_keys, &mut all_join_keys, + null_equality, )?; } @@ -290,6 +302,7 @@ fn find_inner_join( rights: &mut Vec, possible_join_keys: &JoinKeySet, all_join_keys: &mut JoinKeySet, + null_equality: NullEquality, ) -> Result { for (i, right_input) in rights.iter().enumerate() { let mut join_keys = vec![]; @@ -328,7 +341,7 @@ fn find_inner_join( on: join_keys, filter: None, schema: join_schema, - null_equals_null: false, + null_equality, })); } } @@ -350,7 +363,7 @@ fn find_inner_join( filter: None, join_type: JoinType::Inner, join_constraint: JoinConstraint::On, - null_equals_null: false, + null_equality, })) } @@ -440,22 +453,28 @@ mod tests { logical_plan::builder::LogicalPlanBuilder, Operator::{And, Or}, }; + use insta::assert_snapshot; + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let starting_schema = Arc::clone($plan.schema()); + let rule = EliminateCrossJoin::new(); + let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap(); + let formatted_plan = optimized_plan.display_indent_schema(); + // Ensure the rule was actually applied + assert!(is_plan_transformed, "failed to optimize plan"); + // Verify the schema remains unchanged + assert_eq!(&starting_schema, optimized_plan.schema()); + assert_snapshot!( + formatted_plan, + @ $expected, + ); - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) { - let starting_schema = Arc::clone(plan.schema()); - let rule = EliminateCrossJoin::new(); - let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(transformed_plan.transformed, "failed to optimize plan"); - let optimized_plan = transformed_plan.data; - let formatted = optimized_plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - assert_eq!(&starting_schema, optimized_plan.schema()) + Ok(()) + }}; } #[test] @@ -473,16 +492,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -501,16 +519,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -528,16 +545,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -559,15 +575,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -589,15 +605,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -615,15 +631,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -644,19 +660,18 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - let expected = vec![ - "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -691,19 +706,18 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -765,22 +779,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -840,22 +853,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -915,22 +927,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -994,22 +1005,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1083,21 +1093,20 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1177,20 +1186,19 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1208,15 +1216,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1235,16 +1243,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1263,16 +1270,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1291,16 +1297,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1328,17 +1333,81 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) + } + + #[test] + fn preserve_null_equality_setting() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Create an inner join with NullEquality::NullEqualsNull + let join_schema = Arc::new(build_join_schema( + t1.schema(), + t2.schema(), + &JoinType::Inner, + )?); + + let inner_join = LogicalPlan::Join(Join { + left: Arc::new(t1), + right: Arc::new(t2), + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + on: vec![], + filter: None, + schema: join_schema, + null_equality: NullEquality::NullEqualsNull, // Test preservation + }); + + // Apply filter that can create join conditions + let plan = LogicalPlanBuilder::from(inner_join) + .filter(binary_expr( + col("t1.a").eq(col("t2.a")), + And, + col("t2.c").lt(lit(20u32)), + ))? + .build()?; + + let rule = EliminateCrossJoin::new(); + let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data; + + // Verify that null_equality is preserved in the optimized plan + fn check_null_equality_preserved(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Join(join) => { + // All joins in the optimized plan should preserve null equality + if join.null_equality == NullEquality::NullEqualsNothing { + return false; + } + // Recursively check child plans + plan.inputs() + .iter() + .all(|input| check_null_equality_preserved(input)) + } + _ => { + // Recursively check child plans for non-join nodes + plan.inputs() + .iter() + .all(|input| check_null_equality_preserved(input)) + } + } + } + + assert!( + check_null_equality_preserved(&optimized_plan), + "null_equality setting should be preserved after optimization" + ); Ok(()) } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 4669500920956..a6651df938a70 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -118,16 +118,26 @@ impl OptimizerRule for EliminateDuplicatedExpr { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - crate::test::assert_optimized_plan_eq( - Arc::new(EliminateDuplicatedExpr::new()), - plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateDuplicatedExpr::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -137,10 +147,12 @@ mod tests { .sort_by(vec![col("a"), col("a"), col("b"), col("c")])? .limit(5, Some(10))? .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST + TableScan: test + ") } #[test] @@ -156,9 +168,11 @@ mod tests { .sort(sort_exprs)? .limit(5, Some(10))? .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 4ed2ac8ba1a4e..1b763d6f8957b 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -60,7 +60,7 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v)), + predicate: Expr::Literal(ScalarValue::Boolean(v), _), input, .. }) => match v { @@ -81,17 +81,29 @@ impl OptimizerRule for EliminateFilter { mod tests { use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ - col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, - }; + use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr}; use crate::eliminate_filter::EliminateFilter; use crate::test::*; use datafusion_expr::test::function_stub::sum; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateFilter::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -105,13 +117,12 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation: rows=0") } #[test] fn filter_null() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + let filter_expr = Expr::Literal(ScalarValue::Boolean(None), None); let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) @@ -120,8 +131,7 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation: rows=0") } #[test] @@ -139,11 +149,12 @@ mod tests { .build()?; // Left side is removed - let expected = "Union\ - \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + EmptyRelation: rows=0 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -156,9 +167,10 @@ mod tests { .filter(filter_expr)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -176,12 +188,13 @@ mod tests { .build()?; // Filter is removed - let expected = "Union\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -202,8 +215,9 @@ mod tests { .build()?; // Filter is removed - let expected = "Projection: test.a\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a + EmptyRelation: rows=0 + ") } } diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 7e252d6dcea0e..4e16fc0aa159c 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -101,7 +101,7 @@ fn is_constant_expression(expr: &Expr) -> bool { Expr::BinaryExpr(e) => { is_constant_expression(&e.left) && is_constant_expression(&e.right) } - Expr::Literal(_) => true, + Expr::Literal(_, _) => true, Expr::ScalarFunction(e) => { matches!( e.func.signature().volatility, @@ -115,7 +115,9 @@ fn is_constant_expression(expr: &Expr) -> bool { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -129,7 +131,23 @@ mod tests { use std::sync::Arc; - #[derive(Debug)] + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateGroupByConstant::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; + } + + #[derive(Debug, PartialEq, Eq, Hash)] struct ScalarUDFMock { signature: Signature, } @@ -167,17 +185,11 @@ mod tests { .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: test.a, UInt32(1), count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, UInt32(1), count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -187,17 +199,11 @@ mod tests { .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: Utf8(\"test\"), UInt32(123), count(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r#" + Projection: Utf8("test"), UInt32(123), count(test.c) + Aggregate: groupBy=[[]], aggr=[[count(test.c)]] + TableScan: test + "#) } #[test] @@ -207,16 +213,10 @@ mod tests { .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -226,16 +226,10 @@ mod tests { .aggregate(vec![lit(123u32)], Vec::::new())? .build()?; - let expected = "\ - Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[UInt32(123)]], aggr=[[]] + TableScan: test + ") } #[test] @@ -248,17 +242,11 @@ mod tests { )? .build()?; - let expected = "\ - Projection: UInt32(123) AS const, test.a, count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: UInt32(123) AS const, test.a, count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -273,17 +261,11 @@ mod tests { .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -298,15 +280,9 @@ mod tests { .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 789235595dabf..412bbea2ae92c 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -54,7 +54,7 @@ impl OptimizerRule for EliminateJoin { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( + Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: join.schema, @@ -74,15 +74,28 @@ impl OptimizerRule for EliminateJoin { #[cfg(test)] mod tests { + use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_join::EliminateJoin; - use crate::test::*; + use crate::OptimizerContext; use datafusion_common::Result; use datafusion_expr::JoinType::Inner; - use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan}; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateJoin::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -95,7 +108,6 @@ mod tests { )? .build()?; - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation: rows=0") } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 5d3a1b223b7a7..8e25d3246f6c2 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -90,7 +90,6 @@ impl OptimizerRule for EliminateLimit { #[cfg(test)] mod tests { use super::*; - use crate::optimizer::Optimizer; use crate::test::*; use crate::OptimizerContext; use datafusion_common::Column; @@ -100,36 +99,43 @@ mod tests { }; use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; use crate::push_down_limit::PushDownLimit; use datafusion_expr::test::function_stub::sum; - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(EliminateLimit::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } - fn assert_optimized_plan_eq_with_pushdown( - plan: LogicalPlan, - expected: &str, - ) -> Result<()> { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let config = OptimizerContext::new().with_max_passes(1); - let optimizer = Optimizer::with_rules(vec![ - Arc::new(PushDownLimit::new()), - Arc::new(EliminateLimit::new()), - ]); - let optimized_plan = optimizer - .optimize(plan, &config, observe) - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_eq_with_pushdown { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![ + Arc::new(PushDownLimit::new()), + Arc::new(EliminateLimit::new()) + ]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -140,8 +146,10 @@ mod tests { .limit(0, Some(0))? .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ "EmptyRelation: rows=0" + ) } #[test] @@ -157,11 +165,15 @@ mod tests { .build()?; // Left side is removed - let expected = "Union\ - \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Union + EmptyRelation: rows=0 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -174,8 +186,10 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_eq_with_pushdown(plan, expected) + assert_optimized_plan_eq_with_pushdown!( + plan, + @ "EmptyRelation: rows=0" + ) } #[test] @@ -190,12 +204,16 @@ mod tests { // After remove global-state, we don't record the parent // So, bottom don't know parent info, so can't eliminate. - let expected = "Limit: skip=2, fetch=1\ - \n Sort: test.a ASC NULLS LAST, fetch=3\ - \n Limit: skip=0, fetch=2\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq_with_pushdown(plan, expected) + assert_optimized_plan_eq_with_pushdown!( + plan, + @ r" + Limit: skip=2, fetch=1 + Sort: test.a ASC NULLS LAST, fetch=3 + Limit: skip=0, fetch=2 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -208,12 +226,16 @@ mod tests { .limit(0, Some(1))? .build()?; - let expected = "Limit: skip=0, fetch=1\ - \n Sort: test.a ASC NULLS LAST\ - \n Limit: skip=0, fetch=2\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=0, fetch=1 + Sort: test.a ASC NULLS LAST + Limit: skip=0, fetch=2 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -226,12 +248,16 @@ mod tests { .limit(3, Some(1))? .build()?; - let expected = "Limit: skip=3, fetch=1\ - \n Sort: test.a ASC NULLS LAST\ - \n Limit: skip=2, fetch=1\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=3, fetch=1 + Sort: test.a ASC NULLS LAST + Limit: skip=2, fetch=1 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -248,12 +274,16 @@ mod tests { .limit(3, Some(1))? .build()?; - let expected = "Limit: skip=3, fetch=1\ - \n Inner Join: Using test.a = test1.a\ - \n Limit: skip=2, fetch=1\ - \n TableScan: test\ - \n TableScan: test1"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=3, fetch=1 + Inner Join: Using test.a = test1.a + Limit: skip=2, fetch=1 + TableScan: test + TableScan: test1 + " + ) } #[test] @@ -264,8 +294,12 @@ mod tests { .limit(0, None)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 94da08243d78f..f8f93727cd9ba 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -116,7 +116,8 @@ mod tests { use super::*; use crate::analyzer::type_coercion::TypeCoercion; use crate::analyzer::Analyzer; - use crate::test::*; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{col, logical_plan::table_scan}; @@ -129,15 +130,23 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) - .execute_and_check(plan, &options, |_, _| {})?; - assert_optimized_plan_eq( - Arc::new(EliminateNestedUnion::new()), - analyzed_plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let options = ConfigOptions::default(); + let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) + .execute_and_check($plan, &options, |_, _| {})?; + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateNestedUnion::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + analyzed_plan, + @ $expected, + ) + }}; } #[test] @@ -146,11 +155,11 @@ mod tests { let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?; - let expected = "\ - Union\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + TableScan: table + ") } #[test] @@ -162,11 +171,12 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + ") } #[test] @@ -180,13 +190,13 @@ mod tests { .union(plan_builder.build()?)? .build()?; - let expected = "\ - Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -200,14 +210,15 @@ mod tests { .union(plan_builder.build()?)? .build()?; - let expected = "Union\ - \n Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -222,14 +233,15 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -243,13 +255,14 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } // We don't need to use project_with_column_index in logical optimizer, @@ -273,13 +286,14 @@ mod tests { )? .build()?; - let expected = "Union\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + ") } #[test] @@ -301,14 +315,15 @@ mod tests { )? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + ") } #[test] @@ -348,13 +363,14 @@ mod tests { .union(table_3.build()?)? .build()?; - let expected = "Union\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + ") } #[test] @@ -394,13 +410,14 @@ mod tests { .union_distinct(table_3.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + ") } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 1ecb32ca2a435..45877642f2766 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -118,7 +118,7 @@ impl OptimizerRule for EliminateOuterJoin { on: join.on.clone(), filter: join.filter.clone(), schema: Arc::clone(&join.schema), - null_equals_null: join.null_equals_null, + null_equality: join.null_equality, })); Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) @@ -304,7 +304,9 @@ fn extract_non_nullable_columns( #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use arrow::datatypes::DataType; use datafusion_expr::{ binary_expr, cast, col, lit, @@ -313,8 +315,20 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateOuterJoin::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -332,12 +346,13 @@ mod tests { )? .filter(col("t2.b").is_null())? .build()?; - let expected = "\ - Filter: t2.b IS NULL\ - \n Left Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NULL + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -355,12 +370,13 @@ mod tests { )? .filter(col("t2.b").is_not_null())? .build()?; - let expected = "\ - Filter: t2.b IS NOT NULL\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NOT NULL + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -382,12 +398,13 @@ mod tests { col("t1.c").lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) OR t1.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -409,12 +426,13 @@ mod tests { col("t2.c").lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) AND t2.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -436,11 +454,12 @@ mod tests { try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 48191ec206313..c76de942de805 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -19,8 +19,8 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::DFSchema; -use datafusion_common::Result; +use datafusion_common::{internal_err, DFSchema}; +use datafusion_common::{NullEquality, Result}; use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; @@ -75,13 +75,52 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => { let left_schema = left.schema(); let right_schema = right.schema(); let (equijoin_predicates, non_equijoin_expr) = split_eq_and_noneq_join_predicate(expr, left_schema, right_schema)?; + // Equi-join operators like HashJoin support a special behavior + // that evaluates `NULL = NULL` as true instead of NULL. Therefore, + // we transform `t1.c1 IS NOT DISTINCT FROM t2.c1` into an equi-join + // and set the `NullEquality` configuration in the join operator. + // This allows certain queries to use Hash Join instead of + // Nested Loop Join, resulting in better performance. + // + // Only convert when there are NO equijoin predicates, to be conservative. + if on.is_empty() + && equijoin_predicates.is_empty() + && non_equijoin_expr.is_some() + { + // SAFETY: checked in the outer `if` + let expr = non_equijoin_expr.clone().unwrap(); + let (equijoin_predicates, non_equijoin_expr) = + split_is_not_distinct_from_and_other_join_predicate( + expr, + left_schema, + right_schema, + )?; + + if !equijoin_predicates.is_empty() { + on.extend(equijoin_predicates); + + return Ok(Transformed::yes(LogicalPlan::Join(Join { + left, + right, + on, + filter: non_equijoin_expr, + join_type, + join_constraint, + schema, + // According to `is not distinct from`'s semantics, it's + // safe to override it + null_equality: NullEquality::NullEqualsNull, + }))); + } + } + if !equijoin_predicates.is_empty() { on.extend(equijoin_predicates); Ok(Transformed::yes(LogicalPlan::Join(Join { @@ -92,7 +131,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }))) } else { Ok(Transformed::no(LogicalPlan::Join(Join { @@ -103,7 +142,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }))) } } @@ -112,22 +151,98 @@ impl OptimizerRule for ExtractEquijoinPredicate { } } +/// Splits an ANDed filter expression into equijoin predicates and remaining filters. +/// Returns all equijoin predicates and the remaining filters combined with AND. +/// +/// # Example +/// +/// For the expression `a.id = b.id AND a.x > 10 AND b.x > b.id`, this function will extract `a.id = b.id` as an equijoin predicate. +/// +/// It first splits the ANDed sub-expressions: +/// - expr1: a.id = b.id +/// - expr2: a.x > 10 +/// - expr3: b.x > b.id +/// +/// Then, it filters out the equijoin predicates and collects the non-equality expressions. +/// The equijoin condition is: +/// - It is an equality expression like `lhs == rhs` +/// - All column references in `lhs` are from the left schema, and all in `rhs` are from the right schema +/// +/// According to the above rule, `expr1` is the equijoin predicate, while `expr2` and `expr3` are not. +/// The function returns Ok(\[expr1\], Some(expr2 AND expr3)) fn split_eq_and_noneq_join_predicate( filter: Expr, left_schema: &DFSchema, right_schema: &DFSchema, ) -> Result<(Vec, Option)> { + split_op_and_other_join_predicates(filter, left_schema, right_schema, Operator::Eq) +} + +/// See `split_eq_and_noneq_join_predicate`'s comment for the idea. This function +/// is splitting out `is not distinct from` expressions instead of equal exprs. +/// The `is not distinct from` exprs will be return as `EquijoinPredicate`. +/// +/// # Example +/// - Input: `a.id IS NOT DISTINCT FROM b.id AND a.x > 10 AND b.x > b.id` +/// - Output from this splitter: `Ok([a.id, b.id], Some((a.x > 10) AND (b.x > b.id)))` +/// +/// # Note +/// Caller should be cautious -- `is not distinct from` is not equivalent to an +/// equal expression; the caller is responsible for correctly setting the +/// `nulls equals nulls` property in the join operator (if it supports it) to +/// make the transformation valid. +/// +/// For the above example: in downstream, a valid plan that uses the extracted +/// equijoin keys should look like: +/// +/// HashJoin +/// - on: `a.id = b.id` (equality) +/// - join_filter: `(a.x > 10) AND (b.x > b.id)` +/// - nulls_equals_null: `true` +/// +/// This reflects that `IS NOT DISTINCT FROM` treats `NULL = NULL` as true and +/// thus requires setting `NullEquality::NullEqualsNull` in the join operator to +/// preserve semantics while enabling an equi-join implementation (e.g., HashJoin). +fn split_is_not_distinct_from_and_other_join_predicate( + filter: Expr, + left_schema: &DFSchema, + right_schema: &DFSchema, +) -> Result<(Vec, Option)> { + split_op_and_other_join_predicates( + filter, + left_schema, + right_schema, + Operator::IsNotDistinctFrom, + ) +} + +/// See comments in `split_eq_and_noneq_join_predicate` for details. +fn split_op_and_other_join_predicates( + filter: Expr, + left_schema: &DFSchema, + right_schema: &DFSchema, + operator: Operator, +) -> Result<(Vec, Option)> { + if !matches!(operator, Operator::Eq | Operator::IsNotDistinctFrom) { + return internal_err!( + "split_op_and_other_join_predicates only supports 'Eq' or 'IsNotDistinctFrom' operators, \ + but received: {:?}", + operator + ); + } + let exprs = split_conjunction_owned(filter); + // Treat 'is not distinct from' comparison as join key in equal joins let mut accum_join_keys: Vec<(Expr, Expr)> = vec![]; let mut accum_filters: Vec = vec![]; for expr in exprs { match expr { Expr::BinaryExpr(BinaryExpr { ref left, - op: Operator::Eq, + ref op, ref right, - }) => { + }) if *op == operator => { let join_key_pair = find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?; @@ -155,6 +270,7 @@ fn split_eq_and_noneq_join_predicate( #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use crate::test::*; use arrow::datatypes::DataType; use datafusion_expr::{ @@ -162,14 +278,18 @@ mod tests { }; use std::sync::Arc; - fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(ExtractEquijoinPredicate {}), - plan, - expected, - ); - - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(ExtractEquijoinPredicate {}); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -180,11 +300,15 @@ mod tests { let plan = LogicalPlanBuilder::from(t1) .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))? .build()?; - let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -199,11 +323,15 @@ mod tests { Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))), )? .build()?; - let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -222,11 +350,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -249,11 +381,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -275,11 +411,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -310,13 +450,17 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -343,13 +487,17 @@ mod tests { Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))), )? .build()?; - let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -369,10 +517,14 @@ mod tests { let plan = LogicalPlanBuilder::from(t1) .join_on(t2, JoinType::Left, Some(filter))? .build()?; - let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 2e7a751ca4c57..8ad7fa53c0e33 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -21,7 +21,7 @@ use crate::optimizer::ApplyOrder; use crate::push_down_filter::on_lr_is_preserved; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; +use datafusion_common::{NullEquality, Result}; use datafusion_expr::utils::conjunction; use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan}; use std::sync::Arc; @@ -51,7 +51,8 @@ impl OptimizerRule for FilterNullJoinKeys { } match plan { LogicalPlan::Join(mut join) - if !join.on.is_empty() && !join.null_equals_null => + if !join.on.is_empty() + && join.null_equality == NullEquality::NullEqualsNothing => { let (left_preserved, right_preserved) = on_lr_is_preserved(join.join_type); @@ -107,35 +108,52 @@ fn create_not_null_predicate(filters: Vec) -> Expr { #[cfg(test)] mod tests { use super::*; - use crate::test::assert_optimized_plan_eq; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Column; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(FilterNullJoinKeys {})]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] fn left_nullable() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; - let expected = "Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] fn left_nullable_left_join() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?; - let expected = "Left Join: t1.optional_id = t2.id\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Left Join: t1.optional_id = t2.id + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -144,22 +162,26 @@ mod tests { // Note: order of tables is reversed let plan = build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?; - let expected = "Left Join: t2.id = t1.optional_id\ - \n TableScan: t2\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Left Join: t2.id = t1.optional_id + TableScan: t2 + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + ") } #[test] fn left_nullable_on_condition_reversed() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?; - let expected = "Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -189,14 +211,16 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id\ - \n Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL\ - \n TableScan: t3\ - \n Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id + Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL + TableScan: t3 + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -213,11 +237,13 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1)\ - \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1) + Filter: t1.optional_id + UInt32(1) IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -234,11 +260,13 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1)\ - \n TableScan: t1\ - \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1) + TableScan: t1 + Filter: t2.optional_id + UInt32(1) IS NOT NULL + TableScan: t2 + ") } #[test] @@ -255,13 +283,14 @@ mod tests { None, )? .build()?; - let expected = - "Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1)\ - \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t1\ - \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1) + Filter: t1.optional_id + UInt32(1) IS NOT NULL + TableScan: t1 + Filter: t2.optional_id + UInt32(1) IS NOT NULL + TableScan: t2 + ") } #[test] @@ -283,13 +312,22 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.optional_id = t2.optional_id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n Filter: t2.optional_id IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan_from_cols, expected)?; - assert_optimized_plan_equal(plan_from_exprs, expected) + + assert_optimized_plan_equal!(plan_from_cols, @r" + Inner Join: t1.optional_id = t2.optional_id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + Filter: t2.optional_id IS NOT NULL + TableScan: t2 + ")?; + + assert_optimized_plan_equal!(plan_from_exprs, @r" + Inner Join: t1.optional_id = t2.optional_id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + Filter: t2.optional_id IS NOT NULL + TableScan: t2 + ") } fn build_plan( diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 893cb249a2a86..85fa9493f449d 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -40,6 +40,7 @@ pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; +pub mod decorrelate_lateral_join; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b3a09e2dcbcc7..5db71417bc8fd 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -26,13 +26,12 @@ use std::sync::Arc; use datafusion_common::{ get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column, - HashMap, JoinType, Result, + DFSchema, HashMap, JoinType, Result, }; use datafusion_expr::expr::Alias; -use datafusion_expr::Unnest; use datafusion_expr::{ - logical_plan::LogicalPlan, projection_schema, Aggregate, Distinct, Expr, Projection, - TableScan, Window, + logical_plan::LogicalPlan, Aggregate, Distinct, EmptyRelation, Expr, Projection, + TableScan, Unnest, Window, }; use crate::optimize_projections::required_indices::RequiredIndices; @@ -56,6 +55,24 @@ use datafusion_common::tree_node::{ /// The rule analyzes the input logical plan, determines the necessary column /// indices, and then removes any unnecessary columns. It also removes any /// unnecessary projections from the plan tree. +/// +/// ## Schema, Field Properties, and Metadata Handling +/// +/// The `OptimizeProjections` rule preserves schema and field metadata in most optimization scenarios: +/// +/// **Schema-level metadata preservation by plan type**: +/// - **Window and Aggregate plans**: Schema metadata is preserved +/// - **Projection plans**: Schema metadata is preserved per [`projection_schema`](datafusion_expr::logical_plan::projection_schema). +/// - **Other logical plans**: Schema metadata is preserved unless [`LogicalPlan::recompute_schema`] +/// is called on plan types that drop metadata +/// +/// **Field-level properties and metadata**: Individual field properties are preserved when fields +/// are retained in the optimized plan, determined by [`exprlist_to_fields`](datafusion_expr::utils::exprlist_to_fields) +/// and [`ExprSchemable::to_field`](datafusion_expr::expr_schema::ExprSchemable::to_field). +/// +/// **Field precedence**: When the same field appears multiple times, the optimizer +/// maintains one occurrence and removes duplicates (refer to `RequiredIndices::compact()`), +/// preserving the properties and metadata of that occurrence. #[derive(Default, Debug)] pub struct OptimizeProjections {} @@ -154,23 +171,16 @@ fn optimize_projections( // Only use the absolutely necessary aggregate expressions required // by the parent: - let mut new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr); - - // Aggregations always need at least one aggregate expression. - // With a nested count, we don't require any column as input, but - // still need to create a correct aggregate, which may be optimized - // out later. As an example, consider the following query: - // - // SELECT count(*) FROM (SELECT count(*) FROM [...]) - // - // which always returns 1. - if new_aggr_expr.is_empty() - && new_group_bys.is_empty() - && !aggregate.aggr_expr.is_empty() - { - // take the old, first aggregate expression - new_aggr_expr = aggregate.aggr_expr; - new_aggr_expr.resize_with(1, || unreachable!()); + let new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr); + + if new_group_bys.is_empty() && new_aggr_expr.is_empty() { + // Global aggregation with no aggregate functions always produces 1 row and no columns. + return Ok(Transformed::yes(LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + }, + ))); } let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); @@ -346,12 +356,35 @@ fn optimize_projections( .collect::>>()? } LogicalPlan::EmptyRelation(_) - | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Values(_) | LogicalPlan::DescribeTable(_) => { // These operators have no inputs, so stop the optimization process. return Ok(Transformed::no(plan)); } + LogicalPlan::RecursiveQuery(recursive) => { + // Only allow subqueries that reference the current CTE; nested subqueries are not yet + // supported for projection pushdown for simplicity. + // TODO: be able to do projection pushdown on recursive CTEs with subqueries + if plan_contains_other_subqueries( + recursive.static_term.as_ref(), + &recursive.name, + ) || plan_contains_other_subqueries( + recursive.recursive_term.as_ref(), + &recursive.name, + ) { + return Ok(Transformed::no(plan)); + } + + plan.inputs() + .into_iter() + .map(|input| { + indices + .clone() + .with_projection_beneficial() + .with_plan_exprs(&plan, input.schema()) + }) + .collect::>>()? + } LogicalPlan::Join(join) => { let left_len = join.left.schema().fields().len(); let (left_req_indices, right_req_indices) = @@ -377,11 +410,22 @@ fn optimize_projections( ); } LogicalPlan::Unnest(Unnest { - dependency_indices, .. + input, + dependency_indices, + .. }) => { - vec![RequiredIndices::new_from_indices( - dependency_indices.clone(), - )] + // at least provide the indices for the exec-columns as a starting point + let required_indices = + RequiredIndices::new().with_plan_exprs(&plan, input.schema())?; + + // Add additional required indices from the parent + let mut additional_necessary_child_indices = Vec::new(); + indices.indices().iter().for_each(|idx| { + if let Some(index) = dependency_indices.get(*idx) { + additional_necessary_child_indices.push(*index); + } + }); + vec![required_indices.append(&additional_necessary_child_indices)] } }; @@ -432,6 +476,18 @@ fn optimize_projections( /// appear more than once in its input fields. This can act as a caching mechanism /// for non-trivial computations. /// +/// ## Metadata Handling During Projection Merging +/// +/// **Alias metadata preservation**: When merging projections, alias metadata from both +/// the current and previous projections is carefully preserved. The presence of metadata +/// precludes alias trimming. +/// +/// **Schema, Fields, and metadata**: If a projection is rewritten, the schema and metadata +/// are preserved. Individual field properties and metadata flows through expression rewriting +/// and are preserved when fields are referenced in the merged projection. +/// Refer to [`projection_schema`](datafusion_expr::logical_plan::projection_schema) +/// for more details. +/// /// # Parameters /// /// * `proj` - A reference to the `Projection` to be merged. @@ -455,6 +511,17 @@ fn merge_consecutive_projections(proj: Projection) -> Result::new(); expr.iter() @@ -523,7 +590,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// Rewrites a projection expression using the projection before it (i.e. its input) @@ -544,7 +611,8 @@ fn is_expr_trivial(expr: &Expr) -> bool { /// - `Err(error)`: An error occurred during the function call. /// /// # Notes -/// This rewrite also removes any unnecessary layers of aliasing. +/// This rewrite also removes any unnecessary layers of aliasing. "Unnecessary" is +/// defined as not contributing new information, such as metadata. /// /// Without trimming, we can end up with unnecessary indirections inside expressions /// during projection merges. @@ -573,8 +641,18 @@ fn is_expr_trivial(expr: &Expr) -> bool { fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { expr.transform_up(|expr| { match expr { - // remove any intermediate aliases - Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + // remove any intermediate aliases if they do not carry metadata + Expr::Alias(alias) => { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } + } Expr::Column(col) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; @@ -652,10 +730,10 @@ fn outer_columns_helper_multi<'a, 'b>( /// Depending on the join type, it divides the requirement indices into those /// that apply to the left child and those that apply to the right child. /// -/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split -/// between left and right children. The right child indices are adjusted to -/// point to valid positions within the right child by subtracting the length -/// of the left child. +/// - For `INNER`, `LEFT`, `RIGHT`, `FULL`, `LEFTMARK`, and `RIGHTMARK` joins, +/// the requirements are split between left and right children. The right +/// child indices are adjusted to point to valid positions within the right +/// child by subtracting the length of the left child. /// /// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all /// requirements are re-routed to either the left child or the right child @@ -684,7 +762,8 @@ fn split_join_requirements( | JoinType::Left | JoinType::Right | JoinType::Full - | JoinType::LeftMark => { + | JoinType::LeftMark + | JoinType::RightMark => { // Decrease right side indices by `left_len` so that they point to valid // positions within the right child: indices.split_off(left_len) @@ -774,9 +853,83 @@ fn rewrite_projection_given_requirements( /// Projection is unnecessary, when /// - input schema of the projection, output schema of the projection are same, and /// - all projection expressions are either Column or Literal -fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result { - let proj_schema = projection_schema(input, proj_exprs)?; - Ok(&proj_schema == input.schema() && proj_exprs.iter().all(is_expr_trivial)) +pub fn is_projection_unnecessary( + input: &LogicalPlan, + proj_exprs: &[Expr], +) -> Result { + // First check if the number of expressions is equal to the number of fields in the input schema. + if proj_exprs.len() != input.schema().fields().len() { + return Ok(false); + } + Ok(input.schema().iter().zip(proj_exprs.iter()).all( + |((field_relation, field_name), expr)| { + // Check if the expression is a column and if it matches the field name + if let Expr::Column(col) = expr { + col.relation.as_ref() == field_relation && col.name.eq(field_name.name()) + } else { + false + } + }, + )) +} + +/// Returns true if the plan subtree contains any subqueries that are not the +/// CTE reference itself. This treats any non-CTE [`LogicalPlan::SubqueryAlias`] +/// node (including aliased relations) as a blocker, along with expression-level +/// subqueries like scalar, EXISTS, or IN. These cases prevent projection +/// pushdown for now because we cannot safely reason about their column usage. +fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool { + if let LogicalPlan::SubqueryAlias(alias) = plan { + if alias.alias.table() != cte_name + && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name) + { + return true; + } + } + + let mut found = false; + plan.apply_expressions(|expr| { + if expr_contains_subquery(expr) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .expect("expression traversal never fails"); + if found { + return true; + } + + plan.inputs() + .into_iter() + .any(|child| plan_contains_other_subqueries(child, cte_name)) +} + +fn expr_contains_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true), + _ => Ok(false), + }) + // Safe unwrap since we are doing a simple boolean check + .unwrap() +} + +fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool { + match plan { + LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name, + LogicalPlan::SubqueryAlias(alias) => { + subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name) + } + _ => { + let inputs = plan.inputs(); + if inputs.len() == 1 { + subquery_alias_targets_recursive_cte(inputs[0], cte_name) + } else { + false + } + } + } } #[cfg(test)] @@ -791,8 +944,8 @@ mod tests { use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::{ - assert_fields_eq, assert_optimized_plan_eq, scan_empty, test_table_scan, - test_table_scan_fields, test_table_scan_with_name, + assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields, + test_table_scan_with_name, }; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; @@ -810,13 +963,27 @@ mod tests { not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; + use insta::assert_snapshot; + use crate::assert_optimized_plan_eq_snapshot; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::{count, max, min}; use datafusion_functions_aggregate::min_max::max_udaf; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(OptimizeProjections::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[derive(Debug, Hash, PartialEq, Eq)] @@ -848,6 +1015,8 @@ mod tests { Some(Ordering::Equal) => self.input.partial_cmp(&other.input), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -935,6 +1104,8 @@ mod tests { } cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -1005,9 +1176,13 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) + test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1019,9 +1194,13 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) + test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1032,9 +1211,13 @@ mod tests { .project(vec![col("a").alias("alias")])? .build()?; - let expected = "Projection: test.a AS alias\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS alias + TableScan: test projection=[a] + " + ) } #[test] @@ -1045,9 +1228,13 @@ mod tests { .project(vec![col("alias2").alias("alias")])? .build()?; - let expected = "Projection: test.a AS alias\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS alias + TableScan: test projection=[a] + " + ) } #[test] @@ -1065,11 +1252,13 @@ mod tests { .build() .unwrap(); - let expected = "Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ - \n Projection: \ - \n Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ - \n TableScan: ?table? projection=[]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + EmptyRelation: rows=1 + " + ) } #[test] @@ -1079,9 +1268,13 @@ mod tests { .project(vec![-col("a")])? .build()?; - let expected = "Projection: (- test.a)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: (- test.a) + TableScan: test projection=[a] + " + ) } #[test] @@ -1091,9 +1284,13 @@ mod tests { .project(vec![col("a").is_null()])? .build()?; - let expected = "Projection: test.a IS NULL\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NULL + TableScan: test projection=[a] + " + ) } #[test] @@ -1103,9 +1300,13 @@ mod tests { .project(vec![col("a").is_not_null()])? .build()?; - let expected = "Projection: test.a IS NOT NULL\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT NULL + TableScan: test projection=[a] + " + ) } #[test] @@ -1115,9 +1316,13 @@ mod tests { .project(vec![col("a").is_true()])? .build()?; - let expected = "Projection: test.a IS TRUE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS TRUE + TableScan: test projection=[a] + " + ) } #[test] @@ -1127,9 +1332,13 @@ mod tests { .project(vec![col("a").is_not_true()])? .build()?; - let expected = "Projection: test.a IS NOT TRUE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT TRUE + TableScan: test projection=[a] + " + ) } #[test] @@ -1139,9 +1348,13 @@ mod tests { .project(vec![col("a").is_false()])? .build()?; - let expected = "Projection: test.a IS FALSE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS FALSE + TableScan: test projection=[a] + " + ) } #[test] @@ -1151,9 +1364,13 @@ mod tests { .project(vec![col("a").is_not_false()])? .build()?; - let expected = "Projection: test.a IS NOT FALSE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT FALSE + TableScan: test projection=[a] + " + ) } #[test] @@ -1163,9 +1380,13 @@ mod tests { .project(vec![col("a").is_unknown()])? .build()?; - let expected = "Projection: test.a IS UNKNOWN\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS UNKNOWN + TableScan: test projection=[a] + " + ) } #[test] @@ -1175,9 +1396,13 @@ mod tests { .project(vec![col("a").is_not_unknown()])? .build()?; - let expected = "Projection: test.a IS NOT UNKNOWN\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT UNKNOWN + TableScan: test projection=[a] + " + ) } #[test] @@ -1187,9 +1412,13 @@ mod tests { .project(vec![not(col("a"))])? .build()?; - let expected = "Projection: NOT test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: NOT test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1199,9 +1428,13 @@ mod tests { .project(vec![try_cast(col("a"), DataType::Float64)])? .build()?; - let expected = "Projection: TRY_CAST(test.a AS Float64)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: TRY_CAST(test.a AS Float64) + TableScan: test projection=[a] + " + ) } #[test] @@ -1215,9 +1448,13 @@ mod tests { .project(vec![similar_to_expr])? .build()?; - let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a SIMILAR TO Utf8("[0-9]") + TableScan: test projection=[a] + "# + ) } #[test] @@ -1227,9 +1464,13 @@ mod tests { .project(vec![col("a").between(lit(1), lit(3))])? .build()?; - let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a BETWEEN Int32(1) AND Int32(3) + TableScan: test projection=[a] + " + ) } // Test Case expression @@ -1246,9 +1487,13 @@ mod tests { ])? .build()?; - let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d + TableScan: test projection=[a] + " + ) } // Test outer projection isn't discarded despite the same schema as inner @@ -1266,11 +1511,14 @@ mod tests { ])? .build()?; - let expected = - "Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d\ - \n Projection: test.a + Int32(1) AS a, Int32(0) AS d\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d + Projection: test.a + Int32(1) AS a, Int32(0) AS d + TableScan: test projection=[a] + " + ) } // Since only column `a` is referred at the output. Scan should only contain projection=[a]. @@ -1288,10 +1536,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a] + " + ) } // Only column `a` is referred at the output. However, User defined node itself uses column `b` @@ -1315,10 +1567,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a, b]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a, b] + " + ) } // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` @@ -1350,10 +1606,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a, b, c] + " + ) } // Columns `l.a`, `l.c`, `r.a` is referred at the output. @@ -1374,11 +1634,15 @@ mod tests { .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\ - \n UserDefinedCrossJoin\ - \n TableScan: l projection=[a, c]\ - \n TableScan: r projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: l.a, l.c, r.a, Int32(0) AS d + UserDefinedCrossJoin + TableScan: l projection=[a, c] + TableScan: r projection=[a] + " + ) } #[test] @@ -1389,10 +1653,13 @@ mod tests { .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] + TableScan: test projection=[b] + " + ) } #[test] @@ -1403,10 +1670,13 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]] + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1418,11 +1688,14 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]\ - \n SubqueryAlias: a\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]] + SubqueryAlias: a + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1434,12 +1707,15 @@ mod tests { .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ - \n Projection: test.b\ - \n Filter: test.c > Int32(1)\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] + Projection: test.b + Filter: test.c > Int32(1) + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1460,11 +1736,13 @@ mod tests { .project([col(Column::new_unqualified("tag.one"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]\ - \n TableScan: m4 projection=[tag.one]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]] + TableScan: m4 projection=[tag.one] + " + ) } #[test] @@ -1475,10 +1753,13 @@ mod tests { .project(vec![col("a"), col("b"), col("c")])? .project(vec![col("a"), col("c"), col("b")])? .build()?; - let expected = "Projection: test.a, test.c, test.b\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.c, test.b + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1486,9 +1767,10 @@ mod tests { let schema = Schema::new(test_table_scan_fields()); let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; - let expected = "TableScan: test projection=[b, a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[b, a, c]" + ) } #[test] @@ -1498,10 +1780,13 @@ mod tests { let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))? .project(vec![col("a"), col("b")])? .build()?; - let expected = "Projection: test.a, test.b\ - \n TableScan: test projection=[b, a]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + TableScan: test projection=[b, a] + " + ) } #[test] @@ -1511,10 +1796,13 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("c"), col("b"), col("a")])? .build()?; - let expected = "Projection: test.c, test.b, test.a\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.b, test.a + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1529,14 +1817,18 @@ mod tests { .filter(col("a").gt(lit(1)))? .project(vec![col("a"), col("c"), col("b")])? .build()?; - let expected = "Projection: test.a, test.c, test.b\ - \n Filter: test.a > Int32(1)\ - \n Filter: test.b > Int32(1)\ - \n Projection: test.c, test.a, test.b\ - \n Filter: test.c > Int32(1)\ - \n Projection: test.c, test.b, test.a\ - \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.c, test.b + Filter: test.a > Int32(1) + Filter: test.b > Int32(1) + Projection: test.c, test.a, test.b + Filter: test.c > Int32(1) + Projection: test.c, test.b, test.a + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1551,14 +1843,17 @@ mod tests { .project(vec![col("a"), col("b"), col("c1")])? .build()?; - // make sure projections are pushed down to both table scans - let expected = "Left Join: test.a = test2.c1\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to both table scans + assert_snapshot!( + optimized_plan.clone(), + @r" + Left Join: test.a = test2.c1 + TableScan: test projection=[a, b] + TableScan: test2 projection=[c1] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan; @@ -1602,15 +1897,18 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - // make sure projections are pushed down to both table scans - let expected = "Projection: test.a, test.b\ - \n Left Join: test.a = test2.c1\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to both table scans + assert_snapshot!( + optimized_plan.clone(), + @r" + Projection: test.a, test.b + Left Join: test.a = test2.c1 + TableScan: test projection=[a, b] + TableScan: test2 projection=[c1] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; @@ -1648,19 +1946,22 @@ mod tests { let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; let plan = LogicalPlanBuilder::from(table_scan) - .join_using(table2_scan, JoinType::Left, vec!["a"])? + .join_using(table2_scan, JoinType::Left, vec!["a".into()])? .project(vec![col("a"), col("b")])? .build()?; - // make sure projections are pushed down to table scan - let expected = "Projection: test.a, test.b\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[a]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to table scan + assert_snapshot!( + optimized_plan.clone(), + @r" + Projection: test.a, test.b + Left Join: Using test.a = test2.a + TableScan: test projection=[a, b] + TableScan: test2 projection=[a] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; @@ -1692,17 +1993,20 @@ mod tests { fn cast() -> Result<()> { let table_scan = test_table_scan()?; - let projection = LogicalPlanBuilder::from(table_scan) + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![Expr::Cast(Cast::new( Box::new(col("c")), DataType::Float64, ))])? .build()?; - let expected = "Projection: CAST(test.c AS Float64)\ - \n TableScan: test projection=[c]"; - - assert_optimized_plan_equal(projection, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: CAST(test.c AS Float64) + TableScan: test projection=[c] + " + ) } #[test] @@ -1716,9 +2020,10 @@ mod tests { assert_fields_eq(&table_scan, vec!["a", "b", "c"]); assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b]" + ) } #[test] @@ -1737,9 +2042,10 @@ mod tests { assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b]" + ) } #[test] @@ -1755,11 +2061,14 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a"]); - let expected = "Limit: skip=0, fetch=5\ - \n Projection: test.c, test.a\ - \n TableScan: test projection=[a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=5 + Projection: test.c, test.a + TableScan: test projection=[a, c] + " + ) } #[test] @@ -1767,8 +2076,10 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan).build()?; // should expand projection to all columns without projection - let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b, c]" + ) } #[test] @@ -1777,9 +2088,13 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![lit(1_i64), lit(2_i64)])? .build()?; - let expected = "Projection: Int64(1), Int64(2)\ - \n TableScan: test projection=[]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int64(1), Int64(2) + TableScan: test projection=[] + " + ) } /// tests that it removes unused columns in projections @@ -1799,13 +2114,15 @@ mod tests { assert_fields_eq(&plan, vec!["c", "max(test.a)"]); let plan = optimize(plan).expect("failed to optimize plan"); - let expected = "\ - Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]\ - \n Filter: test.c > Int32(1)\ - \n Projection: test.c, test.a\ - \n TableScan: test projection=[a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]] + Filter: test.c > Int32(1) + Projection: test.c, test.a + TableScan: test projection=[a, c] + " + ) } /// tests that it removes un-needed projections @@ -1823,11 +2140,13 @@ mod tests { assert_fields_eq(&plan, vec!["a"]); - let expected = "\ - Projection: Int32(1) AS a\ - \n TableScan: test projection=[]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) AS a + TableScan: test projection=[] + " + ) } #[test] @@ -1852,11 +2171,13 @@ mod tests { assert_fields_eq(&plan, vec!["a"]); - let expected = "\ - Projection: Int32(1) AS a\ - \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) AS a + TableScan: test projection=[], full_filters=[b = Int32(1)] + " + ) } /// tests that optimizing twice yields same plan @@ -1895,12 +2216,15 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]); - let expected = "Projection: test.c, test.a, max(test.b)\ - \n Filter: test.c > Int32(1)\ - \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.a, max(test.b) + Filter: test.c > Int32(1) + Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]] + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1917,10 +2241,13 @@ mod tests { )? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]] + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1933,18 +2260,21 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "Projection: test.a\ - \n Distinct:\ - \n TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Distinct: + TableScan: test projection=[a, b] + " + ) } #[test] fn test_window() -> Result<()> { let table_scan = test_table_scan()?; - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) @@ -1952,7 +2282,7 @@ mod tests { .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); @@ -1965,13 +2295,16 @@ mod tests { .project(vec![col1, col2])? .build()?; - let expected = "Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: test projection=[a, b] + " + ) } fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index ffbb95cb7f74e..084152d40e92c 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -33,6 +33,7 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, HashSet, Result use datafusion_expr::logical_plan::LogicalPlan; use crate::common_subexpr_eliminate::CommonSubexprEliminate; +use crate::decorrelate_lateral_join::DecorrelateLateralJoin; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; @@ -106,7 +107,7 @@ pub trait OptimizerConfig { /// Return alias generator used to generate unique aliases for subqueries fn alias_generator(&self) -> &Arc; - fn options(&self) -> &ConfigOptions; + fn options(&self) -> Arc; fn function_registry(&self) -> Option<&dyn FunctionRegistry> { None @@ -124,7 +125,7 @@ pub struct OptimizerContext { /// Alias generator used to generate unique aliases for subqueries alias_generator: Arc, - options: ConfigOptions, + options: Arc, } impl OptimizerContext { @@ -133,6 +134,11 @@ impl OptimizerContext { let mut options = ConfigOptions::default(); options.optimizer.filter_null_join_keys = true; + Self::new_with_config_options(Arc::new(options)) + } + + /// Create a optimizer config with provided [ConfigOptions]. + pub fn new_with_config_options(options: Arc) -> Self { Self { query_execution_start_time: Utc::now(), alias_generator: Arc::new(AliasGenerator::new()), @@ -142,7 +148,9 @@ impl OptimizerContext { /// Specify whether to enable the filter_null_keys rule pub fn filter_null_keys(mut self, filter_null_keys: bool) -> Self { - self.options.optimizer.filter_null_join_keys = filter_null_keys; + Arc::make_mut(&mut self.options) + .optimizer + .filter_null_join_keys = filter_null_keys; self } @@ -159,13 +167,13 @@ impl OptimizerContext { /// Specify whether the optimizer should skip rules that produce /// errors, or fail the query pub fn with_skip_failing_rules(mut self, b: bool) -> Self { - self.options.optimizer.skip_failed_rules = b; + Arc::make_mut(&mut self.options).optimizer.skip_failed_rules = b; self } /// Specify how many times to attempt to optimize the plan pub fn with_max_passes(mut self, v: u8) -> Self { - self.options.optimizer.max_passes = v as usize; + Arc::make_mut(&mut self.options).optimizer.max_passes = v as usize; self } } @@ -186,8 +194,8 @@ impl OptimizerConfig for OptimizerContext { &self.alias_generator } - fn options(&self) -> &ConfigOptions { - &self.options + fn options(&self) -> Arc { + Arc::clone(&self.options) } } @@ -226,6 +234,7 @@ impl Optimizer { Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(DecorrelateLateralJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), @@ -413,7 +422,7 @@ impl Optimizer { previous_plans.insert(LogicalPlanSignature::new(&new_plan)); if !plan_is_fresh { // plan did not change, so no need to continue trying to optimize - debug!("optimizer pass {} did not make changes", i); + debug!("optimizer pass {i} did not make changes"); break; } i += 1; @@ -506,8 +515,11 @@ mod tests { }); let err = opt.optimize(plan, &config, &observe).unwrap_err(); - // Simplify assert to check the error message contains the expected message, which is only the schema length mismatch - assert_contains!(err.strip_backtrace(), "Schema mismatch: the schema length are not same Expected schema length: 3, got: 0"); + // Simplify assert to check the error message contains the expected message + assert_contains!( + err.strip_backtrace(), + "Failed due to a difference in schemas: original schema: DFSchema" + ); } #[test] diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 344707ae8dbe3..4db3215dfb76a 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -242,17 +242,31 @@ mod tests { binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator, }; + use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_filter::EliminateFilter; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ - assert_optimized_plan_eq, assert_optimized_plan_with_rules, test_table_scan, - test_table_scan_fields, test_table_scan_with_name, + assert_optimized_plan_with_rules, test_table_scan, test_table_scan_fields, + test_table_scan_with_name, }; + use crate::OptimizerContext; use super::*; - fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PropagateEmptyRelation::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } fn assert_together_optimized_plan( @@ -280,8 +294,7 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, lit(1))])? .build()?; - let expected = "EmptyRelation"; - assert_eq(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation: rows=0") } #[test] @@ -303,7 +316,7 @@ mod tests { .filter(col("a").lt_eq(lit(1i64)))? .build()?; - let expected = "EmptyRelation"; + let expected = "EmptyRelation: rows=0"; assert_together_optimized_plan(plan, expected, true) } @@ -366,7 +379,7 @@ mod tests { .union(four)? .build()?; - let expected = "EmptyRelation"; + let expected = "EmptyRelation: rows=0"; assert_together_optimized_plan(plan, expected, true) } @@ -421,7 +434,7 @@ mod tests { .filter(col("a").lt_eq(lit(1i64)))? .build()?; - let expected = "EmptyRelation"; + let expected = "EmptyRelation: rows=0"; assert_together_optimized_plan(plan, expected, true) } @@ -461,7 +474,7 @@ mod tests { )? .build()?; - let expected = "EmptyRelation"; + let expected = "EmptyRelation: rows=0"; assert_together_optimized_plan(plan, expected, eq) } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c9617514e4539..a8251d6690022 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -20,6 +20,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use arrow::datatypes::DataType; use indexmap::IndexSet; use itertools::Itertools; @@ -40,6 +41,7 @@ use datafusion_expr::{ }; use crate::optimizer::ApplyOrder; +use crate::simplify_expressions::simplify_predicates; use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; use crate::{OptimizerConfig, OptimizerRule}; @@ -168,7 +170,7 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false), // No columns from the left side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::RightSemi | JoinType::RightAnti => (false, true), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true), } } @@ -191,6 +193,7 @@ pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::LeftAnti => (false, true), JoinType::RightAnti => (true, false), JoinType::LeftMark => (false, true), + JoinType::RightMark => (true, false), } } @@ -254,7 +257,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { let mut is_evaluate = true; predicate.apply(|expr| match expr { Expr::Column(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::Placeholder(_) | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } @@ -691,7 +694,7 @@ fn infer_join_predicates_from_on_filters( inferred_predicates, ) } - JoinType::Right | JoinType::RightSemi => { + JoinType::Right | JoinType::RightSemi | JoinType::RightMark => { infer_join_predicates_impl::( join_col_keys, on_filters, @@ -778,6 +781,18 @@ impl OptimizerRule for PushDownFilter { return Ok(Transformed::no(plan)); }; + let predicate = split_conjunction_owned(filter.predicate.clone()); + let old_predicate_len = predicate.len(); + let new_predicates = simplify_predicates(predicate)?; + if old_predicate_len != new_predicates.len() { + let Some(new_predicate) = conjunction(new_predicates) else { + // new_predicates is empty - remove the filter entirely + // Return the child plan without the filter + return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input))); + }; + filter.predicate = new_predicate; + } + match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Filter(child_filter) => { let parents_predicates = split_conjunction_owned(filter.predicate); @@ -861,14 +876,37 @@ impl OptimizerRule for PushDownFilter { let predicates = split_conjunction_owned(filter.predicate.clone()); let mut non_unnest_predicates = vec![]; let mut unnest_predicates = vec![]; + let mut unnest_struct_columns = vec![]; + + for idx in &unnest.struct_type_columns { + let (sub_qualifier, field) = + unnest.input.schema().qualified_field(*idx); + let field_name = field.name().clone(); + + if let DataType::Struct(children) = field.data_type() { + for child in children { + let child_name = child.name().clone(); + unnest_struct_columns.push(Column::new( + sub_qualifier.cloned(), + format!("{field_name}.{child_name}"), + )); + } + } + } + for predicate in predicates { // collect all the Expr::Column in predicate recursively let mut accum: HashSet = HashSet::new(); expr_to_columns(&predicate, &mut accum)?; - if unnest.list_type_columns.iter().any(|(_, unnest_list)| { - accum.contains(&unnest_list.output_column) - }) { + let contains_list_columns = + unnest.list_type_columns.iter().any(|(_, unnest_list)| { + accum.contains(&unnest_list.output_column) + }); + let contains_struct_columns = + unnest_struct_columns.iter().any(|c| accum.contains(c)); + + if contains_list_columns || contains_struct_columns { unnest_predicates.push(predicate); } else { non_unnest_predicates.push(predicate); @@ -940,8 +978,11 @@ impl OptimizerRule for PushDownFilter { let group_expr_columns = agg .group_expr .iter() - .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string()))) - .collect::>>()?; + .map(|e| { + let (relation, name) = e.qualified_name(); + Column::new(relation, name) + }) + .collect::>(); let predicates = split_conjunction_owned(filter.predicate); @@ -1009,7 +1050,10 @@ impl OptimizerRule for PushDownFilter { func.params .partition_by .iter() - .map(|c| Column::from_qualified_name(c.schema_name().to_string())) + .map(|c| { + let (relation, name) = c.qualified_name(); + Column::new(relation, name) + }) .collect::>() }; let potential_partition_keys = window @@ -1391,7 +1435,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; - use datafusion_common::{DFSchemaRef, ScalarValue}; + use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::expr::{ScalarFunction, WindowFunction}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ @@ -1401,38 +1445,47 @@ mod tests { WindowFunctionDefinition, }; + use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::Optimizer; use crate::simplify_expressions::SimplifyExpressions; use crate::test::*; use crate::OptimizerContext; use datafusion_expr::test::function_stub::sum; + use insta::assert_snapshot; use super::*; fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - crate::test::assert_optimized_plan_eq( - Arc::new(PushDownFilter::new()), - plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownFilter::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } - fn assert_optimized_plan_eq_with_rewrite_predicate( - plan: LogicalPlan, - expected: &str, - ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![ - Arc::new(SimplifyExpressions::new()), - Arc::new(PushDownFilter::new()), - ]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(expected, formatted_plan); - Ok(()) + macro_rules! assert_optimized_plan_eq_with_rewrite_predicate { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = Optimizer::with_rules(vec![ + Arc::new(SimplifyExpressions::new()), + Arc::new(PushDownFilter::new()), + ]); + let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?; + assert_snapshot!(optimized_plan, @ $expected); + Ok::<(), DataFusionError>(()) + }}; } #[test] @@ -1443,10 +1496,13 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before projection - let expected = "\ - Projection: test.a, test.b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -1458,12 +1514,15 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before single projection - let expected = "\ - Filter: test.a = Int64(1)\ - \n Limit: skip=0, fetch=10\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a = Int64(1) + Limit: skip=0, fetch=10 + Projection: test.a, test.b + TableScan: test + " + ) } #[test] @@ -1472,8 +1531,10 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(lit(0i64).eq(lit(1i64)))? .build()?; - let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test, full_filters=[Int64(0) = Int64(1)]" + ) } #[test] @@ -1485,11 +1546,14 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before double projection - let expected = "\ - Projection: test.c, test.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.b + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -1500,10 +1564,37 @@ mod tests { .filter(col("a").gt(lit(10i64)))? .build()?; // filter of key aggregation is commutative - let expected = "\ - Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) + } + + /// verifies that filters with unusual column names are pushed down through aggregate operators + #[test] + fn filter_move_agg_special() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("$a", DataType::UInt32, false), + Field::new("$b", DataType::UInt32, false), + Field::new("$c", DataType::UInt32, false), + ]); + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])? + .filter(col("$a").gt(lit(10i64)))? + .build()?; + // filter of key aggregation is commutative + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]] + TableScan: test, full_filters=[test.$a > Int64(10)] + " + ) } #[test] @@ -1513,10 +1604,14 @@ mod tests { .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? .filter(col("b").gt(lit(10i64)))? .build()?; - let expected = "Filter: test.b > Int64(10)\ - \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > Int64(10) + Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]] + TableScan: test + " + ) } #[test] @@ -1525,10 +1620,13 @@ mod tests { .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? .filter(col("test.b + test.a").gt(lit(10i64)))? .build()?; - let expected = - "Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\ - \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]] + TableScan: test, full_filters=[test.b + test.a > Int64(10)] + " + ) } #[test] @@ -1539,11 +1637,14 @@ mod tests { .filter(col("b").gt(lit(10i64)))? .build()?; // filter of aggregate is after aggregation since they are non-commutative - let expected = "\ - Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: b > Int64(10) + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]] + TableScan: test + " + ) } /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed @@ -1551,7 +1652,7 @@ mod tests { fn filter_move_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1567,10 +1668,48 @@ mod tests { .filter(col("b").gt(lit(10i64)))? .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.b > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.b > Int64(10)] + " + ) + } + + /// verifies that filters with unusual identifier names are pushed down through window functions + #[test] + fn filter_window_special_identifier() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("$a", DataType::UInt32, false), + Field::new("$b", DataType::UInt32, false), + Field::new("$c", DataType::UInt32, false), + ]); + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + + let window = Expr::from(WindowFunction::new( + WindowFunctionDefinition::WindowUDF( + datafusion_functions_window::rank::rank_udwf(), + ), + vec![], + )) + .partition_by(vec![col("$a"), col("$b")]) + .order_by(vec![col("$c").sort(true, true)]) + .build() + .unwrap(); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![window])? + .filter(col("$b").gt(lit(10i64)))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.$b > Int64(10)] + " + ) } /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and @@ -1579,7 +1718,7 @@ mod tests { fn filter_move_complex_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1595,10 +1734,13 @@ mod tests { .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)] + " + ) } /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed @@ -1606,7 +1748,7 @@ mod tests { fn filter_move_partial_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1622,11 +1764,14 @@ mod tests { .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? .build()?; - let expected = "\ - Filter: test.b = Int64(1)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b = Int64(1) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that filters on partition expressions are not pushed, as the single expression @@ -1635,7 +1780,7 @@ mod tests { fn filter_expression_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1653,11 +1798,14 @@ mod tests { .filter(add(col("a"), col("b")).gt(lit(10i64)))? .build()?; - let expected = "\ - Filter: test.a + test.b > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a + test.b > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that filters are not pushed on order by columns (that are not used in partitioning) @@ -1665,7 +1813,7 @@ mod tests { fn filter_order_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1681,11 +1829,14 @@ mod tests { .filter(col("c").gt(lit(10i64)))? .build()?; - let expected = "\ - Filter: test.c > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that when we use multiple window functions with a common partition key, the filter @@ -1694,7 +1845,7 @@ mod tests { fn filter_multiple_windows_common_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1705,7 +1856,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1721,10 +1872,13 @@ mod tests { .filter(col("a").gt(lit(10i64)))? // a appears in both window functions .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when we use multiple window functions with different partitions keys, the @@ -1733,7 +1887,7 @@ mod tests { fn filter_multiple_windows_disjoint_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1744,7 +1898,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1760,11 +1914,14 @@ mod tests { .filter(col("b").gt(lit(10i64)))? // b only appears in one window function .build()?; - let expected = "\ - Filter: test.b > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -1776,10 +1933,13 @@ mod tests { .filter(col("b").eq(lit(1i64)))? .build()?; // filter is before projection - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } fn add(left: Expr, right: Expr) -> Expr { @@ -1811,19 +1971,21 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: b = Int64(1)\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: b = Int64(1) + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test + ", ); - // filter is before projection - let expected = "\ - Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)] + " + ) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1841,21 +2003,23 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: a = Int64(1)\ - \n Projection: b * Int32(3) AS a, test.c\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: a = Int64(1) + Projection: b * Int32(3) AS a, test.c + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Projection: b * Int32(3) AS a, test.c\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b * Int32(3) AS a, test.c + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)] + " + ) } #[derive(Debug, PartialEq, Eq, Hash)] @@ -1867,7 +2031,10 @@ mod tests { // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for NoopPlan { fn partial_cmp(&self, other: &Self) -> Option { - self.input.partial_cmp(&other.input) + self.input + .partial_cmp(&other.input) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -1930,10 +2097,13 @@ mod tests { .build()?; // Push filter below NoopPlan - let expected = "\ - NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1946,11 +2116,14 @@ mod tests { .build()?; // Push only predicate on `a` below NoopPlan - let expected = "\ - Filter: test.c = Int64(2)\ - \n NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c = Int64(2) + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1963,11 +2136,14 @@ mod tests { .build()?; // Push filter below NoopPlan for each child branch - let expected = "\ - NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1980,12 +2156,15 @@ mod tests { .build()?; // Push only predicate on `a` below NoopPlan - let expected = "\ - Filter: test.c = Int64(2)\ - \n NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c = Int64(2) + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -2002,23 +2181,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: sum(test.c) > Int64(10)\ - \n Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: sum(test.c) > Int64(10) + Filter: b > Int64(10) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Filter: sum(test.c) > Int64(10)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: sum(test.c) > Int64(10) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -2037,22 +2218,24 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when two limits are in place, we jump neither @@ -2067,14 +2250,17 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter does not just any of the limits - let expected = "\ - Projection: test.a, test.b\ - \n Filter: test.a = Int64(1)\ - \n Limit: skip=0, fetch=10\ - \n Limit: skip=0, fetch=20\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: test.a = Int64(1) + Limit: skip=0, fetch=10 + Limit: skip=0, fetch=20 + Projection: test.a, test.b + TableScan: test + " + ) } #[test] @@ -2086,10 +2272,14 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter appears below Union - let expected = "Union\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Union + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test2, full_filters=[test2.a = Int64(1)] + " + ) } #[test] @@ -2106,13 +2296,18 @@ mod tests { .build()?; // filter appears below Union - let expected = "Union\n SubqueryAlias: test2\ - \n Projection: test.a AS b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n SubqueryAlias: test2\ - \n Projection: test.a AS b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Union + SubqueryAlias: test2 + Projection: test.a AS b + TableScan: test, full_filters=[test.a = Int64(1)] + SubqueryAlias: test2 + Projection: test.a AS b + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -2136,14 +2331,17 @@ mod tests { .filter(filter)? .build()?; - let expected = "Projection: test.a, test1.d\ - \n Cross Join: \ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int32(1)]\ - \n Projection: test1.d, test1.e, test1.f\ - \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test1.d + Cross Join: + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int32(1)] + Projection: test1.d, test1.e, test1.f + TableScan: test1, full_filters=[test1.d > Int32(2)] + " + ) } #[test] @@ -2163,13 +2361,17 @@ mod tests { .filter(filter)? .build()?; - let expected = "Projection: test.a, test1.a\ - \n Cross Join: \ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int32(1)]\ - \n Projection: test1.a, test1.b, test1.c\ - \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test1.a + Cross Join: + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int32(1)] + Projection: test1.a, test1.b, test1.c + TableScan: test1, full_filters=[test1.a > Int32(2)] + " + ) } /// verifies that filters with the same columns are correctly placed @@ -2186,24 +2388,26 @@ mod tests { // Should be able to move both filters below the projections // not part of the test - assert_eq!( - format!("{plan}"), - "Filter: test.a >= Int64(1)\ - \n Projection: test.a\ - \n Limit: skip=0, fetch=1\ - \n Filter: test.a <= Int64(1)\ - \n Projection: test.a\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: test.a >= Int64(1) + Projection: test.a + Limit: skip=0, fetch=1 + Filter: test.a <= Int64(1) + Projection: test.a + TableScan: test + ", ); - - let expected = "\ - Projection: test.a\ - \n Filter: test.a >= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n Projection: test.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) + Limit: skip=0, fetch=1 + Projection: test.a + TableScan: test, full_filters=[test.a <= Int64(1)] + " + ) } /// verifies that filters to be placed on the same depth are ANDed @@ -2218,22 +2422,24 @@ mod tests { .build()?; // not part of the test - assert_eq!( - format!("{plan}"), - "Projection: test.a\ - \n Filter: test.a >= Int64(1)\ - \n Filter: test.a <= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) + Filter: test.a <= Int64(1) + Limit: skip=0, fetch=1 + TableScan: test + ", ); - - let expected = "\ - Projection: test.a\ - \n Filter: test.a >= Int64(1) AND test.a <= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) AND test.a <= Int64(1) + Limit: skip=0, fetch=1 + TableScan: test + " + ) } /// verifies that filters on a plan with user nodes are not lost @@ -2247,19 +2453,21 @@ mod tests { let plan = user_defined::new(plan); - let expected = "\ - TestUserDefined\ - \n Filter: test.a <= Int64(1)\ - \n TableScan: test"; - // not part of the test - assert_eq!(format!("{plan}"), expected); - - let expected = "\ - TestUserDefined\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - - assert_optimized_plan_eq(plan, expected) + assert_snapshot!(plan, + @r" + TestUserDefined + Filter: test.a <= Int64(1) + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + TestUserDefined + TableScan: test, full_filters=[test.a <= Int64(1)] + " + ) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -2282,22 +2490,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Inner Join: test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to side before the join - let expected = "\ - Inner Join: test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -2319,22 +2530,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Inner Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Inner Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to side before the join - let expected = "\ - Inner Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-join predicates with columns from both sides are converted to join filters @@ -2359,24 +2573,27 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.c <= test2.b\ - \n Inner Join: test.a = test2.a\ - \n Projection: test.a, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.c <= test2.b + Inner Join: test.a = test2.a + Projection: test.a, test.c + TableScan: test + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Filter is converted to Join Filter - let expected = "\ - Inner Join: test.a = test2.a Filter: test.c <= test2.b\ - \n Projection: test.a, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a Filter: test.c <= test2.b + Projection: test.a, test.c + TableScan: test + Projection: test2.a, test2.b + TableScan: test2 + " + ) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -2402,23 +2619,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.b <= Int64(1)\ - \n Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b\ - \n TableScan: test\ - \n Projection: test2.a, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.b <= Int64(1) + Inner Join: test.a = test2.a + Projection: test.a, test.b + TableScan: test + Projection: test2.a, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b\ - \n TableScan: test, full_filters=[test.b <= Int64(1)]\ - \n Projection: test2.a, test2.c\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + Projection: test.a, test.b + TableScan: test, full_filters=[test.b <= Int64(1)] + Projection: test2.a, test2.c + TableScan: test2 + " + ) } /// post-join predicates on the right side of a left join are not duplicated @@ -2441,23 +2661,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter not duplicated nor pushed down - i.e. noop - let expected = "\ - Filter: test2.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2 + " + ) } /// post-join predicates on the left side of a right join are not duplicated @@ -2479,23 +2702,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter not duplicated nor pushed down - i.e. noop - let expected = "\ - Filter: test.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -2518,22 +2744,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to left side of the join, not the right - let expected = "\ - Left Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2 + " + ) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -2556,22 +2785,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to right side of join, not duplicated to the left - let expected = "\ - Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -2599,22 +2831,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a Filter: test.b < test2.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.c > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a Filter: test.b < test2.b + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.c > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// join filter should be completely removed after pushdown @@ -2641,22 +2876,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2681,22 +2919,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\ - \n Projection: test.a\ - \n TableScan: test\ - \n Projection: test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.b Filter: test.a > UInt32(1) + Projection: test.a + TableScan: test + Projection: test2.b + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.b\ - \n Projection: test.a\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.b + Projection: test.a + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.b + TableScan: test2, full_filters=[test2.b > UInt32(1)] + " + ) } /// single table predicate parts of ON condition should be pushed to right input @@ -2724,22 +2965,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// single table predicate parts of ON condition should be pushed to left input @@ -2767,22 +3011,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2 + " + ) } /// single table predicate parts of ON condition should not be pushed @@ -2810,17 +3057,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = &format!("{plan}"); - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + " + ) } struct PushDownProvider { @@ -2864,9 +3119,7 @@ mod tests { let table_scan = LogicalPlan::TableScan(TableScan { table_name: "test".into(), filters, - projected_schema: Arc::new(DFSchema::try_from( - (*test_provider.schema()).clone(), - )?), + projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?), projection, source: Arc::new(test_provider), fetch: None, @@ -2887,9 +3140,10 @@ mod tests { fn filter_with_table_provider_exact() -> Result<()> { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?; - let expected = "\ - TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test, full_filters=[a = Int64(1)]" + ) } #[test] @@ -2897,10 +3151,13 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: a = Int64(1) + TableScan: test, partial_filters=[a = Int64(1)] + " + ) } #[test] @@ -2913,13 +3170,15 @@ mod tests { .expect("failed to optimize plan") .data; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test, partial_filters=[a = Int64(1)]"; - // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(optimized_plan, expected) + assert_optimized_plan_equal!( + optimized_plan, + @r" + Filter: a = Int64(1) + TableScan: test, partial_filters=[a = Int64(1)] + " + ) } #[test] @@ -2927,10 +3186,13 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: a = Int64(1) + TableScan: test + " + ) } #[test] @@ -2944,11 +3206,14 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - let expected = "Projection: a, b\ - \n Filter: a = Int64(10) AND b > Int64(11)\ - \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + Filter: a = Int64(10) AND b > Int64(11) + TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)] + " + ) } #[test] @@ -2962,13 +3227,13 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - let expected = r#" -Projection: a, b - TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)] - "# - .trim(); - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)] + " + ) } #[test] @@ -2983,20 +3248,21 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND test.c > Int64(10)\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND test.c > Int64(10) + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ - "; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } #[test] @@ -3012,23 +3278,23 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND test.c > Int64(10)\ - \n Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND test.c > Int64(10) + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ - "; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } #[test] @@ -3040,20 +3306,21 @@ Projection: a, b .build()?; // filter on col b and d - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND d > Int64(10)\ - \n Projection: test.a AS b, test.c AS d\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND d > Int64(10) + Projection: test.a AS b, test.c AS d + TableScan: test + ", ); - // rewrite filter col b to test.a, col d to test.c - let expected = "\ - Projection: test.a AS b, test.c AS d\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c AS d + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -3077,23 +3344,26 @@ Projection: a, b )? .build()?; - assert_eq!( - format!("{plan}"), - "Inner Join: c = d Filter: c > UInt32(1)\ - \n Projection: test.a AS c\ - \n TableScan: test\ - \n Projection: test2.b AS d\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: c = d Filter: c > UInt32(1) + Projection: test.a AS c + TableScan: test + Projection: test2.b AS d + TableScan: test2 + ", ); - // Change filter on col `c`, 'd' to `test.a`, 'test.b' - let expected = "\ - Inner Join: c = d\ - \n Projection: test.a AS c\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.b AS d\ - \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: c = d + Projection: test.a AS c + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.b AS d + TableScan: test2, full_filters=[test2.b > UInt32(1)] + " + ) } #[test] @@ -3109,20 +3379,21 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)]) + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])] + " + ) } #[test] @@ -3139,22 +3410,23 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ - \n Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)]) + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])] + " + ) } #[test] @@ -3174,23 +3446,27 @@ Projection: a, b .build()?; // filter on col b in subquery - let expected_before = "\ - Filter: b IN ()\ - \n Subquery:\ - \n Projection: sq.c\ - \n TableScan: sq\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Filter: b IN () + Subquery: + Projection: sq.c + TableScan: sq + Projection: test.a AS b, test.c + TableScan: test + ", + ); // rewrite filter col b to test.a - let expected_after = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ()]\ - \n Subquery:\ - \n Projection: sq.c\ - \n TableScan: sq"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ()] + Subquery: + Projection: sq.c + TableScan: sq + " + ) } #[test] @@ -3205,25 +3481,31 @@ Projection: a, b .project(vec![col("b.a")])? .build()?; - let expected_before = "Projection: b.a\ - \n Filter: b.a = Int64(1)\ - \n SubqueryAlias: b\ - \n Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: Int64(0) AS a\ - \n EmptyRelation"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Projection: b.a + Filter: b.a = Int64(1) + SubqueryAlias: b + Projection: b.a + SubqueryAlias: b + Projection: Int64(0) AS a + EmptyRelation: rows=1 + ", + ); // Ensure that the predicate without any columns (0 = 1) is // still there. - let expected_after = "Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: Int64(0) AS a\ - \n Filter: Int64(0) = Int64(1)\ - \n EmptyRelation"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b.a + SubqueryAlias: b + Projection: b.a + SubqueryAlias: b + Projection: Int64(0) AS a + Filter: Int64(0) = Int64(1) + EmptyRelation: rows=1 + " + ) } #[test] @@ -3245,13 +3527,14 @@ Projection: a, b .cross_join(right)? .filter(filter)? .build()?; - let expected = "\ - Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; - assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?; + + assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r" + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)] + Projection: test1.a AS d, test1.a AS e + TableScan: test1 + ")?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. @@ -3259,7 +3542,16 @@ Projection: a, b .rewrite(plan, &OptimizerContext::new()) .expect("failed to optimize plan") .data; - assert_optimized_plan_eq(optimized_plan, expected) + assert_optimized_plan_equal!( + optimized_plan, + @r" + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)] + Projection: test1.a AS d, test1.a AS e + TableScan: test1 + " + ) } #[test] @@ -3283,23 +3575,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + LeftSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side. - let expected = "\ - Filter: test2.a <= Int64(1)\ - \n LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.a <= Int64(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a <= Int64(1) + LeftSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.a <= Int64(1)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } #[test] @@ -3326,21 +3621,24 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Both side will be pushed down. - let expected = "\ - LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + LeftSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3364,23 +3662,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test1.a <= Int64(1)\ - \n RightSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test1.a <= Int64(1) + RightSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side. - let expected = "\ - Filter: test1.a <= Int64(1)\ - \n RightSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test1.a <= Int64(1) + RightSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } #[test] @@ -3407,21 +3708,24 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Both side will be pushed down. - let expected = "\ - RightSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + RightSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3448,25 +3752,28 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a > UInt32(2)\ - \n LeftAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test2.a > UInt32(2) + LeftAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For left anti, filter of the right side filter can be pushed down. - let expected = "\ - Filter: test2.a > UInt32(2)\ - \n LeftAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1, full_filters=[test1.a > UInt32(2)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a > UInt32(2) + LeftAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1, full_filters=[test1.a > UInt32(2)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } #[test] @@ -3496,23 +3803,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For left anti, filter of the right side filter can be pushed down. - let expected = "\ - LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3539,25 +3849,28 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test1.a > UInt32(2)\ - \n RightAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test1.a > UInt32(2) + RightAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For right anti, filter of the left side can be pushed down. - let expected = "\ - Filter: test1.a > UInt32(2)\ - \n RightAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.a > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test1.a > UInt32(2) + RightAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.a > UInt32(2)] + " + ) } #[test] @@ -3587,25 +3900,29 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For right anti, filter of the left side can be pushed down. - let expected = "RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } - #[derive(Debug)] + #[derive(Debug, PartialEq, Eq, Hash)] struct TestScalarUDF { signature: Signature, } @@ -3648,21 +3965,27 @@ Projection: a, b .project(vec![col("t.a"), col("t.r")])? .build()?; - let expected_before = "Projection: t.a, t.r\ - \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\ - \n SubqueryAlias: t\ - \n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\ - \n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\ - \n TableScan: test1"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: t.a, t.r\ - \n SubqueryAlias: t\ - \n Filter: r > Float64(0.5)\ - \n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\ - \n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\ - \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Projection: t.a, t.r + Filter: t.a > Int32(5) AND t.r > Float64(0.5) + SubqueryAlias: t + Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r + Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]] + TableScan: test1 + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: t.a, t.r + SubqueryAlias: t + Filter: r > Float64(0.5) + Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r + Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]] + TableScan: test1, full_filters=[test1.a > Int32(5)] + " + ) } #[test] @@ -3692,23 +4015,29 @@ Projection: a, b .project(vec![col("t.a"), col("t.r")])? .build()?; - let expected_before = "Projection: t.a, t.r\ - \n Filter: t.r > Float64(0.8)\ - \n SubqueryAlias: t\ - \n Projection: test1.a AS a, TestScalarUDF() AS r\ - \n Inner Join: test1.a = test2.a\ - \n TableScan: test1\ - \n TableScan: test2"; - assert_eq!(format!("{plan}"), expected_before); - - let expected = "Projection: t.a, t.r\ - \n SubqueryAlias: t\ - \n Filter: r > Float64(0.8)\ - \n Projection: test1.a AS a, TestScalarUDF() AS r\ - \n Inner Join: test1.a = test2.a\ - \n TableScan: test1\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_snapshot!(plan, + @r" + Projection: t.a, t.r + Filter: t.r > Float64(0.8) + SubqueryAlias: t + Projection: test1.a AS a, TestScalarUDF() AS r + Inner Join: test1.a = test2.a + TableScan: test1 + TableScan: test2 + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: t.a, t.r + SubqueryAlias: t + Filter: r > Float64(0.8) + Projection: test1.a AS a, TestScalarUDF() AS r + Inner Join: test1.a = test2.a + TableScan: test1 + TableScan: test2 + " + ) } #[test] @@ -3724,15 +4053,21 @@ Projection: a, b .filter(expr.gt(lit(0.1)))? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1)\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: test.a, test.b\ - \n Filter: TestScalarUDF() > Float64(0.1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) + Projection: test.a, test.b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: TestScalarUDF() > Float64(0.1) + TableScan: test + " + ) } #[test] @@ -3752,15 +4087,21 @@ Projection: a, b )? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: test.a, test.b\ - \n Filter: TestScalarUDF() > Float64(0.1)\ - \n TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10) + Projection: test.a, test.b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: TestScalarUDF() > Float64(0.1) + TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)] + " + ) } #[test] @@ -3783,15 +4124,21 @@ Projection: a, b )? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ - \n Projection: a, b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: a, b\ - \n Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10) + Projection: a, b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1) + TableScan: test + " + ) } #[test] @@ -3864,12 +4211,19 @@ Projection: a, b let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?; // Check the original plan format (not part of the test assertions) - let expected_before = "Filter: Boolean(false)\ - \n TestUserNode"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Filter: Boolean(false) + TestUserNode + ", + ); // Check that the filter is pushed down to the user-defined node - let expected_after = "Filter: Boolean(false)\n TestUserNode"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Filter: Boolean(false) + TestUserNode + " + ) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 1e9ef16bde675..c5a2e65788051 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -276,8 +276,10 @@ mod test { use std::vec; use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use datafusion_common::DFSchemaRef; use datafusion_expr::{ col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension, @@ -285,8 +287,20 @@ mod test { }; use datafusion_functions_aggregate::expr_fn::max; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownLimit::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[derive(Debug, PartialEq, Eq, Hash)] @@ -298,7 +312,10 @@ mod test { // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for NoopPlan { fn partial_cmp(&self, other: &Self) -> Option { - self.input.partial_cmp(&other.input) + self.input + .partial_cmp(&other.input) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -351,7 +368,10 @@ mod test { // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for NoLimitNoopPlan { fn partial_cmp(&self, other: &Self) -> Option { - self.input.partial_cmp(&other.input) + self.input + .partial_cmp(&other.input) + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -408,12 +428,15 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -430,12 +453,15 @@ mod test { .limit(10, Some(1000))? .build()?; - let expected = "Limit: skip=10, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -453,12 +479,15 @@ mod test { .limit(20, Some(500))? .build()?; - let expected = "Limit: skip=30, fetch=500\ - \n NoopPlan\ - \n Limit: skip=0, fetch=530\ - \n TableScan: test, fetch=530"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=30, fetch=500 + NoopPlan + Limit: skip=0, fetch=530 + TableScan: test, fetch=530 + " + ) } #[test] @@ -475,14 +504,17 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -499,11 +531,14 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoLimitNoopPlan\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoLimitNoopPlan + TableScan: test + " + ) } #[test] @@ -517,11 +552,14 @@ mod test { // Should push the limit down to table provider // When it has a select - let expected = "Projection: test.a\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -536,10 +574,13 @@ mod test { // Should push down the smallest limit // Towards table scan // This rule doesn't replace multiple limits - let expected = "Limit: skip=0, fetch=10\ - \n TableScan: test, fetch=10"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + TableScan: test, fetch=10 + " + ) } #[test] @@ -552,11 +593,14 @@ mod test { .build()?; // Limit should *not* push down aggregate node - let expected = "Limit: skip=0, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + TableScan: test + " + ) } #[test] @@ -569,14 +613,17 @@ mod test { .build()?; // Limit should push down through union - let expected = "Limit: skip=0, fetch=1000\ - \n Union\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Union + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -589,11 +636,14 @@ mod test { .build()?; // Should push down limit to sort - let expected = "Limit: skip=0, fetch=10\ - \n Sort: test.a ASC NULLS LAST, fetch=10\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + Sort: test.a ASC NULLS LAST, fetch=10 + TableScan: test + " + ) } #[test] @@ -606,11 +656,14 @@ mod test { .build()?; // Should push down limit to sort - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS LAST, fetch=15\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS LAST, fetch=15 + TableScan: test + " + ) } #[test] @@ -624,12 +677,15 @@ mod test { .build()?; // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation - let expected = "Limit: skip=0, fetch=10\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -641,10 +697,13 @@ mod test { // Should not push any limit down to table provider // When it has a select - let expected = "Limit: skip=10, fetch=None\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=None + TableScan: test + " + ) } #[test] @@ -658,11 +717,14 @@ mod test { // Should push the limit down to table provider // When it has a select - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=1000\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=1000 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -675,11 +737,14 @@ mod test { .limit(10, None)? .build()?; - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=990\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=990 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -692,11 +757,14 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=1000\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=1000 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -709,10 +777,13 @@ mod test { .limit(0, Some(10))? .build()?; - let expected = "Limit: skip=10, fetch=10\ - \n TableScan: test, fetch=20"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=10 + TableScan: test, fetch=20 + " + ) } #[test] @@ -725,11 +796,14 @@ mod test { .build()?; // Limit should *not* push down aggregate node - let expected = "Limit: skip=10, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + TableScan: test + " + ) } #[test] @@ -742,14 +816,17 @@ mod test { .build()?; // Limit should push down through union - let expected = "Limit: skip=10, fetch=1000\ - \n Union\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Union + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -768,12 +845,15 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Inner Join: test.a = test2.a + TableScan: test + TableScan: test2 + " + ) } #[test] @@ -792,12 +872,15 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Inner Join: test.a = test2.a + TableScan: test + TableScan: test2 + " + ) } #[test] @@ -817,16 +900,19 @@ mod test { .build()?; // Limit pushdown Not supported in sub_query - let expected = "Limit: skip=10, fetch=100\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: test1.a = test1.a\ - \n Projection: test1.a\ - \n TableScan: test1\ - \n Projection: test2.a\ - \n TableScan: test2"; - - assert_optimized_plan_equal(outer_query, expected) + assert_optimized_plan_equal!( + outer_query, + @r" + Limit: skip=10, fetch=100 + Filter: EXISTS () + Subquery: + Filter: test1.a = test1.a + Projection: test1.a + TableScan: test1 + Projection: test2.a + TableScan: test2 + " + ) } #[test] @@ -846,16 +932,19 @@ mod test { .build()?; // Limit pushdown Not supported in sub_query - let expected = "Limit: skip=10, fetch=100\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: test1.a = test1.a\ - \n Projection: test1.a\ - \n TableScan: test1\ - \n Projection: test2.a\ - \n TableScan: test2"; - - assert_optimized_plan_equal(outer_query, expected) + assert_optimized_plan_equal!( + outer_query, + @r" + Limit: skip=10, fetch=100 + Filter: EXISTS () + Subquery: + Filter: test1.a = test1.a + Projection: test1.a + TableScan: test1 + Projection: test2.a + TableScan: test2 + " + ) } #[test] @@ -874,13 +963,16 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Left Join: test.a = test2.a\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Left Join: test.a = test2.a + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + TableScan: test2 + " + ) } #[test] @@ -899,13 +991,16 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=0, fetch=1000\ - \n Right Join: test.a = test2.a\ - \n TableScan: test\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test2, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Right Join: test.a = test2.a + TableScan: test + Limit: skip=0, fetch=1000 + TableScan: test2, fetch=1000 + " + ) } #[test] @@ -924,13 +1019,16 @@ mod test { .build()?; // Limit pushdown with offset supported in right outer join - let expected = "Limit: skip=10, fetch=1000\ - \n Right Join: test.a = test2.a\ - \n TableScan: test\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test2, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Right Join: test.a = test2.a + TableScan: test + Limit: skip=0, fetch=1010 + TableScan: test2, fetch=1010 + " + ) } #[test] @@ -943,14 +1041,17 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n Cross Join: \ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test2, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Cross Join: + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test2, fetch=1000 + " + ) } #[test] @@ -963,14 +1064,17 @@ mod test { .limit(1000, Some(1000))? .build()?; - let expected = "Limit: skip=1000, fetch=1000\ - \n Cross Join: \ - \n Limit: skip=0, fetch=2000\ - \n TableScan: test, fetch=2000\ - \n Limit: skip=0, fetch=2000\ - \n TableScan: test2, fetch=2000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=1000 + Cross Join: + Limit: skip=0, fetch=2000 + TableScan: test, fetch=2000 + Limit: skip=0, fetch=2000 + TableScan: test2, fetch=2000 + " + ) } #[test] @@ -982,10 +1086,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } #[test] @@ -997,10 +1104,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } #[test] @@ -1013,10 +1123,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "SubqueryAlias: a\ - \n Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + SubqueryAlias: a + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 48b2828faf452..2383787fa0e8a 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -186,21 +186,29 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { mod tests { use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::test::*; + use crate::OptimizerContext; use datafusion_common::Result; - use datafusion_expr::{ - col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, - }; + use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, Expr}; use datafusion_functions_aggregate::sum::sum; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq( - Arc::new(ReplaceDistinctWithAggregate::new()), - plan.clone(), - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(ReplaceDistinctWithAggregate::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -212,8 +220,11 @@ mod tests { .distinct()? .build()?; - let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.c + Aggregate: groupBy=[[test.c]], aggr=[[]] + TableScan: test + ") } #[test] @@ -225,9 +236,11 @@ mod tests { .distinct()? .build()?; - let expected = - "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, test.b + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + TableScan: test + ") } #[test] @@ -238,8 +251,11 @@ mod tests { .distinct()? .build()?; - let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + Projection: test.a, test.b + TableScan: test + ") } #[test] @@ -251,8 +267,11 @@ mod tests { .distinct()? .build()?; - let expected = - "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + Projection: test.a, test.b + Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]] + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 499447861a58b..48d1182527013 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -22,9 +22,10 @@ use std::sync::Arc; use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; -use crate::utils::replace_qualified_name; +use crate::utils::{evaluates_to_null, replace_qualified_name}; use crate::{OptimizerConfig, OptimizerRule}; +use crate::analyzer::type_coercion::TypeCoercionRewriter; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, @@ -36,6 +37,8 @@ use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; /// Optimizer rule for rewriting subquery filters to joins +/// and places additional projection on top of the filter, to preserve +/// original schema. #[derive(Default, Debug)] pub struct ScalarSubqueryToJoin {} @@ -122,8 +125,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } } + + // Preserve original schema as new Join might have more fields than what Filter & parents expect. + let projection = + filter.input.schema().columns().into_iter().map(Expr::from); let new_plan = LogicalPlanBuilder::from(cur_input) .filter(rewrite_expr)? + .project(projection)? .build()?; Ok(Transformed::yes(new_plan)) } @@ -334,7 +342,7 @@ fn build_join( .join_on( sub_query_alias, JoinType::Left, - vec![Expr::Literal(ScalarValue::Boolean(Some(true)))], + vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)], )? .build()? } @@ -348,6 +356,10 @@ fn build_join( let mut computation_project_expr = HashMap::new(); if let Some(expr_map) = collected_count_expr_map { for (name, result) in expr_map { + if evaluates_to_null(result.clone(), result.column_refs())? { + // If expr always returns null when column is null, skip processing + continue; + } let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr { Expr::Case(expr::Case { expr: None, @@ -360,7 +372,7 @@ fn build_join( ), ( Box::new(Expr::Not(Box::new(filter.clone()))), - Box::new(Expr::Literal(ScalarValue::Null)), + Box::new(Expr::Literal(ScalarValue::Null, None)), ), ], else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( @@ -381,7 +393,11 @@ fn build_join( )))), }) }; - computation_project_expr.insert(name, computer_expr); + let mut expr_rewrite = TypeCoercionRewriter { + schema: new_plan.schema(), + }; + computation_project_expr + .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?); } } @@ -398,9 +414,24 @@ mod tests { use arrow::datatypes::DataType; use datafusion_expr::test::function_stub::sum; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; use datafusion_functions_aggregate::min_max::{max, min}; + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(ScalarSubqueryToJoin::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; + } + /// Test multiple correlated subqueries #[test] fn multiple_subqueries() -> Result<()> { @@ -424,25 +455,25 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test recursive correlated subqueries @@ -479,26 +510,27 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\ - \n Projection: sum(orders.o_totalprice), orders.o_custkey [sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, sum(orders.o_totalprice):Float64;N]\ - \n Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\ - \n Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\ - \n Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\ - \n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, sum(lineitem.l_extendedprice):Float64;N]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean] + Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N] + Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean] + Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated scalar subquery filter with additional subquery filters @@ -521,22 +553,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery with no columns in schema @@ -559,20 +590,20 @@ mod tests { .build()?; // it will optimize, but fail for the same reason the unoptimized query would - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for scalar subquery with both columns in schema @@ -591,22 +622,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery not equal @@ -629,21 +659,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar subquery less than @@ -666,21 +694,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar subquery filter with subquery disjunction @@ -704,21 +730,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar without projection @@ -759,21 +783,61 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\ - \n Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) + } + + /// Test for correlated scalar subquery with non-strong project + #[test] + fn scalar_subquery_with_non_strong_project() -> Result<()> { + let case = Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(col("max(orders.o_totalprice)")), + Box::new(lit("a")), + )], + else_expr: Some(Box::new(lit("b"))), + }); + + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + out_ref_col(DataType::Int64, "customer.c_custkey") + .eq(col("orders.o_custkey")), + )? + .aggregate(Vec::::new(), vec![max(col("orders.o_totalprice"))])? + .project(vec![case])? + .build()?, ); - Ok(()) + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r#" + Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] + Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + "# + ) } /// Test for correlated scalar subquery multiple projected columns @@ -823,21 +887,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -862,21 +925,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery filter with disjunctions @@ -902,21 +964,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery filter @@ -935,21 +996,20 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = "Projection: test.c [c:UInt32]\ - \n Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\ - \n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32]\ - \n Projection: min(sq.c), sq.a [min(sq.c):UInt32;N, a:UInt32]\ - \n Aggregate: groupBy=[[sq.a]], aggr=[[min(sq.c)]] [a:UInt32, min(sq.c):UInt32;N]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.c [c:UInt32] + Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32] + Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] + Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean] + Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean] + Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for non-correlated scalar subquery with no filters @@ -967,21 +1027,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -998,21 +1057,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1050,26 +1108,25 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: min(orders.o_custkey), orders.o_custkey [min(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, min(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1099,25 +1156,24 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]\ - \n Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N] + Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 9003467703df2..c40906239073a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -17,28 +17,30 @@ //! Expression simplification API -use std::borrow::Cow; -use std::collections::HashSet; -use std::ops::Not; - use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; +use std::borrow::Cow; +use std::collections::HashSet; +use std::ops::Not; +use std::sync::Arc; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; -use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{ - and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, - WindowFunctionDefinition, + and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, + Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ - expr::{InList, InSubquery, WindowFunction}, + expr::{InList, InSubquery}, utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast}; @@ -46,6 +48,7 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; +use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ @@ -54,10 +57,8 @@ use crate::simplify_expressions::unwrap_cast::{ unwrap_cast_in_comparison_for_binary, }; use crate::simplify_expressions::SimplifyInfo; -use crate::{ - analyzer::type_coercion::TypeCoercionRewriter, - simplify_expressions::unwrap_cast::try_cast_literal_to_type, -}; +use datafusion_expr::expr::FieldMetadata; +use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; use regex::Regex; @@ -188,7 +189,7 @@ impl ExprSimplifier { /// assert_eq!(expr, b_lt_2); /// ``` pub fn simplify(&self, expr: Expr) -> Result { - Ok(self.simplify_with_cycle_count(expr)?.0) + Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data) } /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating @@ -198,7 +199,34 @@ impl ExprSimplifier { /// /// See [Self::simplify] for details and usage examples. /// + #[deprecated( + since = "48.0.0", + note = "Use `simplify_with_cycle_count_transformed` instead" + )] + #[allow(unused_mut)] pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> { + let (transformed, cycle_count) = + self.simplify_with_cycle_count_transformed(expr)?; + Ok((transformed.data, cycle_count)) + } + + /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating + /// constants and applying algebraic simplifications. Additionally returns a `u32` + /// representing the number of simplification cycles performed, which can be useful for testing + /// optimizations. + /// + /// # Returns + /// + /// A tuple containing: + /// - The simplified expression wrapped in a `Transformed` indicating if changes were made + /// - The number of simplification cycles that were performed + /// + /// See [Self::simplify] for details and usage examples. + /// + pub fn simplify_with_cycle_count_transformed( + &self, + mut expr: Expr, + ) -> Result<(Transformed, u32)> { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); @@ -212,6 +240,7 @@ impl ExprSimplifier { // simplifications can enable new constant evaluation // see `Self::with_max_cycles` let mut num_cycles = 0; + let mut has_transformed = false; loop { let Transformed { data, transformed, .. @@ -221,13 +250,18 @@ impl ExprSimplifier { .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; expr = data; num_cycles += 1; + // Track if any transformation occurred + has_transformed = has_transformed || transformed; if !transformed || num_cycles >= self.max_simplifier_cycles { break; } } // shorten inlist should be started after other inlist rules are applied expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; - Ok((expr, num_cycles)) + Ok(( + Transformed::new_transformed(expr, has_transformed), + num_cycles, + )) } /// Apply type coercion to an [`Expr`] so that it can be @@ -392,15 +426,15 @@ impl ExprSimplifier { /// let expr = col("a").is_not_null(); /// /// // When using default maximum cycles, 2 cycles will be performed. - /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap(); - /// assert_eq!(simplified_expr, lit(true)); + /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap(); + /// assert_eq!(simplified_expr.data, lit(true)); /// // 2 cycles were executed, but only 1 was needed /// assert_eq!(count, 2); /// /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1. - /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap(); + /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap(); /// // Expression has been rewritten to: (c = a AND b = 1) - /// assert_eq!(simplified_expr, lit(true)); + /// assert_eq!(simplified_expr.data, lit(true)); /// // Only 1 cycle was executed /// assert_eq!(count, 1); /// @@ -444,7 +478,7 @@ impl TreeNodeRewriter for Canonicalizer { }))) } // - (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { + (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -489,9 +523,9 @@ struct ConstEvaluator<'a> { /// The simplify result of ConstEvaluator enum ConstSimplifyResult { // Expr was simplified and contains the new expression - Simplified(ScalarValue), + Simplified(ScalarValue, Option), // Expr was not simplified and original value is returned - NotSimplified(ScalarValue), + NotSimplified(ScalarValue, Option), // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } @@ -533,11 +567,11 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { // any error is countered during simplification, return the original // so that normal evaluation can occur Some(true) => match self.evaluate_to_scalar(expr) { - ConstSimplifyResult::Simplified(s) => { - Ok(Transformed::yes(Expr::Literal(s))) + ConstSimplifyResult::Simplified(s, m) => { + Ok(Transformed::yes(Expr::Literal(s, m))) } - ConstSimplifyResult::NotSimplified(s) => { - Ok(Transformed::no(Expr::Literal(s))) + ConstSimplifyResult::NotSimplified(s, m) => { + Ok(Transformed::no(Expr::Literal(s, m))) } ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { Ok(Transformed::yes(expr)) @@ -557,11 +591,15 @@ impl<'a> ConstEvaluator<'a> { // The dummy column name is unused and doesn't matter as only // expressions without column references can be evaluated static DUMMY_COL_NAME: &str = "."; - let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); - let input_schema = DFSchema::try_from(schema.clone())?; + let schema = Arc::new(Schema::new(vec![Field::new( + DUMMY_COL_NAME, + DataType::Null, + true, + )])); + let input_schema = DFSchema::try_from(Arc::clone(&schema))?; // Need a single "input" row to produce a single output row let col = new_null_array(&DataType::Null, 1); - let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col])?; + let input_batch = RecordBatch::try_new(schema, vec![col])?; Ok(Self { can_evaluate: vec![], @@ -606,7 +644,7 @@ impl<'a> ConstEvaluator<'a> { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) | Expr::BinaryExpr { .. } @@ -632,8 +670,8 @@ impl<'a> ConstEvaluator<'a> { /// Internal helper to evaluates an Expr pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult { - if let Expr::Literal(s) = expr { - return ConstSimplifyResult::NotSimplified(s); + if let Expr::Literal(s, m) = expr { + return ConstSimplifyResult::NotSimplified(s, m); } let phys_expr = @@ -641,6 +679,16 @@ impl<'a> ConstEvaluator<'a> { Ok(e) => e, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; + let metadata = phys_expr + .return_field(self.input_batch.schema_ref()) + .ok() + .and_then(|f| { + let m = f.metadata(); + match m.is_empty() { + true => None, + false => Some(FieldMetadata::from(m)), + } + }); let col_val = match phys_expr.evaluate(&self.input_batch) { Ok(v) => v, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), @@ -649,17 +697,19 @@ impl<'a> ConstEvaluator<'a> { ColumnarValue::Array(a) => { if a.len() != 1 { ConstSimplifyResult::SimplifyRuntimeError( - DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())), + exec_datafusion_err!("Could not evaluate the expression, found a result of length {}", a.len()), expr, ) } else if as_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::List( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::List(a.as_list::().to_owned().into()), + metadata, + ) } else if as_large_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::LargeList( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::LargeList(a.as_list::().to_owned().into()), + metadata, + ) } else { // Non-ListArray match ScalarValue::try_from_array(&a, 0) { @@ -671,7 +721,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s, metadata) } } Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr), @@ -689,7 +739,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s, metadata) } } } @@ -728,6 +778,29 @@ impl TreeNodeRewriter for Simplifier<'_, S> { let info = self.info; Ok(match expr { + // `value op NULL` -> `NULL` + // `NULL op value` -> `NULL` + // except for few operators that can return non-null value even when one of the operands is NULL + ref expr @ Expr::BinaryExpr(BinaryExpr { + ref left, + ref op, + ref right, + }) if op.returns_null_on_null() + && (is_null(left.as_ref()) || is_null(right.as_ref())) => + { + Transformed::yes(Expr::Literal( + ScalarValue::try_new_null(&info.get_data_type(expr)?)?, + None, + )) + } + + // `NULL {AND, OR} NULL` -> `NULL` + Expr::BinaryExpr(BinaryExpr { + left, + op: And | Or, + right, + }) if is_null(&left) && is_null(&right) => Transformed::yes(lit_bool_null()), + // // Rules for Eq // @@ -760,6 +833,25 @@ impl TreeNodeRewriter for Simplifier<'_, S> { None => lit_bool_null(), }) } + // According to SQL's null semantics, NULL = NULL evaluates to NULL + // Both sides are the same expression (A = A) and A is non-volatile expression + // A = A --> A IS NOT NULL OR NULL + // A = A --> true (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: Eq, + right, + }) if (left == right) & !left.is_volatile() => { + Transformed::yes(match !info.nullable(&left)? { + true => lit(true), + false => Expr::BinaryExpr(BinaryExpr { + left: Box::new(Expr::IsNotNull(left)), + op: Or, + right: Box::new(lit_bool_null()), + }), + }) + } + // Rules for NotEq // @@ -976,30 +1068,23 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for Multiply // - // A * 1 --> A + // A * 1 --> A (with type coercion if needed) Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, - }) if is_one(&right) => Transformed::yes(*left), + }) if is_one(&right) => { + simplify_right_is_one_case(info, left, &Multiply, &right)? + } // 1 * A --> A Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, - }) if is_one(&left) => Transformed::yes(*right), - // A * null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Multiply, - right, - }) if is_null(&right) => Transformed::yes(*right), - // null * A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Multiply, - right: _, - }) if is_null(&left) => Transformed::yes(*left), + }) if is_one(&left) => { + // 1 * A is equivalent to A * 1 + simplify_right_is_one_case(info, right, &Multiply, &left)? + } // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -1033,36 +1118,14 @@ impl TreeNodeRewriter for Simplifier<'_, S> { left, op: Divide, right, - }) if is_one(&right) => Transformed::yes(*left), - // null / A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Divide, - right: _, - }) if is_null(&left) => Transformed::yes(*left), - // A / null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Divide, - right, - }) if is_null(&right) => Transformed::yes(*right), + }) if is_one(&right) => { + simplify_right_is_one_case(info, left, &Divide, &right)? + } // // Rules for Modulo // - // A % null --> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Modulo, - right, - }) if is_null(&right) => Transformed::yes(*right), - // null % A --> null - Expr::BinaryExpr(BinaryExpr { - left, - op: Modulo, - right: _, - }) if is_null(&left) => Transformed::yes(*left), // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -1072,29 +1135,16 @@ impl TreeNodeRewriter for Simplifier<'_, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // // Rules for BitwiseAnd // - // A & null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseAnd, - right, - }) if is_null(&right) => Transformed::yes(*right), - - // null & A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseAnd, - right: _, - }) if is_null(&left) => Transformed::yes(*left), - // A & 0 -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, @@ -1115,9 +1165,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // A & !A -> 0 (if A not nullable) @@ -1126,9 +1177,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) & A --> (..A..) @@ -1167,20 +1219,6 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for BitwiseOr // - // A | null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseOr, - right, - }) if is_null(&right) => Transformed::yes(*right), - - // null | A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseOr, - right: _, - }) if is_null(&left) => Transformed::yes(*left), - // A | 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, @@ -1201,9 +1239,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // A | !A -> -1 (if A not nullable) @@ -1212,9 +1251,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) | A --> (..A..) @@ -1253,20 +1293,6 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for BitwiseXor // - // A ^ null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseXor, - right, - }) if is_null(&right) => Transformed::yes(*right), - - // null ^ A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseXor, - right: _, - }) if is_null(&left) => Transformed::yes(*left), - // A ^ 0 -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, @@ -1287,9 +1313,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // A ^ !A -> -1 (if A not nullable) @@ -1298,9 +1325,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1311,7 +1339,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&right)?)?, + None, + ) } else { expr }) @@ -1325,7 +1356,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + ) } else { expr }) @@ -1335,20 +1369,6 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for BitwiseShiftRight // - // A >> null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseShiftRight, - right, - }) if is_null(&right) => Transformed::yes(*right), - - // null >> A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftRight, - right: _, - }) if is_null(&left) => Transformed::yes(*left), - // A >> 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, @@ -1360,20 +1380,6 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for BitwiseShiftRight // - // A << null -> null - Expr::BinaryExpr(BinaryExpr { - left: _, - op: BitwiseShiftLeft, - right, - }) if is_null(&right) => Transformed::yes(*right), - - // null << A -> null - Expr::BinaryExpr(BinaryExpr { - left, - op: BitwiseShiftLeft, - right: _, - }) if is_null(&left) => Transformed::yes(*left), - // A << 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, @@ -1395,6 +1401,89 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for Case // + // Inline a comparison to a literal with the case statement into the `THEN` clauses. + // which can enable further simplifications + // CASE WHEN X THEN "a" WHEN Y THEN "b" ... END = "a" --> CASE WHEN X THEN "a" = "a" WHEN Y THEN "b" = "a" END + Expr::BinaryExpr(BinaryExpr { + left, + op: op @ (Eq | NotEq), + right, + }) if is_case_with_literal_outputs(&left) && is_lit(&right) => { + let case = into_case(*left)?; + Transformed::yes(Expr::Case(Case { + expr: None, + when_then_expr: case + .when_then_expr + .into_iter() + .map(|(when, then)| { + ( + when, + Box::new(Expr::BinaryExpr(BinaryExpr { + left: then, + op, + right: right.clone(), + })), + ) + }) + .collect(), + else_expr: case.else_expr.map(|els| { + Box::new(Expr::BinaryExpr(BinaryExpr { + left: els, + op, + right, + })) + }), + })) + } + + // CASE WHEN true THEN A ... END --> A + // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END + // CASE WHEN false THEN A END --> NULL + // CASE WHEN false THEN A ELSE B END --> B + // CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END + Expr::Case(Case { + expr: None, + when_then_expr, + mut else_expr, + }) if when_then_expr + .iter() + .any(|(when, _)| is_true(when.as_ref()) || is_false(when.as_ref())) => + { + let out_type = info.get_data_type(&when_then_expr[0].1)?; + let mut new_when_then_expr = Vec::with_capacity(when_then_expr.len()); + + for (when, then) in when_then_expr.into_iter() { + if is_true(when.as_ref()) { + // Skip adding the rest of the when-then expressions after WHEN true + // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END + else_expr = Some(then); + break; + } else if !is_false(when.as_ref()) { + new_when_then_expr.push((when, then)); + } + // else: skip WHEN false cases + } + + // Exclude CASE statement altogether if there are no when-then expressions left + if new_when_then_expr.is_empty() { + // CASE WHEN false THEN A ELSE B END --> B + if let Some(else_expr) = else_expr { + return Ok(Transformed::yes(*else_expr)); + // CASE WHEN false THEN A END --> NULL + } else { + let null = + Expr::Literal(ScalarValue::try_new_null(&out_type)?, None); + return Ok(Transformed::yes(null)); + } + } + + Transformed::yes(Expr::Case(Case { + expr: None, + when_then_expr: new_when_then_expr, + else_expr, + })) + } + // CASE // WHEN X THEN A // WHEN Y THEN B @@ -1411,7 +1500,11 @@ impl TreeNodeRewriter for Simplifier<'_, S> { when_then_expr, else_expr, }) if !when_then_expr.is_empty() - && when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number + // The rewrite is O(n²) in general so limit to small number of when-thens that can be true + && (when_then_expr.len() < 3 // small number of input whens + // or all thens are literal bools and a small number of them are true + || (when_then_expr.iter().all(|(_, then)| is_bool_lit(then)) + && when_then_expr.iter().filter(|(_, then)| is_true(then)).count() < 3)) && info.is_boolean_type(&when_then_expr[0].1)? => { // String disjunction of all the when predicates encountered so far. Not nullable. @@ -1435,6 +1528,55 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Do a first pass at simplification out_expr.rewrite(self)? } + // CASE + // WHEN X THEN true + // WHEN Y THEN true + // WHEN Z THEN false + // ... + // ELSE true + // END + // + // ---> + // + // NOT(CASE + // WHEN X THEN false + // WHEN Y THEN false + // WHEN Z THEN true + // ... + // ELSE false + // END) + // + // Note: the rationale for this rewrite is that the case can then be further + // simplified into a small number of ANDs and ORs + Expr::Case(Case { + expr: None, + when_then_expr, + else_expr, + }) if !when_then_expr.is_empty() + && when_then_expr + .iter() + .all(|(_, then)| is_bool_lit(then)) // all thens are literal bools + // This simplification is only helpful if we end up with a small number of true thens + && when_then_expr + .iter() + .filter(|(_, then)| is_false(then)) + .count() + < 3 + && else_expr.as_deref().is_none_or(is_bool_lit) => + { + Transformed::yes( + Expr::Case(Case { + expr: None, + when_then_expr: when_then_expr + .into_iter() + .map(|(when, then)| (when, Box::new(Expr::Not(then)))) + .collect(), + else_expr: else_expr + .map(|else_expr| Box::new(Expr::Not(else_expr))), + }) + .not(), + ) + } Expr::ScalarFunction(ScalarFunction { func: udf, args }) => { match udf.simplify(args, info)? { ExprSimplifyResult::Original(args) => { @@ -1457,12 +1599,9 @@ impl TreeNodeRewriter for Simplifier<'_, S> { (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(ref udwf), - .. - }) => match (udwf.simplify(), expr) { + Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) { (Some(simplify_function), Expr::WindowFunction(wf)) => { - Transformed::yes(simplify_function(wf, info)?) + Transformed::yes(simplify_function(*wf, info)?) } (_, expr) => Transformed::no(expr), }, @@ -1541,8 +1680,9 @@ impl TreeNodeRewriter for Simplifier<'_, S> { })) } Some(pattern_str) - if !pattern_str - .contains(['%', '_', escape_char].as_ref()) => + if !like.case_insensitive + && !pattern_str + .contains(['%', '_', escape_char].as_ref()) => { // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression // TODO: handle escape characters @@ -1575,20 +1715,20 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // expr IN () --> false // expr NOT IN () --> true Expr::InList(InList { - expr, + expr: _, list, negated, - }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { - Transformed::yes(lit(negated)) - } + }) if list.is_empty() => Transformed::yes(lit(negated)), // null in (x, y, z) --> null // null not in (x, y, z) --> null Expr::InList(InList { expr, - list: _, + list, negated: _, - }) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()), + }) if is_null(expr.as_ref()) && !list.is_empty() => { + Transformed::yes(lit_bool_null()) + } // expr IN ((subquery)) -> expr IN (subquery), see ##5529 Expr::InList(InList { @@ -1761,7 +1901,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { info, &left, op, &right, ) && op.supports_propagation() => { - unwrap_cast_in_comparison_for_binary(info, left, right, op)? + unwrap_cast_in_comparison_for_binary(info, *left, *right, op)? } // literal op try_cast/cast(expr as data_type) // --> @@ -1774,8 +1914,8 @@ impl TreeNodeRewriter for Simplifier<'_, S> { { unwrap_cast_in_comparison_for_binary( info, - right, - left, + *right, + *left, op.swap().unwrap(), )? } @@ -1804,12 +1944,12 @@ impl TreeNodeRewriter for Simplifier<'_, S> { .into_iter() .map(|right| { match right { - Expr::Literal(right_lit_value) => { + Expr::Literal(right_lit_value, _) => { // if the right_lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else { internal_err!( - "Can't cast the list expr {:?} to type {:?}", + "Can't cast the list expr {:?} to type {}", right_lit_value, &expr_type )? }; @@ -1838,18 +1978,18 @@ impl TreeNodeRewriter for Simplifier<'_, S> { fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { - Expr::Literal(ScalarValue::Utf8(s)) => Some((DataType::Utf8, s)), - Expr::Literal(ScalarValue::LargeUtf8(s)) => Some((DataType::LargeUtf8, s)), - Expr::Literal(ScalarValue::Utf8View(s)) => Some((DataType::Utf8View, s)), + Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), + Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)), + Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)), _ => None, } } fn to_string_scalar(data_type: DataType, value: Option) -> Expr { match data_type { - DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value)), - DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value)), - DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value)), + DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None), + DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None), + DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None), _ => unreachable!(), } } @@ -1890,17 +2030,17 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { } /// Try to convert an expression to an in-list expression -fn as_inlist(expr: &Expr) -> Option> { +fn as_inlist(expr: &'_ Expr) -> Option> { match expr { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { + (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList { expr: left.clone(), list: vec![*right.clone()], negated: false, })), - (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList { + (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList { expr: right.clone(), list: vec![*left.clone()], negated: false, @@ -1920,12 +2060,12 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(InList { + (Expr::Column(_), Expr::Literal(_, _)) => Some(InList { expr: left, list: vec![*right], negated: false, }), - (Expr::Literal(_), Expr::Column(_)) => Some(InList { + (Expr::Literal(_, _), Expr::Column(_)) => Some(InList { expr: right, list: vec![*left], negated: false, @@ -1997,12 +2137,41 @@ fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result { } } +// A * 1 -> A +// A / 1 -> A +// +// Move this function body out of the large match branch avoid stack overflow +fn simplify_right_is_one_case( + info: &S, + left: Box, + op: &Operator, + right: &Expr, +) -> Result> { + // Check if resulting type would be different due to coercion + let left_type = info.get_data_type(&left)?; + let right_type = info.get_data_type(right)?; + match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() { + Ok(result_type) => { + // Only cast if the types differ + if left_type != result_type { + Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type)))) + } else { + Ok(Transformed::yes(*left)) + } + } + Err(_) => Ok(Transformed::yes(*left)), + } +} + #[cfg(test)] mod tests { + use super::*; use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; + use arrow::datatypes::FieldRef; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ + expr::WindowFunction, function::{ AccumulatorArgs, AggregateFunctionSimplification, WindowFunctionSimplification, @@ -2012,14 +2181,15 @@ mod tests { }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; + use datafusion_physical_expr::PhysicalExpr; + use std::hash::Hash; + use std::sync::LazyLock; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, sync::Arc, }; - use super::*; - // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -2054,12 +2224,15 @@ mod tests { } fn test_schema() -> DFSchemaRef { - Schema::new(vec![ - Field::new("i", DataType::Int64, false), - Field::new("b", DataType::Boolean, true), - ]) - .to_dfschema_ref() - .unwrap() + static TEST_SCHEMA: LazyLock = LazyLock::new(|| { + Schema::new(vec![ + Field::new("i", DataType::Int64, false), + Field::new("b", DataType::Boolean, true), + ]) + .to_dfschema_ref() + .unwrap() + }); + Arc::clone(&TEST_SCHEMA) } #[test] @@ -2152,6 +2325,21 @@ mod tests { } } + #[test] + fn test_simplify_eq_not_self() { + // `expr_a`: column `c2` is nullable, so `c2 = c2` simplifies to `c2 IS NOT NULL OR NULL` + // This ensures the expression is only true when `c2` is not NULL, accounting for SQL's NULL semantics. + let expr_a = col("c2").eq(col("c2")); + let expected_a = col("c2").is_not_null().or(lit_bool_null()); + + // `expr_b`: column `c2_non_null` is explicitly non-nullable, so `c2_non_null = c2_non_null` is always true + let expr_b = col("c2_non_null").eq(col("c2_non_null")); + let expected_b = lit(true); + + assert_eq!(simplify(expr_a), expected_a); + assert_eq!(simplify(expr_b), expected_b); + } + #[test] fn test_simplify_or_true() { let expr_a = col("c2").or(lit(true)); @@ -2250,15 +2438,15 @@ mod tests { #[test] fn test_simplify_multiply_by_null() { - let null = Expr::Literal(ScalarValue::Null); + let null = lit(ScalarValue::Int64(None)); // A * null --> null { - let expr = col("c2") * null.clone(); + let expr = col("c3") * null.clone(); assert_eq!(simplify(expr), null); } // null * A --> null { - let expr = null.clone() * col("c2"); + let expr = null.clone() * col("c3"); assert_eq!(simplify(expr), null); } } @@ -2314,14 +2502,14 @@ mod tests { #[test] fn test_simplify_divide_null() { // A / null --> null - let null = lit(ScalarValue::Null); + let null = lit(ScalarValue::Int64(None)); { - let expr = col("c") / null.clone(); + let expr = col("c3") / null.clone(); assert_eq!(simplify(expr), null); } // null / A --> null { - let expr = null.clone() / col("c"); + let expr = null.clone() / col("c3"); assert_eq!(simplify(expr), null); } } @@ -2337,15 +2525,15 @@ mod tests { #[test] fn test_simplify_modulo_by_null() { - let null = lit(ScalarValue::Null); + let null = lit(ScalarValue::Int64(None)); // A % null --> null { - let expr = col("c2") % null.clone(); + let expr = col("c3") % null.clone(); assert_eq!(simplify(expr), null); } // null % A --> null { - let expr = null.clone() % col("c2"); + let expr = null.clone() % col("c3"); assert_eq!(simplify(expr), null); } } @@ -2391,45 +2579,45 @@ mod tests { #[test] fn test_simplify_bitwise_xor_by_null() { - let null = lit(ScalarValue::Null); + let null = lit(ScalarValue::Int64(None)); // A ^ null --> null { - let expr = col("c2") ^ null.clone(); + let expr = col("c3") ^ null.clone(); assert_eq!(simplify(expr), null); } // null ^ A --> null { - let expr = null.clone() ^ col("c2"); + let expr = null.clone() ^ col("c3"); assert_eq!(simplify(expr), null); } } #[test] fn test_simplify_bitwise_shift_right_by_null() { - let null = lit(ScalarValue::Null); + let null = lit(ScalarValue::Int64(None)); // A >> null --> null { - let expr = col("c2") >> null.clone(); + let expr = col("c3") >> null.clone(); assert_eq!(simplify(expr), null); } // null >> A --> null { - let expr = null.clone() >> col("c2"); + let expr = null.clone() >> col("c3"); assert_eq!(simplify(expr), null); } } #[test] fn test_simplify_bitwise_shift_left_by_null() { - let null = lit(ScalarValue::Null); + let null = lit(ScalarValue::Int64(None)); // A << null --> null { - let expr = col("c2") << null.clone(); + let expr = col("c3") << null.clone(); assert_eq!(simplify(expr), null); } // null << A --> null { - let expr = null.clone() << col("c2"); + let expr = null.clone() << col("c3"); assert_eq!(simplify(expr), null); } } @@ -2496,15 +2684,15 @@ mod tests { #[test] fn test_simplify_bitwise_and_by_null() { - let null = lit(ScalarValue::Null); + let null = Expr::Literal(ScalarValue::Int64(None), None); // A & null --> null { - let expr = col("c2") & null.clone(); + let expr = col("c3") & null.clone(); assert_eq!(simplify(expr), null); } // null & A --> null { - let expr = null.clone() & col("c2"); + let expr = null.clone() & col("c3"); assert_eq!(simplify(expr), null); } } @@ -3185,6 +3373,15 @@ mod tests { simplifier.simplify(expr) } + fn coerce(expr: Expr) -> Expr { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)), + ); + simplifier.coerce(expr, schema.as_ref()).unwrap() + } + fn simplify(expr: Expr) -> Expr { try_simplify(expr).unwrap() } @@ -3195,7 +3392,8 @@ mod tests { let simplifier = ExprSimplifier::new( SimplifyContext::new(&execution_props).with_schema(schema), ); - simplifier.simplify_with_cycle_count(expr) + let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?; + Ok((expr.data, count)) } fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) { @@ -3216,23 +3414,27 @@ mod tests { } fn expr_test_schema() -> DFSchemaRef { - Arc::new( - DFSchema::from_unqualified_fields( - vec![ - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::Boolean, true), - Field::new("c3", DataType::Int64, true), - Field::new("c4", DataType::UInt32, true), - Field::new("c1_non_null", DataType::Utf8, false), - Field::new("c2_non_null", DataType::Boolean, false), - Field::new("c3_non_null", DataType::Int64, false), - Field::new("c4_non_null", DataType::UInt32, false), - ] - .into(), - HashMap::new(), + static EXPR_TEST_SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new( + DFSchema::from_unqualified_fields( + vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Boolean, true), + Field::new("c3", DataType::Int64, true), + Field::new("c4", DataType::UInt32, true), + Field::new("c1_non_null", DataType::Utf8, false), + Field::new("c2_non_null", DataType::Boolean, false), + Field::new("c3_non_null", DataType::Int64, false), + Field::new("c4_non_null", DataType::UInt32, false), + Field::new("c5", DataType::FixedSizeBinary(3), true), + ] + .into(), + HashMap::new(), + ) + .unwrap(), ) - .unwrap(), - ) + }); + Arc::clone(&EXPR_TEST_SCHEMA) } #[test] @@ -3377,6 +3579,142 @@ mod tests { ); } + #[test] + fn simplify_literal_case_equality() { + // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" + let simple_case = Expr::Case(Case::new( + None, + vec![( + Box::new(col("c2_non_null").not_eq(lit(false))), + Box::new(lit("ok")), + )], + Some(Box::new(lit("not_ok"))), + )); + + // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" == "ok" + // --> + // CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok" + // --> + // CASE WHEN c2 != false THEN true ELSE false + // --> + // c2 + assert_eq!( + simplify(binary_expr(simple_case.clone(), Operator::Eq, lit("ok"),)), + col("c2_non_null"), + ); + + // CASE WHEN c2 != false THEN "ok" ELSE "not_ok" != "ok" + // --> + // NOT(CASE WHEN c2 != false THEN "ok" == "ok" ELSE "not_ok" == "ok") + // --> + // NOT(CASE WHEN c2 != false THEN true ELSE false) + // --> + // NOT(c2) + assert_eq!( + simplify(binary_expr(simple_case, Operator::NotEq, lit("ok"),)), + not(col("c2_non_null")), + ); + + let complex_case = Expr::Case(Case::new( + None, + vec![ + ( + Box::new(col("c1").eq(lit("inboxed"))), + Box::new(lit("pending")), + ), + ( + Box::new(col("c1").eq(lit("scheduled"))), + Box::new(lit("pending")), + ), + ( + Box::new(col("c1").eq(lit("completed"))), + Box::new(lit("completed")), + ), + ( + Box::new(col("c1").eq(lit("paused"))), + Box::new(lit("paused")), + ), + (Box::new(col("c2")), Box::new(lit("running"))), + ( + Box::new(col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0)))), + Box::new(lit("backing-off")), + ), + ], + Some(Box::new(lit("ready"))), + )); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::Eq, + lit("completed"), + )), + not_distinct_from(col("c1").eq(lit("completed")), lit(true)).and( + distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true))) + ) + ); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::NotEq, + lit("completed"), + )), + distinct_from(col("c1").eq(lit("completed")), lit(true)) + .or(not_distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true)))) + ); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::Eq, + lit("running"), + )), + not_distinct_from(col("c2"), lit(true)).and( + distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true))) + .and(distinct_from(col("c1").eq(lit("completed")), lit(true))) + .and(distinct_from(col("c1").eq(lit("paused")), lit(true))) + ) + ); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::Eq, + lit("ready"), + )), + distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .and(distinct_from(col("c1").eq(lit("scheduled")), lit(true))) + .and(distinct_from(col("c1").eq(lit("completed")), lit(true))) + .and(distinct_from(col("c1").eq(lit("paused")), lit(true))) + .and(distinct_from(col("c2"), lit(true))) + .and(distinct_from( + col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))), + lit(true) + )) + ); + + assert_eq!( + simplify(binary_expr( + complex_case.clone(), + Operator::NotEq, + lit("ready"), + )), + not_distinct_from(col("c1").eq(lit("inboxed")), lit(true)) + .or(not_distinct_from(col("c1").eq(lit("scheduled")), lit(true))) + .or(not_distinct_from(col("c1").eq(lit("completed")), lit(true))) + .or(not_distinct_from(col("c1").eq(lit("paused")), lit(true))) + .or(not_distinct_from(col("c2"), lit(true))) + .or(not_distinct_from( + col("c1").eq(lit("invoked")).and(col("c3").gt(lit(0))), + lit(true) + )) + ); + } + #[test] fn simplify_expr_case_when_then_else() { // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true @@ -3496,6 +3834,200 @@ mod tests { ); } + #[test] + fn simplify_expr_case_when_first_true() { + // CASE WHEN true THEN 1 ELSE c1 END --> 1 + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(true)), Box::new(lit(1)),)], + Some(Box::new(col("c1"))), + ))), + lit(1) + ); + + // CASE WHEN true THEN col('a') ELSE col('b') END --> col('a') + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(true)), Box::new(lit("a")),)], + Some(Box::new(lit("b"))), + ))), + lit("a") + ); + + // CASE WHEN true THEN col('a') WHEN col('x') > 5 THEN col('b') ELSE col('c') END --> col('a') + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(lit(true)), Box::new(lit("a"))), + (Box::new(lit("x").gt(lit(5))), Box::new(lit("b"))), + ], + Some(Box::new(lit("c"))), + ))), + lit("a") + ); + + // CASE WHEN true THEN col('a') END --> col('a') (no else clause) + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(true)), Box::new(lit("a")),)], + None, + ))), + lit("a") + ); + + // Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified + let expr = Expr::Case(Case::new( + None, + vec![(Box::new(col("c2")), Box::new(lit(1)))], + Some(Box::new(lit(2))), + )); + assert_eq!(simplify(expr.clone()), expr); + + // Negative test: CASE WHEN false THEN 1 ELSE 2 END should not use this rule + let expr = Expr::Case(Case::new( + None, + vec![(Box::new(lit(false)), Box::new(lit(1)))], + Some(Box::new(lit(2))), + )); + assert_ne!(simplify(expr), lit(1)); + + // Negative test: CASE WHEN col('c1') > 5 THEN 1 ELSE 2 END should not be simplified + let expr = Expr::Case(Case::new( + None, + vec![(Box::new(col("c1").gt(lit(5))), Box::new(lit(1)))], + Some(Box::new(lit(2))), + )); + assert_eq!(simplify(expr.clone()), expr); + } + + #[test] + fn simplify_expr_case_when_any_true() { + // CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), + (Box::new(lit(true)), Box::new(lit("b"))), + ], + Some(Box::new(lit("c"))), + ))), + Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").gt(lit(0))), Box::new(lit("a")))], + Some(Box::new(lit("b"))), + )) + ); + + // CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END + // --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))), + (Box::new(lit(true)), Box::new(lit("c"))), + (Box::new(col("c3").eq(lit(0))), Box::new(lit("d"))), + ], + Some(Box::new(lit("e"))), + ))), + Expr::Case(Case::new( + None, + vec![ + (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))), + ], + Some(Box::new(lit("c"))), + )) + ); + + // CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else) + // --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(col("c3").gt(lit(0))), Box::new(lit(1))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), + (Box::new(lit(true)), Box::new(lit(3))), + ], + None, + ))), + Expr::Case(Case::new( + None, + vec![ + (Box::new(col("c3").gt(lit(0))), Box::new(lit(1))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), + ], + Some(Box::new(lit(3))), + )) + ); + + // Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified + let expr = Expr::Case(Case::new( + None, + vec![ + (Box::new(col("c3").gt(lit(0))), Box::new(col("c3"))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), + ], + Some(Box::new(lit(3))), + )); + assert_eq!(simplify(expr.clone()), expr); + } + + #[test] + fn simplify_expr_case_when_any_false() { + // CASE WHEN false THEN 'a' END --> NULL + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(false)), Box::new(lit("a")))], + None, + ))), + Expr::Literal(ScalarValue::Utf8(None), None) + ); + + // CASE WHEN false THEN 2 ELSE 1 END --> 1 + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(false)), Box::new(lit(2)))], + Some(Box::new(lit(1))), + ))), + lit(1), + ); + + // CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(col("c3").lt(lit(10))), Box::new(lit("b"))), + (Box::new(lit(false)), Box::new(col("c3"))), + ], + Some(Box::new(col("c4"))), + ))), + Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").lt(lit(10))), Box::new(lit("b")))], + Some(Box::new(col("c4"))), + )) + ); + + // Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified + let expr = Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").eq(lit(4))), Box::new(lit(1)))], + Some(Box::new(lit(2))), + )); + assert_eq!(simplify(expr.clone()), expr); + } + fn distinct_from(left: impl Into, right: impl Into) -> Expr { Expr::BinaryExpr(BinaryExpr { left: Box::new(left.into()), @@ -3755,6 +4287,56 @@ mod tests { assert_eq!(simplify(expr.clone()), expr); } + #[test] + fn simplify_null_in_empty_inlist() { + // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false + let expr = in_list(lit_bool_null(), vec![], false); + assert_eq!(simplify(expr), lit(false)); + + // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true + let expr = in_list(lit_bool_null(), vec![], true); + assert_eq!(simplify(expr), lit(true)); + + // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false + let null_null = || Expr::Literal(ScalarValue::Null, None); + let expr = in_list(null_null(), vec![], false); + assert_eq!(simplify(expr), lit(false)); + + // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true + let expr = in_list(null_null(), vec![], true); + assert_eq!(simplify(expr), lit(true)); + } + + #[test] + fn just_simplifier_simplify_null_in_empty_inlist() { + let simplify = |expr: Expr| -> Expr { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let info = SimplifyContext::new(&execution_props).with_schema(schema); + let simplifier = &mut Simplifier::new(&info); + expr.rewrite(simplifier) + .expect("Failed to simplify expression") + .data + }; + + // `NULL::boolean IN ()` == `NULL::boolean IN (SELECT foo FROM empty)` == false + let expr = in_list(lit_bool_null(), vec![], false); + assert_eq!(simplify(expr), lit(false)); + + // `NULL::boolean NOT IN ()` == `NULL::boolean NOT IN (SELECT foo FROM empty)` == true + let expr = in_list(lit_bool_null(), vec![], true); + assert_eq!(simplify(expr), lit(true)); + + // `NULL IN ()` == `NULL IN (SELECT foo FROM empty)` == false + let null_null = || Expr::Literal(ScalarValue::Null, None); + let expr = in_list(null_null(), vec![], false); + assert_eq!(simplify(expr), lit(false)); + + // `NULL NOT IN ()` == `NULL NOT IN (SELECT foo FROM empty)` == true + let expr = in_list(null_null(), vec![], true); + assert_eq!(simplify(expr), lit(true)); + } + #[test] fn simplify_large_or() { let expr = (0..5) @@ -3933,6 +4515,11 @@ mod tests { assert_eq!(simplify(expr), col("c1").like(lit("a_"))); let expr = col("c1").not_like(lit("a_")); assert_eq!(simplify(expr), col("c1").not_like(lit("a_"))); + + let expr = col("c1").ilike(lit("a")); + assert_eq!(simplify(expr), col("c1").ilike(lit("a"))); + let expr = col("c1").not_ilike(lit("a")); + assert_eq!(simplify(expr), col("c1").not_ilike(lit("a"))); } #[test] @@ -4076,14 +4663,17 @@ mod tests { } fn boolean_test_schema() -> DFSchemaRef { - Schema::new(vec![ - Field::new("A", DataType::Boolean, false), - Field::new("B", DataType::Boolean, false), - Field::new("C", DataType::Boolean, false), - Field::new("D", DataType::Boolean, false), - ]) - .to_dfschema_ref() - .unwrap() + static BOOLEAN_TEST_SCHEMA: LazyLock = LazyLock::new(|| { + Schema::new(vec![ + Field::new("A", DataType::Boolean, false), + Field::new("B", DataType::Boolean, false), + Field::new("C", DataType::Boolean, false), + Field::new("D", DataType::Boolean, false), + ]) + .to_dfschema_ref() + .unwrap() + }); + Arc::clone(&BOOLEAN_TEST_SCHEMA) } #[test] @@ -4124,7 +4714,7 @@ mod tests { vec![], false, None, - None, + vec![], None, )); @@ -4138,7 +4728,7 @@ mod tests { vec![], false, None, - None, + vec![], None, )); @@ -4148,7 +4738,7 @@ mod tests { /// A Mock UDAF which defines `simplify` to be used in tests /// related to UDAF simplification - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifyMockUdaf { simplify: bool, } @@ -4213,8 +4803,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4222,8 +4811,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); @@ -4231,7 +4819,7 @@ mod tests { /// A Mock UDWF which defines `simplify` to be used in tests /// related to UDWF simplification - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifyMockUdwf { simplify: bool, } @@ -4275,11 +4863,15 @@ mod tests { unimplemented!("not needed for tests") } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!("not needed for tests") } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } - #[derive(Debug)] + #[derive(Debug, PartialEq, Eq, Hash)] struct VolatileUdf { signature: Signature, } @@ -4350,6 +4942,34 @@ mod tests { } } + #[test] + fn simplify_fixed_size_binary_eq_lit() { + let bytes = [1u8, 2, 3].as_slice(); + + // The expression starts simple. + let expr = col("c5").eq(lit(bytes)); + + // The type coercer introduces a cast. + let coerced = coerce(expr.clone()); + let schema = expr_test_schema(); + assert_eq!( + coerced, + col("c5") + .cast_to(&DataType::Binary, schema.as_ref()) + .unwrap() + .eq(lit(bytes)) + ); + + // The simplifier removes the cast. + assert_eq!( + simplify(coerced), + col("c5").eq(Expr::Literal( + ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),), + None + )) + ); + } + fn if_not_null(expr: Expr, then: bool) -> Expr { Expr::Case(Case { expr: Some(expr.is_not_null().into()), diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 4700ab97b5f39..bbb023cfbad9f 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -84,7 +84,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { low, high, }) => { - if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( + if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( self.guarantees.get(inner.as_ref()), low.as_ref(), high.as_ref(), @@ -115,7 +115,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { .get(left.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = left.as_ref() { + if let Expr::Literal(value, _) = left.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -126,7 +126,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { .get(right.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = right.as_ref() { + if let Expr::Literal(value, _) = right.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -168,7 +168,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { let new_list: Vec = list .iter() .filter_map(|expr| { - if let Expr::Literal(item) = expr { + if let Expr::Literal(item, _) = expr { match interval .contains(NullableInterval::from(item.clone())) { @@ -244,8 +244,7 @@ mod tests { let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, - "{} simplified to {}, but expected {}", - expr, output, expected + "{expr} simplified to {output}, but expected {expected}" ); } } @@ -255,8 +254,7 @@ mod tests { let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, - "{} was simplified to {}, but expected it to be unchanged", - expr, output + "{expr} was simplified to {output}, but expected it to be unchanged" ); } } @@ -417,7 +415,7 @@ mod tests { let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); let output = col("x").rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, Expr::Literal(scalar.clone())); + assert_eq!(output, Expr::Literal(scalar.clone(), None)); } } diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index c8638eb723955..a1c1dc17d2945 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -39,10 +39,10 @@ impl TreeNodeRewriter for ShortenInListSimplifier { // if expr is a single column reference: // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) if let Expr::InList(InList { - expr, - list, + ref expr, + ref list, negated, - }) = expr.clone() + }) = expr { if !list.is_empty() && ( @@ -57,7 +57,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { { let first_val = list[0].clone(); if negated { - return Ok(Transformed::yes(list.into_iter().skip(1).fold( + return Ok(Transformed::yes(list.iter().skip(1).cloned().fold( (*expr.clone()).not_eq(first_val), |acc, y| { // Note that `A and B and C and D` is a left-deep tree structure @@ -81,7 +81,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { }, ))); } else { - return Ok(Transformed::yes(list.into_iter().skip(1).fold( + return Ok(Transformed::yes(list.iter().skip(1).cloned().fold( (*expr.clone()).eq(first_val), |acc, y| { // Same reasoning as above diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 5fbee02e3909e..7ae38eec9a3ad 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -23,6 +23,7 @@ mod guarantees; mod inlist_simplifier; mod regex; pub mod simplify_exprs; +mod simplify_predicates; mod unwrap_cast; mod utils; @@ -31,6 +32,7 @@ pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; pub use expr_simplifier::*; pub use simplify_exprs::*; +pub use simplify_predicates::simplify_predicates; // Export for test in datafusion/core/tests/optimizer_integration.rs pub use guarantees::GuaranteeRewriter; diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 0b47cdee212f2..82c5ea3d8d820 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -46,7 +46,7 @@ pub fn simplify_regex_expr( ) -> Result { let mode = OperatorMode::new(&op); - if let Expr::Literal(ScalarValue::Utf8(Some(pattern))) = right.as_ref() { + if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() { // Handle the special case for ".*" pattern if pattern == ANY_CHAR_REGEX_PATTERN { let new_expr = if mode.not { @@ -121,7 +121,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern), None)), escape_char: None, case_insensitive: self.i, }; @@ -255,9 +255,9 @@ fn partial_anchored_literal_to_like(v: &[Hir]) -> Option { }; if match_begin { - Some(format!("{}%", lit)) + Some(format!("{lit}%")) } else { - Some(format!("%{}", lit)) + Some(format!("%{lit}")) } } @@ -331,7 +331,7 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { } HirKind::Concat(inner) => { if let Some(pattern) = partial_anchored_literal_to_like(inner) - .or(collect_concat_to_like_string(inner)) + .or_else(|| collect_concat_to_like_string(inner)) { return Some(mode.expr(Box::new(left.clone()), pattern)); } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index e33869ca2b636..4faf9389cfac4 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -69,6 +69,7 @@ impl OptimizerRule for SimplifyExpressions { ) -> Result, DataFusionError> { let mut execution_props = ExecutionProps::new(); execution_props.query_execution_start_time = config.query_execution_start_time(); + execution_props.config_options = Some(config.options()); Self::optimize_internal(plan, &execution_props) } } @@ -123,10 +124,11 @@ impl SimplifyExpressions { let name_preserver = NamePreserver::new(&plan); let mut rewrite_expr = |expr: Expr| { let name = name_preserver.save(&expr); - let expr = simplifier.simplify(expr)?; - // TODO it would be nice to have a way to know if the expression was simplified - // or not. For now conservatively return Transformed::yes - Ok(Transformed::yes(name.restore(expr))) + let expr = simplifier.simplify_with_cycle_count_transformed(expr)?.0; + Ok(Transformed::new_transformed( + name.restore(expr.data), + expr.transformed, + )) }; plan.map_expressions(|expr| { @@ -154,12 +156,12 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, Utc}; - use crate::optimizer::Optimizer; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::*; use datafusion_functions_aggregate::expr_fn::{max, min}; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -179,15 +181,20 @@ mod tests { .expect("building plan") } - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - // Use Optimizer to do plan traversal - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(SimplifyExpressions::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -210,9 +217,10 @@ mod tests { assert_eq!(1, table_scan.schema().fields().len()); assert_fields_eq(&table_scan, vec!["a"]); - let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]"; - - assert_optimized_plan_eq(table_scan, expected) + assert_optimized_plan_equal!( + table_scan, + @ r"TableScan: test projection=[a], full_filters=[Boolean(true)]" + ) } #[test] @@ -223,12 +231,13 @@ mod tests { .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -240,12 +249,13 @@ mod tests { .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -257,12 +267,13 @@ mod tests { .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -278,12 +289,13 @@ mod tests { ))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.a > Int32(5) AND test.b < Int32(6)\ - \n Projection: test.a, test.b\ - \n TableScan: test", + @ r" + Filter: test.a > Int32(5) AND test.b < Int32(6) + Projection: test.a, test.b + TableScan: test + " ) } @@ -296,13 +308,15 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.c\ - \n Filter: test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.c + Filter: test.b + TableScan: test + " + ) } #[test] @@ -315,14 +329,16 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Limit: skip=0, fetch=1\ - \n Filter: test.c\ - \n Filter: NOT test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Limit: skip=0, fetch=1 + Filter: test.c + Filter: NOT test.b + TableScan: test + " + ) } #[test] @@ -333,12 +349,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.b AND test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.b AND test.c + TableScan: test + " + ) } #[test] @@ -349,12 +367,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.b OR NOT test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.b OR NOT test.c + TableScan: test + " + ) } #[test] @@ -365,12 +385,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: test.b + TableScan: test + " + ) } #[test] @@ -380,11 +402,13 @@ mod tests { .project(vec![col("a"), col("d"), col("b").eq(lit(false))])? .build()?; - let expected = "\ - Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false) + TableScan: test + " + ) } #[test] @@ -398,12 +422,14 @@ mod tests { )? .build()?; - let expected = "\ - Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]]\ - \n Projection: test.a, test.c, test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]] + Projection: test.a, test.c, test.b + TableScan: test + " + ) } #[test] @@ -421,10 +447,10 @@ mod tests { let values = vec![vec![expr1, expr2]]; let plan = LogicalPlanBuilder::values(values)?.build()?; - let expected = "\ - Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ "Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))" + ) } fn get_optimized_plan_formatted( @@ -481,10 +507,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).not())? .build()?; - let expected = "Filter: test.d <= Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) + TableScan: test + " + ) } #[test] @@ -494,10 +524,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not())? .build()?; - let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) OR test.d >= Int32(100) + TableScan: test + " + ) } #[test] @@ -507,10 +541,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not())? .build()?; - let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) AND test.d >= Int32(100) + TableScan: test + " + ) } #[test] @@ -520,10 +558,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).not().not())? .build()?; - let expected = "Filter: test.d > Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d > Int32(10) + TableScan: test + " + ) } #[test] @@ -533,10 +575,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("e").is_null().not())? .build()?; - let expected = "Filter: test.e IS NOT NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.e IS NOT NULL + TableScan: test + " + ) } #[test] @@ -546,10 +592,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("e").is_not_null().not())? .build()?; - let expected = "Filter: test.e IS NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.e IS NULL + TableScan: test + " + ) } #[test] @@ -559,11 +609,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())? .build()?; - let expected = - "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3) + TableScan: test + " + ) } #[test] @@ -573,11 +626,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())? .build()?; - let expected = - "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3) + TableScan: test + " + ) } #[test] @@ -588,10 +644,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(qual.not())? .build()?; - let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d < Int32(1) OR test.d > Int32(10) + TableScan: test + " + ) } #[test] @@ -602,10 +662,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(qual.not())? .build()?; - let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d >= Int32(1) AND test.d <= Int32(10) + TableScan: test + " + ) } #[test] @@ -622,10 +686,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").like(col("b")).not())? .build()?; - let expected = "Filter: test.a NOT LIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a NOT LIKE test.b + TableScan: test + " + ) } #[test] @@ -642,10 +710,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").not_like(col("b")).not())? .build()?; - let expected = "Filter: test.a LIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a LIKE test.b + TableScan: test + " + ) } #[test] @@ -662,10 +734,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").ilike(col("b")).not())? .build()?; - let expected = "Filter: test.a NOT ILIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a NOT ILIKE test.b + TableScan: test + " + ) } #[test] @@ -675,10 +751,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not())? .build()?; - let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d IS NOT DISTINCT FROM Int32(10) + TableScan: test + " + ) } #[test] @@ -688,10 +768,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not())? .build()?; - let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d IS DISTINCT FROM Int32(10) + TableScan: test + " + ) } #[test] @@ -713,11 +797,14 @@ mod tests { // before simplify: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) // after simplify: t1.a + UInt32(1) = t2.a + UInt32(2) AS t1.a + Int64(1) = t2.a + Int64(2) - let expected = "Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2)\ - \n TableScan: t1\ - \n TableScan: t2"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2) + TableScan: t1 + TableScan: t2 + " + ) } #[test] @@ -727,10 +814,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").is_not_null())? .build()?; - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(true) + TableScan: test + " + ) } #[test] @@ -740,10 +831,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").is_null())? .build()?; - let expected = "Filter: Boolean(false)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(false) + TableScan: test + " + ) } #[test] @@ -760,10 +855,13 @@ mod tests { )? .build()?; - let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]] + TableScan: test + " + ) } #[test] @@ -778,19 +876,27 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a IS NOT NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a IS NOT NULL + TableScan: test + " + )?; // Test `!= ".*"` transforms to checking if the column is empty let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a = Utf8(\"\")\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r#" + Filter: test.a = Utf8("") + TableScan: test + "# + )?; // Test case-insensitive versions @@ -798,18 +904,174 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("b"), Operator::RegexIMatch, lit(".*")))? .build()?; - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(true) + TableScan: test + " + )?; // Test `!~ ".*"` (case-insensitive) transforms to checking if the column is empty let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotIMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a = Utf8(\"\")\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r#" + Filter: test.a = Utf8("") + TableScan: test + "# + ) + } + + #[test] + fn simplify_not_in_list() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("a").in_list(vec![lit("a"), lit("b")], false).not())? + .build()?; + + assert_optimized_plan_equal!( + plan, + @ r#" + Filter: test.a != Utf8("a") AND test.a != Utf8("b") + TableScan: test + "# + ) + } + + #[test] + fn simplify_not_not_in_list() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + col("a") + .in_list(vec![lit("a"), lit("b")], false) + .not() + .not(), + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @ r#" + Filter: test.a = Utf8("a") OR test.a = Utf8("b") + TableScan: test + "# + ) + } + + #[test] + fn simplify_not_exists() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + let table_scan2 = + datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + exists(Arc::new(LogicalPlanBuilder::from(table_scan2).build()?)).not(), + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @ r" + Filter: NOT EXISTS () + Subquery: + TableScan: test2 + TableScan: test + " + ) + } + + #[test] + fn simplify_not_not_exists() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + let table_scan2 = + datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + exists(Arc::new(LogicalPlanBuilder::from(table_scan2).build()?)) + .not() + .not(), + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @ r" + Filter: EXISTS () + Subquery: + TableScan: test2 + TableScan: test + " + ) + } + + #[test] + fn simplify_not_in_subquery() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + let table_scan2 = + datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + in_subquery( + col("a"), + Arc::new(LogicalPlanBuilder::from(table_scan2).build()?), + ) + .not(), + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a NOT IN () + Subquery: + TableScan: test2 + TableScan: test + " + ) + } + + #[test] + fn simplify_not_not_in_subquery() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let table_scan = table_scan(Some("test"), &schema, None)?.build()?; + let table_scan2 = + datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + in_subquery( + col("a"), + Arc::new(LogicalPlanBuilder::from(table_scan2).build()?), + ) + .not() + .not(), + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a IN () + Subquery: + TableScan: test2 + TableScan: test + " + ) } } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs new file mode 100644 index 0000000000000..131404e607060 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs @@ -0,0 +1,337 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplifies predicates by reducing redundant or overlapping conditions. +//! +//! This module provides functionality to optimize logical predicates used in query planning +//! by eliminating redundant conditions, thus reducing the number of predicates to evaluate. +//! Unlike the simplifier in `simplify_expressions/simplify_exprs.rs`, which focuses on +//! general expression simplification (e.g., constant folding and algebraic simplifications), +//! this module specifically targets predicate optimization by handling containment relationships. +//! For example, it can simplify `x > 5 AND x > 6` to just `x > 6`, as the latter condition +//! encompasses the former, resulting in fewer checks during query execution. + +use datafusion_common::{Column, Result, ScalarValue}; +use datafusion_expr::{BinaryExpr, Expr, Operator}; +use std::collections::BTreeMap; + +/// Simplifies a list of predicates by removing redundancies. +/// +/// This function takes a vector of predicate expressions and groups them by the column they reference. +/// Predicates that reference a single column and are comparison operations (e.g., >, >=, <, <=, =) +/// are analyzed to remove redundant conditions. For instance, `x > 5 AND x > 6` is simplified to +/// `x > 6`. Other predicates that do not fit this pattern are retained as-is. +/// +/// # Arguments +/// * `predicates` - A vector of `Expr` representing the predicates to simplify. +/// +/// # Returns +/// A `Result` containing a vector of simplified `Expr` predicates. +pub fn simplify_predicates(predicates: Vec) -> Result> { + // Early return for simple cases + if predicates.len() <= 1 { + return Ok(predicates); + } + + // Group predicates by their column reference + let mut column_predicates: BTreeMap> = BTreeMap::new(); + let mut other_predicates = Vec::new(); + + for pred in predicates { + match &pred { + Expr::BinaryExpr(BinaryExpr { + left, + op: + Operator::Gt + | Operator::GtEq + | Operator::Lt + | Operator::LtEq + | Operator::Eq, + right, + }) => { + let left_col = extract_column_from_expr(left); + let right_col = extract_column_from_expr(right); + if let (Some(col), Some(_)) = (&left_col, right.as_literal()) { + column_predicates.entry(col.clone()).or_default().push(pred); + } else if let (Some(_), Some(col)) = (left.as_literal(), &right_col) { + column_predicates.entry(col.clone()).or_default().push(pred); + } else { + other_predicates.push(pred); + } + } + _ => other_predicates.push(pred), + } + } + + // Process each column's predicates to remove redundancies + let mut result = other_predicates; + for (_, preds) in column_predicates { + let simplified = simplify_column_predicates(preds)?; + result.extend(simplified); + } + + Ok(result) +} + +/// Simplifies predicates related to a single column. +/// +/// This function processes a list of predicates that all reference the same column and +/// simplifies them based on their operators. It groups predicates into greater-than (>, >=), +/// less-than (<, <=), and equality (=) categories, then selects the most restrictive condition +/// in each category to reduce redundancy. For example, among `x > 5` and `x > 6`, only `x > 6` +/// is retained as it is more restrictive. +/// +/// # Arguments +/// * `predicates` - A vector of `Expr` representing predicates for a single column. +/// +/// # Returns +/// A `Result` containing a vector of simplified `Expr` predicates for the column. +fn simplify_column_predicates(predicates: Vec) -> Result> { + if predicates.len() <= 1 { + return Ok(predicates); + } + + // Group by operator type, but combining similar operators + let mut greater_predicates = Vec::new(); // Combines > and >= + let mut less_predicates = Vec::new(); // Combines < and <= + let mut eq_predicates = Vec::new(); + + for pred in predicates { + match &pred { + Expr::BinaryExpr(BinaryExpr { left: _, op, right }) => { + match (op, right.as_literal().is_some()) { + (Operator::Gt, true) + | (Operator::Lt, false) + | (Operator::GtEq, true) + | (Operator::LtEq, false) => greater_predicates.push(pred), + (Operator::Lt, true) + | (Operator::Gt, false) + | (Operator::LtEq, true) + | (Operator::GtEq, false) => less_predicates.push(pred), + (Operator::Eq, _) => eq_predicates.push(pred), + _ => unreachable!("Unexpected operator: {}", op), + } + } + _ => unreachable!("Unexpected predicate {}", pred.to_string()), + } + } + + let mut result = Vec::new(); + + if !eq_predicates.is_empty() { + // If there are many equality predicates, we can only keep one if they are all the same + if eq_predicates.len() == 1 + || eq_predicates.iter().all(|e| e == &eq_predicates[0]) + { + result.push(eq_predicates.pop().unwrap()); + } else { + // If they are not the same, add a false predicate + result.push(Expr::Literal(ScalarValue::Boolean(Some(false)), None)); + } + } + + // Handle all greater-than-style predicates (keep the most restrictive - highest value) + if !greater_predicates.is_empty() { + if let Some(most_restrictive) = + find_most_restrictive_predicate(&greater_predicates, true)? + { + result.push(most_restrictive); + } else { + result.extend(greater_predicates); + } + } + + // Handle all less-than-style predicates (keep the most restrictive - lowest value) + if !less_predicates.is_empty() { + if let Some(most_restrictive) = + find_most_restrictive_predicate(&less_predicates, false)? + { + result.push(most_restrictive); + } else { + result.extend(less_predicates); + } + } + + Ok(result) +} + +/// Finds the most restrictive predicate from a list based on literal values. +/// +/// This function iterates through a list of predicates to identify the most restrictive one +/// by comparing their literal values. For greater-than predicates, the highest value is most +/// restrictive, while for less-than predicates, the lowest value is most restrictive. +/// +/// # Arguments +/// * `predicates` - A slice of `Expr` representing predicates to compare. +/// * `find_greater` - A boolean indicating whether to find the highest value (true for >, >=) +/// or the lowest value (false for <, <=). +/// +/// # Returns +/// A `Result` containing an `Option` with the most restrictive predicate, if any. +fn find_most_restrictive_predicate( + predicates: &[Expr], + find_greater: bool, +) -> Result> { + if predicates.is_empty() { + return Ok(None); + } + + let mut most_restrictive_idx = 0; + let mut best_value: Option<&ScalarValue> = None; + + for (idx, pred) in predicates.iter().enumerate() { + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = pred { + // Extract the literal value based on which side has it + let scalar_value = match (right.as_literal(), left.as_literal()) { + (Some(scalar), _) => Some(scalar), + (_, Some(scalar)) => Some(scalar), + _ => None, + }; + + if let Some(scalar) = scalar_value { + if let Some(current_best) = best_value { + let comparison = scalar.try_cmp(current_best)?; + let is_better = if find_greater { + comparison == std::cmp::Ordering::Greater + } else { + comparison == std::cmp::Ordering::Less + }; + + if is_better { + best_value = Some(scalar); + most_restrictive_idx = idx; + } + } else { + best_value = Some(scalar); + most_restrictive_idx = idx; + } + } + } + } + + Ok(Some(predicates[most_restrictive_idx].clone())) +} + +/// Extracts a column reference from an expression, if present. +/// +/// This function checks if the given expression is a column reference or contains one, +/// such as within a cast operation. It returns the `Column` if found. +/// +/// # Arguments +/// * `expr` - A reference to an `Expr` to inspect for a column reference. +/// +/// # Returns +/// An `Option` containing the column reference if found, otherwise `None`. +fn extract_column_from_expr(expr: &Expr) -> Option { + match expr { + Expr::Column(col) => Some(col.clone()), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::DataType; + use datafusion_expr::{cast, col, lit}; + + #[test] + fn test_simplify_predicates_with_cast() { + // Test that predicates on cast expressions are not grouped with predicates on the raw column + // a < 5 AND CAST(a AS varchar) < 'abc' AND a < 6 + // Should simplify to: + // a < 5 AND CAST(a AS varchar) < 'abc' + + let predicates = vec![ + col("a").lt(lit(5i32)), + cast(col("a"), DataType::Utf8).lt(lit("abc")), + col("a").lt(lit(6i32)), + ]; + + let result = simplify_predicates(predicates).unwrap(); + + // Should have 2 predicates: a < 5 and CAST(a AS varchar) < 'abc' + assert_eq!(result.len(), 2); + + // Check that the cast predicate is preserved + let has_cast_predicate = result.iter().any(|p| { + matches!(p, Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Lt, + right + }) if matches!(left.as_ref(), Expr::Cast(_)) && right == &Box::new(lit("abc"))) + }); + assert!(has_cast_predicate, "Cast predicate should be preserved"); + + // Check that we have the more restrictive column predicate (a < 5) + let has_column_predicate = result.iter().any(|p| { + matches!(p, Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Lt, + right + }) if left == &Box::new(col("a")) && right == &Box::new(lit(5i32))) + }); + assert!(has_column_predicate, "Should have a < 5 predicate"); + } + + #[test] + fn test_extract_column_ignores_cast() { + // Test that extract_column_from_expr does not extract columns from cast expressions + let cast_expr = cast(col("a"), DataType::Utf8); + assert_eq!(extract_column_from_expr(&cast_expr), None); + + // Test that it still extracts from direct column references + let col_expr = col("a"); + assert_eq!(extract_column_from_expr(&col_expr), Some(Column::from("a"))); + } + + #[test] + fn test_simplify_predicates_direct_columns_only() { + // Test that only predicates on direct columns are simplified together + let predicates = vec![ + col("a").lt(lit(5i32)), + col("a").lt(lit(3i32)), + col("b").gt(lit(10i32)), + col("b").gt(lit(20i32)), + ]; + + let result = simplify_predicates(predicates).unwrap(); + + // Should have 2 predicates: a < 3 and b > 20 (most restrictive for each column) + assert_eq!(result.len(), 2); + + // Check for a < 3 + let has_a_predicate = result.iter().any(|p| { + matches!(p, Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Lt, + right + }) if left == &Box::new(col("a")) && right == &Box::new(lit(3i32))) + }); + assert!(has_a_predicate, "Should have a < 3 predicate"); + + // Check for b > 20 + let has_b_predicate = result.iter().any(|p| { + matches!(p, Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Gt, + right + }) if left == &Box::new(col("b")) && right == &Box::new(lit(20i32))) + }); + assert!(has_b_predicate, "Should have b > 20 predicate"); + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index be71a8cd19b00..5286cbd7bdf64 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -55,28 +55,23 @@ //! ``` //! -use std::cmp::Ordering; - -use arrow::datatypes::{ - DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, - MIN_DECIMAL128_FOR_EACH_PRECISION, -}; -use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; +use arrow::datatypes::DataType; use datafusion_common::{internal_err, tree_node::Transformed}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{lit, BinaryExpr}; use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; +use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type}; pub(super) fn unwrap_cast_in_comparison_for_binary( info: &S, - cast_expr: Box, - literal: Box, + cast_expr: Expr, + literal: Expr, op: Operator, ) -> Result> { - match (*cast_expr, *literal) { + match (cast_expr, literal) { ( Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }), - Expr::Literal(lit_value), + Expr::Literal(lit_value, _), ) => { let Ok(expr_type) = info.get_data_type(&expr) else { return internal_err!("Can't get the data type of the expr {:?}", &expr); @@ -95,7 +90,7 @@ pub(super) fn unwrap_cast_in_comparison_for_binary( // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else { return internal_err!( - "Can't cast the literal expr {:?} to type {:?}", + "Can't cast the literal expr {:?} to type {}", &lit_value, &expr_type ); @@ -126,7 +121,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< | Expr::Cast(Cast { expr: left_expr, .. }), - Expr::Literal(lit_val), + Expr::Literal(lit_val, _), ) => { let Ok(expr_type) = info.get_data_type(left_expr) else { return false; @@ -183,7 +178,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< } match right { - Expr::Literal(lit_val) + Expr::Literal(lit_val, _) if try_cast_literal_to_type(lit_val, &expr_type).is_some() => {} _ => return false, } @@ -192,44 +187,6 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< true } -/// Returns true if unwrap_cast_in_comparison supports this data type -fn is_supported_type(data_type: &DataType) -> bool { - is_supported_numeric_type(data_type) - || is_supported_string_type(data_type) - || is_supported_dictionary_type(data_type) -} - -/// Returns true if unwrap_cast_in_comparison support this numeric type -fn is_supported_numeric_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - | DataType::Timestamp(_, _) - ) -} - -/// Returns true if unwrap_cast_in_comparison supports casting this value as a string -fn is_supported_string_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View - ) -} - -/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary -fn is_supported_dictionary_type(data_type: &DataType) -> bool { - matches!(data_type, - DataType::Dictionary(_, inner) if is_supported_type(inner)) -} - ///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./ /// /// Specifically, rewrites @@ -276,231 +233,6 @@ fn cast_literal_to_type_with_op( } } -/// Convert a literal value from one data type to another -pub(super) fn try_cast_literal_to_type( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_data_type = lit_value.data_type(); - if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { - return None; - } - if lit_value.is_null() { - // null value can be cast to any type of null value - return ScalarValue::try_from(target_type).ok(); - } - try_cast_numeric_literal(lit_value, target_type) - .or_else(|| try_cast_string_literal(lit_value, target_type)) - .or_else(|| try_cast_dictionary(lit_value, target_type)) -} - -/// Convert a numeric value from one numeric data type to another -fn try_cast_numeric_literal( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_data_type = lit_value.data_type(); - if !is_supported_numeric_type(&lit_data_type) - || !is_supported_numeric_type(target_type) - { - return None; - } - - let mul = match target_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => 1_i128, - DataType::Timestamp(_, _) => 1_i128, - DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), - _ => return None, - }; - let (target_min, target_max) = match target_type { - DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), - DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), - DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), - DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), - DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), - DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), - DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), - DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), - DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), - DataType::Decimal128(precision, _) => ( - // Different precision for decimal128 can store different range of value. - // For example, the precision is 3, the max of value is `999` and the min - // value is `-999` - MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], - MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], - ), - _ => return None, - }; - let lit_value_target_type = match lit_value { - ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::Decimal128(Some(v), _, scale) => { - let lit_scale_mul = 10_i128.pow(*scale as u32); - if mul >= lit_scale_mul { - // Example: - // lit is decimal(123,3,2) - // target type is decimal(5,3) - // the lit can be converted to the decimal(1230,5,3) - (*v).checked_mul(mul / lit_scale_mul) - } else if (*v) % (lit_scale_mul / mul) == 0 { - // Example: - // lit is decimal(123000,10,3) - // target type is int32: the lit can be converted to INT32(123) - // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) - Some(*v / (lit_scale_mul / mul)) - } else { - // can't convert the lit decimal to the target data type - None - } - } - _ => None, - }; - - match lit_value_target_type { - None => None, - Some(value) => { - if value >= target_min && value <= target_max { - // the value casted from lit to the target type is in the range of target type. - // return the target type of scalar value - let result_scalar = match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value as i64)), - DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), - DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), - DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), - DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), - DataType::Timestamp(TimeUnit::Second, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Second, tz.clone()), - value, - ); - ScalarValue::TimestampSecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Millisecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), - value, - ); - ScalarValue::TimestampMillisecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), - value, - ); - ScalarValue::TimestampMicrosecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), - value, - ); - ScalarValue::TimestampNanosecond(value, tz.clone()) - } - DataType::Decimal128(p, s) => { - ScalarValue::Decimal128(Some(value), *p, *s) - } - _ => { - return None; - } - }; - Some(result_scalar) - } else { - None - } - } - } -} - -fn try_cast_string_literal( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); - let scalar_value = match target_type { - DataType::Utf8 => ScalarValue::Utf8(string_value), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), - DataType::Utf8View => ScalarValue::Utf8View(string_value), - _ => return None, - }; - Some(scalar_value) -} - -/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary -fn try_cast_dictionary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_value_type = lit_value.data_type(); - let result_scalar = match (lit_value, target_type) { - // Unwrap dictionary when inner type matches target type - (ScalarValue::Dictionary(_, inner_value), _) - if inner_value.data_type() == *target_type => - { - (**inner_value).clone() - } - // Wrap type when target type is dictionary - (_, DataType::Dictionary(index_type, inner_type)) - if **inner_type == lit_value_type => - { - ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) - } - _ => { - return None; - } - }; - Some(result_scalar) -} - -/// Cast a timestamp value from one unit to another -fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { - let value = value as i64; - let from_scale = match from { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value), - }; - - let to_scale = match to { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value), - }; - - match from_scale.cmp(&to_scale) { - Ordering::Less => value.checked_mul(to_scale / from_scale), - Ordering::Greater => Some(value / (from_scale / to_scale)), - Ordering::Equal => Some(value), - } -} - #[cfg(test)] mod tests { use super::*; @@ -508,8 +240,7 @@ mod tests { use std::sync::Arc; use crate::simplify_expressions::ExprSimplifier; - use arrow::compute::{cast_with_options, CastOptions}; - use arrow::datatypes::Field; + use arrow::datatypes::{Field, TimeUnit}; use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; @@ -566,9 +297,9 @@ mod tests { let expected = col("c2").eq(lit(16i64)); assert_eq!(optimize_test(c2_eq_lit, &schema), expected); - // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL) + // cast(c1, INT64) < INT64(NULL) => NULL let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64()); - let expected = col("c1").lt(null_i32()); + let expected = null_bool(); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL) @@ -586,9 +317,9 @@ mod tests { let expected = col("c1").not_eq(lit(123i32)); assert_eq!(optimize_test(expr_input, &schema), expected); - // cast(c1, UTF8) = NULL => c1 = NULL + // cast(c1, UTF8) = NULL => NULL let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None))); - let expected = col("c1").eq(lit(ScalarValue::Int32(None))); + let expected = null_bool(); assert_eq!(optimize_test(expr_input, &schema), expected); } @@ -691,7 +422,7 @@ mod tests { // c3 < INT64(NULL) let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64()); - let expected = col("c3").lt(null_decimal(18, 2)); + let expected = null_bool(); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // decimal to decimal @@ -922,10 +653,6 @@ mod tests { lit(ScalarValue::TimestampNanosecond(Some(ts), utc)) } - fn null_decimal(precision: u8, scale: i8) -> Expr { - lit(ScalarValue::Decimal128(None, precision, scale)) - } - fn timestamp_nano_none_type() -> DataType { DataType::Timestamp(TimeUnit::Nanosecond, None) } @@ -940,514 +667,4 @@ mod tests { fn dictionary_tag_type() -> DataType { DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) } - - #[test] - fn test_try_cast_to_type_nulls() { - // test that nulls can be cast to/from all integer types - let scalars = vec![ - ScalarValue::Int8(None), - ScalarValue::Int16(None), - ScalarValue::Int32(None), - ScalarValue::Int64(None), - ScalarValue::UInt8(None), - ScalarValue::UInt16(None), - ScalarValue::UInt32(None), - ScalarValue::UInt64(None), - ScalarValue::Decimal128(None, 3, 0), - ScalarValue::Decimal128(None, 8, 2), - ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), - ]; - - for s1 in &scalars { - for s2 in &scalars { - let expected_value = ExpectedCast::Value(s2.clone()); - - expect_cast(s1.clone(), s2.data_type(), expected_value); - } - } - } - - #[test] - fn test_try_cast_to_type_int_in_range() { - // test values that can be cast to/from all integer types - let scalars = vec![ - ScalarValue::Int8(Some(123)), - ScalarValue::Int16(Some(123)), - ScalarValue::Int32(Some(123)), - ScalarValue::Int64(Some(123)), - ScalarValue::UInt8(Some(123)), - ScalarValue::UInt16(Some(123)), - ScalarValue::UInt32(Some(123)), - ScalarValue::UInt64(Some(123)), - ScalarValue::Decimal128(Some(123), 3, 0), - ScalarValue::Decimal128(Some(12300), 8, 2), - ]; - - for s1 in &scalars { - for s2 in &scalars { - let expected_value = ExpectedCast::Value(s2.clone()); - - expect_cast(s1.clone(), s2.data_type(), expected_value); - } - } - - let max_i32 = ScalarValue::Int32(Some(i32::MAX)); - expect_cast( - max_i32, - DataType::UInt64, - ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))), - ); - - let min_i32 = ScalarValue::Int32(Some(i32::MIN)); - expect_cast( - min_i32, - DataType::Int64, - ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))), - ); - - let max_i64 = ScalarValue::Int64(Some(i64::MAX)); - expect_cast( - max_i64, - DataType::UInt64, - ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))), - ); - } - - #[test] - fn test_try_cast_to_type_int_out_of_range() { - let min_i32 = ScalarValue::Int32(Some(i32::MIN)); - let min_i64 = ScalarValue::Int64(Some(i64::MIN)); - let max_i64 = ScalarValue::Int64(Some(i64::MAX)); - let max_u64 = ScalarValue::UInt64(Some(u64::MAX)); - - expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue); - - expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue); - - expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue); - - expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue); - - expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue); - - expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue); - - // decimal out of range - expect_cast( - ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0), - DataType::Int64, - ExpectedCast::NoValue, - ); - - expect_cast( - ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 37, 1), - DataType::Int64, - ExpectedCast::NoValue, - ); - } - - #[test] - fn test_try_decimal_cast_in_range() { - expect_cast( - ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(3, 0), - ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)), - ); - - expect_cast( - ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(8, 0), - ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)), - ); - - expect_cast( - ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(8, 5), - ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)), - ); - } - - #[test] - fn test_try_decimal_cast_out_of_range() { - // decimal would lose precision - expect_cast( - ScalarValue::Decimal128(Some(12345), 5, 2), - DataType::Decimal128(3, 0), - ExpectedCast::NoValue, - ); - - // decimal would lose precision - expect_cast( - ScalarValue::Decimal128(Some(12300), 5, 2), - DataType::Decimal128(2, 0), - ExpectedCast::NoValue, - ); - } - - #[test] - fn test_try_cast_to_type_timestamps() { - for time_unit in [ - TimeUnit::Second, - TimeUnit::Millisecond, - TimeUnit::Microsecond, - TimeUnit::Nanosecond, - ] { - let utc = Some("+00:00".into()); - // No timezone, utc timezone - let (lit_tz_none, lit_tz_utc) = match time_unit { - TimeUnit::Second => ( - ScalarValue::TimestampSecond(Some(12345), None), - ScalarValue::TimestampSecond(Some(12345), utc), - ), - - TimeUnit::Millisecond => ( - ScalarValue::TimestampMillisecond(Some(12345), None), - ScalarValue::TimestampMillisecond(Some(12345), utc), - ), - - TimeUnit::Microsecond => ( - ScalarValue::TimestampMicrosecond(Some(12345), None), - ScalarValue::TimestampMicrosecond(Some(12345), utc), - ), - - TimeUnit::Nanosecond => ( - ScalarValue::TimestampNanosecond(Some(12345), None), - ScalarValue::TimestampNanosecond(Some(12345), utc), - ), - }; - - // DataFusion ignores timezones for comparisons of ScalarValue - // so double check it here - assert_eq!(lit_tz_none, lit_tz_utc); - - // e.g. DataType::Timestamp(_, None) - let dt_tz_none = lit_tz_none.data_type(); - - // e.g. DataType::Timestamp(_, Some(utc)) - let dt_tz_utc = lit_tz_utc.data_type(); - - // None <--> None - expect_cast( - lit_tz_none.clone(), - dt_tz_none.clone(), - ExpectedCast::Value(lit_tz_none.clone()), - ); - - // None <--> Utc - expect_cast( - lit_tz_none.clone(), - dt_tz_utc.clone(), - ExpectedCast::Value(lit_tz_utc.clone()), - ); - - // Utc <--> None - expect_cast( - lit_tz_utc.clone(), - dt_tz_none.clone(), - ExpectedCast::Value(lit_tz_none.clone()), - ); - - // Utc <--> Utc - expect_cast( - lit_tz_utc.clone(), - dt_tz_utc.clone(), - ExpectedCast::Value(lit_tz_utc.clone()), - ); - - // timestamp to int64 - expect_cast( - lit_tz_utc.clone(), - DataType::Int64, - ExpectedCast::Value(ScalarValue::Int64(Some(12345))), - ); - - // int64 to timestamp - expect_cast( - ScalarValue::Int64(Some(12345)), - dt_tz_none.clone(), - ExpectedCast::Value(lit_tz_none.clone()), - ); - - // int64 to timestamp - expect_cast( - ScalarValue::Int64(Some(12345)), - dt_tz_utc.clone(), - ExpectedCast::Value(lit_tz_utc.clone()), - ); - - // timestamp to string (not supported yet) - expect_cast( - lit_tz_utc.clone(), - DataType::LargeUtf8, - ExpectedCast::NoValue, - ); - } - } - - #[test] - fn test_try_cast_to_type_unsupported() { - // int64 to list - expect_cast( - ScalarValue::Int64(Some(12345)), - DataType::List(Arc::new(Field::new("f", DataType::Int32, true))), - ExpectedCast::NoValue, - ); - } - - #[derive(Debug, Clone)] - enum ExpectedCast { - /// test successfully cast value and it is as specified - Value(ScalarValue), - /// test returned OK, but could not cast the value - NoValue, - } - - /// Runs try_cast_literal_to_type with the specified inputs and - /// ensure it computes the expected output, and ensures the - /// casting is consistent with the Arrow kernels - fn expect_cast( - literal: ScalarValue, - target_type: DataType, - expected_result: ExpectedCast, - ) { - let actual_value = try_cast_literal_to_type(&literal, &target_type); - - println!("expect_cast: "); - println!(" {literal:?} --> {target_type:?}"); - println!(" expected_result: {expected_result:?}"); - println!(" actual_result: {actual_value:?}"); - - match expected_result { - ExpectedCast::Value(expected_value) => { - let actual_value = - actual_value.expect("Expected cast value but got None"); - - assert_eq!(actual_value, expected_value); - - // Verify that calling the arrow - // cast kernel yields the same results - // input array - let literal_array = literal - .to_array_of_size(1) - .expect("Failed to convert to array of size"); - let expected_array = expected_value - .to_array_of_size(1) - .expect("Failed to convert to array of size"); - let cast_array = cast_with_options( - &literal_array, - &target_type, - &CastOptions::default(), - ) - .expect("Expected to be cast array with arrow cast kernel"); - - assert_eq!( - &expected_array, &cast_array, - "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" - ); - - // Verify that for timestamp types the timezones are the same - // (ScalarValue::cmp doesn't account for timezones); - if let ( - DataType::Timestamp(left_unit, left_tz), - DataType::Timestamp(right_unit, right_tz), - ) = (actual_value.data_type(), expected_value.data_type()) - { - assert_eq!(left_unit, right_unit); - assert_eq!(left_tz, right_tz); - } - } - ExpectedCast::NoValue => { - assert!( - actual_value.is_none(), - "Expected no cast value, but got {actual_value:?}" - ); - } - } - } - - #[test] - fn test_try_cast_literal_to_timestamp() { - // same timestamp - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), - ) - .unwrap(); - - assert_eq!( - new_scalar, - ScalarValue::TimestampNanosecond(Some(123456), None) - ); - - // TimestampNanosecond to TimestampMicrosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), - ) - .unwrap(); - - assert_eq!( - new_scalar, - ScalarValue::TimestampMicrosecond(Some(123), None) - ); - - // TimestampNanosecond to TimestampMillisecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .unwrap(); - - assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); - - // TimestampNanosecond to TimestampSecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampNanosecond(Some(123456), None), - &DataType::Timestamp(TimeUnit::Second, None), - ) - .unwrap(); - - assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); - - // TimestampMicrosecond to TimestampNanosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), - ) - .unwrap(); - - assert_eq!( - new_scalar, - ScalarValue::TimestampNanosecond(Some(123000), None) - ); - - // TimestampMicrosecond to TimestampMillisecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .unwrap(); - - assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); - - // TimestampMicrosecond to TimestampSecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMicrosecond(Some(123456789), None), - &DataType::Timestamp(TimeUnit::Second, None), - ) - .unwrap(); - assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); - - // TimestampMillisecond to TimestampNanosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampNanosecond(Some(123000000), None) - ); - - // TimestampMillisecond to TimestampMicrosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampMicrosecond(Some(123000), None) - ); - // TimestampMillisecond to TimestampSecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampMillisecond(Some(123456789), None), - &DataType::Timestamp(TimeUnit::Second, None), - ) - .unwrap(); - assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); - - // TimestampSecond to TimestampNanosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Nanosecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampNanosecond(Some(123000000000), None) - ); - - // TimestampSecond to TimestampMicrosecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Microsecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampMicrosecond(Some(123000000), None) - ); - - // TimestampSecond to TimestampMillisecond - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(123), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .unwrap(); - assert_eq!( - new_scalar, - ScalarValue::TimestampMillisecond(Some(123000), None) - ); - - // overflow - let new_scalar = try_cast_literal_to_type( - &ScalarValue::TimestampSecond(Some(i64::MAX), None), - &DataType::Timestamp(TimeUnit::Millisecond, None), - ) - .unwrap(); - assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); - } - - #[test] - fn test_try_cast_to_string_type() { - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - - for s1 in &scalars { - for s2 in &scalars { - let expected_value = ExpectedCast::Value(s2.clone()); - - expect_cast(s1.clone(), s2.data_type(), expected_value); - } - } - } - #[test] - fn test_try_cast_to_dictionary_type() { - fn dictionary_type(t: DataType) -> DataType { - DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) - } - fn dictionary_value(value: ScalarValue) -> ScalarValue { - ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) - } - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - for s in &scalars { - expect_cast( - s.clone(), - dictionary_type(s.data_type()), - ExpectedCast::Value(dictionary_value(s.clone())), - ); - expect_cast( - dictionary_value(s.clone()), - s.data_type(), - ExpectedCast::Value(s.clone()), - ) - } - } } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index cf182175e48ee..35e256f3064e3 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -17,11 +17,12 @@ //! Utility functions for expression simplification +use arrow::datatypes::i256; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, or}, - Expr, Like, Operator, + Case, Expr, Like, Operator, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -139,47 +140,59 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> pub fn is_zero(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(0))) - | Expr::Literal(ScalarValue::Int16(Some(0))) - | Expr::Literal(ScalarValue::Int32(Some(0))) - | Expr::Literal(ScalarValue::Int64(Some(0))) - | Expr::Literal(ScalarValue::UInt8(Some(0))) - | Expr::Literal(ScalarValue::UInt16(Some(0))) - | Expr::Literal(ScalarValue::UInt32(Some(0))) - | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true, + Expr::Literal(ScalarValue::Int8(Some(0)), _) + | Expr::Literal(ScalarValue::Int16(Some(0)), _) + | Expr::Literal(ScalarValue::Int32(Some(0)), _) + | Expr::Literal(ScalarValue::Int64(Some(0)), _) + | Expr::Literal(ScalarValue::UInt8(Some(0)), _) + | Expr::Literal(ScalarValue::UInt16(Some(0)), _) + | Expr::Literal(ScalarValue::UInt32(Some(0)), _) + | Expr::Literal(ScalarValue::UInt64(Some(0)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true, + Expr::Literal(ScalarValue::Decimal256(Some(v), _p, _s), _) + if *v == i256::ZERO => + { + true + } _ => false, } } pub fn is_one(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(1))) - | Expr::Literal(ScalarValue::Int16(Some(1))) - | Expr::Literal(ScalarValue::Int32(Some(1))) - | Expr::Literal(ScalarValue::Int64(Some(1))) - | Expr::Literal(ScalarValue::UInt8(Some(1))) - | Expr::Literal(ScalarValue::UInt16(Some(1))) - | Expr::Literal(ScalarValue::UInt32(Some(1))) - | Expr::Literal(ScalarValue::UInt64(Some(1))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s)) => { + Expr::Literal(ScalarValue::Int8(Some(1)), _) + | Expr::Literal(ScalarValue::Int16(Some(1)), _) + | Expr::Literal(ScalarValue::Int32(Some(1)), _) + | Expr::Literal(ScalarValue::Int64(Some(1)), _) + | Expr::Literal(ScalarValue::UInt8(Some(1)), _) + | Expr::Literal(ScalarValue::UInt16(Some(1)), _) + | Expr::Literal(ScalarValue::UInt32(Some(1)), _) + | Expr::Literal(ScalarValue::UInt64(Some(1)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s), _) => { *s >= 0 && POWS_OF_TEN .get(*s as usize) .map(|x| x == v) .unwrap_or_default() } + Expr::Literal(ScalarValue::Decimal256(Some(v), _p, s), _) => { + *s >= 0 + && match i256::from(10).checked_pow(*s as u32) { + Some(res) => res == *v, + None => false, + } + } _ => false, } } pub fn is_true(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => *v, + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => *v, _ => false, } } @@ -187,24 +200,24 @@ pub fn is_true(expr: &Expr) -> bool { /// returns true if expr is a /// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise pub fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) + matches!(expr, Expr::Literal(ScalarValue::Boolean(_), _)) } /// Return a literal NULL value of Boolean data type pub fn lit_bool_null() -> Expr { - Expr::Literal(ScalarValue::Boolean(None)) + Expr::Literal(ScalarValue::Boolean(None), None) } pub fn is_null(expr: &Expr) -> bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } pub fn is_false(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v), + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => !(*v), _ => false, } } @@ -247,11 +260,36 @@ pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { /// `Expr::Literal(ScalarValue::Boolean(v))`. pub fn as_bool_lit(expr: &Expr) -> Result> { match expr { - Expr::Literal(ScalarValue::Boolean(v)) => Ok(*v), + Expr::Literal(ScalarValue::Boolean(v), _) => Ok(*v), _ => internal_err!("Expected boolean literal, got {expr:?}"), } } +pub fn is_case_with_literal_outputs(expr: &Expr) -> bool { + match expr { + Expr::Case(Case { + expr: None, + when_then_expr, + else_expr, + }) => { + when_then_expr.iter().all(|(_, then)| is_lit(then)) + && else_expr.as_deref().is_none_or(is_lit) + } + _ => false, + } +} + +pub fn into_case(expr: Expr) -> Result { + match expr { + Expr::Case(case) => Ok(case), + _ => internal_err!("Expected case, got {expr:?}"), + } +} + +pub fn is_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(_, _)) +} + /// negate a Not clause /// input is the clause to be negated.(args of Not clause) /// For BinaryExpr, use the negation of op instead. @@ -365,3 +403,78 @@ pub fn distribute_negation(expr: Expr) -> Expr { _ => Expr::Negative(Box::new(expr)), } } + +#[cfg(test)] +mod tests { + use super::{is_one, is_zero}; + use arrow::datatypes::i256; + use datafusion_common::ScalarValue; + use datafusion_expr::lit; + + #[test] + fn test_is_zero() { + assert!(is_zero(&lit(ScalarValue::Int8(Some(0))))); + assert!(is_zero(&lit(ScalarValue::Float32(Some(0.0))))); + assert!(is_zero(&lit(ScalarValue::Decimal128( + Some(i128::from(0)), + 9, + 0 + )))); + assert!(is_zero(&lit(ScalarValue::Decimal128( + Some(i128::from(0)), + 9, + 5 + )))); + assert!(is_zero(&lit(ScalarValue::Decimal256( + Some(i256::ZERO), + 9, + 0 + )))); + assert!(is_zero(&lit(ScalarValue::Decimal256( + Some(i256::ZERO), + 9, + 5 + )))); + } + + #[test] + fn test_is_one() { + assert!(is_one(&lit(ScalarValue::Int8(Some(1))))); + assert!(is_one(&lit(ScalarValue::Float32(Some(1.0))))); + assert!(is_one(&lit(ScalarValue::Decimal128( + Some(i128::from(1)), + 9, + 0 + )))); + assert!(is_one(&lit(ScalarValue::Decimal128( + Some(i128::from(10)), + 9, + 1 + )))); + assert!(is_one(&lit(ScalarValue::Decimal128( + Some(i128::from(100)), + 9, + 2 + )))); + assert!(is_one(&lit(ScalarValue::Decimal256( + Some(i256::from(1)), + 9, + 0 + )))); + assert!(is_one(&lit(ScalarValue::Decimal256( + Some(i256::from(10)), + 9, + 1 + )))); + assert!(is_one(&lit(ScalarValue::Decimal256( + Some(i256::from(100)), + 9, + 2 + )))); + assert!(!is_one(&lit(ScalarValue::Decimal256( + Some(i256::from(100)), + 9, + -1 + )))); + } +} diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 7337d2ffce5c3..e9a23c7c4dc50 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -79,7 +79,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { }, }) = expr { - if filter.is_some() || order_by.is_some() { + if filter.is_some() || !order_by.is_empty() { return Ok(false); } aggregate_count += 1; @@ -200,20 +200,20 @@ impl OptimizerRule for SingleDistinctToGroupBy { vec![col(SINGLE_DISTINCT_ALIAS)], false, // intentional to remove distinct here None, - None, + vec![], None, ))) // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation } else { index += 1; - let alias_str = format!("alias{}", index); + let alias_str = format!("alias{index}"); inner_aggr_exprs.push( Expr::AggregateFunction(AggregateFunction::new_udf( Arc::clone(&func), args, false, None, - None, + vec![], None, )) .alias(&alias_str), @@ -223,7 +223,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { vec![col(&alias_str)], false, None, - None, + vec![], None, ))) } @@ -280,6 +280,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use crate::test::*; use datafusion_expr::expr::GroupingSet; use datafusion_expr::ExprFunctionExt; @@ -295,18 +296,23 @@ mod tests { vec![expr], true, None, - None, + vec![], None, )) } - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(SingleDistinctToGroupBy::new()), - plan, - expected, - ); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(SingleDistinctToGroupBy::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -318,11 +324,13 @@ mod tests { .build()?; // Do nothing - let expected = - "Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -334,12 +342,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64] + Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -357,10 +368,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -375,10 +389,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -394,10 +411,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -408,12 +428,15 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64] + Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64] + Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -425,12 +448,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -445,10 +471,13 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -461,13 +490,17 @@ mod tests { vec![count_distinct(col("b")), max_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N] + Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -482,10 +515,13 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -497,12 +533,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64] + Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64] + Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -519,13 +558,17 @@ mod tests { ], )? .build()?; - // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N] + Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -538,13 +581,17 @@ mod tests { vec![sum(col("c")), max(col("c")), count_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -557,13 +604,17 @@ mod tests { vec![min(col("a")), count_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64] + Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -576,17 +627,21 @@ mod tests { vec![col("a")], false, Some(Box::new(col("a").gt(lit(5)))), - None, + vec![], None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -602,11 +657,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -619,17 +678,21 @@ mod tests { vec![col("a")], false, None, - Some(vec![col("a").sort(true, false)]), + vec![col("a").sort(true, false)], None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -645,11 +708,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -666,10 +733,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 94d07a0791b3b..6e0b734bb9280 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -21,7 +21,7 @@ use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{assert_contains, Result}; -use datafusion_expr::{col, logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; use std::sync::Arc; pub mod user_defined; @@ -64,15 +64,6 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { assert_eq!(actual, expected); } -pub fn test_subquery_with_name(name: &str) -> Result> { - let table_scan = test_table_scan_with_name(name)?; - Ok(Arc::new( - LogicalPlanBuilder::from(table_scan) - .project(vec![col("c")])? - .build()?, - )) -} - pub fn scan_tpch_table(table: &str) -> LogicalPlan { let schema = Arc::new(get_tpch_table_schema(table)); table_scan(Some(table), &schema, None) @@ -108,43 +99,20 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { } } -pub fn assert_analyzed_plan_eq( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - assert_analyzed_plan_with_config_eq(options, rule, plan, expected)?; - - Ok(()) -} +#[macro_export] +macro_rules! assert_analyzed_plan_with_config_eq_snapshot { + ( + $options:expr, + $rule:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let analyzed_plan = $crate::Analyzer::with_rules(vec![$rule]).execute_and_check($plan, &$options, |_, _| {})?; -pub fn assert_analyzed_plan_with_config_eq( - options: ConfigOptions, - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = format!("{analyzed_plan}"); - assert_eq!(formatted_plan, expected); + insta::assert_snapshot!(analyzed_plan, @ $expected); - Ok(()) -} - -pub fn assert_analyzed_plan_eq_display_indent( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = analyzed_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); - - Ok(()) + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } pub fn assert_analyzer_check_err( @@ -165,27 +133,26 @@ pub fn assert_analyzer_check_err( fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} -pub fn assert_optimized_plan_eq( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - // Apply the rule once - let opt_context = OptimizerContext::new().with_max_passes(1); +#[macro_export] +macro_rules! assert_optimized_plan_eq_snapshot { + ( + $optimizer_context:expr, + $rules:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = $crate::Optimizer::with_rules($rules); + let optimized_plan = optimizer.optimize($plan, &$optimizer_context, |_, _| {})?; + insta::assert_snapshot!(optimized_plan, @ $expected); - let optimizer = Optimizer::with_rules(vec![Arc::clone(&rule)]); - let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - - Ok(()) + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } fn generate_optimized_plan_with_rules( rules: Vec>, plan: LogicalPlan, ) -> LogicalPlan { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} let config = &mut OptimizerContext::new() .with_max_passes(1) .with_skip_failing_rules(false); @@ -211,60 +178,20 @@ pub fn assert_optimized_plan_with_rules( Ok(()) } -pub fn assert_optimized_plan_eq_display_indent( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(vec![rule]); - let optimized_plan = optimizer - .optimize(plan, &OptimizerContext::new(), observe) - .expect("failed to optimize plan"); - let formatted_plan = optimized_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); -} - -pub fn assert_multi_rules_optimized_plan_eq_display_indent( - rules: Vec>, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(rules); - let optimized_plan = optimizer - .optimize(plan, &OptimizerContext::new(), observe) - .expect("failed to optimize plan"); - let formatted_plan = optimized_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); -} - -pub fn assert_optimizer_err( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(vec![rule]); - let res = optimizer.optimize(plan, &OptimizerContext::new(), observe); - match res { - Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), - Err(ref e) => { - let actual = format!("{e}"); - if expected.is_empty() || !actual.contains(expected) { - assert_eq!(actual, expected) - } - } - } -} - -pub fn assert_optimization_skipped( - rule: Arc, - plan: LogicalPlan, -) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule]); - let new_plan = optimizer.optimize(plan.clone(), &OptimizerContext::new(), observe)?; - - assert_eq!( - format!("{}", plan.display_indent()), - format!("{}", new_plan.display_indent()) - ); - Ok(()) +#[macro_export] +macro_rules! assert_optimized_plan_eq_display_indent_snapshot { + ( + $rule:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = $crate::Optimizer::with_rules(vec![$rule]); + let optimized_plan = optimizer + .optimize($plan, &$crate::OptimizerContext::new(), |_, _| {}) + .expect("failed to optimize plan"); + let formatted_plan = optimized_plan.display_indent_schema(); + insta::assert_snapshot!(formatted_plan, @ $expected); + + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index c734d908f6d6c..81763fa0552fb 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -79,45 +79,71 @@ pub fn is_restrict_null_predicate<'a>( return Ok(true); } + // If result is single `true`, return false; + // If result is single `NULL` or `false`, return true; + Ok( + match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } + } + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) + ), + }, + ) +} + +/// Determines if an expression will always evaluate to null. +/// `c0 + 8` return true +/// `c0 IS NULL` return false +/// `CASE WHEN c0 > 1 then 0 else 1` return false +pub fn evaluates_to_null<'a>( + predicate: Expr, + null_columns: impl IntoIterator, +) -> Result { + if matches!(predicate, Expr::Column(_)) { + return Ok(true); + } + + Ok( + match evaluate_expr_with_null_column(predicate, null_columns)? { + ColumnarValue::Array(_) => false, + ColumnarValue::Scalar(scalar) => scalar.is_null(), + }, + ) +} + +fn evaluate_expr_with_null_column<'a>( + predicate: Expr, + null_columns: impl IntoIterator, +) -> Result { static DUMMY_COL_NAME: &str = "?"; - let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); - let input_schema = DFSchema::try_from(schema.clone())?; + let schema = Arc::new(Schema::new(vec![Field::new( + DUMMY_COL_NAME, + DataType::Null, + true, + )])); + let input_schema = DFSchema::try_from(Arc::clone(&schema))?; let column = new_null_array(&DataType::Null, 1); - let input_batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![column])?; + let input_batch = RecordBatch::try_new(schema, vec![column])?; let execution_props = ExecutionProps::default(); let null_column = Column::from_name(DUMMY_COL_NAME); - let join_cols_to_replace = join_cols_of_predicate + let join_cols_to_replace = null_columns .into_iter() .map(|column| (column, &null_column)) .collect::>(); let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?; let coerced_predicate = coerce(replaced_predicate, &input_schema)?; - let phys_expr = - create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?; - - let result_type = phys_expr.data_type(&schema)?; - if !matches!(&result_type, DataType::Boolean) { - return Ok(false); - } - - // If result is single `true`, return false; - // If result is single `NULL` or `false`, return true; - Ok(match phys_expr.evaluate(&input_batch)? { - ColumnarValue::Array(array) => { - if array.len() == 1 { - let boolean_array = as_boolean_array(&array)?; - boolean_array.is_null(0) || !boolean_array.value(0) - } else { - false - } - } - ColumnarValue::Scalar(scalar) => matches!( - scalar, - ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) - ), - }) + create_physical_expr(&coerced_predicate, &input_schema, &execution_props)? + .evaluate(&input_batch) } fn coerce(expr: Expr, schema: &DFSchema) -> Result { @@ -141,7 +167,11 @@ mod tests { (Expr::IsNotNull(Box::new(col("a"))), true), // a = NULL ( - binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), + binary_expr( + col("a"), + Operator::Eq, + Expr::Literal(ScalarValue::Null, None), + ), true, ), // a > 8 @@ -204,12 +234,16 @@ mod tests { ), // a IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false), + in_list( + col("a"), + vec![Expr::Literal(ScalarValue::Null, None)], + false, + ), true, ), // a NOT IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true), true, ), ]; @@ -219,7 +253,7 @@ mod tests { let join_cols_of_predicate = std::iter::once(&column_a); let actual = is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; - assert_eq!(actual, expected, "{}", predicate); + assert_eq!(actual, expected, "{predicate}"); } Ok(()) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 941e5bd7b4d77..c0f48b8ebfc40 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -46,6 +46,48 @@ fn init() { let _ = env_logger::try_init(); } +#[test] +fn recursive_cte_with_nested_subquery() -> Result<()> { + // Covers bailout path in `plan_contains_other_subqueries`, ensuring nested subqueries + // within recursive CTE branches prevent projection pushdown. + let sql = r#" + WITH RECURSIVE numbers(id, level) AS ( + SELECT sub.id, sub.level FROM ( + SELECT col_int32 AS id, 1 AS level FROM test + ) sub + UNION ALL + SELECT t.col_int32, numbers.level + 1 + FROM test t + JOIN numbers ON t.col_int32 = numbers.id + 1 + ) + SELECT id, level FROM numbers + "#; + + let plan = test_sql(sql)?; + + assert_snapshot!( + format!("{plan}"), + @r#" + SubqueryAlias: numbers + Projection: sub.id AS id, sub.level AS level + RecursiveQuery: is_distinct=false + Projection: sub.id, sub.level + SubqueryAlias: sub + Projection: test.col_int32 AS id, Int64(1) AS level + TableScan: test + Projection: t.col_int32, numbers.level + Int64(1) + Inner Join: CAST(t.col_int32 AS Int64) = CAST(numbers.id AS Int64) + Int64(1) + SubqueryAlias: t + Filter: CAST(test.col_int32 AS Int64) IS NOT NULL + TableScan: test + Filter: CAST(numbers.id AS Int64) + Int64(1) IS NOT NULL + TableScan: numbers + "# + ); + + Ok(()) +} + #[test] fn case_when() -> Result<()> { let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test"; @@ -250,7 +292,7 @@ fn between_date32_plus_interval() -> Result<()> { format!("{plan}"), @r#" Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] - Projection: + Projection: Filter: test.col_date32 >= Date32("1998-03-18") AND test.col_date32 <= Date32("1998-06-16") TableScan: test projection=[col_date32] "# @@ -268,7 +310,7 @@ fn between_date64_plus_interval() -> Result<()> { format!("{plan}"), @r#" Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] - Projection: + Projection: Filter: test.col_date64 >= Date64("1998-03-18") AND test.col_date64 <= Date64("1998-06-16") TableScan: test projection=[col_date64] "# @@ -284,9 +326,7 @@ fn propagate_empty_relation() { assert_snapshot!( format!("{plan}"), - @r#" - EmptyRelation - "# + @"EmptyRelation: rows=0" ); } @@ -480,6 +520,117 @@ fn select_correlated_predicate_subquery_with_uppercase_ident() { ); } +#[test] +fn recursive_cte_projection_pushdown() -> Result<()> { + // Test that projection pushdown works with recursive CTEs by ensuring + // only the required columns are projected from the base table, even when + // the CTE definition includes unused columns + let sql = "WITH RECURSIVE nodes AS (\ + SELECT col_int32 AS id, col_utf8 AS name, col_uint32 AS extra FROM test \ + UNION ALL \ + SELECT id + 1, name, extra FROM nodes WHERE id < 3\ + ) SELECT id FROM nodes"; + let plan = test_sql(sql)?; + + // The optimizer successfully performs projection pushdown by only selecting the needed + // columns from the base table and recursive table, eliminating unused columns + assert_snapshot!( + format!("{plan}"), + @r#"SubqueryAlias: nodes + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id + TableScan: test projection=[col_int32] + Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id] +"# + ); + Ok(()) +} + +#[test] +fn recursive_cte_with_aliased_self_reference() -> Result<()> { + let sql = "WITH RECURSIVE nodes AS (\ + SELECT col_int32 AS id, col_utf8 AS name FROM test \ + UNION ALL \ + SELECT child.id + 1, child.name FROM nodes AS child WHERE child.id < 3\ + ) SELECT id FROM nodes"; + let plan = test_sql(sql)?; + + assert_snapshot!( + format!("{plan}"), + @r#"SubqueryAlias: nodes + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id + TableScan: test projection=[col_int32] + Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) + SubqueryAlias: child + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id]"#, + ); + Ok(()) +} + +#[test] +fn recursive_cte_with_unused_columns() -> Result<()> { + // Test projection pushdown with a recursive CTE where the base case + // includes columns that are never used in the recursive part or final result + let sql = "WITH RECURSIVE series AS (\ + SELECT 1 AS n, col_utf8, col_uint32, col_date32 FROM test WHERE col_int32 = 1 \ + UNION ALL \ + SELECT n + 1, col_utf8, col_uint32, col_date32 FROM series WHERE n < 3\ + ) SELECT n FROM series"; + let plan = test_sql(sql)?; + + // The optimizer successfully performs projection pushdown by eliminating unused columns + // even when they're defined in the CTE but not actually needed + assert_snapshot!( + format!("{plan}"), + @r#"SubqueryAlias: series + RecursiveQuery: is_distinct=false + Projection: Int64(1) AS n + Filter: test.col_int32 = Int32(1) + TableScan: test projection=[col_int32] + Projection: series.n + Int64(1) + Filter: series.n < Int64(3) + TableScan: series projection=[n] +"# + ); + Ok(()) +} + +#[test] +/// Asserts the minimal plan shape once projection pushdown succeeds for a recursive CTE. +/// Unlike the previous two tests that retain extra columns in either the base or recursive +/// branches, this baseline shows the optimizer trimming everything down to the single +/// column required by the final projection. +fn recursive_cte_projection_pushdown_baseline() -> Result<()> { + // Test case that truly demonstrates projection pushdown working: + // The base case only selects needed columns + let sql = "WITH RECURSIVE countdown AS (\ + SELECT col_int32 AS n FROM test WHERE col_int32 = 5 \ + UNION ALL \ + SELECT n - 1 FROM countdown WHERE n > 1\ + ) SELECT n FROM countdown"; + let plan = test_sql(sql)?; + + // This demonstrates optimal projection pushdown where only col_int32 is projected from the base table, + // and only the needed column is selected from the recursive table + assert_snapshot!( + format!("{plan}"), + @r#"SubqueryAlias: countdown + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS n + Filter: test.col_int32 = Int32(5) + TableScan: test projection=[col_int32] + Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) + Filter: countdown.n > Int32(1) + TableScan: countdown projection=[n] +"# + ); + Ok(()) +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... @@ -492,7 +643,32 @@ fn test_sql(sql: &str) -> Result { .with_expr_planners(vec![ Arc::new(AggregateFunctionPlanner), Arc::new(WindowFunctionPlanner), - ]); + ]) + .with_schema( + "test", + Schema::new_with_metadata( + vec![ + Field::new("col_int32", DataType::Int32, true), + Field::new("col_uint32", DataType::UInt32, true), + Field::new("col_utf8", DataType::Utf8, true), + Field::new("col_date32", DataType::Date32, true), + Field::new("col_date64", DataType::Date64, true), + // timestamp with no timezone + Field::new( + "col_ts_nano_none", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + // timestamp with UTC timezone + Field::new( + "col_ts_nano_utc", + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + true, + ), + ], + HashMap::new(), + ), + ); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; @@ -501,7 +677,7 @@ fn test_sql(sql: &str) -> Result { let analyzer = Analyzer::new(); let optimizer = Optimizer::new(); // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(plan, config.options(), |_, _| {})?; + let plan = analyzer.execute_and_check(plan, &config.options(), |_, _| {})?; optimizer.optimize(plan, &config, observe) } @@ -510,6 +686,7 @@ fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} #[derive(Default)] struct MyContextProvider { options: ConfigOptions, + tables: HashMap>, udafs: HashMap>, expr_planners: Vec>, } @@ -525,38 +702,23 @@ impl MyContextProvider { self.expr_planners = expr_planners; self } + + fn with_schema(mut self, name: impl Into, schema: Schema) -> Self { + self.tables.insert( + name.into(), + Arc::new(MyTableSource { + schema: Arc::new(schema), + }), + ); + self + } } impl ContextProvider for MyContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { let table_name = name.table(); - if table_name.starts_with("test") { - let schema = Schema::new_with_metadata( - vec![ - Field::new("col_int32", DataType::Int32, true), - Field::new("col_uint32", DataType::UInt32, true), - Field::new("col_utf8", DataType::Utf8, true), - Field::new("col_date32", DataType::Date32, true), - Field::new("col_date64", DataType::Date64, true), - // timestamp with no timezone - Field::new( - "col_ts_nano_none", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - // timestamp with UTC timezone - Field::new( - "col_ts_nano_utc", - DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), - true, - ), - ], - HashMap::new(), - ); - - Ok(Arc::new(MyTableSource { - schema: Arc::new(schema), - })) + if let Some(table) = self.tables.get(table_name) { + Ok(table.clone()) } else { plan_err!("table does not exist") } @@ -578,6 +740,14 @@ impl ContextProvider for MyContextProvider { None } + fn create_cte_work_table( + &self, + _name: &str, + schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(MyTableSource { schema })) + } + fn options(&self) -> &ConfigOptions { &self.options } diff --git a/datafusion/physical-expr-adapter/Cargo.toml b/datafusion/physical-expr-adapter/Cargo.toml new file mode 100644 index 0000000000000..03e1b1f06578d --- /dev/null +++ b/datafusion/physical-expr-adapter/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "datafusion-physical-expr-adapter" +description = "Physical expression schema adaptation utilities for DataFusion" +keywords = ["datafusion", "query", "sql"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "datafusion_physical_expr_adapter" +path = "src/lib.rs" + +[dependencies] +arrow = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } +datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +itertools = { workspace = true } + +[dev-dependencies] diff --git a/datafusion/physical-expr-adapter/README.md b/datafusion/physical-expr-adapter/README.md new file mode 100644 index 0000000000000..02bc144c16f34 --- /dev/null +++ b/datafusion/physical-expr-adapter/README.md @@ -0,0 +1,38 @@ + + +# Apache DataFusion Physical Expression Adapter + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate provides utilities for adapting physical expressions to different schemas in DataFusion. + +It handles schema differences in file scans by rewriting expressions to match the physical schema, +including type casting, missing columns, and partition values. + +For detailed documentation, see the [`PhysicalExprAdapter`] trait documentation. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion +[`physicalexpradapter`]: https://docs.rs/datafusion/latest/datafusion/physical_expr_adapter/trait.PhysicalExprAdapter.html diff --git a/datafusion/physical-expr-adapter/src/lib.rs b/datafusion/physical-expr-adapter/src/lib.rs new file mode 100644 index 0000000000000..12ea0025e2667 --- /dev/null +++ b/datafusion/physical-expr-adapter/src/lib.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![doc( + html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", + html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" +)] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! Physical expression schema adaptation utilities for DataFusion + +pub mod schema_rewriter; + +pub use schema_rewriter::{ + DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, + PhysicalExprAdapterFactory, +}; diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs new file mode 100644 index 0000000000000..61cc97dae300e --- /dev/null +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -0,0 +1,862 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression schema rewriting utilities + +use std::sync::Arc; + +use arrow::compute::can_cast_types; +use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; +use datafusion_common::{ + exec_err, + tree_node::{Transformed, TransformedResult, TreeNode}, + Result, ScalarValue, +}; +use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::{ + expressions::{self, CastExpr, Column}, + ScalarFunctionExpr, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// Trait for adapting physical expressions to match a target schema. +/// +/// This is used in file scans to rewrite expressions so that they can be evaluated +/// against the physical schema of the file being scanned. It allows for handling +/// differences between logical and physical schemas, such as type mismatches or missing columns. +/// +/// ## Overview +/// +/// The `PhysicalExprAdapter` allows rewriting physical expressions to match different schemas, including: +/// +/// - **Type casting**: When logical and physical schemas have different types, expressions are +/// automatically wrapped with cast operations. For example, `lit(ScalarValue::Int32(123)) = int64_column` +/// gets rewritten to `lit(ScalarValue::Int32(123)) = cast(int64_column, 'Int32')`. +/// Note that this does not attempt to simplify such expressions - that is done by shared simplifiers. +/// +/// - **Missing columns**: When a column exists in the logical schema but not in the physical schema, +/// references to it are replaced with null literals. +/// +/// - **Struct field access**: Expressions like `struct_column.field_that_is_missing_in_schema` are +/// rewritten to `null` when the field doesn't exist in the physical schema. +/// +/// - **Partition columns**: Partition column references can be replaced with their literal values +/// when scanning specific partitions. +/// +/// ## Custom Implementations +/// +/// You can create a custom implementation of this trait to handle specific rewriting logic. +/// For example, to fill in missing columns with default values instead of nulls: +/// +/// ```rust +/// use datafusion_physical_expr_adapter::{PhysicalExprAdapter, PhysicalExprAdapterFactory}; +/// use arrow::datatypes::{Schema, Field, DataType, FieldRef, SchemaRef}; +/// use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +/// use datafusion_common::{Result, ScalarValue, tree_node::{Transformed, TransformedResult, TreeNode}}; +/// use datafusion_physical_expr::expressions::{self, Column}; +/// use std::sync::Arc; +/// +/// #[derive(Debug)] +/// pub struct CustomPhysicalExprAdapter { +/// logical_file_schema: SchemaRef, +/// physical_file_schema: SchemaRef, +/// } +/// +/// impl PhysicalExprAdapter for CustomPhysicalExprAdapter { +/// fn rewrite(&self, expr: Arc) -> Result> { +/// expr.transform(|expr| { +/// if let Some(column) = expr.as_any().downcast_ref::() { +/// // Check if the column exists in the physical schema +/// if self.physical_file_schema.index_of(column.name()).is_err() { +/// // If the column is missing, fill it with a default value instead of null +/// // The default value could be stored in the table schema's column metadata for example. +/// let default_value = ScalarValue::Int32(Some(0)); +/// return Ok(Transformed::yes(expressions::lit(default_value))); +/// } +/// } +/// // If the column exists, return it as is +/// Ok(Transformed::no(expr)) +/// }).data() +/// } +/// +/// fn with_partition_values( +/// &self, +/// partition_values: Vec<(FieldRef, ScalarValue)>, +/// ) -> Arc { +/// // For simplicity, this example ignores partition values +/// Arc::new(CustomPhysicalExprAdapter { +/// logical_file_schema: self.logical_file_schema.clone(), +/// physical_file_schema: self.physical_file_schema.clone(), +/// }) +/// } +/// } +/// +/// #[derive(Debug)] +/// pub struct CustomPhysicalExprAdapterFactory; +/// +/// impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { +/// fn create( +/// &self, +/// logical_file_schema: SchemaRef, +/// physical_file_schema: SchemaRef, +/// ) -> Arc { +/// Arc::new(CustomPhysicalExprAdapter { +/// logical_file_schema, +/// physical_file_schema, +/// }) +/// } +/// } +/// ``` +pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug { + /// Rewrite a physical expression to match the target schema. + /// + /// This method should return a transformed expression that matches the target schema. + /// + /// Arguments: + /// - `expr`: The physical expression to rewrite. + /// - `logical_file_schema`: The logical schema of the table being queried, excluding any partition columns. + /// - `physical_file_schema`: The physical schema of the file being scanned. + /// - `partition_values`: Optional partition values to use for rewriting partition column references. + /// These are handled as if they were columns appended onto the logical file schema. + /// + /// Returns: + /// - `Arc`: The rewritten physical expression that can be evaluated against the physical schema. + fn rewrite(&self, expr: Arc) -> Result>; + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc; +} + +pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug { + /// Create a new instance of the physical expression adapter. + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc; +} + +#[derive(Debug, Clone)] +pub struct DefaultPhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + Arc::new(DefaultPhysicalExprAdapter { + logical_file_schema, + physical_file_schema, + partition_values: Vec::new(), + }) + } +} + +/// Default implementation for rewriting physical expressions to match different schemas. +/// +/// # Example +/// +/// ```rust +/// use datafusion_physical_expr_adapter::{DefaultPhysicalExprAdapterFactory, PhysicalExprAdapterFactory}; +/// use arrow::datatypes::Schema; +/// use std::sync::Arc; +/// +/// # fn example( +/// # predicate: std::sync::Arc, +/// # physical_file_schema: &Schema, +/// # logical_file_schema: &Schema, +/// # ) -> datafusion_common::Result<()> { +/// let factory = DefaultPhysicalExprAdapterFactory; +/// let adapter = factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone())); +/// let adapted_predicate = adapter.rewrite(predicate)?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct DefaultPhysicalExprAdapter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + partition_values: Vec<(FieldRef, ScalarValue)>, +} + +impl DefaultPhysicalExprAdapter { + /// Create a new instance of the default physical expression adapter. + /// + /// This adapter rewrites expressions to match the physical schema of the file being scanned, + /// handling type mismatches and missing columns by filling them with default values. + pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self { + Self { + logical_file_schema, + physical_file_schema, + partition_values: Vec::new(), + } + } +} + +impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { + fn rewrite(&self, expr: Arc) -> Result> { + let rewriter = DefaultPhysicalExprAdapterRewriter { + logical_file_schema: &self.logical_file_schema, + physical_file_schema: &self.physical_file_schema, + partition_fields: &self.partition_values, + }; + expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) + .data() + } + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc { + Arc::new(DefaultPhysicalExprAdapter { + partition_values, + ..self.clone() + }) + } +} + +struct DefaultPhysicalExprAdapterRewriter<'a> { + logical_file_schema: &'a Schema, + physical_file_schema: &'a Schema, + partition_fields: &'a [(FieldRef, ScalarValue)], +} + +impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { + fn rewrite_expr( + &self, + expr: Arc, + ) -> Result>> { + if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? { + return Ok(Transformed::yes(transformed)); + } + + if let Some(column) = expr.as_any().downcast_ref::() { + return self.rewrite_column(Arc::clone(&expr), column); + } + + Ok(Transformed::no(expr)) + } + + /// Attempt to rewrite struct field access expressions to return null if the field does not exist in the physical schema. + /// Note that this does *not* handle nested struct fields, only top-level struct field access. + /// See for more details. + fn try_rewrite_struct_field_access( + &self, + expr: &Arc, + ) -> Result>> { + let get_field_expr = + match ScalarFunctionExpr::try_downcast_func::(expr.as_ref()) { + Some(expr) => expr, + None => return Ok(None), + }; + + let source_expr = match get_field_expr.args().first() { + Some(expr) => expr, + None => return Ok(None), + }; + + let field_name_expr = match get_field_expr.args().get(1) { + Some(expr) => expr, + None => return Ok(None), + }; + + let lit = match field_name_expr + .as_any() + .downcast_ref::() + { + Some(lit) => lit, + None => return Ok(None), + }; + + let field_name = match lit.value().try_as_str().flatten() { + Some(name) => name, + None => return Ok(None), + }; + + let column = match source_expr.as_any().downcast_ref::() { + Some(column) => column, + None => return Ok(None), + }; + + let physical_field = + match self.physical_file_schema.field_with_name(column.name()) { + Ok(field) => field, + Err(_) => return Ok(None), + }; + + let physical_struct_fields = match physical_field.data_type() { + DataType::Struct(fields) => fields, + _ => return Ok(None), + }; + + if physical_struct_fields + .iter() + .any(|f| f.name() == field_name) + { + return Ok(None); + } + + let logical_field = match self.logical_file_schema.field_with_name(column.name()) + { + Ok(field) => field, + Err(_) => return Ok(None), + }; + + let logical_struct_fields = match logical_field.data_type() { + DataType::Struct(fields) => fields, + _ => return Ok(None), + }; + + let logical_struct_field = match logical_struct_fields + .iter() + .find(|f| f.name() == field_name) + { + Some(field) => field, + None => return Ok(None), + }; + + let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?; + Ok(Some(expressions::lit(null_value))) + } + + fn rewrite_column( + &self, + expr: Arc, + column: &Column, + ) -> Result>> { + // Get the logical field for this column if it exists in the logical schema + let logical_field = match self.logical_file_schema.field_with_name(column.name()) + { + Ok(field) => field, + Err(e) => { + // If the column is a partition field, we can use the partition value + if let Some(partition_value) = self.get_partition_value(column.name()) { + return Ok(Transformed::yes(expressions::lit(partition_value))); + } + // This can be hit if a custom rewrite injected a reference to a column that doesn't exist in the logical schema. + // For example, a pre-computed column that is kept only in the physical schema. + // If the column exists in the physical schema, we can still use it. + if let Ok(physical_field) = + self.physical_file_schema.field_with_name(column.name()) + { + // If the column exists in the physical schema, we can use it in place of the logical column. + // This is nice to users because if they do a rewrite that results in something like `physical_int32_col = 123u64` + // we'll at least handle the casts for them. + physical_field + } else { + // A completely unknown column that doesn't exist in either schema! + // This should probably never be hit unless something upstream broke, but nonetheless it's better + // for us to return a handleable error than to panic / do something unexpected. + return Err(e.into()); + } + } + }; + + // Check if the column exists in the physical schema + let physical_column_index = + match self.physical_file_schema.index_of(column.name()) { + Ok(index) => index, + Err(_) => { + if !logical_field.is_nullable() { + return exec_err!( + "Non-nullable column '{}' is missing from the physical schema", + column.name() + ); + } + // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. + // TODO: do we need to sync this with what the `SchemaAdapter` actually does? + // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else! + // See https://github.com/apache/datafusion/issues/16527 + let null_value = + ScalarValue::Null.cast_to(logical_field.data_type())?; + return Ok(Transformed::yes(expressions::lit(null_value))); + } + }; + let physical_field = self.physical_file_schema.field(physical_column_index); + + let column = match ( + column.index() == physical_column_index, + logical_field.data_type() == physical_field.data_type(), + ) { + // If the column index matches and the data types match, we can use the column as is + (true, true) => return Ok(Transformed::no(expr)), + // If the indexes or data types do not match, we need to create a new column expression + (true, _) => column.clone(), + (false, _) => { + Column::new_with_schema(logical_field.name(), self.physical_file_schema)? + } + }; + + if logical_field.data_type() == physical_field.data_type() { + // If the data types match, we can use the column as is + return Ok(Transformed::yes(Arc::new(column))); + } + + // We need to cast the column to the logical data type + // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` + // since that's much cheaper to evalaute. + // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 + let is_compatible = + can_cast_types(physical_field.data_type(), logical_field.data_type()); + if !is_compatible { + return exec_err!( + "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", + column.name(), + physical_field.data_type(), + logical_field.data_type() + ); + } + + let cast_expr = Arc::new(CastExpr::new( + Arc::new(column), + logical_field.data_type().clone(), + None, + )); + + Ok(Transformed::yes(cast_expr)) + } + + fn get_partition_value(&self, column_name: &str) -> Option { + self.partition_fields + .iter() + .find(|(field, _)| field.name() == column_name) + .map(|(_, value)| value.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{RecordBatch, RecordBatchOptions}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::{assert_contains, record_batch, Result, ScalarValue}; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use itertools::Itertools; + use std::sync::Arc; + + fn create_test_schema() -> (Schema, Schema) { + let physical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Different type + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), // Missing from physical + ]); + + (physical_schema, logical_schema) + } + + #[test] + fn test_rewrite_column_with_type_cast() { + let (physical_schema, logical_schema) = create_test_schema(); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let column_expr = Arc::new(Column::new("a", 0)); + + let result = adapter.rewrite(column_expr).unwrap(); + + // Should be wrapped in a cast expression + assert!(result.as_any().downcast_ref::().is_some()); + } + + #[test] + fn test_rewrite_multi_column_expr_with_type_cast() { + let (physical_schema, logical_schema) = create_test_schema(); + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + + // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter + let column_a = Arc::new(Column::new("a", 0)) as Arc; + let column_c = Arc::new(Column::new("c", 2)) as Arc; + let expr = expressions::BinaryExpr::new( + Arc::clone(&column_a), + Operator::Plus, + Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), + ); + let expr = expressions::BinaryExpr::new( + Arc::new(expr), + Operator::Or, + Arc::new(expressions::BinaryExpr::new( + Arc::clone(&column_c), + Operator::Gt, + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))), + )), + ); + + let result = adapter.rewrite(Arc::new(expr)).unwrap(); + println!("Rewritten expression: {result}"); + + let expected = expressions::BinaryExpr::new( + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Int64, + None, + )), + Operator::Plus, + Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), + ); + let expected = Arc::new(expressions::BinaryExpr::new( + Arc::new(expected), + Operator::Or, + Arc::new(expressions::BinaryExpr::new( + lit(ScalarValue::Float64(None)), // c is missing, so it becomes null + Operator::Gt, + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))), + )), + )) as Arc; + + assert_eq!( + result.to_string(), + expected.to_string(), + "The rewritten expression did not match the expected output" + ); + } + + #[test] + fn test_rewrite_struct_column_incompatible() { + let physical_schema = Schema::new(vec![Field::new( + "data", + DataType::Struct(vec![Field::new("field1", DataType::Binary, true)].into()), + true, + )]); + + let logical_schema = Schema::new(vec![Field::new( + "data", + DataType::Struct(vec![Field::new("field1", DataType::Int32, true)].into()), + true, + )]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let column_expr = Arc::new(Column::new("data", 0)); + + let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string(); + assert_contains!(error_msg, "Cannot cast column 'data'"); + } + + #[test] + fn test_rewrite_struct_compatible_cast() { + let physical_schema = Schema::new(vec![Field::new( + "data", + DataType::Struct( + vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(), + ), + false, + )]); + + let logical_schema = Schema::new(vec![Field::new( + "data", + DataType::Struct( + vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8View, true), + ] + .into(), + ), + false, + )]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let column_expr = Arc::new(Column::new("data", 0)); + + let result = adapter.rewrite(column_expr).unwrap(); + + let expected = Arc::new(CastExpr::new( + Arc::new(Column::new("data", 0)), + DataType::Struct( + vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8View, true), + ] + .into(), + ), + None, + )) as Arc; + + assert_eq!(result.to_string(), expected.to_string()); + } + + #[test] + fn test_rewrite_missing_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let column_expr = Arc::new(Column::new("c", 2)); + + let result = adapter.rewrite(column_expr)?; + + // Should be replaced with a literal null + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!(*literal.value(), ScalarValue::Float64(None)); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_missing_column_non_nullable_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), // Missing and non-nullable + ]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let column_expr = Arc::new(Column::new("b", 1)); + + let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string(); + assert_contains!(error_msg, "Non-nullable column 'b' is missing"); + } + + #[test] + fn test_rewrite_missing_column_nullable() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, true), // Missing but nullable + ]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let column_expr = Arc::new(Column::new("b", 1)); + + let result = adapter.rewrite(column_expr).unwrap(); + + let expected = + Arc::new(Literal::new(ScalarValue::Utf8(None))) as Arc; + + assert_eq!(result.to_string(), expected.to_string()); + } + + #[test] + fn test_rewrite_partition_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let partition_field = + Arc::new(Field::new("partition_col", DataType::Utf8, false)); + let partition_value = ScalarValue::Utf8(Some("test_value".to_string())); + let partition_values = vec![(partition_field, partition_value)]; + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = adapter.with_partition_values(partition_values); + + let column_expr = Arc::new(Column::new("partition_col", 0)); + let result = adapter.rewrite(column_expr)?; + + // Should be replaced with the partition value + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!( + *literal.value(), + ScalarValue::Utf8(Some("test_value".to_string())) + ); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_no_change_needed() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let column_expr = Arc::new(Column::new("b", 1)) as Arc; + + let result = adapter.rewrite(Arc::clone(&column_expr))?; + + // Should be the same expression (no transformation needed) + // We compare the underlying pointer through the trait object + assert!(std::ptr::eq( + column_expr.as_ref() as *const dyn PhysicalExpr, + result.as_ref() as *const dyn PhysicalExpr + )); + + Ok(()) + } + + #[test] + fn test_non_nullable_missing_column_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), // Non-nullable missing column + ]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let column_expr = Arc::new(Column::new("b", 1)); + + let result = adapter.rewrite(column_expr); + assert!(result.is_err()); + assert_contains!( + result.unwrap_err().to_string(), + "Non-nullable column 'b' is missing from the physical schema" + ); + } + + /// Helper function to project expressions onto a RecordBatch + fn batch_project( + expr: Vec>, + batch: &RecordBatch, + schema: SchemaRef, + ) -> Result { + let arrays = expr + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>>()?; + + if arrays.is_empty() { + let options = + RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options) + .map_err(Into::into) + } else { + RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into) + } + } + + /// Example showing how we can use the `DefaultPhysicalExprAdapter` to adapt RecordBatches during a scan + /// to apply projections, type conversions and handling of missing columns all at once. + #[test] + fn test_adapt_batches() { + let physical_batch = record_batch!( + ("a", Int32, vec![Some(1), None, Some(3)]), + ("extra", Utf8, vec![Some("x"), Some("y"), None]) + ) + .unwrap(); + + let physical_schema = physical_batch.schema(); + + let logical_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), // Different type + Field::new("b", DataType::Utf8, true), // Missing from physical + ])); + + let projection = vec![ + col("b", &logical_schema).unwrap(), + col("a", &logical_schema).unwrap(), + ]; + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = + factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)); + + let adapted_projection = projection + .into_iter() + .map(|expr| adapter.rewrite(expr).unwrap()) + .collect_vec(); + + let adapted_schema = Arc::new(Schema::new( + adapted_projection + .iter() + .map(|expr| expr.return_field(&physical_schema).unwrap()) + .collect_vec(), + )); + + let res = batch_project( + adapted_projection, + &physical_batch, + Arc::clone(&adapted_schema), + ) + .unwrap(); + + assert_eq!(res.num_columns(), 2); + assert_eq!(res.column(0).data_type(), &DataType::Utf8); + assert_eq!(res.column(1).data_type(), &DataType::Int64); + assert_eq!( + res.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect_vec(), + vec![None, None, None] + ); + assert_eq!( + res.column(1) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect_vec(), + vec![Some(1), None, Some(3)] + ); + } + + #[test] + fn test_try_rewrite_struct_field_access() { + // Test the core logic of try_rewrite_struct_field_access + let physical_schema = Schema::new(vec![Field::new( + "struct_col", + DataType::Struct( + vec![Field::new("existing_field", DataType::Int32, true)].into(), + ), + true, + )]); + + let logical_schema = Schema::new(vec![Field::new( + "struct_col", + DataType::Struct( + vec![ + Field::new("existing_field", DataType::Int32, true), + Field::new("missing_field", DataType::Utf8, true), + ] + .into(), + ), + true, + )]); + + let rewriter = DefaultPhysicalExprAdapterRewriter { + logical_file_schema: &logical_schema, + physical_file_schema: &physical_schema, + partition_fields: &[], + }; + + // Test that when a field exists in physical schema, it returns None + let column = Arc::new(Column::new("struct_col", 0)) as Arc; + let result = rewriter.try_rewrite_struct_field_access(&column).unwrap(); + assert!(result.is_none()); + + // The actual test for the get_field expression would require creating a proper ScalarFunctionExpr + // with ScalarUDF, which is complex to set up in a unit test. The integration tests in + // datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality. + } +} diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index a5a12b5527b7d..58dc767dbad2a 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -40,7 +40,7 @@ name = "datafusion_physical_expr_common" [dependencies] ahash = { workspace = true } arrow = { workspace = true } -datafusion-common = { workspace = true, default-features = true } +datafusion-common = { workspace = true } datafusion-expr-common = { workspace = true } hashbrown = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/physical-expr-common/README.md b/datafusion/physical-expr-common/README.md index 7a1eff77d3b4f..c318e7468183f 100644 --- a/datafusion/physical-expr-common/README.md +++ b/datafusion/physical-expr-common/README.md @@ -17,11 +17,19 @@ under the License. --> -# DataFusion Core Physical Expressions +# Apache DataFusion Core Physical Expressions -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides shared APIs for implementing -physical expressions such as `PhysicalExpr` and `PhysicalSortExpr`. +physical expressions such as [`PhysicalExpr`] and [`PhysicalSortExpr`]. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion +[`physicalexpr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/trait.PhysicalExpr.html +[`physicalsortexpr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.PhysicalSortExpr.html diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 233deff758c7b..7084bc440e86b 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -154,9 +154,26 @@ pub fn compare_op_for_nested( if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) { Ok(BooleanArray::new(values, None)) } else { - // If one of the side is NULL, we returns NULL + // If one of the side is NULL, we return NULL // i.e. NULL eq NULL -> NULL - let nulls = NullBuffer::union(l.nulls(), r.nulls()); + // For nested comparisons, we need to ensure the null buffer matches the result length + let nulls = match (is_l_scalar, is_r_scalar) { + (false, false) | (true, true) => NullBuffer::union(l.nulls(), r.nulls()), + (true, false) => { + // When left is null-scalar and right is array, expand left nulls to match result length + match l.nulls().filter(|nulls| !nulls.is_valid(0)) { + Some(_) => Some(NullBuffer::new_null(len)), // Left scalar is null + None => r.nulls().cloned(), // Left scalar is non-null + } + } + (false, true) => { + // When right is null-scalar and left is array, expand right nulls to match result length + match r.nulls().filter(|nulls| !nulls.is_valid(0)) { + Some(_) => Some(NullBuffer::new_null(len)), // Right scalar is null + None => l.nulls().cloned(), // Right scalar is non-null + } + } + }; Ok(BooleanArray::new(values, nulls)) } } diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 86d4487f4c126..e21206d906422 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 43f214607f9fc..e5e7d6c00f08d 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -23,11 +23,14 @@ use std::sync::Arc; use crate::utils::scatter; -use arrow::array::BooleanArray; +use arrow::array::{new_empty_array, ArrayRef, BooleanArray}; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; @@ -65,34 +68,91 @@ pub type PhysicalExprRef = Arc; /// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html /// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html /// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html -pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { +pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; + fn data_type(&self, input_schema: &Schema) -> Result { + Ok(self.return_field(input_schema)?.data_type().to_owned()) + } /// Determine whether this expression is nullable, given the schema of the input - fn nullable(&self, input_schema: &Schema) -> Result; + fn nullable(&self, input_schema: &Schema) -> Result { + Ok(self.return_field(input_schema)?.is_nullable()) + } /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; - /// Evaluate an expression against a RecordBatch after first applying a - /// validity array + /// The output field associated with this expression + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(Arc::new(Field::new( + format!("{self}"), + self.data_type(input_schema)?, + self.nullable(input_schema)?, + ))) + } + /// Evaluate an expression against a RecordBatch after first applying a validity array + /// + /// # Errors + /// + /// Returns an `Err` if the expression could not be evaluated or if the length of the + /// `selection` validity array and the number of row in `batch` is not equal. fn evaluate_selection( &self, batch: &RecordBatch, selection: &BooleanArray, ) -> Result { - let tmp_batch = filter_record_batch(batch, selection)?; + let row_count = batch.num_rows(); + if row_count != selection.len() { + return exec_err!("Selection array length does not match batch row count: {} != {row_count}", selection.len()); + } - let tmp_result = self.evaluate(&tmp_batch)?; + let selection_count = selection.true_count(); + + // First, check if we can avoid filtering altogether. + if selection_count == row_count { + // All values from the `selection` filter are true and match the input batch. + // No need to perform any filtering. + return self.evaluate(batch); + } - if batch.num_rows() == tmp_batch.num_rows() { - // All values from the `selection` filter are true. - Ok(tmp_result) - } else if let ColumnarValue::Array(a) = tmp_result { - scatter(selection, a.as_ref()).map(ColumnarValue::Array) + // Next, prepare the result array for each 'true' row in the selection vector. + let filtered_result = if selection_count == 0 { + // Do not call `evaluate` when the selection is empty. + // `evaluate_selection` is used to conditionally evaluate expressions. + // When the expression in question is fallible, evaluating it with an empty + // record batch may trigger a runtime error (e.g. division by zero). + // + // Instead, create an empty array matching the expected return type. + let datatype = self.data_type(batch.schema_ref().as_ref())?; + ColumnarValue::Array(new_empty_array(&datatype)) } else { - Ok(tmp_result) + // If we reach this point, there's no other option than to filter the batch. + // This is a fairly costly operation since it requires creating partial copies + // (worst case of length `row_count - 1`) of all the arrays in the record batch. + // The resulting `filtered_batch` will contain `selection_count` rows. + let filtered_batch = filter_record_batch(batch, selection)?; + self.evaluate(&filtered_batch)? + }; + + // Finally, scatter the filtered result array so that the indices match the input rows again. + match &filtered_result { + ColumnarValue::Array(a) => { + scatter(selection, a.as_ref()).map(ColumnarValue::Array) + } + ColumnarValue::Scalar(ScalarValue::Boolean(value)) => { + // When the scalar is true or false, skip the scatter process + if let Some(v) = value { + if *v { + Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef)) + } else { + Ok(filtered_result) + } + } else { + let array = BooleanArray::from(vec![None; row_count]); + scatter(selection, &array).map(ColumnarValue::Array) + } + } + ColumnarValue::Scalar(_) => Ok(filtered_result), } } @@ -283,42 +343,104 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// See the [`fmt_sql`] function for an example of printing `PhysicalExpr`s as SQL. /// fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result; -} -/// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object -/// safe. To ease implementation, blanket implementation is provided for [`Eq`] types. -pub trait DynEq { - fn dyn_eq(&self, other: &dyn Any) -> bool; -} + /// Take a snapshot of this `PhysicalExpr`, if it is dynamic. + /// + /// "Dynamic" in this case means containing references to structures that may change + /// during plan execution, such as hash tables. + /// + /// This method is used to capture the current state of `PhysicalExpr`s that may contain + /// dynamic references to other operators in order to serialize it over the wire + /// or treat it via downcast matching. + /// + /// You should not call this method directly as it does not handle recursion. + /// Instead use [`snapshot_physical_expr`] to handle recursion and capture the + /// full state of the `PhysicalExpr`. + /// + /// This is expected to return "simple" expressions that do not have mutable state + /// and are composed of DataFusion's built-in `PhysicalExpr` implementations. + /// Callers however should *not* assume anything about the returned expressions + /// since callers and implementers may not agree on what "simple" or "built-in" + /// means. + /// In other words, if you need to serialize a `PhysicalExpr` across the wire + /// you should call this method and then try to serialize the result, + /// but you should handle unknown or unexpected `PhysicalExpr` implementations gracefully + /// just as if you had not called this method at all. + /// + /// In particular, consider: + /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::TopK` + /// that is involved in a query with `SELECT * FROM t1 ORDER BY a LIMIT 10`. + /// This function may return something like `a >= 12`. + /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::joins::HashJoinExec` + /// from a query such as `SELECT * FROM t1 JOIN t2 ON t1.a = t2.b`. + /// This function may return something like `t2.b IN (1, 5, 7)`. + /// + /// A system or function that can only deal with a hardcoded set of `PhysicalExpr` implementations + /// or needs to serialize this state to bytes may not be able to handle these dynamic references. + /// In such cases, we should return a simplified version of the `PhysicalExpr` that does not + /// contain these dynamic references. + /// + /// Systems that implement remote execution of plans, e.g. serialize a portion of the query plan + /// and send it across the wire to a remote executor may want to call this method after + /// every batch on the source side and broadcast / update the current snapshot to the remote executor. + /// + /// Note for implementers: this method should *not* handle recursion. + /// Recursion is handled in [`snapshot_physical_expr`]. + fn snapshot(&self) -> Result>> { + // By default, we return None to indicate that this PhysicalExpr does not + // have any dynamic references or state. + // This is a safe default behavior. + Ok(None) + } -impl DynEq for T { - fn dyn_eq(&self, other: &dyn Any) -> bool { - other.downcast_ref::() == Some(self) + /// Returns the generation of this `PhysicalExpr` for snapshotting purposes. + /// The generation is an arbitrary u64 that can be used to track changes + /// in the state of the `PhysicalExpr` over time without having to do an exhaustive comparison. + /// This is useful to avoid unnecessary computation or serialization if there are no changes to the expression. + /// In particular, dynamic expressions that may change over time; this allows cheap checks for changes. + /// Static expressions that do not change over time should return 0, as does the default implementation. + /// You should not call this method directly as it does not handle recursion. + /// Instead use [`snapshot_generation`] to handle recursion and capture the + /// full state of the `PhysicalExpr`. + fn snapshot_generation(&self) -> u64 { + // By default, we return 0 to indicate that this PhysicalExpr does not + // have any dynamic references or state. + // Since the recursive algorithm XORs the generations of all children the overall + // generation will be 0 if no children have a non-zero generation, meaning that + // static expressions will always return 0. + 0 + } + + /// Returns true if the expression node is volatile, i.e. whether it can return + /// different results when evaluated multiple times with the same input. + /// + /// Note: unlike [`is_volatile`], this function does not consider inputs: + /// - `random()` returns `true`, + /// - `a + random()` returns `false` (because the operation `+` itself is not volatile.) + /// + /// The default to this function was set to `false` when it was created + /// to avoid imposing API churn on implementers, but this is not a safe default in general. + /// It is highly recommended that volatile expressions implement this method and return `true`. + /// This default may be removed in the future if it causes problems or we decide to + /// eat the cost of the breaking change and require all implementers to make a choice. + fn is_volatile_node(&self) -> bool { + false } } +#[deprecated( + since = "50.0.0", + note = "Use `datafusion_expr_common::dyn_eq` instead" +)] +pub use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; + impl PartialEq for dyn PhysicalExpr { fn eq(&self, other: &Self) -> bool { self.dyn_eq(other.as_any()) } } - impl Eq for dyn PhysicalExpr {} -/// [`PhysicalExpr`] can't be constrained by [`Hash`] directly because it must remain -/// object safe. To ease implementation blanket implementation is provided for [`Hash`] -/// types. -pub trait DynHash { - fn dyn_hash(&self, _state: &mut dyn Hasher); -} - -impl DynHash for T { - fn dyn_hash(&self, mut state: &mut dyn Hasher) { - self.type_id().hash(&mut state); - self.hash(&mut state) - } -} - impl Hash for dyn PhysicalExpr { fn hash(&self, state: &mut H) { self.dyn_hash(state); @@ -346,21 +468,6 @@ pub fn with_new_children_if_necessary( } } -#[deprecated(since = "44.0.0")] -pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { - if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else { - any - } -} - /// Returns [`Display`] able a list of [`PhysicalExpr`] /// /// Example output: `[a + 1, b]` @@ -384,10 +491,10 @@ where let mut iter = self.0.clone(); write!(f, "[")?; if let Some(expr) = iter.next() { - write!(f, "{}", expr)?; + write!(f, "{expr}")?; } for expr in iter { - write!(f, ", {}", expr)?; + write!(f, ", {expr}")?; } write!(f, "]")?; Ok(()) @@ -401,27 +508,28 @@ where /// /// # Example /// ``` -/// # // The boiler plate needed to create a `PhysicalExpr` for the example +/// # // The boilerplate needed to create a `PhysicalExpr` for the example /// # use std::any::Any; +/// use std::collections::HashMap; /// # use std::fmt::Formatter; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema}; /// # use datafusion_common::Result; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr}; -/// # #[derive(Debug, Hash, PartialOrd, PartialEq)] -/// # struct MyExpr {}; +/// # #[derive(Debug, PartialEq, Eq, Hash)] +/// # struct MyExpr {} /// # impl PhysicalExpr for MyExpr {fn as_any(&self) -> &dyn Any { unimplemented!() } /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } /// # } /// # impl std::fmt::Display for MyExpr {fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { unimplemented!() } } -/// # impl DynEq for MyExpr {fn dyn_eq(&self, other: &dyn Any) -> bool { unimplemented!() } } /// # fn make_physical_expr() -> Arc { Arc::new(MyExpr{}) } /// let expr: Arc = make_physical_expr(); /// // wrap the expression in `sql_fmt` which can be used with @@ -446,3 +554,270 @@ pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ { Wrapper { expr } } + +/// Take a snapshot of the given `PhysicalExpr` if it is dynamic. +/// +/// Take a snapshot of this `PhysicalExpr` if it is dynamic. +/// This is used to capture the current state of `PhysicalExpr`s that may contain +/// dynamic references to other operators in order to serialize it over the wire +/// or treat it via downcast matching. +/// +/// See the documentation of [`PhysicalExpr::snapshot`] for more details. +/// +/// # Returns +/// +/// Returns an `Option>` which is the snapshot of the +/// `PhysicalExpr` if it is dynamic. If the `PhysicalExpr` does not have +/// any dynamic references or state, it returns `None`. +pub fn snapshot_physical_expr( + expr: Arc, +) -> Result> { + expr.transform_up(|e| { + if let Some(snapshot) = e.snapshot()? { + Ok(Transformed::yes(snapshot)) + } else { + Ok(Transformed::no(Arc::clone(&e))) + } + }) + .data() +} + +/// Check the generation of this `PhysicalExpr`. +/// Dynamic `PhysicalExpr`s may have a generation that is incremented +/// every time the state of the `PhysicalExpr` changes. +/// If the generation changes that means this `PhysicalExpr` or one of its children +/// has changed since the last time it was evaluated. +/// +/// This algorithm will not produce collisions as long as the structure of the +/// `PhysicalExpr` does not change and no `PhysicalExpr` decrements its own generation. +pub fn snapshot_generation(expr: &Arc) -> u64 { + let mut generation = 0u64; + expr.apply(|e| { + // Add the current generation of the `PhysicalExpr` to our global generation. + generation = generation.wrapping_add(e.snapshot_generation()); + Ok(TreeNodeRecursion::Continue) + }) + .expect("this traversal is infallible"); + + generation +} + +/// Check if the given `PhysicalExpr` is dynamic. +/// Internally this calls [`snapshot_generation`] to check if the generation is non-zero, +/// any dynamic `PhysicalExpr` should have a non-zero generation. +pub fn is_dynamic_physical_expr(expr: &Arc) -> bool { + // If the generation is non-zero, then this `PhysicalExpr` is dynamic. + snapshot_generation(expr) != 0 +} + +/// Returns true if the expression is volatile, i.e. whether it can return different +/// results when evaluated multiple times with the same input. +/// +/// For example the function call `RANDOM()` is volatile as each call will +/// return a different value. +/// +/// This method recursively checks if any sub-expression is volatile, for example +/// `1 + RANDOM()` will return `true`. +pub fn is_volatile(expr: &Arc) -> bool { + if expr.is_volatile_node() { + return true; + } + let mut is_volatile = false; + expr.apply(|e| { + if e.is_volatile_node() { + is_volatile = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .expect("infallible closure should not fail"); + is_volatile +} + +#[cfg(test)] +mod test { + use crate::physical_expr::PhysicalExpr; + use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch}; + use arrow::datatypes::{DataType, Schema}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use std::fmt::{Display, Formatter}; + use std::sync::Arc; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestExpr {} + + impl PhysicalExpr for TestExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn data_type(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(DataType::Int64) + } + + fn nullable(&self, _schema: &Schema) -> datafusion_common::Result { + Ok(false) + } + + fn evaluate( + &self, + batch: &RecordBatch, + ) -> datafusion_common::Result { + let data = vec![1; batch.num_rows()]; + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data)))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(Self {})) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("TestExpr") + } + } + + impl Display for TestExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.fmt_sql(f) + } + } + + macro_rules! assert_arrays_eq { + ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => { + let expected = $EXPECTED.to_array(1).unwrap(); + let actual = $ACTUAL; + + let actual_array = actual.to_array(expected.len()).unwrap(); + let actual_ref = actual_array.as_ref(); + let expected_ref = expected.as_ref(); + assert!( + actual_ref == expected_ref, + "{}: expected: {:?}, actual: {:?}", + $MESSAGE, + $EXPECTED, + actual_ref + ); + }; + } + + fn test_evaluate_selection( + batch: &RecordBatch, + selection: &BooleanArray, + expected: &ColumnarValue, + ) { + let expr = TestExpr {}; + + // First check that the `evaluate_selection` is the expected one + let selection_result = expr.evaluate_selection(batch, selection).unwrap(); + assert_eq!( + expected.to_array(1).unwrap().len(), + selection_result.to_array(1).unwrap().len(), + "evaluate_selection should output row count should match input record batch" + ); + assert_arrays_eq!( + expected, + &selection_result, + "evaluate_selection returned unexpected value" + ); + + // If we're selecting all rows, the result should be the same as calling `evaluate` + // with the full record batch. + if (0..batch.num_rows()) + .all(|row_idx| row_idx < selection.len() && selection.value(row_idx)) + { + let empty_result = expr.evaluate(batch).unwrap(); + + assert_arrays_eq!( + empty_result, + &selection_result, + "evaluate_selection does not match unfiltered evaluate result" + ); + } + } + + fn test_evaluate_selection_error(batch: &RecordBatch, selection: &BooleanArray) { + let expr = TestExpr {}; + + // First check that the `evaluate_selection` is the expected one + let selection_result = expr.evaluate_selection(batch, selection); + assert!(selection_result.is_err(), "evaluate_selection should fail"); + } + + #[test] + pub fn test_evaluate_selection_with_empty_record_batch() { + test_evaluate_selection( + &RecordBatch::new_empty(Arc::new(Schema::empty())), + &BooleanArray::from(vec![false; 0]), + &ColumnarValue::Array(Arc::new(Int64Array::new_null(0))), + ); + } + + #[test] + pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() { + test_evaluate_selection_error( + &RecordBatch::new_empty(Arc::new(Schema::empty())), + &BooleanArray::from(vec![false; 10]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() { + test_evaluate_selection_error( + &RecordBatch::new_empty(Arc::new(Schema::empty())), + &BooleanArray::from(vec![true; 10]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch() { + test_evaluate_selection( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![true; 10]), + &ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection( + ) { + test_evaluate_selection_error( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![false; 20]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection( + ) { + test_evaluate_selection_error( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![true; 20]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection( + ) { + test_evaluate_selection_error( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![false; 5]), + ); + } + + #[test] + pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection( + ) { + test_evaluate_selection_error( + unsafe { &RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) }, + &BooleanArray::from(vec![true; 5]), + ); + } +} diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 3a54b5b403995..d19d7024a516e 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -17,33 +17,34 @@ //! Sort expressions -use crate::physical_expr::{fmt_sql, PhysicalExpr}; -use std::fmt; -use std::fmt::{Display, Formatter}; +use std::cmp::Ordering; +use std::fmt::{self, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::ops::{Deref, Index, Range, RangeFrom, RangeTo}; -use std::sync::{Arc, LazyLock}; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; use std::vec::IntoIter; +use crate::physical_expr::{fmt_sql, PhysicalExpr}; + use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{HashSet, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; -use itertools::Itertools; /// Represents Sort operation for a column in a RecordBatch /// /// Example: /// ``` /// # use std::any::Any; +/// # use std::collections::HashMap; /// # use std::fmt::{Display, Formatter}; /// # use std::hash::Hasher; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; /// # use arrow::compute::SortOptions; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// # use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -56,6 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() } @@ -75,7 +77,7 @@ use itertools::Itertools; /// .nulls_last(); /// assert_eq!(sort_expr.to_string(), "a DESC NULLS LAST"); /// ``` -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq)] pub struct PhysicalSortExpr { /// Physical expression representing the column to sort pub expr: Arc, @@ -94,6 +96,15 @@ impl PhysicalSortExpr { Self::new(expr, SortOptions::default()) } + /// Reverses the sort expression. For instance, `[a ASC NULLS LAST]` turns + /// into `[a DESC NULLS FIRST]`. Such reversals are useful in planning, e.g. + /// when constructing equivalent window expressions. + pub fn reverse(&self) -> Self { + let mut result = self.clone(); + result.options = !result.options; + result + } + /// Set the sort sort options to ASC pub fn asc(mut self) -> Self { self.options.descending = false; @@ -127,23 +138,58 @@ impl PhysicalSortExpr { to_str(&self.options) ) } -} -/// Access the PhysicalSortExpr as a PhysicalExpr -impl AsRef for PhysicalSortExpr { - fn as_ref(&self) -> &(dyn PhysicalExpr + 'static) { - self.expr.as_ref() + /// Evaluates the sort expression into a `SortColumn` that can be passed + /// into the arrow sort kernel. + pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result { + let array_to_sort = match self.expr.evaluate(batch)? { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, + }; + Ok(SortColumn { + values: array_to_sort, + options: Some(self.options), + }) + } + + /// Checks whether this sort expression satisfies the given `requirement`. + /// If sort options are unspecified in `requirement`, only expressions are + /// compared for inequality. See [`options_compatible`] for details on + /// how sort options compare with one another. + pub fn satisfy( + &self, + requirement: &PhysicalSortRequirement, + schema: &Schema, + ) -> bool { + self.expr.eq(&requirement.expr) + && requirement.options.is_none_or(|opts| { + options_compatible( + &self.options, + &opts, + self.expr.nullable(schema).unwrap_or(true), + ) + }) + } + + /// Checks whether this sort expression satisfies the given `sort_expr`. + /// See [`options_compatible`] for details on how sort options compare with + /// one another. + pub fn satisfy_expr(&self, sort_expr: &Self, schema: &Schema) -> bool { + self.expr.eq(&sort_expr.expr) + && options_compatible( + &self.options, + &sort_expr.options, + self.expr.nullable(schema).unwrap_or(true), + ) } } impl PartialEq for PhysicalSortExpr { - fn eq(&self, other: &PhysicalSortExpr) -> bool { + fn eq(&self, other: &Self) -> bool { self.options == other.options && self.expr.eq(&other.expr) } } -impl Eq for PhysicalSortExpr {} - impl Hash for PhysicalSortExpr { fn hash(&self, state: &mut H) { self.expr.hash(state); @@ -157,38 +203,20 @@ impl Display for PhysicalSortExpr { } } -impl PhysicalSortExpr { - /// evaluate the sort expression into SortColumn that can be passed into arrow sort kernel - pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result { - let value_to_sort = self.expr.evaluate(batch)?; - let array_to_sort = match value_to_sort { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, - }; - Ok(SortColumn { - values: array_to_sort, - options: Some(self.options), - }) - } - - /// Checks whether this sort expression satisfies the given `requirement`. - /// If sort options are unspecified in `requirement`, only expressions are - /// compared for inequality. - pub fn satisfy( - &self, - requirement: &PhysicalSortRequirement, - schema: &Schema, - ) -> bool { +/// Returns whether the given two [`SortOptions`] are compatible. Here, +/// compatibility means that they are either exactly equal, or they differ only +/// in whether NULL values come in first/last, which is immaterial because the +/// column in question is not nullable (specified by the `nullable` parameter). +pub fn options_compatible( + options_lhs: &SortOptions, + options_rhs: &SortOptions, + nullable: bool, +) -> bool { + if nullable { + options_lhs == options_rhs + } else { // If the column is not nullable, NULLS FIRST/LAST is not important. - let nullable = self.expr.nullable(schema).unwrap_or(true); - self.expr.eq(&requirement.expr) - && if nullable { - requirement.options.is_none_or(|opts| self.options == opts) - } else { - requirement - .options - .is_none_or(|opts| self.options.descending == opts.descending) - } + options_lhs.descending == options_rhs.descending } } @@ -220,28 +248,8 @@ pub struct PhysicalSortRequirement { pub options: Option, } -impl From for PhysicalSortExpr { - /// If options is `None`, the default sort options `ASC, NULLS LAST` is used. - /// - /// The default is picked to be consistent with - /// PostgreSQL: - fn from(value: PhysicalSortRequirement) -> Self { - let options = value.options.unwrap_or(SortOptions { - descending: false, - nulls_first: false, - }); - PhysicalSortExpr::new(value.expr, options) - } -} - -impl From for PhysicalSortRequirement { - fn from(value: PhysicalSortExpr) -> Self { - PhysicalSortRequirement::new(value.expr, Some(value.options)) - } -} - impl PartialEq for PhysicalSortRequirement { - fn eq(&self, other: &PhysicalSortRequirement) -> bool { + fn eq(&self, other: &Self) -> bool { self.options == other.options && self.expr.eq(&other.expr) } } @@ -265,10 +273,10 @@ pub fn format_physical_sort_requirement_list( let mut iter = self.0.iter(); write!(f, "[")?; if let Some(expr) = iter.next() { - write!(f, "{}", expr)?; + write!(f, "{expr}")?; } for expr in iter { - write!(f, ", {}", expr)?; + write!(f, ", {expr}")?; } write!(f, "]")?; Ok(()) @@ -291,37 +299,16 @@ impl PhysicalSortRequirement { Self { expr, options } } - /// Replace the required expression for this requirement with the new one - pub fn with_expr(mut self, expr: Arc) -> Self { - self.expr = expr; - self - } - /// Returns whether this requirement is equal or more specific than `other`. - pub fn compatible(&self, other: &PhysicalSortRequirement) -> bool { + pub fn compatible(&self, other: &Self) -> bool { self.expr.eq(&other.expr) && other .options .is_none_or(|other_opts| self.options == Some(other_opts)) } - - #[deprecated(since = "43.0.0", note = "use LexRequirement::from_lex_ordering")] - pub fn from_sort_exprs<'a>( - ordering: impl IntoIterator, - ) -> LexRequirement { - let ordering = ordering.into_iter().cloned().collect(); - LexRequirement::from_lex_ordering(ordering) - } - #[deprecated(since = "43.0.0", note = "use LexOrdering::from_lex_requirement")] - pub fn to_sort_exprs( - requirements: impl IntoIterator, - ) -> LexOrdering { - let requirements = requirements.into_iter().collect(); - LexOrdering::from_lex_requirement(requirements) - } } -/// Returns the SQL string representation of the given [SortOptions] object. +/// Returns the SQL string representation of the given [`SortOptions`] object. #[inline] fn to_str(options: &SortOptions) -> &str { match (options.descending, options.nulls_first) { @@ -332,162 +319,147 @@ fn to_str(options: &SortOptions) -> &str { } } -///`LexOrdering` contains a `Vec`, which represents -/// a lexicographical ordering. +// Cross-conversion utilities between `PhysicalSortExpr` and `PhysicalSortRequirement` +impl From for PhysicalSortRequirement { + fn from(value: PhysicalSortExpr) -> Self { + Self::new(value.expr, Some(value.options)) + } +} + +impl From for PhysicalSortExpr { + /// The default sort options `ASC, NULLS LAST` when the requirement does + /// not specify sort options. This default is consistent with PostgreSQL. + /// + /// Reference: + fn from(value: PhysicalSortRequirement) -> Self { + let options = value + .options + .unwrap_or_else(|| SortOptions::new(false, false)); + Self::new(value.expr, options) + } +} + +/// This object represents a lexicographical ordering and contains a vector +/// of `PhysicalSortExpr` objects. /// -/// For example, `vec![a ASC, b DESC]` represents a lexicographical ordering +/// For example, a `vec![a ASC, b DESC]` represents a lexicographical ordering /// that first sorts by column `a` in ascending order, then by column `b` in /// descending order. -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +/// +/// # Invariants +/// +/// The following always hold true for a `LexOrdering`: +/// +/// 1. It is non-degenerate, meaning it contains at least one element. +/// 2. It is duplicate-free, meaning it does not contain multiple entries for +/// the same column. +#[derive(Debug, Clone)] pub struct LexOrdering { - inner: Vec, -} - -impl AsRef for LexOrdering { - fn as_ref(&self) -> &LexOrdering { - self - } + /// Vector of sort expressions representing the lexicographical ordering. + exprs: Vec, + /// Set of expressions in the lexicographical ordering, used to ensure + /// that the ordering is duplicate-free. Note that the elements in this + /// set are the same underlying physical expressions as in `exprs`. + set: HashSet>, } impl LexOrdering { - /// Creates a new [`LexOrdering`] from a vector - pub fn new(inner: Vec) -> Self { - Self { inner } - } - - /// Return an empty LexOrdering (no expressions) - pub fn empty() -> &'static LexOrdering { - static EMPTY_ORDER: LazyLock = LazyLock::new(LexOrdering::default); - &EMPTY_ORDER - } - - /// Returns the number of elements that can be stored in the LexOrdering - /// without reallocating. - pub fn capacity(&self) -> usize { - self.inner.capacity() - } - - /// Clears the LexOrdering, removing all elements. - pub fn clear(&mut self) { - self.inner.clear() - } - - /// Takes ownership of the actual vector of `PhysicalSortExpr`s in the LexOrdering. - pub fn take_exprs(self) -> Vec { - self.inner - } - - /// Returns `true` if the LexOrdering contains `expr` - pub fn contains(&self, expr: &PhysicalSortExpr) -> bool { - self.inner.contains(expr) - } - - /// Add all elements from `iter` to the LexOrdering. - pub fn extend>(&mut self, iter: I) { - self.inner.extend(iter) - } - - /// Remove all elements from the LexOrdering where `f` evaluates to `false`. - pub fn retain(&mut self, f: F) - where - F: FnMut(&PhysicalSortExpr) -> bool, - { - self.inner.retain(f) - } - - /// Returns `true` if the LexOrdering contains no elements. - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - /// Returns an iterator over each `&PhysicalSortExpr` in the LexOrdering. - pub fn iter(&self) -> core::slice::Iter { - self.inner.iter() - } - - /// Returns the number of elements in the LexOrdering. - pub fn len(&self) -> usize { - self.inner.len() - } - - /// Removes the last element from the LexOrdering and returns it, or `None` if it is empty. - pub fn pop(&mut self) -> Option { - self.inner.pop() + /// Creates a new [`LexOrdering`] from the given vector of sort expressions. + /// If the vector is empty, returns `None`. + pub fn new(exprs: impl IntoIterator) -> Option { + let exprs = exprs.into_iter(); + let mut candidate = Self { + // not valid yet; valid publicly-returned instance must be non-empty + exprs: Vec::new(), + set: HashSet::new(), + }; + for expr in exprs { + candidate.push(expr); + } + if candidate.exprs.is_empty() { + None + } else { + Some(candidate) + } } - /// Appends an element to the back of the LexOrdering. - pub fn push(&mut self, physical_sort_expr: PhysicalSortExpr) { - self.inner.push(physical_sort_expr) + /// Appends an element to the back of the `LexOrdering`. + pub fn push(&mut self, sort_expr: PhysicalSortExpr) { + if self.set.insert(Arc::clone(&sort_expr.expr)) { + self.exprs.push(sort_expr); + } } - /// Truncates the LexOrdering, keeping only the first `len` elements. - pub fn truncate(&mut self, len: usize) { - self.inner.truncate(len) + /// Add all elements from `iter` to the `LexOrdering`. + pub fn extend(&mut self, sort_exprs: impl IntoIterator) { + for sort_expr in sort_exprs { + self.push(sort_expr); + } } - /// Merge the contents of `other` into `self`, removing duplicates. - pub fn merge(mut self, other: LexOrdering) -> Self { - self.inner = self.inner.into_iter().chain(other).unique().collect(); - self + /// Returns the leading `PhysicalSortExpr` of the `LexOrdering`. Note that + /// this function does not return an `Option`, as a `LexOrdering` is always + /// non-degenerate (i.e. it contains at least one element). + pub fn first(&self) -> &PhysicalSortExpr { + // Can safely `unwrap` because `LexOrdering` is non-degenerate: + self.exprs.first().unwrap() } - /// Converts a `LexRequirement` into a `LexOrdering`. - /// - /// This function converts [`PhysicalSortRequirement`] to [`PhysicalSortExpr`] - /// for each entry in the input. - /// - /// If the required ordering is `None` for an entry in `requirement`, the - /// default ordering `ASC, NULLS LAST` is used (see - /// [`PhysicalSortExpr::from`]). - pub fn from_lex_requirement(requirement: LexRequirement) -> LexOrdering { - requirement - .into_iter() - .map(PhysicalSortExpr::from) - .collect() + /// Returns the number of elements that can be stored in the `LexOrdering` + /// without reallocating. + pub fn capacity(&self) -> usize { + self.exprs.capacity() } - /// Collapse a `LexOrdering` into a new duplicate-free `LexOrdering` based on expression. - /// - /// This function filters duplicate entries that have same physical - /// expression inside, ignoring [`SortOptions`]. For example: - /// - /// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. - pub fn collapse(self) -> Self { - let mut output = LexOrdering::default(); - for item in self { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } + /// Truncates the `LexOrdering`, keeping only the first `len` elements. + /// Returns `true` if truncation made a change, `false` otherwise. Negative + /// cases happen in two scenarios: (1) When `len` is greater than or equal + /// to the number of expressions inside this `LexOrdering`, making truncation + /// a no-op, or (2) when `len` is `0`, making truncation impossible. + pub fn truncate(&mut self, len: usize) -> bool { + if len == 0 || len >= self.exprs.len() { + return false; } - output - } - - /// Transforms each `PhysicalSortExpr` in the `LexOrdering` - /// in place using the provided closure `f`. - pub fn transform(&mut self, f: F) - where - F: FnMut(&mut PhysicalSortExpr), - { - self.inner.iter_mut().for_each(f); + for PhysicalSortExpr { expr, .. } in self.exprs[len..].iter() { + self.set.remove(expr); + } + self.exprs.truncate(len); + true } } -impl From> for LexOrdering { - fn from(value: Vec) -> Self { - Self::new(value) +impl PartialEq for LexOrdering { + fn eq(&self, other: &Self) -> bool { + let Self { + exprs, + set: _, // derived from `exprs` + } = self; + // PartialEq must be consistent with PartialOrd + exprs == &other.exprs } } - -impl From for LexOrdering { - fn from(value: LexRequirement) -> Self { - Self::from_lex_requirement(value) +impl Eq for LexOrdering {} +impl PartialOrd for LexOrdering { + /// There is a partial ordering among `LexOrdering` objects. For example, the + /// ordering `[a ASC]` is coarser (less) than ordering `[a ASC, b ASC]`. + /// If two orderings do not share a prefix, they are incomparable. + fn partial_cmp(&self, other: &Self) -> Option { + // PartialEq must be consistent with PartialOrd + self.exprs + .iter() + .zip(other.exprs.iter()) + .all(|(lhs, rhs)| lhs == rhs) + .then(|| self.len().cmp(&other.len())) } } -/// Convert a `LexOrdering` into a `Arc[]` for fast copies -impl From for Arc<[PhysicalSortExpr]> { - fn from(value: LexOrdering) -> Self { - value.inner.into() +impl From<[PhysicalSortExpr; N]> for LexOrdering { + fn from(value: [PhysicalSortExpr; N]) -> Self { + // TODO: Replace this assertion with a condition on the generic parameter + // when Rust supports it. + assert!(N > 0); + Self::new(value) + .expect("A LexOrdering from non-empty array must be non-degenerate") } } @@ -495,181 +467,268 @@ impl Deref for LexOrdering { type Target = [PhysicalSortExpr]; fn deref(&self) -> &Self::Target { - self.inner.as_slice() + self.exprs.as_slice() } } impl Display for LexOrdering { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut first = true; - for sort_expr in &self.inner { + for sort_expr in &self.exprs { if first { first = false; } else { write!(f, ", ")?; } - write!(f, "{}", sort_expr)?; + write!(f, "{sort_expr}")?; } Ok(()) } } -impl FromIterator for LexOrdering { - fn from_iter>(iter: T) -> Self { - let mut lex_ordering = LexOrdering::default(); - - for i in iter { - lex_ordering.push(i); - } +impl IntoIterator for LexOrdering { + type Item = PhysicalSortExpr; + type IntoIter = IntoIter; - lex_ordering + fn into_iter(self) -> Self::IntoIter { + self.exprs.into_iter() } } -impl Index for LexOrdering { - type Output = PhysicalSortExpr; +impl<'a> IntoIterator for &'a LexOrdering { + type Item = &'a PhysicalSortExpr; + type IntoIter = std::slice::Iter<'a, PhysicalSortExpr>; - fn index(&self, index: usize) -> &Self::Output { - &self.inner[index] + fn into_iter(self) -> Self::IntoIter { + self.exprs.iter() } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; - - fn index(&self, range: Range) -> &Self::Output { - &self.inner[range] +impl From for Vec { + fn from(ordering: LexOrdering) -> Self { + ordering.exprs } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; +/// This object represents a lexicographical ordering requirement and contains +/// a vector of `PhysicalSortRequirement` objects. +/// +/// For example, a `vec![a Some(ASC), b None]` represents a lexicographical +/// requirement that firsts imposes an ordering by column `a` in ascending +/// order, then by column `b` in *any* (ascending or descending) order. The +/// ordering is non-degenerate, meaning it contains at least one element, and +/// it is duplicate-free, meaning it does not contain multiple entries for the +/// same column. +/// +/// Note that a `LexRequirement` need not enforce the uniqueness of its sort +/// expressions after construction like a `LexOrdering` does, because it provides +/// no mutation methods. If such methods become necessary, we will need to +/// enforce uniqueness like the latter object. +#[derive(Debug, Clone, PartialEq)] +pub struct LexRequirement { + reqs: Vec, +} - fn index(&self, range_from: RangeFrom) -> &Self::Output { - &self.inner[range_from] +impl LexRequirement { + /// Creates a new [`LexRequirement`] from the given vector of sort expressions. + /// If the vector is empty, returns `None`. + pub fn new(reqs: impl IntoIterator) -> Option { + let (non_empty, requirements) = Self::construct(reqs); + non_empty.then_some(requirements) + } + + /// Returns the leading `PhysicalSortRequirement` of the `LexRequirement`. + /// Note that this function does not return an `Option`, as a `LexRequirement` + /// is always non-degenerate (i.e. it contains at least one element). + pub fn first(&self) -> &PhysicalSortRequirement { + // Can safely `unwrap` because `LexRequirement` is non-degenerate: + self.reqs.first().unwrap() + } + + /// Constructs a new `LexRequirement` from the given sort requirements w/o + /// enforcing non-degeneracy. This function is used internally and is not + /// meant (or safe) for external use. + fn construct( + reqs: impl IntoIterator, + ) -> (bool, Self) { + let mut set = HashSet::new(); + let reqs = reqs + .into_iter() + .filter_map(|r| set.insert(Arc::clone(&r.expr)).then_some(r)) + .collect(); + (!set.is_empty(), Self { reqs }) } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; - - fn index(&self, range_to: RangeTo) -> &Self::Output { - &self.inner[range_to] +impl From<[PhysicalSortRequirement; N]> for LexRequirement { + fn from(value: [PhysicalSortRequirement; N]) -> Self { + // TODO: Replace this assertion with a condition on the generic parameter + // when Rust supports it. + assert!(N > 0); + let (non_empty, requirement) = Self::construct(value); + debug_assert!(non_empty); + requirement } } -impl IntoIterator for LexOrdering { - type Item = PhysicalSortExpr; - type IntoIter = IntoIter; +impl Deref for LexRequirement { + type Target = [PhysicalSortRequirement]; - fn into_iter(self) -> Self::IntoIter { - self.inner.into_iter() + fn deref(&self) -> &Self::Target { + self.reqs.as_slice() } } -///`LexOrderingRef` is an alias for the type &`[PhysicalSortExpr]`, which represents -/// a reference to a lexicographical ordering. -#[deprecated(since = "43.0.0", note = "use &LexOrdering instead")] -pub type LexOrderingRef<'a> = &'a [PhysicalSortExpr]; +impl IntoIterator for LexRequirement { + type Item = PhysicalSortRequirement; + type IntoIter = IntoIter; -///`LexRequirement` is an struct containing a `Vec`, which -/// represents a lexicographical ordering requirement. -#[derive(Debug, Default, Clone, PartialEq)] -pub struct LexRequirement { - pub inner: Vec, + fn into_iter(self) -> Self::IntoIter { + self.reqs.into_iter() + } } -impl LexRequirement { - pub fn new(inner: Vec) -> Self { - Self { inner } - } +impl<'a> IntoIterator for &'a LexRequirement { + type Item = &'a PhysicalSortRequirement; + type IntoIter = std::slice::Iter<'a, PhysicalSortRequirement>; - pub fn is_empty(&self) -> bool { - self.inner.is_empty() + fn into_iter(self) -> Self::IntoIter { + self.reqs.iter() } +} - pub fn iter(&self) -> impl Iterator { - self.inner.iter() +impl From for Vec { + fn from(requirement: LexRequirement) -> Self { + requirement.reqs } +} - pub fn push(&mut self, physical_sort_requirement: PhysicalSortRequirement) { - self.inner.push(physical_sort_requirement) +// Cross-conversion utilities between `LexOrdering` and `LexRequirement` +impl From for LexRequirement { + fn from(value: LexOrdering) -> Self { + // Can construct directly as `value` is non-degenerate: + let (non_empty, requirements) = + Self::construct(value.into_iter().map(Into::into)); + debug_assert!(non_empty); + requirements } +} - /// Create a new [`LexRequirement`] from a [`LexOrdering`] - /// - /// Returns [`LexRequirement`] that requires the exact - /// sort of the [`PhysicalSortExpr`]s in `ordering` - pub fn from_lex_ordering(ordering: LexOrdering) -> Self { - Self::new( - ordering - .into_iter() - .map(PhysicalSortRequirement::from) - .collect(), - ) +impl From for LexOrdering { + fn from(value: LexRequirement) -> Self { + // Can construct directly as `value` is non-degenerate + Self::new(value.into_iter().map(Into::into)) + .expect("A LexOrdering from LexRequirement must be non-degenerate") } +} - /// Constructs a duplicate-free `LexOrderingReq` by filtering out - /// duplicate entries that have same physical expression inside. - /// - /// For example, `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a - /// Some(ASC)]`. - pub fn collapse(self) -> Self { - let mut output = Vec::::new(); - for item in self { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); +/// Represents a plan's input ordering requirements. Vector elements represent +/// alternative ordering requirements in the order of preference. The list of +/// alternatives can be either hard or soft, depending on whether the operator +/// can work without an input ordering. +/// +/// # Invariants +/// +/// The following always hold true for a `OrderingRequirements`: +/// +/// 1. It is non-degenerate, meaning it contains at least one ordering. The +/// absence of an input ordering requirement is represented by a `None` value +/// in `ExecutionPlan` APIs, which return an `Option`. +#[derive(Debug, Clone, PartialEq)] +pub enum OrderingRequirements { + /// The operator is not able to work without one of these requirements. + Hard(Vec), + /// The operator can benefit from these input orderings when available, + /// but can still work in the absence of any input ordering. + Soft(Vec), +} + +impl OrderingRequirements { + /// Creates a new instance from the given alternatives. If an empty list of + /// alternatives are given, returns `None`. + pub fn new_alternatives( + alternatives: impl IntoIterator, + soft: bool, + ) -> Option { + let alternatives = alternatives.into_iter().collect::>(); + (!alternatives.is_empty()).then(|| { + if soft { + Self::Soft(alternatives) + } else { + Self::Hard(alternatives) } - } - LexRequirement::new(output) + }) } -} -impl From for LexRequirement { - fn from(value: LexOrdering) -> Self { - Self::from_lex_ordering(value) + /// Creates a new instance with a single hard requirement. + pub fn new(requirement: LexRequirement) -> Self { + Self::Hard(vec![requirement]) } -} -impl Deref for LexRequirement { - type Target = [PhysicalSortRequirement]; + /// Creates a new instance with a single soft requirement. + pub fn new_soft(requirement: LexRequirement) -> Self { + Self::Soft(vec![requirement]) + } - fn deref(&self) -> &Self::Target { - self.inner.as_slice() + /// Adds an alternative requirement to the list of alternatives. + pub fn add_alternative(&mut self, requirement: LexRequirement) { + match self { + Self::Hard(alts) | Self::Soft(alts) => alts.push(requirement), + } } -} -impl FromIterator for LexRequirement { - fn from_iter>(iter: T) -> Self { - let mut lex_requirement = LexRequirement::new(vec![]); + /// Returns the first (i.e. most preferred) `LexRequirement` among + /// alternative requirements. + pub fn into_single(self) -> LexRequirement { + match self { + Self::Hard(mut alts) | Self::Soft(mut alts) => alts.swap_remove(0), + } + } - for i in iter { - lex_requirement.inner.push(i); + /// Returns a reference to the first (i.e. most preferred) `LexRequirement` + /// among alternative requirements. + pub fn first(&self) -> &LexRequirement { + match self { + Self::Hard(alts) | Self::Soft(alts) => &alts[0], } + } - lex_requirement + /// Returns all alternatives as a vector of `LexRequirement` objects and a + /// boolean value indicating softness/hardness of the requirements. + pub fn into_alternatives(self) -> (Vec, bool) { + match self { + Self::Hard(alts) => (alts, false), + Self::Soft(alts) => (alts, true), + } } } -impl IntoIterator for LexRequirement { - type Item = PhysicalSortRequirement; - type IntoIter = IntoIter; +impl From for OrderingRequirements { + fn from(requirement: LexRequirement) -> Self { + Self::new(requirement) + } +} - fn into_iter(self) -> Self::IntoIter { - self.inner.into_iter() +impl From for OrderingRequirements { + fn from(ordering: LexOrdering) -> Self { + Self::new(ordering.into()) } } -impl<'a> IntoIterator for &'a LexOrdering { - type Item = &'a PhysicalSortExpr; - type IntoIter = std::slice::Iter<'a, PhysicalSortExpr>; +impl Deref for OrderingRequirements { + type Target = [LexRequirement]; - fn into_iter(self) -> Self::IntoIter { - self.inner.iter() + fn deref(&self) -> &Self::Target { + match &self { + Self::Hard(alts) | Self::Soft(alts) => alts.as_slice(), + } } } -///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which -/// represents a reference to a lexicographical ordering requirement. -/// #[deprecated(since = "43.0.0", note = "use &LexRequirement instead")] -pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; +impl DerefMut for OrderingRequirements { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::Hard(alts) | Self::Soft(alts) => alts.as_mut_slice(), + } + } +} diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 114007bfa6afb..05b216ab75ebc 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,16 +17,14 @@ use std::sync::Arc; +use crate::physical_expr::PhysicalExpr; +use crate::tree_node::ExprContext; + use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; - use datafusion_common::Result; use datafusion_expr_common::sort_properties::ExprProperties; -use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; -use crate::tree_node::ExprContext; - /// Represents a [`PhysicalExpr`] node with associated properties (order and /// range) in a context where properties are tracked. pub type ExprPropertiesNode = ExprContext; @@ -93,16 +91,6 @@ pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { Ok(make_array(data)) } -/// Reverses the ORDER BY expression, which is useful during equivalent window -/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into -/// 'ORDER BY a DESC, NULLS FIRST'. -pub fn reverse_order_bys(order_bys: &LexOrdering) -> LexOrdering { - order_bys - .iter() - .map(|e| PhysicalSortExpr::new(Arc::clone(&e.expr), !e.options)) - .collect() -} - #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 72baa0db00a21..b7654a0f6f603 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -40,7 +40,7 @@ name = "datafusion_physical_expr" [dependencies] ahash = { workspace = true } arrow = { workspace = true } -datafusion-common = { workspace = true, default-features = true } +datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } @@ -49,14 +49,15 @@ half = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } -log = { workspace = true } +parking_lot = { workspace = true } paste = "^1.0" -petgraph = "0.7.1" +petgraph = "0.8.3" [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } datafusion-functions = { workspace = true } +insta = { workspace = true } rand = { workspace = true } rstest = { workspace = true } @@ -71,3 +72,7 @@ name = "case_when" [[bench]] harness = false name = "is_null" + +[[bench]] +harness = false +name = "binary_op" diff --git a/datafusion/physical-expr/README.md b/datafusion/physical-expr/README.md index 424256c77e7e2..4c79223b09b8c 100644 --- a/datafusion/physical-expr/README.md +++ b/datafusion/physical-expr/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Physical Expressions +# Apache DataFusion Physical Expressions -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that provides data types and utilities for physical expressions. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-expr/benches/binary_op.rs b/datafusion/physical-expr/benches/binary_op.rs new file mode 100644 index 0000000000000..5b0f700fdb8aa --- /dev/null +++ b/datafusion/physical-expr/benches/binary_op.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::BooleanArray, + datatypes::{DataType, Field, Schema}, +}; +use arrow::{array::StringArray, record_batch::RecordBatch}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::{and, binary_expr, col, lit, or, Operator}; +use datafusion_physical_expr::{ + expressions::{BinaryExpr, Column}, + planner::logical2physical, + PhysicalExpr, +}; +use std::sync::Arc; + +/// Generates BooleanArrays with different true/false distributions for benchmarking. +/// +/// Returns a vector of tuples containing scenario name and corresponding BooleanArray. +/// +/// # Arguments +/// - `TEST_ALL_FALSE` - Used to generate what kind of test data +/// - `len` - Length of the BooleanArray to generate +fn generate_boolean_cases( + len: usize, +) -> Vec<(String, BooleanArray)> { + let mut cases = Vec::with_capacity(6); + + // Scenario 1: All elements false or all elements true + if TEST_ALL_FALSE { + let all_false = BooleanArray::from(vec![false; len]); + cases.push(("all_false".to_string(), all_false)); + } else { + let all_true = BooleanArray::from(vec![true; len]); + cases.push(("all_true".to_string(), all_true)); + } + + // Scenario 2: Single true at first position or single false at first position + if TEST_ALL_FALSE { + let mut first_true = vec![false; len]; + first_true[0] = true; + cases.push(("one_true_first".to_string(), BooleanArray::from(first_true))); + } else { + let mut first_false = vec![true; len]; + first_false[0] = false; + cases.push(( + "one_false_first".to_string(), + BooleanArray::from(first_false), + )); + } + + // Scenario 3: Single true at last position or single false at last position + if TEST_ALL_FALSE { + let mut last_true = vec![false; len]; + last_true[len - 1] = true; + cases.push(("one_true_last".to_string(), BooleanArray::from(last_true))); + } else { + let mut last_false = vec![true; len]; + last_false[len - 1] = false; + cases.push(("one_false_last".to_string(), BooleanArray::from(last_false))); + } + + // Scenario 4: Single true at exact middle or single false at exact middle + let mid = len / 2; + if TEST_ALL_FALSE { + let mut mid_true = vec![false; len]; + mid_true[mid] = true; + cases.push(("one_true_middle".to_string(), BooleanArray::from(mid_true))); + } else { + let mut mid_false = vec![true; len]; + mid_false[mid] = false; + cases.push(( + "one_false_middle".to_string(), + BooleanArray::from(mid_false), + )); + } + + // Scenario 5: Single true at 25% position or single false at 25% position + let mid_left = len / 4; + if TEST_ALL_FALSE { + let mut mid_left_true = vec![false; len]; + mid_left_true[mid_left] = true; + cases.push(( + "one_true_middle_left".to_string(), + BooleanArray::from(mid_left_true), + )); + } else { + let mut mid_left_false = vec![true; len]; + mid_left_false[mid_left] = false; + cases.push(( + "one_false_middle_left".to_string(), + BooleanArray::from(mid_left_false), + )); + } + + // Scenario 6: Single true at 75% position or single false at 75% position + let mid_right = (3 * len) / 4; + if TEST_ALL_FALSE { + let mut mid_right_true = vec![false; len]; + mid_right_true[mid_right] = true; + cases.push(( + "one_true_middle_right".to_string(), + BooleanArray::from(mid_right_true), + )); + } else { + let mut mid_right_false = vec![true; len]; + mid_right_false[mid_right] = false; + cases.push(( + "one_false_middle_right".to_string(), + BooleanArray::from(mid_right_false), + )); + } + + // Scenario 7: Test all true or all false in AND/OR + // This situation won't cause a short circuit, but it can skip the bool calculation + if TEST_ALL_FALSE { + let all_true = vec![true; len]; + cases.push(("all_true_in_and".to_string(), BooleanArray::from(all_true))); + } else { + let all_false = vec![false; len]; + cases.push(("all_false_in_or".to_string(), BooleanArray::from(all_false))); + } + + cases +} + +/// Benchmarks AND/OR operator short-circuiting by evaluating complex regex conditions. +/// +/// Creates 7 test scenarios per operator: +/// 1. All values enable short-circuit (all_true/all_false) +/// 2. 2-6 Single true/false value at different positions to measure early exit +/// 3. Test all true or all false in AND/OR +/// +/// You can run this benchmark with: +/// ```sh +/// cargo bench --bench binary_op -- short_circuit +/// ``` +fn benchmark_binary_op_in_short_circuit(c: &mut Criterion) { + // Create schema with three columns + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])); + + // Generate test data with extended content + let (b_values, c_values) = generate_test_strings(8192); + + let batches_and = + create_record_batch::(schema.clone(), &b_values, &c_values).unwrap(); + let batches_or = + create_record_batch::(schema.clone(), &b_values, &c_values).unwrap(); + + // Build complex string matching conditions + let right_condition_and = and( + // Check for API endpoint pattern in URLs + binary_expr( + col("b"), + Operator::RegexMatch, + lit(r#"^https://(\w+\.)?example\.(com|org)/"#), + ), + // Check for markdown code blocks and summary section + binary_expr( + col("c"), + Operator::RegexMatch, + lit("```(rust|python|go)\nfn? main$$"), + ), + ); + + let right_condition_or = or( + // Check for secure HTTPS protocol + binary_expr( + col("b"), + Operator::RegexMatch, + lit(r#"^https://(\w+\.)?example\.(com|org)/"#), + ), + // Check for Rust code examples + binary_expr( + col("c"), + Operator::RegexMatch, + lit("```(rust|python|go)\nfn? main$$"), + ), + ); + + // Create physical binary expressions + // a AND ((b ~ regex) AND (c ~ regex)) + let expr_and = BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::And, + logical2physical(&right_condition_and, &schema), + ); + + // a OR ((b ~ regex) OR (c ~ regex)) + let expr_or = BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + logical2physical(&right_condition_or, &schema), + ); + + // Each scenario when the test operator is `and` + { + for (name, batch) in batches_and.into_iter() { + c.bench_function(&format!("short_circuit/and/{name}"), |b| { + b.iter(|| expr_and.evaluate(black_box(&batch)).unwrap()) + }); + } + } + // Each scenario when the test operator is `or` + { + for (name, batch) in batches_or.into_iter() { + c.bench_function(&format!("short_circuit/or/{name}"), |b| { + b.iter(|| expr_or.evaluate(black_box(&batch)).unwrap()) + }); + } + } +} + +/// Generate test data with computationally expensive patterns +fn generate_test_strings(num_rows: usize) -> (Vec, Vec) { + // Extended URL patterns with query parameters and paths + let base_urls = [ + "https://api.example.com/v2/users/12345/posts?category=tech&sort=date&lang=en-US", + "https://cdn.example.net/assets/images/2023/08/15/sample-image-highres.jpg?width=1920&quality=85", + "http://service.demo.org:8080/api/data/transactions/20230815123456.csv", + "ftp://legacy.archive.example/backups/2023/Q3/database-dump.sql.gz", + "https://docs.example.co.uk/reference/advanced-topics/concurrency/parallel-processing.md#implementation-details", + ]; + + // Extended markdown content with code blocks and structure + let base_markdowns = [ + concat!( + "# Advanced Topics in Computer Science\n\n", + "## Summary\nThis article explores complex system design patterns and...\n\n", + "```rust\nfn process_data(data: &mut [i32]) {\n // Parallel processing example\n data.par_iter_mut().for_each(|x| *x *= 2);\n}\n```\n\n", + "## Performance Considerations\nWhen implementing concurrent systems...\n" + ), + concat!( + "## API Documentation\n\n", + "```json\n{\n \"endpoint\": \"/api/v2/users\",\n \"methods\": [\"GET\", \"POST\"],\n \"parameters\": {\n \"page\": \"number\"\n }\n}\n```\n\n", + "# Authentication Guide\nSecure your API access using OAuth 2.0...\n" + ), + concat!( + "# Data Processing Pipeline\n\n", + "```python\nfrom multiprocessing import Pool\n\ndef main():\n with Pool(8) as p:\n results = p.map(process_item, data)\n```\n\n", + "## Summary of Optimizations\n1. Batch processing\n2. Memory pooling\n3. Concurrent I/O operations\n" + ), + concat!( + "# System Architecture Overview\n\n", + "## Components\n- Load Balancer\n- Database Cluster\n- Cache Service\n\n", + "```go\nfunc main() {\n router := gin.Default()\n router.GET(\"/api/health\", healthCheck)\n router.Run(\":8080\")\n}\n```\n" + ), + concat!( + "## Configuration Reference\n\n", + "```yaml\nserver:\n port: 8080\n max_threads: 32\n\ndatabase:\n url: postgres://user@prod-db:5432/main\n```\n\n", + "# Deployment Strategies\nBlue-green deployment patterns with...\n" + ), + ]; + + let mut urls = Vec::with_capacity(num_rows); + let mut markdowns = Vec::with_capacity(num_rows); + + for i in 0..num_rows { + urls.push(base_urls[i % 5].to_string()); + markdowns.push(base_markdowns[i % 5].to_string()); + } + + (urls, markdowns) +} + +/// Creates record batches with boolean arrays that test different short-circuit scenarios. +/// When TEST_ALL_FALSE = true: creates data for AND operator benchmarks (needs early false exit) +/// When TEST_ALL_FALSE = false: creates data for OR operator benchmarks (needs early true exit) +fn create_record_batch( + schema: Arc, + b_values: &[String], + c_values: &[String], +) -> arrow::error::Result> { + // Generate data for six scenarios, but only the data for the "all_false" and "all_true" cases can be optimized through short-circuiting + let boolean_array = generate_boolean_cases::(b_values.len()); + let mut rbs = Vec::with_capacity(boolean_array.len()); + for (name, a_array) in boolean_array { + let b_array = StringArray::from(b_values.to_vec()); + let c_array = StringArray::from(c_values.to_vec()); + rbs.push(( + name, + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(a_array), Arc::new(b_array), Arc::new(c_array)], + )?, + )); + } + Ok(rbs) +} + +criterion_group!(benches, benchmark_binary_op_in_short_circuit); + +criterion_main!(benches); diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 90bfc5efb61e8..e91e8d1f137c1 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -21,7 +21,7 @@ use arrow::record_batch::RecordBatch; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::{col, in_list, lit}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::*; use std::sync::Arc; @@ -51,7 +51,7 @@ fn do_benches( for string_length in [5, 10, 20] { let values: StringArray = (0..array_length) .map(|_| { - rng.gen_bool(null_percent) + rng.random_bool(null_percent) .then(|| random_string(&mut rng, string_length)) }) .collect(); @@ -71,11 +71,11 @@ fn do_benches( } let values: Float32Array = (0..array_length) - .map(|_| rng.gen_bool(null_percent).then(|| rng.gen())) + .map(|_| rng.random_bool(null_percent).then(|| rng.random())) .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Float32(Some(rng.gen()))) + .map(|_| ScalarValue::Float32(Some(rng.random()))) .collect(); do_bench( @@ -86,11 +86,11 @@ fn do_benches( ); let values: Int32Array = (0..array_length) - .map(|_| rng.gen_bool(null_percent).then(|| rng.gen())) + .map(|_| rng.random_bool(null_percent).then(|| rng.random())) .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Int32(Some(rng.gen()))) + .map(|_| ScalarValue::Int32(Some(rng.random()))) .collect(); do_bench( diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index ae3d9050fa628..19d2ecc924ddc 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -28,10 +28,9 @@ pub(crate) mod stats { pub use datafusion_functions_aggregate_common::stats::StatsType; } pub mod utils { - #[allow(deprecated)] // allow adjust_output_array pub use datafusion_functions_aggregate_common::utils::{ - adjust_output_array, get_accum_scalar_values_as_arrays, get_sort_options, - ordering_fields, DecimalAverager, Hashable, + get_accum_scalar_values_as_arrays, get_sort_options, ordering_fields, + DecimalAverager, Hashable, }; } @@ -41,7 +40,7 @@ use std::sync::Arc; use crate::expressions::Column; use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; use datafusion_expr_common::accumulator::Accumulator; @@ -52,8 +51,7 @@ use datafusion_functions_aggregate_common::accumulator::{ }; use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr_common::utils::reverse_order_bys; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// Builder for physical [`AggregateFunctionExpr`] /// @@ -70,7 +68,7 @@ pub struct AggregateExprBuilder { /// Arrow Schema for the aggregate function schema: SchemaRef, /// The physical order by expressions - ordering_req: LexOrdering, + order_bys: Vec, /// Whether to ignore null values ignore_nulls: bool, /// Whether is distinct aggregate function @@ -87,7 +85,7 @@ impl AggregateExprBuilder { alias: None, human_display: String::default(), schema: Arc::new(Schema::empty()), - ordering_req: LexOrdering::default(), + order_bys: vec![], ignore_nulls: false, is_distinct: false, is_reversed: false, @@ -97,6 +95,97 @@ impl AggregateExprBuilder { /// Constructs an `AggregateFunctionExpr` from the builder /// /// Note that an [`Self::alias`] must be provided before calling this method. + /// + /// # Example: Create an [`AggregateUDF`] + /// + /// In the following example, [`AggregateFunctionExpr`] will be built using [`AggregateExprBuilder`] + /// which provides a build function. Full example could be accessed from the source file. + /// + /// ``` + /// # use std::any::Any; + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, FieldRef}; + /// # use datafusion_common::{Result, ScalarValue}; + /// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility, Expr}; + /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; + /// # use arrow::datatypes::Field; + /// # + /// # #[derive(Debug, Clone, PartialEq, Eq, Hash)] + /// # struct FirstValueUdf { + /// # signature: Signature, + /// # } + /// # + /// # impl FirstValueUdf { + /// # fn new() -> Self { + /// # Self { + /// # signature: Signature::any(1, Volatility::Immutable), + /// # } + /// # } + /// # } + /// # + /// # impl AggregateUDFImpl for FirstValueUdf { + /// # fn as_any(&self) -> &dyn Any { + /// # unimplemented!() + /// # } + /// # + /// # fn name(&self) -> &str { + /// # unimplemented!() + /// # } + /// # + /// # fn signature(&self) -> &Signature { + /// # unimplemented!() + /// # } + /// # + /// # fn return_type(&self, args: &[DataType]) -> Result { + /// # unimplemented!() + /// # } + /// # + /// # fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + /// # unimplemented!() + /// # } + /// # + /// # fn state_fields(&self, args: StateFieldsArgs) -> Result> { + /// # unimplemented!() + /// # } + /// # + /// # fn documentation(&self) -> Option<&Documentation> { + /// # unimplemented!() + /// # } + /// # } + /// # + /// # let first_value = AggregateUDF::from(FirstValueUdf::new()); + /// # let expr = first_value.call(vec![col("a")]); + /// # + /// # use datafusion_physical_expr::expressions::Column; + /// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + /// # use datafusion_physical_expr::aggregate::AggregateExprBuilder; + /// # use datafusion_physical_expr::expressions::PhysicalSortExpr; + /// # use datafusion_physical_expr::PhysicalSortRequirement; + /// # + /// fn build_aggregate_expr() -> Result<()> { + /// let args = vec![Arc::new(Column::new("a", 0)) as Arc]; + /// let order_by = vec![PhysicalSortExpr { + /// expr: Arc::new(Column::new("x", 1)) as Arc, + /// options: Default::default(), + /// }]; + /// + /// let first_value = AggregateUDF::from(FirstValueUdf::new()); + /// + /// let aggregate_expr = AggregateExprBuilder::new( + /// Arc::new(first_value), + /// args + /// ) + /// .order_by(order_by) + /// .alias("first_a_by_x") + /// .ignore_nulls() + /// .build()?; + /// + /// Ok(()) + /// } + /// ``` + /// + /// This creates a physical expression equivalent to SQL: + /// `first_value(a ORDER BY x) IGNORE NULLS AS first_a_by_x` pub fn build(self) -> Result { let Self { fun, @@ -104,7 +193,7 @@ impl AggregateExprBuilder { alias, human_display, schema, - ordering_req, + order_bys, ignore_nulls, is_distinct, is_reversed, @@ -113,30 +202,25 @@ impl AggregateExprBuilder { return internal_err!("args should not be empty"); } - let mut ordering_fields = vec![]; - - if !ordering_req.is_empty() { - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(&schema)) - .collect::>>()?; + let ordering_types = order_bys + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; - ordering_fields = - utils::ordering_fields(ordering_req.as_ref(), &ordering_types); - } + let ordering_fields = utils::ordering_fields(&order_bys, &ordering_types); - let input_exprs_types = args + let input_exprs_fields = args .iter() - .map(|arg| arg.data_type(&schema)) + .map(|arg| arg.return_field(&schema)) .collect::>>()?; check_arg_count( fun.name(), - &input_exprs_types, + &input_exprs_fields, &fun.signature().type_signature, )?; - let data_type = fun.return_type(&input_exprs_types)?; + let return_field = fun.return_field(&input_exprs_fields)?; let is_nullable = fun.is_nullable(); let name = match alias { None => { @@ -150,15 +234,15 @@ impl AggregateExprBuilder { Ok(AggregateFunctionExpr { fun: Arc::unwrap_or_clone(fun), args, - data_type, + return_field, name, human_display, schema: Arc::unwrap_or_clone(schema), - ordering_req, + order_bys, ignore_nulls, ordering_fields, is_distinct, - input_types: input_exprs_types, + input_fields: input_exprs_fields, is_reversed, is_nullable, }) @@ -179,8 +263,8 @@ impl AggregateExprBuilder { self } - pub fn order_by(mut self, order_by: LexOrdering) -> Self { - self.ordering_req = order_by; + pub fn order_by(mut self, order_bys: Vec) -> Self { + self.order_bys = order_bys; self } @@ -222,22 +306,22 @@ impl AggregateExprBuilder { pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, - /// Output / return type of this aggregate - data_type: DataType, + /// Output / return field of this aggregate + return_field: FieldRef, /// Output column name that this expression creates name: String, /// Simplified name for `tree` explain. human_display: String, schema: Schema, // The physical order by expressions - ordering_req: LexOrdering, + order_bys: Vec, // Whether to ignore null values ignore_nulls: bool, // fields used for order sensitive aggregation functions - ordering_fields: Vec, + ordering_fields: Vec, is_distinct: bool, is_reversed: bool, - input_types: Vec, + input_fields: Vec, is_nullable: bool, } @@ -284,8 +368,12 @@ impl AggregateFunctionExpr { } /// the field of the final result of this aggregation. - pub fn field(&self) -> Field { - Field::new(&self.name, self.data_type.clone(), self.is_nullable) + pub fn field(&self) -> FieldRef { + self.return_field + .as_ref() + .clone() + .with_name(&self.name) + .into() } /// the accumulator used to accumulate values from the expressions. @@ -293,10 +381,10 @@ impl AggregateFunctionExpr { /// return states with the same description as `state_fields` pub fn create_accumulator(&self) -> Result> { let acc_args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -307,11 +395,11 @@ impl AggregateFunctionExpr { } /// the field of the final result of this aggregation. - pub fn state_fields(&self) -> Result> { + pub fn state_fields(&self) -> Result> { let args = StateFieldsArgs { name: &self.name, - input_types: &self.input_types, - return_type: &self.data_type, + input_fields: &self.input_fields, + return_field: Arc::clone(&self.return_field), ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, }; @@ -319,31 +407,24 @@ impl AggregateFunctionExpr { self.fun.state_fields(args) } - /// Order by requirements for the aggregate function - /// By default it is `None` (there is no requirement) - /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this - pub fn order_bys(&self) -> Option<&LexOrdering> { - if self.ordering_req.is_empty() { - return None; - } - - if !self.order_sensitivity().is_insensitive() { - return Some(self.ordering_req.as_ref()); + /// Returns the ORDER BY expressions for the aggregate function. + pub fn order_bys(&self) -> &[PhysicalSortExpr] { + if self.order_sensitivity().is_insensitive() { + &[] + } else { + &self.order_bys } - - None } /// Indicates whether aggregator can produce the correct result with any /// arbitrary input ordering. By default, we assume that aggregate expressions /// are order insensitive. pub fn order_sensitivity(&self) -> AggregateOrderSensitivity { - if !self.ordering_req.is_empty() { - // If there is requirement, use the sensitivity of the implementation - self.fun.order_sensitivity() - } else { - // If no requirement, aggregator is order insensitive + if self.order_bys.is_empty() { AggregateOrderSensitivity::Insensitive + } else { + // If there is an ORDER BY clause, use the sensitivity of the implementation: + self.fun.order_sensitivity() } } @@ -371,7 +452,7 @@ impl AggregateFunctionExpr { }; AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec()) - .order_by(self.ordering_req.clone()) + .order_by(self.order_bys.clone()) .schema(Arc::new(self.schema.clone())) .alias(self.name().to_string()) .with_ignore_nulls(self.ignore_nulls) @@ -384,10 +465,10 @@ impl AggregateFunctionExpr { /// Creates accumulator implementation that supports retract pub fn create_sliding_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -453,10 +534,10 @@ impl AggregateFunctionExpr { /// `[Self::create_groups_accumulator`] will be called. pub fn groups_accumulator_supported(&self) -> bool { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -472,10 +553,10 @@ impl AggregateFunctionExpr { /// implemented in addition to [`Accumulator`]. pub fn create_groups_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -493,18 +574,16 @@ impl AggregateFunctionExpr { ReversedUDAF::NotSupported => None, ReversedUDAF::Identical => Some(self.clone()), ReversedUDAF::Reversed(reverse_udf) => { - let reverse_ordering_req = reverse_order_bys(self.ordering_req.as_ref()); let mut name = self.name().to_string(); // If the function is changed, we need to reverse order_by clause as well // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) - if self.fun().name() == reverse_udf.name() { - } else { + if self.fun().name() != reverse_udf.name() { replace_order_by_clause(&mut name); } replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); AggregateExprBuilder::new(reverse_udf, self.args.to_vec()) - .order_by(reverse_ordering_req) + .order_by(self.order_bys.iter().map(|e| e.reverse()).collect()) .schema(Arc::new(self.schema.clone())) .alias(name) .with_ignore_nulls(self.ignore_nulls) @@ -520,14 +599,11 @@ impl AggregateFunctionExpr { /// These expressions are (1)function arguments, (2) order by expressions. pub fn all_expressions(&self) -> AggregatePhysicalExpressions { let args = self.expressions(); - let order_bys = self + let order_by_exprs = self .order_bys() - .cloned() - .unwrap_or_else(LexOrdering::default); - let order_by_exprs = order_bys .iter() .map(|sort_expr| Arc::clone(&sort_expr.expr)) - .collect::>(); + .collect(); AggregatePhysicalExpressions { args, order_by_exprs, @@ -539,10 +615,42 @@ impl AggregateFunctionExpr { /// Returns `Some(Arc)` if re-write is supported, otherwise returns `None`. pub fn with_new_expressions( &self, - _args: Vec>, - _order_by_exprs: Vec>, + args: Vec>, + order_by_exprs: Vec>, ) -> Option { - None + if args.len() != self.args.len() + || (self.order_sensitivity() != AggregateOrderSensitivity::Insensitive + && order_by_exprs.len() != self.order_bys.len()) + { + return None; + } + + let new_order_bys = self + .order_bys + .iter() + .zip(order_by_exprs) + .map(|(req, new_expr)| PhysicalSortExpr { + expr: new_expr, + options: req.options, + }) + .collect(); + + Some(AggregateFunctionExpr { + fun: self.fun.clone(), + args, + return_field: Arc::clone(&self.return_field), + name: self.name.clone(), + // TODO: Human name should be updated after re-write to not mislead + human_display: self.human_display.clone(), + schema: self.schema.clone(), + order_bys: new_order_bys, + ignore_nulls: self.ignore_nulls, + ordering_fields: self.ordering_fields.clone(), + is_distinct: self.is_distinct, + is_reversed: false, + input_fields: self.input_fields.clone(), + is_nullable: self.is_nullable, + }) } /// If this function is max, return (output_field, true) @@ -552,7 +660,7 @@ impl AggregateFunctionExpr { /// output_field is the name of the column produced by this aggregate /// /// Note: this is used to use special aggregate implementations in certain conditions - pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { self.fun.is_descending().map(|flag| (self.field(), flag)) } @@ -597,7 +705,7 @@ pub struct AggregatePhysicalExpressions { impl PartialEq for AggregateFunctionExpr { fn eq(&self, other: &Self) -> bool { self.name == other.name - && self.data_type == other.data_type + && self.return_field == other.return_field && self.fun == other.fun && self.args.len() == other.args.len() && self diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 5abd50f6d1b4f..1d59dab8fd6dd 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -100,7 +100,7 @@ impl ExprBoundaries { ) -> Result { let field = schema.fields().get(col_index).ok_or_else(|| { internal_datafusion_err!( - "Could not create `ExprBoundaries`: in `try_from_column` `col_index` + "Could not create `ExprBoundaries`: in `try_from_column` `col_index` has gone out of bounds with a value of {col_index}, the schema has {} columns.", schema.fields.len() ) @@ -112,7 +112,7 @@ impl ExprBoundaries { .min_value .get_value() .cloned() - .unwrap_or(empty_field.clone()), + .unwrap_or_else(|| empty_field.clone()), col_stats .max_value .get_value() @@ -425,7 +425,7 @@ mod tests { fn test_analyze_invalid_boundary_exprs() { let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)])); let expr = col("a").lt(lit(10)).or(col("a").gt(lit(20))); - let expected_error = "Interval arithmetic does not support the operator OR"; + let expected_error = "OR operator cannot yet propagate true intervals"; let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap(); let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); let physical_expr = diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs new file mode 100644 index 0000000000000..b434694a20cc8 --- /dev/null +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ScalarFunctionExpr; +use arrow::array::{make_array, MutableArrayData, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_common::{internal_err, not_impl_err}; +use datafusion_expr::async_udf::AsyncScalarUDF; +use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::Display; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// Wrapper around a scalar function that can be evaluated asynchronously +#[derive(Debug, Clone, Eq)] +pub struct AsyncFuncExpr { + /// The name of the output column this function will generate + pub name: String, + /// The actual function (always `ScalarFunctionExpr`) + pub func: Arc, + /// The field that this function will return + return_field: FieldRef, +} + +impl Display for AsyncFuncExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "async_expr(name={}, expr={})", self.name, self.func) + } +} + +impl PartialEq for AsyncFuncExpr { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.func == Arc::clone(&other.func) + } +} + +impl Hash for AsyncFuncExpr { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.func.as_ref().hash(state); + } +} + +impl AsyncFuncExpr { + /// create a new AsyncFuncExpr + pub fn try_new( + name: impl Into, + func: Arc, + schema: &Schema, + ) -> Result { + let Some(_) = func.as_any().downcast_ref::() else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + func + ); + }; + + let return_field = func.return_field(schema)?; + Ok(Self { + name: name.into(), + func, + return_field, + }) + } + + /// return the name of the output column + pub fn name(&self) -> &str { + &self.name + } + + /// Return the output field generated by evaluating this function + pub fn field(&self, input_schema: &Schema) -> Result { + Ok(Field::new( + &self.name, + self.func.data_type(input_schema)?, + self.func.nullable(input_schema)?, + )) + } + + /// Return the ideal batch size for this function + pub fn ideal_batch_size(&self) -> Result> { + if let Some(expr) = self.func.as_any().downcast_ref::() { + if let Some(udf) = + expr.fun().inner().as_any().downcast_ref::() + { + return Ok(udf.ideal_batch_size()); + } + } + not_impl_err!("Can't get ideal_batch_size from {:?}", self.func) + } + + /// This (async) function is called for each record batch to evaluate the LLM expressions + /// + /// The output is the output of evaluating the async expression and the input record batch + pub async fn invoke_with_args( + &self, + batch: &RecordBatch, + config_options: Arc, + ) -> Result { + let Some(scalar_function_expr) = + self.func.as_any().downcast_ref::() + else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + self.func + ); + }; + + let Some(async_udf) = scalar_function_expr + .fun() + .inner() + .as_any() + .downcast_ref::() + else { + return not_impl_err!( + "Don't know how to evaluate async function: {:?}", + scalar_function_expr + ); + }; + + let arg_fields = scalar_function_expr + .args() + .iter() + .map(|e| e.return_field(batch.schema_ref())) + .collect::>>()?; + + let mut result_batches = vec![]; + if let Some(ideal_batch_size) = self.ideal_batch_size()? { + let mut remainder = batch.clone(); + while remainder.num_rows() > 0 { + let size = if ideal_batch_size > remainder.num_rows() { + remainder.num_rows() + } else { + ideal_batch_size + }; + + let current_batch = remainder.slice(0, size); // get next 10 rows + remainder = remainder.slice(size, remainder.num_rows() - size); + let args = scalar_function_expr + .args() + .iter() + .map(|e| e.evaluate(¤t_batch)) + .collect::>>()?; + result_batches.push( + async_udf + .invoke_async_with_args(ScalarFunctionArgs { + args, + arg_fields: arg_fields.clone(), + number_rows: current_batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&config_options), + }) + .await?, + ); + } + } else { + let args = scalar_function_expr + .args() + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + result_batches.push( + async_udf + .invoke_async_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields, + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&config_options), + }) + .await?, + ); + } + + let datas = ColumnarValue::values_to_arrays(&result_batches)? + .iter() + .map(|b| b.to_data()) + .collect::>(); + let total_len = datas.iter().map(|d| d.len()).sum(); + let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len); + datas.iter().enumerate().for_each(|(i, data)| { + mutable.extend(i, 0, data.len()); + }); + let array_ref = make_array(mutable.freeze()); + Ok(ColumnarValue::Array(array_ref)) + } +} + +impl PhysicalExpr for AsyncFuncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.func.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.func.nullable(input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + // TODO: implement this for scalar value input + not_impl_err!("AsyncFuncExpr.evaluate") + } + + fn children(&self) -> Vec<&Arc> { + self.func.children() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let new_func = Arc::clone(&self.func).with_new_children(children)?; + Ok(Arc::new(AsyncFuncExpr { + name: self.name.clone(), + func: new_func, + return_field: Arc::clone(&self.return_field), + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.func) + } +} diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 13a3c79a47a2f..66ce77ef415ef 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,30 +15,61 @@ // specific language governing permissions and limitations // under the License. -use super::{add_offset_to_expr, ProjectionMapping}; -use crate::{ - expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, - PhysicalSortExpr, PhysicalSortRequirement, -}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinType, ScalarValue}; -use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; use std::fmt::Display; +use std::ops::Deref; use std::sync::Arc; use std::vec::IntoIter; +use super::projection::ProjectionTargets; +use super::ProjectionMapping; +use crate::expressions::Literal; +use crate::physical_expr::add_offset_to_expr; +use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement}; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{HashMap, JoinType, Result, ScalarValue}; +use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; + use indexmap::{IndexMap, IndexSet}; -/// A structure representing a expression known to be constant in a physical execution plan. +/// Represents whether a constant expression's value is uniform or varies across +/// partitions. Has two variants: +/// - `Heterogeneous`: The constant expression may have different values for +/// different partitions. +/// - `Uniform(Option)`: The constant expression has the same value +/// across all partitions, or is `None` if the value is unknown. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub enum AcrossPartitions { + #[default] + Heterogeneous, + Uniform(Option), +} + +impl Display for AcrossPartitions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"), + AcrossPartitions::Uniform(value) => { + if let Some(val) = value { + write!(f, "(uniform: {val})") + } else { + write!(f, "(uniform: unknown)") + } + } + } + } +} + +/// A structure representing a expression known to be constant in a physical +/// execution plan. /// -/// The `ConstExpr` struct encapsulates an expression that is constant during the execution -/// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would -/// be known constant +/// The `ConstExpr` struct encapsulates an expression that is constant during +/// the execution of a query. For example if a filter like `A = 5` appears +/// earlier in the plan, `A` would become a constant in subsequent operations. /// /// # Fields /// /// - `expr`: Constant expression for a node in the physical plan. -/// /// - `across_partitions`: A boolean flag indicating whether the constant /// expression is the same across partitions. If set to `true`, the constant /// expression has same value for all partitions. If set to `false`, the @@ -50,108 +81,37 @@ use indexmap::{IndexMap, IndexSet}; /// # use datafusion_physical_expr::ConstExpr; /// # use datafusion_physical_expr::expressions::lit; /// let col = lit(5); -/// // Create a constant expression from a physical expression ref -/// let const_expr = ConstExpr::from(&col); -/// // create a constant expression from a physical expression +/// // Create a constant expression from a physical expression: /// let const_expr = ConstExpr::from(col); /// ``` -// TODO: Consider refactoring the `across_partitions` and `value` fields into an enum: -// -// ``` -// enum PartitionValues { -// Uniform(Option), // Same value across all partitions -// Heterogeneous(Vec>) // Different values per partition -// } -// ``` -// -// This would provide more flexible representation of partition values. -// Note: This is a breaking change for the equivalence API and should be -// addressed in a separate issue/PR. -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct ConstExpr { - /// The expression that is known to be constant (e.g. a `Column`) - expr: Arc, - /// Does the constant have the same value across all partitions? See - /// struct docs for more details - across_partitions: AcrossPartitions, -} - -#[derive(PartialEq, Clone, Debug)] -/// Represents whether a constant expression's value is uniform or varies across partitions. -/// -/// The `AcrossPartitions` enum is used to describe the nature of a constant expression -/// in a physical execution plan: -/// -/// - `Heterogeneous`: The constant expression may have different values for different partitions. -/// - `Uniform(Option)`: The constant expression has the same value across all partitions, -/// or is `None` if the value is not specified. -pub enum AcrossPartitions { - Heterogeneous, - Uniform(Option), -} - -impl Default for AcrossPartitions { - fn default() -> Self { - Self::Heterogeneous - } -} - -impl PartialEq for ConstExpr { - fn eq(&self, other: &Self) -> bool { - self.across_partitions == other.across_partitions && self.expr.eq(&other.expr) - } + /// The expression that is known to be constant (e.g. a `Column`). + pub expr: Arc, + /// Indicates whether the constant have the same value across all partitions. + pub across_partitions: AcrossPartitions, } +// TODO: The `ConstExpr` definition above can be in an inconsistent state where +// `expr` is a literal but `across_partitions` is not `Uniform`. Consider +// a refactor to ensure that `ConstExpr` is always in a consistent state +// (either by changing type definition, or by API constraints). impl ConstExpr { - /// Create a new constant expression from a physical expression. + /// Create a new constant expression from a physical expression, specifying + /// whether the constant expression is the same across partitions. /// - /// Note you can also use `ConstExpr::from` to create a constant expression - /// from a reference as well - pub fn new(expr: Arc) -> Self { - Self { - expr, - // By default, assume constant expressions are not same across partitions. - across_partitions: Default::default(), + /// Note that you can also use `ConstExpr::from` to create a constant + /// expression from just a physical expression, with the *safe* assumption + /// of heterogenous values across partitions unless the expression is a + /// literal. + pub fn new(expr: Arc, across_partitions: AcrossPartitions) -> Self { + let mut result = ConstExpr::from(expr); + // Override the across partitions specification if the expression is not + // a literal. + if result.across_partitions == AcrossPartitions::Heterogeneous { + result.across_partitions = across_partitions; } - } - - /// Set the `across_partitions` flag - /// - /// See struct docs for more details - pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self { - self.across_partitions = across_partitions; - self - } - - /// Is the expression the same across all partitions? - /// - /// See struct docs for more details - pub fn across_partitions(&self) -> AcrossPartitions { - self.across_partitions.clone() - } - - pub fn expr(&self) -> &Arc { - &self.expr - } - - pub fn owned_expr(self) -> Arc { - self.expr - } - - pub fn map(&self, f: F) -> Option - where - F: Fn(&Arc) -> Option>, - { - let maybe_expr = f(&self.expr); - maybe_expr.map(|expr| Self { - expr, - across_partitions: self.across_partitions.clone(), - }) - } - - /// Returns true if this constant expression is equal to the given expression - pub fn eq_expr(&self, other: impl AsRef) -> bool { - self.expr.as_ref() == other.as_ref() + result } /// Returns a [`Display`]able list of `ConstExpr`. @@ -166,7 +126,7 @@ impl ConstExpr { } else { write!(f, ",")?; } - write!(f, "{}", const_expr)?; + write!(f, "{const_expr}")?; } Ok(()) } @@ -175,47 +135,36 @@ impl ConstExpr { } } +impl PartialEq for ConstExpr { + fn eq(&self, other: &Self) -> bool { + self.across_partitions == other.across_partitions && self.expr.eq(&other.expr) + } +} + impl Display for ConstExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.expr)?; - match &self.across_partitions { - AcrossPartitions::Heterogeneous => { - write!(f, "(heterogeneous)")?; - } - AcrossPartitions::Uniform(value) => { - if let Some(val) = value { - write!(f, "(uniform: {})", val)?; - } else { - write!(f, "(uniform: unknown)")?; - } - } - } - Ok(()) + write!(f, "{}", self.across_partitions) } } impl From> for ConstExpr { fn from(expr: Arc) -> Self { - Self::new(expr) - } -} - -impl From<&Arc> for ConstExpr { - fn from(expr: &Arc) -> Self { - Self::new(Arc::clone(expr)) + // By default, assume constant expressions are not same across partitions. + // However, if we have a literal, it will have a single value that is the + // same across all partitions. + let across = if let Some(lit) = expr.as_any().downcast_ref::() { + AcrossPartitions::Uniform(Some(lit.value().clone())) + } else { + AcrossPartitions::Heterogeneous + }; + Self { + expr, + across_partitions: across, + } } } -/// Checks whether `expr` is among in the `const_exprs`. -pub fn const_exprs_contains( - const_exprs: &[ConstExpr], - expr: &Arc, -) -> bool { - const_exprs - .iter() - .any(|const_expr| const_expr.expr.eq(expr)) -} - /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by /// equality predicates (e.g. `a = b`), typically equi-join conditions and @@ -223,259 +172,361 @@ pub fn const_exprs_contains( /// /// Two `EquivalenceClass`es are equal if they contains the same expressions in /// without any ordering. -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct EquivalenceClass { - /// The expressions in this equivalence class. The order doesn't - /// matter for equivalence purposes - /// - exprs: IndexSet>, -} - -impl PartialEq for EquivalenceClass { - /// Returns true if other is equal in the sense - /// of bags (multi-sets), disregarding their orderings. - fn eq(&self, other: &Self) -> bool { - self.exprs.eq(&other.exprs) - } + /// The expressions in this equivalence class. The order doesn't matter for + /// equivalence purposes. + pub(crate) exprs: IndexSet>, + /// Indicates whether the expressions in this equivalence class have a + /// constant value. A `Some` value indicates constant-ness. + pub(crate) constant: Option, } impl EquivalenceClass { - /// Create a new empty equivalence class - pub fn new_empty() -> Self { - Self { - exprs: IndexSet::new(), - } - } - - // Create a new equivalence class from a pre-existing `Vec` - pub fn new(exprs: Vec>) -> Self { - Self { - exprs: exprs.into_iter().collect(), + // Create a new equivalence class from a pre-existing collection. + pub fn new(exprs: impl IntoIterator>) -> Self { + let mut class = Self::default(); + for expr in exprs { + class.push(expr); } - } - - /// Return the inner vector of expressions - pub fn into_vec(self) -> Vec> { - self.exprs.into_iter().collect() + class } /// Return the "canonical" expression for this class (the first element) - /// if any - fn canonical_expr(&self) -> Option> { - self.exprs.iter().next().cloned() + /// if non-empty. + pub fn canonical_expr(&self) -> Option<&Arc> { + self.exprs.iter().next() } /// Insert the expression into this class, meaning it is known to be equal to - /// all other expressions in this class + /// all other expressions in this class. pub fn push(&mut self, expr: Arc) { + if let Some(lit) = expr.as_any().downcast_ref::() { + let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone())); + if let Some(across) = self.constant.as_mut() { + // TODO: Return an error if constant values do not agree. + if *across == AcrossPartitions::Heterogeneous { + *across = expr_across; + } + } else { + self.constant = Some(expr_across); + } + } self.exprs.insert(expr); } - /// Inserts all the expressions from other into this class + /// Inserts all the expressions from other into this class. pub fn extend(&mut self, other: Self) { - for expr in other.exprs { - // use push so entries are deduplicated - self.push(expr); + self.exprs.extend(other.exprs); + match (&self.constant, &other.constant) { + (Some(across), Some(_)) => { + // TODO: Return an error if constant values do not agree. + if across == &AcrossPartitions::Heterogeneous { + self.constant = other.constant; + } + } + (None, Some(_)) => self.constant = other.constant, + (_, None) => {} } } - /// Returns true if this equivalence class contains t expression - pub fn contains(&self, expr: &Arc) -> bool { - self.exprs.contains(expr) - } - - /// Returns true if this equivalence class has any entries in common with `other` + /// Returns whether this equivalence class has any entries in common with + /// `other`. pub fn contains_any(&self, other: &Self) -> bool { - self.exprs.iter().any(|e| other.contains(e)) - } - - /// return the number of items in this class - pub fn len(&self) -> usize { - self.exprs.len() + self.exprs.intersection(&other.exprs).next().is_some() } - /// return true if this class is empty - pub fn is_empty(&self) -> bool { - self.exprs.is_empty() + /// Returns whether this equivalence class is trivial, meaning that it is + /// either empty, or contains a single expression that is not a constant. + /// Such classes are not useful, and can be removed from equivalence groups. + pub fn is_trivial(&self) -> bool { + self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none()) } - /// Iterate over all elements in this class, in some arbitrary order - pub fn iter(&self) -> impl Iterator> { - self.exprs.iter() - } - - /// Return a new equivalence class that have the specified offset added to - /// each expression (used when schemas are appended such as in joins) - pub fn with_offset(&self, offset: usize) -> Self { - let new_exprs = self + /// Adds the given offset to all columns in the expressions inside this + /// class. This is used when schemas are appended, e.g. in joins. + pub fn try_with_offset(&self, offset: isize) -> Result { + let mut cls = Self::default(); + for expr_result in self .exprs .iter() .cloned() .map(|e| add_offset_to_expr(e, offset)) - .collect(); - Self::new(new_exprs) + { + cls.push(expr_result?); + } + Ok(cls) + } +} + +impl Deref for EquivalenceClass { + type Target = IndexSet>; + + fn deref(&self) -> &Self::Target { + &self.exprs + } +} + +impl IntoIterator for EquivalenceClass { + type Item = Arc; + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.exprs.into_iter() } } impl Display for EquivalenceClass { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "[{}]", format_physical_expr_list(&self.exprs)) + write!(f, "{{")?; + write!(f, "members: {}", format_physical_expr_list(&self.exprs))?; + if let Some(across) = &self.constant { + write!(f, ", constant: {across}")?; + } + write!(f, "}}") + } +} + +impl From for Vec> { + fn from(cls: EquivalenceClass) -> Self { + cls.exprs.into_iter().collect() } } -/// A collection of distinct `EquivalenceClass`es -#[derive(Debug, Clone)] +type AugmentedMapping<'a> = IndexMap< + &'a Arc, + (&'a ProjectionTargets, Option<&'a EquivalenceClass>), +>; + +/// A collection of distinct `EquivalenceClass`es. This object supports fast +/// lookups of expressions and their equivalence classes. +#[derive(Clone, Debug, Default)] pub struct EquivalenceGroup { + /// A mapping from expressions to their equivalence class key. + map: HashMap, usize>, + /// The equivalence classes in this group. classes: Vec, } impl EquivalenceGroup { - /// Creates an empty equivalence group. - pub fn empty() -> Self { - Self { classes: vec![] } - } - /// Creates an equivalence group from the given equivalence classes. - pub fn new(classes: Vec) -> Self { - let mut result = Self { classes }; - result.remove_redundant_entries(); - result - } - - /// Returns how many equivalence classes there are in this group. - pub fn len(&self) -> usize { - self.classes.len() + pub fn new(classes: impl IntoIterator) -> Self { + classes.into_iter().collect::>().into() } - /// Checks whether this equivalence group is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 + /// Adds `expr` as a constant expression to this equivalence group. + pub fn add_constant(&mut self, const_expr: ConstExpr) { + // If the expression is already in an equivalence class, we should + // adjust the constant-ness of the class if necessary: + if let Some(idx) = self.map.get(&const_expr.expr) { + let cls = &mut self.classes[*idx]; + if let Some(across) = cls.constant.as_mut() { + // TODO: Return an error if constant values do not agree. + if *across == AcrossPartitions::Heterogeneous { + *across = const_expr.across_partitions; + } + } else { + cls.constant = Some(const_expr.across_partitions); + } + return; + } + // If the expression is not in any equivalence class, but has the same + // constant value with some class, add it to that class: + if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions { + for (idx, cls) in self.classes.iter_mut().enumerate() { + if cls + .constant + .as_ref() + .is_some_and(|across| const_expr.across_partitions.eq(across)) + { + self.map.insert(Arc::clone(&const_expr.expr), idx); + cls.push(const_expr.expr); + return; + } + } + } + // Otherwise, create a new class with the expression as the only member: + let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr)); + if new_class.constant.is_none() { + new_class.constant = Some(const_expr.across_partitions); + } + Self::update_lookup_table(&mut self.map, &new_class, self.classes.len()); + self.classes.push(new_class); } - /// Returns an iterator over the equivalence classes in this group. - pub fn iter(&self) -> impl Iterator { - self.classes.iter() + /// Removes constant expressions that may change across partitions. + /// This method should be used when merging data from different partitions. + /// Returns whether any change was made to the equivalence group. + pub fn clear_per_partition_constants(&mut self) -> bool { + let (mut idx, mut change) = (0, false); + while idx < self.classes.len() { + let cls = &mut self.classes[idx]; + if let Some(AcrossPartitions::Heterogeneous) = cls.constant { + change = true; + if cls.len() == 1 { + // If this class becomes trivial, remove it entirely: + self.remove_class_at_idx(idx); + continue; + } else { + cls.constant = None; + } + } + idx += 1; + } + change } - /// Adds the equality `left` = `right` to this equivalence group. - /// New equality conditions often arise after steps like `Filter(a = b)`, - /// `Alias(a, a as b)` etc. + /// Adds the equality `left` = `right` to this equivalence group. New + /// equality conditions often arise after steps like `Filter(a = b)`, + /// `Alias(a, a as b)` etc. Returns whether the given equality defines + /// a new equivalence class. pub fn add_equal_conditions( &mut self, - left: &Arc, - right: &Arc, - ) { - let mut first_class = None; - let mut second_class = None; - for (idx, cls) in self.classes.iter().enumerate() { - if cls.contains(left) { - first_class = Some(idx); - } - if cls.contains(right) { - second_class = Some(idx); - } - } + left: Arc, + right: Arc, + ) -> bool { + let first_class = self.map.get(&left).copied(); + let second_class = self.map.get(&right).copied(); match (first_class, second_class) { (Some(mut first_idx), Some(mut second_idx)) => { // If the given left and right sides belong to different classes, // we should unify/bridge these classes. - if first_idx != second_idx { - // By convention, make sure `second_idx` is larger than `first_idx`. - if first_idx > second_idx { - (first_idx, second_idx) = (second_idx, first_idx); + match first_idx.cmp(&second_idx) { + // The equality is already known, return and signal this: + std::cmp::Ordering::Equal => return false, + // Swap indices to ensure `first_idx` is the lesser index. + std::cmp::Ordering::Greater => { + std::mem::swap(&mut first_idx, &mut second_idx); } - // Remove the class at `second_idx` and merge its values with - // the class at `first_idx`. The convention above makes sure - // that `first_idx` is still valid after removing `second_idx`. - let other_class = self.classes.swap_remove(second_idx); - self.classes[first_idx].extend(other_class); + _ => {} } + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. + let other_class = self.remove_class_at_idx(second_idx); + // Update the lookup table for the second class: + Self::update_lookup_table(&mut self.map, &other_class, first_idx); + self.classes[first_idx].extend(other_class); } (Some(group_idx), None) => { // Right side is new, extend left side's class: - self.classes[group_idx].push(Arc::clone(right)); + self.map.insert(Arc::clone(&right), group_idx); + self.classes[group_idx].push(right); } (None, Some(group_idx)) => { // Left side is new, extend right side's class: - self.classes[group_idx].push(Arc::clone(left)); + self.map.insert(Arc::clone(&left), group_idx); + self.classes[group_idx].push(left); } (None, None) => { // None of the expressions is among existing classes. // Create a new equivalence class and extend the group. - self.classes.push(EquivalenceClass::new(vec![ - Arc::clone(left), - Arc::clone(right), - ])); + let class = EquivalenceClass::new([left, right]); + Self::update_lookup_table(&mut self.map, &class, self.classes.len()); + self.classes.push(class); + return true; } } + false } - /// Removes redundant entries from this group. - fn remove_redundant_entries(&mut self) { - // Remove duplicate entries from each equivalence class: - self.classes.retain_mut(|cls| { - // Keep groups that have at least two entries as singleton class is - // meaningless (i.e. it contains no non-trivial information): - cls.len() > 1 - }); - // Unify/bridge groups that have common expressions: - self.bridge_classes() + /// Removes the equivalence class at the given index from this group. + fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass { + // Remove the class at the given index: + let cls = self.classes.swap_remove(idx); + // Remove its entries from the lookup table: + for expr in cls.iter() { + self.map.remove(expr); + } + // Update the lookup table for the moved class: + if idx < self.classes.len() { + Self::update_lookup_table(&mut self.map, &self.classes[idx], idx); + } + cls + } + + /// Updates the entry in lookup table for the given equivalence class with + /// the given index. + fn update_lookup_table( + map: &mut HashMap, usize>, + cls: &EquivalenceClass, + idx: usize, + ) { + for expr in cls.iter() { + map.insert(Arc::clone(expr), idx); + } + } + + /// Removes redundant entries from this group. Returns whether any change + /// was made to the equivalence group. + fn remove_redundant_entries(&mut self) -> bool { + // First, remove trivial equivalence classes: + let mut change = false; + for idx in (0..self.classes.len()).rev() { + if self.classes[idx].is_trivial() { + self.remove_class_at_idx(idx); + change = true; + } + } + // Then, unify/bridge groups that have common expressions: + self.bridge_classes() || change } /// This utility function unifies/bridges classes that have common expressions. /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all /// equal and belong to one class. This utility converts merges such classes. - fn bridge_classes(&mut self) { - let mut idx = 0; - while idx < self.classes.len() { - let mut next_idx = idx + 1; - let start_size = self.classes[idx].len(); - while next_idx < self.classes.len() { - if self.classes[idx].contains_any(&self.classes[next_idx]) { - let extension = self.classes.swap_remove(next_idx); + /// Returns whether any change was made to the equivalence group. + fn bridge_classes(&mut self) -> bool { + let (mut idx, mut change) = (0, false); + 'scan: while idx < self.classes.len() { + for other_idx in (idx + 1..self.classes.len()).rev() { + if self.classes[idx].contains_any(&self.classes[other_idx]) { + let extension = self.remove_class_at_idx(other_idx); + Self::update_lookup_table(&mut self.map, &extension, idx); self.classes[idx].extend(extension); - } else { - next_idx += 1; + change = true; + continue 'scan; } } - if self.classes[idx].len() > start_size { - continue; - } idx += 1; } + change } /// Extends this equivalence group with the `other` equivalence group. - pub fn extend(&mut self, other: Self) { + /// Returns whether any equivalence classes were unified/bridged as a + /// result of the extension process. + pub fn extend(&mut self, other: Self) -> bool { + for (idx, cls) in other.classes.iter().enumerate() { + // Update the lookup table for the new class: + Self::update_lookup_table(&mut self.map, cls, idx); + } self.classes.extend(other.classes); - self.remove_redundant_entries(); + self.bridge_classes() } - /// Normalizes the given physical expression according to this group. - /// The expression is replaced with the first expression in the equivalence - /// class it matches with (if any). + /// Normalizes the given physical expression according to this group. The + /// expression is replaced with the first (canonical) expression in the + /// equivalence class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.transform(|expr| { - for cls in self.iter() { - if cls.contains(&expr) { - // The unwrap below is safe because the guard above ensures - // that the class is not empty. - return Ok(Transformed::yes(cls.canonical_expr().unwrap())); - } - } - Ok(Transformed::no(expr)) + let cls = self.get_equivalence_class(&expr); + let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else { + return Ok(Transformed::no(expr)); + }; + Ok(Transformed::yes(Arc::clone(canonical))) }) .data() .unwrap() // The unwrap above is safe because the closure always returns `Ok`. } - /// Normalizes the given sort expression according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the sort expression as is. + /// Normalizes the given sort expression according to this group. The + /// underlying physical expression is replaced with the first expression in + /// the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, + /// returns the sort expression as is. pub fn normalize_sort_expr( &self, mut sort_expr: PhysicalSortExpr, @@ -484,11 +535,29 @@ impl EquivalenceGroup { sort_expr } - /// Normalizes the given sort requirement according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the given sort requirement as is. + /// Normalizes the given sort expressions (i.e. `sort_exprs`) by: + /// - Replacing sections that belong to some equivalence class in the + /// with the first entry in the matching equivalence class. + /// - Removing expressions that have a constant value. + /// + /// If columns `a` and `b` are known to be equal, `d` is known to be a + /// constant, and `sort_exprs` is `[b ASC, d DESC, c ASC, a ASC]`, this + /// function would return `[a ASC, c ASC, a ASC]`. + pub fn normalize_sort_exprs<'a>( + &'a self, + sort_exprs: impl IntoIterator + 'a, + ) -> impl Iterator + 'a { + sort_exprs + .into_iter() + .map(|sort_expr| self.normalize_sort_expr(sort_expr)) + .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none()) + } + + /// Normalizes the given sort requirement according to this group. The + /// underlying physical expression is replaced with the first expression in + /// the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, + /// returns the given sort requirement as is. pub fn normalize_sort_requirement( &self, mut sort_requirement: PhysicalSortRequirement, @@ -497,44 +566,81 @@ impl EquivalenceGroup { sort_requirement } - /// This function applies the `normalize_expr` function for all expressions - /// in `exprs` and returns the corresponding normalized physical expressions. - pub fn normalize_exprs( - &self, - exprs: impl IntoIterator>, - ) -> Vec> { - exprs + /// Normalizes the given sort requirements (i.e. `sort_reqs`) by: + /// - Replacing sections that belong to some equivalence class in the + /// with the first entry in the matching equivalence class. + /// - Removing expressions that have a constant value. + /// + /// If columns `a` and `b` are known to be equal, `d` is known to be a + /// constant, and `sort_reqs` is `[b ASC, d DESC, c ASC, a ASC]`, this + /// function would return `[a ASC, c ASC, a ASC]`. + pub fn normalize_sort_requirements<'a>( + &'a self, + sort_reqs: impl IntoIterator + 'a, + ) -> impl Iterator + 'a { + sort_reqs .into_iter() - .map(|expr| self.normalize_expr(expr)) - .collect() + .map(|req| self.normalize_sort_requirement(req)) + .filter(|req| self.is_expr_constant(&req.expr).is_none()) } - /// This function applies the `normalize_sort_expr` function for all sort - /// expressions in `sort_exprs` and returns the corresponding normalized - /// sort expressions. - pub fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = LexRequirement::from(sort_exprs.clone()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - LexOrdering::from(normalized_sort_reqs) + /// Perform an indirect projection of `expr` by consulting the equivalence + /// classes. + fn project_expr_indirect( + aug_mapping: &AugmentedMapping, + expr: &Arc, + ) -> Option> { + // Literals don't need to be projected + if expr.as_any().downcast_ref::().is_some() { + return Some(Arc::clone(expr)); + } + + // The given expression is not inside the mapping, so we try to project + // indirectly using equivalence classes. + for (targets, eq_class) in aug_mapping.values() { + // If we match an equivalent expression to a source expression in + // the mapping, then we can project. For example, if we have the + // mapping `(a as a1, a + c)` and the equivalence `a == b`, + // expression `b` projects to `a1`. + if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) { + let (target, _) = targets.first(); + return Some(Arc::clone(target)); + } + } + // Project a non-leaf expression by projecting its children. + let children = expr.children(); + if children.is_empty() { + // A leaf expression should be inside the mapping. + return None; + } + children + .into_iter() + .map(|child| { + // First, we try to project children with an exact match. If + // we are unable to do this, we consult equivalence classes. + if let Some((targets, _)) = aug_mapping.get(child) { + // If we match the source, we can project directly: + let (target, _) = targets.first(); + Some(Arc::clone(target)) + } else { + Self::project_expr_indirect(aug_mapping, child) + } + }) + .collect::>>() + .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) } - /// This function applies the `normalize_sort_requirement` function for all - /// requirements in `sort_reqs` and returns the corresponding normalized - /// sort requirements. - pub fn normalize_sort_requirements( - &self, - sort_reqs: &LexRequirement, - ) -> LexRequirement { - LexRequirement::new( - sort_reqs - .iter() - .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) - .collect(), - ) - .collapse() + fn augment_projection_mapping<'a>( + &'a self, + mapping: &'a ProjectionMapping, + ) -> AugmentedMapping<'a> { + mapping + .iter() + .map(|(k, v)| { + let eq_class = self.get_equivalence_class(k); + (k, (v, eq_class)) + }) + .collect() } /// Projects `expr` according to the given projection mapping. @@ -544,81 +650,118 @@ impl EquivalenceGroup { mapping: &ProjectionMapping, expr: &Arc, ) -> Option> { - // First, we try to project expressions with an exact match. If we are - // unable to do this, we consult equivalence classes. - if let Some(target) = mapping.target_expr(expr) { + if let Some(targets) = mapping.get(expr) { // If we match the source, we can project directly: - return Some(target); + let (target, _) = targets.first(); + Some(Arc::clone(target)) } else { - // If the given expression is not inside the mapping, try to project - // expressions considering the equivalence classes. - for (source, target) in mapping.iter() { - // If we match an equivalent expression to `source`, then we can - // project. For example, if we have the mapping `(a as a1, a + c)` - // and the equivalence class `(a, b)`, expression `b` projects to `a1`. - if self - .get_equivalence_class(source) - .is_some_and(|group| group.contains(expr)) - { - return Some(Arc::clone(target)); - } - } - } - // Project a non-leaf expression by projecting its children. - let children = expr.children(); - if children.is_empty() { - // Leaf expression should be inside mapping. - return None; + let aug_mapping = self.augment_projection_mapping(mapping); + Self::project_expr_indirect(&aug_mapping, expr) } - children - .into_iter() - .map(|child| self.project_expr(mapping, child)) - .collect::>>() - .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) + } + + /// Projects `expressions` according to the given projection mapping. + /// This function is similar to [`Self::project_expr`], but projects multiple + /// expressions at once more efficiently than calling `project_expr` for each + /// expression. + pub fn project_expressions<'a>( + &'a self, + mapping: &'a ProjectionMapping, + expressions: impl IntoIterator> + 'a, + ) -> impl Iterator>> + 'a { + let mut aug_mapping = None; + expressions.into_iter().map(move |expr| { + if let Some(targets) = mapping.get(expr) { + // If we match the source, we can project directly: + let (target, _) = targets.first(); + Some(Arc::clone(target)) + } else { + let aug_mapping = aug_mapping + .get_or_insert_with(|| self.augment_projection_mapping(mapping)); + Self::project_expr_indirect(aug_mapping, expr) + } + }) } /// Projects this equivalence group according to the given projection mapping. pub fn project(&self, mapping: &ProjectionMapping) -> Self { - let projected_classes = self.iter().filter_map(|cls| { - let new_class = cls - .iter() - .filter_map(|expr| self.project_expr(mapping, expr)) - .collect::>(); - (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) + let projected_classes = self.iter().map(|cls| { + let new_exprs = self.project_expressions(mapping, cls.iter()); + EquivalenceClass::new(new_exprs.flatten()) }); // The key is the source expression, and the value is the equivalence // class that contains the corresponding target expression. - let mut new_classes: IndexMap<_, _> = IndexMap::new(); - for (source, target) in mapping.iter() { + let mut new_constants = vec![]; + let mut new_classes = IndexMap::<_, EquivalenceClass>::new(); + for (source, targets) in mapping.iter() { // We need to find equivalent projected expressions. For example, // consider a table with columns `[a, b, c]` with `a` == `b`, and // projection `[a + c, b + c]`. To conclude that `a + c == b + c`, // we first normalize all source expressions in the mapping, then // merge all equivalent expressions into the classes. let normalized_expr = self.normalize_expr(Arc::clone(source)); - new_classes - .entry(normalized_expr) - .or_insert_with(EquivalenceClass::new_empty) - .push(Arc::clone(target)); + let cls = new_classes.entry(normalized_expr).or_default(); + for (target, _) in targets.iter() { + cls.push(Arc::clone(target)); + } + // Save new constants arising from the projection: + if let Some(across) = self.is_expr_constant(source) { + for (target, _) in targets.iter() { + let const_expr = ConstExpr::new(Arc::clone(target), across.clone()); + new_constants.push(const_expr); + } + } } - // Only add equivalence classes with at least two members as singleton - // equivalence classes are meaningless. - let new_classes = new_classes - .into_iter() - .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls)); - let classes = projected_classes.chain(new_classes).collect(); - Self::new(classes) + // Union projected classes with new classes to make up the result: + let classes = projected_classes + .chain(new_classes.into_values()) + .filter(|cls| !cls.is_trivial()); + let mut result = Self::new(classes); + // Add new constants arising from the projection to the equivalence group: + for constant in new_constants { + result.add_constant(constant); + } + result + } + + /// Returns a `Some` value if the expression is constant according to + /// equivalence group, and `None` otherwise. The `Some` variant contains + /// an `AcrossPartitions` value indicating whether the expression is + /// constant across partitions, and its actual value (if available). + pub fn is_expr_constant( + &self, + expr: &Arc, + ) -> Option { + if let Some(lit) = expr.as_any().downcast_ref::() { + return Some(AcrossPartitions::Uniform(Some(lit.value().clone()))); + } + if let Some(cls) = self.get_equivalence_class(expr) { + if cls.constant.is_some() { + return cls.constant.clone(); + } + } + // TODO: This function should be able to return values of non-literal + // complex constants as well; e.g. it should return `8` for the + // expression `3 + 5`, not an unknown `heterogenous` value. + let children = expr.children(); + if children.is_empty() { + return None; + } + for child in children { + self.is_expr_constant(child)?; + } + Some(AcrossPartitions::Heterogeneous) } /// Returns the equivalence class containing `expr`. If no equivalence class /// contains `expr`, returns `None`. - fn get_equivalence_class( + pub fn get_equivalence_class( &self, expr: &Arc, ) -> Option<&EquivalenceClass> { - self.iter().find(|cls| cls.contains(expr)) + self.map.get(expr).map(|idx| &self.classes[*idx]) } /// Combine equivalence groups of the given join children. @@ -628,18 +771,16 @@ impl EquivalenceGroup { join_type: &JoinType, left_size: usize, on: &[(PhysicalExprRef, PhysicalExprRef)], - ) -> Self { - match join_type { + ) -> Result { + let group = match join_type { JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { let mut result = Self::new( - self.iter() - .cloned() - .chain( - right_equivalences - .iter() - .map(|cls| cls.with_offset(left_size)), - ) - .collect(), + self.iter().cloned().chain( + right_equivalences + .iter() + .map(|cls| cls.try_with_offset(left_size as _)) + .collect::>>()?, + ), ); // In we have an inner join, expressions in the "on" condition // are equal in the resulting table. @@ -647,36 +788,25 @@ impl EquivalenceGroup { for (lhs, rhs) in on.iter() { let new_lhs = Arc::clone(lhs); // Rewrite rhs to point to the right side of the join: - let new_rhs = Arc::clone(rhs) - .transform(|expr| { - if let Some(column) = - expr.as_any().downcast_ref::() - { - let new_column = Arc::new(Column::new( - column.name(), - column.index() + left_size, - )) - as _; - return Ok(Transformed::yes(new_column)); - } - - Ok(Transformed::no(expr)) - }) - .data() - .unwrap(); - result.add_equal_conditions(&new_lhs, &new_rhs); + let new_rhs = + add_offset_to_expr(Arc::clone(rhs), left_size as _)?; + result.add_equal_conditions(new_lhs, new_rhs); } } result } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), - JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), - } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right_equivalences.clone() + } + }; + Ok(group) } - /// Checks if two expressions are equal either directly or through equivalence classes. - /// For complex expressions (e.g. a + b), checks that the expression trees are structurally - /// identical and their leaf nodes are equivalent either directly or through equivalence classes. + /// Checks if two expressions are equal directly or through equivalence + /// classes. For complex expressions (e.g. `a + b`), checks that the + /// expression trees are structurally identical and their leaf nodes are + /// equivalent either directly or through equivalence classes. pub fn exprs_equal( &self, left: &Arc, @@ -726,16 +856,19 @@ impl EquivalenceGroup { .zip(right_children) .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child)) } +} + +impl Deref for EquivalenceGroup { + type Target = [EquivalenceClass]; - /// Return the inner classes of this equivalence group. - pub fn into_inner(self) -> Vec { - self.classes + fn deref(&self) -> &Self::Target { + &self.classes } } impl IntoIterator for EquivalenceGroup { type Item = EquivalenceClass; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.classes.into_iter() @@ -747,20 +880,37 @@ impl Display for EquivalenceGroup { write!(f, "[")?; let mut iter = self.iter(); if let Some(cls) = iter.next() { - write!(f, "{}", cls)?; + write!(f, "{cls}")?; } for cls in iter { - write!(f, ", {}", cls)?; + write!(f, ", {cls}")?; } write!(f, "]") } } +impl From> for EquivalenceGroup { + fn from(classes: Vec) -> Self { + let mut result = Self { + map: classes + .iter() + .enumerate() + .flat_map(|(idx, cls)| { + cls.iter().map(move |expr| (Arc::clone(expr), idx)) + }) + .collect(), + classes, + }; + result.remove_redundant_entries(); + result + } +} + #[cfg(test)] mod tests { use super::*; use crate::equivalence::tests::create_test_params; - use crate::expressions::{binary, col, lit, BinaryExpr, Literal}; + use crate::expressions::{binary, col, lit, BinaryExpr, Column, Literal}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Result, ScalarValue}; @@ -786,24 +936,32 @@ mod tests { for (entries, expected) in test_cases { let entries = entries .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(|entry| { + entry.into_iter().map(|idx| { + let c = Column::new(format!("col_{idx}").as_str(), idx); + Arc::new(c) as _ + }) + }) .map(EquivalenceClass::new) .collect::>(); let expected = expected .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(|entry| { + entry.into_iter().map(|idx| { + let c = Column::new(format!("col_{idx}").as_str(), idx); + Arc::new(c) as _ + }) + }) .map(EquivalenceClass::new) .collect::>(); - let mut eq_groups = EquivalenceGroup::new(entries.clone()); - eq_groups.bridge_classes(); + let eq_groups: EquivalenceGroup = entries.clone().into(); let eq_groups = eq_groups.classes; let err_msg = format!( - "error in test entries: {:?}, expected: {:?}, actual:{:?}", - entries, expected, eq_groups + "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}" ); - assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); + assert_eq!(eq_groups.len(), expected.len(), "{err_msg}"); for idx in 0..eq_groups.len() { - assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); + assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}"); } } Ok(()) @@ -811,58 +969,45 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { + let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _; let entries = [ - EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), - // This group is meaningless should be removed - EquivalenceClass::new(vec![lit(3), lit(3)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + EquivalenceClass::new([c(1), c(1), lit(20)]), + EquivalenceClass::new([lit(30), lit(30)]), + EquivalenceClass::new([c(2), c(3), c(4)]), ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. let expected = [ - EquivalenceClass::new(vec![lit(1), lit(2)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + EquivalenceClass::new([c(1), lit(20)]), + EquivalenceClass::new([lit(30)]), + EquivalenceClass::new([c(2), c(3), c(4)]), ]; - let mut eq_groups = EquivalenceGroup::new(entries.to_vec()); - eq_groups.remove_redundant_entries(); - - let eq_groups = eq_groups.classes; - assert_eq!(eq_groups.len(), expected.len()); - assert_eq!(eq_groups.len(), 2); - - assert_eq!(eq_groups[0], expected[0]); - assert_eq!(eq_groups[1], expected[1]); + let eq_groups = EquivalenceGroup::new(entries); + assert_eq!(eq_groups.classes, expected); Ok(()) } #[test] fn test_schema_normalize_expr_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); + let col_a = Arc::new(Column::new("a", 0)) as Arc; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; // Assume that column a and c are aliases. - let (_test_schema, eq_properties) = create_test_params()?; - - let col_a_expr = Arc::new(col_a.clone()) as Arc; - let col_b_expr = Arc::new(col_b.clone()) as Arc; - let col_c_expr = Arc::new(col_c.clone()) as Arc; - // Test cases for equivalence normalization, - // First entry in the tuple is argument, second entry is expected result after normalization. + let (_, eq_properties) = create_test_params()?; + // Test cases for equivalence normalization. First entry in the tuple is + // the argument, second entry is expected result after normalization. let expressions = vec![ // Normalized version of the column a and c should go to a // (by convention all the expressions inside equivalence class are mapped to the first entry // in this case a is the first entry in the equivalence class.) - (&col_a_expr, &col_a_expr), - (&col_c_expr, &col_a_expr), + (Arc::clone(&col_a), Arc::clone(&col_a)), + (col_c, col_a), // Cannot normalize column b - (&col_b_expr, &col_b_expr), + (Arc::clone(&col_b), Arc::clone(&col_b)), ]; let eq_group = eq_properties.eq_group(); for (expr, expected_eq) in expressions { - assert!( - expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))), - "error in test: expr: {expr:?}" - ); + assert!(expected_eq.eq(&eq_group.normalize_expr(expr))); } Ok(()) @@ -870,21 +1015,15 @@ mod tests { #[test] fn test_contains_any() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) - as Arc; - let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; - let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - - let cls1 = - EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]); - let cls2 = - EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]); - let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]); + let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _; + let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _; + let col_a_expr = Arc::new(Column::new("a", 0)) as _; + let col_b_expr = Arc::new(Column::new("b", 1)) as _; + let col_c_expr = Arc::new(Column::new("c", 2)) as _; + + let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]); + let cls2 = EquivalenceClass::new([lit_true, col_b_expr]); + let cls3 = EquivalenceClass::new([col_c_expr, lit_false]); // lit_true is common assert!(cls1.contains_any(&cls2)); @@ -903,21 +1042,19 @@ mod tests { } // Create test columns - let col_a = Arc::new(Column::new("a", 0)) as Arc; - let col_b = Arc::new(Column::new("b", 1)) as Arc; - let col_x = Arc::new(Column::new("x", 2)) as Arc; - let col_y = Arc::new(Column::new("y", 3)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_x = Arc::new(Column::new("x", 2)) as _; + let col_y = Arc::new(Column::new("y", 3)) as _; // Create test literals - let lit_1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let lit_2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _; + let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _; // Create equivalence group with classes (a = x) and (b = y) - let eq_group = EquivalenceGroup::new(vec![ - EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]), - EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]), + let eq_group = EquivalenceGroup::new([ + EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]), + EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]), ]); let test_cases = vec![ @@ -967,12 +1104,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&col_b), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&col_y), - )) as Arc, + )) as _, expected: true, description: "Binary expressions with equivalent operands should be equal", @@ -982,12 +1119,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&col_b), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&col_a), - )) as Arc, + )) as _, expected: false, description: "Binary expressions with non-equivalent operands should not be equal", @@ -997,12 +1134,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&lit_1), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&lit_1), - )) as Arc, + )) as _, expected: true, description: "Binary expressions with equivalent column and same literal should be equal", }, @@ -1015,7 +1152,7 @@ mod tests { )), Operator::Multiply, Arc::clone(&lit_1), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::clone(&col_x), @@ -1024,7 +1161,7 @@ mod tests { )), Operator::Multiply, Arc::clone(&lit_1), - )) as Arc, + )) as _, expected: true, description: "Nested binary expressions with equivalent operands should be equal", }, @@ -1040,8 +1177,7 @@ mod tests { let actual = eq_group.exprs_equal(&left, &right); assert_eq!( actual, expected, - "{}: Failed comparing {:?} and {:?}, expected {}, got {}", - description, left, right, expected, actual + "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}" ); } @@ -1059,36 +1195,36 @@ mod tests { Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let mut group = EquivalenceGroup::empty(); - group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?); + let mut group = EquivalenceGroup::default(); + group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?); let projected_schema = Arc::new(Schema::new(vec![ Field::new("a+c", DataType::Int32, false), Field::new("b+c", DataType::Int32, false), ])); - let mapping = ProjectionMapping { - map: vec![ - ( - binary( - col("a", &schema)?, - Operator::Plus, - col("c", &schema)?, - &schema, - )?, - col("a+c", &projected_schema)?, - ), - ( - binary( - col("b", &schema)?, - Operator::Plus, - col("c", &schema)?, - &schema, - )?, - col("b+c", &projected_schema)?, - ), - ], - }; + let mapping = [ + ( + binary( + col("a", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + vec![(col("a+c", &projected_schema)?, 0)].into(), + ), + ( + binary( + col("b", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + vec![(col("b+c", &projected_schema)?, 1)].into(), + ), + ] + .into_iter() + .collect::(); let projected = group.project(&mapping); diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index e94d2bad57126..bcc6835e2f6c7 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Borrow; use std::sync::Arc; -use crate::expressions::Column; -use crate::{LexRequirement, PhysicalExpr}; +use crate::PhysicalExpr; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use arrow::compute::SortOptions; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; mod class; mod ordering; @@ -29,55 +30,39 @@ mod properties; pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup}; pub use ordering::OrderingEquivalenceClass; -pub use projection::ProjectionMapping; +pub use projection::{project_ordering, project_orderings, ProjectionMapping}; pub use properties::{ calculate_union, join_equivalence_properties, EquivalenceProperties, }; -/// This function constructs a duplicate-free `LexOrderingReq` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. -/// -/// It will also filter out entries that are ordered if the next entry is; -/// for instance, `vec![floor(a) Some(ASC), a Some(ASC)]` will be collapsed to -/// `vec![a Some(ASC)]`. -#[deprecated(since = "45.0.0", note = "Use LexRequirement::collapse")] -pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { - input.collapse() +// Convert each tuple to a `PhysicalSortExpr` and construct a vector. +pub fn convert_to_sort_exprs>>( + args: &[(T, SortOptions)], +) -> Vec { + args.iter() + .map(|(expr, options)| PhysicalSortExpr::new(Arc::clone(expr.borrow()), *options)) + .collect() } -/// Adds the `offset` value to `Column` indices inside `expr`. This function is -/// generally used during the update of the right table schema in join operations. -pub fn add_offset_to_expr( - expr: Arc, - offset: usize, -) -> Arc { - expr.transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::yes(Arc::new(Column::new( - col.name(), - offset + col.index(), - )))), - None => Ok(Transformed::no(e)), - }) - .data() - .unwrap() - // Note that we can safely unwrap here since our transform always returns - // an `Ok` value. +// Convert each vector of tuples to a `LexOrdering`. +pub fn convert_to_orderings>>( + args: &[Vec<(T, SortOptions)>], +) -> Vec { + args.iter() + .filter_map(|sort_exprs| LexOrdering::new(convert_to_sort_exprs(sort_exprs))) + .collect() } #[cfg(test)] mod tests { - use super::*; - use crate::expressions::col; - use crate::PhysicalSortExpr; + use crate::expressions::{col, Column}; + use crate::{LexRequirement, PhysicalSortExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{plan_datafusion_err, Result}; - use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, PhysicalSortRequirement, - }; + use datafusion_common::{plan_err, Result}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement; /// Converts a string to a physical sort expression /// @@ -97,8 +82,7 @@ mod tests { "ASC" => sort_expr.asc(), "DESC" => sort_expr.desc(), _ => panic!( - "unknown sort options. Expected 'ASC' or 'DESC', got {}", - options + "unknown sort options. Expected 'ASC' or 'DESC', got {options}" ), } } @@ -115,27 +99,21 @@ mod tests { mapping: &ProjectionMapping, input_schema: &Arc, ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); + // Calculate output schema: + let mut fields = vec![]; + for (source, targets) in mapping.iter() { + let data_type = source.data_type(input_schema)?; + let nullable = source.nullable(input_schema)?; + for (target, _) in targets.iter() { + let Some(column) = target.as_any().downcast_ref::() else { + return plan_err!("Expects to have column"); + }; + fields.push(Field::new(column.name(), data_type.clone(), nullable)); + } + } let output_schema = Arc::new(Schema::new_with_metadata( - fields?, + fields, input_schema.metadata().clone(), )); @@ -164,15 +142,15 @@ mod tests { /// Column [a=c] (e.g they are aliases). pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; + let col_a = col("a", &test_schema)?; + let col_b = col("b", &test_schema)?; + let col_c = col("c", &test_schema)?; + let col_d = col("d", &test_schema)?; + let col_e = col("e", &test_schema)?; + let col_f = col("f", &test_schema)?; + let col_g = col("g", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_c))?; let option_asc = SortOptions { descending: false, @@ -195,68 +173,19 @@ mod tests { ], ]; let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); Ok((test_schema, eq_properties)) } - // Convert each tuple to PhysicalSortRequirement + // Convert each tuple to a `PhysicalSortRequirement` and construct a + // a `LexRequirement` from them. pub fn convert_to_sort_reqs( - in_data: &[(&Arc, Option)], + args: &[(&Arc, Option)], ) -> LexRequirement { - in_data - .iter() - .map(|(expr, options)| { - PhysicalSortRequirement::new(Arc::clone(*expr), *options) - }) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - pub fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect() - } - - // Convert each inner tuple to PhysicalSortExpr - pub fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], - ) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - pub fn convert_to_sort_exprs_owned( - in_data: &[(Arc, SortOptions)], - ) -> LexOrdering { - LexOrdering::new( - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options: *options, - }) - .collect(), - ) - } - - // Convert each inner tuple to PhysicalSortExpr - pub fn convert_to_orderings_owned( - orderings: &[Vec<(Arc, SortOptions)>], - ) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) - .collect() + let exprs = args.iter().map(|(expr, options)| { + PhysicalSortRequirement::new(Arc::clone(*expr), *options) + }); + LexRequirement::new(exprs).unwrap() } #[test] @@ -270,49 +199,49 @@ mod tests { ])); let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; + let col_x = Arc::new(Column::new("x", 3)) as _; + let col_y = Arc::new(Column::new("y", 4)) as _; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); // b and c are aliases. Existing equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); + assert!(eq_groups.contains(&col_x)); + assert!(eq_groups.contains(&col_y)); Ok(()) } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 0efd46ad912e9..aa65c4a80ae9a 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -16,115 +16,83 @@ // under the License. use std::fmt::Display; -use std::hash::Hash; +use std::ops::Deref; use std::sync::Arc; use std::vec::IntoIter; -use crate::equivalence::add_offset_to_expr; -use crate::{LexOrdering, PhysicalExpr}; +use crate::expressions::with_new_schema; +use crate::{add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr}; use arrow::compute::SortOptions; -use datafusion_common::HashSet; +use arrow::datatypes::SchemaRef; +use datafusion_common::{HashSet, Result}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -/// An `OrderingEquivalenceClass` object keeps track of different alternative -/// orderings than can describe a schema. For example, consider the following table: +/// An `OrderingEquivalenceClass` keeps track of distinct alternative orderings +/// than can describe a table. For example, consider the following table: /// /// ```text -/// |a|b|c|d| -/// |1|4|3|1| -/// |2|3|3|2| -/// |3|1|2|2| -/// |3|2|1|3| +/// ┌───┬───┬───┬───┐ +/// │ a │ b │ c │ d │ +/// ├───┼───┼───┼───┤ +/// │ 1 │ 4 │ 3 │ 1 │ +/// │ 2 │ 3 │ 3 │ 2 │ +/// │ 3 │ 1 │ 2 │ 2 │ +/// │ 3 │ 2 │ 1 │ 3 │ +/// └───┴───┴───┴───┘ /// ``` /// -/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table +/// Here, both `[a ASC, b ASC]` and `[c DESC, d ASC]` describe the table /// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] +/// +/// An `OrderingEquivalenceClass` is a set of such equivalent orderings, which +/// is represented by a vector of `LexOrdering`s. The set does not store any +/// redundant information by enforcing the invariant that no suffix of an +/// ordering in the equivalence class is a prefix of another ordering in the +/// equivalence class. The set can be empty, which means that there are no +/// orderings that describe the table. +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct OrderingEquivalenceClass { orderings: Vec, } impl OrderingEquivalenceClass { - /// Creates new empty ordering equivalence class. - pub fn empty() -> Self { - Default::default() - } - /// Clears (empties) this ordering equivalence class. pub fn clear(&mut self) { self.orderings.clear(); } - /// Creates new ordering equivalence class from the given orderings - /// - /// Any redundant entries are removed - pub fn new(orderings: Vec) -> Self { - let mut result = Self { orderings }; + /// Creates a new ordering equivalence class from the given orderings + /// and removes any redundant entries (if given). + pub fn new( + orderings: impl IntoIterator>, + ) -> Self { + let mut result = Self { + orderings: orderings.into_iter().filter_map(LexOrdering::new).collect(), + }; result.remove_redundant_entries(); result } - /// Converts this OrderingEquivalenceClass to a vector of orderings. - pub fn into_inner(self) -> Vec { - self.orderings - } - - /// Checks whether `ordering` is a member of this equivalence class. - pub fn contains(&self, ordering: &LexOrdering) -> bool { - self.orderings.contains(ordering) - } - - /// Adds `ordering` to this equivalence class. - #[allow(dead_code)] - #[deprecated( - since = "45.0.0", - note = "use OrderingEquivalenceClass::add_new_ordering instead" - )] - fn push(&mut self, ordering: LexOrdering) { - self.add_new_ordering(ordering) - } - - /// Checks whether this ordering equivalence class is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalent orderings in this class. - /// - /// Note this class also implements [`IntoIterator`] to return an iterator - /// over owned [`LexOrdering`]s. - pub fn iter(&self) -> impl Iterator { - self.orderings.iter() - } - - /// Returns how many equivalent orderings there are in this class. - pub fn len(&self) -> usize { - self.orderings.len() - } - - /// Extend this ordering equivalence class with the `other` class. - pub fn extend(&mut self, other: Self) { - self.orderings.extend(other.orderings); + /// Extend this ordering equivalence class with the given orderings. + pub fn extend(&mut self, orderings: impl IntoIterator) { + self.orderings.extend(orderings); // Make sure that there are no redundant orderings: self.remove_redundant_entries(); } - /// Adds new orderings into this ordering equivalence class - pub fn add_new_orderings( + /// Adds new orderings into this ordering equivalence class. + pub fn add_orderings( &mut self, - orderings: impl IntoIterator, + sort_exprs: impl IntoIterator>, ) { - self.orderings.extend(orderings); + self.orderings + .extend(sort_exprs.into_iter().filter_map(LexOrdering::new)); // Make sure that there are no redundant orderings: self.remove_redundant_entries(); } - /// Adds a single ordering to the existing ordering equivalence class. - pub fn add_new_ordering(&mut self, ordering: LexOrdering) { - self.add_new_orderings([ordering]); - } - - /// Removes redundant orderings from this equivalence class. + /// Removes redundant orderings from this ordering equivalence class. /// /// For instance, if we already have the ordering `[a ASC, b ASC, c DESC]`, /// then there is no need to keep ordering `[a ASC, b ASC]` in the state. @@ -133,82 +101,72 @@ impl OrderingEquivalenceClass { while work { work = false; let mut idx = 0; - while idx < self.orderings.len() { + 'outer: while idx < self.orderings.len() { let mut ordering_idx = idx + 1; - let mut removal = self.orderings[idx].is_empty(); while ordering_idx < self.orderings.len() { - work |= self.resolve_overlap(idx, ordering_idx); - if self.orderings[idx].is_empty() { - removal = true; - break; + if let Some(remove) = self.resolve_overlap(idx, ordering_idx) { + work = true; + if remove { + self.orderings.swap_remove(idx); + continue 'outer; + } } - work |= self.resolve_overlap(ordering_idx, idx); - if self.orderings[ordering_idx].is_empty() { - self.orderings.swap_remove(ordering_idx); - } else { - ordering_idx += 1; + if let Some(remove) = self.resolve_overlap(ordering_idx, idx) { + work = true; + if remove { + self.orderings.swap_remove(ordering_idx); + continue; + } } + ordering_idx += 1; } - if removal { - self.orderings.swap_remove(idx); - } else { - idx += 1; - } + idx += 1; } } } /// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of - /// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. + /// `orderings[pre_idx]`. If there is any overlap, returns a `Some(true)` + /// if any trimming took place, and `Some(false)` otherwise. If there is + /// no overlap, returns `None`. /// /// For example, if `orderings[idx]` is `[a ASC, b ASC, c DESC]` and /// `orderings[pre_idx]` is `[b ASC, c DESC]`, then the function will trim /// `orderings[idx]` to `[a ASC]`. - fn resolve_overlap(&mut self, idx: usize, pre_idx: usize) -> bool { + fn resolve_overlap(&mut self, idx: usize, pre_idx: usize) -> Option { let length = self.orderings[idx].len(); let other_length = self.orderings[pre_idx].len(); for overlap in 1..=length.min(other_length) { if self.orderings[idx][length - overlap..] == self.orderings[pre_idx][..overlap] { - self.orderings[idx].truncate(length - overlap); - return true; + return Some(!self.orderings[idx].truncate(length - overlap)); } } - false + None } /// Returns the concatenation of all the orderings. This enables merge /// operations to preserve all equivalent orderings simultaneously. pub fn output_ordering(&self) -> Option { - let output_ordering = self - .orderings - .iter() - .flatten() - .cloned() - .collect::() - .collapse(); - (!output_ordering.is_empty()).then_some(output_ordering) + self.orderings.iter().cloned().reduce(|mut cat, o| { + cat.extend(o); + cat + }) } - // Append orderings in `other` to all existing orderings in this equivalence - // class. + // Append orderings in `other` to all existing orderings in this ordering + // equivalence class. pub fn join_suffix(mut self, other: &Self) -> Self { let n_ordering = self.orderings.len(); - // Replicate entries before cross product + // Replicate entries before cross product: let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); - self.orderings = self - .orderings - .iter() - .cloned() - .cycle() - .take(n_cross) - .collect(); - // Suffix orderings of other to the current orderings. + self.orderings = self.orderings.into_iter().cycle().take(n_cross).collect(); + // Append sort expressions of `other` to the current orderings: for (outer_idx, ordering) in other.iter().enumerate() { - for idx in 0..n_ordering { - // Calculate cross product index - let idx = outer_idx * n_ordering + idx; + let base = outer_idx * n_ordering; + // Use the cross product index: + for idx in base..(base + n_ordering) { self.orderings[idx].extend(ordering.iter().cloned()); } } @@ -217,12 +175,40 @@ impl OrderingEquivalenceClass { /// Adds `offset` value to the index of each expression inside this /// ordering equivalence class. - pub fn add_offset(&mut self, offset: usize) { - for ordering in self.orderings.iter_mut() { - ordering.transform(|sort_expr| { - sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset); - }) + pub fn add_offset(&mut self, offset: isize) -> Result<()> { + let orderings = std::mem::take(&mut self.orderings); + for ordering_result in orderings + .into_iter() + .map(|o| add_offset_to_physical_sort_exprs(o, offset)) + { + self.orderings.extend(LexOrdering::new(ordering_result?)); } + Ok(()) + } + + /// Transforms this `OrderingEquivalenceClass` by mapping columns in the + /// original schema to columns in the new schema by index. The new schema + /// and the original schema needs to be aligned; i.e. they should have the + /// same number of columns, and fields at the same index have the same type + /// in both schemas. + pub fn with_new_schema(mut self, schema: &SchemaRef) -> Result { + self.orderings = self + .orderings + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = with_new_schema(sort_expr.expr, schema)?; + Ok(sort_expr) + }) + .collect::>>() + // The following `unwrap` is safe because the vector will always + // be non-empty. + .map(|v| LexOrdering::new(v).unwrap()) + }) + .collect::>()?; + Ok(self) } /// Gets sort options associated with this expression if it is a leading @@ -257,31 +243,6 @@ impl OrderingEquivalenceClass { /// added as a constant during `ordering_satisfy_requirement()` iterations /// after the corresponding prefix requirement is satisfied. /// - /// ### Example Scenarios - /// - /// In these scenarios, we assume that all expressions share the same sort - /// properties. - /// - /// #### Case 1: Sort Requirement `[a, c]` - /// - /// **Existing Orderings:** `[[a, b, c], [a, d]]`, **Constants:** `[]` - /// 1. `ordering_satisfy_single()` returns `true` because the requirement - /// `a` is satisfied by `[a, b, c].first()`. - /// 2. `a` is added as a constant for the next iteration. - /// 3. The normalized orderings become `[[b, c], [d]]`. - /// 4. `ordering_satisfy_single()` returns `false` for `c`, as neither - /// `[b, c]` nor `[d]` satisfies `c`. - /// - /// #### Case 2: Sort Requirement `[a, d]` - /// - /// **Existing Orderings:** `[[a, b, c], [a, d]]`, **Constants:** `[]` - /// 1. `ordering_satisfy_single()` returns `true` because the requirement - /// `a` is satisfied by `[a, b, c].first()`. - /// 2. `a` is added as a constant for the next iteration. - /// 3. The normalized orderings become `[[b, c], [d]]`. - /// 4. `ordering_satisfy_single()` returns `true` for `d`, as `[d]` satisfies - /// `d`. - /// /// ### Future Improvements /// /// This function may become unnecessary if any of the following improvements @@ -296,15 +257,14 @@ impl OrderingEquivalenceClass { ]; for ordering in self.iter() { - if let Some(leading_ordering) = ordering.first() { - if leading_ordering.expr.eq(expr) { - let opt = ( - leading_ordering.options.descending, - leading_ordering.options.nulls_first, - ); - constantness_defining_pairs[0].remove(&opt); - constantness_defining_pairs[1].remove(&opt); - } + let leading_ordering = ordering.first(); + if leading_ordering.expr.eq(expr) { + let opt = ( + leading_ordering.options.descending, + leading_ordering.options.nulls_first, + ); + constantness_defining_pairs[0].remove(&opt); + constantness_defining_pairs[1].remove(&opt); } } @@ -314,10 +274,26 @@ impl OrderingEquivalenceClass { } } -/// Convert the `OrderingEquivalenceClass` into an iterator of LexOrderings +impl Deref for OrderingEquivalenceClass { + type Target = [LexOrdering]; + + fn deref(&self) -> &Self::Target { + self.orderings.as_slice() + } +} + +impl From> for OrderingEquivalenceClass { + fn from(orderings: Vec) -> Self { + let mut result = Self { orderings }; + result.remove_redundant_entries(); + result + } +} + +/// Convert the `OrderingEquivalenceClass` into an iterator of `LexOrdering`s. impl IntoIterator for OrderingEquivalenceClass { type Item = LexOrdering; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.orderings.into_iter() @@ -329,13 +305,18 @@ impl Display for OrderingEquivalenceClass { write!(f, "[")?; let mut iter = self.orderings.iter(); if let Some(ordering) = iter.next() { - write!(f, "[{}]", ordering)?; + write!(f, "[{ordering}]")?; } for ordering in iter { - write!(f, ", [{}]", ordering)?; + write!(f, ", [{ordering}]")?; } - write!(f, "]")?; - Ok(()) + write!(f, "]") + } +} + +impl From for Vec { + fn from(oeq_class: OrderingEquivalenceClass) -> Self { + oeq_class.orderings } } @@ -343,12 +324,10 @@ impl Display for OrderingEquivalenceClass { mod tests { use std::sync::Arc; - use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, create_test_schema, - }; + use crate::equivalence::tests::create_test_schema; use crate::equivalence::{ - EquivalenceClass, EquivalenceGroup, EquivalenceProperties, - OrderingEquivalenceClass, + convert_to_orderings, convert_to_sort_exprs, EquivalenceClass, EquivalenceGroup, + EquivalenceProperties, OrderingEquivalenceClass, }; use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; @@ -359,9 +338,9 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; - use datafusion_physical_expr_common::sort_expr::LexOrdering; #[test] fn test_ordering_satisfy() -> Result<()> { @@ -369,11 +348,11 @@ mod tests { Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), ])); - let crude = LexOrdering::new(vec![PhysicalSortExpr { + let crude = vec![PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), - }]); - let finer = LexOrdering::new(vec![ + }]; + let finer = vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), @@ -382,20 +361,18 @@ mod tests { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), }, - ]); + ]; // finer ordering satisfies, crude ordering should return true let eq_properties_finer = EquivalenceProperties::new_with_orderings( Arc::clone(&input_schema), - &[finer.clone()], + [finer.clone()], ); - assert!(eq_properties_finer.ordering_satisfy(crude.as_ref())); + assert!(eq_properties_finer.ordering_satisfy(crude.clone())?); // Crude ordering doesn't satisfy finer ordering. should return false - let eq_properties_crude = EquivalenceProperties::new_with_orderings( - Arc::clone(&input_schema), - &[crude.clone()], - ); - assert!(!eq_properties_crude.ordering_satisfy(finer.as_ref())); + let eq_properties_crude = + EquivalenceProperties::new_with_orderings(Arc::clone(&input_schema), [crude]); + assert!(!eq_properties_crude.ordering_satisfy(finer)?); Ok(()) } @@ -414,16 +391,19 @@ mod tests { Arc::clone(&test_fun), vec![Arc::clone(col_a)], &test_schema, + Arc::new(ConfigOptions::default()), )?) as PhysicalExprRef; let floor_f = Arc::new(ScalarFunctionExpr::try_new( Arc::clone(&test_fun), vec![Arc::clone(col_f)], &test_schema, + Arc::new(ConfigOptions::default()), )?) as PhysicalExprRef; let exp_a = Arc::new(ScalarFunctionExpr::try_new( Arc::clone(&test_fun), vec![Arc::clone(col_a)], &test_schema, + Arc::new(ConfigOptions::default()), )?) as PhysicalExprRef; let a_plus_b = Arc::new(BinaryExpr::new( @@ -663,30 +643,20 @@ mod tests { format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - let eq_group = eq_group + eq_properties.add_orderings(orderings); + let classes = eq_group .into_iter() - .map(|eq_class| { - let eq_classes = eq_class.into_iter().cloned().collect::>(); - EquivalenceClass::new(eq_classes) - }) - .collect::>(); - let eq_group = EquivalenceGroup::new(eq_group); - eq_properties.add_equivalence_group(eq_group); + .map(|eq_class| EquivalenceClass::new(eq_class.into_iter().cloned())); + let eq_group = EquivalenceGroup::new(classes); + eq_properties.add_equivalence_group(eq_group)?; let constants = constants.into_iter().map(|expr| { - ConstExpr::from(expr) - .with_across_partitions(AcrossPartitions::Uniform(None)) + ConstExpr::new(Arc::clone(expr), AcrossPartitions::Uniform(None)) }); - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(constants)?; let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(reqs.as_ref()), - expected, - "{}", - err_msg - ); + assert_eq!(eq_properties.ordering_satisfy(reqs)?, expected, "{err_msg}"); } Ok(()) @@ -707,7 +677,7 @@ mod tests { }; // a=c (e.g they are aliases). let mut eq_properties = EquivalenceProperties::new(test_schema); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_c))?; let orderings = vec![ vec![(col_a, options)], @@ -717,7 +687,7 @@ mod tests { let orderings = convert_to_orderings(&orderings); // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); // First entry in the tuple is required ordering, second entry is the expected flag // that indicates whether this required ordering is satisfied. @@ -739,14 +709,9 @@ mod tests { for (reqs, expected) in test_cases { let err_msg = - format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); + format!("error in test reqs: {reqs:?}, expected: {expected:?}",); let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(reqs.as_ref()), - expected, - "{}", - err_msg - ); + assert_eq!(eq_properties.ordering_satisfy(reqs)?, expected, "{err_msg}"); } Ok(()) @@ -856,7 +821,7 @@ mod tests { // ------- TEST CASE 5 --------- // Empty ordering ( - vec![vec![]], + vec![], // No ordering in the state (empty ordering is ignored). vec![], ), @@ -975,13 +940,11 @@ mod tests { for (orderings, expected) in test_cases { let orderings = convert_to_orderings(&orderings); let expected = convert_to_orderings(&expected); - let actual = OrderingEquivalenceClass::new(orderings.clone()); - let actual = actual.orderings; + let actual = OrderingEquivalenceClass::from(orderings.clone()); let err_msg = format!( - "orderings: {:?}, expected: {:?}, actual :{:?}", - orderings, expected, actual + "orderings: {orderings:?}, expected: {expected:?}, actual :{actual:?}" ); - assert_eq!(actual.len(), expected.len(), "{}", err_msg); + assert_eq!(actual.len(), expected.len(), "{err_msg}"); for elem in actual { assert!(expected.contains(&elem), "{}", err_msg); } diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 035678fbf1f39..a4ed8187cfadd 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; use std::sync::Arc; use crate::expressions::Column; @@ -22,15 +23,55 @@ use crate::PhysicalExpr; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{internal_err, plan_err, Result}; + +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use indexmap::IndexMap; + +/// Stores target expressions, along with their indices, that associate with a +/// source expression in a projection mapping. +#[derive(Clone, Debug, Default)] +pub struct ProjectionTargets { + /// A non-empty vector of pairs of target expressions and their indices. + /// Consider using a special non-empty collection type in the future (e.g. + /// if Rust provides one in the standard library). + exprs_indices: Vec<(Arc, usize)>, +} + +impl ProjectionTargets { + /// Returns the first target expression and its index. + pub fn first(&self) -> &(Arc, usize) { + // Since the vector is non-empty, we can safely unwrap: + self.exprs_indices.first().unwrap() + } + + /// Adds a target expression and its index to the list of targets. + pub fn push(&mut self, target: (Arc, usize)) { + self.exprs_indices.push(target); + } +} + +impl Deref for ProjectionTargets { + type Target = [(Arc, usize)]; + + fn deref(&self) -> &Self::Target { + &self.exprs_indices + } +} + +impl From, usize)>> for ProjectionTargets { + fn from(exprs_indices: Vec<(Arc, usize)>) -> Self { + Self { exprs_indices } + } +} /// Stores the mapping between source expressions and target expressions for a /// projection. -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct ProjectionMapping { /// Mapping between source expressions and target expressions. /// Vector indices correspond to the indices after projection. - pub map: Vec<(Arc, Arc)>, + map: IndexMap, ProjectionTargets>, } impl ProjectionMapping { @@ -42,44 +83,46 @@ impl ProjectionMapping { /// projection mapping would be: /// /// ```text - /// [0]: (c + d, col("c + d")) - /// [1]: (a + b, col("a + b")) + /// [0]: (c + d, [(col("c + d"), 0)]) + /// [1]: (a + b, [(col("a + b"), 1)]) /// ``` /// /// where `col("c + d")` means the column named `"c + d"`. pub fn try_new( - expr: &[(Arc, String)], + expr: impl IntoIterator, String)>, input_schema: &SchemaRef, ) -> Result { // Construct a map from the input expressions to the output expression of the projection: - expr.iter() - .enumerate() - .map(|(expr_idx, (expression, name))| { - let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - Arc::clone(expression) - .transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => { - // Sometimes, an expression and its name in the input_schema - // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. - // Conceptually, `source_expr` and `expression` should be the same. - let idx = col.index(); - let matching_input_field = input_schema.field(idx); - if col.name() != matching_input_field.name() { - return internal_err!("Input field name {} does not match with the projection expression {}", - matching_input_field.name(),col.name()) - } - let matching_input_column = - Column::new(matching_input_field.name(), idx); - Ok(Transformed::yes(Arc::new(matching_input_column))) - } - None => Ok(Transformed::no(e)), - }) - .data() - .map(|source_expr| (source_expr, target_expr)) + let mut map = IndexMap::<_, ProjectionTargets>::new(); + for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { + let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; + let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_field = input_schema.field(idx); + let matching_name = matching_field.name(); + if col.name() != matching_name { + return internal_err!( + "Input field name {} does not match with the projection expression {}", + matching_name, + col.name() + ); + } + let matching_column = Column::new(matching_name, idx); + Ok(Transformed::yes(Arc::new(matching_column))) + } + None => Ok(Transformed::no(e)), }) - .collect::>>() - .map(|map| Self { map }) + .data()?; + map.entry(source_expr) + .or_default() + .push((target_expr, expr_idx)); + } + Ok(Self { map }) } /// Constructs a subset mapping using the provided indices. @@ -87,67 +130,136 @@ impl ProjectionMapping { /// This is used when the output is a subset of the input without any /// other transformations. The indices are for columns in the schema. pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result { - let projection_exprs = project_index_to_exprs(indices, schema); - ProjectionMapping::try_new(&projection_exprs, schema) + let projection_exprs = indices.iter().map(|index| { + let field = schema.field(*index); + let column = Arc::new(Column::new(field.name(), *index)); + (column as _, field.name().clone()) + }); + ProjectionMapping::try_new(projection_exprs, schema) } +} + +impl Deref for ProjectionMapping { + type Target = IndexMap, ProjectionTargets>; - /// Iterate over pairs of (source, target) expressions - pub fn iter( - &self, - ) -> impl Iterator, Arc)> + '_ { - self.map.iter() + fn deref(&self) -> &Self::Target { + &self.map } +} - /// This function returns the target expression for a given source expression. - /// - /// # Arguments - /// - /// * `expr` - Source physical expression. - /// - /// # Returns - /// - /// An `Option` containing the target for the given source expression, - /// where a `None` value means that `expr` is not inside the mapping. - pub fn target_expr( - &self, - expr: &Arc, - ) -> Option> { - self.map - .iter() - .find(|(source, _)| source.eq(expr)) - .map(|(_, target)| Arc::clone(target)) +impl FromIterator<(Arc, ProjectionTargets)> for ProjectionMapping { + fn from_iter, ProjectionTargets)>>( + iter: T, + ) -> Self { + Self { + map: IndexMap::from_iter(iter), + } } } -fn project_index_to_exprs( - projection_index: &[usize], +/// Projects a slice of [LexOrdering]s onto the given schema. +/// +/// This is a convenience wrapper that applies [project_ordering] to each +/// input ordering and collects the successful projections: +/// - For each input ordering, the result of [project_ordering] is appended to +/// the output if it is `Some(...)`. +/// - Order is preserved and no deduplication is attempted. +/// - If none of the input orderings can be projected, an empty `Vec` is +/// returned. +/// +/// See [project_ordering] for the semantics of projecting a single +/// [LexOrdering]. +pub fn project_orderings( + orderings: &[LexOrdering], schema: &SchemaRef, -) -> Vec<(Arc, String)> { - projection_index - .iter() - .map(|index| { - let field = schema.field(*index); - ( - Arc::new(Column::new(field.name(), *index)) as Arc, - field.name().to_owned(), - ) - }) - .collect::>() +) -> Vec { + let mut projected_orderings = vec![]; + + for ordering in orderings { + projected_orderings.extend(project_ordering(ordering, schema)); + } + + projected_orderings +} + +/// Projects a single [LexOrdering] onto the given schema. +/// +/// This function attempts to rewrite every [PhysicalSortExpr] in the provided +/// [LexOrdering] so that any [Column] expressions point at the correct field +/// indices in `schema`. +/// +/// Key details: +/// - Columns are matched by name, not by index. The index of each matched +/// column is looked up with [Schema::column_with_name](arrow::datatypes::Schema::column_with_name) and a new +/// [Column] with the correct [index](Column::index) is substituted. +/// - If an expression references a column name that does not exist in +/// `schema`, projection of the current ordering stops and only the already +/// rewritten prefix is kept. This models the fact that a lexicographical +/// ordering remains valid for any leading prefix whose expressions are +/// present in the projected schema. +/// - If no expressions can be projected (i.e. the first one is missing), the +/// function returns `None`. +/// +/// Return value: +/// - `Some(LexOrdering)` if at least one sort expression could be projected. +/// The returned ordering may be a strict prefix of the input ordering. +/// - `None` if no part of the ordering can be projected onto `schema`. +/// +/// Example +/// +/// Suppose we have an input ordering `[col("a@0"), col("b@1")]` but the projected +/// schema only contains b and not a. The result will be `Some([col("a@0")])`. In other +/// words, the column reference is reindexed to match the projected schema. +/// If neither a nor b is present, the result will be None. +pub fn project_ordering( + ordering: &LexOrdering, + schema: &SchemaRef, +) -> Option { + let mut projected_exprs = vec![]; + for PhysicalSortExpr { expr, options } in ordering.iter() { + let transformed = Arc::clone(expr).transform_up(|expr| { + let Some(col) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::no(expr)); + }; + + let name = col.name(); + if let Some((idx, _)) = schema.column_with_name(name) { + // Compute the new column expression (with correct index) after projection: + Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) + } else { + // Cannot find expression in the projected_schema, + // signal this using an Err result + plan_err!("") + } + }); + + match transformed { + Ok(transformed) => { + projected_exprs.push(PhysicalSortExpr::new(transformed.data, *options)); + } + Err(_) => { + // Err result indicates an expression could not be found in the + // projected_schema, stop iterating since rest of the orderings are violated + break; + } + } + } + + LexOrdering::new(projected_exprs) } #[cfg(test)] mod tests { use super::*; - use crate::equivalence::tests::{ - convert_to_orderings, convert_to_orderings_owned, output_schema, - }; - use crate::equivalence::EquivalenceProperties; + use crate::equivalence::tests::output_schema; + use crate::equivalence::{convert_to_orderings, EquivalenceProperties}; use crate::expressions::{col, BinaryExpr}; use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExprRef, ScalarFunctionExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use datafusion_common::config::ConfigOptions; use datafusion_expr::{Operator, ScalarUDF}; #[test] @@ -608,13 +720,12 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let expected = expected @@ -628,17 +739,16 @@ mod tests { .collect::>() }) .collect::>(); - let expected = convert_to_orderings_owned(&expected); + let expected = convert_to_orderings(&expected); let projected_eq = eq_properties.project(&projection_mapping, output_schema); let orderings = projected_eq.oeq_class(); let err_msg = format!( - "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings, expected, projection_mapping + "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } @@ -672,6 +782,7 @@ mod tests { test_fun, vec![Arc::clone(col_c)], &schema, + Arc::new(ConfigOptions::default()), )?) as PhysicalExprRef; let option_asc = SortOptions { @@ -687,9 +798,8 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let col_a_new = &col("a_new", &output_schema)?; @@ -813,7 +923,7 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let expected = convert_to_orderings(expected); @@ -822,11 +932,10 @@ mod tests { let orderings = projected_eq.oeq_class(); let err_msg = format!( - "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings, expected, projection_mapping + "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } @@ -868,9 +977,8 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let col_a_plus_b_new = &col("a+b", &output_schema)?; @@ -955,11 +1063,11 @@ mod tests { for (orderings, equal_columns, expected) in test_cases { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); for (lhs, rhs) in equal_columns { - eq_properties.add_equal_conditions(lhs, rhs)?; + eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?; } let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let expected = convert_to_orderings(&expected); @@ -968,11 +1076,10 @@ mod tests { let orderings = projected_eq.oeq_class(); let err_msg = format!( - "actual: {:?}, expected: {:?}, projection_mapping: {:?}", - orderings, expected, projection_mapping + "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 9eba295e562e2..26d5d32c65121 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -16,71 +16,67 @@ // under the License. use std::fmt::{self, Display}; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use super::expr_refers; use crate::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use indexmap::IndexSet; -use indexmap::IndexMap; +use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; -use super::{expr_refers, ExprWrapper}; - // A list of sort expressions that can be calculated from a known set of /// dependencies. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Dependencies { - inner: IndexSet, + sort_exprs: IndexSet, } impl Display for Dependencies { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "[")?; - let mut iter = self.inner.iter(); + let mut iter = self.sort_exprs.iter(); if let Some(dep) = iter.next() { - write!(f, "{}", dep)?; + write!(f, "{dep}")?; } for dep in iter { - write!(f, ", {}", dep)?; + write!(f, ", {dep}")?; } write!(f, "]") } } impl Dependencies { - /// Create a new empty `Dependencies` instance. - fn new() -> Self { + // Creates a new `Dependencies` instance from the given sort expressions. + pub fn new(sort_exprs: impl IntoIterator) -> Self { Self { - inner: IndexSet::new(), + sort_exprs: sort_exprs.into_iter().collect(), } } +} - /// Create a new `Dependencies` from an iterator of `PhysicalSortExpr`. - pub fn new_from_iter(iter: impl IntoIterator) -> Self { - Self { - inner: iter.into_iter().collect(), - } - } +impl Deref for Dependencies { + type Target = IndexSet; - /// Insert a new dependency into the set. - pub fn insert(&mut self, sort_expr: PhysicalSortExpr) { - self.inner.insert(sort_expr); + fn deref(&self) -> &Self::Target { + &self.sort_exprs } +} - /// Iterator over dependencies in the set - pub fn iter(&self) -> impl Iterator + Clone { - self.inner.iter() +impl DerefMut for Dependencies { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sort_exprs } +} - /// Return the inner set of dependencies - pub fn into_inner(self) -> IndexSet { - self.inner - } +impl IntoIterator for Dependencies { + type Item = PhysicalSortExpr; + type IntoIter = as IntoIterator>::IntoIter; - /// Returns true if there are no dependencies - fn is_empty(&self) -> bool { - self.inner.is_empty() + fn into_iter(self) -> Self::IntoIter { + self.sort_exprs.into_iter() } } @@ -133,26 +129,25 @@ impl<'a> DependencyEnumerator<'a> { let node = dependency_map .get(referred_sort_expr) .expect("`referred_sort_expr` should be inside `dependency_map`"); - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. - let target_sort_expr = node.target_sort_expr.as_ref().unwrap(); + // Since we work on intermediate nodes, we are sure `node.target` exists. + let target = node.target.as_ref().unwrap(); // An empty dependency means the referred_sort_expr represents a global ordering. // Return its projected version, which is the target_expression. if node.dependencies.is_empty() { - return vec![LexOrdering::new(vec![target_sort_expr.clone()])]; + return vec![[target.clone()].into()]; }; node.dependencies .iter() .flat_map(|dep| { - let mut orderings = if self.insert(target_sort_expr, dep) { + let mut orderings = if self.insert(target, dep) { self.construct_orderings(dep, dependency_map) } else { vec![] }; for ordering in orderings.iter_mut() { - ordering.push(target_sort_expr.clone()) + ordering.push(target.clone()); } orderings }) @@ -178,70 +173,55 @@ impl<'a> DependencyEnumerator<'a> { /// # Note on IndexMap Rationale /// /// Using `IndexMap` (which preserves insert order) to ensure consistent results -/// across different executions for the same query. We could have used -/// `HashSet`, `HashMap` in place of them without any loss of functionality. +/// across different executions for the same query. We could have used `HashSet` +/// and `HashMap` instead without any loss of functionality. /// /// As an example, if existing orderings are /// 1. `[a ASC, b ASC]` -/// 2. `[c ASC]` for +/// 2. `[c ASC]` /// /// Then both the following output orderings are valid /// 1. `[a ASC, b ASC, c ASC]` /// 2. `[c ASC, a ASC, b ASC]` /// -/// (this are both valid as they are concatenated versions of the alternative -/// orderings). When using `HashSet`, `HashMap` it is not guaranteed to generate -/// consistent result, among the possible 2 results in the example above. -#[derive(Debug)] +/// These are both valid as they are concatenated versions of the alternative +/// orderings. Had we used `HashSet`/`HashMap`, we couldn't guarantee to generate +/// the same result among the possible two results in the example above. +#[derive(Debug, Default)] pub struct DependencyMap { - inner: IndexMap, + map: IndexMap, } impl DependencyMap { - pub fn new() -> Self { - Self { - inner: IndexMap::new(), - } - } - - /// Insert a new dependency `sort_expr` --> `dependency` into the map. - /// - /// If `target_sort_expr` is none, a new entry is created with empty dependencies. + /// Insert a new dependency of `sort_expr` (i.e. `dependency`) into the map + /// along with its target sort expression. pub fn insert( &mut self, - sort_expr: &PhysicalSortExpr, - target_sort_expr: Option<&PhysicalSortExpr>, - dependency: Option<&PhysicalSortExpr>, + sort_expr: PhysicalSortExpr, + target_sort_expr: Option, + dependency: Option, ) { - self.inner - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.cloned(), - dependencies: Dependencies::new(), - }) - .insert_dependency(dependency) - } - - /// Iterator over (sort_expr, DependencyNode) pairs - pub fn iter(&self) -> impl Iterator { - self.inner.iter() + let entry = self.map.entry(sort_expr); + let node = entry.or_insert_with(|| DependencyNode { + target: target_sort_expr, + dependencies: Dependencies::default(), + }); + node.dependencies.extend(dependency); } +} - /// iterator over all sort exprs - pub fn sort_exprs(&self) -> impl Iterator { - self.inner.keys() - } +impl Deref for DependencyMap { + type Target = IndexMap; - /// Return the dependency node for the given sort expression, if any - pub fn get(&self, sort_expr: &PhysicalSortExpr) -> Option<&DependencyNode> { - self.inner.get(sort_expr) + fn deref(&self) -> &Self::Target { + &self.map } } impl Display for DependencyMap { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "DependencyMap: {{")?; - for (sort_expr, node) in self.inner.iter() { + for (sort_expr, node) in self.map.iter() { writeln!(f, " {sort_expr} --> {node}")?; } writeln!(f, "}}") @@ -256,30 +236,21 @@ impl Display for DependencyMap { /// /// # Fields /// -/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target -/// sort expression associated with the node. It is `None` if the sort expression +/// - `target`: An optional `PhysicalSortExpr` representing the target sort +/// expression associated with the node. It is `None` if the sort expression /// cannot be projected. /// - `dependencies`: A [`Dependencies`] containing dependencies on other sort /// expressions that are referred to by the target sort expression. #[derive(Debug, Clone, PartialEq, Eq)] pub struct DependencyNode { - pub target_sort_expr: Option, - pub dependencies: Dependencies, -} - -impl DependencyNode { - /// Insert dependency to the state (if exists). - fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { - if let Some(dep) = dependency { - self.dependencies.insert(dep.clone()); - } - } + pub(crate) target: Option, + pub(crate) dependencies: Dependencies, } impl Display for DependencyNode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(target) = &self.target_sort_expr { - write!(f, "(target: {}, ", target)?; + if let Some(target) = &self.target { + write!(f, "(target: {target}, ")?; } else { write!(f, "(")?; } @@ -307,12 +278,12 @@ pub fn referred_dependencies( source: &Arc, ) -> Vec { // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: - let mut expr_to_sort_exprs = IndexMap::::new(); + let mut expr_to_sort_exprs = IndexMap::<_, Dependencies>::new(); for sort_expr in dependency_map - .sort_exprs() + .keys() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { - let key = ExprWrapper(Arc::clone(&sort_expr.expr)); + let key = Arc::clone(&sort_expr.expr); expr_to_sort_exprs .entry(key) .or_default() @@ -322,16 +293,10 @@ pub fn referred_dependencies( // Generate all valid dependencies for the source. For example, if the source // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - let dependencies = expr_to_sort_exprs + expr_to_sort_exprs .into_values() - .map(Dependencies::into_inner) - .collect::>(); - dependencies - .iter() .multi_cartesian_product() - .map(|referred_deps| { - Dependencies::new_from_iter(referred_deps.into_iter().cloned()) - }) + .map(Dependencies::new) .collect() } @@ -378,46 +343,39 @@ pub fn construct_prefix_orderings( /// # Parameters /// /// * `dependencies` - Set of relevant expressions. -/// * `dependency_map` - Map of dependencies for expressions that may appear in `dependencies` +/// * `dependency_map` - Map of dependencies for expressions that may appear in +/// `dependencies`. /// /// # Returns /// -/// A vector of lexical orderings (`Vec`) representing all valid orderings -/// based on the given dependencies. +/// A vector of lexical orderings (`Vec`) representing all valid +/// orderings based on the given dependencies. pub fn generate_dependency_orderings( dependencies: &Dependencies, dependency_map: &DependencyMap, ) -> Vec { // Construct all the valid prefix orderings for each expression appearing - // in the projection: - let relevant_prefixes = dependencies + // in the projection. Note that if relevant prefixes are empty, there is no + // dependency, meaning that dependent is a leading ordering. + dependencies .iter() - .flat_map(|dep| { + .filter_map(|dep| { let prefixes = construct_prefix_orderings(dep, dependency_map); (!prefixes.is_empty()).then_some(prefixes) }) - .collect::>(); - - // No dependency, dependent is a leading ordering. - if relevant_prefixes.is_empty() { - // Return an empty ordering: - return vec![LexOrdering::default()]; - } - - relevant_prefixes - .into_iter() + // Generate all possible valid orderings: .multi_cartesian_product() .flat_map(|prefix_orderings| { + let length = prefix_orderings.len(); prefix_orderings - .iter() - .permutations(prefix_orderings.len()) - .map(|prefixes| { - prefixes - .into_iter() - .flat_map(|ordering| ordering.clone()) - .collect() + .into_iter() + .permutations(length) + .filter_map(|prefixes| { + prefixes.into_iter().reduce(|mut acc, ordering| { + acc.extend(ordering); + acc + }) }) - .collect::>() }) .collect() } @@ -429,21 +387,24 @@ mod tests { use super::*; use crate::equivalence::tests::{ - convert_to_sort_exprs, convert_to_sort_reqs, create_test_params, - create_test_schema, output_schema, parse_sort_expr, + convert_to_sort_reqs, create_test_params, create_test_schema, output_schema, + parse_sort_expr, }; - use crate::equivalence::ProjectionMapping; + use crate::equivalence::{convert_to_sort_exprs, ProjectionMapping}; use crate::expressions::{col, BinaryExpr, CastExpr, Column}; use crate::{ConstExpr, EquivalenceProperties, ScalarFunctionExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{Constraint, Constraints, Result}; use datafusion_expr::sort_properties::SortProperties; use datafusion_expr::Operator; - use datafusion_functions::string::concat; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{ + LexRequirement, PhysicalSortRequirement, + }; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -463,7 +424,7 @@ mod tests { (Arc::clone(&col_a), "a3".to_string()), (Arc::clone(&col_a), "a4".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let out_schema = output_schema(&projection_mapping, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 @@ -473,7 +434,7 @@ mod tests { (Arc::clone(&col_a), "a3".to_string()), (Arc::clone(&col_a), "a4".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let col_a1 = &col("a1", &out_schema)?; @@ -506,20 +467,20 @@ mod tests { let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); // add equivalent ordering [a, b, c, d] - input_properties.add_new_ordering(LexOrdering::new(vec![ + input_properties.add_ordering([ parse_sort_expr("a", &input_schema), parse_sort_expr("b", &input_schema), parse_sort_expr("c", &input_schema), parse_sort_expr("d", &input_schema), - ])); + ]); // add equivalent ordering [a, c, b, d] - input_properties.add_new_ordering(LexOrdering::new(vec![ + input_properties.add_ordering([ parse_sort_expr("a", &input_schema), parse_sort_expr("c", &input_schema), parse_sort_expr("b", &input_schema), // NB b and c are swapped parse_sort_expr("d", &input_schema), - ])); + ]); // simply project all the columns in order let proj_exprs = vec![ @@ -528,7 +489,7 @@ mod tests { (col("c", &input_schema)?, "c".to_string()), (col("d", &input_schema)?, "d".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let out_properties = input_properties.project(&projection_mapping, input_schema); assert_eq!( @@ -541,8 +502,6 @@ mod tests { #[test] fn test_normalize_ordering_equivalence_classes() -> Result<()> { - let sort_options = SortOptions::default(); - let schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), @@ -553,35 +512,19 @@ mod tests { let col_c_expr = col("c", &schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr)?; - let others = vec![ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b_expr), - options: sort_options, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_c_expr), - options: sort_options, - }]), - ]; - eq_properties.add_new_orderings(others); + eq_properties.add_equal_conditions(col_a_expr, Arc::clone(&col_c_expr))?; + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new_default(Arc::clone(&col_b_expr))], + vec![PhysicalSortExpr::new_default(Arc::clone(&col_c_expr))], + ]); let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); - expected_eqs.add_new_orderings([ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b_expr), - options: sort_options, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_c_expr), - options: sort_options, - }]), + expected_eqs.add_orderings([ + vec![PhysicalSortExpr::new_default(col_b_expr)], + vec![PhysicalSortExpr::new_default(col_c_expr)], ]); - let oeq_class = eq_properties.oeq_class().clone(); - let expected = expected_eqs.oeq_class(); - assert!(oeq_class.eq(expected)); - + assert!(eq_properties.oeq_class().eq(expected_eqs.oeq_class())); Ok(()) } @@ -594,34 +537,22 @@ mod tests { Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let required_columns = [Arc::clone(&col_b), Arc::clone(&col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ])]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_b), - options: sort_options_not - }, - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: sort_options - } - ]) + vec![ + PhysicalSortExpr::new(col_b, sort_options_not), + PhysicalSortExpr::new(col_a, sort_options), + ] ); let schema = Schema::new(vec![ @@ -629,40 +560,28 @@ mod tests { Field::new("b", DataType::Int32, true), Field::new("c", DataType::Int32, true), ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let required_columns = [Arc::clone(&col_b), Arc::clone(&col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }]), - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]), + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new( + Arc::new(Column::new("c", 2)), + sort_options, + )], + vec![ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ], ]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_b), - options: sort_options_not - }, - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: sort_options - } - ]) + vec![ + PhysicalSortExpr::new(col_b, sort_options_not), + PhysicalSortExpr::new(col_a, sort_options), + ] ); let required_columns = [ @@ -677,21 +596,12 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); // not satisfied orders - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ])]); - let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("c", 2)), sort_options), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ]); + let (_, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0]); Ok(()) @@ -707,49 +617,35 @@ mod tests { ]); let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + let col_d = col("d", &schema)?; let option_asc = SortOptions { descending: false, nulls_first: false, }; // b=a (e.g they are aliases) - eq_properties.add_equal_conditions(col_b, col_a)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; // [b ASC], [d ASC] - eq_properties.add_new_orderings(vec![ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(col_b), - options: option_asc, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(col_d), - options: option_asc, - }]), + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new(Arc::clone(&col_b), option_asc)], + vec![PhysicalSortExpr::new(Arc::clone(&col_d), option_asc)], ]); let test_cases = vec![ // d + b ( - Arc::new(BinaryExpr::new( - Arc::clone(col_d), - Operator::Plus, - Arc::clone(col_b), - )) as Arc, + Arc::new(BinaryExpr::new(col_d, Operator::Plus, Arc::clone(&col_b))) as _, SortProperties::Ordered(option_asc), ), // b - (Arc::clone(col_b), SortProperties::Ordered(option_asc)), + (col_b, SortProperties::Ordered(option_asc)), // a - (Arc::clone(col_a), SortProperties::Ordered(option_asc)), + (Arc::clone(&col_a), SortProperties::Ordered(option_asc)), // a + c ( - Arc::new(BinaryExpr::new( - Arc::clone(col_a), - Operator::Plus, - Arc::clone(col_c), - )), + Arc::new(BinaryExpr::new(col_a, Operator::Plus, col_c)), SortProperties::Unordered, ), ]; @@ -757,14 +653,14 @@ mod tests { let leading_orderings = eq_properties .oeq_class() .iter() - .flat_map(|ordering| ordering.first().cloned()) + .map(|ordering| ordering.first().clone()) .collect::>(); let expr_props = eq_properties.get_expr_properties(Arc::clone(&expr)); let err_msg = format!( "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", expr, expected, expr_props.sort_properties ); - assert_eq!(expr_props.sort_properties, expected, "{}", err_msg); + assert_eq!(expr_props.sort_properties, expected, "{err_msg}"); } Ok(()) @@ -790,7 +686,7 @@ mod tests { Arc::clone(col_a), Operator::Plus, Arc::clone(col_d), - )) as Arc; + )) as _; let option_asc = SortOptions { descending: false, @@ -801,16 +697,10 @@ mod tests { nulls_first: true, }; // [d ASC, h DESC] also satisfies schema. - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_d), - options: option_asc, - }, - PhysicalSortExpr { - expr: Arc::clone(col_h), - options: option_desc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(col_d), option_asc), + PhysicalSortExpr::new(Arc::clone(col_h), option_desc), + ]); let test_cases = vec![ // TEST CASE 1 (vec![col_a], vec![(col_a, option_asc)]), @@ -878,7 +768,7 @@ mod tests { for (exprs, expected) in test_cases { let exprs = exprs.into_iter().cloned().collect::>(); let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); + let (actual, _) = eq_properties.find_longest_permutation(&exprs)?; assert_eq!(actual, expected); } @@ -896,7 +786,7 @@ mod tests { let col_h = &col("h", &test_schema)?; // Add column h as constant - eq_properties = eq_properties.with_constants(vec![ConstExpr::from(col_h)]); + eq_properties.add_constants(vec![ConstExpr::from(Arc::clone(col_h))])?; let test_cases = vec![ // TEST CASE 1 @@ -907,72 +797,13 @@ mod tests { for (exprs, expected) in test_cases { let exprs = exprs.into_iter().cloned().collect::>(); let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); + let (actual, _) = eq_properties.find_longest_permutation(&exprs)?; assert_eq!(actual, expected); } Ok(()) } - #[test] - fn test_get_finer() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. - // Third entry is the expected result. - let tests_cases = vec![ - // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC)] - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, None), (col_b, Some(option_asc))], - Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] - ( - vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ], - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - Some(vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] - // result should be None - ( - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], - None, - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_reqs(&lhs); - let rhs = convert_to_sort_reqs(&rhs); - let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); - let finer = eq_properties.get_finer_requirement(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - #[test] fn test_normalize_sort_reqs() -> Result<()> { // Schema satisfies following properties @@ -1040,7 +871,7 @@ mod tests { let expected_normalized = convert_to_sort_reqs(&expected_normalized); assert_eq!( - eq_properties.normalize_sort_requirements(&req), + eq_properties.normalize_sort_requirements(req).unwrap(), expected_normalized ); } @@ -1073,8 +904,9 @@ mod tests { for (reqs, expected) in test_cases.into_iter() { let reqs = convert_to_sort_reqs(&reqs); let expected = convert_to_sort_reqs(&expected); - - let normalized = eq_properties.normalize_sort_requirements(&reqs); + let normalized = eq_properties + .normalize_sort_requirements(reqs.clone()) + .unwrap(); assert!( expected.eq(&normalized), "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" @@ -1091,21 +923,12 @@ mod tests { Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); - let base_properties = EquivalenceProperties::new(Arc::clone(&schema)) - .with_reorder(LexOrdering::new( - ["a", "b", "c"] - .into_iter() - .map(|c| { - col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { - expr, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }) - }) - .collect::>>()?, - )); + let mut base_properties = EquivalenceProperties::new(Arc::clone(&schema)); + base_properties.reorder( + ["a", "b", "c"] + .into_iter() + .map(|c| PhysicalSortExpr::new_default(col(c, schema.as_ref()).unwrap())), + )?; struct TestCase { name: &'static str, @@ -1118,17 +941,14 @@ mod tests { let col_a = col("a", schema.as_ref())?; let col_b = col("b", schema.as_ref())?; let col_c = col("c", schema.as_ref())?; - let cast_c = Arc::new(CastExpr::new(col_c, DataType::Date32, None)); + let cast_c = Arc::new(CastExpr::new(col_c, DataType::Date32, None)) as _; let cases = vec![ TestCase { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![Arc::clone(&col_b)], - equal_conditions: vec![[ - Arc::clone(&cast_c) as Arc, - Arc::clone(&col_a), - ]], + equal_conditions: vec![[Arc::clone(&cast_c), Arc::clone(&col_a)]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -1138,10 +958,7 @@ mod tests { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![col_b], - equal_conditions: vec![[ - Arc::clone(&col_a), - Arc::clone(&cast_c) as Arc, - ]], + equal_conditions: vec![[Arc::clone(&col_a), Arc::clone(&cast_c)]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -1150,10 +967,7 @@ mod tests { // b is not constant anymore constants: vec![], // a and c are still compatible, but this is irrelevant since the original ordering is (a, b, c) - equal_conditions: vec![[ - Arc::clone(&cast_c) as Arc, - Arc::clone(&col_a), - ]], + equal_conditions: vec![[Arc::clone(&cast_c), Arc::clone(&col_a)]], sort_columns: &["c"], should_satisfy_ordering: false, }, @@ -1167,19 +981,21 @@ mod tests { // Equal conditions before constants { let mut properties = base_properties.clone(); - for [left, right] in &case.equal_conditions { + for [left, right] in case.equal_conditions.clone() { properties.add_equal_conditions(left, right)? } - properties.with_constants( + properties.add_constants( case.constants.iter().cloned().map(ConstExpr::from), - ) + )?; + properties }, // Constants before equal conditions { - let mut properties = base_properties.clone().with_constants( + let mut properties = base_properties.clone(); + properties.add_constants( case.constants.iter().cloned().map(ConstExpr::from), - ); - for [left, right] in &case.equal_conditions { + )?; + for [left, right] in case.equal_conditions { properties.add_equal_conditions(left, right)? } properties @@ -1188,16 +1004,11 @@ mod tests { let sort = case .sort_columns .iter() - .map(|&name| { - col(name, &schema).map(|col| PhysicalSortExpr { - expr: col, - options: SortOptions::default(), - }) - }) - .collect::>()?; + .map(|&name| col(name, &schema).map(PhysicalSortExpr::new_default)) + .collect::>>()?; assert_eq!( - properties.ordering_satisfy(sort.as_ref()), + properties.ordering_satisfy(sort)?, case.should_satisfy_ordering, "failed test '{}'", case.name @@ -1224,31 +1035,30 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, + Field::new("f", DataType::Utf8, true).into(), + Arc::new(ConfigOptions::default()), )); // Assume existing ordering is [c ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - eq_properties.add_new_ordering(LexOrdering::from(vec![ + eq_properties.add_ordering([ PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ])); + ]); // Add equality condition c = concat(a, b) - eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; + eq_properties.add_equal_conditions(Arc::clone(&col_c), a_concat_b)?; let orderings = eq_properties.oeq_class(); - let expected_ordering1 = - LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc() - ]); - let expected_ordering2 = LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + let expected_ordering1 = [PhysicalSortExpr::new_default(col_c).asc()].into(); + let expected_ordering2 = [ + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); // The ordering should be [c ASC] and [a ASC, b ASC] assert_eq!(orderings.len(), 2); @@ -1270,25 +1080,26 @@ mod tests { let col_b = col("b", &schema)?; let col_c = col("c", &schema)?; - let a_times_b: Arc = Arc::new(BinaryExpr::new( + let a_times_b = Arc::new(BinaryExpr::new( Arc::clone(&col_a), Operator::Multiply, Arc::clone(&col_b), - )); + )) as _; // Assume existing ordering is [c ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - let initial_ordering = LexOrdering::from(vec![ + let initial_ordering: LexOrdering = [ PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); - eq_properties.add_new_ordering(initial_ordering.clone()); + eq_properties.add_ordering(initial_ordering.clone()); // Add equality condition c = a * b - eq_properties.add_equal_conditions(&col_c, &a_times_b)?; + eq_properties.add_equal_conditions(col_c, a_times_b)?; let orderings = eq_properties.oeq_class(); @@ -1311,37 +1122,36 @@ mod tests { let col_b = col("b", &schema)?; let col_c = col("c", &schema)?; - let a_concat_b: Arc = Arc::new(ScalarFunctionExpr::new( + let a_concat_b = Arc::new(ScalarFunctionExpr::new( "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, - )); + Field::new("f", DataType::Utf8, true).into(), + Arc::new(ConfigOptions::default()), + )) as _; // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - eq_properties.add_new_ordering(LexOrdering::from(vec![ + eq_properties.add_ordering([ PhysicalSortExpr::new_default(Arc::clone(&a_concat_b)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ])); + ]); // Add equality condition c = concat(a, b) - eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; + eq_properties.add_equal_conditions(col_c, Arc::clone(&a_concat_b))?; let orderings = eq_properties.oeq_class(); - let expected_ordering1 = LexOrdering::from(vec![PhysicalSortExpr::new_default( - Arc::clone(&a_concat_b), - ) - .asc()]); - let expected_ordering2 = LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + let expected_ordering1 = [PhysicalSortExpr::new_default(a_concat_b).asc()].into(); + let expected_ordering2 = [ + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); - // The ordering should be [concat(a, b) ASC] and [a ASC, b ASC] + // The ordering should be [c ASC] and [a ASC, b ASC] assert_eq!(orderings.len(), 2); assert!(orderings.contains(&expected_ordering1)); assert!(orderings.contains(&expected_ordering2)); @@ -1349,6 +1159,35 @@ mod tests { Ok(()) } + #[test] + fn test_requirements_compatible() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + let eq_properties = EquivalenceProperties::new(schema); + let lex_a: LexRequirement = + [PhysicalSortRequirement::new(Arc::clone(&col_a), None)].into(); + let lex_a_b: LexRequirement = [ + PhysicalSortRequirement::new(col_a, None), + PhysicalSortRequirement::new(col_b, None), + ] + .into(); + let lex_c = [PhysicalSortRequirement::new(col_c, None)].into(); + + assert!(eq_properties.requirements_compatible(lex_a.clone(), lex_a.clone())); + assert!(!eq_properties.requirements_compatible(lex_a.clone(), lex_a_b.clone())); + assert!(eq_properties.requirements_compatible(lex_a_b, lex_a.clone())); + assert!(!eq_properties.requirements_compatible(lex_c, lex_a)); + + Ok(()) + } + #[test] fn test_with_reorder_constant_filtering() -> Result<()> { let schema = create_test_schema()?; @@ -1357,26 +1196,21 @@ mod tests { // Setup constant columns let col_a = col("a", &schema)?; let col_b = col("b", &schema)?; - eq_properties = eq_properties.with_constants([ConstExpr::from(&col_a)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(&col_a))])?; - let sort_exprs = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: SortOptions::default(), - }, - ]); + let sort_exprs = vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + PhysicalSortExpr::new_default(Arc::clone(&col_b)), + ]; - let result = eq_properties.with_reorder(sort_exprs); + let change = eq_properties.reorder(sort_exprs)?; + assert!(change); - // Should only contain b since a is constant - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering.len(), 1); - assert!(ordering[0].expr.eq(&col_b)); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 2); + assert!(ordering[0].expr.eq(&col_a)); + assert!(ordering[1].expr.eq(&col_b)); Ok(()) } @@ -1397,32 +1231,21 @@ mod tests { }; // Initial ordering: [a ASC, b DESC, c ASC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: desc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(&col_a), asc), + PhysicalSortExpr::new(Arc::clone(&col_b), desc), + PhysicalSortExpr::new(Arc::clone(&col_c), asc), + ]); // New ordering: [a ASC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }]); + let new_order = vec![PhysicalSortExpr::new(Arc::clone(&col_a), asc)]; - let result = eq_properties.with_reorder(new_order); + let change = eq_properties.reorder(new_order)?; + assert!(!change); // Should only contain [a ASC, b DESC, c ASC] - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); assert_eq!(ordering.len(), 3); assert!(ordering[0].expr.eq(&col_a)); assert!(ordering[0].options.eq(&asc)); @@ -1444,37 +1267,28 @@ mod tests { let col_c = col("c", &schema)?; // Make a and b equivalent - eq_properties.add_equal_conditions(&col_a, &col_b)?; - - let asc = SortOptions::default(); + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; // Initial ordering: [a ASC, c ASC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + PhysicalSortExpr::new_default(Arc::clone(&col_c)), + ]); // New ordering: [b ASC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: asc, - }]); + let new_order = vec![PhysicalSortExpr::new_default(Arc::clone(&col_b))]; - let result = eq_properties.with_reorder(new_order); + let change = eq_properties.reorder(new_order)?; - // Should only contain [b ASC, c ASC] - assert_eq!(result.oeq_class().len(), 1); + assert!(!change); + // Should only contain [a/b ASC, c ASC] + assert_eq!(eq_properties.oeq_class().len(), 1); // Verify orderings - let ordering = result.oeq_class().iter().next().unwrap(); + let asc = SortOptions::default(); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); assert_eq!(ordering.len(), 2); - assert!(ordering[0].expr.eq(&col_b)); + assert!(ordering[0].expr.eq(&col_a) || ordering[0].expr.eq(&col_b)); assert!(ordering[0].options.eq(&asc)); assert!(ordering[1].expr.eq(&col_c)); assert!(ordering[1].options.eq(&asc)); @@ -1497,29 +1311,21 @@ mod tests { }; // Initial ordering: [a ASC, b DESC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: desc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(&col_a), asc), + PhysicalSortExpr::new(Arc::clone(&col_b), desc), + ]); // New ordering: [a DESC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: desc, - }]); + let new_order = vec![PhysicalSortExpr::new(Arc::clone(&col_a), desc)]; - let result = eq_properties.with_reorder(new_order.clone()); + let change = eq_properties.reorder(new_order.clone())?; + assert!(change); // Should only contain the new ordering since options don't match - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering, &new_order); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.to_vec(), new_order); Ok(()) } @@ -1535,62 +1341,32 @@ mod tests { let col_d = col("d", &schema)?; let col_e = col("e", &schema)?; - let asc = SortOptions::default(); - // Constants: c is constant - eq_properties = eq_properties.with_constants([ConstExpr::from(&col_c)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(&col_c))])?; // Equality: b = d - eq_properties.add_equal_conditions(&col_b, &col_d)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_d))?; // Orderings: [d ASC, a ASC], [e ASC] - eq_properties.add_new_orderings([ - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_d), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_e), - options: asc, - }]), + eq_properties.add_orderings([ + vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_d)), + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + ], + vec![PhysicalSortExpr::new_default(Arc::clone(&col_e))], ]); - // Initial ordering: [b ASC, c ASC] - let new_order = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ]); - - let result = eq_properties.with_reorder(new_order); - - // Should preserve the original [d ASC, a ASC] ordering - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering.len(), 2); - - // First expression should be either b or d (they're equivalent) - assert!( - ordering[0].expr.eq(&col_b) || ordering[0].expr.eq(&col_d), - "Expected b or d as first expression, got {:?}", - ordering[0].expr - ); - assert!(ordering[0].options.eq(&asc)); + // New ordering: [b ASC, c ASC] + let new_order = vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_b)), + PhysicalSortExpr::new_default(Arc::clone(&col_c)), + ]; - // Second expression should be a - assert!(ordering[1].expr.eq(&col_a)); - assert!(ordering[1].options.eq(&asc)); + let old_orderings = eq_properties.oeq_class().clone(); + let change = eq_properties.reorder(new_order)?; + // Original orderings should be preserved: + assert!(!change); + assert_eq!(eq_properties.oeq_class, old_orderings); Ok(()) } @@ -1691,81 +1467,62 @@ mod tests { { let mut eq_properties = EquivalenceProperties::new(Arc::clone(schema)); - // Convert base ordering - let base_ordering = LexOrdering::new( - base_order - .iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ); - // Convert string column names to orderings - let satisfied_orderings: Vec = satisfied_orders + let satisfied_orderings: Vec<_> = satisfied_orders .iter() .map(|cols| { - LexOrdering::new( - cols.iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ) + cols.iter() + .map(|col_name| { + PhysicalSortExpr::new_default(col(col_name, schema).unwrap()) + }) + .collect::>() }) .collect(); - let unsatisfied_orderings: Vec = unsatisfied_orders + let unsatisfied_orderings: Vec<_> = unsatisfied_orders .iter() .map(|cols| { - LexOrdering::new( - cols.iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ) + cols.iter() + .map(|col_name| { + PhysicalSortExpr::new_default(col(col_name, schema).unwrap()) + }) + .collect::>() }) .collect(); // Test that orderings are not satisfied before adding constraints - for ordering in &satisfied_orderings { - assert!( - !eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should not be satisfied before adding constraints", - name, - ordering + for ordering in satisfied_orderings.clone() { + let err_msg = format!( + "{name}: ordering {ordering:?} should not be satisfied before adding constraints", ); + assert!(!eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } // Add base ordering - eq_properties.add_new_ordering(base_ordering); + let base_ordering = base_order.iter().map(|col_name| PhysicalSortExpr { + expr: col(col_name, schema).unwrap(), + options: SortOptions::default(), + }); + eq_properties.add_ordering(base_ordering); // Add constraints eq_properties = eq_properties.with_constraints(Constraints::new_unverified(constraints)); // Test that expected orderings are now satisfied - for ordering in &satisfied_orderings { - assert!( - eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should be satisfied after adding constraints", - name, - ordering + for ordering in satisfied_orderings { + let err_msg = format!( + "{name}: ordering {ordering:?} should be satisfied after adding constraints", ); + assert!(eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } // Test that unsatisfied orderings remain unsatisfied - for ordering in &unsatisfied_orderings { - assert!( - !eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should not be satisfied after adding constraints", - name, - ordering + for ordering in unsatisfied_orderings { + let err_msg = format!( + "{name}: ordering {ordering:?} should not be satisfied after adding constraints", ); + assert!(!eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } } diff --git a/datafusion/physical-expr/src/equivalence/properties/joins.rs b/datafusion/physical-expr/src/equivalence/properties/joins.rs index 7944e89d0305a..485b11d586397 100644 --- a/datafusion/physical-expr/src/equivalence/properties/joins.rs +++ b/datafusion/physical-expr/src/equivalence/properties/joins.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. +use super::EquivalenceProperties; use crate::{equivalence::OrderingEquivalenceClass, PhysicalExprRef}; -use arrow::datatypes::SchemaRef; -use datafusion_common::{JoinSide, JoinType}; -use super::EquivalenceProperties; +use arrow::datatypes::SchemaRef; +use datafusion_common::{JoinSide, JoinType, Result}; /// Calculate ordering equivalence properties for the given join operation. pub fn join_equivalence_properties( @@ -30,7 +30,7 @@ pub fn join_equivalence_properties( maintains_input_order: &[bool], probe_side: Option, on: &[(PhysicalExprRef, PhysicalExprRef)], -) -> EquivalenceProperties { +) -> Result { let left_size = left.schema.fields.len(); let mut result = EquivalenceProperties::new(join_schema); result.add_equivalence_group(left.eq_group().join( @@ -38,15 +38,13 @@ pub fn join_equivalence_properties( join_type, left_size, on, - )); + )?)?; let EquivalenceProperties { - constants: left_constants, oeq_class: left_oeq_class, .. } = left; let EquivalenceProperties { - constants: right_constants, oeq_class: mut right_oeq_class, .. } = right; @@ -54,12 +52,14 @@ pub fn join_equivalence_properties( [true, false] => { // In this special case, right side ordering can be prefixed with // the left side ordering. - if let (Some(JoinSide::Left), JoinType::Inner) = (probe_side, join_type) { + if matches!(join_type, JoinType::Inner | JoinType::Left) + && probe_side == Some(JoinSide::Left) + { updated_right_ordering_equivalence_class( &mut right_oeq_class, join_type, left_size, - ); + )?; // Right side ordering equivalence properties should be prepended // with those of the left side while constructing output ordering @@ -70,9 +70,9 @@ pub fn join_equivalence_properties( // then we should add `a ASC, b ASC` to the ordering equivalences // of the join output. let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); + result.add_orderings(out_oeq_class); } else { - result.add_ordering_equivalence_class(left_oeq_class); + result.add_orderings(left_oeq_class); } } [false, true] => { @@ -80,10 +80,12 @@ pub fn join_equivalence_properties( &mut right_oeq_class, join_type, left_size, - ); + )?; // In this special case, left side ordering can be prefixed with // the right side ordering. - if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { + if matches!(join_type, JoinType::Inner | JoinType::Right) + && probe_side == Some(JoinSide::Right) + { // Left side ordering equivalence properties should be prepended // with those of the right side while constructing output ordering // equivalence properties since stream side is the right side. @@ -93,25 +95,16 @@ pub fn join_equivalence_properties( // then we should add `b ASC, a ASC` to the ordering equivalences // of the join output. let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); + result.add_orderings(out_oeq_class); } else { - result.add_ordering_equivalence_class(right_oeq_class); + result.add_orderings(right_oeq_class); } } [false, false] => {} [true, true] => unreachable!("Cannot maintain ordering of both sides"), _ => unreachable!("Join operators can not have more than two children"), } - match join_type { - JoinType::LeftAnti | JoinType::LeftSemi => { - result = result.with_constants(left_constants); - } - JoinType::RightAnti | JoinType::RightSemi => { - result = result.with_constants(right_constants); - } - _ => {} - } - result + Ok(result) } /// In the context of a join, update the right side `OrderingEquivalenceClass` @@ -125,28 +118,29 @@ pub fn updated_right_ordering_equivalence_class( right_oeq_class: &mut OrderingEquivalenceClass, join_type: &JoinType, left_size: usize, -) { +) -> Result<()> { if matches!( join_type, JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right ) { - right_oeq_class.add_offset(left_size); + right_oeq_class.add_offset(left_size as _)?; } + Ok(()) } #[cfg(test)] mod tests { - use std::sync::Arc; use super::*; - use crate::equivalence::add_offset_to_expr; - use crate::equivalence::tests::{convert_to_orderings, create_test_schema}; + use crate::equivalence::convert_to_orderings; + use crate::equivalence::tests::create_test_schema; use crate::expressions::col; - use datafusion_common::Result; + use crate::physical_expr::add_offset_to_expr; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Fields, Schema}; + use datafusion_common::Result; #[test] fn test_join_equivalence_properties() -> Result<()> { @@ -154,9 +148,9 @@ mod tests { let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; let col_c = &col("c", &schema)?; - let offset = schema.fields.len(); - let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset); - let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset); + let offset = schema.fields.len() as _; + let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset)?; + let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset)?; let option_asc = SortOptions { descending: false, nulls_first: false, @@ -205,8 +199,8 @@ mod tests { let left_orderings = convert_to_orderings(&left_orderings); let right_orderings = convert_to_orderings(&right_orderings); let expected = convert_to_orderings(&expected); - left_eq_properties.add_new_orderings(left_orderings); - right_eq_properties.add_new_orderings(right_orderings); + left_eq_properties.add_orderings(left_orderings); + right_eq_properties.add_orderings(right_orderings); let join_eq = join_equivalence_properties( left_eq_properties, right_eq_properties, @@ -215,16 +209,14 @@ mod tests { &[true, false], Some(JoinSide::Left), &[], - ); + )?; let err_msg = format!("expected: {:?}, actual:{:?}", expected, &join_eq.oeq_class); - assert_eq!(join_eq.oeq_class.len(), expected.len(), "{}", err_msg); + assert_eq!(join_eq.oeq_class.len(), expected.len(), "{err_msg}"); for ordering in join_eq.oeq_class { assert!( expected.contains(&ordering), - "{}, ordering: {:?}", - err_msg, - ordering + "{err_msg}, ordering: {ordering:?}" ); } } @@ -255,7 +247,7 @@ mod tests { ]; let orderings = convert_to_orderings(&orderings); // Right child ordering equivalences - let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); + let mut right_oeq_class = OrderingEquivalenceClass::from(orderings); let left_columns_len = 4; @@ -266,24 +258,24 @@ mod tests { // Join Schema let schema = Schema::new(fields); - let col_a = &col("a", &schema)?; - let col_d = &col("d", &schema)?; - let col_x = &col("x", &schema)?; - let col_y = &col("y", &schema)?; - let col_z = &col("z", &schema)?; - let col_w = &col("w", &schema)?; + let col_a = col("a", &schema)?; + let col_d = col("d", &schema)?; + let col_x = col("x", &schema)?; + let col_y = col("y", &schema)?; + let col_z = col("z", &schema)?; + let col_w = col("w", &schema)?; let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); // a=x and d=w - join_eq_properties.add_equal_conditions(col_a, col_x)?; - join_eq_properties.add_equal_conditions(col_d, col_w)?; + join_eq_properties.add_equal_conditions(col_a, Arc::clone(&col_x))?; + join_eq_properties.add_equal_conditions(col_d, Arc::clone(&col_w))?; updated_right_ordering_equivalence_class( &mut right_oeq_class, &join_type, left_columns_len, - ); - join_eq_properties.add_ordering_equivalence_class(right_oeq_class); + )?; + join_eq_properties.add_orderings(right_oeq_class); let result = join_eq_properties.oeq_class().clone(); // [x ASC, y ASC], [z ASC, w ASC] @@ -292,7 +284,7 @@ mod tests { vec![(col_z, option_asc), (col_w, option_asc)], ]; let orderings = convert_to_orderings(&orderings); - let expected = OrderingEquivalenceClass::new(orderings); + let expected = OrderingEquivalenceClass::from(orderings); assert_eq!(result, expected); diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index c7c33ba5b2ba5..2404b8f0dd3eb 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -19,47 +19,43 @@ mod dependency; // Submodule containing DependencyMap and Dependencies mod joins; // Submodule containing join_equivalence_properties mod union; // Submodule containing calculate_union -use dependency::{ - construct_prefix_orderings, generate_dependency_orderings, referred_dependencies, - Dependencies, DependencyMap, -}; pub use joins::*; pub use union::*; -use std::fmt::Display; -use std::hash::{Hash, Hasher}; +use std::fmt::{self, Display}; +use std::mem; use std::sync::Arc; -use std::{fmt, mem}; -use crate::equivalence::class::{const_exprs_contains, AcrossPartitions}; +use self::dependency::{ + construct_prefix_orderings, generate_dependency_orderings, referred_dependencies, + Dependencies, DependencyMap, +}; use crate::equivalence::{ - EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, + AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; use crate::expressions::{with_new_schema, CastExpr, Column, Literal}; use crate::{ - physical_exprs_contains, ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, - PhysicalSortExpr, PhysicalSortRequirement, + ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, }; -use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, Constraint, Constraints, HashMap, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::sort_expr::options_compatible; use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::IndexSet; use itertools::Itertools; -/// `EquivalenceProperties` stores information about the output -/// of a plan node, that can be used to optimize the plan. -/// -/// Currently, it keeps track of: -/// - Sort expressions (orderings) -/// - Equivalent expressions: expressions that are known to have same value. -/// - Constants expressions: expressions that are known to contain a single -/// constant value. +/// `EquivalenceProperties` stores information about the output of a plan node +/// that can be used to optimize the plan. Currently, it keeps track of: +/// - Sort expressions (orderings), +/// - Equivalent expressions; i.e. expressions known to have the same value. +/// - Constants expressions; i.e. expressions known to contain a single constant +/// value. /// /// Please see the [Using Ordering for Better Plans] blog for more details. /// @@ -81,8 +77,8 @@ use itertools::Itertools; /// ``` /// /// In this case, both `a ASC` and `b DESC` can describe the table ordering. -/// `EquivalenceProperties`, tracks these different valid sort expressions and -/// treat `a ASC` and `b DESC` on an equal footing. For example if the query +/// `EquivalenceProperties` tracks these different valid sort expressions and +/// treat `a ASC` and `b DESC` on an equal footing. For example, if the query /// specifies the output sorted by EITHER `a ASC` or `b DESC`, the sort can be /// avoided. /// @@ -101,12 +97,11 @@ use itertools::Itertools; /// └---┴---┘ /// ``` /// -/// In this case, columns `a` and `b` always have the same value, which can of -/// such equivalences inside this object. With this information, Datafusion can -/// optimize operations such as. For example, if the partition requirement is -/// `Hash(a)` and output partitioning is `Hash(b)`, then DataFusion avoids -/// repartitioning the data as the existing partitioning satisfies the -/// requirement. +/// In this case, columns `a` and `b` always have the same value. With this +/// information, Datafusion can optimize various operations. For example, if +/// the partition requirement is `Hash(a)` and output partitioning is +/// `Hash(b)`, then DataFusion avoids repartitioning the data as the existing +/// partitioning satisfies the requirement. /// /// # Code Example /// ``` @@ -125,40 +120,85 @@ use itertools::Itertools; /// # let col_c = col("c", &schema).unwrap(); /// // This object represents data that is sorted by a ASC, c DESC /// // with a single constant value of b -/// let mut eq_properties = EquivalenceProperties::new(schema) -/// .with_constants(vec![ConstExpr::from(col_b)]); -/// eq_properties.add_new_ordering(LexOrdering::new(vec![ +/// let mut eq_properties = EquivalenceProperties::new(schema); +/// eq_properties.add_constants(vec![ConstExpr::from(col_b)]); +/// eq_properties.add_ordering([ /// PhysicalSortExpr::new_default(col_a).asc(), /// PhysicalSortExpr::new_default(col_c).desc(), -/// ])); +/// ]); /// -/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], const: [b@1(heterogeneous)]") +/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], eq: [{members: [b@1], constant: (heterogeneous)}]"); /// ``` -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct EquivalenceProperties { - /// Distinct equivalence classes (exprs known to have the same expressions) + /// Distinct equivalence classes (i.e. expressions with the same value). eq_group: EquivalenceGroup, - /// Equivalent sort expressions + /// Equivalent sort expressions (i.e. those define the same ordering). oeq_class: OrderingEquivalenceClass, - /// Expressions whose values are constant - /// - /// TODO: We do not need to track constants separately, they can be tracked - /// inside `eq_group` as `Literal` expressions. - constants: Vec, - /// Table constraints + /// Cache storing equivalent sort expressions in normal form (i.e. without + /// constants/duplicates and in standard form) and a map associating leading + /// terms with full sort expressions. + oeq_cache: OrderingEquivalenceCache, + /// Table constraints that factor in equivalence calculations. constraints: Constraints, /// Schema associated with this object. schema: SchemaRef, } +/// This object serves as a cache for storing equivalent sort expressions +/// in normal form, and a map associating leading sort expressions with +/// full lexicographical orderings. With this information, DataFusion can +/// efficiently determine whether a given ordering is satisfied by the +/// existing orderings, and discover new orderings based on the existing +/// equivalence properties. +#[derive(Clone, Debug, Default)] +struct OrderingEquivalenceCache { + /// Equivalent sort expressions in normal form. + normal_cls: OrderingEquivalenceClass, + /// Map associating leading sort expressions with full lexicographical + /// orderings. Values are indices into `normal_cls`. + leading_map: HashMap, Vec>, +} + +impl OrderingEquivalenceCache { + /// Creates a new `OrderingEquivalenceCache` object with the given + /// equivalent orderings, which should be in normal form. + pub fn new( + orderings: impl IntoIterator>, + ) -> Self { + let mut cache = Self { + normal_cls: OrderingEquivalenceClass::new(orderings), + leading_map: HashMap::new(), + }; + cache.update_map(); + cache + } + + /// Updates/reconstructs the leading expression map according to the normal + /// ordering equivalence class within. + pub fn update_map(&mut self) { + self.leading_map.clear(); + for (idx, ordering) in self.normal_cls.iter().enumerate() { + let expr = Arc::clone(&ordering.first().expr); + self.leading_map.entry(expr).or_default().push(idx); + } + } + + /// Clears the cache, removing all orderings and leading expressions. + pub fn clear(&mut self) { + self.normal_cls.clear(); + self.leading_map.clear(); + } +} + impl EquivalenceProperties { /// Creates an empty `EquivalenceProperties` object. pub fn new(schema: SchemaRef) -> Self { Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::empty(), - constants: vec![], - constraints: Constraints::empty(), + eq_group: EquivalenceGroup::default(), + oeq_class: OrderingEquivalenceClass::default(), + oeq_cache: OrderingEquivalenceCache::default(), + constraints: Constraints::default(), schema, } } @@ -170,12 +210,23 @@ impl EquivalenceProperties { } /// Creates a new `EquivalenceProperties` object with the given orderings. - pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { + pub fn new_with_orderings( + schema: SchemaRef, + orderings: impl IntoIterator>, + ) -> Self { + let eq_group = EquivalenceGroup::default(); + let oeq_class = OrderingEquivalenceClass::new(orderings); + // Here, we can avoid performing a full normalization, and get by with + // only removing constants because the equivalence group is empty. + let normal_orderings = oeq_class.iter().cloned().map(|o| { + o.into_iter() + .filter(|sort_expr| eq_group.is_expr_constant(&sort_expr.expr).is_none()) + }); Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), - constants: vec![], - constraints: Constraints::empty(), + oeq_cache: OrderingEquivalenceCache::new(normal_orderings), + oeq_class, + eq_group, + constraints: Constraints::default(), schema, } } @@ -190,91 +241,131 @@ impl EquivalenceProperties { &self.oeq_class } - /// Return the inner OrderingEquivalenceClass, consuming self - pub fn into_oeq_class(self) -> OrderingEquivalenceClass { - self.oeq_class - } - /// Returns a reference to the equivalence group within. pub fn eq_group(&self) -> &EquivalenceGroup { &self.eq_group } - /// Returns a reference to the constant expressions - pub fn constants(&self) -> &[ConstExpr] { - &self.constants - } - + /// Returns a reference to the constraints within. pub fn constraints(&self) -> &Constraints { &self.constraints } - /// Returns the output ordering of the properties. - pub fn output_ordering(&self) -> Option { - let constants = self.constants(); - let mut output_ordering = self.oeq_class().output_ordering().unwrap_or_default(); - // Prune out constant expressions - output_ordering - .retain(|sort_expr| !const_exprs_contains(constants, &sort_expr.expr)); - (!output_ordering.is_empty()).then_some(output_ordering) + /// Returns all the known constants expressions. + pub fn constants(&self) -> Vec { + self.eq_group + .iter() + .flat_map(|c| { + c.iter().filter_map(|expr| { + c.constant + .as_ref() + .map(|across| ConstExpr::new(Arc::clone(expr), across.clone())) + }) + }) + .collect() } - /// Returns the normalized version of the ordering equivalence class within. - /// Normalization removes constants and duplicates as well as standardizing - /// expressions according to the equivalence group within. - pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { - OrderingEquivalenceClass::new( - self.oeq_class - .iter() - .map(|ordering| self.normalize_sort_exprs(ordering)) - .collect(), - ) + /// Returns the output ordering of the properties. + pub fn output_ordering(&self) -> Option { + let concat = self.oeq_class.iter().flat_map(|o| o.iter().cloned()); + self.normalize_sort_exprs(concat) } /// Extends this `EquivalenceProperties` with the `other` object. - pub fn extend(mut self, other: Self) -> Self { - self.eq_group.extend(other.eq_group); - self.oeq_class.extend(other.oeq_class); - self.with_constants(other.constants) + pub fn extend(mut self, other: Self) -> Result { + self.constraints.extend(other.constraints); + self.add_equivalence_group(other.eq_group)?; + self.add_orderings(other.oeq_class); + Ok(self) } /// Clears (empties) the ordering equivalence class within this object. /// Call this method when existing orderings are invalidated. pub fn clear_orderings(&mut self) { self.oeq_class.clear(); + self.oeq_cache.clear(); } /// Removes constant expressions that may change across partitions. - /// This method should be used when data from different partitions are merged. + /// This method should be used when merging data from different partitions. pub fn clear_per_partition_constants(&mut self) { - self.constants.retain(|item| { - matches!(item.across_partitions(), AcrossPartitions::Uniform(_)) - }) - } - - /// Extends this `EquivalenceProperties` by adding the orderings inside the - /// ordering equivalence class `other`. - pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { - self.oeq_class.extend(other); + if self.eq_group.clear_per_partition_constants() { + // Renormalize orderings if the equivalence group changes: + let normal_orderings = self + .oeq_class + .iter() + .cloned() + .map(|o| self.eq_group.normalize_sort_exprs(o)); + self.oeq_cache = OrderingEquivalenceCache::new(normal_orderings); + } } /// Adds new orderings into the existing ordering equivalence class. - pub fn add_new_orderings( + pub fn add_orderings( &mut self, - orderings: impl IntoIterator, + orderings: impl IntoIterator>, ) { - self.oeq_class.add_new_orderings(orderings); + let orderings: Vec<_> = + orderings.into_iter().filter_map(LexOrdering::new).collect(); + let normal_orderings: Vec<_> = orderings + .iter() + .cloned() + .filter_map(|o| self.normalize_sort_exprs(o)) + .collect(); + if !normal_orderings.is_empty() { + self.oeq_class.extend(orderings); + // Normalize given orderings to update the cache: + self.oeq_cache.normal_cls.extend(normal_orderings); + // TODO: If no ordering is found to be redundant during extension, we + // can use a shortcut algorithm to update the leading map. + self.oeq_cache.update_map(); + } } /// Adds a single ordering to the existing ordering equivalence class. - pub fn add_new_ordering(&mut self, ordering: LexOrdering) { - self.add_new_orderings([ordering]); + pub fn add_ordering(&mut self, ordering: impl IntoIterator) { + self.add_orderings(std::iter::once(ordering)); + } + + fn update_oeq_cache(&mut self) -> Result<()> { + // Renormalize orderings if the equivalence group changes: + let normal_cls = mem::take(&mut self.oeq_cache.normal_cls); + let normal_orderings = normal_cls + .into_iter() + .map(|o| self.eq_group.normalize_sort_exprs(o)); + self.oeq_cache.normal_cls = OrderingEquivalenceClass::new(normal_orderings); + self.oeq_cache.update_map(); + // Discover any new orderings based on the new equivalence classes: + let leading_exprs: Vec<_> = self.oeq_cache.leading_map.keys().cloned().collect(); + for expr in leading_exprs { + self.discover_new_orderings(expr)?; + } + Ok(()) } /// Incorporates the given equivalence group to into the existing /// equivalence group within. - pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { - self.eq_group.extend(other_eq_group); + pub fn add_equivalence_group( + &mut self, + other_eq_group: EquivalenceGroup, + ) -> Result<()> { + if !other_eq_group.is_empty() { + self.eq_group.extend(other_eq_group); + self.update_oeq_cache()?; + } + Ok(()) + } + + /// Returns the ordering equivalence class within in normal form. + /// Normalization standardizes expressions according to the equivalence + /// group within, and removes constants/duplicates. + pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { + self.oeq_class + .iter() + .cloned() + .filter_map(|ordering| self.normalize_sort_exprs(ordering)) + .collect::>() + .into() } /// Adds a new equality condition into the existing equivalence group. @@ -282,286 +373,229 @@ impl EquivalenceProperties { /// equivalence class to the equivalence group. pub fn add_equal_conditions( &mut self, - left: &Arc, - right: &Arc, + left: Arc, + right: Arc, ) -> Result<()> { - // Discover new constants in light of new the equality: - if self.is_expr_constant(left) { - // Left expression is constant, add right as constant - if !const_exprs_contains(&self.constants, right) { - let const_expr = ConstExpr::from(right) - .with_across_partitions(self.get_expr_constant_value(left)); - self.constants.push(const_expr); - } - } else if self.is_expr_constant(right) { - // Right expression is constant, add left as constant - if !const_exprs_contains(&self.constants, left) { - let const_expr = ConstExpr::from(left) - .with_across_partitions(self.get_expr_constant_value(right)); - self.constants.push(const_expr); - } + // Add equal expressions to the state: + if self.eq_group.add_equal_conditions(Arc::clone(&left), right) { + self.update_oeq_cache()?; } - - // Add equal expressions to the state - self.eq_group.add_equal_conditions(left, right); - - // Discover any new orderings - self.discover_new_orderings(left)?; + self.update_oeq_cache()?; Ok(()) } /// Track/register physical expressions with constant values. - #[deprecated(since = "43.0.0", note = "Use [`with_constants`] instead")] - pub fn add_constants(self, constants: impl IntoIterator) -> Self { - self.with_constants(constants) - } - - /// Remove the specified constant - pub fn remove_constant(mut self, c: &ConstExpr) -> Self { - self.constants.retain(|existing| existing != c); - self - } - - /// Track/register physical expressions with constant values. - pub fn with_constants( - mut self, + pub fn add_constants( + &mut self, constants: impl IntoIterator, - ) -> Self { - let normalized_constants = constants - .into_iter() - .filter_map(|c| { - let across_partitions = c.across_partitions(); - let expr = c.owned_expr(); - let normalized_expr = self.eq_group.normalize_expr(expr); - - if const_exprs_contains(&self.constants, &normalized_expr) { - return None; - } - - let const_expr = ConstExpr::from(normalized_expr) - .with_across_partitions(across_partitions); - - Some(const_expr) + ) -> Result<()> { + // Add the new constant to the equivalence group: + for constant in constants { + self.eq_group.add_constant(constant); + } + // Renormalize the orderings after adding new constants by removing + // the constants from existing orderings: + let normal_cls = mem::take(&mut self.oeq_cache.normal_cls); + let normal_orderings = normal_cls.into_iter().map(|ordering| { + ordering.into_iter().filter(|sort_expr| { + self.eq_group.is_expr_constant(&sort_expr.expr).is_none() }) - .collect::>(); - - // Add all new normalized constants - self.constants.extend(normalized_constants); - - // Discover any new orderings based on the constants - for ordering in self.normalized_oeq_class().iter() { - if let Err(e) = self.discover_new_orderings(&ordering[0].expr) { - log::debug!("error discovering new orderings: {e}"); - } + }); + self.oeq_cache.normal_cls = OrderingEquivalenceClass::new(normal_orderings); + self.oeq_cache.update_map(); + // Discover any new orderings based on the constants: + let leading_exprs: Vec<_> = self.oeq_cache.leading_map.keys().cloned().collect(); + for expr in leading_exprs { + self.discover_new_orderings(expr)?; } - - self + Ok(()) } - // Discover new valid orderings in light of a new equality. - // Accepts a single argument (`expr`) which is used to determine - // which orderings should be updated. - // When constants or equivalence classes are changed, there may be new orderings - // that can be discovered with the new equivalence properties. - // For a discussion, see: https://github.com/apache/datafusion/issues/9812 - fn discover_new_orderings(&mut self, expr: &Arc) -> Result<()> { - let normalized_expr = self.eq_group().normalize_expr(Arc::clone(expr)); + /// Discover new valid orderings in light of a new equality. Accepts a single + /// argument (`expr`) which is used to determine the orderings to update. + /// When constants or equivalence classes change, there may be new orderings + /// that can be discovered with the new equivalence properties. + /// For a discussion, see: + fn discover_new_orderings( + &mut self, + normal_expr: Arc, + ) -> Result<()> { + let Some(ordering_idxs) = self.oeq_cache.leading_map.get(&normal_expr) else { + return Ok(()); + }; let eq_class = self .eq_group - .iter() - .find_map(|class| { - class - .contains(&normalized_expr) - .then(|| class.clone().into_vec()) - }) - .unwrap_or_else(|| vec![Arc::clone(&normalized_expr)]); - - let mut new_orderings: Vec = vec![]; - for ordering in self.normalized_oeq_class().iter() { - if !ordering[0].expr.eq(&normalized_expr) { - continue; - } + .get_equivalence_class(&normal_expr) + .map_or_else(|| vec![normal_expr], |class| class.clone().into()); + let mut new_orderings = vec![]; + for idx in ordering_idxs { + let ordering = &self.oeq_cache.normal_cls[*idx]; let leading_ordering_options = ordering[0].options; - for equivalent_expr in &eq_class { + 'exprs: for equivalent_expr in &eq_class { let children = equivalent_expr.children(); if children.is_empty() { continue; } - - // Check if all children match the next expressions in the ordering - let mut all_children_match = true; + // Check if all children match the next expressions in the ordering: let mut child_properties = vec![]; - - // Build properties for each child based on the next expressions - for (i, child) in children.iter().enumerate() { - if let Some(next) = ordering.get(i + 1) { - if !child.as_ref().eq(next.expr.as_ref()) { - all_children_match = false; - break; - } - child_properties.push(ExprProperties { - sort_properties: SortProperties::Ordered(next.options), - range: Interval::make_unbounded( - &child.data_type(&self.schema)?, - )?, - preserves_lex_ordering: true, - }); - } else { - all_children_match = false; - break; + // Build properties for each child based on the next expression: + for (i, child) in children.into_iter().enumerate() { + let Some(next) = ordering.get(i + 1) else { + break 'exprs; + }; + if !next.expr.eq(child) { + break 'exprs; } + let data_type = child.data_type(&self.schema)?; + child_properties.push(ExprProperties { + sort_properties: SortProperties::Ordered(next.options), + range: Interval::make_unbounded(&data_type)?, + preserves_lex_ordering: true, + }); } - - if all_children_match { - // Check if the expression is monotonic in all arguments - if let Ok(expr_properties) = - equivalent_expr.get_properties(&child_properties) - { - if expr_properties.preserves_lex_ordering - && SortProperties::Ordered(leading_ordering_options) - == expr_properties.sort_properties - { - // Assume existing ordering is [c ASC, a ASC, b ASC] - // When equality c = f(a,b) is given, if we know that given ordering `[a ASC, b ASC]`, - // ordering `[f(a,b) ASC]` is valid, then we can deduce that ordering `[a ASC, b ASC]` is also valid. - // Hence, ordering `[a ASC, b ASC]` can be added to the state as a valid ordering. - // (e.g. existing ordering where leading ordering is removed) - new_orderings.push(LexOrdering::new(ordering[1..].to_vec())); - break; - } - } + // Check if the expression is monotonic in all arguments: + let expr_properties = + equivalent_expr.get_properties(&child_properties)?; + if expr_properties.preserves_lex_ordering + && expr_properties.sort_properties + == SortProperties::Ordered(leading_ordering_options) + { + // Assume that `[c ASC, a ASC, b ASC]` is among existing + // orderings. If equality `c = f(a, b)` is given, ordering + // `[a ASC, b ASC]` implies the ordering `[c ASC]`. Thus, + // ordering `[a ASC, b ASC]` is also a valid ordering. + new_orderings.push(ordering[1..].to_vec()); + break; } } } - self.oeq_class.add_new_orderings(new_orderings); - Ok(()) - } - - /// Updates the ordering equivalence group within assuming that the table - /// is re-sorted according to the argument `sort_exprs`. Note that constants - /// and equivalence classes are unchanged as they are unaffected by a re-sort. - /// If the given ordering is already satisfied, the function does nothing. - pub fn with_reorder(mut self, sort_exprs: LexOrdering) -> Self { - // Filter out constant expressions as they don't affect ordering - let filtered_exprs = LexOrdering::new( - sort_exprs - .into_iter() - .filter(|expr| !self.is_expr_constant(&expr.expr)) - .collect(), - ); - - if filtered_exprs.is_empty() { - return self; - } - - let mut new_orderings = vec![filtered_exprs.clone()]; - - // Preserve valid suffixes from existing orderings - let oeq_class = mem::take(&mut self.oeq_class); - for existing in oeq_class { - if self.is_prefix_of(&filtered_exprs, &existing) { - let mut extended = filtered_exprs.clone(); - extended.extend(existing.into_iter().skip(filtered_exprs.len())); - new_orderings.push(extended); - } + if !new_orderings.is_empty() { + self.add_orderings(new_orderings); } - - self.oeq_class = OrderingEquivalenceClass::new(new_orderings); - self + Ok(()) } - /// Checks if the new ordering matches a prefix of the existing ordering - /// (considering expression equivalences) - fn is_prefix_of(&self, new_order: &LexOrdering, existing: &LexOrdering) -> bool { - // Check if new order is longer than existing - can't be a prefix - if new_order.len() > existing.len() { - return false; + /// Updates the ordering equivalence class within assuming that the table + /// is re-sorted according to the argument `ordering`, and returns whether + /// this operation resulted in any change. Note that equivalence classes + /// (and constants) do not change as they are unaffected by a re-sort. If + /// the given ordering is already satisfied, the function does nothing. + pub fn reorder( + &mut self, + ordering: impl IntoIterator, + ) -> Result { + let (ordering, ordering_tee) = ordering.into_iter().tee(); + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(ordering) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok(false); + }; + if normal_ordering.len() != self.common_sort_prefix_length(&normal_ordering)? { + // If the ordering is unsatisfied, replace existing orderings: + self.clear_orderings(); + self.add_ordering(ordering_tee); + return Ok(true); } - - // Check if new order matches existing prefix (considering equivalences) - new_order.iter().zip(existing).all(|(new, existing)| { - self.eq_group.exprs_equal(&new.expr, &existing.expr) - && new.options == existing.options - }) + Ok(false) } /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the - /// equivalence group and the ordering equivalence class within. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = LexRequirement::from(sort_exprs.clone()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - LexOrdering::from(normalized_sort_reqs) + /// equivalence group within. Returns a `LexOrdering` instance if the + /// expressions define a proper lexicographical ordering. For more details, + /// see [`EquivalenceGroup::normalize_sort_exprs`]. + pub fn normalize_sort_exprs( + &self, + sort_exprs: impl IntoIterator, + ) -> Option { + LexOrdering::new(self.eq_group.normalize_sort_exprs(sort_exprs)) } /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the - /// equivalence group and the ordering equivalence class within. It works by: - /// - Removing expressions that have a constant value from the given requirement. - /// - Replacing sections that belong to some equivalence class in the equivalence - /// group with the first entry in the matching equivalence class. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_requirements(&self, sort_reqs: &LexRequirement) -> LexRequirement { - let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); - let mut constant_exprs = vec![]; - constant_exprs.extend( - self.constants - .iter() - .map(|const_expr| Arc::clone(const_expr.expr())), - ); - let constants_normalized = self.eq_group.normalize_exprs(constant_exprs); - // Prune redundant sections in the requirement: - normalized_sort_reqs - .iter() - .filter(|&order| !physical_exprs_contains(&constants_normalized, &order.expr)) - .cloned() - .collect::() - .collapse() + /// equivalence group within. Returns a `LexRequirement` instance if the + /// expressions define a proper lexicographical requirement. For more + /// details, see [`EquivalenceGroup::normalize_sort_exprs`]. + pub fn normalize_sort_requirements( + &self, + sort_reqs: impl IntoIterator, + ) -> Option { + LexRequirement::new(self.eq_group.normalize_sort_requirements(sort_reqs)) } - /// Checks whether the given ordering is satisfied by any of the existing - /// orderings. - pub fn ordering_satisfy(&self, given: &LexOrdering) -> bool { - // Convert the given sort expressions to sort requirements: - let sort_requirements = LexRequirement::from(given.clone()); - self.ordering_satisfy_requirement(&sort_requirements) + /// Iteratively checks whether the given ordering is satisfied by any of + /// the existing orderings. See [`Self::ordering_satisfy_requirement`] for + /// more details and examples. + pub fn ordering_satisfy( + &self, + given: impl IntoIterator, + ) -> Result { + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(given) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok(true); + }; + Ok(normal_ordering.len() == self.common_sort_prefix_length(&normal_ordering)?) } - /// Checks whether the given sort requirements are satisfied by any of the - /// existing orderings. - pub fn ordering_satisfy_requirement(&self, reqs: &LexRequirement) -> bool { - let mut eq_properties = self.clone(); + /// Iteratively checks whether the given sort requirement is satisfied by + /// any of the existing orderings. + /// + /// ### Example Scenarios + /// + /// In these scenarios, assume that all expressions share the same sort + /// properties. + /// + /// #### Case 1: Sort Requirement `[a, c]` + /// + /// **Existing orderings:** `[[a, b, c], [a, d]]`, **constants:** `[]` + /// 1. The function first checks the leading requirement `a`, which is + /// satisfied by `[a, b, c].first()`. + /// 2. `a` is added as a constant for the next iteration. + /// 3. Normal orderings become `[[b, c], [d]]`. + /// 4. The function fails for `c` in the second iteration, as neither + /// `[b, c]` nor `[d]` satisfies `c`. + /// + /// #### Case 2: Sort Requirement `[a, d]` + /// + /// **Existing orderings:** `[[a, b, c], [a, d]]`, **constants:** `[]` + /// 1. The function first checks the leading requirement `a`, which is + /// satisfied by `[a, b, c].first()`. + /// 2. `a` is added as a constant for the next iteration. + /// 3. Normal orderings become `[[b, c], [d]]`. + /// 4. The function returns `true` as `[d]` satisfies `d`. + pub fn ordering_satisfy_requirement( + &self, + given: impl IntoIterator, + ) -> Result { // First, standardize the given requirement: - let normalized_reqs = eq_properties.normalize_sort_requirements(reqs); - - // Check whether given ordering is satisfied by constraints first - if self.satisfied_by_constraints(&normalized_reqs) { - return true; + let Some(normal_reqs) = self.normalize_sort_requirements(given) else { + // If the requirement vanishes after normalization, it is satisfied: + return Ok(true); + }; + // Then, check whether given requirement is satisfied by constraints: + if self.satisfied_by_constraints(&normal_reqs) { + return Ok(true); } - - for normalized_req in normalized_reqs { - // Check whether given ordering is satisfied - if !eq_properties.ordering_satisfy_single(&normalized_req) { - return false; + let schema = self.schema(); + let mut eq_properties = self.clone(); + for element in normal_reqs { + // Check whether given requirement is satisfied: + let ExprProperties { + sort_properties, .. + } = eq_properties.get_expr_properties(Arc::clone(&element.expr)); + let satisfy = match sort_properties { + SortProperties::Ordered(options) => element.options.is_none_or(|opts| { + let nullable = element.expr.nullable(schema).unwrap_or(true); + options_compatible(&options, &opts, nullable) + }), + // Singleton expressions satisfy any requirement. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + }; + if !satisfy { + return Ok(false); } // Treat satisfied keys as constants in subsequent iterations. We // can do this because the "next" key only matters in a lexicographical @@ -575,263 +609,263 @@ impl EquivalenceProperties { // From the analysis above, we know that `[a ASC]` is satisfied. Then, // we add column `a` as constant to the algorithm state. This enables us // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. - eq_properties = eq_properties - .with_constants(std::iter::once(ConstExpr::from(normalized_req.expr))); + let const_expr = ConstExpr::from(element.expr); + eq_properties.add_constants(std::iter::once(const_expr))?; } - true + Ok(true) } - /// Checks if the sort requirements are satisfied by any of the table constraints (primary key or unique). - /// Returns true if any constraint fully satisfies the requirements. - fn satisfied_by_constraints( - &self, - normalized_reqs: &[PhysicalSortRequirement], - ) -> bool { - self.constraints.iter().any(|constraint| match constraint { - Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => self - .satisfied_by_constraint( - normalized_reqs, - indices, - matches!(constraint, Constraint::Unique(_)), + /// Returns the number of consecutive sort expressions (starting from the + /// left) that are satisfied by the existing ordering. + fn common_sort_prefix_length(&self, normal_ordering: &LexOrdering) -> Result { + let full_length = normal_ordering.len(); + // Check whether the given ordering is satisfied by constraints: + if self.satisfied_by_constraints_ordering(normal_ordering) { + // If constraints satisfy all sort expressions, return the full + // length: + return Ok(full_length); + } + let schema = self.schema(); + let mut eq_properties = self.clone(); + for (idx, element) in normal_ordering.into_iter().enumerate() { + // Check whether given ordering is satisfied: + let ExprProperties { + sort_properties, .. + } = eq_properties.get_expr_properties(Arc::clone(&element.expr)); + let satisfy = match sort_properties { + SortProperties::Ordered(options) => options_compatible( + &options, + &element.options, + element.expr.nullable(schema).unwrap_or(true), ), - }) + // Singleton expressions satisfy any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + }; + if !satisfy { + // As soon as one sort expression is unsatisfied, return how + // many we've satisfied so far: + return Ok(idx); + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + let const_expr = ConstExpr::from(Arc::clone(&element.expr)); + eq_properties.add_constants(std::iter::once(const_expr))? + } + // All sort expressions are satisfied, return full length: + Ok(full_length) } - /// Checks if sort requirements are satisfied by a constraint (primary key or unique). - /// Returns true if the constraint indices form a valid prefix of an existing ordering - /// that matches the requirements. For unique constraints, also verifies nullable columns. - fn satisfied_by_constraint( + /// Determines the longest normal prefix of `ordering` satisfied by the + /// existing ordering. Returns that prefix as a new `LexOrdering`, and a + /// boolean indicating whether all the sort expressions are satisfied. + pub fn extract_common_sort_prefix( &self, - normalized_reqs: &[PhysicalSortRequirement], - indices: &[usize], - check_null: bool, - ) -> bool { - // Requirements must contain indices - if indices.len() > normalized_reqs.len() { - return false; + ordering: LexOrdering, + ) -> Result<(Vec, bool)> { + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(ordering) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok((vec![], true)); + }; + let prefix_len = self.common_sort_prefix_length(&normal_ordering)?; + let flag = prefix_len == normal_ordering.len(); + let mut sort_exprs: Vec<_> = normal_ordering.into(); + if !flag { + sort_exprs.truncate(prefix_len); } + Ok((sort_exprs, flag)) + } - // Iterate over all orderings - self.oeq_class.iter().any(|ordering| { - if indices.len() > ordering.len() { - return false; - } - - // Build a map of column positions in the ordering - let mut col_positions = HashMap::with_capacity(ordering.len()); - for (pos, req) in ordering.iter().enumerate() { - if let Some(col) = req.expr.as_any().downcast_ref::() { - col_positions.insert( - col.index(), - (pos, col.nullable(&self.schema).unwrap_or(true)), - ); - } - } - - // Check if all constraint indices appear in valid positions - if !indices.iter().all(|&idx| { - col_positions - .get(&idx) - .map(|&(pos, nullable)| { - // For unique constraints, verify column is not nullable if it's first/last - !check_null - || (pos != 0 && pos != ordering.len() - 1) - || !nullable + /// Checks if the sort expressions are satisfied by any of the table + /// constraints (primary key or unique). Returns true if any constraint + /// fully satisfies the expressions (i.e. constraint indices form a valid + /// prefix of an existing ordering that matches the expressions). For + /// unique constraints, also verifies nullable columns. + fn satisfied_by_constraints_ordering( + &self, + normal_exprs: &[PhysicalSortExpr], + ) -> bool { + self.constraints.iter().any(|constraint| match constraint { + Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => { + let check_null = matches!(constraint, Constraint::Unique(_)); + let normalized_size = normal_exprs.len(); + indices.len() <= normalized_size + && self.oeq_class.iter().any(|ordering| { + let length = ordering.len(); + if indices.len() > length || normalized_size < length { + return false; + } + // Build a map of column positions in the ordering: + let mut col_positions = HashMap::with_capacity(length); + for (pos, req) in ordering.iter().enumerate() { + if let Some(col) = req.expr.as_any().downcast_ref::() + { + let nullable = col.nullable(&self.schema).unwrap_or(true); + col_positions.insert(col.index(), (pos, nullable)); + } + } + // Check if all constraint indices appear in valid positions: + if !indices.iter().all(|idx| { + col_positions.get(idx).is_some_and(|&(pos, nullable)| { + // For unique constraints, verify column is not nullable if it's first/last: + !check_null + || !nullable + || (pos != 0 && pos != length - 1) + }) + }) { + return false; + } + // Check if this ordering matches the prefix: + normal_exprs.iter().zip(ordering).all(|(given, existing)| { + existing.satisfy_expr(given, &self.schema) + }) }) - .unwrap_or(false) - }) { - return false; } - - // Check if this ordering matches requirements prefix - let ordering_len = ordering.len(); - normalized_reqs.len() >= ordering_len - && normalized_reqs[..ordering_len].iter().zip(ordering).all( - |(req, existing)| { - req.expr.eq(&existing.expr) - && req - .options - .is_none_or(|req_opts| req_opts == existing.options) - }, - ) }) } - /// Determines whether the ordering specified by the given sort requirement - /// is satisfied based on the orderings within, equivalence classes, and - /// constant expressions. - /// - /// # Parameters - /// - /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering - /// satisfaction check will be done. - /// - /// # Returns - /// - /// Returns `true` if the specified ordering is satisfied, `false` otherwise. - fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { - let ExprProperties { - sort_properties, .. - } = self.get_expr_properties(Arc::clone(&req.expr)); - match sort_properties { - SortProperties::Ordered(options) => { - let sort_expr = PhysicalSortExpr { - expr: Arc::clone(&req.expr), - options, - }; - sort_expr.satisfy(req, self.schema()) + /// Checks if the sort requirements are satisfied by any of the table + /// constraints (primary key or unique). Returns true if any constraint + /// fully satisfies the requirements (i.e. constraint indices form a valid + /// prefix of an existing ordering that matches the requirements). For + /// unique constraints, also verifies nullable columns. + fn satisfied_by_constraints(&self, normal_reqs: &[PhysicalSortRequirement]) -> bool { + self.constraints.iter().any(|constraint| match constraint { + Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => { + let check_null = matches!(constraint, Constraint::Unique(_)); + let normalized_size = normal_reqs.len(); + indices.len() <= normalized_size + && self.oeq_class.iter().any(|ordering| { + let length = ordering.len(); + if indices.len() > length || normalized_size < length { + return false; + } + // Build a map of column positions in the ordering: + let mut col_positions = HashMap::with_capacity(length); + for (pos, req) in ordering.iter().enumerate() { + if let Some(col) = req.expr.as_any().downcast_ref::() + { + let nullable = col.nullable(&self.schema).unwrap_or(true); + col_positions.insert(col.index(), (pos, nullable)); + } + } + // Check if all constraint indices appear in valid positions: + if !indices.iter().all(|idx| { + col_positions.get(idx).is_some_and(|&(pos, nullable)| { + // For unique constraints, verify column is not nullable if it's first/last: + !check_null + || !nullable + || (pos != 0 && pos != length - 1) + }) + }) { + return false; + } + // Check if this ordering matches the prefix: + normal_reqs.iter().zip(ordering).all(|(given, existing)| { + existing.satisfy(given, &self.schema) + }) + }) } - // Singleton expressions satisfies any ordering. - SortProperties::Singleton => true, - SortProperties::Unordered => false, - } + }) } /// Checks whether the `given` sort requirements are equal or more specific /// than the `reference` sort requirements. pub fn requirements_compatible( &self, - given: &LexRequirement, - reference: &LexRequirement, + given: LexRequirement, + reference: LexRequirement, ) -> bool { - let normalized_given = self.normalize_sort_requirements(given); - let normalized_reference = self.normalize_sort_requirements(reference); + let Some(normal_given) = self.normalize_sort_requirements(given) else { + return true; + }; + let Some(normal_reference) = self.normalize_sort_requirements(reference) else { + return true; + }; - (normalized_reference.len() <= normalized_given.len()) - && normalized_reference + (normal_reference.len() <= normal_given.len()) + && normal_reference .into_iter() - .zip(normalized_given) + .zip(normal_given) .all(|(reference, given)| given.compatible(&reference)) } - /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking - /// any ties by choosing `lhs`. - /// - /// The finer ordering is the ordering that satisfies both of the orderings. - /// If the orderings are incomparable, returns `None`. - /// - /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is - /// the latter. - pub fn get_finer_ordering( - &self, - lhs: &LexOrdering, - rhs: &LexOrdering, - ) -> Option { - // Convert the given sort expressions to sort requirements: - let lhs = LexRequirement::from(lhs.clone()); - let rhs = LexRequirement::from(rhs.clone()); - let finer = self.get_finer_requirement(&lhs, &rhs); - // Convert the chosen sort requirements back to sort expressions: - finer.map(LexOrdering::from) - } - - /// Returns the finer ordering among the requirements `lhs` and `rhs`, - /// breaking any ties by choosing `lhs`. + /// Modify existing orderings by substituting sort expressions with appropriate + /// targets from the projection mapping. We substitute a sort expression when + /// its physical expression has a one-to-one functional relationship with a + /// target expression in the mapping. /// - /// The finer requirements are the ones that satisfy both of the given - /// requirements. If the requirements are incomparable, returns `None`. + /// After substitution, we may generate more than one `LexOrdering` for each + /// existing equivalent ordering. For example, `[a ASC, b ASC]` will turn + /// into `[CAST(a) ASC, b ASC]` and `[a ASC, b ASC]` when applying projection + /// expressions `a, b, CAST(a)`. /// - /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` - /// is the latter. - pub fn get_finer_requirement( - &self, - req1: &LexRequirement, - req2: &LexRequirement, - ) -> Option { - let mut lhs = self.normalize_sort_requirements(req1); - let mut rhs = self.normalize_sort_requirements(req2); - lhs.inner - .iter_mut() - .zip(rhs.inner.iter_mut()) - .all(|(lhs, rhs)| { - lhs.expr.eq(&rhs.expr) - && match (lhs.options, rhs.options) { - (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, - (Some(options), None) => { - rhs.options = Some(options); - true - } - (None, Some(options)) => { - lhs.options = Some(options); - true - } - (None, None) => true, - } - }) - .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) - } - - /// we substitute the ordering according to input expression type, this is a simplified version - /// In this case, we just substitute when the expression satisfy the following condition: - /// I. just have one column and is a CAST expression - /// TODO: Add one-to-ones analysis for monotonic ScalarFunctions. - /// TODO: we could precompute all the scenario that is computable, for example: atan(x + 1000) should also be substituted if - /// x is DESC or ASC - /// After substitution, we may generate more than 1 `LexOrdering`. As an example, - /// `[a ASC, b ASC]` will turn into `[a ASC, b ASC], [CAST(a) ASC, b ASC]` when projection expressions `a, b, CAST(a)` is applied. - pub fn substitute_ordering_component( - &self, + /// TODO: Handle all scenarios that allow substitution; e.g. when `x` is + /// sorted, `atan(x + 1000)` should also be substituted. For now, we + /// only consider single-column `CAST` expressions. + fn substitute_oeq_class( + schema: &SchemaRef, mapping: &ProjectionMapping, - sort_expr: &LexOrdering, - ) -> Result> { - let new_orderings = sort_expr - .iter() - .map(|sort_expr| { - let referring_exprs: Vec<_> = mapping - .iter() - .map(|(source, _target)| source) - .filter(|source| expr_refers(source, &sort_expr.expr)) - .cloned() - .collect(); - let mut res = LexOrdering::new(vec![sort_expr.clone()]); - // TODO: Add one-to-ones analysis for ScalarFunctions. - for r_expr in referring_exprs { - // we check whether this expression is substitutable or not - if let Some(cast_expr) = r_expr.as_any().downcast_ref::() { - // we need to know whether the Cast Expr matches or not - let expr_type = sort_expr.expr.data_type(&self.schema)?; - if cast_expr.expr.eq(&sort_expr.expr) - && cast_expr.is_bigger_cast(expr_type) + oeq_class: OrderingEquivalenceClass, + ) -> OrderingEquivalenceClass { + let new_orderings = oeq_class.into_iter().flat_map(|order| { + // Modify/expand existing orderings by substituting sort + // expressions with appropriate targets from the mapping: + order + .into_iter() + .map(|sort_expr| { + let referring_exprs = mapping + .iter() + .map(|(source, _target)| source) + .filter(|source| expr_refers(source, &sort_expr.expr)) + .cloned(); + let mut result = vec![]; + // The sort expression comes from this schema, so the + // following call to `unwrap` is safe. + let expr_type = sort_expr.expr.data_type(schema).unwrap(); + // TODO: Add one-to-one analysis for ScalarFunctions. + for r_expr in referring_exprs { + // We check whether this expression is substitutable. + if let Some(cast_expr) = + r_expr.as_any().downcast_ref::() { - res.push(PhysicalSortExpr { - expr: Arc::clone(&r_expr), - options: sort_expr.options, - }); + // For casts, we need to know whether the cast + // expression matches: + if cast_expr.expr.eq(&sort_expr.expr) + && cast_expr.is_bigger_cast(&expr_type) + { + result.push(PhysicalSortExpr::new( + r_expr, + sort_expr.options, + )); + } } } - } - Ok(res) - }) - .collect::>>()?; - // Generate all valid orderings, given substituted expressions. - let res = new_orderings - .into_iter() - .multi_cartesian_product() - .map(LexOrdering::new) - .collect::>(); - Ok(res) + result.push(sort_expr); + result + }) + // Generate all valid orderings given substituted expressions: + .multi_cartesian_product() + }); + OrderingEquivalenceClass::new(new_orderings) } - /// In projection, supposed we have a input function 'A DESC B DESC' and the output shares the same expression - /// with A and B, we could surely use the ordering of the original ordering, However, if the A has been changed, - /// for example, A-> Cast(A, Int64) or any other form, it is invalid if we continue using the original ordering - /// Since it would cause bug in dependency constructions, we should substitute the input order in order to get correct - /// dependency map, happen in issue 8838: - pub fn substitute_oeq_class(&mut self, mapping: &ProjectionMapping) -> Result<()> { - let new_order = self - .oeq_class - .iter() - .map(|order| self.substitute_ordering_component(mapping, order)) - .collect::>>()?; - let new_order = new_order.into_iter().flatten().collect(); - self.oeq_class = OrderingEquivalenceClass::new(new_order); - Ok(()) - } - /// Projects argument `expr` according to `projection_mapping`, taking - /// equivalences into account. + /// Projects argument `expr` according to the projection described by + /// `mapping`, taking equivalences into account. /// /// For example, assume that columns `a` and `c` are always equal, and that - /// `projection_mapping` encodes following mapping: + /// the projection described by `mapping` encodes the following: /// /// ```text /// a -> a1 @@ -839,13 +873,25 @@ impl EquivalenceProperties { /// ``` /// /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to - /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. + /// `Some(a1 + b1)` and `d` to `None`, meaning that it is not projectable. pub fn project_expr( &self, expr: &Arc, - projection_mapping: &ProjectionMapping, + mapping: &ProjectionMapping, ) -> Option> { - self.eq_group.project_expr(projection_mapping, expr) + self.eq_group.project_expr(mapping, expr) + } + + /// Projects the given `expressions` according to the projection described + /// by `mapping`, taking equivalences into account. This function is similar + /// to [`Self::project_expr`], but projects multiple expressions at once + /// more efficiently than calling `project_expr` for each expression. + pub fn project_expressions<'a>( + &'a self, + expressions: impl IntoIterator> + 'a, + mapping: &'a ProjectionMapping, + ) -> impl Iterator>> + 'a { + self.eq_group.project_expressions(mapping, expressions) } /// Constructs a dependency map based on existing orderings referred to in @@ -877,71 +923,85 @@ impl EquivalenceProperties { /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} /// c ASC: Node {None, HashSet{a ASC}} /// ``` - fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = DependencyMap::new(); - for ordering in self.normalized_oeq_class().iter() { - for (idx, sort_expr) in ordering.iter().enumerate() { - let target_sort_expr = - self.project_expr(&sort_expr.expr, mapping).map(|expr| { - PhysicalSortExpr { - expr, - options: sort_expr.options, - } - }); - let is_projected = target_sort_expr.is_some(); - if is_projected - || mapping - .iter() - .any(|(source, _)| expr_refers(source, &sort_expr.expr)) - { - // Previous ordering is a dependency. Note that there is no, - // dependency for a leading ordering (i.e. the first sort - // expression). - let dependency = idx.checked_sub(1).map(|a| &ordering[a]); - // Add sort expressions that can be projected or referred to - // by any of the projection expressions to the dependency map: - dependency_map.insert( - sort_expr, - target_sort_expr.as_ref(), - dependency, - ); - } - if !is_projected { - // If we can not project, stop constructing the dependency - // map as remaining dependencies will be invalid after projection. + fn construct_dependency_map( + &self, + oeq_class: OrderingEquivalenceClass, + mapping: &ProjectionMapping, + ) -> DependencyMap { + let mut map = DependencyMap::default(); + for ordering in oeq_class.into_iter() { + // Previous expression is a dependency. Note that there is no + // dependency for the leading expression. + if !self.insert_to_dependency_map( + mapping, + ordering[0].clone(), + None, + &mut map, + ) { + continue; + } + for (dependency, sort_expr) in ordering.into_iter().tuple_windows() { + if !self.insert_to_dependency_map( + mapping, + sort_expr, + Some(dependency), + &mut map, + ) { + // If we can't project, stop constructing the dependency map + // as remaining dependencies will be invalid post projection. break; } } } - dependency_map + map } - /// Returns a new `ProjectionMapping` where source expressions are normalized. - /// - /// This normalization ensures that source expressions are transformed into a - /// consistent representation. This is beneficial for algorithms that rely on - /// exact equalities, as it allows for more precise and reliable comparisons. + /// Projects the sort expression according to the projection mapping and + /// inserts it into the dependency map with the given dependency. Returns + /// a boolean flag indicating whether the given expression is projectable. + fn insert_to_dependency_map( + &self, + mapping: &ProjectionMapping, + sort_expr: PhysicalSortExpr, + dependency: Option, + map: &mut DependencyMap, + ) -> bool { + let target_sort_expr = self + .project_expr(&sort_expr.expr, mapping) + .map(|expr| PhysicalSortExpr::new(expr, sort_expr.options)); + let projectable = target_sort_expr.is_some(); + if projectable + || mapping + .iter() + .any(|(source, _)| expr_refers(source, &sort_expr.expr)) + { + // Add sort expressions that can be projected or referred to + // by any of the projection expressions to the dependency map: + map.insert(sort_expr, target_sort_expr, dependency); + } + projectable + } + + /// Returns a new `ProjectionMapping` where source expressions are in normal + /// form. Normalization ensures that source expressions are transformed into + /// a consistent representation, which is beneficial for algorithms that rely + /// on exact equalities, as it allows for more precise and reliable comparisons. /// /// # Parameters /// - /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. + /// - `mapping`: A reference to the original `ProjectionMapping` to normalize. /// /// # Returns /// - /// A new `ProjectionMapping` with normalized source expressions. - fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { - // Construct the mapping where source expressions are normalized. In this way - // In the algorithms below we can work on exact equalities - ProjectionMapping { - map: mapping - .iter() - .map(|(source, target)| { - let normalized_source = - self.eq_group.normalize_expr(Arc::clone(source)); - (normalized_source, Arc::clone(target)) - }) - .collect(), - } + /// A new `ProjectionMapping` with source expressions in normal form. + fn normalize_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { + mapping + .iter() + .map(|(source, target)| { + let normal_source = self.eq_group.normalize_expr(Arc::clone(source)); + (normal_source, target.clone()) + }) + .collect() } /// Computes projected orderings based on a given projection mapping. @@ -955,42 +1015,55 @@ impl EquivalenceProperties { /// /// - `mapping`: A reference to the `ProjectionMapping` that defines the /// relationship between source and target expressions. + /// - `oeq_class`: The `OrderingEquivalenceClass` containing the orderings + /// to project. /// /// # Returns /// - /// A vector of `LexOrdering` containing all valid orderings after projection. - fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { - let mapping = self.normalized_mapping(mapping); - + /// A vector of all valid (but not in normal form) orderings after projection. + fn projected_orderings( + &self, + mapping: &ProjectionMapping, + mut oeq_class: OrderingEquivalenceClass, + ) -> Vec { + // Normalize source expressions in the mapping: + let mapping = self.normalize_mapping(mapping); // Get dependency map for existing orderings: - let dependency_map = self.construct_dependency_map(&mapping); - let orderings = mapping.iter().flat_map(|(source, target)| { + oeq_class = Self::substitute_oeq_class(&self.schema, &mapping, oeq_class); + let dependency_map = self.construct_dependency_map(oeq_class, &mapping); + let orderings = mapping.iter().flat_map(|(source, targets)| { referred_dependencies(&dependency_map, source) .into_iter() - .filter_map(|relevant_deps| { - if let Ok(SortProperties::Ordered(options)) = - get_expr_properties(source, &relevant_deps, &self.schema) - .map(|prop| prop.sort_properties) - { - Some((options, relevant_deps)) + .filter_map(|deps| { + let ep = get_expr_properties(source, &deps, &self.schema); + let sort_properties = ep.map(|prop| prop.sort_properties); + if let Ok(SortProperties::Ordered(options)) = sort_properties { + Some((options, deps)) } else { - // Do not consider unordered cases + // Do not consider unordered cases. None } }) .flat_map(|(options, relevant_deps)| { - let sort_expr = PhysicalSortExpr { - expr: Arc::clone(target), - options, - }; - // Generate dependent orderings (i.e. prefixes for `sort_expr`): - let mut dependency_orderings = + // Generate dependent orderings (i.e. prefixes for targets): + let dependency_orderings = generate_dependency_orderings(&relevant_deps, &dependency_map); - // Append `sort_expr` to the dependent orderings: - for ordering in dependency_orderings.iter_mut() { - ordering.push(sort_expr.clone()); + let sort_exprs = targets.iter().map(|(target, _)| { + PhysicalSortExpr::new(Arc::clone(target), options) + }); + if dependency_orderings.is_empty() { + sort_exprs.map(|sort_expr| [sort_expr].into()).collect() + } else { + sort_exprs + .flat_map(|sort_expr| { + let mut result = dependency_orderings.clone(); + for ordering in result.iter_mut() { + ordering.push(sort_expr.clone()); + } + result + }) + .collect::>() } - dependency_orderings }) }); @@ -1004,116 +1077,67 @@ impl EquivalenceProperties { if prefixes.is_empty() { // If prefix is empty, there is no dependency. Insert // empty ordering: - prefixes = vec![LexOrdering::default()]; - } - // Append current ordering on top its dependencies: - for ordering in prefixes.iter_mut() { - if let Some(target) = &node.target_sort_expr { - ordering.push(target.clone()) + if let Some(target) = &node.target { + prefixes.push([target.clone()].into()); + } + } else { + // Append current ordering on top its dependencies: + for ordering in prefixes.iter_mut() { + if let Some(target) = &node.target { + ordering.push(target.clone()); + } } } prefixes }); // Simplify each ordering by removing redundant sections: - orderings - .chain(projected_orderings) - .map(|lex_ordering| lex_ordering.collapse()) - .collect() - } - - /// Projects constants based on the provided `ProjectionMapping`. - /// - /// This function takes a `ProjectionMapping` and identifies/projects - /// constants based on the existing constants and the mapping. It ensures - /// that constants are appropriately propagated through the projection. - /// - /// # Parameters - /// - /// - `mapping`: A reference to a `ProjectionMapping` representing the - /// mapping of source expressions to target expressions in the projection. - /// - /// # Returns - /// - /// Returns a `Vec>` containing the projected constants. - fn projected_constants(&self, mapping: &ProjectionMapping) -> Vec { - // First, project existing constants. For example, assume that `a + b` - // is known to be constant. If the projection were `a as a_new`, `b as b_new`, - // then we would project constant `a + b` as `a_new + b_new`. - let mut projected_constants = self - .constants - .iter() - .flat_map(|const_expr| { - const_expr - .map(|expr| self.eq_group.project_expr(mapping, expr)) - .map(|projected_expr| { - projected_expr - .with_across_partitions(const_expr.across_partitions()) - }) - }) - .collect::>(); - - // Add projection expressions that are known to be constant: - for (source, target) in mapping.iter() { - if self.is_expr_constant(source) - && !const_exprs_contains(&projected_constants, target) - { - if self.is_expr_constant_across_partitions(source) { - projected_constants.push( - ConstExpr::from(target) - .with_across_partitions(self.get_expr_constant_value(source)), - ) - } else { - projected_constants.push( - ConstExpr::from(target) - .with_across_partitions(AcrossPartitions::Heterogeneous), - ) - } - } - } - projected_constants + orderings.chain(projected_orderings).collect() } /// Projects constraints according to the given projection mapping. /// - /// This function takes a projection mapping and extracts the column indices of the target columns. - /// It then projects the constraints to only include relationships between - /// columns that exist in the projected output. + /// This function takes a projection mapping and extracts column indices of + /// target columns. It then projects the constraints to only include + /// relationships between columns that exist in the projected output. /// - /// # Arguments + /// # Parameters /// - /// * `mapping` - A reference to `ProjectionMapping` that defines how expressions are mapped - /// in the projection operation + /// * `mapping` - A reference to the `ProjectionMapping` that defines the + /// projection operation. /// /// # Returns /// - /// Returns a new `Constraints` object containing only the constraints - /// that are valid for the projected columns. + /// Returns an optional `Constraints` object containing only the constraints + /// that are valid for the projected columns (if any exists). fn projected_constraints(&self, mapping: &ProjectionMapping) -> Option { let indices = mapping .iter() - .filter_map(|(_, target)| target.as_any().downcast_ref::()) - .map(|col| col.index()) + .flat_map(|(_, targets)| { + targets.iter().flat_map(|(target, _)| { + target.as_any().downcast_ref::().map(|c| c.index()) + }) + }) .collect::>(); - debug_assert_eq!(mapping.map.len(), indices.len()); self.constraints.project(&indices) } - /// Projects the equivalences within according to `mapping` - /// and `output_schema`. + /// Projects the equivalences within according to `mapping` and + /// `output_schema`. pub fn project(&self, mapping: &ProjectionMapping, output_schema: SchemaRef) -> Self { let eq_group = self.eq_group.project(mapping); - let oeq_class = OrderingEquivalenceClass::new(self.projected_orderings(mapping)); - let constants = self.projected_constants(mapping); - let constraints = self - .projected_constraints(mapping) - .unwrap_or_else(Constraints::empty); + let orderings = + self.projected_orderings(mapping, self.oeq_cache.normal_cls.clone()); + let normal_orderings = orderings + .iter() + .cloned() + .map(|o| eq_group.normalize_sort_exprs(o)); Self { + oeq_cache: OrderingEquivalenceCache::new(normal_orderings), + oeq_class: OrderingEquivalenceClass::new(orderings), + constraints: self.projected_constraints(mapping).unwrap_or_default(), schema: output_schema, eq_group, - oeq_class, - constants, - constraints, } } @@ -1130,7 +1154,7 @@ impl EquivalenceProperties { pub fn find_longest_permutation( &self, exprs: &[Arc], - ) -> (LexOrdering, Vec) { + ) -> Result<(Vec, Vec)> { let mut eq_properties = self.clone(); let mut result = vec![]; // The algorithm is as follows: @@ -1143,32 +1167,23 @@ impl EquivalenceProperties { // This algorithm should reach a fixed point in at most `exprs.len()` // iterations. let mut search_indices = (0..exprs.len()).collect::>(); - for _idx in 0..exprs.len() { + for _ in 0..exprs.len() { // Get ordered expressions with their indices. let ordered_exprs = search_indices .iter() - .flat_map(|&idx| { + .filter_map(|&idx| { let ExprProperties { sort_properties, .. } = eq_properties.get_expr_properties(Arc::clone(&exprs[idx])); match sort_properties { - SortProperties::Ordered(options) => Some(( - PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options, - }, - idx, - )), + SortProperties::Ordered(options) => { + let expr = Arc::clone(&exprs[idx]); + Some((PhysicalSortExpr::new(expr, options), idx)) + } SortProperties::Singleton => { - // Assign default ordering to constant expressions - let options = SortOptions::default(); - Some(( - PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options, - }, - idx, - )) + // Assign default ordering to constant expressions: + let expr = Arc::clone(&exprs[idx]); + Some((PhysicalSortExpr::new_default(expr), idx)) } SortProperties::Unordered => None, } @@ -1186,44 +1201,20 @@ impl EquivalenceProperties { // Note that these expressions are not properly "constants". This is just // an implementation strategy confined to this function. for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { - eq_properties = - eq_properties.with_constants(std::iter::once(ConstExpr::from(expr))); + let const_expr = ConstExpr::from(Arc::clone(expr)); + eq_properties.add_constants(std::iter::once(const_expr))?; search_indices.shift_remove(idx); } // Add new ordered section to the state. result.extend(ordered_exprs); } - let (left, right) = result.into_iter().unzip(); - (LexOrdering::new(left), right) - } - - /// This function determines whether the provided expression is constant - /// based on the known constants. - /// - /// # Parameters - /// - /// - `expr`: A reference to a `Arc` representing the - /// expression to be checked. - /// - /// # Returns - /// - /// Returns `true` if the expression is constant according to equivalence - /// group, `false` otherwise. - pub fn is_expr_constant(&self, expr: &Arc) -> bool { - // As an example, assume that we know columns `a` and `b` are constant. - // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will - // return `false`. - let const_exprs = self - .constants - .iter() - .map(|const_expr| Arc::clone(const_expr.expr())); - let normalized_constants = self.eq_group.normalize_exprs(const_exprs); - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - is_constant_recurse(&normalized_constants, &normalized_expr) + Ok(result.into_iter().unzip()) } /// This function determines whether the provided expression is constant - /// across partitions based on the known constants. + /// based on the known constants. For example, if columns `a` and `b` are + /// constant, then expressions `a`, `b` and `a + b` will all return `true` + /// whereas expression `c` will return `false`. /// /// # Parameters /// @@ -1232,87 +1223,15 @@ impl EquivalenceProperties { /// /// # Returns /// - /// Returns `true` if the expression is constant across all partitions according - /// to equivalence group, `false` otherwise - #[deprecated( - since = "45.0.0", - note = "Use [`is_expr_constant_across_partitions`] instead" - )] - pub fn is_expr_constant_accross_partitions( + /// Returns a `Some` value if the expression is constant according to + /// equivalence group, and `None` otherwise. The `Some` variant contains + /// an `AcrossPartitions` value indicating whether the expression is + /// constant across partitions, and its actual value (if available). + pub fn is_expr_constant( &self, expr: &Arc, - ) -> bool { - self.is_expr_constant_across_partitions(expr) - } - - /// This function determines whether the provided expression is constant - /// across partitions based on the known constants. - /// - /// # Parameters - /// - /// - `expr`: A reference to a `Arc` representing the - /// expression to be checked. - /// - /// # Returns - /// - /// Returns `true` if the expression is constant across all partitions according - /// to equivalence group, `false` otherwise. - pub fn is_expr_constant_across_partitions( - &self, - expr: &Arc, - ) -> bool { - // As an example, assume that we know columns `a` and `b` are constant. - // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will - // return `false`. - let const_exprs = self - .constants - .iter() - .filter_map(|const_expr| { - if matches!( - const_expr.across_partitions(), - AcrossPartitions::Uniform { .. } - ) { - Some(Arc::clone(const_expr.expr())) - } else { - None - } - }) - .collect::>(); - let normalized_constants = self.eq_group.normalize_exprs(const_exprs); - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - is_constant_recurse(&normalized_constants, &normalized_expr) - } - - /// Retrieves the constant value of a given physical expression, if it exists. - /// - /// Normalizes the input expression and checks if it matches any known constants - /// in the current context. Returns whether the expression has a uniform value, - /// varies across partitions, or is not constant. - /// - /// # Parameters - /// - `expr`: A reference to the physical expression to evaluate. - /// - /// # Returns - /// - `AcrossPartitions::Uniform(value)`: If the expression has the same value across partitions. - /// - `AcrossPartitions::Heterogeneous`: If the expression varies across partitions. - /// - `None`: If the expression is not recognized as constant. - pub fn get_expr_constant_value( - &self, - expr: &Arc, - ) -> AcrossPartitions { - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - - if let Some(lit) = normalized_expr.as_any().downcast_ref::() { - return AcrossPartitions::Uniform(Some(lit.value().clone())); - } - - for const_expr in self.constants.iter() { - if normalized_expr.eq(const_expr.expr()) { - return const_expr.across_partitions(); - } - } - - AcrossPartitions::Heterogeneous + ) -> Option { + self.eq_group.is_expr_constant(expr) } /// Retrieves the properties for a given physical expression. @@ -1335,13 +1254,12 @@ impl EquivalenceProperties { .transform_up(|expr| update_properties(expr, self)) .data() .map(|node| node.data) - .unwrap_or(ExprProperties::new_unknown()) + .unwrap_or_else(|_| ExprProperties::new_unknown()) } - /// Transforms this `EquivalenceProperties` into a new `EquivalenceProperties` - /// by mapping columns in the original schema to columns in the new schema - /// by index. - pub fn with_new_schema(self, schema: SchemaRef) -> Result { + /// Transforms this `EquivalenceProperties` by mapping columns in the + /// original schema to columns in the new schema by index. + pub fn with_new_schema(mut self, schema: SchemaRef) -> Result { // The new schema and the original schema is aligned when they have the // same number of columns, and fields at the same index have the same // type in both schemas. @@ -1356,54 +1274,49 @@ impl EquivalenceProperties { // Rewriting equivalence properties in terms of new schema is not // safe when schemas are not aligned: return plan_err!( - "Cannot rewrite old_schema:{:?} with new schema: {:?}", + "Schemas have to be aligned to rewrite equivalences:\n Old schema: {:?}\n New schema: {:?}", self.schema, schema ); } - // Rewrite constants according to new schema: - let new_constants = self - .constants - .into_iter() - .map(|const_expr| { - let across_partitions = const_expr.across_partitions(); - let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?; - Ok(ConstExpr::new(new_const_expr) - .with_across_partitions(across_partitions)) - }) - .collect::>>()?; - - // Rewrite orderings according to new schema: - let mut new_orderings = vec![]; - for ordering in self.oeq_class { - let new_ordering = ordering - .into_iter() - .map(|mut sort_expr| { - sort_expr.expr = with_new_schema(sort_expr.expr, &schema)?; - Ok(sort_expr) - }) - .collect::>()?; - new_orderings.push(new_ordering); - } // Rewrite equivalence classes according to the new schema: let mut eq_classes = vec![]; - for eq_class in self.eq_group { - let new_eq_exprs = eq_class - .into_vec() + for mut eq_class in self.eq_group { + // Rewrite the expressions in the equivalence class: + eq_class.exprs = eq_class + .exprs .into_iter() .map(|expr| with_new_schema(expr, &schema)) .collect::>()?; - eq_classes.push(EquivalenceClass::new(new_eq_exprs)); + // Rewrite the constant value (if available and known): + let data_type = eq_class + .canonical_expr() + .map(|e| e.data_type(&schema)) + .transpose()?; + if let (Some(data_type), Some(AcrossPartitions::Uniform(Some(value)))) = + (data_type, &mut eq_class.constant) + { + *value = value.cast_to(&data_type)?; + } + eq_classes.push(eq_class); } + self.eq_group = eq_classes.into(); + + // Rewrite orderings according to new schema: + self.oeq_class = self.oeq_class.with_new_schema(&schema)?; + self.oeq_cache.normal_cls = self.oeq_cache.normal_cls.with_new_schema(&schema)?; + + // Update the schema: + self.schema = schema; - // Construct the resulting equivalence properties: - let mut result = EquivalenceProperties::new(schema); - result.constants = new_constants; - result.add_new_orderings(new_orderings); - result.add_equivalence_group(EquivalenceGroup::new(eq_classes)); + Ok(self) + } +} - Ok(result) +impl From for OrderingEquivalenceClass { + fn from(eq_properties: EquivalenceProperties) -> Self { + eq_properties.oeq_class } } @@ -1411,24 +1324,21 @@ impl EquivalenceProperties { /// /// Format: /// ```text -/// order: [[a ASC, b ASC], [a ASC, c ASC]], eq: [[a = b], [a = c]], const: [a = 1] +/// order: [[b@1 ASC NULLS LAST]], eq: [{members: [a@0], constant: (heterogeneous)}] /// ``` impl Display for EquivalenceProperties { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.eq_group.is_empty() - && self.oeq_class.is_empty() - && self.constants.is_empty() - { - return write!(f, "No properties"); - } - if !self.oeq_class.is_empty() { + let empty_eq_group = self.eq_group.is_empty(); + let empty_oeq_class = self.oeq_class.is_empty(); + if empty_oeq_class && empty_eq_group { + write!(f, "No properties")?; + } else if !empty_oeq_class { write!(f, "order: {}", self.oeq_class)?; - } - if !self.eq_group.is_empty() { - write!(f, ", eq: {}", self.eq_group)?; - } - if !self.constants.is_empty() { - write!(f, ", const: [{}]", ConstExpr::format_list(&self.constants))?; + if !empty_eq_group { + write!(f, ", eq: {}", self.eq_group)?; + } + } else { + write!(f, "eq: {}", self.eq_group)?; } Ok(()) } @@ -1472,45 +1382,20 @@ fn update_properties( Interval::make_unbounded(&node.expr.data_type(eq_properties.schema())?)? } // Now, check what we know about orderings: - let normalized_expr = eq_properties + let normal_expr = eq_properties .eq_group .normalize_expr(Arc::clone(&node.expr)); - let oeq_class = eq_properties.normalized_oeq_class(); - if eq_properties.is_expr_constant(&normalized_expr) - || oeq_class.is_expr_partial_const(&normalized_expr) + let oeq_class = &eq_properties.oeq_cache.normal_cls; + if eq_properties.is_expr_constant(&normal_expr).is_some() + || oeq_class.is_expr_partial_const(&normal_expr) { node.data.sort_properties = SortProperties::Singleton; - } else if let Some(options) = oeq_class.get_options(&normalized_expr) { + } else if let Some(options) = oeq_class.get_options(&normal_expr) { node.data.sort_properties = SortProperties::Ordered(options); } Ok(Transformed::yes(node)) } -/// This function determines whether the provided expression is constant -/// based on the known constants. -/// -/// # Parameters -/// -/// - `constants`: A `&[Arc]` containing expressions known to -/// be a constant. -/// - `expr`: A reference to a `Arc` representing the expression -/// to check. -/// -/// # Returns -/// -/// Returns `true` if the expression is constant according to equivalence -/// group, `false` otherwise. -fn is_constant_recurse( - constants: &[Arc], - expr: &Arc, -) -> bool { - if physical_exprs_contains(constants, expr) || expr.as_any().is::() { - return true; - } - let children = expr.children(); - !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) -} - /// This function examines whether a referring expression directly refers to a /// given referred expression or if any of its children in the expression tree /// refer to the specified expression. @@ -1571,7 +1456,7 @@ fn get_expr_properties( } else if let Some(literal) = expr.as_any().downcast_ref::() { Ok(ExprProperties { sort_properties: SortProperties::Singleton, - range: Interval::try_new(literal.value().clone(), literal.value().clone())?, + range: literal.value().into(), preserves_lex_ordering: true, }) } else { @@ -1585,59 +1470,3 @@ fn get_expr_properties( expr.get_properties(&child_states) } } - -/// Wrapper struct for `Arc` to use them as keys in a hash map. -#[derive(Debug, Clone)] -struct ExprWrapper(Arc); - -impl PartialEq for ExprWrapper { - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } -} - -impl Eq for ExprWrapper {} - -impl Hash for ExprWrapper { - fn hash(&self, state: &mut H) { - self.0.hash(state); - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::expressions::{col, BinaryExpr}; - - use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; - use datafusion_expr::Operator; - - #[test] - fn test_expr_consists_of_constants() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; - let col_d = col("d", &schema)?; - let b_plus_d = Arc::new(BinaryExpr::new( - Arc::clone(&col_b), - Operator::Plus, - Arc::clone(&col_d), - )) as Arc; - - let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b)]; - let expr = Arc::clone(&b_plus_d); - assert!(!is_constant_recurse(&constants, &expr)); - - let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b), Arc::clone(&col_d)]; - let expr = Arc::clone(&b_plus_d); - assert!(is_constant_recurse(&constants, &expr)); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index 64ef9278e248b..efbefd0d39bfb 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -15,28 +15,26 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{internal_err, Result}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::iter::Peekable; use std::sync::Arc; +use super::EquivalenceProperties; use crate::equivalence::class::AcrossPartitions; -use crate::ConstExpr; +use crate::{ConstExpr, PhysicalSortExpr}; -use super::EquivalenceProperties; -use crate::PhysicalSortExpr; use arrow::datatypes::SchemaRef; -use std::slice::Iter; +use datafusion_common::{internal_err, Result}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; -/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` -/// of `lhs` and `rhs` according to the schema of `lhs`. +/// Computes the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of `lhs` and `rhs` according to the schema of `lhs`. /// -/// Rules: The UnionExec does not interleave its inputs: instead it passes each -/// input partition from the children as its own output. +/// Rules: The `UnionExec` does not interleave its inputs, instead it passes +/// each input partition from the children as its own output. /// /// Since the output equivalence properties are properties that are true for /// *all* output partitions, that is the same as being true for all *input* -/// partitions +/// partitions. fn calculate_union_binary( lhs: EquivalenceProperties, mut rhs: EquivalenceProperties, @@ -48,28 +46,21 @@ fn calculate_union_binary( // First, calculate valid constants for the union. An expression is constant // at the output of the union if it is constant in both sides with matching values. + let rhs_constants = rhs.constants(); let constants = lhs .constants() - .iter() + .into_iter() .filter_map(|lhs_const| { // Find matching constant expression in RHS - rhs.constants() + rhs_constants .iter() - .find(|rhs_const| rhs_const.expr().eq(lhs_const.expr())) + .find(|rhs_const| rhs_const.expr.eq(&lhs_const.expr)) .map(|rhs_const| { - let mut const_expr = ConstExpr::new(Arc::clone(lhs_const.expr())); - - // If both sides have matching constant values, preserve the value and set across_partitions=true - if let ( - AcrossPartitions::Uniform(Some(lhs_val)), - AcrossPartitions::Uniform(Some(rhs_val)), - ) = (lhs_const.across_partitions(), rhs_const.across_partitions()) - { - if lhs_val == rhs_val { - const_expr = const_expr.with_across_partitions( - AcrossPartitions::Uniform(Some(lhs_val)), - ) - } + let mut const_expr = lhs_const.clone(); + // If both sides have matching constant values, preserve it. + // Otherwise, set fall back to heterogeneous values. + if lhs_const.across_partitions != rhs_const.across_partitions { + const_expr.across_partitions = AcrossPartitions::Heterogeneous; } const_expr }) @@ -79,14 +70,13 @@ fn calculate_union_binary( // Next, calculate valid orderings for the union by searching for prefixes // in both sides. let mut orderings = UnionEquivalentOrderingBuilder::new(); - orderings.add_satisfied_orderings(lhs.normalized_oeq_class(), lhs.constants(), &rhs); - orderings.add_satisfied_orderings(rhs.normalized_oeq_class(), rhs.constants(), &lhs); + orderings.add_satisfied_orderings(&lhs, &rhs)?; + orderings.add_satisfied_orderings(&rhs, &lhs)?; let orderings = orderings.build(); - let mut eq_properties = - EquivalenceProperties::new(lhs.schema).with_constants(constants); - - eq_properties.add_new_orderings(orderings); + let mut eq_properties = EquivalenceProperties::new(lhs.schema); + eq_properties.add_constants(constants)?; + eq_properties.add_orderings(orderings); Ok(eq_properties) } @@ -137,135 +127,139 @@ impl UnionEquivalentOrderingBuilder { Self { orderings: vec![] } } - /// Add all orderings from `orderings` that satisfy `properties`, - /// potentially augmented with`constants`. + /// Add all orderings from `source` that satisfy `properties`, + /// potentially augmented with the constants in `source`. /// - /// Note: any column that is known to be constant can be inserted into the - /// ordering without changing its meaning + /// Note: Any column that is known to be constant can be inserted into the + /// ordering without changing its meaning. /// /// For example: - /// * `orderings` contains `[a ASC, c ASC]` and `constants` contains `b` - /// * `properties` has required ordering `[a ASC, b ASC]` + /// * Orderings in `source` contains `[a ASC, c ASC]` and constants contains + /// `b`, + /// * `properties` has the ordering `[a ASC, b ASC]`. /// /// Then this will add `[a ASC, b ASC]` to the `orderings` list (as `a` was /// in the sort order and `b` was a constant). fn add_satisfied_orderings( &mut self, - orderings: impl IntoIterator, - constants: &[ConstExpr], + source: &EquivalenceProperties, properties: &EquivalenceProperties, - ) { - for mut ordering in orderings.into_iter() { + ) -> Result<()> { + let constants = source.constants(); + let properties_constants = properties.constants(); + for mut ordering in source.oeq_cache.normal_cls.clone() { // Progressively shorten the ordering to search for a satisfied prefix: loop { - match self.try_add_ordering(ordering, constants, properties) { + ordering = match self.try_add_ordering( + ordering, + &constants, + properties, + &properties_constants, + )? { AddedOrdering::Yes => break, - AddedOrdering::No(o) => { - ordering = o; - ordering.pop(); + AddedOrdering::No(ordering) => { + let mut sort_exprs: Vec<_> = ordering.into(); + sort_exprs.pop(); + if let Some(ordering) = LexOrdering::new(sort_exprs) { + ordering + } else { + break; + } } } } } + Ok(()) } - /// Adds `ordering`, potentially augmented with constants, if it satisfies - /// the target `properties` properties. + /// Adds `ordering`, potentially augmented with `constants`, if it satisfies + /// the given `properties`. /// - /// Returns + /// # Returns /// - /// * [`AddedOrdering::Yes`] if the ordering was added (either directly or - /// augmented), or was empty. - /// - /// * [`AddedOrdering::No`] if the ordering was not added + /// An [`AddedOrdering::Yes`] instance if the ordering was added (either + /// directly or augmented), or was empty. An [`AddedOrdering::No`] instance + /// otherwise. fn try_add_ordering( &mut self, ordering: LexOrdering, constants: &[ConstExpr], properties: &EquivalenceProperties, - ) -> AddedOrdering { - if ordering.is_empty() { - AddedOrdering::Yes - } else if properties.ordering_satisfy(ordering.as_ref()) { + properties_constants: &[ConstExpr], + ) -> Result { + if properties.ordering_satisfy(ordering.clone())? { // If the ordering satisfies the target properties, no need to // augment it with constants. self.orderings.push(ordering); - AddedOrdering::Yes + Ok(AddedOrdering::Yes) + } else if self.try_find_augmented_ordering( + &ordering, + constants, + properties, + properties_constants, + ) { + // Augmented with constants to match the properties. + Ok(AddedOrdering::Yes) } else { - // Did not satisfy target properties, try and augment with constants - // to match the properties - if self.try_find_augmented_ordering(&ordering, constants, properties) { - AddedOrdering::Yes - } else { - AddedOrdering::No(ordering) - } + Ok(AddedOrdering::No(ordering)) } } /// Attempts to add `constants` to `ordering` to satisfy the properties. - /// - /// returns true if any orderings were added, false otherwise + /// Returns `true` if augmentation took place, `false` otherwise. fn try_find_augmented_ordering( &mut self, ordering: &LexOrdering, constants: &[ConstExpr], properties: &EquivalenceProperties, + properties_constants: &[ConstExpr], ) -> bool { - // can't augment if there is nothing to augment with - if constants.is_empty() { - return false; - } - let start_num_orderings = self.orderings.len(); - - // for each equivalent ordering in properties, try and augment - // `ordering` it with the constants to match - for existing_ordering in properties.oeq_class.iter() { - if let Some(augmented_ordering) = self.augment_ordering( - ordering, - constants, - existing_ordering, - &properties.constants, - ) { - if !augmented_ordering.is_empty() { - assert!(properties.ordering_satisfy(augmented_ordering.as_ref())); + let mut result = false; + // Can only augment if there are constants. + if !constants.is_empty() { + // For each equivalent ordering in properties, try and augment + // `ordering` with the constants to match `existing_ordering`: + for existing_ordering in properties.oeq_class.iter() { + if let Some(augmented_ordering) = Self::augment_ordering( + ordering, + constants, + existing_ordering, + properties_constants, + ) { self.orderings.push(augmented_ordering); + result = true; } } } - - self.orderings.len() > start_num_orderings + result } - /// Attempts to augment the ordering with constants to match the - /// `existing_ordering` - /// - /// Returns Some(ordering) if an augmented ordering was found, None otherwise + /// Attempts to augment the ordering with constants to match `existing_ordering`. + /// Returns `Some(ordering)` if an augmented ordering was found, `None` otherwise. fn augment_ordering( - &mut self, ordering: &LexOrdering, constants: &[ConstExpr], existing_ordering: &LexOrdering, existing_constants: &[ConstExpr], ) -> Option { - let mut augmented_ordering = LexOrdering::default(); - let mut sort_expr_iter = ordering.iter().peekable(); - let mut existing_sort_expr_iter = existing_ordering.iter().peekable(); - - // walk in parallel down the two orderings, trying to match them up - while sort_expr_iter.peek().is_some() || existing_sort_expr_iter.peek().is_some() - { - // If the next expressions are equal, add the next match - // otherwise try and match with a constant + let mut augmented_ordering = vec![]; + let mut sort_exprs = ordering.iter().peekable(); + let mut existing_sort_exprs = existing_ordering.iter().peekable(); + + // Walk in parallel down the two orderings, trying to match them up: + while sort_exprs.peek().is_some() || existing_sort_exprs.peek().is_some() { + // If the next expressions are equal, add the next match. Otherwise, + // try and match with a constant. if let Some(expr) = - advance_if_match(&mut sort_expr_iter, &mut existing_sort_expr_iter) + advance_if_match(&mut sort_exprs, &mut existing_sort_exprs) { augmented_ordering.push(expr); } else if let Some(expr) = - advance_if_matches_constant(&mut sort_expr_iter, existing_constants) + advance_if_matches_constant(&mut sort_exprs, existing_constants) { augmented_ordering.push(expr); } else if let Some(expr) = - advance_if_matches_constant(&mut existing_sort_expr_iter, constants) + advance_if_matches_constant(&mut existing_sort_exprs, constants) { augmented_ordering.push(expr); } else { @@ -274,7 +268,7 @@ impl UnionEquivalentOrderingBuilder { } } - Some(augmented_ordering) + LexOrdering::new(augmented_ordering) } fn build(self) -> Vec { @@ -282,134 +276,135 @@ impl UnionEquivalentOrderingBuilder { } } -/// Advances two iterators in parallel -/// -/// If the next expressions are equal, the iterators are advanced and returns -/// the matched expression . -/// -/// Otherwise, the iterators are left unchanged and return `None` -fn advance_if_match( - iter1: &mut Peekable>, - iter2: &mut Peekable>, +/// Advances two iterators in parallel if the next expressions are equal. +/// Otherwise, the iterators are left unchanged and returns `None`. +fn advance_if_match<'a>( + iter1: &mut Peekable>, + iter2: &mut Peekable>, ) -> Option { - if matches!((iter1.peek(), iter2.peek()), (Some(expr1), Some(expr2)) if expr1.eq(expr2)) - { - iter1.next().unwrap(); + let (expr1, expr2) = (iter1.peek()?, iter2.peek()?); + if expr1.eq(expr2) { + iter1.next(); iter2.next().cloned() } else { None } } -/// Advances the iterator with a constant -/// -/// If the next expression matches one of the constants, advances the iterator -/// returning the matched expression -/// -/// Otherwise, the iterator is left unchanged and returns `None` -fn advance_if_matches_constant( - iter: &mut Peekable>, +/// Advances the iterator with a constant if the next expression matches one of +/// the constants. Otherwise, the iterator is left unchanged and returns `None`. +fn advance_if_matches_constant<'a>( + iter: &mut Peekable>, constants: &[ConstExpr], ) -> Option { let expr = iter.peek()?; - let const_expr = constants.iter().find(|c| c.eq_expr(expr))?; - let found_expr = PhysicalSortExpr::new(Arc::clone(const_expr.expr()), expr.options); + let const_expr = constants.iter().find(|c| expr.expr.eq(&c.expr))?; + let found_expr = PhysicalSortExpr::new(Arc::clone(&const_expr.expr), expr.options); iter.next(); Some(found_expr) } #[cfg(test)] mod tests { - use super::*; - use crate::equivalence::class::const_exprs_contains; use crate::equivalence::tests::{create_test_schema, parse_sort_expr}; use crate::expressions::col; + use crate::PhysicalExpr; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use itertools::Itertools; + /// Checks whether `expr` is among in the `const_exprs`. + fn const_exprs_contains( + const_exprs: &[ConstExpr], + expr: &Arc, + ) -> bool { + const_exprs + .iter() + .any(|const_expr| const_expr.expr.eq(expr)) + } + #[test] - fn test_union_equivalence_properties_multi_children_1() { + fn test_union_equivalence_properties_multi_children_1() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + .with_child_sort(vec![vec!["a", "b", "c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b"]]) + .with_child_sort(vec![vec!["a2", "b2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_2() { + fn test_union_equivalence_properties_multi_children_2() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + .with_child_sort(vec![vec!["a", "b", "c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b", "c"]]) + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b", "c"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_3() { + fn test_union_equivalence_properties_multi_children_3() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"]], &schema) + .with_child_sort(vec![vec!["a", "b"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b"]]) + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_4() { + fn test_union_equivalence_properties_multi_children_4() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"]], &schema) + .with_child_sort(vec![vec!["a", "b"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["b2", "c2"]], &schema3) - .with_expected_sort(vec![]) + .with_child_sort(vec![vec!["b2", "c2"]], &schema3)? + .with_expected_sort(vec![])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_5() { + fn test_union_equivalence_properties_multi_children_5() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema) + .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2) - .with_expected_sort(vec![vec!["a", "b"], vec!["c"]]) + .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2)? + .with_expected_sort(vec![vec!["a", "b"], vec!["c"]])? .run() } #[test] - fn test_union_equivalence_properties_constants_common_constants() { + fn test_union_equivalence_properties_constants_common_constants() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -417,23 +412,23 @@ mod tests { vec![vec!["a"]], vec!["b", "c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [b ASC], const [a, c] vec![vec!["b"]], vec!["a", "c"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union expected orderings: [[a ASC], [b ASC]], const [c] vec![vec!["a"], vec!["b"]], vec!["c"], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_prefix() { + fn test_union_equivalence_properties_constants_prefix() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -441,23 +436,23 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, b ASC], const [] vec![vec!["a", "b"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [a ASC], const [] vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_asc_desc_mismatch() { + fn test_union_equivalence_properties_constants_asc_desc_mismatch() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -465,23 +460,23 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a DESC], const [] vec![vec!["a DESC"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union doesn't have any ordering or constant vec![], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_different_schemas() { + fn test_union_equivalence_properties_constants_different_schemas() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); UnionEquivalenceTest::new(&schema) @@ -490,13 +485,13 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a1 ASC, b1 ASC], const [] vec![vec!["a1", "b1"]], vec![], &schema2, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [a ASC] // @@ -504,12 +499,12 @@ mod tests { // corresponding schemas. vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_gaps() { + fn test_union_equivalence_properties_constants_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -517,13 +512,13 @@ mod tests { vec![vec!["a", "c"]], vec!["b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [ // [a ASC, b ASC, c ASC], @@ -531,12 +526,12 @@ mod tests { // ], const [] vec![vec!["a", "b", "c"], vec!["b", "a", "c"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_no_fill_gaps() { + fn test_union_equivalence_properties_constants_no_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -544,23 +539,23 @@ mod tests { vec![vec!["a", "c"]], vec!["d"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [[a]] (only a is constant) vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_some_gaps() { + fn test_union_equivalence_properties_constants_fill_some_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -568,23 +563,24 @@ mod tests { vec![vec!["c"]], vec!["a", "b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a DESC, b], const [] vec![vec!["a DESC", "b"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [[a, b]] (can fill in the a/b with constants) vec![vec!["a DESC", "b"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() { + fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() -> Result<()> + { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -592,13 +588,13 @@ mod tests { vec![vec!["a", "c"]], vec!["b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b DESC", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [ // [a ASC, b ASC, c ASC], @@ -606,12 +602,12 @@ mod tests { // ], const [] vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_gap_fill_symmetric() { + fn test_union_equivalence_properties_constants_gap_fill_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -619,25 +615,25 @@ mod tests { vec![vec!["a", "b", "d"]], vec!["c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, c ASC, d ASC], const [b] vec![vec!["a", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a, b, c, d] // [a, c, b, d] vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_gap_fill_and_common() { + fn test_union_equivalence_properties_constants_gap_fill_and_common() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -645,24 +641,24 @@ mod tests { vec![vec!["a DESC", "d"]], vec!["b", "c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a DESC, c ASC, d ASC], const [b] vec![vec!["a DESC", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a DESC, c, d] [b] vec![vec!["a DESC", "c", "d"]], vec!["b"], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_middle_desc() { + fn test_union_equivalence_properties_constants_middle_desc() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -672,20 +668,20 @@ mod tests { vec![vec!["a", "b DESC", "d"]], vec!["c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, c ASC, d ASC], const [b] vec![vec!["a", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a, b, d] (c constant) // [a, c, d] (b constant) vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]], vec![], - ) + )? .run() } @@ -718,10 +714,10 @@ mod tests { mut self, orderings: Vec>, schema: &SchemaRef, - ) -> Self { - let properties = self.make_props(orderings, vec![], schema); + ) -> Result { + let properties = self.make_props(orderings, vec![], schema)?; self.child_properties.push(properties); - self + Ok(self) } /// Add a union input with the specified orderings and constant @@ -734,19 +730,19 @@ mod tests { orderings: Vec>, constants: Vec<&str>, schema: &SchemaRef, - ) -> Self { - let properties = self.make_props(orderings, constants, schema); + ) -> Result { + let properties = self.make_props(orderings, constants, schema)?; self.child_properties.push(properties); - self + Ok(self) } /// Set the expected output sort order for the union of the children /// /// See [`Self::make_props`] for the format of the strings in `orderings` - fn with_expected_sort(mut self, orderings: Vec>) -> Self { - let properties = self.make_props(orderings, vec![], &self.output_schema); + fn with_expected_sort(mut self, orderings: Vec>) -> Result { + let properties = self.make_props(orderings, vec![], &self.output_schema)?; self.expected_properties = Some(properties); - self + Ok(self) } /// Set the expected output sort order and constant expressions for the @@ -758,15 +754,16 @@ mod tests { mut self, orderings: Vec>, constants: Vec<&str>, - ) -> Self { - let properties = self.make_props(orderings, constants, &self.output_schema); + ) -> Result { + let properties = + self.make_props(orderings, constants, &self.output_schema)?; self.expected_properties = Some(properties); - self + Ok(self) } /// compute the union's output equivalence properties from the child /// properties, and compare them to the expected properties - fn run(self) { + fn run(self) -> Result<()> { let Self { output_schema, child_properties, @@ -798,6 +795,7 @@ mod tests { ), ); } + Ok(()) } fn assert_eq_properties_same( @@ -808,9 +806,9 @@ mod tests { // Check whether constants are same let lhs_constants = lhs.constants(); let rhs_constants = rhs.constants(); - for rhs_constant in rhs_constants { + for rhs_constant in &rhs_constants { assert!( - const_exprs_contains(lhs_constants, rhs_constant.expr()), + const_exprs_contains(&lhs_constants, &rhs_constant.expr), "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" ); } @@ -845,24 +843,19 @@ mod tests { orderings: Vec>, constants: Vec<&str>, schema: &SchemaRef, - ) -> EquivalenceProperties { - let orderings = orderings - .iter() - .map(|ordering| { - ordering - .iter() - .map(|name| parse_sort_expr(name, schema)) - .collect::() - }) - .collect::>(); + ) -> Result { + let orderings = orderings.iter().map(|ordering| { + ordering.iter().map(|name| parse_sort_expr(name, schema)) + }); let constants = constants .iter() - .map(|col_name| ConstExpr::new(col(col_name, schema).unwrap())) - .collect::>(); + .map(|col_name| ConstExpr::from(col(col_name, schema).unwrap())); - EquivalenceProperties::new_with_orderings(Arc::clone(schema), &orderings) - .with_constants(constants) + let mut props = + EquivalenceProperties::new_with_orderings(Arc::clone(schema), orderings); + props.add_constants(constants)?; + Ok(props) } } @@ -877,25 +870,29 @@ mod tests { let literal_10 = ScalarValue::Int32(Some(10)); // Create first input with a=10 - let const_expr1 = ConstExpr::new(Arc::clone(&col_a)) - .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone()))); - let input1 = EquivalenceProperties::new(Arc::clone(&schema)) - .with_constants(vec![const_expr1]); + let const_expr1 = ConstExpr::new( + Arc::clone(&col_a), + AcrossPartitions::Uniform(Some(literal_10.clone())), + ); + let mut input1 = EquivalenceProperties::new(Arc::clone(&schema)); + input1.add_constants(vec![const_expr1])?; // Create second input with a=10 - let const_expr2 = ConstExpr::new(Arc::clone(&col_a)) - .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone()))); - let input2 = EquivalenceProperties::new(Arc::clone(&schema)) - .with_constants(vec![const_expr2]); + let const_expr2 = ConstExpr::new( + Arc::clone(&col_a), + AcrossPartitions::Uniform(Some(literal_10.clone())), + ); + let mut input2 = EquivalenceProperties::new(Arc::clone(&schema)); + input2.add_constants(vec![const_expr2])?; // Calculate union properties let union_props = calculate_union(vec![input1, input2], schema)?; // Verify column 'a' remains constant with value 10 let const_a = &union_props.constants()[0]; - assert!(const_a.expr().eq(&col_a)); + assert!(const_a.expr.eq(&col_a)); assert_eq!( - const_a.across_partitions(), + const_a.across_partitions, AcrossPartitions::Uniform(Some(literal_10)) ); @@ -924,4 +921,63 @@ mod tests { .collect::>(), )) } + + #[test] + fn test_constants_share_values() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("const_1", DataType::Utf8, false), + Field::new("const_2", DataType::Utf8, false), + ])); + + let col_const_1 = col("const_1", &schema)?; + let col_const_2 = col("const_2", &schema)?; + + let literal_foo = ScalarValue::Utf8(Some("foo".to_owned())); + let literal_bar = ScalarValue::Utf8(Some("bar".to_owned())); + + let const_expr_1_foo = ConstExpr::new( + Arc::clone(&col_const_1), + AcrossPartitions::Uniform(Some(literal_foo.clone())), + ); + let const_expr_2_foo = ConstExpr::new( + Arc::clone(&col_const_2), + AcrossPartitions::Uniform(Some(literal_foo.clone())), + ); + let const_expr_2_bar = ConstExpr::new( + Arc::clone(&col_const_2), + AcrossPartitions::Uniform(Some(literal_bar.clone())), + ); + + let mut input1 = EquivalenceProperties::new(Arc::clone(&schema)); + let mut input2 = EquivalenceProperties::new(Arc::clone(&schema)); + + // | Input | Const_1 | Const_2 | + // | ----- | ------- | ------- | + // | 1 | foo | foo | + // | 2 | foo | bar | + input1.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_foo.clone()])?; + input2.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_bar.clone()])?; + + // Calculate union properties + let union_props = calculate_union(vec![input1, input2], schema)?; + + // This should result in: + // const_1 = Uniform("foo") + // const_2 = Heterogeneous + assert_eq!(union_props.constants().len(), 2); + let union_const_1 = &union_props.constants()[0]; + assert!(union_const_1.expr.eq(&col_const_1)); + assert_eq!( + union_const_1.across_partitions, + AcrossPartitions::Uniform(Some(literal_foo)), + ); + let union_const_2 = &union_props.constants()[1]; + assert!(union_const_2.expr.eq(&col_const_2)); + assert_eq!( + union_const_2.across_partitions, + AcrossPartitions::Heterogeneous, + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index f21d3e7652cdc..ce3d4ced4e3a2 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,19 +17,18 @@ mod kernels; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - -use crate::expressions::binary::kernels::concat_elements_utf8view; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; use arrow::array::*; -use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; +use arrow::compute::kernels::boolean::{and_kleene, or_kleene}; use arrow::compute::kernels::cmp::*; -use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; -use arrow::compute::{cast, ilike, like, nilike, nlike}; +use arrow::compute::{ + cast, filter_record_batch, ilike, like, nilike, nlike, SlicesIterator, +}; use arrow::datatypes::*; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; @@ -49,6 +48,7 @@ use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, bitwise_shift_right_dyn_scalar, bitwise_xor_dyn, bitwise_xor_dyn_scalar, + concat_elements_utf8view, regex_match_dyn, regex_match_dyn_scalar, }; /// Binary expression @@ -159,183 +159,12 @@ fn boolean_op( left: &dyn Array, right: &dyn Array, op: impl FnOnce(&BooleanArray, &BooleanArray) -> Result, -) -> Result, ArrowError> { +) -> Result, ArrowError> { let ll = as_boolean_array(left).expect("boolean_op failed to downcast left array"); let rr = as_boolean_array(right).expect("boolean_op failed to downcast right array"); op(ll, rr).map(|t| Arc::new(t) as _) } -macro_rules! binary_string_array_flag_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ - match $LEFT.data_type() { - DataType::Utf8 => { - compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) - }, - DataType::Utf8View => { - compute_utf8view_flag_op!($LEFT, $RIGHT, $OP, StringViewArray, $NOT, $FLAG) - } - DataType::LargeUtf8 => { - compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) - }, - other => internal_err!( - "Data type {:?} not supported for binary_string_array_flag_op operation '{}' on string array", - other, stringify!($OP) - ), - } - }}; -} - -/// Invoke a compute kernel on a pair of binary data arrays with flags -macro_rules! compute_utf8_flag_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .expect("compute_utf8_flag_op failed to downcast array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .expect("compute_utf8_flag_op failed to downcast array"); - - let flag = if $FLAG { - Some($ARRAYTYPE::from(vec!["i"; ll.len()])) - } else { - None - }; - let mut array = $OP(ll, rr, flag.as_ref())?; - if $NOT { - array = not(&array).unwrap(); - } - Ok(Arc::new(array)) - }}; -} - -/// Invoke a compute kernel on a pair of binary data arrays with flags -macro_rules! compute_utf8view_flag_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .expect("compute_utf8view_flag_op failed to downcast array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .expect("compute_utf8view_flag_op failed to downcast array"); - - let flag = if $FLAG { - Some($ARRAYTYPE::from(vec!["i"; ll.len()])) - } else { - None - }; - let mut array = $OP(ll, rr, flag.as_ref())?; - if $NOT { - array = not(&array).unwrap(); - } - Ok(Arc::new(array)) - }}; -} - -macro_rules! binary_string_array_flag_op_scalar { - ($LEFT:ident, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ - // This macro is slightly different from binary_string_array_flag_op because, when comparing with a scalar value, - // the query can be optimized in such a way that operands will be dicts, so we need to support it here - let result: Result> = match $LEFT.data_type() { - DataType::Utf8 => { - compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) - }, - DataType::Utf8View => { - compute_utf8view_flag_op_scalar!($LEFT, $RIGHT, $OP, StringViewArray, $NOT, $FLAG) - } - DataType::LargeUtf8 => { - compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) - }, - DataType::Dictionary(_, _) => { - let values = $LEFT.as_any_dictionary().values(); - - match values.data_type() { - DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG), - DataType::Utf8View => compute_utf8view_flag_op_scalar!(values, $RIGHT, $OP, StringViewArray, $NOT, $FLAG), - DataType::LargeUtf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG), - other => internal_err!( - "Data type {:?} not supported as a dictionary value type for binary_string_array_flag_op_scalar operation '{}' on string array", - other, stringify!($OP) - ), - }.map( - // downcast_dictionary_array duplicates code per possible key type, so we aim to do all prep work before - |evaluated_values| downcast_dictionary_array! { - $LEFT => { - let unpacked_dict = evaluated_values.take_iter($LEFT.keys().iter().map(|opt| opt.map(|v| v as _))).collect::(); - Arc::new(unpacked_dict) as _ - }, - _ => unreachable!(), - } - ) - }, - other => internal_err!( - "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array", - other, stringify!($OP) - ), - }; - Some(result) - }}; -} - -/// Invoke a compute kernel on a data array and a scalar value with flag -macro_rules! compute_utf8_flag_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .expect("compute_utf8_flag_op_scalar failed to downcast array"); - - let string_value = match $RIGHT.try_as_str() { - Some(Some(string_value)) => string_value, - // null literal or non string - _ => return internal_err!( - "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", - $RIGHT, stringify!($OP) - ) - }; - - let flag = $FLAG.then_some("i"); - let mut array = - paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?; - if $NOT { - array = not(&array).unwrap(); - } - - Ok(Arc::new(array)) - }}; -} - -/// Invoke a compute kernel on a data array and a scalar value with flag -macro_rules! compute_utf8view_flag_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$ARRAYTYPE>() - .expect("compute_utf8view_flag_op_scalar failed to downcast array"); - - let string_value = match $RIGHT.try_as_str() { - Some(Some(string_value)) => string_value, - // null literal or non string - _ => return internal_err!( - "compute_utf8view_flag_op_scalar failed to cast literal value {} for operation '{}'", - $RIGHT, stringify!($OP) - ) - }; - - let flag = $FLAG.then_some("i"); - let mut array = - paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?; - if $NOT { - array = not(&array).unwrap(); - } - - Ok(Arc::new(array)) - }}; -} - impl PhysicalExpr for BinaryExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -358,7 +187,63 @@ impl PhysicalExpr for BinaryExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { use arrow::compute::kernels::numeric::*; + // Evaluate left-hand side expression. let lhs = self.left.evaluate(batch)?; + + // Check if we can apply short-circuit evaluation. + match check_short_circuit(&lhs, &self.op) { + ShortCircuitStrategy::None => {} + ShortCircuitStrategy::ReturnLeft => return Ok(lhs), + ShortCircuitStrategy::ReturnRight => { + let rhs = self.right.evaluate(batch)?; + return Ok(rhs); + } + ShortCircuitStrategy::PreSelection(selection) => { + // The function `evaluate_selection` was not called for filtering and calculation, + // as it takes into account cases where the selection contains null values. + let batch = filter_record_batch(batch, selection)?; + let right_ret = self.right.evaluate(&batch)?; + + match &right_ret { + ColumnarValue::Array(array) => { + // When the array on the right is all true or all false, skip the scatter process + let boolean_array = array.as_boolean(); + let true_count = boolean_array.true_count(); + let length = boolean_array.len(); + if true_count == length { + return Ok(lhs); + } else if true_count == 0 && boolean_array.null_count() == 0 { + // If the right-hand array is returned at this point,the lengths will be inconsistent; + // returning a scalar can avoid this issue + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + Some(false), + ))); + } + + return pre_selection_scatter(selection, Some(boolean_array)); + } + ColumnarValue::Scalar(scalar) => { + if let ScalarValue::Boolean(v) = scalar { + // When the scalar is true or false, skip the scatter process + if let Some(v) = v { + if *v { + return Ok(lhs); + } else { + return Ok(right_ret); + } + } else { + return pre_selection_scatter(selection, None); + } + } else { + return internal_err!( + "Expected boolean scalar value, found: {right_ret:?}" + ); + } + } + } + } + } + let rhs = self.right.evaluate(batch)?; let left_data_type = lhs.data_type(); let right_data_type = rhs.data_type(); @@ -367,8 +252,8 @@ impl PhysicalExpr for BinaryExpr { let input_schema = schema.as_ref(); if left_data_type.is_nested() { - if right_data_type != left_data_type { - return internal_err!("type mismatch"); + if !left_data_type.equals_datatype(&right_data_type) { + return internal_err!("Cannot evaluate binary expression because of type mismatch: left {}, right {} ", left_data_type, right_data_type); } return apply_cmp_for_nested(self.op, &lhs, &rhs); } @@ -399,23 +284,19 @@ impl PhysicalExpr for BinaryExpr { let result_type = self.data_type(input_schema)?; - // Attempt to use special kernels if one input is scalar and the other is an array - let scalar_result = match (&lhs, &rhs) { - (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { - // if left is array and right is literal(not NULL) - use scalar operations - if scalar.is_null() { - None - } else { - self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { - r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) - }) + // If the left-hand side is an array and the right-hand side is a non-null scalar, try the optimized kernel. + if let (ColumnarValue::Array(array), ColumnarValue::Scalar(ref scalar)) = + (&lhs, &rhs) + { + if !scalar.is_null() { + if let Some(result_array) = + self.evaluate_array_scalar(array, scalar.clone())? + { + let final_array = result_array + .and_then(|a| to_result_type_array(&self.op, a, &result_type)); + return final_array.map(ColumnarValue::Array); } } - (_, _) => None, // default to array implementation - }; - - if let Some(result) = scalar_result { - return result.map(ColumnarValue::Array); } // if both arrays or both literals - extract arrays and continue execution @@ -500,7 +381,7 @@ impl PhysicalExpr for BinaryExpr { } } else if self.op.eq(&Operator::Or) { if interval.eq(&Interval::CERTAINLY_FALSE) { - // A certainly false logical conjunction can only derive from certainly + // A certainly false logical disjunction can only derive from certainly // false operands. Otherwise, we prove infeasibility. Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE) && !right_interval.eq(&Interval::CERTAINLY_TRUE)) @@ -678,7 +559,7 @@ fn to_result_type_array( Ok(cast(&array, result_type)?) } else { internal_err!( - "Incompatible Dictionary value type {value_type:?} with result type {result_type:?} of Binary operator {op:?}" + "Incompatible Dictionary value type {value_type} with result type {result_type} of Binary operator {op:?}" ) } } @@ -699,34 +580,10 @@ impl BinaryExpr { ) -> Result>> { use Operator::*; let scalar_result = match &self.op { - RegexMatch => binary_string_array_flag_op_scalar!( - array, - scalar, - regexp_is_match, - false, - false - ), - RegexIMatch => binary_string_array_flag_op_scalar!( - array, - scalar, - regexp_is_match, - false, - true - ), - RegexNotMatch => binary_string_array_flag_op_scalar!( - array, - scalar, - regexp_is_match, - true, - false - ), - RegexNotIMatch => binary_string_array_flag_op_scalar!( - array, - scalar, - regexp_is_match, - true, - true - ), + RegexMatch => regex_match_dyn_scalar(array, scalar, false, false), + RegexIMatch => regex_match_dyn_scalar(array, scalar, false, true), + RegexNotMatch => regex_match_dyn_scalar(array, scalar, true, false), + RegexNotIMatch => regex_match_dyn_scalar(array, scalar, true, true), BitwiseAnd => bitwise_and_dyn_scalar(array, scalar), BitwiseOr => bitwise_or_dyn_scalar(array, scalar), BitwiseXor => bitwise_xor_dyn_scalar(array, scalar), @@ -775,18 +632,10 @@ impl BinaryExpr { ) } } - RegexMatch => { - binary_string_array_flag_op!(left, right, regexp_is_match, false, false) - } - RegexIMatch => { - binary_string_array_flag_op!(left, right, regexp_is_match, false, true) - } - RegexNotMatch => { - binary_string_array_flag_op!(left, right, regexp_is_match, true, false) - } - RegexNotIMatch => { - binary_string_array_flag_op!(left, right, regexp_is_match, true, true) - } + RegexMatch => regex_match_dyn(left, right, false, false), + RegexIMatch => regex_match_dyn(left, right, false, true), + RegexNotMatch => regex_match_dyn(left, right, true, false), + RegexNotIMatch => regex_match_dyn(left, right, true, true), BitwiseAnd => bitwise_and_dyn(left, right), BitwiseOr => bitwise_or_dyn(left, right), BitwiseXor => bitwise_xor_dyn(left, right), @@ -805,6 +654,213 @@ impl BinaryExpr { } } +enum ShortCircuitStrategy<'a> { + None, + ReturnLeft, + ReturnRight, + PreSelection(&'a BooleanArray), +} + +/// Based on the results calculated from the left side of the short-circuit operation, +/// if the proportion of `true` is less than 0.2 and the current operation is an `and`, +/// the `RecordBatch` will be filtered in advance. +const PRE_SELECTION_THRESHOLD: f32 = 0.2; + +/// Checks if a logical operator (`AND`/`OR`) can short-circuit evaluation based on the left-hand side (lhs) result. +/// +/// Short-circuiting occurs under these circumstances: +/// - For `AND`: +/// - if LHS is all false => short-circuit → return LHS +/// - if LHS is all true => short-circuit → return RHS +/// - if LHS is mixed and true_count/sum_count <= [`PRE_SELECTION_THRESHOLD`] -> pre-selection +/// - For `OR`: +/// - if LHS is all true => short-circuit → return LHS +/// - if LHS is all false => short-circuit → return RHS +/// # Arguments +/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar) +/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar) +/// * `op` - The logical operator (`AND` or `OR`) +/// +/// # Implementation Notes +/// 1. Only works with Boolean-typed arguments (other types automatically return `false`) +/// 2. Handles both scalar values and array values +/// 3. For arrays, uses optimized bit counting techniques for boolean arrays +fn check_short_circuit<'a>( + lhs: &'a ColumnarValue, + op: &Operator, +) -> ShortCircuitStrategy<'a> { + // Quick reject for non-logical operators,and quick judgment when op is and + let is_and = match op { + Operator::And => true, + Operator::Or => false, + _ => return ShortCircuitStrategy::None, + }; + + // Non-boolean types can't be short-circuited + if lhs.data_type() != DataType::Boolean { + return ShortCircuitStrategy::None; + } + + match lhs { + ColumnarValue::Array(array) => { + // Fast path for arrays - try to downcast to boolean array + if let Ok(bool_array) = as_boolean_array(array) { + // Arrays with nulls can't be short-circuited + if bool_array.null_count() > 0 { + return ShortCircuitStrategy::None; + } + + let len = bool_array.len(); + if len == 0 { + return ShortCircuitStrategy::None; + } + + let true_count = bool_array.values().count_set_bits(); + if is_and { + // For AND, prioritize checking for all-false (short circuit case) + // Uses optimized false_count() method provided by Arrow + + // Short circuit if all values are false + if true_count == 0 { + return ShortCircuitStrategy::ReturnLeft; + } + + // If no false values, then all must be true + if true_count == len { + return ShortCircuitStrategy::ReturnRight; + } + + // determine if we can pre-selection + if true_count as f32 / len as f32 <= PRE_SELECTION_THRESHOLD { + return ShortCircuitStrategy::PreSelection(bool_array); + } + } else { + // For OR, prioritize checking for all-true (short circuit case) + // Uses optimized true_count() method provided by Arrow + + // Short circuit if all values are true + if true_count == len { + return ShortCircuitStrategy::ReturnLeft; + } + + // If no true values, then all must be false + if true_count == 0 { + return ShortCircuitStrategy::ReturnRight; + } + } + } + } + ColumnarValue::Scalar(scalar) => { + // Fast path for scalar values + if let ScalarValue::Boolean(Some(is_true)) = scalar { + // Return Left for: + // - AND with false value + // - OR with true value + if (is_and && !is_true) || (!is_and && *is_true) { + return ShortCircuitStrategy::ReturnLeft; + } else { + return ShortCircuitStrategy::ReturnRight; + } + } + } + } + + // If we can't short-circuit, indicate that normal evaluation should continue + ShortCircuitStrategy::None +} + +/// Creates a new boolean array based on the evaluation of the right expression, +/// but only for positions where the left_result is true. +/// +/// This function is used for short-circuit evaluation optimization of logical AND operations: +/// - When left_result has few true values, we only evaluate the right expression for those positions +/// - Values are copied from right_array where left_result is true +/// - All other positions are filled with false values +/// +/// # Parameters +/// - `left_result` Boolean array with selection mask (typically from left side of AND) +/// - `right_result` Result of evaluating right side of expression (only for selected positions) +/// +/// # Returns +/// A combined ColumnarValue with values from right_result where left_result is true +/// +/// # Example +/// Initial Data: { 1, 2, 3, 4, 5 } +/// Left Evaluation +/// (Condition: Equal to 2 or 3) +/// ↓ +/// Filtered Data: {2, 3} +/// Left Bitmap: { 0, 1, 1, 0, 0 } +/// ↓ +/// Right Evaluation +/// (Condition: Even numbers) +/// ↓ +/// Right Data: { 2 } +/// Right Bitmap: { 1, 0 } +/// ↓ +/// Combine Results +/// Final Bitmap: { 0, 1, 0, 0, 0 } +/// +/// # Note +/// Perhaps it would be better to modify `left_result` directly without creating a copy? +/// In practice, `left_result` should have only one owner, so making changes should be safe. +/// However, this is difficult to achieve under the immutable constraints of [`Arc`] and [`BooleanArray`]. +fn pre_selection_scatter( + left_result: &BooleanArray, + right_result: Option<&BooleanArray>, +) -> Result { + let result_len = left_result.len(); + + let mut result_array_builder = BooleanArray::builder(result_len); + + // keep track of current position we have in right boolean array + let mut right_array_pos = 0; + + // keep track of how much is filled + let mut last_end = 0; + // reduce if condition in for_each + match right_result { + Some(right_result) => { + SlicesIterator::new(left_result).for_each(|(start, end)| { + // the gap needs to be filled with false + if start > last_end { + result_array_builder.append_n(start - last_end, false); + } + + // copy values from right array for this slice + let len = end - start; + right_result + .slice(right_array_pos, len) + .iter() + .for_each(|v| result_array_builder.append_option(v)); + + right_array_pos += len; + last_end = end; + }); + } + None => SlicesIterator::new(left_result).for_each(|(start, end)| { + // the gap needs to be filled with false + if start > last_end { + result_array_builder.append_n(start - last_end, false); + } + + // append nulls for this slice derictly + let len = end - start; + result_array_builder.append_nulls(len); + + last_end = end; + }), + } + + // Fill any remaining positions with false + if last_end < result_len { + result_array_builder.append_n(result_len - last_end, false); + } + let boolean_result = result_array_builder.finish(); + + Ok(ColumnarValue::Array(Arc::new(boolean_result))) +} + fn concat_elements(left: Arc, right: Arc) -> Result { Ok(match left.data_type() { DataType::Utf8 => Arc::new(concat_elements_utf8( @@ -859,10 +915,14 @@ pub fn similar_to( mod tests { use super::*; use crate::expressions::{col, lit, try_cast, Column, Literal}; + use datafusion_expr::lit as expr_lit; use datafusion_common::plan_datafusion_err; use datafusion_physical_expr_common::physical_expr::fmt_sql; + use crate::planner::logical2physical; + use arrow::array::BooleanArray; + use datafusion_expr::col as logical_col; /// Performs a binary operation, applying any type coercion necessary fn binary_op( left: Arc, @@ -4832,4 +4892,425 @@ mod tests { Ok(()) } + + #[test] + fn test_check_short_circuit() { + // Test with non-nullable arrays + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let a_array = Int32Array::from(vec![1, 3, 4, 5, 6]); + let b_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(a_array), Arc::new(b_array)], + ) + .unwrap(); + + // op: AND left: all false + let left_expr = logical2physical(&logical_col("a").eq(expr_lit(2)), &schema); + let left_value = left_expr.evaluate(&batch).unwrap(); + assert!(matches!( + check_short_circuit(&left_value, &Operator::And), + ShortCircuitStrategy::ReturnLeft + )); + + // op: AND left: not all false + let left_expr = logical2physical(&logical_col("a").eq(expr_lit(3)), &schema); + let left_value = left_expr.evaluate(&batch).unwrap(); + let ColumnarValue::Array(array) = &left_value else { + panic!("Expected ColumnarValue::Array"); + }; + let ShortCircuitStrategy::PreSelection(value) = + check_short_circuit(&left_value, &Operator::And) + else { + panic!("Expected ShortCircuitStrategy::PreSelection"); + }; + let expected_boolean_arr: Vec<_> = + as_boolean_array(array).unwrap().iter().collect(); + let boolean_arr: Vec<_> = value.iter().collect(); + assert_eq!(expected_boolean_arr, boolean_arr); + + // op: OR left: all true + let left_expr = logical2physical(&logical_col("a").gt(expr_lit(0)), &schema); + let left_value = left_expr.evaluate(&batch).unwrap(); + assert!(matches!( + check_short_circuit(&left_value, &Operator::Or), + ShortCircuitStrategy::ReturnLeft + )); + + // op: OR left: not all true + let left_expr: Arc = + logical2physical(&logical_col("a").gt(expr_lit(2)), &schema); + let left_value = left_expr.evaluate(&batch).unwrap(); + assert!(matches!( + check_short_circuit(&left_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with nullable arrays and null values + let schema_nullable = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Boolean, true), + Field::new("d", DataType::Boolean, true), + ])); + + // Create arrays with null values + let c_array = Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + None, + ])) as ArrayRef; + let d_array = Arc::new(BooleanArray::from(vec![ + Some(false), + Some(true), + Some(false), + None, + Some(true), + ])) as ArrayRef; + + let batch_nullable = RecordBatch::try_new( + Arc::clone(&schema_nullable), + vec![Arc::clone(&c_array), Arc::clone(&d_array)], + ) + .unwrap(); + + // Case: Mixed values with nulls - shouldn't short-circuit for AND + let mixed_nulls = logical2physical(&logical_col("c"), &schema_nullable); + let mixed_nulls_value = mixed_nulls.evaluate(&batch_nullable).unwrap(); + assert!(matches!( + check_short_circuit(&mixed_nulls_value, &Operator::And), + ShortCircuitStrategy::None + )); + + // Case: Mixed values with nulls - shouldn't short-circuit for OR + assert!(matches!( + check_short_circuit(&mixed_nulls_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with all nulls + let all_nulls = Arc::new(BooleanArray::from(vec![None, None, None])) as ArrayRef; + let null_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("e", DataType::Boolean, true)])), + vec![all_nulls], + ) + .unwrap(); + + let null_expr = logical2physical(&logical_col("e"), &null_batch.schema()); + let null_value = null_expr.evaluate(&null_batch).unwrap(); + + // All nulls shouldn't short-circuit for AND or OR + assert!(matches!( + check_short_circuit(&null_value, &Operator::And), + ShortCircuitStrategy::None + )); + assert!(matches!( + check_short_circuit(&null_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with scalar values + // Scalar true + let scalar_true = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + assert!(matches!( + check_short_circuit(&scalar_true, &Operator::Or), + ShortCircuitStrategy::ReturnLeft + )); // Should short-circuit OR + assert!(matches!( + check_short_circuit(&scalar_true, &Operator::And), + ShortCircuitStrategy::ReturnRight + )); // Should return the RHS for AND + + // Scalar false + let scalar_false = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + assert!(matches!( + check_short_circuit(&scalar_false, &Operator::And), + ShortCircuitStrategy::ReturnLeft + )); // Should short-circuit AND + assert!(matches!( + check_short_circuit(&scalar_false, &Operator::Or), + ShortCircuitStrategy::ReturnRight + )); // Should return the RHS for OR + + // Scalar null + let scalar_null = ColumnarValue::Scalar(ScalarValue::Boolean(None)); + assert!(matches!( + check_short_circuit(&scalar_null, &Operator::And), + ShortCircuitStrategy::None + )); + assert!(matches!( + check_short_circuit(&scalar_null, &Operator::Or), + ShortCircuitStrategy::None + )); + } + + /// Test for [pre_selection_scatter] + /// Since [check_short_circuit] ensures that the left side does not contain null and is neither all_true nor all_false, as well as not being empty, + /// the following tests have been designed: + /// 1. Test sparse left with interleaved true/false + /// 2. Test multiple consecutive true blocks + /// 3. Test multiple consecutive true blocks + /// 4. Test single true at first position + /// 5. Test single true at last position + /// 6. Test nulls in right array + #[test] + fn test_pre_selection_scatter() { + fn create_bool_array(bools: Vec) -> BooleanArray { + BooleanArray::from(bools.into_iter().map(Some).collect::>()) + } + // Test sparse left with interleaved true/false + { + // Left: [T, F, T, F, T] + // Right: [F, T, F] (values for 3 true positions) + let left = create_bool_array(vec![true, false, true, false, true]); + let right = create_bool_array(vec![false, true, false]); + + let result = pre_selection_scatter(&left, Some(&right)).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, true, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test multiple consecutive true blocks + { + // Left: [F, T, T, F, T, T, T] + // Right: [T, F, F, T, F] + let left = + create_bool_array(vec![false, true, true, false, true, true, true]); + let right = create_bool_array(vec![true, false, false, true, false]); + + let result = pre_selection_scatter(&left, Some(&right)).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = + create_bool_array(vec![false, true, false, false, false, true, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test single true at first position + { + // Left: [T, F, F] + // Right: [F] + let left = create_bool_array(vec![true, false, false]); + let right = create_bool_array(vec![false]); + + let result = pre_selection_scatter(&left, Some(&right)).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test single true at last position + { + // Left: [F, F, T] + // Right: [F] + let left = create_bool_array(vec![false, false, true]); + let right = create_bool_array(vec![false]); + + let result = pre_selection_scatter(&left, Some(&right)).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test nulls in right array + { + // Left: [F, T, F, T] + // Right: [None, Some(false)] (with null at first position) + let left = create_bool_array(vec![false, true, false, true]); + let right = BooleanArray::from(vec![None, Some(false)]); + + let result = pre_selection_scatter(&left, Some(&right)).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = BooleanArray::from(vec![ + Some(false), + None, // null from right + Some(false), + Some(false), + ]); + assert_eq!(&expected, result_arr.as_boolean()); + } + } + + #[test] + fn test_and_true_preselection_returns_lhs() { + let schema = + Arc::new(Schema::new(vec![Field::new("c", DataType::Boolean, false)])); + let c_array = Arc::new(BooleanArray::from(vec![false, true, false, false, false])) + as ArrayRef; + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::clone(&c_array)]) + .unwrap(); + + let expr = logical2physical(&logical_col("c").and(expr_lit(true)), &schema); + + let result = expr.evaluate(&batch).unwrap(); + let ColumnarValue::Array(result_arr) = result else { + panic!("Expected ColumnarValue::Array"); + }; + + let expected: Vec<_> = c_array.as_boolean().iter().collect(); + let actual: Vec<_> = result_arr.as_boolean().iter().collect(); + assert_eq!( + expected, actual, + "AND with TRUE must equal LHS even with PreSelection" + ); + } + + #[test] + fn test_evaluate_bounds_int32() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + + // Test addition bounds + let add_expr = + binary_expr(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema).unwrap(); + let add_bounds = add_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(add_bounds, Interval::make(Some(6), Some(25)).unwrap()); + + // Test subtraction bounds + let sub_expr = + binary_expr(Arc::clone(&a), Operator::Minus, Arc::clone(&b), &schema) + .unwrap(); + let sub_bounds = sub_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(sub_bounds, Interval::make(Some(-14), Some(5)).unwrap()); + + // Test multiplication bounds + let mul_expr = + binary_expr(Arc::clone(&a), Operator::Multiply, Arc::clone(&b), &schema) + .unwrap(); + let mul_bounds = mul_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(mul_bounds, Interval::make(Some(5), Some(150)).unwrap()); + + // Test division bounds + let div_expr = + binary_expr(Arc::clone(&a), Operator::Divide, Arc::clone(&b), &schema) + .unwrap(); + let div_bounds = div_expr + .evaluate_bounds(&[ + &Interval::make(Some(10), Some(20)).unwrap(), + &Interval::make(Some(2), Some(5)).unwrap(), + ]) + .unwrap(); + assert_eq!(div_bounds, Interval::make(Some(2), Some(10)).unwrap()); + } + + #[test] + fn test_evaluate_bounds_bool() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + ]); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + + // Test OR bounds + let or_expr = + binary_expr(Arc::clone(&a), Operator::Or, Arc::clone(&b), &schema).unwrap(); + let or_bounds = or_expr + .evaluate_bounds(&[ + &Interval::make(Some(true), Some(true)).unwrap(), + &Interval::make(Some(false), Some(false)).unwrap(), + ]) + .unwrap(); + assert_eq!(or_bounds, Interval::make(Some(true), Some(true)).unwrap()); + + // Test AND bounds + let and_expr = + binary_expr(Arc::clone(&a), Operator::And, Arc::clone(&b), &schema).unwrap(); + let and_bounds = and_expr + .evaluate_bounds(&[ + &Interval::make(Some(true), Some(true)).unwrap(), + &Interval::make(Some(false), Some(false)).unwrap(), + ]) + .unwrap(); + assert_eq!( + and_bounds, + Interval::make(Some(false), Some(false)).unwrap() + ); + } + + #[test] + fn test_evaluate_nested_type() { + let batch_schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + true, + ), + Field::new( + "b", + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), + true, + ), + ])); + + let mut list_builder_a = ListBuilder::new(Int32Builder::new()); + + list_builder_a.append_value([Some(1)]); + list_builder_a.append_value([Some(2)]); + list_builder_a.append_value([]); + list_builder_a.append_value([None]); + + let list_array_a: ArrayRef = Arc::new(list_builder_a.finish()); + + let mut list_builder_b = ListBuilder::new(Int32Builder::new()); + + list_builder_b.append_value([Some(1)]); + list_builder_b.append_value([Some(2)]); + list_builder_b.append_value([]); + list_builder_b.append_value([None]); + + let list_array_b: ArrayRef = Arc::new(list_builder_b.finish()); + + let batch = + RecordBatch::try_new(batch_schema, vec![list_array_a, list_array_b]).unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::List(Arc::new(Field::new("foo", DataType::Int32, true))), + true, + ), + Field::new( + "b", + DataType::List(Arc::new(Field::new("bar", DataType::Int32, true))), + true, + ), + ])); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + + let eq_expr = + binary_expr(Arc::clone(&a), Operator::Eq, Arc::clone(&b), &schema).unwrap(); + + let eq_result = eq_expr.evaluate(&batch).unwrap(); + let expected = + BooleanArray::from_iter(vec![Some(true), Some(true), Some(true), Some(true)]); + assert_eq!(eq_result.into_array(4).unwrap().as_boolean(), &expected); + } } diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index ae26f3e842418..71d1242eea85c 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -23,15 +23,17 @@ use arrow::compute::kernels::bitwise::{ bitwise_shift_left_scalar, bitwise_shift_right, bitwise_shift_right_scalar, bitwise_xor, bitwise_xor_scalar, }; +use arrow::compute::kernels::boolean::not; +use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar}; use arrow::datatypes::DataType; -use datafusion_common::plan_err; +use arrow::error::ArrowError; +use datafusion_common::{internal_err, plan_err}; use datafusion_common::{Result, ScalarValue}; -use arrow::error::ArrowError; use std::sync::Arc; /// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, $RIGHT) -macro_rules! call_bitwise_kernel { +macro_rules! call_kernel { ($LEFT:expr, $RIGHT:expr, $KERNEL:expr, $ARRAY_TYPE:ident) => {{ let left = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); let right = $RIGHT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); @@ -42,36 +44,36 @@ macro_rules! call_bitwise_kernel { /// Creates a $FUNC(left: ArrayRef, right: ArrayRef) that /// downcasts left / right to the appropriate integral type and calls the kernel -macro_rules! create_dyn_kernel { +macro_rules! create_left_integral_dyn_kernel { ($FUNC:ident, $KERNEL:ident) => { pub(crate) fn $FUNC(left: ArrayRef, right: ArrayRef) -> Result { match &left.data_type() { DataType::Int8 => { - call_bitwise_kernel!(left, right, $KERNEL, Int8Array) + call_kernel!(left, right, $KERNEL, Int8Array) } DataType::Int16 => { - call_bitwise_kernel!(left, right, $KERNEL, Int16Array) + call_kernel!(left, right, $KERNEL, Int16Array) } DataType::Int32 => { - call_bitwise_kernel!(left, right, $KERNEL, Int32Array) + call_kernel!(left, right, $KERNEL, Int32Array) } DataType::Int64 => { - call_bitwise_kernel!(left, right, $KERNEL, Int64Array) + call_kernel!(left, right, $KERNEL, Int64Array) } DataType::UInt8 => { - call_bitwise_kernel!(left, right, $KERNEL, UInt8Array) + call_kernel!(left, right, $KERNEL, UInt8Array) } DataType::UInt16 => { - call_bitwise_kernel!(left, right, $KERNEL, UInt16Array) + call_kernel!(left, right, $KERNEL, UInt16Array) } DataType::UInt32 => { - call_bitwise_kernel!(left, right, $KERNEL, UInt32Array) + call_kernel!(left, right, $KERNEL, UInt32Array) } DataType::UInt64 => { - call_bitwise_kernel!(left, right, $KERNEL, UInt64Array) + call_kernel!(left, right, $KERNEL, UInt64Array) } other => plan_err!( - "Data type {:?} not supported for binary operation '{}' on dyn arrays", + "Data type {} not supported for binary operation '{}' on dyn arrays", other, stringify!($KERNEL) ), @@ -80,14 +82,14 @@ macro_rules! create_dyn_kernel { }; } -create_dyn_kernel!(bitwise_or_dyn, bitwise_or); -create_dyn_kernel!(bitwise_xor_dyn, bitwise_xor); -create_dyn_kernel!(bitwise_and_dyn, bitwise_and); -create_dyn_kernel!(bitwise_shift_right_dyn, bitwise_shift_right); -create_dyn_kernel!(bitwise_shift_left_dyn, bitwise_shift_left); +create_left_integral_dyn_kernel!(bitwise_or_dyn, bitwise_or); +create_left_integral_dyn_kernel!(bitwise_xor_dyn, bitwise_xor); +create_left_integral_dyn_kernel!(bitwise_and_dyn, bitwise_and); +create_left_integral_dyn_kernel!(bitwise_shift_right_dyn, bitwise_shift_right); +create_left_integral_dyn_kernel!(bitwise_shift_left_dyn, bitwise_shift_left); /// Downcasts $LEFT as $ARRAY_TYPE and $RIGHT as TYPE and calls $KERNEL($LEFT, $RIGHT) -macro_rules! call_bitwise_scalar_kernel { +macro_rules! call_scalar_kernel { ($LEFT:expr, $RIGHT:expr, $KERNEL:ident, $ARRAY_TYPE:ident, $TYPE:ty) => {{ let len = $LEFT.len(); let array = $LEFT.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); @@ -104,20 +106,20 @@ macro_rules! call_bitwise_scalar_kernel { /// Creates a $FUNC(left: ArrayRef, right: ScalarValue) that /// downcasts left / right to the appropriate integral type and calls the kernel -macro_rules! create_dyn_scalar_kernel { +macro_rules! create_left_integral_dyn_scalar_kernel { ($FUNC:ident, $KERNEL:ident) => { pub(crate) fn $FUNC(array: &dyn Array, scalar: ScalarValue) -> Option> { let result = match array.data_type() { - DataType::Int8 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, Int8Array, i8), - DataType::Int16 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, Int16Array, i16), - DataType::Int32 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, Int32Array, i32), - DataType::Int64 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, Int64Array, i64), - DataType::UInt8 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt8Array, u8), - DataType::UInt16 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt16Array, u16), - DataType::UInt32 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt32Array, u32), - DataType::UInt64 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt64Array, u64), + DataType::Int8 => call_scalar_kernel!(array, scalar, $KERNEL, Int8Array, i8), + DataType::Int16 => call_scalar_kernel!(array, scalar, $KERNEL, Int16Array, i16), + DataType::Int32 => call_scalar_kernel!(array, scalar, $KERNEL, Int32Array, i32), + DataType::Int64 => call_scalar_kernel!(array, scalar, $KERNEL, Int64Array, i64), + DataType::UInt8 => call_scalar_kernel!(array, scalar, $KERNEL, UInt8Array, u8), + DataType::UInt16 => call_scalar_kernel!(array, scalar, $KERNEL, UInt16Array, u16), + DataType::UInt32 => call_scalar_kernel!(array, scalar, $KERNEL, UInt32Array, u32), + DataType::UInt64 => call_scalar_kernel!(array, scalar, $KERNEL, UInt64Array, u64), other => plan_err!( - "Data type {:?} not supported for binary operation '{}' on dyn arrays", + "Data type {} not supported for binary operation '{}' on dyn arrays", other, stringify!($KERNEL) ), @@ -127,11 +129,17 @@ macro_rules! create_dyn_scalar_kernel { }; } -create_dyn_scalar_kernel!(bitwise_and_dyn_scalar, bitwise_and_scalar); -create_dyn_scalar_kernel!(bitwise_or_dyn_scalar, bitwise_or_scalar); -create_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar); -create_dyn_scalar_kernel!(bitwise_shift_right_dyn_scalar, bitwise_shift_right_scalar); -create_dyn_scalar_kernel!(bitwise_shift_left_dyn_scalar, bitwise_shift_left_scalar); +create_left_integral_dyn_scalar_kernel!(bitwise_and_dyn_scalar, bitwise_and_scalar); +create_left_integral_dyn_scalar_kernel!(bitwise_or_dyn_scalar, bitwise_or_scalar); +create_left_integral_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar); +create_left_integral_dyn_scalar_kernel!( + bitwise_shift_right_dyn_scalar, + bitwise_shift_right_scalar +); +create_left_integral_dyn_scalar_kernel!( + bitwise_shift_left_dyn_scalar, + bitwise_shift_left_scalar +); pub fn concat_elements_utf8view( left: &StringViewArray, @@ -164,3 +172,125 @@ pub fn concat_elements_utf8view( } Ok(result.finish()) } + +/// Invoke a compute kernel on a pair of binary data arrays with flags +macro_rules! regexp_is_match_flag { + ($LEFT:expr, $RIGHT:expr, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .expect("failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .expect("failed to downcast array"); + + let flag = if $FLAG { + Some($ARRAYTYPE::from(vec!["i"; ll.len()])) + } else { + None + }; + let mut array = regexp_is_match(ll, rr, flag.as_ref())?; + if $NOT { + array = not(&array).unwrap(); + } + Ok(Arc::new(array)) + }}; +} + +pub(crate) fn regex_match_dyn( + left: ArrayRef, + right: ArrayRef, + not_match: bool, + flag: bool, +) -> Result { + match left.data_type() { + DataType::Utf8 => { + regexp_is_match_flag!(left, right, StringArray, not_match, flag) + } + DataType::Utf8View => { + regexp_is_match_flag!(left, right, StringViewArray, not_match, flag) + } + DataType::LargeUtf8 => { + regexp_is_match_flag!(left, right, LargeStringArray, not_match, flag) + } + other => internal_err!( + "Data type {} not supported for regex_match_dyn on string array", + other + ), + } +} + +/// Invoke a compute kernel on a data array and a scalar value with flag +macro_rules! regexp_is_match_flag_scalar { + ($LEFT:expr, $RIGHT:expr, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .expect("failed to downcast array"); + + if let Some(Some(string_value)) = $RIGHT.try_as_str() { + let flag = $FLAG.then_some("i"); + match regexp_is_match_scalar(ll, &string_value, flag) { + Ok(mut array) => { + if $NOT { + array = not(&array).unwrap(); + } + Ok(Arc::new(array)) + } + Err(e) => internal_err!("failed to call 'regex_match_dyn_scalar' {}", e), + } + } else { + internal_err!( + "failed to cast literal value {} for operation 'regex_match_dyn_scalar'", + $RIGHT + ) + } + }}; +} + +pub(crate) fn regex_match_dyn_scalar( + left: &dyn Array, + right: ScalarValue, + not_match: bool, + flag: bool, +) -> Option> { + let result: Result = match left.data_type() { + DataType::Utf8 => { + regexp_is_match_flag_scalar!(left, right, StringArray, not_match, flag) + }, + DataType::Utf8View => { + regexp_is_match_flag_scalar!(left, right, StringViewArray, not_match, flag) + } + DataType::LargeUtf8 => { + regexp_is_match_flag_scalar!(left, right, LargeStringArray, not_match, flag) + }, + DataType::Dictionary(_, _) => { + let values = left.as_any_dictionary().values(); + + match values.data_type() { + DataType::Utf8 => regexp_is_match_flag_scalar!(values, right, StringArray, not_match, flag), + DataType::Utf8View => regexp_is_match_flag_scalar!(values, right, StringViewArray, not_match, flag), + DataType::LargeUtf8 => regexp_is_match_flag_scalar!(values, right, LargeStringArray, not_match, flag), + other => internal_err!( + "Data type {} not supported as a dictionary value type for operation 'regex_match_dyn_scalar' on string array", + other + ), + }.map( + // downcast_dictionary_array duplicates code per possible key type, so we aim to do all prep work before + |evaluated_values| downcast_dictionary_array! { + left => { + let unpacked_dict = evaluated_values.take_iter(left.keys().iter().map(|opt| opt.map(|v| v as _))).collect::(); + Arc::new(unpacked_dict) as ArrayRef + }, + _ => unreachable!(), + } + ) + }, + other => internal_err!( + "Data type {} not supported for operation 'regex_match_dyn_scalar' on string array", + other + ), + }; + Some(result) +} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 854c715eb0a25..d14146a20d8bd 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::expressions::try_cast; +use crate::PhysicalExpr; use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; -use crate::expressions::try_cast; -use crate::PhysicalExpr; - use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; @@ -156,10 +155,7 @@ impl CaseExpr { && else_expr.as_ref().unwrap().as_any().is::() { EvalMethod::ScalarOrScalar - } else if when_then_expr.len() == 1 - && is_cheap_and_infallible(&(when_then_expr[0].1)) - && else_expr.as_ref().is_some_and(is_cheap_and_infallible) - { + } else if when_then_expr.len() == 1 && else_expr.is_some() { EvalMethod::ExpressionOrExpression } else { EvalMethod::NoExpression @@ -324,12 +320,14 @@ impl CaseExpr { } if let Some(e) = self.else_expr() { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; + if remainder.true_count() > 0 { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + current_value = zip(&remainder, &else_, ¤t_value)?; + } } Ok(ColumnarValue::Array(current_value)) @@ -414,7 +412,7 @@ impl CaseExpr { fn expr_or_expr(&self, batch: &RecordBatch) -> Result { let return_type = self.data_type(&batch.schema())?; - // evalute when condition on batch + // evaluate when condition on batch let when_value = self.when_then_expr[0].0.evaluate(batch)?; let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { @@ -424,6 +422,16 @@ impl CaseExpr { ) })?; + // For the true and false/null selection vectors, bypass `evaluate_selection` and merging + // results. This avoids materializing the array for the other branch which we will discard + // entirely anyway. + let true_count = when_value.true_count(); + if true_count == batch.num_rows() { + return self.when_then_expr[0].1.evaluate(batch); + } else if true_count == 0 { + return self.else_expr.as_ref().unwrap().evaluate(batch); + } + // Treat 'NULL' as false value let when_value = match when_value.null_count() { 0 => Cow::Borrowed(when_value), @@ -603,7 +611,7 @@ mod tests { use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; - use arrow::datatypes::*; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -1069,7 +1077,6 @@ mod tests { .into_iter() .collect(); - //let valid_array = vec![true, false, false, true, false, tru let null_buffer = Buffer::from([0b00101001u8]); let load4 = load4 .into_data() diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index a6766687a881a..407e3e6a9d294 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, DataType::*, Schema}; +use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; @@ -97,8 +97,10 @@ impl CastExpr { pub fn cast_options(&self) -> &CastOptions<'static> { &self.cast_options } - pub fn is_bigger_cast(&self, src: DataType) -> bool { - if src == self.cast_type { + + /// Check if the cast is a widening cast (e.g. from `Int8` to `Int16`). + pub fn is_bigger_cast(&self, src: &DataType) -> bool { + if self.cast_type.eq(src) { return true; } matches!( @@ -144,6 +146,16 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(self + .expr + .return_field(input_schema)? + .as_ref() + .clone() + .with_data_type(self.cast_type.clone()) + .into()) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } @@ -220,7 +232,7 @@ pub fn cast_with_options( } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { - not_impl_err!("Unsupported CAST from {expr_type:?} to {cast_type:?}") + not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") } } @@ -250,8 +262,8 @@ mod tests { }, datatypes::*, }; - use datafusion_common::assert_contains; use datafusion_physical_expr_common::physical_expr::fmt_sql; + use insta::assert_snapshot; // runs an end-to-end test of physical type cast // 1. construct a record batch with a column "a" of type A @@ -426,11 +438,8 @@ mod tests { )?; let expression = cast_with_options(col("a", &schema)?, &schema, Decimal128(6, 2), None)?; - let e = expression.evaluate(&batch).unwrap_err(); // panics on OK - assert_contains!( - e.to_string(), - "Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999" - ); + let e = expression.evaluate(&batch).unwrap_err().strip_backtrace(); // panics on OK + assert_snapshot!(e, @"Arrow error: Invalid argument error: 12345679 is too large to store in a Decimal128 of precision 6. Max is 999999"); let expression_safe = cast_with_options( col("a", &schema)?, diff --git a/datafusion/physical-expr/src/expressions/cast_column.rs b/datafusion/physical-expr/src/expressions/cast_column.rs new file mode 100644 index 0000000000000..80d71c3def408 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/cast_column.rs @@ -0,0 +1,409 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression for struct-aware casting of columns. + +use crate::physical_expr::PhysicalExpr; +use arrow::{ + compute::CastOptions, + datatypes::{DataType, FieldRef, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::{ + format::DEFAULT_CAST_OPTIONS, nested_struct::cast_column, Result, ScalarValue, +}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use std::{ + any::Any, + fmt::{self, Display}, + hash::Hash, + sync::Arc, +}; +/// A physical expression that applies [`cast_column`] to its input. +/// +/// [`CastColumnExpr`] extends the regular [`CastExpr`](super::CastExpr) by +/// retaining schema metadata for both the input and output fields. This allows +/// the evaluator to perform struct-aware casts that honour nested field +/// ordering, preserve nullability, and fill missing fields with null values. +/// +/// This expression is intended for schema rewriting scenarios where the +/// planner already resolved the input column but needs to adapt its physical +/// representation to a new [`arrow::datatypes::Field`]. It mirrors the behaviour of the +/// [`datafusion_common::nested_struct::cast_column`] helper while integrating +/// with the `PhysicalExpr` trait so it can participate in the execution plan +/// like any other column expression. +#[derive(Debug, Clone, Eq)] +pub struct CastColumnExpr { + /// The physical expression producing the value to cast. + expr: Arc, + /// The logical field of the input column. + input_field: FieldRef, + /// The field metadata describing the desired output column. + target_field: FieldRef, + /// Options forwarded to [`cast_column`]. + cast_options: CastOptions<'static>, +} + +// Manually derive `PartialEq`/`Hash` as `Arc` does not +// implement these traits by default for the trait object. +impl PartialEq for CastColumnExpr { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) + && self.input_field.eq(&other.input_field) + && self.target_field.eq(&other.target_field) + && self.cast_options.eq(&other.cast_options) + } +} + +impl Hash for CastColumnExpr { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + self.input_field.hash(state); + self.target_field.hash(state); + self.cast_options.hash(state); + } +} + +impl CastColumnExpr { + /// Create a new [`CastColumnExpr`]. + pub fn new( + expr: Arc, + input_field: FieldRef, + target_field: FieldRef, + cast_options: Option>, + ) -> Self { + Self { + expr, + input_field, + target_field, + cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS), + } + } + + /// The expression that produces the value to be cast. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Field metadata describing the resolved input column. + pub fn input_field(&self) -> &FieldRef { + &self.input_field + } + + /// Field metadata describing the output column after casting. + pub fn target_field(&self) -> &FieldRef { + &self.target_field + } +} + +impl Display for CastColumnExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "CAST_COLUMN({} AS {:?})", + self.expr, + self.target_field.data_type() + ) + } +} + +impl PhysicalExpr for CastColumnExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.target_field.data_type().clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.target_field.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let value = self.expr.evaluate(batch)?; + match value { + ColumnarValue::Array(array) => { + let casted = + cast_column(&array, self.target_field.as_ref(), &self.cast_options)?; + Ok(ColumnarValue::Array(casted)) + } + ColumnarValue::Scalar(scalar) => { + let as_array = scalar.to_array_of_size(1)?; + let casted = cast_column( + &as_array, + self.target_field.as_ref(), + &self.cast_options, + )?; + let result = ScalarValue::try_from_array(casted.as_ref(), 0)?; + Ok(ColumnarValue::Scalar(result)) + } + } + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.target_field)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.expr] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 1); + let child = children.pop().expect("CastColumnExpr child"); + Ok(Arc::new(Self::new( + child, + Arc::clone(&self.input_field), + Arc::clone(&self.target_field), + Some(self.cast_options.clone()), + ))) + } + + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Display::fmt(self, f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::expressions::{Column, Literal}; + use arrow::{ + array::{Array, ArrayRef, BooleanArray, Int32Array, StringArray, StructArray}, + datatypes::{DataType, Field, Fields, SchemaRef}, + }; + use datafusion_common::{ + cast::{as_int64_array, as_string_array, as_struct_array, as_uint8_array}, + Result as DFResult, ScalarValue, + }; + + fn make_schema(field: &Field) -> SchemaRef { + Arc::new(Schema::new(vec![field.clone()])) + } + + fn make_struct_array(fields: Fields, arrays: Vec) -> StructArray { + StructArray::new(fields, arrays, None) + } + + #[test] + fn cast_primitive_array() -> DFResult<()> { + let input_field = Field::new("a", DataType::Int32, true); + let target_field = Field::new("a", DataType::Int64, true); + let schema = make_schema(&input_field); + + let values = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![values])?; + + let column = Arc::new(Column::new_with_schema("a", schema.as_ref())?); + let expr = CastColumnExpr::new( + column, + Arc::new(input_field.clone()), + Arc::new(target_field.clone()), + None, + ); + + let result = expr.evaluate(&batch)?; + let ColumnarValue::Array(array) = result else { + panic!("expected array"); + }; + let casted = as_int64_array(array.as_ref())?; + assert_eq!(casted.value(0), 1); + assert!(casted.is_null(1)); + assert_eq!(casted.value(2), 3); + Ok(()) + } + + #[test] + fn cast_struct_array_missing_child() -> DFResult<()> { + let source_a = Field::new("a", DataType::Int32, true); + let source_b = Field::new("b", DataType::Utf8, true); + let input_field = Field::new( + "s", + DataType::Struct( + vec![Arc::new(source_a.clone()), Arc::new(source_b.clone())].into(), + ), + true, + ); + let target_a = Field::new("a", DataType::Int64, true); + let target_c = Field::new("c", DataType::Utf8, true); + let target_field = Field::new( + "s", + DataType::Struct( + vec![Arc::new(target_a.clone()), Arc::new(target_c.clone())].into(), + ), + true, + ); + + let schema = make_schema(&input_field); + let struct_array = make_struct_array( + vec![Arc::new(source_a.clone()), Arc::new(source_b.clone())].into(), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef, + Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) + as ArrayRef, + ], + ); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(struct_array) as Arc<_>], + )?; + + let column = Arc::new(Column::new_with_schema("s", schema.as_ref())?); + let expr = CastColumnExpr::new( + column, + Arc::new(input_field.clone()), + Arc::new(target_field.clone()), + None, + ); + + let result = expr.evaluate(&batch)?; + let ColumnarValue::Array(array) = result else { + panic!("expected array"); + }; + let struct_array = as_struct_array(array.as_ref())?; + let cast_a = as_int64_array(struct_array.column_by_name("a").unwrap().as_ref())?; + assert_eq!(cast_a.value(0), 1); + assert!(cast_a.is_null(1)); + + let cast_c = as_string_array(struct_array.column_by_name("c").unwrap().as_ref())?; + assert!(cast_c.is_null(0)); + assert!(cast_c.is_null(1)); + Ok(()) + } + + #[test] + fn cast_nested_struct_array() -> DFResult<()> { + let inner_source = Field::new( + "inner", + DataType::Struct( + vec![Arc::new(Field::new("x", DataType::Int32, true))].into(), + ), + true, + ); + let outer_field = Field::new( + "root", + DataType::Struct(vec![Arc::new(inner_source.clone())].into()), + true, + ); + + let inner_target = Field::new( + "inner", + DataType::Struct( + vec![ + Arc::new(Field::new("x", DataType::Int64, true)), + Arc::new(Field::new("y", DataType::Boolean, true)), + ] + .into(), + ), + true, + ); + let target_field = Field::new( + "root", + DataType::Struct(vec![Arc::new(inner_target.clone())].into()), + true, + ); + + let schema = make_schema(&outer_field); + + let inner_struct = make_struct_array( + vec![Arc::new(Field::new("x", DataType::Int32, true))].into(), + vec![Arc::new(Int32Array::from(vec![Some(7), None])) as ArrayRef], + ); + let outer_struct = make_struct_array( + vec![Arc::new(inner_source.clone())].into(), + vec![Arc::new(inner_struct) as ArrayRef], + ); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(outer_struct) as ArrayRef], + )?; + + let column = Arc::new(Column::new_with_schema("root", schema.as_ref())?); + let expr = CastColumnExpr::new( + column, + Arc::new(outer_field.clone()), + Arc::new(target_field.clone()), + None, + ); + + let result = expr.evaluate(&batch)?; + let ColumnarValue::Array(array) = result else { + panic!("expected array"); + }; + let struct_array = as_struct_array(array.as_ref())?; + let inner = + as_struct_array(struct_array.column_by_name("inner").unwrap().as_ref())?; + let x = as_int64_array(inner.column_by_name("x").unwrap().as_ref())?; + assert_eq!(x.value(0), 7); + assert!(x.is_null(1)); + let y = inner.column_by_name("y").unwrap(); + let y = y + .as_any() + .downcast_ref::() + .expect("boolean array"); + assert!(y.is_null(0)); + assert!(y.is_null(1)); + Ok(()) + } + + #[test] + fn cast_struct_scalar() -> DFResult<()> { + let source_field = Field::new("a", DataType::Int32, true); + let input_field = Field::new( + "s", + DataType::Struct(vec![Arc::new(source_field.clone())].into()), + true, + ); + let target_field = Field::new( + "s", + DataType::Struct( + vec![Arc::new(Field::new("a", DataType::UInt8, true))].into(), + ), + true, + ); + + let schema = make_schema(&input_field); + let scalar_struct = StructArray::new( + vec![Arc::new(source_field.clone())].into(), + vec![Arc::new(Int32Array::from(vec![Some(9)])) as ArrayRef], + None, + ); + let literal = + Arc::new(Literal::new(ScalarValue::Struct(Arc::new(scalar_struct)))); + let expr = CastColumnExpr::new( + literal, + Arc::new(input_field.clone()), + Arc::new(target_field.clone()), + None, + ); + + let batch = RecordBatch::new_empty(Arc::clone(&schema)); + let result = expr.evaluate(&batch)?; + let ColumnarValue::Scalar(ScalarValue::Struct(array)) = result else { + panic!("expected struct scalar"); + }; + let casted = array.column_by_name("a").unwrap(); + let casted = as_uint8_array(casted.as_ref())?; + assert_eq!(casted.value(0), 9); + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index ab5b359847535..c9f3fb00f019e 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,6 +22,7 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, @@ -127,6 +128,10 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(input_schema.field(self.index).clone().into()) + } + fn children(&self) -> Vec<&Arc> { vec![] } @@ -199,7 +204,6 @@ mod test { use arrow::array::StringArray; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; - use datafusion_common::Result; use std::sync::Arc; @@ -209,8 +213,9 @@ mod test { let col = Column::new("id", 9); let error = col.data_type(&schema).expect_err("error").strip_backtrace(); assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) + but input schema only has 1 columns: [\"foo\"].\nThis issue was likely caused by a bug \ + in DataFusion's code. Please help us to resolve this by filing a bug report \ + in our issue tracker: https://github.com/apache/datafusion/issues".starts_with(&error)) } #[test] @@ -219,20 +224,21 @@ mod test { let col = Column::new("id", 9); let error = col.nullable(&schema).expect_err("error").strip_backtrace(); assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)) + but input schema only has 1 columns: [\"foo\"].\nThis issue was likely caused by a bug \ + in DataFusion's code. Please help us to resolve this by filing a bug report \ + in our issue tracker: https://github.com/apache/datafusion/issues".starts_with(&error)); } #[test] - fn out_of_bounds_evaluate() -> Result<()> { + fn out_of_bounds_evaluate() { let schema = Schema::new(vec![Field::new("foo", DataType::Utf8, true)]); let data: StringArray = vec!["data"].into(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?; + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)]).unwrap(); let col = Column::new("id", 9); let error = col.evaluate(&batch).expect_err("error").strip_backtrace(); assert!("Internal error: PhysicalExpr Column references column 'id' at index 9 (zero-based) \ - but input schema only has 1 columns: [\"foo\"].\nThis was likely caused by a bug in \ - DataFusion's code and we would welcome that you file an bug report in our issue tracker".starts_with(&error)); - Ok(()) + but input schema only has 1 columns: [\"foo\"].\nThis issue was likely caused by a bug \ + in DataFusion's code. Please help us to resolve this by filing a bug report \ + in our issue tracker: https://github.com/apache/datafusion/issues".starts_with(&error)); } } diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs new file mode 100644 index 0000000000000..a53b32c976893 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -0,0 +1,512 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use parking_lot::RwLock; +use std::{any::Any, fmt::Display, hash::Hash, sync::Arc}; + +use crate::PhysicalExpr; +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::{ + tree_node::{Transformed, TransformedResult, TreeNode}, + Result, +}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::{DynEq, DynHash}; + +/// A dynamic [`PhysicalExpr`] that can be updated by anyone with a reference to it. +/// +/// Any `ExecutionPlan` that uses this expression and holds a reference to it internally should probably also +/// implement `ExecutionPlan::reset_state` to remain compatible with recursive queries and other situations where +/// the same `ExecutionPlan` is reused with different data. +#[derive(Debug)] +pub struct DynamicFilterPhysicalExpr { + /// The original children of this PhysicalExpr, if any. + /// This is necessary because the dynamic filter may be initialized with a placeholder (e.g. `lit(true)`) + /// and later remapped to the actual expressions that are being filtered. + /// But we need to know the children (e.g. columns referenced in the expression) ahead of time to evaluate the expression correctly. + children: Vec>, + /// If any of the children were remapped / modified (e.g. to adjust for projections) we need to keep track of the new children + /// so that when we update `current()` in subsequent iterations we can re-apply the replacements. + remapped_children: Option>>, + /// The source of dynamic filters. + inner: Arc>, + /// For testing purposes track the data type and nullability to make sure they don't change. + /// If they do, there's a bug in the implementation. + /// But this can have overhead in production, so it's only included in our tests. + data_type: Arc>>, + nullable: Arc>>, +} + +#[derive(Debug)] +struct Inner { + /// A counter that gets incremented every time the expression is updated so that we can track changes cheaply. + /// This is used for [`PhysicalExpr::snapshot_generation`] to have a cheap check for changes. + generation: u64, + expr: Arc, +} + +impl Inner { + fn new(expr: Arc) -> Self { + Self { + // Start with generation 1 which gives us a different result for [`PhysicalExpr::generation`] than the default 0. + // This is not currently used anywhere but it seems useful to have this simple distinction. + generation: 1, + expr, + } + } + + /// Clone the inner expression. + fn expr(&self) -> &Arc { + &self.expr + } +} + +impl Hash for DynamicFilterPhysicalExpr { + fn hash(&self, state: &mut H) { + let inner = self.current().expect("Failed to get current expression"); + inner.dyn_hash(state); + self.children.dyn_hash(state); + self.remapped_children.dyn_hash(state); + } +} + +impl PartialEq for DynamicFilterPhysicalExpr { + fn eq(&self, other: &Self) -> bool { + let inner = self.current().expect("Failed to get current expression"); + let our_children = self.remapped_children.as_ref().unwrap_or(&self.children); + let other_children = other.remapped_children.as_ref().unwrap_or(&other.children); + let other = other.current().expect("Failed to get current expression"); + inner.dyn_eq(other.as_any()) && our_children == other_children + } +} + +impl Eq for DynamicFilterPhysicalExpr {} + +impl Display for DynamicFilterPhysicalExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.render(f, |expr, f| write!(f, "{expr}")) + } +} + +impl DynamicFilterPhysicalExpr { + /// Create a new [`DynamicFilterPhysicalExpr`] + /// from an initial expression and a list of children. + /// The list of children is provided separately because + /// the initial expression may not have the same children. + /// For example, if the initial expression is just `true` + /// it will not reference any columns, but we may know that + /// we are going to replace this expression with a real one + /// that does reference certain columns. + /// In this case you **must** pass in the columns that will be + /// used in the final expression as children to this function + /// since DataFusion is generally not compatible with dynamic + /// *children* in expressions. + /// + /// To determine the children you can: + /// + /// - Use [`collect_columns`] to collect the columns from the expression. + /// - Use existing information, such as the sort columns in a `SortExec`. + /// + /// Generally the important bit is that the *leaf children that reference columns + /// do not change* since those will be used to determine what columns need to read or projected + /// when evaluating the expression. + /// + /// Any `ExecutionPlan` that uses this expression and holds a reference to it internally should probably also + /// implement `ExecutionPlan::reset_state` to remain compatible with recursive queries and other situations where + /// the same `ExecutionPlan` is reused with different data. + /// + /// [`collect_columns`]: crate::utils::collect_columns + pub fn new( + children: Vec>, + inner: Arc, + ) -> Self { + Self { + children, + remapped_children: None, // Initially no remapped children + inner: Arc::new(RwLock::new(Inner::new(inner))), + data_type: Arc::new(RwLock::new(None)), + nullable: Arc::new(RwLock::new(None)), + } + } + + fn remap_children( + children: &[Arc], + remapped_children: Option<&Vec>>, + expr: Arc, + ) -> Result> { + if let Some(remapped_children) = remapped_children { + // Remap the children to the new children + // of the expression. + expr.transform_up(|child| { + // Check if this is any of our original children + if let Some(pos) = + children.iter().position(|c| c.as_ref() == child.as_ref()) + { + // If so, remap it to the current children + // of the expression. + let new_child = Arc::clone(&remapped_children[pos]); + Ok(Transformed::yes(new_child)) + } else { + // Otherwise, just return the expression + Ok(Transformed::no(child)) + } + }) + .data() + } else { + // If we don't have any remapped children, just return the expression + Ok(Arc::clone(&expr)) + } + } + + /// Get the current generation of the expression. + fn current_generation(&self) -> u64 { + self.inner.read().generation + } + + /// Get the current expression. + /// This will return the current expression with any children + /// remapped to match calls to [`PhysicalExpr::with_new_children`]. + pub fn current(&self) -> Result> { + let expr = Arc::clone(self.inner.read().expr()); + Self::remap_children(&self.children, self.remapped_children.as_ref(), expr) + } + + /// Update the current expression. + /// Any children of this expression must be a subset of the original children + /// passed to the constructor. + /// This should be called e.g.: + /// - When we've computed the probe side's hash table in a HashJoinExec + /// - After every batch is processed if we update the TopK heap in a SortExec using a TopK approach. + pub fn update(&self, new_expr: Arc) -> Result<()> { + // Remap the children of the new expression to match the original children + // We still do this again in `current()` but doing it preventively here + // reduces the work needed in some cases if `current()` is called multiple times + // and the same externally facing `PhysicalExpr` is used for both `with_new_children` and `update()`.` + let new_expr = Self::remap_children( + &self.children, + self.remapped_children.as_ref(), + new_expr, + )?; + + // Load the current inner, increment generation, and store the new one + let mut current = self.inner.write(); + *current = Inner { + generation: current.generation + 1, + expr: new_expr, + }; + Ok(()) + } + + fn render( + &self, + f: &mut std::fmt::Formatter<'_>, + render_expr: impl FnOnce( + Arc, + &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result, + ) -> std::fmt::Result { + let inner = self.current().map_err(|_| std::fmt::Error)?; + let current_generation = self.current_generation(); + write!(f, "DynamicFilter [ ")?; + if current_generation == 1 { + write!(f, "empty")?; + } else { + render_expr(inner, f)?; + } + + write!(f, " ]") + } +} + +impl PhysicalExpr for DynamicFilterPhysicalExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn children(&self) -> Vec<&Arc> { + self.remapped_children + .as_ref() + .unwrap_or(&self.children) + .iter() + .collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self { + children: self.children.clone(), + remapped_children: Some(children), + inner: Arc::clone(&self.inner), + data_type: Arc::clone(&self.data_type), + nullable: Arc::clone(&self.nullable), + })) + } + + fn data_type(&self, input_schema: &Schema) -> Result { + let res = self.current()?.data_type(input_schema)?; + #[cfg(test)] + { + use datafusion_common::internal_err; + // Check if the data type has changed. + let mut data_type_lock = self.data_type.write(); + + if let Some(existing) = &*data_type_lock { + if existing != &res { + // If the data type has changed, we have a bug. + return internal_err!( + "DynamicFilterPhysicalExpr data type has changed unexpectedly. \ + Expected: {existing:?}, Actual: {res:?}" + ); + } + } else { + *data_type_lock = Some(res.clone()); + } + } + Ok(res) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + let res = self.current()?.nullable(input_schema)?; + #[cfg(test)] + { + use datafusion_common::internal_err; + // Check if the nullability has changed. + let mut nullable_lock = self.nullable.write(); + if let Some(existing) = *nullable_lock { + if existing != res { + // If the nullability has changed, we have a bug. + return internal_err!( + "DynamicFilterPhysicalExpr nullability has changed unexpectedly. \ + Expected: {existing}, Actual: {res}" + ); + } + } else { + *nullable_lock = Some(res); + } + } + Ok(res) + } + + fn evaluate( + &self, + batch: &arrow::record_batch::RecordBatch, + ) -> Result { + let current = self.current()?; + #[cfg(test)] + { + // Ensure that we are not evaluating after the expression has changed. + let schema = batch.schema(); + self.nullable(&schema)?; + self.data_type(&schema)?; + }; + current.evaluate(batch) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.render(f, |expr, f| expr.fmt_sql(f)) + } + + fn snapshot(&self) -> Result>> { + // Return the current expression as a snapshot. + Ok(Some(self.current()?)) + } + + fn snapshot_generation(&self) -> u64 { + // Return the current generation of the expression. + self.inner.read().generation + } +} + +#[cfg(test)] +mod test { + use crate::{ + expressions::{col, lit, BinaryExpr}, + utils::reassign_expr_columns, + }; + use arrow::{ + array::RecordBatch, + datatypes::{DataType, Field, Schema}, + }; + use datafusion_common::ScalarValue; + + use super::*; + + #[test] + fn test_remap_children() { + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let expr = Arc::new(BinaryExpr::new( + col("a", &table_schema).unwrap(), + datafusion_expr::Operator::Eq, + lit(42) as Arc, + )); + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![col("a", &table_schema).unwrap()], + expr as Arc, + )); + // Simulate two `ParquetSource` files with different filter schemas + // Both of these should hit the same inner `PhysicalExpr` even after `update()` is called + // and be able to remap children independently. + let filter_schema_1 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let filter_schema_2 = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + // Each ParquetExec calls `with_new_children` on the DynamicFilterPhysicalExpr + // and remaps the children to the file schema. + let dynamic_filter_1 = reassign_expr_columns( + Arc::clone(&dynamic_filter) as Arc, + &filter_schema_1, + ) + .unwrap(); + let snap = dynamic_filter_1.snapshot().unwrap().unwrap(); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); + let dynamic_filter_2 = reassign_expr_columns( + Arc::clone(&dynamic_filter) as Arc, + &filter_schema_2, + ) + .unwrap(); + let snap = dynamic_filter_2.snapshot().unwrap().unwrap(); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); + // Both filters allow evaluating the same expression + let batch_1 = RecordBatch::try_new( + Arc::clone(&filter_schema_1), + vec![ + // a + ScalarValue::Int32(Some(42)).to_array_of_size(1).unwrap(), + // b + ScalarValue::Int32(Some(43)).to_array_of_size(1).unwrap(), + ], + ) + .unwrap(); + let batch_2 = RecordBatch::try_new( + Arc::clone(&filter_schema_2), + vec![ + // b + ScalarValue::Int32(Some(43)).to_array_of_size(1).unwrap(), + // a + ScalarValue::Int32(Some(42)).to_array_of_size(1).unwrap(), + ], + ) + .unwrap(); + // Evaluate the expression on both batches + let result_1 = dynamic_filter_1.evaluate(&batch_1).unwrap(); + let result_2 = dynamic_filter_2.evaluate(&batch_2).unwrap(); + // Check that the results are the same + let ColumnarValue::Array(arr_1) = result_1 else { + panic!("Expected ColumnarValue::Array"); + }; + let ColumnarValue::Array(arr_2) = result_2 else { + panic!("Expected ColumnarValue::Array"); + }; + assert!(arr_1.eq(&arr_2)); + let expected = ScalarValue::Boolean(Some(true)) + .to_array_of_size(1) + .unwrap(); + assert!(arr_1.eq(&expected)); + // Now lets update the expression + // Note that we update the *original* expression and that should be reflected in both the derived expressions + let new_expr = Arc::new(BinaryExpr::new( + col("a", &table_schema).unwrap(), + datafusion_expr::Operator::Gt, + lit(43) as Arc, + )); + dynamic_filter + .update(Arc::clone(&new_expr) as Arc) + .expect("Failed to update expression"); + // Now we should be able to evaluate the new expression on both batches + let result_1 = dynamic_filter_1.evaluate(&batch_1).unwrap(); + let result_2 = dynamic_filter_2.evaluate(&batch_2).unwrap(); + // Check that the results are the same + let ColumnarValue::Array(arr_1) = result_1 else { + panic!("Expected ColumnarValue::Array"); + }; + let ColumnarValue::Array(arr_2) = result_2 else { + panic!("Expected ColumnarValue::Array"); + }; + assert!(arr_1.eq(&arr_2)); + let expected = ScalarValue::Boolean(Some(false)) + .to_array_of_size(1) + .unwrap(); + assert!(arr_1.eq(&expected)); + } + + #[test] + fn test_snapshot() { + let expr = lit(42) as Arc; + let dynamic_filter = DynamicFilterPhysicalExpr::new(vec![], Arc::clone(&expr)); + + // Take a snapshot of the current expression + let snapshot = dynamic_filter.snapshot().unwrap(); + assert_eq!(snapshot, Some(expr)); + + // Update the current expression + let new_expr = lit(100) as Arc; + dynamic_filter.update(Arc::clone(&new_expr)).unwrap(); + // Take another snapshot + let snapshot = dynamic_filter.snapshot().unwrap(); + assert_eq!(snapshot, Some(new_expr)); + } + + #[test] + fn test_dynamic_filter_physical_expr_misbehaves_data_type_nullable() { + let dynamic_filter = + DynamicFilterPhysicalExpr::new(vec![], lit(42) as Arc); + + // First call to data_type and nullable should set the initial values. + let initial_data_type = dynamic_filter.data_type(&Schema::empty()).unwrap(); + let initial_nullable = dynamic_filter.nullable(&Schema::empty()).unwrap(); + + // Call again and expect no change. + let second_data_type = dynamic_filter.data_type(&Schema::empty()).unwrap(); + let second_nullable = dynamic_filter.nullable(&Schema::empty()).unwrap(); + assert_eq!( + initial_data_type, second_data_type, + "Data type should not change on second call." + ); + assert_eq!( + initial_nullable, second_nullable, + "Nullability should not change on second call." + ); + + // Now change the current expression to something else. + dynamic_filter + .update(lit(ScalarValue::Utf8(None)) as Arc) + .expect("Failed to update expression"); + // Check that we error if we call data_type, nullable or evaluate after changing the expression. + assert!( + dynamic_filter.data_type(&Schema::empty()).is_err(), + "Expected err when data_type is called after changing the expression." + ); + assert!( + dynamic_filter.nullable(&Schema::empty()).is_err(), + "Expected err when nullable is called after changing the expression." + ); + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + assert!( + dynamic_filter.evaluate(&batch).is_err(), + "Expected err when evaluate is called after changing the expression." + ); + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 469f7bbee3173..fa91635d9bfd9 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -306,18 +306,31 @@ impl InListExpr { } } +#[macro_export] +macro_rules! expr_vec_fmt { + ( $ARRAY:expr ) => {{ + $ARRAY + .iter() + .map(|e| format!("{e}")) + .collect::>() + .join(", ") + }}; +} + impl std::fmt::Display for InListExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let list = expr_vec_fmt!(self.list); + if self.negated { if self.static_filter.is_some() { - write!(f, "{} NOT IN (SET) ({:?})", self.expr, self.list) + write!(f, "{} NOT IN (SET) ([{list}])", self.expr) } else { - write!(f, "{} NOT IN ({:?})", self.expr, self.list) + write!(f, "{} NOT IN ([{list}])", self.expr) } } else if self.static_filter.is_some() { - write!(f, "Use {} IN (SET) ({:?})", self.expr, self.list) + write!(f, "{} IN (SET) ([{list}])", self.expr) } else { - write!(f, "{} IN ({:?})", self.expr, self.list) + write!(f, "{} IN ([{list}])", self.expr) } } } @@ -463,13 +476,14 @@ pub fn in_list( #[cfg(test)] mod tests { - use super::*; use crate::expressions; use crate::expressions::{col, lit, try_cast}; use datafusion_common::plan_err; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_physical_expr_common::physical_expr::fmt_sql; + use insta::assert_snapshot; + use itertools::Itertools as _; type InListCastResult = (Arc, Vec>); @@ -488,7 +502,8 @@ mod tests { let result_type = get_coerce_type(expr_type, &list_types); match result_type { None => plan_err!( - "Can not find compatible types to compare {expr_type:?} with {list_types:?}" + "Can not find compatible types to compare {expr_type} with [{}]", + list_types.iter().join(", ") ), Some(data_type) => { // find the coerced type @@ -1441,7 +1456,7 @@ mod tests { } #[test] - fn test_fmt_sql() -> Result<()> { + fn test_fmt_sql_1() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let col_a = col("a", &schema)?; @@ -1450,33 +1465,53 @@ mod tests { let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?; let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); - assert_eq!(sql_string, "a IN (a, b)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_snapshot!(sql_string, @"a IN (a, b)"); + assert_snapshot!(display_string, @"a@0 IN (SET) ([a, b])"); + Ok(()) + } + + #[test] + fn test_fmt_sql_2() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let col_a = col("a", &schema)?; // Test: a NOT IN ('a', 'b') let list = vec![lit("a"), lit("b")]; let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?; let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); - assert_eq!(sql_string, "a NOT IN (a, b)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_snapshot!(sql_string, @"a NOT IN (a, b)"); + assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b])"); + Ok(()) + } + + #[test] + fn test_fmt_sql_3() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let col_a = col("a", &schema)?; // Test: a IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?; let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); - assert_eq!(sql_string, "a IN (a, b, NULL)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_snapshot!(sql_string, @"a IN (a, b, NULL)"); + assert_snapshot!(display_string, @"a@0 IN (SET) ([a, b, NULL])"); + Ok(()) + } + + #[test] + fn test_fmt_sql_4() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let col_a = col("a", &schema)?; // Test: a NOT IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?; let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); - assert_eq!(sql_string, "a NOT IN (a, b, NULL)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); - + assert_snapshot!(sql_string, @"a NOT IN (a, b, NULL)"); + assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b, NULL])"); Ok(()) } } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 0619e72488581..62be8ebbc13e3 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,9 +17,6 @@ //! IS NOT NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, @@ -28,6 +25,8 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NOT NULL expression #[derive(Debug, Eq)] diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 4c6081f35cad7..356fe2a866672 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,9 +17,6 @@ //! IS NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, @@ -28,6 +25,8 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NULL expression #[derive(Debug, Eq)] diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index ebf9882665ba0..e86c778d51619 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; // Like expression #[derive(Debug, Eq)] diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 0d0c0ecc62c79..6e425ee439d69 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -23,26 +23,59 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::{Field, FieldRef}; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::expr::FieldMetadata; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Literal { value: ScalarValue, + field: FieldRef, +} + +impl Hash for Literal { + fn hash(&self, state: &mut H) { + self.value.hash(state); + let metadata = self.field.metadata(); + let mut keys = metadata.keys().collect::>(); + keys.sort(); + for key in keys { + key.hash(state); + metadata.get(key).unwrap().hash(state); + } + } } impl Literal { /// Create a literal value expression pub fn new(value: ScalarValue) -> Self { - Self { value } + Self::new_with_metadata(value, None) + } + + /// Create a literal value expression + pub fn new_with_metadata( + value: ScalarValue, + metadata: Option, + ) -> Self { + let mut field = Field::new("lit".to_string(), value.data_type(), value.is_null()); + + if let Some(metadata) = metadata { + field = metadata.add_to_field(field); + } + + Self { + value, + field: field.into(), + } } /// Get the scalar value @@ -71,6 +104,10 @@ impl PhysicalExpr for Literal { Ok(self.value.is_null()) } + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + fn evaluate(&self, _batch: &RecordBatch) -> Result { Ok(ColumnarValue::Scalar(self.value.clone())) } @@ -102,7 +139,7 @@ impl PhysicalExpr for Literal { /// Create a literal expression pub fn lit(value: T) -> Arc { match value.lit() { - Expr::Literal(v) => Arc::new(Literal::new(v)), + Expr::Literal(v, _) => Arc::new(Literal::new(v)), _ => unreachable!(), } } @@ -112,7 +149,7 @@ mod tests { use super::*; use arrow::array::Int32Array; - use arrow::datatypes::*; + use arrow::datatypes::Field; use datafusion_common::cast::as_int32_array; use datafusion_physical_expr_common::physical_expr::fmt_sql; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index f00b49f503141..59d675753d985 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -21,7 +21,9 @@ mod binary; mod case; mod cast; +mod cast_column; mod column; +mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; @@ -40,8 +42,10 @@ pub use crate::PhysicalSortExpr; pub use binary::{binary, similar_to, BinaryExpr}; pub use case::{case, CaseExpr}; pub use cast::{cast, CastExpr}; +pub use cast_column::CastColumnExpr; pub use column::{col, with_new_schema, Column}; pub use datafusion_expr::utils::format_state_name; +pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 33a1bae14d420..fa7224768a777 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, @@ -103,6 +104,10 @@ impl PhysicalExpr for NegativeExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 24d2f4d9e074d..94610996c6b00 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -21,12 +21,11 @@ use std::any::Any; use std::hash::Hash; use std::sync::Arc; +use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::PhysicalExpr; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 8a3348b43d20b..8184ef601e543 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; @@ -101,6 +101,10 @@ impl PhysicalExpr for NotExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index e49815cd8b644..b32aabbe5b006 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -22,12 +22,12 @@ use std::sync::Arc; use crate::PhysicalExpr; use arrow::compute; -use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::{DataType, Schema}; +use arrow::compute::CastOptions; +use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_common::{not_impl_err, Result}; use datafusion_expr::ColumnarValue; /// TRY_CAST expression casts an expression to a specific data type and returns NULL on invalid cast @@ -96,18 +96,14 @@ impl PhysicalExpr for TryCastExpr { safe: true, format_options: DEFAULT_FORMAT_OPTIONS, }; - match value { - ColumnarValue::Array(array) => { - let cast = cast_with_options(&array, &self.cast_type, &options)?; - Ok(ColumnarValue::Array(cast)) - } - ColumnarValue::Scalar(scalar) => { - let array = scalar.to_array()?; - let cast_array = cast_with_options(&array, &self.cast_type, &options)?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) - } - } + value.cast_to(&self.cast_type, Some(&options)) + } + + fn return_field(&self, input_schema: &Schema) -> Result { + self.expr + .return_field(input_schema) + .map(|f| f.as_ref().clone().with_data_type(self.cast_type.clone())) + .map(Arc::new) } fn children(&self) -> Vec<&Arc> { @@ -146,7 +142,7 @@ pub fn try_cast( } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(TryCastExpr::new(expr, cast_type))) } else { - not_impl_err!("Unsupported TRY_CAST from {expr_type:?} to {cast_type:?}") + not_impl_err!("Unsupported TRY_CAST from {expr_type} to {cast_type}") } } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index a53814c3ad2b8..c44197bbbe6fc 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -42,7 +42,7 @@ //! //! In order to use interval arithmetic to compute bounds for this expression, //! one would first determine intervals that represent the possible values of -//! `x` and `y`` Let's say that the interval for `x` is `[1, 2]` and the interval +//! `x` and `y` Let's say that the interval for `x` is `[1, 2]` and the interval //! for `y` is `[-3, 1]`. In the chart below, you can see how the computation //! takes place. //! @@ -148,12 +148,12 @@ use std::sync::Arc; use super::utils::{ convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, }; -use crate::expressions::Literal; +use crate::expressions::{BinaryExpr, Literal}; use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; @@ -645,6 +645,17 @@ impl ExprIntervalGraph { .map(|child| self.graph[*child].interval()) .collect::>(); let node_interval = self.graph[node].interval(); + // Special case: true OR could in principle be propagated by 3 interval sets, + // (i.e. left true, or right true, or both true) however we do not support this yet. + if node_interval == &Interval::CERTAINLY_TRUE + && self.graph[node] + .expr + .as_any() + .downcast_ref::() + .is_some_and(|expr| expr.op() == &Operator::Or) + { + return not_impl_err!("OR operator cannot yet propagate true intervals"); + } let propagated_intervals = self.graph[node] .expr .propagate_constraints(node_interval, &children_intervals)?; @@ -857,8 +868,8 @@ mod tests { let mut r = StdRng::seed_from_u64(seed); let (left_given, right_given, left_expected, right_expected) = if ASC { - let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); - let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + let left = r.random_range((0 as $TYPE)..(1000 as $TYPE)); + let right = r.random_range((0 as $TYPE)..(1000 as $TYPE)); ( (Some(left), None), (Some(right), None), @@ -866,8 +877,8 @@ mod tests { (Some(<$TYPE>::max(right, left + expr_right)), None), ) } else { - let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); - let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + let left = r.random_range((0 as $TYPE)..(1000 as $TYPE)); + let right = r.random_range((0 as $TYPE)..(1000 as $TYPE)); ( (None, Some(left)), (None, Some(right)), diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 910631ef4a43f..22752a00e9259 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -45,13 +45,13 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { if let Ok(field) = schema.field_with_name(column.name()) { is_datatype_supported(field.data_type()) } else { - return false; + false } } else if let Some(literal) = expr_any.downcast_ref::() { if let Ok(dt) = literal.data_type(schema) { is_datatype_supported(&dt) } else { - return false; + false } } else if let Some(cast) = expr_any.downcast_ref::() { check_support(cast.expr(), schema) diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 93ced2eb628d8..468591d34d71f 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -30,6 +30,7 @@ pub mod analysis; pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } +pub mod async_scalar_function; pub mod equivalence; pub mod expressions; pub mod intervals; @@ -37,6 +38,7 @@ mod partitioning; mod physical_expr; pub mod planner; mod scalar_function; +pub mod simplifier; pub mod statistics; pub mod utils; pub mod window; @@ -54,21 +56,21 @@ pub use equivalence::{ }; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ + add_offset_to_expr, add_offset_to_physical_sort_exprs, create_lex_ordering, create_ordering, create_physical_sort_expr, create_physical_sort_exprs, physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, - PhysicalExprRef, }; -pub use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, PhysicalExprRef}; pub use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, LexRequirement, PhysicalSortExpr, PhysicalSortRequirement, + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, + PhysicalSortRequirement, }; pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; - -pub use datafusion_physical_expr_common::utils::reverse_order_bys; -pub use utils::split_conjunction; +pub use simplifier::PhysicalExprSimplifier; +pub use utils::{conjunction, conjunction_opt, split_conjunction}; // For backwards compatibility pub mod tree_node { diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index eb7e1ea6282bb..d6b2b1b046f75 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -199,18 +199,17 @@ impl Partitioning { /// Calculate the output partitioning after applying the given projection. pub fn project( &self, - projection_mapping: &ProjectionMapping, + mapping: &ProjectionMapping, input_eq_properties: &EquivalenceProperties, ) -> Self { if let Partitioning::Hash(exprs, part) = self { - let normalized_exprs = exprs - .iter() - .map(|expr| { - input_eq_properties - .project_expr(expr, projection_mapping) - .unwrap_or_else(|| { - Arc::new(UnKnownColumn::new(&expr.to_string())) - }) + let normalized_exprs = input_eq_properties + .project_expressions(exprs, mapping) + .zip(exprs) + .map(|(proj_expr, expr)| { + proj_expr.unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) + }) }) .collect(); Partitioning::Hash(normalized_exprs, *part) diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 63c4ccbb4b385..2cc484ec6a62e 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,12 +17,38 @@ use std::sync::Arc; -use crate::create_physical_expr; +use crate::expressions::{self, Column}; +use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; + +use arrow::compute::SortOptions; +use arrow::datatypes::{Schema, SchemaRef}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; use datafusion_expr::execution_props::ExecutionProps; -pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -pub use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; +use datafusion_expr::{Expr, SortExpr}; + use itertools::izip; +// Exports: +pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// Adds the `offset` value to `Column` indices inside `expr`. This function is +/// generally used during the update of the right table schema in join operations. +pub fn add_offset_to_expr( + expr: Arc, + offset: isize, +) -> Result> { + expr.transform_down(|e| match e.as_any().downcast_ref::() { + Some(col) => { + let Some(idx) = col.index().checked_add_signed(offset) else { + return plan_err!("Column index overflow"); + }; + Ok(Transformed::yes(Arc::new(Column::new(col.name(), idx)))) + } + None => Ok(Transformed::no(e)), + }) + .data() +} /// This function is similar to the `contains` method of `Vec`. It finds /// whether `expr` is among `physical_exprs`. @@ -60,26 +86,21 @@ pub fn physical_exprs_bag_equal( multi_set_lhs == multi_set_rhs } -use crate::{expressions, LexOrdering, PhysicalSortExpr}; -use arrow::compute::SortOptions; -use arrow::datatypes::Schema; -use datafusion_common::plan_err; -use datafusion_common::Result; -use datafusion_expr::{Expr, SortExpr}; - -/// Converts logical sort expressions to physical sort expressions +/// Converts logical sort expressions to physical sort expressions. /// -/// This function transforms a collection of logical sort expressions into their physical -/// representation that can be used during query execution. +/// This function transforms a collection of logical sort expressions into their +/// physical representation that can be used during query execution. /// /// # Arguments /// -/// * `schema` - The schema containing column definitions -/// * `sort_order` - A collection of logical sort expressions grouped into lexicographic orderings +/// * `schema` - The schema containing column definitions. +/// * `sort_order` - A collection of logical sort expressions grouped into +/// lexicographic orderings. /// /// # Returns /// -/// A vector of lexicographic orderings for physical execution, or an error if the transformation fails +/// A vector of lexicographic orderings for physical execution, or an error if +/// the transformation fails. /// /// # Examples /// @@ -114,18 +135,13 @@ pub fn create_ordering( for (group_idx, exprs) in sort_order.iter().enumerate() { // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = LexOrdering::default(); + let mut sort_exprs = vec![]; for (expr_idx, sort) in exprs.iter().enumerate() { match &sort.expr { Expr::Column(col) => match expressions::col(&col.name, schema) { Ok(expr) => { - sort_exprs.push(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); + let opts = SortOptions::new(!sort.asc, sort.nulls_first); + sort_exprs.push(PhysicalSortExpr::new(expr, opts)); } // Cannot find expression in the projected_schema, stop iterating // since rest of the orderings are violated @@ -141,9 +157,33 @@ pub fn create_ordering( } } } - if !sort_exprs.is_empty() { - all_sort_orders.push(sort_exprs); - } + all_sort_orders.extend(LexOrdering::new(sort_exprs)); + } + Ok(all_sort_orders) +} + +/// Creates a vector of [LexOrdering] from a vector of logical expression +pub fn create_lex_ordering( + schema: &SchemaRef, + sort_order: &[Vec], + execution_props: &ExecutionProps, +) -> Result> { + // Try the fast path that only supports column references first + // This avoids creating a DFSchema + if let Ok(ordering) = create_ordering(schema, sort_order) { + return Ok(ordering); + } + + let df_schema = DFSchema::try_from(Arc::clone(schema))?; + + let mut all_sort_orders = vec![]; + + for exprs in sort_order.iter() { + all_sort_orders.extend(LexOrdering::new(create_physical_sort_exprs( + exprs, + &df_schema, + execution_props, + )?)); } Ok(all_sort_orders) } @@ -154,17 +194,9 @@ pub fn create_physical_sort_expr( input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = e; - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, + create_physical_expr(&e.expr, input_dfschema, execution_props).map(|expr| { + let options = SortOptions::new(!e.asc, e.nulls_first); + PhysicalSortExpr::new(expr, options) }) } @@ -173,23 +205,43 @@ pub fn create_physical_sort_exprs( exprs: &[SortExpr], input_dfschema: &DFSchema, execution_props: &ExecutionProps, -) -> Result { +) -> Result> { exprs .iter() - .map(|expr| create_physical_sort_expr(expr, input_dfschema, execution_props)) - .collect::>() + .map(|e| create_physical_sort_expr(e, input_dfschema, execution_props)) + .collect() +} + +pub fn add_offset_to_physical_sort_exprs( + sort_exprs: impl IntoIterator, + offset: isize, +) -> Result> { + sort_exprs + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = add_offset_to_expr(sort_expr.expr, offset)?; + Ok(sort_expr) + }) + .collect() } #[cfg(test)] mod tests { use super::*; - use crate::expressions::{Column, Literal}; + use crate::expressions::{BinaryExpr, Column, Literal}; use crate::physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, }; + use datafusion_physical_expr_common::physical_expr::is_volatile; - use datafusion_common::ScalarValue; + use arrow::datatypes::{DataType, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::ColumnarValue; + use datafusion_expr::Operator; + use std::any::Any; + use std::fmt; #[test] fn test_physical_exprs_contains() { @@ -302,4 +354,120 @@ mod tests { assert!(physical_exprs_bag_equal(list3.as_slice(), list3.as_slice())); assert!(physical_exprs_bag_equal(list4.as_slice(), list4.as_slice())); } + + #[test] + fn test_is_volatile_default_behavior() { + // Test that default PhysicalExpr implementations are not volatile + let literal = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + let column = Arc::new(Column::new("test", 0)) as Arc; + + // Test is_volatile_node() - should return false by default + assert!(!literal.is_volatile_node()); + assert!(!column.is_volatile_node()); + + // Test is_volatile() - should return false for non-volatile expressions + assert!(!is_volatile(&literal)); + assert!(!is_volatile(&column)); + } + + /// Mock volatile PhysicalExpr for testing purposes + #[derive(Debug, Clone, PartialEq, Eq, Hash)] + struct MockVolatileExpr { + volatile: bool, + } + + impl MockVolatileExpr { + fn new(volatile: bool) -> Self { + Self { volatile } + } + } + + impl fmt::Display for MockVolatileExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MockVolatile({})", self.volatile) + } + } + + impl PhysicalExpr for MockVolatileExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Boolean) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + self.volatile, + )))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn is_volatile_node(&self) -> bool { + self.volatile + } + + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "mock_volatile({})", self.volatile) + } + } + + #[test] + fn test_nested_expression_volatility() { + // Test that is_volatile() recursively detects volatility in expression trees + + // Create a volatile mock expression + let volatile_expr = + Arc::new(MockVolatileExpr::new(true)) as Arc; + assert!(volatile_expr.is_volatile_node()); + assert!(is_volatile(&volatile_expr)); + + // Create a non-volatile mock expression + let stable_expr = Arc::new(MockVolatileExpr::new(false)) as Arc; + assert!(!stable_expr.is_volatile_node()); + assert!(!is_volatile(&stable_expr)); + + // Create a literal (non-volatile) + let literal = + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))) as Arc; + assert!(!literal.is_volatile_node()); + assert!(!is_volatile(&literal)); + + // Test composite expression: volatile_expr AND literal + // The BinaryExpr itself is not volatile, but contains a volatile child + let composite_expr = Arc::new(BinaryExpr::new( + Arc::clone(&volatile_expr), + Operator::And, + Arc::clone(&literal), + )) as Arc; + + assert!(!composite_expr.is_volatile_node()); // BinaryExpr itself is not volatile + assert!(is_volatile(&composite_expr)); // But it contains a volatile child + + // Test composite expression with all non-volatile children + let stable_composite = Arc::new(BinaryExpr::new( + Arc::clone(&stable_expr), + Operator::And, + Arc::clone(&literal), + )) as Arc; + + assert!(!stable_composite.is_volatile_node()); + assert!(!is_volatile(&stable_composite)); // No volatile children + } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index fac83dfc45247..73df60c42e963 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -24,11 +24,14 @@ use crate::{ }; use arrow::datatypes::Schema; +use datafusion_common::config::ConfigOptions; use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{ + Alias, Cast, FieldMetadata, InList, Placeholder, ScalarFunction, +}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -102,23 +105,37 @@ use datafusion_expr::{ /// /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references -/// to qualified or unqualified fields by name. +/// to qualified or unqualified fields by name. pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { - let input_schema: &Schema = &input_dfschema.into(); + let input_schema = input_dfschema.as_arrow(); match e { - Expr::Alias(Alias { expr, .. }) => { - Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + Expr::Alias(Alias { expr, metadata, .. }) => { + if let Expr::Literal(v, prior_metadata) = expr.as_ref() { + let new_metadata = FieldMetadata::merge_options( + prior_metadata.as_ref(), + metadata.as_ref(), + ); + Ok(Arc::new(Literal::new_with_metadata( + v.clone(), + new_metadata, + ))) + } else { + Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + } } Expr::Column(c) => { let idx = input_dfschema.index_of_column(c)?; Ok(Arc::new(Column::new(&c.name, idx))) } - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::Literal(value, metadata) => Ok(Arc::new(Literal::new_with_metadata( + value.clone(), + metadata.clone(), + ))), Expr::ScalarVariable(_, variable_names) => { if is_system_variables(variable_names) { match execution_props.get_var_provider(VarType::System) { @@ -168,7 +185,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::Literal(ScalarValue::Boolean(None), None), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -176,7 +193,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::Literal(ScalarValue::Boolean(None), None), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -301,11 +318,16 @@ pub fn create_physical_expr( Expr::ScalarFunction(ScalarFunction { func, args }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; + let config_options = match execution_props.config_options.as_ref() { + Some(config_options) => Arc::clone(config_options), + None => Arc::new(ConfigOptions::default()), + }; Ok(Arc::new(ScalarFunctionExpr::try_new( Arc::clone(func), physical_args, input_schema, + config_options, )?)) } Expr::Between(Between { @@ -347,7 +369,7 @@ pub fn create_physical_expr( list, negated, }) => match expr.as_ref() { - Expr::Literal(ScalarValue::Utf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => { Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { @@ -380,11 +402,12 @@ where exprs .into_iter() .map(|expr| create_physical_expr(expr, input_dfschema, execution_props)) - .collect::>>() + .collect() } /// Convert a logical expression to a physical expression (without any simplification, etc) pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + // TODO this makes a deep copy of the Schema. Should take SchemaRef instead and avoid deep copy let df_schema = schema.clone().to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); create_physical_expr(expr, &df_schema, &execution_props).unwrap() diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 44bbcc4928c68..743d5b99cde95 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -31,30 +31,31 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; +use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + Volatility, }; /// Physical expression of a scalar function -#[derive(Eq, PartialEq, Hash)] pub struct ScalarFunctionExpr { fun: Arc, name: String, args: Vec>, - return_type: DataType, - nullable: bool, + return_field: FieldRef, + config_options: Arc, } impl Debug for ScalarFunctionExpr { @@ -63,7 +64,7 @@ impl Debug for ScalarFunctionExpr { .field("fun", &"") .field("name", &self.name) .field("args", &self.args) - .field("return_type", &self.return_type) + .field("return_field", &self.return_field) .finish() } } @@ -74,14 +75,15 @@ impl ScalarFunctionExpr { name: &str, fun: Arc, args: Vec>, - return_type: DataType, + return_field: FieldRef, + config_options: Arc, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type, - nullable: true, + return_field, + config_options, } } @@ -90,20 +92,20 @@ impl ScalarFunctionExpr { fun: Arc, args: Vec>, schema: &Schema, + config_options: Arc, ) -> Result { let name = fun.name().to_string(); - let arg_types = args + let arg_fields = args .iter() - .map(|e| e.data_type(schema)) + .map(|e| e.return_field(schema)) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` - data_types_with_scalar_udf(&arg_types, &fun)?; - - let nullables = args + let arg_types = arg_fields .iter() - .map(|e| e.nullable(schema)) - .collect::>>()?; + .map(|f| f.data_type().clone()) + .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args .iter() @@ -113,18 +115,17 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); - let ret_args = ReturnTypeArgs { - arg_types: &arg_types, + let ret_args = ReturnFieldArgs { + arg_fields: &arg_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts(); + let return_field = fun.return_field_from_args(ret_args)?; Ok(Self { fun, name, args, - return_type, - nullable, + return_field, + config_options, }) } @@ -145,16 +146,48 @@ impl ScalarFunctionExpr { /// Data type produced by this expression pub fn return_type(&self) -> &DataType { - &self.return_type + self.return_field.data_type() } pub fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.return_field = self + .return_field + .as_ref() + .clone() + .with_nullable(nullable) + .into(); self } pub fn nullable(&self) -> bool { - self.nullable + self.return_field.is_nullable() + } + + pub fn config_options(&self) -> &ConfigOptions { + &self.config_options + } + + /// Given an arbitrary PhysicalExpr attempt to downcast it to a ScalarFunctionExpr + /// and verify that its inner function is of type T. + /// If the downcast fails, or the function is not of type T, returns `None`. + /// Otherwise returns `Some(ScalarFunctionExpr)`. + pub fn try_downcast_func(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr> + where + T: 'static, + { + match expr.as_any().downcast_ref::() { + Some(scalar_expr) + if scalar_expr + .fun() + .inner() + .as_any() + .downcast_ref::() + .is_some() => + { + Some(scalar_expr) + } + _ => None, + } } } @@ -164,6 +197,51 @@ impl fmt::Display for ScalarFunctionExpr { } } +impl PartialEq for ScalarFunctionExpr { + fn eq(&self, o: &Self) -> bool { + if std::ptr::eq(self, o) { + // The equality implementation is somewhat expensive, so let's short-circuit when possible. + return true; + } + let Self { + fun, + name, + args, + return_field, + config_options, + } = self; + fun.eq(&o.fun) + && name.eq(&o.name) + && args.eq(&o.args) + && return_field.eq(&o.return_field) + && (Arc::ptr_eq(config_options, &o.config_options) + || sorted_config_entries(config_options) + == sorted_config_entries(&o.config_options)) + } +} +impl Eq for ScalarFunctionExpr {} +impl Hash for ScalarFunctionExpr { + fn hash(&self, state: &mut H) { + let Self { + fun, + name, + args, + return_field, + config_options: _, // expensive to hash, and often equal + } = self; + fun.hash(state); + name.hash(state); + args.hash(state); + return_field.hash(state); + } +} + +fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { + let mut entries = config_options.entries(); + entries.sort_by(|l, r| l.key.cmp(&r.key)); + entries +} + impl PhysicalExpr for ScalarFunctionExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -171,11 +249,11 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) + Ok(self.return_field.data_type().clone()) } fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(self.nullable) + Ok(self.return_field.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { @@ -185,6 +263,12 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; + let arg_fields = self + .args + .iter() + .map(|e| e.return_field(batch.schema_ref())) + .collect::>>()?; + let input_empty = args.is_empty(); let input_all_scalar = args .iter() @@ -193,8 +277,10 @@ impl PhysicalExpr for ScalarFunctionExpr { // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, + arg_fields, number_rows: batch.num_rows(), - return_type: &self.return_type, + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), })?; if let ColumnarValue::Array(array) = &output { @@ -214,6 +300,10 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) + } + fn children(&self) -> Vec<&Arc> { self.args.iter().collect() } @@ -222,15 +312,13 @@ impl PhysicalExpr for ScalarFunctionExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new( - ScalarFunctionExpr::new( - &self.name, - Arc::clone(&self.fun), - children, - self.return_type().clone(), - ) - .with_nullable(self.nullable), - )) + Ok(Arc::new(ScalarFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + Arc::clone(&self.return_field), + Arc::clone(&self.config_options), + ))) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { @@ -271,4 +359,89 @@ impl PhysicalExpr for ScalarFunctionExpr { } write!(f, ")") } + + fn is_volatile_node(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::Column; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; + use datafusion_physical_expr_common::physical_expr::is_volatile; + use std::any::Any; + + /// Test helper to create a mock UDF with a specific volatility + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockScalarUDF { + signature: Signature, + } + + impl ScalarUDFImpl for MockScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "mock_function" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42)))) + } + } + + #[test] + fn test_scalar_function_volatile_node() { + // Create a volatile UDF + let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF { + signature: Signature::uniform( + 1, + vec![DataType::Float32], + Volatility::Volatile, + ), + })); + + // Create a non-volatile UDF + let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + })); + + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let args = vec![Arc::new(Column::new("a", 0)) as Arc]; + let config_options = Arc::new(ConfigOptions::new()); + + // Test volatile function + let volatile_expr = ScalarFunctionExpr::try_new( + volatile_udf, + args.clone(), + &schema, + Arc::clone(&config_options), + ) + .unwrap(); + + assert!(volatile_expr.is_volatile_node()); + let volatile_arc: Arc = Arc::new(volatile_expr); + assert!(is_volatile(&volatile_arc)); + + // Test non-volatile function + let stable_expr = + ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options) + .unwrap(); + + assert!(!stable_expr.is_volatile_node()); + let stable_arc: Arc = Arc::new(stable_expr); + assert!(!is_volatile(&stable_arc)); + } } diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs new file mode 100644 index 0000000000000..80d6ee0a7b914 --- /dev/null +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplifier for Physical Expressions + +use arrow::datatypes::Schema; +use datafusion_common::{ + tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + Result, +}; +use std::sync::Arc; + +use crate::PhysicalExpr; + +pub mod unwrap_cast; + +/// Simplifies physical expressions by applying various optimizations +/// +/// This can be useful after adapting expressions from a table schema +/// to a file schema. For example, casts added to match the types may +/// potentially be unwrapped. +pub struct PhysicalExprSimplifier<'a> { + schema: &'a Schema, +} + +impl<'a> PhysicalExprSimplifier<'a> { + /// Create a new physical expression simplifier + pub fn new(schema: &'a Schema) -> Self { + Self { schema } + } + + /// Simplify a physical expression + pub fn simplify( + &mut self, + expr: Arc, + ) -> Result> { + Ok(expr.rewrite(self)?.data) + } +} + +impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { + type Node = Arc; + + fn f_up(&mut self, node: Self::Node) -> Result> { + // Apply unwrap cast optimization + #[cfg(test)] + let original_type = node.data_type(self.schema).unwrap(); + let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?; + #[cfg(test)] + assert_eq!( + unwrapped.data.data_type(self.schema).unwrap(), + original_type, + "Simplified expression should have the same data type as the original" + ); + Ok(unwrapped) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + + fn test_schema() -> Schema { + Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int64, false), + Field::new("c3", DataType::Utf8, false), + ]) + } + + #[test] + fn test_simplify() { + let schema = test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // Create: cast(c2 as INT32) != INT32(99) + let column_expr = col("c2", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None)); + let literal_expr = lit(ScalarValue::Int32(Some(99))); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr)); + + // Apply full simplification (uses TreeNodeRewriter) + let optimized = simplifier.simplify(binary_expr).unwrap(); + + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match) + let left_expr = optimized_binary.left(); + assert!( + left_expr.as_any().downcast_ref::().is_none() + && left_expr.as_any().downcast_ref::().is_none() + ); + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99))); + } + + #[test] + fn test_nested_expression_simplification() { + let schema = test_schema(); + let mut simplifier = PhysicalExprSimplifier::new(&schema); + + // Create nested expression: (cast(c1 as INT64) > INT64(5)) OR (cast(c2 as INT32) <= INT32(10)) + let c1_expr = col("c1", &schema).unwrap(); + let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None)); + let c1_literal = lit(ScalarValue::Int64(Some(5))); + let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal)); + + let c2_expr = col("c2", &schema).unwrap(); + let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None)); + let c2_literal = lit(ScalarValue::Int32(Some(10))); + let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal)); + + let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary)); + + // Apply simplification + let optimized = simplifier.simplify(or_expr).unwrap(); + + let or_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Verify left side: c1 > INT32(5) + let left_binary = or_binary + .left() + .as_any() + .downcast_ref::() + .unwrap(); + let left_left_expr = left_binary.left(); + assert!( + left_left_expr.as_any().downcast_ref::().is_none() + && left_left_expr + .as_any() + .downcast_ref::() + .is_none() + ); + let left_literal = left_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5))); + + // Verify right side: c2 <= INT64(10) + let right_binary = or_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + let right_left_expr = right_binary.left(); + assert!( + right_left_expr + .as_any() + .downcast_ref::() + .is_none() + && right_left_expr + .as_any() + .downcast_ref::() + .is_none() + ); + let right_literal = right_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10))); + } +} diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs new file mode 100644 index 0000000000000..d409ce9cb5bf2 --- /dev/null +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -0,0 +1,646 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Unwrap casts in binary comparisons for physical expressions +//! +//! This module provides optimization for physical expressions similar to the logical +//! optimizer's unwrap_cast module. It attempts to remove casts from comparisons to +//! literals by applying the casts to the literals if possible. +//! +//! The optimization improves performance by: +//! 1. Reducing runtime cast operations on column data +//! 2. Enabling better predicate pushdown opportunities +//! 3. Optimizing filter expressions in physical plans +//! +//! # Example +//! +//! Physical expression: `cast(column as INT64) > INT64(10)` +//! Optimized to: `column > INT32(10)` (assuming column is INT32) + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::{ + tree_node::{Transformed, TreeNode}, + Result, ScalarValue, +}; +use datafusion_expr::Operator; +use datafusion_expr_common::casts::try_cast_literal_to_type; + +use crate::expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; +use crate::PhysicalExpr; + +/// Attempts to unwrap casts in comparison expressions. +pub(crate) fn unwrap_cast_in_comparison( + expr: Arc, + schema: &Schema, +) -> Result>> { + expr.transform_down(|e| { + if let Some(binary) = e.as_any().downcast_ref::() { + if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? { + return Ok(Transformed::yes(unwrapped)); + } + } + Ok(Transformed::no(e)) + }) +} + +/// Try to unwrap casts in binary expressions +fn try_unwrap_cast_binary( + binary: &BinaryExpr, + schema: &Schema, +) -> Result>> { + // Case 1: cast(left_expr) op literal + if let (Some((inner_expr, _cast_type)), Some(literal)) = ( + extract_cast_info(binary.left()), + binary.right().as_any().downcast_ref::(), + ) { + if binary.op().supports_propagation() { + if let Some(unwrapped) = try_unwrap_cast_comparison( + Arc::clone(inner_expr), + literal.value(), + *binary.op(), + schema, + )? { + return Ok(Some(unwrapped)); + } + } + } + + // Case 2: literal op cast(right_expr) + if let (Some(literal), Some((inner_expr, _cast_type))) = ( + binary.left().as_any().downcast_ref::(), + extract_cast_info(binary.right()), + ) { + // For literal op cast(expr), we need to swap the operator + if let Some(swapped_op) = binary.op().swap() { + if binary.op().supports_propagation() { + if let Some(unwrapped) = try_unwrap_cast_comparison( + Arc::clone(inner_expr), + literal.value(), + swapped_op, + schema, + )? { + return Ok(Some(unwrapped)); + } + } + } + // If the operator cannot be swapped, we skip this optimization case + // but don't prevent other optimizations + } + + Ok(None) +} + +/// Extract cast information from a physical expression +/// +/// If the expression is a CAST(expr, datatype) or TRY_CAST(expr, datatype), +/// returns Some((inner_expr, target_datatype)). Otherwise returns None. +fn extract_cast_info( + expr: &Arc, +) -> Option<(&Arc, &DataType)> { + if let Some(cast) = expr.as_any().downcast_ref::() { + Some((cast.expr(), cast.cast_type())) + } else if let Some(try_cast) = expr.as_any().downcast_ref::() { + Some((try_cast.expr(), try_cast.cast_type())) + } else { + None + } +} + +/// Try to unwrap a cast in comparison by moving the cast to the literal +fn try_unwrap_cast_comparison( + inner_expr: Arc, + literal_value: &ScalarValue, + op: Operator, + schema: &Schema, +) -> Result>> { + // Get the data type of the inner expression + let inner_type = inner_expr.data_type(schema)?; + + // Try to cast the literal to the inner expression's type + if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) { + let literal_expr = lit(casted_literal); + let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr); + return Ok(Some(Arc::new(binary_expr))); + } + + Ok(None) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, lit}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + + /// Check if an expression is a cast expression + fn is_cast_expr(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() + } + + /// Check if a binary expression is suitable for cast unwrapping + fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool { + // Check if left is cast and right is literal + let left_cast_right_literal = is_cast_expr(binary.left()) + && binary.right().as_any().downcast_ref::().is_some(); + + // Check if left is literal and right is cast + let left_literal_right_cast = + binary.left().as_any().downcast_ref::().is_some() + && is_cast_expr(binary.right()); + + left_cast_right_literal || left_literal_right_cast + } + + fn test_schema() -> Schema { + Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int64, false), + Field::new("c3", DataType::Utf8, false), + ]) + } + + #[test] + fn test_unwrap_cast_in_binary_comparison() { + let schema = test_schema(); + + // Create: cast(c1 as INT64) > INT64(10) + let column_expr = col("c1", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // The result should be: c1 > INT32(10) + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Check that left side is no longer a cast + assert!(!is_cast_expr(optimized_binary.left())); + + // Check that right side is a literal with the correct type and value + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10))); + } + + #[test] + fn test_unwrap_cast_with_literal_on_left() { + let schema = test_schema(); + + // Create: INT64(10) < cast(c1 as INT64) + let column_expr = col("c1", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // The result should be equivalent to: c1 > INT32(10) + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Check the operator was swapped + assert_eq!(*optimized_binary.op(), Operator::Gt); + } + + #[test] + fn test_no_unwrap_when_types_unsupported() { + let schema = Schema::new(vec![Field::new("f1", DataType::Float32, false)]); + + // Create: cast(f1 as FLOAT64) > FLOAT64(10.5) + let column_expr = col("f1", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Float64, None)); + let literal_expr = lit(10.5f64); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should NOT be transformed (floating point types not supported) + assert!(!result.transformed); + } + + #[test] + fn test_is_binary_expr_with_cast_and_literal() { + let schema = test_schema(); + + let column_expr = col("c1", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + let binary_ref = binary_expr.as_any().downcast_ref::().unwrap(); + + assert!(is_binary_expr_with_cast_and_literal(binary_ref)); + } + + #[test] + fn test_unwrap_cast_literal_on_left_side() { + // Test case for: literal <= cast(column) + // This was the specific case that caused the bug + let schema = Schema::new(vec![Field::new( + "decimal_col", + DataType::Decimal128(9, 2), + true, + )]); + + // Create: Decimal128(400) <= cast(decimal_col as Decimal128(22, 2)) + let column_expr = col("decimal_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Decimal128(22, 2), + None, + )); + let literal_expr = lit(ScalarValue::Decimal128(Some(400), 22, 2)); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::LtEq, cast_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // The result should be: decimal_col >= Decimal128(400, 9, 2) + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Check operator was swapped correctly + assert_eq!(*optimized_binary.op(), Operator::GtEq); + + // Check that left side is the column without cast + assert!(!is_cast_expr(optimized_binary.left())); + + // Check that right side is a literal with the correct type + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + right_literal.value().data_type(), + DataType::Decimal128(9, 2) + ); + } + + #[test] + fn test_unwrap_cast_with_different_comparison_operators() { + let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]); + + // Test all comparison operators with literal on the left + let operators = vec![ + (Operator::Lt, Operator::Gt), + (Operator::LtEq, Operator::GtEq), + (Operator::Gt, Operator::Lt), + (Operator::GtEq, Operator::LtEq), + (Operator::Eq, Operator::Eq), + (Operator::NotEq, Operator::NotEq), + ]; + + for (original_op, expected_op) in operators { + // Create: INT64(100) op cast(int_col as INT64) + let column_expr = col("int_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(100i64); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, original_op, cast_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + let optimized = result.data; + let optimized_binary = + optimized.as_any().downcast_ref::().unwrap(); + + // Check the operator was swapped correctly + assert_eq!( + *optimized_binary.op(), + expected_op, + "Failed for operator {original_op:?} -> {expected_op:?}" + ); + + // Check that left side has no cast + assert!(!is_cast_expr(optimized_binary.left())); + + // Check that the literal was cast to the column type + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100))); + } + } + + #[test] + fn test_unwrap_cast_with_decimal_types() { + // Test various decimal precision/scale combinations + let test_cases = vec![ + // (column_precision, column_scale, cast_precision, cast_scale, value) + (9, 2, 22, 2, 400), + (10, 3, 20, 3, 1000), + (5, 1, 10, 1, 99), + ]; + + for (col_p, col_s, cast_p, cast_s, value) in test_cases { + let schema = Schema::new(vec![Field::new( + "decimal_col", + DataType::Decimal128(col_p, col_s), + true, + )]); + + // Test both: cast(column) op literal AND literal op cast(column) + + // Case 1: cast(column) > literal + let column_expr = col("decimal_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + Arc::clone(&column_expr), + DataType::Decimal128(cast_p, cast_s), + None, + )); + let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s)); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + assert!(result.transformed); + + // Case 2: literal < cast(column) + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Decimal128(cast_p, cast_s), + None, + )); + let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s)); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr)); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + assert!(result.transformed); + } + } + + #[test] + fn test_unwrap_cast_with_null_literals() { + // Test with NULL literals to ensure they're handled correctly + let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, true)]); + + // Create: cast(int_col as INT64) = NULL + let column_expr = col("int_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let null_literal = lit(ScalarValue::Int64(None)); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, null_literal)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // Verify the NULL was cast to the column type + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(None)); + } + + #[test] + fn test_unwrap_cast_with_try_cast() { + // Test that TryCast expressions are also unwrapped correctly + let schema = Schema::new(vec![Field::new("str_col", DataType::Utf8, true)]); + + // Create: try_cast(str_col as INT64) > INT64(100) + let column_expr = col("str_col", &schema).unwrap(); + let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64)); + let literal_expr = lit(100i64); + let binary_expr = + Arc::new(BinaryExpr::new(try_cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should NOT be transformed (string to int cast not supported) + assert!(!result.transformed); + } + + #[test] + fn test_unwrap_cast_preserves_non_comparison_operators() { + // Test that non-comparison operators in AND/OR expressions are preserved + let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]); + + // Create: cast(int_col as INT64) > INT64(10) AND cast(int_col as INT64) < INT64(20) + let column_expr = col("int_col", &schema).unwrap(); + + let cast1 = Arc::new(CastExpr::new( + Arc::clone(&column_expr), + DataType::Int64, + None, + )); + let lit1 = lit(10i64); + let compare1 = Arc::new(BinaryExpr::new(cast1, Operator::Gt, lit1)); + + let cast2 = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let lit2 = lit(20i64); + let compare2 = Arc::new(BinaryExpr::new(cast2, Operator::Lt, lit2)); + + let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // Verify the AND operator is preserved + let optimized = result.data; + let and_binary = optimized.as_any().downcast_ref::().unwrap(); + assert_eq!(*and_binary.op(), Operator::And); + + // Both sides should have their casts unwrapped + let left_binary = and_binary + .left() + .as_any() + .downcast_ref::() + .unwrap(); + let right_binary = and_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(!is_cast_expr(left_binary.left())); + assert!(!is_cast_expr(right_binary.left())); + } + + #[test] + fn test_try_cast_unwrapping() { + let schema = test_schema(); + + // Create: try_cast(c1 as INT64) <= INT64(100) + let column_expr = col("c1", &schema).unwrap(); + let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64)); + let literal_expr = lit(100i64); + let binary_expr = + Arc::new(BinaryExpr::new(try_cast_expr, Operator::LtEq, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed to: c1 <= INT32(100) + assert!(result.transformed); + + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Verify the try_cast was removed + assert!(!is_cast_expr(optimized_binary.left())); + + // Verify the literal was converted + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100))); + } + + #[test] + fn test_non_swappable_operator() { + // Test case with an operator that cannot be swapped + let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]); + + // Create: INT64(10) + cast(int_col as INT64) + // The Plus operator cannot be swapped, so this should not be transformed + let column_expr = col("int_col", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::Plus, cast_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should NOT be transformed because Plus cannot be swapped + assert!(!result.transformed); + } + + #[test] + fn test_cast_that_cannot_be_unwrapped_overflow() { + // Test case where the literal value would overflow the target type + let schema = Schema::new(vec![Field::new("small_int", DataType::Int8, false)]); + + // Create: cast(small_int as INT64) > INT64(1000) + // This should NOT be unwrapped because 1000 cannot fit in Int8 (max value is 127) + let column_expr = col("small_int", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = lit(1000i64); // Value too large for Int8 + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should NOT be transformed due to overflow + assert!(!result.transformed); + } + + #[test] + fn test_complex_nested_expression() { + let schema = test_schema(); + + // Create a more complex expression with nested casts + // (cast(c1 as INT64) > INT64(10)) AND (cast(c2 as INT32) = INT32(20)) + let c1_expr = col("c1", &schema).unwrap(); + let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None)); + let c1_literal = lit(10i64); + let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal)); + + let c2_expr = col("c2", &schema).unwrap(); + let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None)); + let c2_literal = lit(20i32); + let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::Eq, c2_literal)); + + // Create AND expression + let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // Verify both sides of the AND were optimized + let optimized = result.data; + let and_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Left side should be: c1 > INT32(10) + let left_binary = and_binary + .left() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(!is_cast_expr(left_binary.left())); + let left_literal = left_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(10))); + + // Right side should be: c2 = INT64(20) (c2 is already INT64, literal cast to match) + let right_binary = and_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(!is_cast_expr(right_binary.left())); + let right_literal = right_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(20))); + } +} diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 8092dc3c1a614..8a57cc7b7c154 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -129,35 +129,15 @@ impl LiteralGuarantee { .as_any() .downcast_ref::() { - // Only support single-column inlist currently, multi-column inlist is not supported - let col = inlist - .expr() - .as_any() - .downcast_ref::(); - let Some(col) = col else { - return builder; - }; - - let literals = inlist - .list() - .iter() - .map(|e| e.as_any().downcast_ref::()) - .collect::>>(); - let Some(literals) = literals else { - return builder; - }; - - let guarantee = if inlist.negated() { - Guarantee::NotIn + if let Some(inlist) = ColInList::try_new(inlist) { + builder.aggregate_multi_conjunct( + inlist.col, + inlist.guarantee, + inlist.list.iter().map(|lit| lit.value()), + ) } else { - Guarantee::In - }; - - builder.aggregate_multi_conjunct( - col, - guarantee, - literals.iter().map(|e| e.value()), - ) + builder + } } else { // split disjunction: OR OR ... let disjunctions = split_disjunction(expr); @@ -184,16 +164,6 @@ impl LiteralGuarantee { .filter_map(|expr| ColOpLit::try_new(expr)) .collect::>(); - if terms.is_empty() { - return builder; - } - - // if not all terms are of the form (col literal), - // can't infer any guarantees - if terms.len() != disjunctions.len() { - return builder; - } - // if all terms are 'col literal' with the same column // and operation we can infer any guarantees // @@ -203,18 +173,70 @@ impl LiteralGuarantee { // foo is required for the expression to be true. // So we can only create a multi value guarantee for `=` // (or a single value). (e.g. ignore `a != foo OR a != bar`) - let first_term = &terms[0]; - if terms.iter().all(|term| { - term.col.name() == first_term.col.name() - && term.guarantee == Guarantee::In - }) { + let first_term = terms.first(); + if !terms.is_empty() + && terms.len() == disjunctions.len() + && terms.iter().all(|term| { + term.col.name() == first_term.unwrap().col.name() + && term.guarantee == Guarantee::In + }) + { builder.aggregate_multi_conjunct( - first_term.col, + first_term.unwrap().col, Guarantee::In, terms.iter().map(|term| term.lit.value()), ) } else { - // can't infer anything + // Handle disjunctions with conjunctions like (a = 1 AND b = 2) OR (a = 2 AND b = 3) + // Extract termsets from each disjunction + // if in each termset, they have same column, and the guarantee is In, + // we can infer a guarantee for the column + // e.g. (a = 1 AND b = 2) OR (a = 2 AND b = 3) is `a IN (1, 2) AND b IN (2, 3)` + // otherwise, we can't infer a guarantee + let termsets: Vec> = disjunctions + .iter() + .map(|expr| { + split_conjunction(expr) + .into_iter() + .filter_map(ColOpLitOrInList::try_new) + .filter(|term| term.guarantee() == Guarantee::In) + .collect() + }) + .collect(); + + // Early return if any termset is empty (can't infer guarantees) + if termsets.iter().any(|terms| terms.is_empty()) { + return builder; + } + + // Find columns that appear in all termsets + let common_cols = find_common_columns(&termsets); + if common_cols.is_empty() { + return builder; + } + + // Build guarantees for common columns + let mut builder = builder; + for col in common_cols { + let literals: Vec<_> = termsets + .iter() + .filter_map(|terms| { + terms.iter().find(|term| term.col() == col).map( + |term| { + term.lits().into_iter().map(|lit| lit.value()) + }, + ) + }) + .flatten() + .collect(); + + builder = builder.aggregate_multi_conjunct( + col, + Guarantee::In, + literals.into_iter(), + ); + } + builder } } @@ -362,7 +384,7 @@ struct ColOpLit<'a> { } impl<'a> ColOpLit<'a> { - /// Returns Some(ColEqLit) if the expression is either: + /// Returns Some(ColOpLit) if the expression is either: /// 1. `col literal` /// 2. `literal col` /// 3. operator is `=` or `!=` @@ -410,6 +432,115 @@ impl<'a> ColOpLit<'a> { } } +/// Represents a single `col [not]in literal` expression +struct ColInList<'a> { + col: &'a crate::expressions::Column, + guarantee: Guarantee, + list: Vec<&'a crate::expressions::Literal>, +} + +impl<'a> ColInList<'a> { + /// Returns Some(ColInList) if the expression is either: + /// 1. `col (literal1, literal2, ...)` + /// 3. operator is `in` or `not in` + /// + /// Returns None otherwise + fn try_new(inlist: &'a crate::expressions::InListExpr) -> Option { + // Only support single-column inlist currently, multi-column inlist is not supported + let col = inlist + .expr() + .as_any() + .downcast_ref::()?; + + let literals = inlist + .list() + .iter() + .map(|e| e.as_any().downcast_ref::()) + .collect::>>()?; + + let guarantee = if inlist.negated() { + Guarantee::NotIn + } else { + Guarantee::In + }; + + Some(Self { + col, + guarantee, + list: literals, + }) + } +} + +/// Represents a single `col [not]in literal` expression or a single `col literal` expression +enum ColOpLitOrInList<'a> { + ColOpLit(ColOpLit<'a>), + ColInList(ColInList<'a>), +} + +impl<'a> ColOpLitOrInList<'a> { + fn try_new(expr: &'a Arc) -> Option { + match expr + .as_any() + .downcast_ref::() + { + Some(inlist) => Some(Self::ColInList(ColInList::try_new(inlist)?)), + None => ColOpLit::try_new(expr).map(Self::ColOpLit), + } + } + + fn guarantee(&self) -> Guarantee { + match self { + Self::ColOpLit(col_op_lit) => col_op_lit.guarantee, + Self::ColInList(col_in_list) => col_in_list.guarantee, + } + } + + fn col(&self) -> &'a crate::expressions::Column { + match self { + Self::ColOpLit(col_op_lit) => col_op_lit.col, + Self::ColInList(col_in_list) => col_in_list.col, + } + } + + fn lits(&self) -> Vec<&'a crate::expressions::Literal> { + match self { + Self::ColOpLit(col_op_lit) => vec![col_op_lit.lit], + Self::ColInList(col_in_list) => col_in_list.list.clone(), + } + } +} + +/// Find columns that appear in all termsets +fn find_common_columns<'a>( + termsets: &[Vec>], +) -> Vec<&'a crate::expressions::Column> { + if termsets.is_empty() { + return Vec::new(); + } + + // Start with columns from the first termset + let mut common_cols: HashSet<_> = termsets[0].iter().map(|term| term.col()).collect(); + + // check if any common_col in one termset occur many times + // e.g. (a = 1 AND a = 2) OR (a = 2 AND b = 3), should not infer a guarantee + // TODO: for above case, we can infer a IN (2) AND b IN (3) + if common_cols.len() != termsets[0].len() { + return Vec::new(); + } + + // Intersect with columns from remaining termsets + for termset in termsets.iter().skip(1) { + let termset_cols: HashSet<_> = termset.iter().map(|term| term.col()).collect(); + if termset_cols.len() != termset.len() { + return Vec::new(); + } + common_cols = common_cols.intersection(&termset_cols).cloned().collect(); + } + + common_cols.into_iter().collect() +} + #[cfg(test)] mod test { use std::sync::LazyLock; @@ -808,12 +939,11 @@ mod test { vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])], ); // b IN (1, 2, 3) OR b = 2 - // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we don't support to analyze this kind of disjunction. Only `ColOpLit OR ColOpLit` is supported. test_analyze( col("b") .in_list(vec![lit(1), lit(2), lit(3)], false) .or(col("b").eq(lit(2))), - vec![], + vec![in_guarantee("b", [1, 2, 3])], ); // b IN (1, 2, 3) OR b != 3 test_analyze( @@ -824,13 +954,123 @@ mod test { ); } + #[test] + fn test_disjunction_and_conjunction_multi_column() { + // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1, 2])], + ); + // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) OR (b = 3) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))) + .or(col("b").eq(lit(3))), + vec![in_guarantee("b", [1, 2, 3])], + ); + // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) OR (c = 3) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))) + .or(col("c").eq(lit(3))), + vec![], + ); + // (a = "foo" AND b > 1) OR (a = "bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").gt(lit(1)))) + .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("a", ["foo", "bar"])], + ); + // (a = "foo" AND b = 1) OR (b = 1 AND c = 2) OR (c = 3 AND a = "bar") + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("b").eq(lit(1)).and(col("c").eq(lit(2)))) + .or(col("c").eq(lit(3)).and(col("a").eq(lit("bar")))), + vec![], + ); + // (a = "foo" AND a = "bar") OR (a = "good" AND b = 1) + // TODO: this should be `a IN ("good") AND b IN (1)` + test_analyze( + (col("a").eq(lit("foo")).and(col("a").eq(lit("bar")))) + .or(col("a").eq(lit("good")).and(col("b").eq(lit(1)))), + vec![], + ); + // (a = "foo" AND a = "foo") OR (a = "good" AND b = 1) + // TODO: this should be `a IN ("foo", "good")` + test_analyze( + (col("a").eq(lit("foo")).and(col("a").eq(lit("foo")))) + .or(col("a").eq(lit("good")).and(col("b").eq(lit(1)))), + vec![], + ); + // (a = "foo" AND b = 3) OR (b = 4 AND b = 1) OR (b = 2 AND a = "bar") + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(3)))) + .or(col("b").eq(lit(4)).and(col("b").eq(lit(1)))) + .or(col("b").eq(lit(2)).and(col("a").eq(lit("bar")))), + vec![], + ); + // (b = 1 AND b > 3) OR (a = "foo" AND b = 4) + test_analyze( + (col("b").eq(lit(1)).and(col("b").gt(lit(3)))) + .or(col("a").eq(lit("foo")).and(col("b").eq(lit(4)))), + // if b isn't 1 or 4, it can not be true (though the expression actually can never be true) + vec![in_guarantee("b", [1, 4])], + ); + // (a = "foo" AND b = 1) OR (a != "bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").not_eq(lit("bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("b", [1, 2])], + ); + // (a = "foo" AND b = 1) OR (a LIKE "%bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").like(lit("%bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("b", [1, 2])], + ); + // (a IN ("foo", "bar") AND b = 5) OR (a IN ("foo", "bar") AND b = 6) + test_analyze( + (col("a") + .in_list(vec![lit("foo"), lit("bar")], false) + .and(col("b").eq(lit(5)))) + .or(col("a") + .in_list(vec![lit("foo"), lit("bar")], false) + .and(col("b").eq(lit(6)))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [5, 6])], + ); + // (a IN ("foo", "bar") AND b = 5) OR (a IN ("foo") AND b = 6) + test_analyze( + (col("a") + .in_list(vec![lit("foo"), lit("bar")], false) + .and(col("b").eq(lit(5)))) + .or(col("a") + .in_list(vec![lit("foo")], false) + .and(col("b").eq(lit(6)))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [5, 6])], + ); + // (a NOT IN ("foo", "bar") AND b = 5) OR (a NOT IN ("foo") AND b = 6) + test_analyze( + (col("a") + .in_list(vec![lit("foo"), lit("bar")], true) + .and(col("b").eq(lit(5)))) + .or(col("a") + .in_list(vec![lit("foo")], true) + .and(col("b").eq(lit(6)))), + vec![in_guarantee("b", [5, 6])], + ); + } + /// Tests that analyzing expr results in the expected guarantees fn test_analyze(expr: Expr, expected: Vec) { println!("Begin analyze of {expr}"); let schema = schema(); let physical_expr = logical2physical(&expr, &schema); - let actual = LiteralGuarantee::analyze(&physical_expr); + let actual = LiteralGuarantee::analyze(&physical_expr) + .into_iter() + .sorted_by_key(|g| g.column.name().to_string()) + .collect::>(); assert_eq!( expected, actual, "expr: {expr}\ @@ -867,6 +1107,7 @@ mod test { Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), ])) }); Arc::clone(&SCHEMA) diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 7e4c7f0e10ba8..745ae855efee2 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -26,15 +26,13 @@ use crate::tree_node::ExprContext; use crate::PhysicalExpr; use crate::PhysicalSortExpr; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::Schema; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{HashMap, HashSet, Result}; use datafusion_expr::Operator; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use itertools::Itertools; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; @@ -47,6 +45,31 @@ pub fn split_conjunction( split_impl(Operator::And, predicate, vec![]) } +/// Create a conjunction of the given predicates. +/// If the input is empty, return a literal true. +/// If the input contains a single predicate, return the predicate. +/// Otherwise, return a conjunction of the predicates (e.g. `a AND b AND c`). +pub fn conjunction( + predicates: impl IntoIterator>, +) -> Arc { + conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true)) +} + +/// Create a conjunction of the given predicates. +/// If the input is empty or the return None. +/// If the input contains a single predicate, return Some(predicate). +/// Otherwise, return a Some(..) of a conjunction of the predicates (e.g. `Some(a AND b AND c)`). +pub fn conjunction_opt( + predicates: impl IntoIterator>, +) -> Option> { + predicates + .into_iter() + .fold(None, |acc, predicate| match acc { + None => Some(predicate), + Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))), + }) +} + /// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs. /// /// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] @@ -215,22 +238,23 @@ pub fn collect_columns(expr: &Arc) -> HashSet { columns } -/// Re-assign column indices referenced in predicate according to given schema. -/// This may be helpful when dealing with projections. -pub fn reassign_predicate_columns( - pred: Arc, - schema: &SchemaRef, - ignore_not_found: bool, +/// Re-assign indices of [`Column`]s within the given [`PhysicalExpr`] according to +/// the provided [`Schema`]. +/// +/// This can be useful when attempting to map an expression onto a different schema. +/// +/// # Errors +/// +/// This function will return an error if any column in the expression cannot be found +/// in the provided schema. +pub fn reassign_expr_columns( + expr: Arc, + schema: &Schema, ) -> Result> { - pred.transform_down(|expr| { - let expr_any = expr.as_any(); - - if let Some(column) = expr_any.downcast_ref::() { - let index = match schema.index_of(column.name()) { - Ok(idx) => idx, - Err(_) if ignore_not_found => usize::MAX, - Err(e) => return Err(e.into()), - }; + expr.transform_down(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + let index = schema.index_of(column.name())?; + return Ok(Transformed::yes(Arc::new(Column::new( column.name(), index, @@ -241,15 +265,6 @@ pub fn reassign_predicate_columns( .data() } -/// Merge left and right sort expressions, checking for duplicates. -pub fn merge_vectors(left: &LexOrdering, right: &LexOrdering) -> LexOrdering { - left.iter() - .cloned() - .chain(right.iter().cloned()) - .unique() - .collect() -} - #[cfg(test)] pub(crate) mod tests { use std::any::Any; @@ -260,7 +275,7 @@ pub(crate) mod tests { use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{exec_err, DataFusionError, ScalarValue}; + use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -268,7 +283,7 @@ pub(crate) mod tests { use petgraph::visit::Bfs; - #[derive(Debug, Clone)] + #[derive(Debug, PartialEq, Eq, Hash)] pub struct TestScalarUDF { pub(crate) signature: Signature, } @@ -320,11 +335,11 @@ pub(crate) mod tests { .as_any() .downcast_ref::() .ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "could not cast {} to {}", self.name(), std::any::type_name::() - )) + ) })?; arg.iter() @@ -336,11 +351,11 @@ pub(crate) mod tests { .as_any() .downcast_ref::() .ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "could not cast {} to {}", self.name(), std::any::type_name::() - )) + ) })?; arg.iter() @@ -492,7 +507,7 @@ pub(crate) mod tests { } #[test] - fn test_reassign_predicate_columns_in_list() { + fn test_reassign_expr_columns_in_list() { let int_field = Field::new("should_not_matter", DataType::Int64, true); let dict_field = Field::new( "id", @@ -512,7 +527,7 @@ pub(crate) mod tests { ) .unwrap(); - let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap(); + let actual = reassign_expr_columns(pred, &schema_small).unwrap(); let expected = in_list( Arc::new(Column::new_with_schema("id", &schema_small).unwrap()), diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index a94d5b1212f52..2ed9770902d58 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -23,18 +23,19 @@ use std::sync::Arc; use crate::aggregate::AggregateFunctionExpr; use crate::window::standard::add_new_ordering_expr_with_partition_by; -use crate::window::window_expr::AggregateWindowExpr; +use crate::window::window_expr::{filter_array, AggregateWindowExpr, WindowFn}; use crate::window::{ PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr, }; -use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; +use crate::{EquivalenceProperties, PhysicalExpr}; -use arrow::array::Array; +use arrow::array::ArrayRef; +use arrow::array::BooleanArray; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; -use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, WindowFrame}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::{exec_datafusion_err, Result, ScalarValue}; +use datafusion_expr::{Accumulator, WindowFrame, WindowFrameBound, WindowFrameUnits}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// A window expr that takes the form of an aggregate function. /// @@ -43,8 +44,10 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; pub struct PlainAggregateWindowExpr { aggregate: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, + is_constant_in_partition: bool, + filter: Option>, } impl PlainAggregateWindowExpr { @@ -52,14 +55,19 @@ impl PlainAggregateWindowExpr { pub fn new( aggregate: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, + filter: Option>, ) -> Self { + let is_constant_in_partition = + Self::is_window_constant_in_partition(order_by, &window_frame); Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, + is_constant_in_partition, + filter, } } @@ -72,7 +80,7 @@ impl PlainAggregateWindowExpr { &self, eq_properties: &mut EquivalenceProperties, window_expr_index: usize, - ) { + ) -> Result<()> { if let Some(expr) = self .get_aggregate_expr() .get_result_ordering(window_expr_index) @@ -81,8 +89,33 @@ impl PlainAggregateWindowExpr { eq_properties, expr, &self.partition_by, - ); + )?; } + Ok(()) + } + + // Returns true if every row in the partition has the same window frame. This allows + // for preventing bound + function calculation for every row due to the values being the + // same. + // + // This occurs when both bounds fall under either condition below: + // 1. Bound is unbounded (`Preceding` or `Following`) + // 2. Bound is `CurrentRow` while using `Range` units with no order by clause + // This results in an invalid range specification. Following PostgreSQL’s convention, + // we interpret this as the entire partition being used for the current window frame. + fn is_window_constant_in_partition( + order_by: &[PhysicalSortExpr], + window_frame: &WindowFrame, + ) -> bool { + let is_constant_bound = |bound: &WindowFrameBound| match bound { + WindowFrameBound::CurrentRow => { + window_frame.units == WindowFrameUnits::Range && order_by.is_empty() + } + _ => bound.is_unbounded(), + }; + + is_constant_bound(&window_frame.start_bound) + && is_constant_bound(&window_frame.end_bound) } } @@ -95,7 +128,7 @@ impl WindowExpr for PlainAggregateWindowExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { Ok(self.aggregate.field()) } @@ -124,10 +157,9 @@ impl WindowExpr for PlainAggregateWindowExpr { // This enables us to run queries involving UNBOUNDED PRECEDING frames // using bounded memory for suitable aggregations. for partition_row in partition_batches.keys() { - let window_state = - window_agg_state.get_mut(partition_row).ok_or_else(|| { - DataFusionError::Execution("Cannot find state".to_string()) - })?; + let window_state = window_agg_state + .get_mut(partition_row) + .ok_or_else(|| exec_datafusion_err!("Cannot find state"))?; let state = &mut window_state.state; if self.window_frame.start_bound.is_unbounded() { state.window_frame_range.start = @@ -141,8 +173,8 @@ impl WindowExpr for PlainAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn get_window_frame(&self) -> &Arc { @@ -156,15 +188,25 @@ impl WindowExpr for PlainAggregateWindowExpr { Arc::new(PlainAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), + self.filter.clone(), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), + self.filter.clone(), )) as _ } }) @@ -173,6 +215,10 @@ impl WindowExpr for PlainAggregateWindowExpr { fn uses_bounded_memory(&self) -> bool { !self.window_frame.end_bound.is_unbounded() } + + fn create_window_fn(&self) -> Result { + Ok(WindowFn::Aggregate(self.get_accumulator()?)) + } } impl AggregateWindowExpr for PlainAggregateWindowExpr { @@ -180,6 +226,10 @@ impl AggregateWindowExpr for PlainAggregateWindowExpr { self.aggregate.create_accumulator() } + fn filter_expr(&self) -> Option<&Arc> { + self.filter.as_ref() + } + /// For a given range, calculate accumulation result inside the range on /// `value_slice` and update accumulator state. // We assume that `cur_range` contains `last_range` and their start points @@ -191,6 +241,7 @@ impl AggregateWindowExpr for PlainAggregateWindowExpr { cur_range: &Range, value_slice: &[ArrayRef], accumulator: &mut Box, + filter_mask: Option<&BooleanArray>, ) -> Result { if cur_range.start == cur_range.end { self.aggregate @@ -203,13 +254,23 @@ impl AggregateWindowExpr for PlainAggregateWindowExpr { // same point (i.e. the beginning of the table/frame). Hence, we // do not call `retract_batch`. if update_bound > 0 { + let slice_mask = + filter_mask.map(|m| m.slice(last_range.end, update_bound)); let update: Vec = value_slice .iter() .map(|v| v.slice(last_range.end, update_bound)) - .collect(); + .map(|arr| match &slice_mask { + Some(m) => filter_array(&arr, m), + None => Ok(arr), + }) + .collect::>>()?; accumulator.update_batch(&update)? } accumulator.evaluate() } } + + fn is_constant_in_partition(&self) -> bool { + self.is_constant_in_partition + } } diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index bc7c716783bdc..b45e35440ac20 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -21,12 +21,6 @@ mod standard; mod standard_window_function_expr; mod window_expr; -#[deprecated(since = "44.0.0", note = "use StandardWindowExpr")] -pub type BuiltInWindowExpr = StandardWindowExpr; - -#[deprecated(since = "44.0.0", note = "use StandardWindowFunctionExpr")] -pub type BuiltInWindowFunctionExpr = dyn StandardWindowFunctionExpr; - pub use aggregate::PlainAggregateWindowExpr; pub use sliding_aggregate::SlidingAggregateWindowExpr; pub use standard::StandardWindowExpr; diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 23967e78f07a7..f93b13fef4dff 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -22,18 +22,17 @@ use std::ops::Range; use std::sync::Arc; use crate::aggregate::AggregateFunctionExpr; -use crate::window::window_expr::AggregateWindowExpr; +use crate::window::window_expr::{filter_array, AggregateWindowExpr, WindowFn}; use crate::window::{ PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr, }; -use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; +use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::Field; +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrame}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; /// A window expr that takes the form of an aggregate function that /// can be incrementally computed over sliding windows. @@ -43,8 +42,9 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; pub struct SlidingAggregateWindowExpr { aggregate: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, + filter: Option>, } impl SlidingAggregateWindowExpr { @@ -52,14 +52,16 @@ impl SlidingAggregateWindowExpr { pub fn new( aggregate: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, + filter: Option>, ) -> Self { Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, + filter, } } @@ -80,7 +82,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { Ok(self.aggregate.field()) } @@ -108,8 +110,8 @@ impl WindowExpr for SlidingAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn get_window_frame(&self) -> &Arc { @@ -123,15 +125,25 @@ impl WindowExpr for SlidingAggregateWindowExpr { Arc::new(PlainAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), + self.filter.clone(), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), + self.filter.clone(), )) as _ } }) @@ -157,7 +169,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { expr: new_expr, options: req.options, }) - .collect::(); + .collect(); Some(Arc::new(SlidingAggregateWindowExpr { aggregate: self .aggregate @@ -166,8 +178,13 @@ impl WindowExpr for SlidingAggregateWindowExpr { partition_by: partition_bys, order_by: new_order_by, window_frame: Arc::clone(&self.window_frame), + filter: self.filter.clone(), })) } + + fn create_window_fn(&self) -> Result { + Ok(WindowFn::Aggregate(self.get_accumulator()?)) + } } impl AggregateWindowExpr for SlidingAggregateWindowExpr { @@ -175,6 +192,10 @@ impl AggregateWindowExpr for SlidingAggregateWindowExpr { self.aggregate.create_sliding_accumulator() } + fn filter_expr(&self) -> Option<&Arc> { + self.filter.as_ref() + } + /// Given current range and the last range, calculates the accumulator /// result for the range of interest. fn get_aggregate_result_inside_range( @@ -183,6 +204,7 @@ impl AggregateWindowExpr for SlidingAggregateWindowExpr { cur_range: &Range, value_slice: &[ArrayRef], accumulator: &mut Box, + filter_mask: Option<&BooleanArray>, ) -> Result { if cur_range.start == cur_range.end { self.aggregate @@ -191,23 +213,39 @@ impl AggregateWindowExpr for SlidingAggregateWindowExpr { // Accumulate any new rows that have entered the window: let update_bound = cur_range.end - last_range.end; if update_bound > 0 { + let slice_mask = + filter_mask.map(|m| m.slice(last_range.end, update_bound)); let update: Vec = value_slice .iter() .map(|v| v.slice(last_range.end, update_bound)) - .collect(); + .map(|arr| match &slice_mask { + Some(m) => filter_array(&arr, m), + None => Ok(arr), + }) + .collect::>>()?; accumulator.update_batch(&update)? } // Remove rows that have now left the window: let retract_bound = cur_range.start - last_range.start; if retract_bound > 0 { + let slice_mask = + filter_mask.map(|m| m.slice(last_range.start, retract_bound)); let retract: Vec = value_slice .iter() .map(|v| v.slice(last_range.start, retract_bound)) - .collect(); + .map(|arr| match &slice_mask { + Some(m) => filter_array(&arr, m), + None => Ok(arr), + }) + .collect::>>()?; accumulator.retract_batch(&retract)? } accumulator.evaluate() } } + + fn is_constant_in_partition(&self) -> bool { + false + } } diff --git a/datafusion/physical-expr/src/window/standard.rs b/datafusion/physical-expr/src/window/standard.rs index 22e8aea83fe78..e9e7f6abf6368 100644 --- a/datafusion/physical-expr/src/window/standard.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -24,23 +24,23 @@ use std::sync::Arc; use super::{StandardWindowFunctionExpr, WindowExpr}; use crate::window::window_expr::{get_orderby_values, WindowFn}; use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; -use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; +use crate::{EquivalenceProperties, PhysicalExpr}; + use arrow::array::{new_empty_array, ArrayRef}; -use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_expr::WindowFrame; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// A window expr that takes the form of a [`StandardWindowFunctionExpr`]. #[derive(Debug)] pub struct StandardWindowExpr { expr: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, } @@ -49,13 +49,13 @@ impl StandardWindowExpr { pub fn new( expr: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, ) -> Self { Self { expr, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, } } @@ -70,15 +70,19 @@ impl StandardWindowExpr { /// If `self.expr` doesn't have an ordering, ordering equivalence properties /// are not updated. Otherwise, ordering equivalence properties are updated /// by the ordering of `self.expr`. - pub fn add_equal_orderings(&self, eq_properties: &mut EquivalenceProperties) { + pub fn add_equal_orderings( + &self, + eq_properties: &mut EquivalenceProperties, + ) -> Result<()> { let schema = eq_properties.schema(); if let Some(fn_res_ordering) = self.expr.get_result_ordering(schema) { add_new_ordering_expr_with_partition_by( eq_properties, fn_res_ordering, &self.partition_by, - ); + )?; } + Ok(()) } } @@ -92,7 +96,7 @@ impl WindowExpr for StandardWindowExpr { self.expr.name() } - fn field(&self) -> Result { + fn field(&self) -> Result { self.expr.field() } @@ -104,16 +108,15 @@ impl WindowExpr for StandardWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn evaluate(&self, batch: &RecordBatch) -> Result { let mut evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); if evaluator.uses_window_frame() { - let sort_options: Vec = - self.order_by.iter().map(|o| o.options).collect(); + let sort_options = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; let mut values = self.evaluate_args(batch)?; @@ -158,6 +161,9 @@ impl WindowExpr for StandardWindowExpr { let field = self.expr.field()?; let out_type = field.data_type(); let sort_options = self.order_by.iter().map(|o| o.options).collect::>(); + // create a WindowAggState to clone when `window_agg_state` does not contain the respective + // group, which is faster than potentially creating a new one at every iteration + let new_state = WindowAggState::new(out_type)?; for (partition_row, partition_batch_state) in partition_batches.iter() { let window_state = if let Some(window_state) = window_agg_state.get_mut(partition_row) { @@ -167,7 +173,7 @@ impl WindowExpr for StandardWindowExpr { window_agg_state .entry(partition_row.clone()) .or_insert(WindowState { - state: WindowAggState::new(out_type)?, + state: new_state.clone(), window_fn: WindowFn::Builtin(evaluator), }) }; @@ -232,6 +238,9 @@ impl WindowExpr for StandardWindowExpr { } let out_col = if row_wise_results.is_empty() { new_empty_array(out_type) + } else if row_wise_results.len() == 1 { + // fast path when the result only has a single row + row_wise_results[0].to_array()? } else { ScalarValue::iter_to_array(row_wise_results.into_iter())? }; @@ -253,7 +262,11 @@ impl WindowExpr for StandardWindowExpr { Arc::new(StandardWindowExpr::new( reverse_expr, &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ }) @@ -268,6 +281,10 @@ impl WindowExpr for StandardWindowExpr { false } } + + fn create_window_fn(&self) -> Result { + Ok(WindowFn::Builtin(self.expr.create_evaluator()?)) + } } /// Adds a new ordering expression into existing ordering equivalence class(es) based on @@ -276,10 +293,10 @@ pub(crate) fn add_new_ordering_expr_with_partition_by( eqp: &mut EquivalenceProperties, expr: PhysicalSortExpr, partition_by: &[Arc], -) { +) -> Result<()> { if partition_by.is_empty() { // In the absence of a PARTITION BY, ordering of `self.expr` is global: - eqp.add_new_orderings([LexOrdering::new(vec![expr])]); + eqp.add_ordering([expr]); } else { // If we have a PARTITION BY, standard functions can not introduce // a global ordering unless the existing ordering is compatible @@ -287,10 +304,11 @@ pub(crate) fn add_new_ordering_expr_with_partition_by( // expressions and existing ordering expressions are equal (w.r.t. // set equality), we can prefix the ordering of `self.expr` with // the existing ordering. - let (mut ordering, _) = eqp.find_longest_permutation(partition_by); + let (mut ordering, _) = eqp.find_longest_permutation(partition_by)?; if ordering.len() == partition_by.len() { ordering.push(expr); - eqp.add_new_orderings([ordering]); + eqp.add_ordering(ordering); } } + Ok(()) } diff --git a/datafusion/physical-expr/src/window/standard_window_function_expr.rs b/datafusion/physical-expr/src/window/standard_window_function_expr.rs index 624b747d93f9a..ca7c3a4db3d4f 100644 --- a/datafusion/physical-expr/src/window/standard_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/standard_window_function_expr.rs @@ -18,10 +18,10 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; -use arrow::datatypes::{Field, SchemaRef}; +use arrow::datatypes::{FieldRef, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; +use datafusion_expr::{LimitEffect, PartitionEvaluator}; use std::any::Any; use std::sync::Arc; @@ -41,7 +41,7 @@ pub trait StandardWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn as_any(&self) -> &dyn Any; /// The field of the final result of evaluating this window function. - fn field(&self) -> Result; + fn field(&self) -> Result; /// Expressions that are passed to the [`PartitionEvaluator`]. fn expressions(&self) -> Vec>; @@ -90,4 +90,6 @@ pub trait StandardWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn get_result_ordering(&self, _schema: &SchemaRef) -> Option { None } + + fn limit_effect(&self) -> LimitEffect; } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 793f2e5ee5867..a6b5bf1871161 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,19 +20,26 @@ use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -use crate::{LexOrdering, PhysicalExpr}; +use crate::PhysicalExpr; +use arrow::array::BooleanArray; use arrow::array::{new_empty_array, Array, ArrayRef}; +use arrow::compute::filter as arrow_filter; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; +use datafusion_common::cast::as_boolean_array; use datafusion_common::utils::compare_rows; -use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + arrow_datafusion_err, exec_datafusion_err, internal_err, DataFusionError, Result, + ScalarValue, +}; use datafusion_expr::window_state::{ PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups, }; use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use indexmap::IndexMap; @@ -67,7 +74,7 @@ pub trait WindowExpr: Send + Sync + Debug { fn as_any(&self) -> &dyn Any; /// The field of the final result of this window function. - fn field(&self) -> Result; + fn field(&self) -> Result; /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default /// implementation returns placeholder text. @@ -109,14 +116,14 @@ pub trait WindowExpr: Send + Sync + Debug { fn partition_by(&self) -> &[Arc]; /// Expressions that's from the window function's order by clause, empty if absent - fn order_by(&self) -> &LexOrdering; + fn order_by(&self) -> &[PhysicalSortExpr]; /// Get order by columns, empty if absent fn order_by_columns(&self, batch: &RecordBatch) -> Result> { self.order_by() .iter() .map(|e| e.evaluate_to_sort_column(batch)) - .collect::>>() + .collect() } /// Get the window frame of this [WindowExpr]. @@ -129,6 +136,12 @@ pub trait WindowExpr: Send + Sync + Debug { /// Get the reverse expression of this [WindowExpr]. fn get_reverse_expr(&self) -> Option>; + /// Creates a new instance of the window function evaluator. + /// + /// Returns `WindowFn::Builtin` for built-in window functions (e.g., ROW_NUMBER, RANK) + /// or `WindowFn::Aggregate` for aggregate window functions (e.g., SUM, AVG). + fn create_window_fn(&self) -> Result; + /// Returns all expressions used in the [`WindowExpr`]. /// These expressions are (1) function arguments, (2) partition by expressions, (3) order by expressions. fn all_expressions(&self) -> WindowPhysicalExpressions { @@ -138,7 +151,7 @@ pub trait WindowExpr: Send + Sync + Debug { .order_by() .iter() .map(|sort_expr| Arc::clone(&sort_expr.expr)) - .collect::>(); + .collect(); WindowPhysicalExpressions { args, partition_by_exprs, @@ -176,6 +189,9 @@ pub trait AggregateWindowExpr: WindowExpr { /// (non-sliding) expressions will return sliding (normal) accumulators. fn get_accumulator(&self) -> Result>; + /// Optional FILTER (WHERE ...) predicate for this window aggregate. + fn filter_expr(&self) -> Option<&Arc>; + /// Given current range and the last range, calculates the accumulator /// result for the range of interest. fn get_aggregate_result_inside_range( @@ -184,14 +200,18 @@ pub trait AggregateWindowExpr: WindowExpr { cur_range: &Range, value_slice: &[ArrayRef], accumulator: &mut Box, + filter_mask: Option<&BooleanArray>, ) -> Result; + /// Indicates whether this window function always produces the same result + /// for all rows in the partition. + fn is_constant_in_partition(&self) -> bool; + /// Evaluates the window function against the batch. fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result { let mut accumulator = self.get_accumulator()?; let mut last_range = Range { start: 0, end: 0 }; - let sort_options: Vec = - self.order_by().iter().map(|o| o.options).collect(); + let sort_options = self.order_by().iter().map(|o| o.options).collect(); let mut window_frame_ctx = WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options); self.get_result_column( @@ -225,10 +245,9 @@ pub trait AggregateWindowExpr: WindowExpr { }, ); }; - let window_state = - window_agg_state.get_mut(partition_row).ok_or_else(|| { - DataFusionError::Execution("Cannot find state".to_string()) - })?; + let window_state = window_agg_state + .get_mut(partition_row) + .ok_or_else(|| exec_datafusion_err!("Cannot find state"))?; let accumulator = match &mut window_state.window_fn { WindowFn::Aggregate(accumulator) => accumulator, _ => unreachable!(), @@ -239,8 +258,7 @@ pub trait AggregateWindowExpr: WindowExpr { // If there is no window state context, initialize it. let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { - let sort_options: Vec = - self.order_by().iter().map(|o| o.options).collect(); + let sort_options = self.order_by().iter().map(|o| o.options).collect(); WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options) }); let out_col = self.get_result_column( @@ -260,6 +278,15 @@ pub trait AggregateWindowExpr: WindowExpr { /// Calculates the window expression result for the given record batch. /// Assumes that `record_batch` belongs to a single partition. + /// + /// # Arguments + /// * `accumulator`: The accumulator to use for the calculation. + /// * `record_batch`: batch belonging to the current partition (see [`PartitionBatchState`]). + /// * `most_recent_row`: the batch that contains the most recent row, if available (see [`PartitionBatchState`]). + /// * `last_range`: The last range of rows that were processed (see [`WindowAggState`]). + /// * `window_frame_ctx`: Details about the window frame (see [`WindowFrameContext`]). + /// * `idx`: The index of the current row in the record batch. + /// * `not_end`: is the current row not the end of the partition (see [`PartitionBatchState`]). #[allow(clippy::too_many_arguments)] fn get_result_column( &self, @@ -272,8 +299,39 @@ pub trait AggregateWindowExpr: WindowExpr { not_end: bool, ) -> Result { let values = self.evaluate_args(record_batch)?; - let order_bys = get_orderby_values(self.order_by_columns(record_batch)?); + // Evaluate filter mask once per record batch if present + let filter_mask_arr: Option = match self.filter_expr() { + Some(expr) => { + let value = expr.evaluate(record_batch)?; + Some(value.into_array(record_batch.num_rows())?) + } + None => None, + }; + + // Borrow boolean view from the owned array + let filter_mask: Option<&BooleanArray> = match filter_mask_arr.as_deref() { + Some(arr) => Some(as_boolean_array(arr)?), + None => None, + }; + + if self.is_constant_in_partition() { + if not_end { + let field = self.field()?; + let out_type = field.data_type(); + return Ok(new_empty_array(out_type)); + } + let values = if let Some(mask) = filter_mask { + // Apply mask to all argument arrays before a single update + filter_arrays(&values, mask)? + } else { + values + }; + accumulator.update_batch(&values)?; + let value = accumulator.evaluate()?; + return value.to_array_of_size(record_batch.num_rows()); + } + let order_bys = get_orderby_values(self.order_by_columns(record_batch)?); let most_recent_row_order_bys = most_recent_row .map(|batch| self.order_by_columns(batch)) .transpose()? @@ -306,6 +364,7 @@ pub trait AggregateWindowExpr: WindowExpr { &cur_range, &values, accumulator, + filter_mask, )?; // Update last range *last_range = cur_range; @@ -323,6 +382,21 @@ pub trait AggregateWindowExpr: WindowExpr { } } +/// Filters a single array with the provided boolean mask. +pub(crate) fn filter_array(array: &ArrayRef, mask: &BooleanArray) -> Result { + arrow_filter(array.as_ref(), mask) + .map(|a| a as ArrayRef) + .map_err(|e| arrow_datafusion_err!(e)) +} + +/// Filters a list of arrays with the provided boolean mask. +pub(crate) fn filter_arrays( + arrays: &[ArrayRef], + mask: &BooleanArray, +) -> Result> { + arrays.iter().map(|arr| filter_array(arr, mask)).collect() +} + /// Determines whether the end bound calculation for a window frame context is /// safe, meaning that the end bound stays the same, regardless of future data, /// based on the current sort expressions and ORDER BY columns. This function @@ -344,13 +418,13 @@ pub(crate) fn is_end_bound_safe( window_frame_ctx: &WindowFrameContext, order_bys: &[ArrayRef], most_recent_order_bys: Option<&[ArrayRef]>, - sort_exprs: &LexOrdering, + sort_exprs: &[PhysicalSortExpr], idx: usize, ) -> Result { if sort_exprs.is_empty() { // Early return if no sort expressions are present: return Ok(false); - } + }; match window_frame_ctx { WindowFrameContext::Rows(window_frame) => { diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index aaadb09bcc98a..15466cd86bb04 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -39,18 +39,18 @@ recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } -datafusion-common = { workspace = true, default-features = true } +datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } +datafusion-pruning = { workspace = true } itertools = { workspace = true } -log = { workspace = true } recursive = { workspace = true, optional = true } [dev-dependencies] datafusion-expr = { workspace = true } -datafusion-functions-nested = { workspace = true } insta = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/README.md b/datafusion/physical-optimizer/README.md index eb361d3f67792..3efbc19d2e724 100644 --- a/datafusion/physical-optimizer/README.md +++ b/datafusion/physical-optimizer/README.md @@ -17,9 +17,16 @@ under the License. --> -# DataFusion Physical Optimizer +# Apache DataFusion Physical Optimizer -DataFusion is an extensible query execution framework, written in Rust, -that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate contains the physical optimizer for DataFusion. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 0d3d83c58373f..672317060d902 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -22,7 +22,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; use datafusion_physical_plan::{expressions, ExecutionPlan}; use std::sync::Arc; @@ -42,6 +42,7 @@ impl AggregateStatistics { impl PhysicalOptimizerRule for AggregateStatistics { #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + #[allow(clippy::only_used_in_recursion)] // See https://github.com/rust-lang/rust-clippy/issues/14566 fn optimize( &self, plan: Arc, @@ -52,7 +53,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { .as_any() .downcast_ref::() .expect("take_optimizable() ensures that this is a AggregateExec"); - let stats = partial_agg_exec.input().statistics()?; + let stats = partial_agg_exec.input().partition_statistics(None)?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { let field = expr.field(); @@ -66,8 +67,10 @@ impl PhysicalOptimizerRule for AggregateStatistics { if let Some((optimizable_statistic, name)) = take_optimizable_value_from_statistics(&statistics_args, expr) { - projections - .push((expressions::lit(optimizable_statistic), name.to_owned())); + projections.push(ProjectionExpr { + expr: expressions::lit(optimizable_statistic), + alias: name.to_owned(), + }); } else { // TODO: we need all aggr_expr to be resolved (cf TODO fullres) break; diff --git a/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs b/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs new file mode 100644 index 0000000000000..0b46c68f2daed --- /dev/null +++ b/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::PhysicalOptimizerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::internal_err; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_plan::async_func::AsyncFuncExec; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +/// Optimizer rule that introduces CoalesceAsyncExec to reduce the number of async executions. +#[derive(Default, Debug)] +pub struct CoalesceAsyncExecInput {} + +impl CoalesceAsyncExecInput { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } +} + +impl PhysicalOptimizerRule for CoalesceAsyncExecInput { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> datafusion_common::Result> { + let target_batch_size = config.execution.batch_size; + plan.transform(|plan| { + if let Some(async_exec) = plan.as_any().downcast_ref::() { + if async_exec.children().len() != 1 { + return internal_err!( + "Expected AsyncFuncExec to have exactly one child" + ); + } + let child = Arc::clone(async_exec.children()[0]); + let coalesce_exec = + Arc::new(CoalesceBatchesExec::new(child, target_batch_size)); + let coalesce_async_exec = plan.with_new_children(vec![coalesce_exec])?; + Ok(Transformed::yes(coalesce_async_exec)) + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "coalesce_async_exec_input" + } + + fn schema_check(&self) -> bool { + true + } +} diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 5e76edad1f569..898386e2f9880 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -42,7 +42,6 @@ use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ physical_exprs_equal, EquivalenceProperties, PhysicalExpr, PhysicalExprRef, }; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -51,7 +50,7 @@ use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::{ CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, }; -use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::tree_node::PlanContext; @@ -296,7 +295,7 @@ pub fn adjust_input_keys_ordering( join_type, projection, mode, - null_equals_null, + null_equality, .. }) = plan.as_any().downcast_ref::() { @@ -315,7 +314,7 @@ pub fn adjust_input_keys_ordering( // TODO: although projection is not used in the join here, because projection pushdown is after enforce_distribution. Maybe we need to handle it later. Same as filter. projection.clone(), PartitionMode::Partitioned, - *null_equals_null, + *null_equality, ) .map(|e| Arc::new(e) as _) }; @@ -335,7 +334,7 @@ pub fn adjust_input_keys_ordering( left.schema().fields().len(), ) .unwrap_or_default(), - JoinType::RightSemi | JoinType::RightAnti => { + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { requirements.data.clone() } JoinType::Left @@ -365,7 +364,7 @@ pub fn adjust_input_keys_ordering( filter, join_type, sort_options, - null_equals_null, + null_equality, .. }) = plan.as_any().downcast_ref::() { @@ -380,7 +379,7 @@ pub fn adjust_input_keys_ordering( filter.clone(), *join_type, new_conditions.1, - *null_equals_null, + *null_equality, ) .map(|e| Arc::new(e) as _) }; @@ -408,7 +407,11 @@ pub fn adjust_input_keys_ordering( // For Projection, we need to transform the requirements to the columns before the Projection // And then to push down the requirements // Construct a mapping from new name to the original Column - let new_required = map_columns_before_projection(&requirements.data, expr); + let proj_exprs: Vec<_> = expr + .iter() + .map(|p| (Arc::clone(&p.expr), p.alias.clone())) + .collect(); + let new_required = map_columns_before_projection(&requirements.data, &proj_exprs); if new_required.len() == requirements.data.len() { requirements.children[0].data = new_required; } else { @@ -545,7 +548,10 @@ pub fn reorder_aggregate_keys( .map(|col| { let name = col.name(); let index = agg_schema.index_of(name)?; - Ok((Arc::new(Column::new(name, index)) as _, name.to_owned())) + Ok(ProjectionExpr { + expr: Arc::new(Column::new(name, index)) as _, + alias: name.to_owned(), + }) }) .collect::>>()?; let agg_fields = agg_schema.fields(); @@ -554,7 +560,10 @@ pub fn reorder_aggregate_keys( { let name = field.name(); let plan = Arc::new(Column::new(name, idx)) as _; - proj_exprs.push((plan, name.clone())) + proj_exprs.push(ProjectionExpr { + expr: plan, + alias: name.clone(), + }) } return ProjectionExec::try_new(proj_exprs, new_final_agg).map(|p| { PlanWithKeyRequirements::new(Arc::new(p), vec![], vec![agg_node]) @@ -617,7 +626,7 @@ pub fn reorder_join_keys_to_inputs( join_type, projection, mode, - null_equals_null, + null_equality, .. }) = plan_any.downcast_ref::() { @@ -643,7 +652,7 @@ pub fn reorder_join_keys_to_inputs( join_type, projection.clone(), PartitionMode::Partitioned, - *null_equals_null, + *null_equality, )?)); } } @@ -654,7 +663,7 @@ pub fn reorder_join_keys_to_inputs( filter, join_type, sort_options, - null_equals_null, + null_equality, .. }) = plan_any.downcast_ref::() { @@ -682,7 +691,7 @@ pub fn reorder_join_keys_to_inputs( filter.clone(), *join_type, new_sort_options, - *null_equals_null, + *null_equality, ) .map(|smj| Arc::new(smj) as _); } @@ -837,7 +846,7 @@ fn new_join_conditions( /// /// * `input`: Current node. /// * `n_target`: desired target partition number, if partition number of the -/// current executor is less than this value. Partition number will be increased. +/// current executor is less than this value. Partition number will be increased. /// /// # Returns /// @@ -880,7 +889,7 @@ fn add_roundrobin_on_top( /// * `input`: Current node. /// * `hash_exprs`: Stores Physical Exprs that are used during hashing. /// * `n_target`: desired target partition number, if partition number of the -/// current executor is less than this value. Partition number will be increased. +/// current executor is less than this value. Partition number will be increased. /// /// # Returns /// @@ -926,38 +935,34 @@ fn add_hash_on_top( Ok(input) } -/// Adds a [`SortPreservingMergeExec`] operator on top of input executor -/// to satisfy single distribution requirement. +/// Adds a [`SortPreservingMergeExec`] or a [`CoalescePartitionsExec`] operator +/// on top of the given plan node to satisfy a single partition requirement +/// while preserving ordering constraints. /// -/// # Arguments +/// # Parameters /// /// * `input`: Current node. /// /// # Returns /// -/// Updated node with an execution plan, where desired single -/// distribution is satisfied by adding [`SortPreservingMergeExec`]. -fn add_spm_on_top(input: DistributionContext) -> DistributionContext { - // Add SortPreservingMerge only when partition count is larger than 1. +/// Updated node with an execution plan, where the desired single distribution +/// requirement is satisfied. +fn add_merge_on_top(input: DistributionContext) -> DistributionContext { + // Apply only when the partition count is larger than one. if input.plan.output_partitioning().partition_count() > 1 { // When there is an existing ordering, we preserve ordering // when decreasing partitions. This will be un-done in the future // if any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.bounded_order_preserving_variants`) - let should_preserve_ordering = input.plan.output_ordering().is_some(); - - let new_plan = if should_preserve_ordering { + // (determined by flag `config.optimizer.prefer_existing_sort`) + let new_plan = if let Some(req) = input.plan.output_ordering() { Arc::new(SortPreservingMergeExec::new( - input - .plan - .output_ordering() - .unwrap_or(&LexOrdering::default()) - .clone(), + req.clone(), Arc::clone(&input.plan), )) as _ } else { + // If there is no input order, we can simply coalesce partitions: Arc::new(CoalescePartitionsExec::new(Arc::clone(&input.plan))) as _ }; @@ -1018,7 +1023,7 @@ fn remove_dist_changing_operators( /// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", /// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet", /// ``` -fn replace_order_preserving_variants( +pub fn replace_order_preserving_variants( mut context: DistributionContext, ) -> Result { context.children = context @@ -1035,7 +1040,9 @@ fn replace_order_preserving_variants( if is_sort_preserving_merge(&context.plan) { let child_plan = Arc::clone(&context.children[0].plan); - context.plan = Arc::new(CoalescePartitionsExec::new(child_plan)); + context.plan = Arc::new( + CoalescePartitionsExec::new(child_plan).with_fetch(context.plan.fetch()), + ); return Ok(context); } else if let Some(repartition) = context.plan.as_any().downcast_ref::() @@ -1112,7 +1119,8 @@ fn get_repartition_requirement_status( { // Decide whether adding a round robin is beneficial depending on // the statistical information we have on the number of rows: - let roundrobin_beneficial_stats = match child.statistics()?.num_rows { + let roundrobin_beneficial_stats = match child.partition_statistics(None)?.num_rows + { Precision::Exact(n_rows) => n_rows > batch_size, Precision::Inexact(n_rows) => !should_use_estimates || (n_rows > batch_size), Precision::Absent => true, @@ -1155,6 +1163,10 @@ fn get_repartition_requirement_status( /// operators to satisfy distribution requirements. Since this function /// takes care of such requirements, we should avoid manually adding data /// exchange operators in other places. +/// +/// This function is intended to be used in a bottom up traversal, as it +/// can first repartition (or newly partition) at the datasources -- these +/// source partitions may be later repartitioned with additional data exchange operators. pub fn ensure_distribution( dist_context: DistributionContext, config: &ConfigOptions, @@ -1244,6 +1256,10 @@ pub fn ensure_distribution( // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. + // + // If repartitioning is not possible (a.k.a. None is returned from `ExecutionPlan::repartitioned`) + // then no repartitioning will have occurred. As the default implementation returns None, it is only + // specific physical plan nodes, such as certain datasources, which are repartitioned. if repartition_file_scans && roundrobin_beneficial_stats { if let Some(new_child) = child.plan.repartitioned(target_partitions, config)? @@ -1255,7 +1271,7 @@ pub fn ensure_distribution( // Satisfy the distribution requirement if it is unmet. match &requirement { Distribution::SinglePartition => { - child = add_spm_on_top(child); + child = add_merge_on_top(child); } Distribution::HashPartitioned(exprs) => { if add_roundrobin { @@ -1283,10 +1299,12 @@ pub fn ensure_distribution( // Either: // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or // - using order preserving variant is not desirable. + let sort_req = required_input_ordering.into_single(); let ordering_satisfied = child .plan .equivalence_properties() - .ordering_satisfy_requirement(&required_input_ordering); + .ordering_satisfy_requirement(sort_req.clone())?; + if (!ordering_satisfied || !order_preserving_variants_desirable) && child.data { @@ -1297,9 +1315,12 @@ pub fn ensure_distribution( // Make sure to satisfy ordering requirement: child = add_sort_above_with_check( child, - required_input_ordering.clone(), - None, - ); + sort_req, + plan.as_any() + .downcast_ref::() + .map(|output| output.fetch()) + .unwrap_or(None), + )?; } } // Stop tracking distribution changing operators diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 20733b65692fc..8a71b28486a2a 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -46,6 +46,7 @@ use crate::enforce_sorting::replace_with_order_preserving_variants::{ use crate::enforce_sorting::sort_pushdown::{ assign_initial_requirements, pushdown_sorts, SortPushDown, }; +use crate::output_requirements::OutputRequirementExec; use crate::utils::{ add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, @@ -191,14 +192,20 @@ fn update_coalesce_ctx_children( } /// Performs optimizations based upon a series of subrules. -/// /// Refer to each subrule for detailed descriptions of the optimizations performed: -/// [`ensure_sorting`], [`parallelize_sorts`], [`replace_with_order_preserving_variants()`], -/// and [`pushdown_sorts`]. -/// /// Subrule application is ordering dependent. /// -/// The subrule `parallelize_sorts` is only applied if `repartition_sorts` is enabled. +/// Optimizer consists of 5 main parts which work sequentially +/// 1. [`ensure_sorting`] Works down-to-top to be able to remove unnecessary [`SortExec`]s, [`SortPreservingMergeExec`]s +/// add [`SortExec`]s if necessary by a requirement and adjusts window operators. +/// 2. [`parallelize_sorts`] (Optional, depends on the `repartition_sorts` configuration) +/// Responsible to identify and remove unnecessary partition unifier operators +/// such as [`SortPreservingMergeExec`], [`CoalescePartitionsExec`] follows [`SortExec`]s does possible simplifications. +/// 3. [`replace_with_order_preserving_variants()`] Replaces with alternative operators, for example can merge +/// a [`SortExec`] and a [`CoalescePartitionsExec`] into one [`SortPreservingMergeExec`] +/// or a [`SortExec`] + [`RepartitionExec`] combination into an order preserving [`RepartitionExec`] +/// 4. [`sort_pushdown`] Works top-down. Responsible to push down sort operators as deep as possible in the plan. +/// 5. `replace_with_partial_sort` Checks if it's possible to replace [`SortExec`]s with [`PartialSortExec`] operators impl PhysicalOptimizerRule for EnforceSorting { fn optimize( &self, @@ -251,87 +258,93 @@ impl PhysicalOptimizerRule for EnforceSorting { } } +/// Only interested with [`SortExec`]s and their unbounded children. +/// If the plan is not a [`SortExec`] or its child is not unbounded, returns the original plan. +/// Otherwise, by checking the requirement satisfaction searches for a replacement chance. +/// If there's one replaces the [`SortExec`] plan with a [`PartialSortExec`] fn replace_with_partial_sort( plan: Arc, ) -> Result> { let plan_any = plan.as_any(); - if let Some(sort_plan) = plan_any.downcast_ref::() { - let child = Arc::clone(sort_plan.children()[0]); - if !child.boundedness().is_unbounded() { - return Ok(plan); - } + let Some(sort_plan) = plan_any.downcast_ref::() else { + return Ok(plan); + }; - // here we're trying to find the common prefix for sorted columns that is required for the - // sort and already satisfied by the given ordering - let child_eq_properties = child.equivalence_properties(); - let sort_req = LexRequirement::from(sort_plan.expr().clone()); + // It's safe to get first child of the SortExec + let child = Arc::clone(sort_plan.children()[0]); + if !child.boundedness().is_unbounded() { + return Ok(plan); + } - let mut common_prefix_length = 0; - while child_eq_properties.ordering_satisfy_requirement(&LexRequirement { - inner: sort_req[0..common_prefix_length + 1].to_vec(), - }) { - common_prefix_length += 1; - } - if common_prefix_length > 0 { - return Ok(Arc::new( - PartialSortExec::new( - LexOrdering::new(sort_plan.expr().to_vec()), - Arc::clone(sort_plan.input()), - common_prefix_length, - ) - .with_preserve_partitioning(sort_plan.preserve_partitioning()) - .with_fetch(sort_plan.fetch()), - )); - } + // Here we're trying to find the common prefix for sorted columns that is required for the + // sort and already satisfied by the given ordering + let child_eq_properties = child.equivalence_properties(); + let sort_exprs = sort_plan.expr().clone(); + + let mut common_prefix_length = 0; + while child_eq_properties + .ordering_satisfy(sort_exprs[0..common_prefix_length + 1].to_vec())? + { + common_prefix_length += 1; + } + if common_prefix_length > 0 { + return Ok(Arc::new( + PartialSortExec::new( + sort_exprs, + Arc::clone(sort_plan.input()), + common_prefix_length, + ) + .with_preserve_partitioning(sort_plan.preserve_partitioning()) + .with_fetch(sort_plan.fetch()), + )); } Ok(plan) } -/// Transform [`CoalescePartitionsExec`] + [`SortExec`] into -/// [`SortExec`] + [`SortPreservingMergeExec`] as illustrated below: +/// Transform [`CoalescePartitionsExec`] + [`SortExec`] cascades into [`SortExec`] +/// + [`SortPreservingMergeExec`] cascades, as illustrated below. /// -/// The [`CoalescePartitionsExec`] + [`SortExec`] cascades -/// combine the partitions first, and then sort: +/// A [`CoalescePartitionsExec`] + [`SortExec`] cascade combines partitions +/// first, and then sorts: /// ```text -/// ┌ ─ ─ ─ ─ ─ ┐ -/// ┌─┬─┬─┐ -/// ││B│A│D│... ├──┐ -/// └─┴─┴─┘ │ +/// ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ +/// ││B│A│D│... ├──┐ +/// └─┴─┴─┘ │ /// └ ─ ─ ─ ─ ─ ┘ │ ┌────────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ -/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ +/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ /// ├──▶(no ordering guarantees)│──▶││B│E│A│D│C│...───▶ Sort ├───▶││A│B│C│D│E│... │ -/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ +/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ /// ┌ ─ ─ ─ ─ ─ ┐ │ └────────────────────────┘ └ ─ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ -/// ┌─┬─┐ │ Partition Partition -/// ││E│C│ ... ├──┘ -/// └─┴─┘ -/// └ ─ ─ ─ ─ ─ ┘ -/// Partition 2 -/// ``` +/// ┌─┬─┐ │ Partition Partition +/// ││E│C│ ... ├──┘ +/// └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 +/// ``` /// /// -/// The [`SortExec`] + [`SortPreservingMergeExec`] cascades -/// sorts each partition first, then merge partitions while retaining the sort: +/// A [`SortExec`] + [`SortPreservingMergeExec`] cascade sorts each partition +/// first, then merges partitions while preserving the sort: /// ```text -/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ -/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ -/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ -/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ +/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ +/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ /// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ │ ┌─────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ -/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ +/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ /// ├──▶ SortPreservingMerge ├───▶││A│B│C│D│E│... │ -/// │ │ │ └─┴─┴─┴─┴─┘ +/// │ │ │ └─┴─┴─┴─┴─┘ /// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ │ └─────────────────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ -/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition -/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ -/// └─┴─┘ │ │ └─┴─┘ -/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ -/// Partition 2 Partition 2 +/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition +/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ +/// └─┴─┘ │ │ └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 Partition 2 /// ``` /// -/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs the -/// sort first on a per-partition basis, thereby parallelizing the sort. -/// +/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs +/// sorting first on a per-partition basis, thereby parallelizing the sort. /// /// The outcome is that plans of the form /// ```text @@ -348,16 +361,32 @@ fn replace_with_partial_sort( /// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` /// by following connections from [`CoalescePartitionsExec`]s to [`SortExec`]s. -/// By performing sorting in parallel, we can increase performance in some scenarios. +/// By performing sorting in parallel, we can increase performance in some +/// scenarios. /// -/// This requires that there are no nodes between the [`SortExec`] and [`CoalescePartitionsExec`] -/// which require single partitioning. Do not parallelize when the following scenario occurs: +/// This optimization requires that there are no nodes between the [`SortExec`] +/// and the [`CoalescePartitionsExec`], which requires single partitioning. Do +/// not parallelize when the following scenario occurs: /// ```text /// "SortExec: expr=\[a@0 ASC\]", /// " ...nodes requiring single partitioning..." /// " CoalescePartitionsExec", /// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` +/// +/// **Steps** +/// 1. Checks if the plan is either a [`SortExec`], a [`SortPreservingMergeExec`], +/// or a [`CoalescePartitionsExec`]. Otherwise, does nothing. +/// 2. If the plan is a [`SortExec`] or a final [`SortPreservingMergeExec`] +/// (i.e. output partitioning is 1): +/// - Check for [`CoalescePartitionsExec`] in children. If found, check if +/// it can be removed (with possible [`RepartitionExec`]s). If so, remove +/// (see `remove_bottleneck_in_subplan`). +/// - If the plan is satisfying the ordering requirements, add a `SortExec`. +/// - Add an SPM above the plan and return. +/// 3. If the plan is a [`CoalescePartitionsExec`]: +/// - Check if it can be removed (with possible [`RepartitionExec`]s). +/// If so, remove (see `remove_bottleneck_in_subplan`). pub fn parallelize_sorts( mut requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result> { @@ -388,7 +417,7 @@ pub fn parallelize_sorts( // deals with the children and their children and so on. requirements = requirements.children.swap_remove(0); - requirements = add_sort_above_with_check(requirements, sort_reqs, fetch); + requirements = add_sort_above_with_check(requirements, sort_reqs, fetch)?; let spm = SortPreservingMergeExec::new(sort_exprs, Arc::clone(&requirements.plan)); @@ -400,6 +429,7 @@ pub fn parallelize_sorts( ), )) } else if is_coalesce_partitions(&requirements.plan) { + let fetch = requirements.plan.fetch(); // There is an unnecessary `CoalescePartitionsExec` in the plan. // This will handle the recursive `CoalescePartitionsExec` plans. requirements = remove_bottleneck_in_subplan(requirements)?; @@ -408,7 +438,10 @@ pub fn parallelize_sorts( Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( - Arc::new(CoalescePartitionsExec::new(Arc::clone(&requirements.plan))), + Arc::new( + CoalescePartitionsExec::new(Arc::clone(&requirements.plan)) + .with_fetch(fetch), + ), false, vec![requirements], ), @@ -420,6 +453,25 @@ pub fn parallelize_sorts( /// This function enforces sorting requirements and makes optimizations without /// violating these requirements whenever possible. Requires a bottom-up traversal. +/// +/// **Steps** +/// 1. Analyze if there are any immediate removals of [`SortExec`]s. If so, +/// removes them (see `analyze_immediate_sort_removal`). +/// 2. For each child of the plan, if the plan requires an input ordering: +/// - Checks if ordering is satisfied with the child. If not: +/// - If the child has an output ordering, removes the unnecessary +/// `SortExec`. +/// - Adds sort above the child plan. +/// - (Plan not requires input ordering) +/// - Checks if the `SortExec` is neutralized in the plan. If so, +/// removes it. +/// 3. Check and modify window operator: +/// - Checks if the plan is a window operator, and connected with a sort. +/// If so, either tries to update the window definition or removes +/// unnecessary [`SortExec`]s (see `adjust_window_sort_removal`). +/// 4. Check and remove possibly unnecessary SPM: +/// - Checks if the plan is SPM and child 1 output partitions, if so +/// decides this SPM is unnecessary and removes it from the plan. pub fn ensure_sorting( mut requirements: PlanWithCorrespondingSort, ) -> Result> { @@ -429,7 +481,7 @@ pub fn ensure_sorting( if requirements.children.is_empty() { return Ok(Transformed::no(requirements)); } - let maybe_requirements = analyze_immediate_sort_removal(requirements); + let maybe_requirements = analyze_immediate_sort_removal(requirements)?; requirements = if !maybe_requirements.transformed { maybe_requirements.data } else { @@ -448,12 +500,20 @@ pub fn ensure_sorting( if let Some(required) = required_ordering { let eq_properties = child.plan.equivalence_properties(); - if !eq_properties.ordering_satisfy_requirement(&required) { + let req = required.into_single(); + if !eq_properties.ordering_satisfy_requirement(req.clone())? { // Make sure we preserve the ordering requirements: if physical_ordering.is_some() { child = update_child_to_remove_unnecessary_sort(idx, child, plan)?; } - child = add_sort_above(child, required, None); + child = add_sort_above( + child, + req, + plan.as_any() + .downcast_ref::() + .map(|output| output.fetch()) + .unwrap_or(None), + ); child = update_sort_ctx_children_data(child, true)?; } } else if physical_ordering.is_none() @@ -489,60 +549,56 @@ pub fn ensure_sorting( update_sort_ctx_children_data(requirements, false).map(Transformed::yes) } -/// Analyzes a given [`SortExec`] (`plan`) to determine whether its input -/// already has a finer ordering than it enforces. +/// Analyzes if there are any immediate sort removals by checking the `SortExec`s +/// and their ordering requirement satisfactions with children +/// If the sort is unnecessary, either replaces it with [`SortPreservingMergeExec`]/`LimitExec` +/// or removes the [`SortExec`]. +/// Otherwise, returns the original plan fn analyze_immediate_sort_removal( mut node: PlanWithCorrespondingSort, -) -> Transformed { - if let Some(sort_exec) = node.plan.as_any().downcast_ref::() { - let sort_input = sort_exec.input(); - // If this sort is unnecessary, we should remove it: - if sort_input.equivalence_properties().ordering_satisfy( - sort_exec - .properties() - .output_ordering() - .unwrap_or(LexOrdering::empty()), - ) { - node.plan = if !sort_exec.preserve_partitioning() - && sort_input.output_partitioning().partition_count() > 1 - { - // Replace the sort with a sort-preserving merge: - let expr = LexOrdering::new(sort_exec.expr().to_vec()); - Arc::new( - SortPreservingMergeExec::new(expr, Arc::clone(sort_input)) - .with_fetch(sort_exec.fetch()), - ) as _ +) -> Result> { + let Some(sort_exec) = node.plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + let sort_input = sort_exec.input(); + // Check if the sort is unnecessary: + let properties = sort_exec.properties(); + if let Some(ordering) = properties.output_ordering().cloned() { + let eqp = sort_input.equivalence_properties(); + if !eqp.ordering_satisfy(ordering)? { + return Ok(Transformed::no(node)); + } + } + node.plan = if !sort_exec.preserve_partitioning() + && sort_input.output_partitioning().partition_count() > 1 + { + // Replace the sort with a sort-preserving merge: + Arc::new( + SortPreservingMergeExec::new( + sort_exec.expr().clone(), + Arc::clone(sort_input), + ) + .with_fetch(sort_exec.fetch()), + ) as _ + } else { + // Remove the sort: + node.children = node.children.swap_remove(0).children; + if let Some(fetch) = sort_exec.fetch() { + // If the sort has a fetch, we need to add a limit: + if properties.output_partitioning().partition_count() == 1 { + Arc::new(GlobalLimitExec::new(Arc::clone(sort_input), 0, Some(fetch))) } else { - // Remove the sort: - node.children = node.children.swap_remove(0).children; - if let Some(fetch) = sort_exec.fetch() { - // If the sort has a fetch, we need to add a limit: - if sort_exec - .properties() - .output_partitioning() - .partition_count() - == 1 - { - Arc::new(GlobalLimitExec::new( - Arc::clone(sort_input), - 0, - Some(fetch), - )) - } else { - Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) - } - } else { - Arc::clone(sort_input) - } - }; - for child in node.children.iter_mut() { - child.data = false; + Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) } - node.data = false; - return Transformed::yes(node); + } else { + Arc::clone(sort_input) } + }; + for child in node.children.iter_mut() { + child.data = false; } - Transformed::no(node) + node.data = false; + Ok(Transformed::yes(node)) } /// Adjusts a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine @@ -583,15 +639,13 @@ fn adjust_window_sort_removal( } else { // We were unable to change the window to accommodate the input, so we // will insert a sort. - let reqs = window_tree - .plan - .required_input_ordering() - .swap_remove(0) - .unwrap_or_default(); + let reqs = window_tree.plan.required_input_ordering().swap_remove(0); // Satisfy the ordering requirement so that the window can run: let mut child_node = window_tree.children.swap_remove(0); - child_node = add_sort_above(child_node, reqs, None); + if let Some(reqs) = reqs { + child_node = add_sort_above(child_node, reqs.into_single(), None); + } let child_plan = Arc::clone(&child_node.plan); window_tree.children.push(child_node); @@ -738,8 +792,7 @@ fn remove_corresponding_sort_from_sub_plan( let fetch = plan.fetch(); let plan = if let Some(ordering) = plan.output_ordering() { Arc::new( - SortPreservingMergeExec::new(LexOrdering::new(ordering.to_vec()), plan) - .with_fetch(fetch), + SortPreservingMergeExec::new(ordering.clone(), plan).with_fetch(fetch), ) as _ } else { Arc::new(CoalescePartitionsExec::new(plan)) as _ diff --git a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs index 2c5c0d4d510ec..b536e7960208e 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs @@ -27,8 +27,7 @@ use crate::utils::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::{internal_err, Result}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::repartition::RepartitionExec; @@ -93,7 +92,7 @@ pub fn update_order_preservation_ctx_children_data(opc: &mut OrderPreservationCo /// inside `sort_input` with their order-preserving variants. This will /// generate an alternative plan, which will be accepted or rejected later on /// depending on whether it helps us remove a `SortExec`. -fn plan_with_order_preserving_variants( +pub fn plan_with_order_preserving_variants( mut sort_input: OrderPreservationContext, // Flag indicating that it is desirable to replace `RepartitionExec`s with // `SortPreservingRepartitionExec`s: @@ -138,6 +137,19 @@ fn plan_with_order_preserving_variants( } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { let child = &sort_input.children[0].plan; if let Some(ordering) = child.output_ordering() { + let mut fetch = fetch; + if let Some(coalesce_fetch) = sort_input.plan.fetch() { + if let Some(sort_fetch) = fetch { + if coalesce_fetch < sort_fetch { + return internal_err!( + "CoalescePartitionsExec fetch [{:?}] should be greater than or equal to SortExec fetch [{:?}]", coalesce_fetch, sort_fetch + ); + } + } else { + // If the sort node does not have a fetch, we need to keep the coalesce node's fetch. + fetch = Some(coalesce_fetch); + } + }; // When the input of a `CoalescePartitionsExec` has an ordering, // replace it with a `SortPreservingMergeExec` if appropriate: let spm = SortPreservingMergeExec::new(ordering.clone(), Arc::clone(child)) @@ -154,7 +166,7 @@ fn plan_with_order_preserving_variants( /// Calculates the updated plan by replacing operators that preserve ordering /// inside `sort_input` with their order-breaking variants. This will restore /// the original plan modified by [`plan_with_order_preserving_variants`]. -fn plan_with_order_breaking_variants( +pub fn plan_with_order_breaking_variants( mut sort_input: OrderPreservationContext, ) -> Result { let plan = &sort_input.plan; @@ -166,18 +178,17 @@ fn plan_with_order_breaking_variants( .map(|(node, maintains, required_ordering)| { // Replace with non-order preserving variants as long as ordering is // not required by intermediate operators: - if maintains - && (is_sort_preserving_merge(plan) - || !required_ordering.is_some_and(|required_ordering| { - node.plan - .equivalence_properties() - .ordering_satisfy_requirement(&required_ordering) - })) - { - plan_with_order_breaking_variants(node) - } else { - Ok(node) + if !maintains { + return Ok(node); + } else if is_sort_preserving_merge(plan) { + return plan_with_order_breaking_variants(node); + } else if let Some(required_ordering) = required_ordering { + let eqp = node.plan.equivalence_properties(); + if eqp.ordering_satisfy_requirement(required_ordering.into_single())? { + return Ok(node); + } } + plan_with_order_breaking_variants(node) }) .collect::>()?; sort_input.data = false; @@ -189,10 +200,12 @@ fn plan_with_order_breaking_variants( let partitioning = plan.output_partitioning().clone(); sort_input.plan = Arc::new(RepartitionExec::try_new(child, partitioning)?) as _; } else if is_sort_preserving_merge(plan) { - // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec`: + // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec` + // SPM may have `fetch`, so pass it to the `CoalescePartitionsExec` let child = Arc::clone(&sort_input.children[0].plan); - let coalesce = CoalescePartitionsExec::new(child); - sort_input.plan = Arc::new(coalesce) as _; + let coalesce = + Arc::new(CoalescePartitionsExec::new(child).with_fetch(plan.fetch())); + sort_input.plan = coalesce; } else { return sort_input.update_plan_from_children(); } @@ -264,25 +277,18 @@ pub fn replace_with_order_preserving_variants( )?; // If the alternate plan makes this sort unnecessary, accept the alternate: - if alternate_plan - .plan - .equivalence_properties() - .ordering_satisfy( - requirements - .plan - .output_ordering() - .unwrap_or(LexOrdering::empty()), - ) - { - for child in alternate_plan.children.iter_mut() { - child.data = false; + if let Some(ordering) = requirements.plan.output_ordering() { + let eqp = alternate_plan.plan.equivalence_properties(); + if !eqp.ordering_satisfy(ordering.clone())? { + // The alternate plan does not help, use faster order-breaking variants: + alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; + alternate_plan.data = false; + requirements.children = vec![alternate_plan]; + return Ok(Transformed::yes(requirements)); } - Ok(Transformed::yes(alternate_plan)) - } else { - // The alternate plan does not help, use faster order-breaking variants: - alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; - alternate_plan.data = false; - requirements.children = vec![alternate_plan]; - Ok(Transformed::yes(requirements)) } + for child in alternate_plan.children.iter_mut() { + child.data = false; + } + Ok(Transformed::yes(alternate_plan)) } diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 2e20608d0e9ed..6e4e784866129 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -24,12 +24,18 @@ use crate::utils::{ use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{plan_err, HashSet, JoinSide, Result}; +use datafusion_common::{internal_err, HashSet, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::PhysicalSortRequirement; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr::{ + add_offset_to_physical_sort_exprs, EquivalenceProperties, +}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, + PhysicalSortRequirement, +}; +use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ calculate_join_output_ordering, ColumnIndex, @@ -50,7 +56,7 @@ use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; /// [`EnforceSorting`]: crate::enforce_sorting::EnforceSorting #[derive(Default, Clone, Debug)] pub struct ParentRequirements { - ordering_requirement: Option, + ordering_requirement: Option, fetch: Option, } @@ -69,6 +75,7 @@ pub fn assign_initial_requirements(sort_push_down: &mut SortPushDown) { } } +/// Tries to push down the sort requirements as far as possible, if decides a `SortExec` is unnecessary removes it. pub fn pushdown_sorts(sort_push_down: SortPushDown) -> Result { sort_push_down .transform_down(pushdown_sorts_helper) @@ -87,91 +94,108 @@ fn min_fetch(f1: Option, f2: Option) -> Option { fn pushdown_sorts_helper( mut sort_push_down: SortPushDown, ) -> Result> { - let plan = &sort_push_down.plan; - let parent_reqs = sort_push_down - .data - .ordering_requirement - .clone() - .unwrap_or_default(); - let satisfy_parent = plan - .equivalence_properties() - .ordering_satisfy_requirement(&parent_reqs); - - if is_sort(plan) { - let current_sort_fetch = plan.fetch(); - let parent_req_fetch = sort_push_down.data.fetch; - - let current_plan_reqs = plan - .output_ordering() - .cloned() - .map(LexRequirement::from) - .unwrap_or_default(); - let parent_is_stricter = plan - .equivalence_properties() - .requirements_compatible(&parent_reqs, ¤t_plan_reqs); - let current_is_stricter = plan - .equivalence_properties() - .requirements_compatible(¤t_plan_reqs, &parent_reqs); + let plan = sort_push_down.plan; + let parent_fetch = sort_push_down.data.fetch; - if !satisfy_parent && !parent_is_stricter { - // This new sort has different requirements than the ordering being pushed down. - // 1. add a `SortExec` here for the pushed down ordering (parent reqs). - // 2. continue sort pushdown, but with the new ordering of the new sort. + let Some(parent_requirement) = sort_push_down.data.ordering_requirement.clone() + else { + // If there are no ordering requirements from the parent, nothing to do + // unless we have a sort. + if is_sort(&plan) { + let Some(sort_ordering) = plan.output_ordering().cloned() else { + return internal_err!("SortExec should have output ordering"); + }; + // The sort is unnecessary, just propagate the stricter fetch and + // ordering requirements. + let fetch = min_fetch(plan.fetch(), parent_fetch); + sort_push_down = sort_push_down + .children + .swap_remove(0) + .update_plan_from_children()?; + sort_push_down.data.fetch = fetch; + sort_push_down.data.ordering_requirement = + Some(OrderingRequirements::from(sort_ordering)); + // Recursive call to helper, so it doesn't transform_down and miss + // the new node (previous child of sort): + return pushdown_sorts_helper(sort_push_down); + } + sort_push_down.plan = plan; + return Ok(Transformed::no(sort_push_down)); + }; + + let eqp = plan.equivalence_properties(); + let satisfy_parent = + eqp.ordering_satisfy_requirement(parent_requirement.first().clone())?; - // remove current sort (which will be the new ordering to pushdown) - let new_reqs = current_plan_reqs; - sort_push_down = sort_push_down.children.swap_remove(0); - sort_push_down = sort_push_down.update_plan_from_children()?; // changed plan + if is_sort(&plan) { + let Some(sort_ordering) = plan.output_ordering().cloned() else { + return internal_err!("SortExec should have output ordering"); + }; - // add back sort exec matching parent - sort_push_down = - add_sort_above(sort_push_down, parent_reqs, parent_req_fetch); + let sort_fetch = plan.fetch(); + let parent_is_stricter = eqp.requirements_compatible( + parent_requirement.first().clone(), + sort_ordering.clone().into(), + ); - // make pushdown requirements be the new ones. + // Remove the current sort as we are either going to prove that it is + // unnecessary, or replace it with a stricter sort. + sort_push_down = sort_push_down + .children + .swap_remove(0) + .update_plan_from_children()?; + if !satisfy_parent && !parent_is_stricter { + // The sort was imposing a different ordering than the one being + // pushed down. Replace it with a sort that matches the pushed-down + // ordering, and continue the pushdown. + // Add back the sort: + sort_push_down = add_sort_above( + sort_push_down, + parent_requirement.into_single(), + parent_fetch, + ); + // Update pushdown requirements: sort_push_down.children[0].data = ParentRequirements { - ordering_requirement: Some(new_reqs), - fetch: current_sort_fetch, + ordering_requirement: Some(OrderingRequirements::from(sort_ordering)), + fetch: sort_fetch, }; + return Ok(Transformed::yes(sort_push_down)); } else { - // Don't add a SortExec - // Do update what sort requirements to keep pushing down - - // remove current sort, and get the sort's child - sort_push_down = sort_push_down.children.swap_remove(0); - sort_push_down = sort_push_down.update_plan_from_children()?; // changed plan - - // set the stricter fetch - sort_push_down.data.fetch = min_fetch(current_sort_fetch, parent_req_fetch); - - // set the stricter ordering - if current_is_stricter { - sort_push_down.data.ordering_requirement = Some(current_plan_reqs); + // Sort was unnecessary, just propagate the stricter fetch and + // ordering requirements: + sort_push_down.data.fetch = min_fetch(sort_fetch, parent_fetch); + let current_is_stricter = eqp.requirements_compatible( + sort_ordering.clone().into(), + parent_requirement.first().clone(), + ); + sort_push_down.data.ordering_requirement = if current_is_stricter { + Some(OrderingRequirements::from(sort_ordering)) } else { - sort_push_down.data.ordering_requirement = Some(parent_reqs); - } - - // recursive call to helper, so it doesn't transform_down and miss the new node (previous child of sort) + Some(parent_requirement) + }; + // Recursive call to helper, so it doesn't transform_down and miss + // the new node (previous child of sort): return pushdown_sorts_helper(sort_push_down); } - } else if parent_reqs.is_empty() { - // note: this `satisfy_parent`, but we don't want to push down anything. - // Nothing to do. - return Ok(Transformed::no(sort_push_down)); - } else if satisfy_parent { + } + + sort_push_down.plan = plan; + if satisfy_parent { // For non-sort operators which satisfy ordering: - let reqs = plan.required_input_ordering(); - let parent_req_fetch = sort_push_down.data.fetch; + let reqs = sort_push_down.plan.required_input_ordering(); for (child, order) in sort_push_down.children.iter_mut().zip(reqs) { child.data.ordering_requirement = order; - child.data.fetch = min_fetch(parent_req_fetch, child.data.fetch); + child.data.fetch = min_fetch(parent_fetch, child.data.fetch); } - } else if let Some(adjusted) = pushdown_requirement_to_children(plan, &parent_reqs)? { - // For operators that can take a sort pushdown. - - // Continue pushdown, with updated requirements: - let parent_fetch = sort_push_down.data.fetch; - let current_fetch = plan.fetch(); + } else if let Some(adjusted) = pushdown_requirement_to_children( + &sort_push_down.plan, + parent_requirement.clone(), + parent_fetch, + )? { + // For operators that can take a sort pushdown, continue with updated + // requirements: + let current_fetch = sort_push_down.plan.fetch(); for (child, order) in sort_push_down.children.iter_mut().zip(adjusted) { child.data.ordering_requirement = order; child.data.fetch = min_fetch(current_fetch, parent_fetch); @@ -179,16 +203,13 @@ fn pushdown_sorts_helper( sort_push_down.data.ordering_requirement = None; } else { // Can not push down requirements, add new `SortExec`: - let sort_reqs = sort_push_down - .data - .ordering_requirement - .clone() - .unwrap_or_default(); - let fetch = sort_push_down.data.fetch; - sort_push_down = add_sort_above(sort_push_down, sort_reqs, fetch); + sort_push_down = add_sort_above( + sort_push_down, + parent_requirement.into_single(), + parent_fetch, + ); assign_initial_requirements(&mut sort_push_down); } - Ok(Transformed::yes(sort_push_down)) } @@ -196,21 +217,52 @@ fn pushdown_sorts_helper( /// If sort cannot be pushed down, return None. fn pushdown_requirement_to_children( plan: &Arc, - parent_required: &LexRequirement, -) -> Result>>> { + parent_required: OrderingRequirements, + parent_fetch: Option, +) -> Result>>> { + // If there is a limit on the parent plan we cannot push it down through operators that change the cardinality. + // E.g. consider if LIMIT 2 is applied below a FilteExec that filters out 1/2 of the rows we'll end up with 1 row instead of 2. + // If the LIMIT is applied after the FilterExec and the FilterExec returns > 2 rows we'll end up with 2 rows (correct). + if parent_fetch.is_some() && !plan.supports_limit_pushdown() { + return Ok(None); + } + // Note: we still need to check the cardinality effect of the plan here, because the + // limit pushdown is not always safe, even if the plan supports it. Here's an example: + // + // UnionExec advertises `supports_limit_pushdown() == true` because it can + // forward a LIMIT k to each of its children—i.e. apply “LIMIT k” separately + // on each branch before merging them together. + // + // However, UnionExec’s `cardinality_effect() == GreaterEqual` (it sums up + // all child row counts), so pushing a global TopK/LIMIT through it would + // break the semantics of “take the first k rows of the combined result.” + // + // For example, with two branches A and B and k = 3: + // — Global LIMIT: take the first 3 rows from (A ∪ B) after merging. + // — Pushed down: take 3 from A, 3 from B, then merge → up to 6 rows! + // + // That’s why we still block on cardinality: even though UnionExec can + // push a LIMIT to its children, its GreaterEqual effect means it cannot + // preserve the global TopK semantics. + if parent_fetch.is_some() { + match plan.cardinality_effect() { + CardinalityEffect::Equal => { + // safe: only true sources (e.g. CoalesceBatchesExec, ProjectionExec) pass + } + _ => return Ok(None), + } + } + let maintains_input_order = plan.maintains_input_order(); if is_window(plan) { - let required_input_ordering = plan.required_input_ordering(); - let request_child = required_input_ordering[0].clone().unwrap_or_default(); + let mut required_input_ordering = plan.required_input_ordering(); + let maybe_child_requirement = required_input_ordering.swap_remove(0); let child_plan = plan.children().swap_remove(0); - - match determine_children_requirement(parent_required, &request_child, child_plan) - { - RequirementsCompatibility::Satisfy => { - let req = (!request_child.is_empty()) - .then(|| LexRequirement::new(request_child.to_vec())); - Ok(Some(vec![req])) - } + let Some(child_req) = maybe_child_requirement else { + return Ok(None); + }; + match determine_children_requirement(&parent_required, &child_req, child_plan) { + RequirementsCompatibility::Satisfy => Ok(Some(vec![Some(child_req)])), RequirementsCompatibility::Compatible(adjusted) => { // If parent requirements are more specific than output ordering // of the window plan, then we can deduce that the parent expects @@ -218,7 +270,7 @@ fn pushdown_requirement_to_children( // that's the case, we block the pushdown of sort operation. if !plan .equivalence_properties() - .ordering_satisfy_requirement(parent_required) + .ordering_satisfy_requirement(parent_required.into_single())? { return Ok(None); } @@ -228,82 +280,71 @@ fn pushdown_requirement_to_children( RequirementsCompatibility::NonCompatible => Ok(None), } } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let sort_req = LexRequirement::from( - sort_exec - .properties() - .output_ordering() - .cloned() - .unwrap_or(LexOrdering::default()), - ); - if sort_exec + let Some(sort_ordering) = sort_exec.properties().output_ordering().cloned() + else { + return internal_err!("SortExec should have output ordering"); + }; + sort_exec .properties() .eq_properties - .requirements_compatible(parent_required, &sort_req) - { - debug_assert!(!parent_required.is_empty()); - Ok(Some(vec![Some(LexRequirement::new( - parent_required.to_vec(), - ))])) - } else { - Ok(None) - } + .requirements_compatible( + parent_required.first().clone(), + sort_ordering.into(), + ) + .then(|| Ok(vec![Some(parent_required)])) + .transpose() } else if plan.fetch().is_some() && plan.supports_limit_pushdown() && plan .maintains_input_order() - .iter() - .all(|maintain| *maintain) + .into_iter() + .all(|maintain| maintain) { - let output_req = LexRequirement::from( - plan.properties() - .output_ordering() - .cloned() - .unwrap_or(LexOrdering::default()), - ); // Push down through operator with fetch when: // - requirement is aligned with output ordering // - it preserves ordering during execution - if plan - .properties() - .eq_properties - .requirements_compatible(parent_required, &output_req) - { - let req = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - Ok(Some(vec![req])) + let Some(ordering) = plan.properties().output_ordering() else { + return Ok(Some(vec![Some(parent_required)])); + }; + if plan.properties().eq_properties.requirements_compatible( + parent_required.first().clone(), + ordering.clone().into(), + ) { + Ok(Some(vec![Some(parent_required)])) } else { Ok(None) } } else if is_union(plan) { - // UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and - // propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec - let req = (!parent_required.is_empty()).then(|| parent_required.clone()); - Ok(Some(vec![req; plan.children().len()])) + // `UnionExec` does not have real sort requirements for its input, we + // just propagate the sort requirements down: + Ok(Some(vec![Some(parent_required); plan.children().len()])) } else if let Some(smj) = plan.as_any().downcast_ref::() { - // If the current plan is SortMergeJoinExec let left_columns_len = smj.left().schema().fields().len(); - let parent_required_expr = LexOrdering::from(parent_required.clone()); - match expr_source_side( - parent_required_expr.as_ref(), - smj.join_type(), - left_columns_len, - ) { - Some(JoinSide::Left) => try_pushdown_requirements_to_join( + let parent_ordering: Vec = parent_required + .first() + .iter() + .cloned() + .map(Into::into) + .collect(); + let eqp = smj.properties().equivalence_properties(); + match expr_source_side(eqp, parent_ordering, smj.join_type(), left_columns_len) { + Some((JoinSide::Left, ordering)) => try_pushdown_requirements_to_join( smj, - parent_required, - parent_required_expr.as_ref(), + parent_required.into_single(), + ordering, JoinSide::Left, ), - Some(JoinSide::Right) => { + Some((JoinSide::Right, ordering)) => { let right_offset = smj.schema().fields.len() - smj.right().schema().fields.len(); - let new_right_required = - shift_right_required(parent_required, right_offset)?; - let new_right_required_expr = LexOrdering::from(new_right_required); + let ordering = add_offset_to_physical_sort_exprs( + ordering, + -(right_offset as isize), + )?; try_pushdown_requirements_to_join( smj, - parent_required, - new_right_required_expr.as_ref(), + parent_required.into_single(), + ordering, JoinSide::Right, ) } @@ -318,28 +359,26 @@ fn pushdown_requirement_to_children( || plan.as_any().is::() // TODO: Add support for Projection push down || plan.as_any().is::() - || pushdown_would_violate_requirements(parent_required, plan.as_ref()) + || pushdown_would_violate_requirements(&parent_required, plan.as_ref()) { // If the current plan is a leaf node or can not maintain any of the input ordering, can not pushed down requirements. // For RepartitionExec, we always choose to not push down the sort requirements even the RepartitionExec(input_partition=1) could maintain input ordering. // Pushing down is not beneficial Ok(None) } else if is_sort_preserving_merge(plan) { - let new_ordering = LexOrdering::from(parent_required.clone()); + let new_ordering = LexOrdering::from(parent_required.first().clone()); let mut spm_eqs = plan.equivalence_properties().clone(); + let old_ordering = spm_eqs.output_ordering().unwrap(); // Sort preserving merge will have new ordering, one requirement above is pushed down to its below. - spm_eqs = spm_eqs.with_reorder(new_ordering); - // Do not push-down through SortPreservingMergeExec when - // ordering requirement invalidates requirement of sort preserving merge exec. - if !spm_eqs.ordering_satisfy(&plan.output_ordering().cloned().unwrap_or_default()) - { - Ok(None) - } else { + let change = spm_eqs.reorder(new_ordering)?; + if !change || spm_eqs.ordering_satisfy(old_ordering)? { // Can push-down through SortPreservingMergeExec, because parent requirement is finer // than SortPreservingMergeExec output ordering. - let req = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - Ok(Some(vec![req])) + Ok(Some(vec![Some(parent_required)])) + } else { + // Do not push-down through SortPreservingMergeExec when + // ordering requirement invalidates requirement of sort preserving merge exec. + Ok(None) } } else if let Some(hash_join) = plan.as_any().downcast_ref::() { handle_hash_join(hash_join, parent_required) @@ -352,22 +391,21 @@ fn pushdown_requirement_to_children( /// Return true if pushing the sort requirements through a node would violate /// the input sorting requirements for the plan fn pushdown_would_violate_requirements( - parent_required: &LexRequirement, + parent_required: &OrderingRequirements, child: &dyn ExecutionPlan, ) -> bool { child .required_input_ordering() - .iter() + .into_iter() + // If there is no requirement, pushing down would not violate anything. + .flatten() .any(|child_required| { - let Some(child_required) = child_required.as_ref() else { - // no requirements, so pushing down would not violate anything - return false; - }; - // check if the plan's requirements would still e satisfied if we pushed - // down the parent requirements + // Check if the plan's requirements would still be satisfied if we + // pushed down the parent requirements: child_required + .into_single() .iter() - .zip(parent_required.iter()) + .zip(parent_required.first().iter()) .all(|(c, p)| !c.compatible(p)) }) } @@ -378,25 +416,24 @@ fn pushdown_would_violate_requirements( /// - If parent requirements are more specific, push down parent requirements. /// - If they are not compatible, need to add a sort. fn determine_children_requirement( - parent_required: &LexRequirement, - request_child: &LexRequirement, + parent_required: &OrderingRequirements, + child_requirement: &OrderingRequirements, child_plan: &Arc, ) -> RequirementsCompatibility { - if child_plan - .equivalence_properties() - .requirements_compatible(request_child, parent_required) - { + let eqp = child_plan.equivalence_properties(); + if eqp.requirements_compatible( + child_requirement.first().clone(), + parent_required.first().clone(), + ) { // Child requirements are more specific, no need to push down. RequirementsCompatibility::Satisfy - } else if child_plan - .equivalence_properties() - .requirements_compatible(parent_required, request_child) - { + } else if eqp.requirements_compatible( + parent_required.first().clone(), + child_requirement.first().clone(), + ) { // Parent requirements are more specific, adjust child's requirements // and push down the new requirements: - let adjusted = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - RequirementsCompatibility::Compatible(adjusted) + RequirementsCompatibility::Compatible(Some(parent_required.clone())) } else { RequirementsCompatibility::NonCompatible } @@ -404,42 +441,41 @@ fn determine_children_requirement( fn try_pushdown_requirements_to_join( smj: &SortMergeJoinExec, - parent_required: &LexRequirement, - sort_expr: &LexOrdering, + parent_required: LexRequirement, + sort_exprs: Vec, push_side: JoinSide, -) -> Result>>> { - let left_eq_properties = smj.left().equivalence_properties(); - let right_eq_properties = smj.right().equivalence_properties(); +) -> Result>>> { let mut smj_required_orderings = smj.required_input_ordering(); - let right_requirement = smj_required_orderings.swap_remove(1); - let left_requirement = smj_required_orderings.swap_remove(0); - let left_ordering = &smj.left().output_ordering().cloned().unwrap_or_default(); - let right_ordering = &smj.right().output_ordering().cloned().unwrap_or_default(); + let ordering = LexOrdering::new(sort_exprs.clone()); let (new_left_ordering, new_right_ordering) = match push_side { JoinSide::Left => { - let left_eq_properties = - left_eq_properties.clone().with_reorder(sort_expr.clone()); - if left_eq_properties - .ordering_satisfy_requirement(&left_requirement.unwrap_or_default()) + let mut left_eq_properties = smj.left().equivalence_properties().clone(); + left_eq_properties.reorder(sort_exprs)?; + let Some(left_requirement) = smj_required_orderings.swap_remove(0) else { + return Ok(None); + }; + if !left_eq_properties + .ordering_satisfy_requirement(left_requirement.into_single())? { - // After re-ordering requirement is still satisfied - (sort_expr, right_ordering) - } else { return Ok(None); } + // After re-ordering, requirement is still satisfied: + (ordering.as_ref(), smj.right().output_ordering()) } JoinSide::Right => { - let right_eq_properties = - right_eq_properties.clone().with_reorder(sort_expr.clone()); - if right_eq_properties - .ordering_satisfy_requirement(&right_requirement.unwrap_or_default()) + let mut right_eq_properties = smj.right().equivalence_properties().clone(); + right_eq_properties.reorder(sort_exprs)?; + let Some(right_requirement) = smj_required_orderings.swap_remove(1) else { + return Ok(None); + }; + if !right_eq_properties + .ordering_satisfy_requirement(right_requirement.into_single())? { - // After re-ordering requirement is still satisfied - (left_ordering, sort_expr) - } else { return Ok(None); } + // After re-ordering, requirement is still satisfied: + (smj.left().output_ordering(), ordering.as_ref()) } JoinSide::None => return Ok(None), }; @@ -449,18 +485,19 @@ fn try_pushdown_requirements_to_join( new_left_ordering, new_right_ordering, join_type, - smj.on(), smj.left().schema().fields.len(), &smj.maintains_input_order(), Some(probe_side), - ); + )?; let mut smj_eqs = smj.properties().equivalence_properties().clone(); - // smj will have this ordering when its input changes. - smj_eqs = smj_eqs.with_reorder(new_output_ordering.unwrap_or_default()); - let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required); + if let Some(new_output_ordering) = new_output_ordering { + // smj will have this ordering when its input changes. + smj_eqs.reorder(new_output_ordering)?; + } + let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required)?; Ok(should_pushdown.then(|| { let mut required_input_ordering = smj.required_input_ordering(); - let new_req = Some(LexRequirement::from(sort_expr.clone())); + let new_req = ordering.map(Into::into); match push_side { JoinSide::Left => { required_input_ordering[0] = new_req; @@ -475,77 +512,78 @@ fn try_pushdown_requirements_to_join( } fn expr_source_side( - required_exprs: &LexOrdering, + eqp: &EquivalenceProperties, + mut ordering: Vec, join_type: JoinType, left_columns_len: usize, -) -> Option { +) -> Option<(JoinSide, Vec)> { + // TODO: Handle the case where a prefix of the ordering comes from the left + // and a suffix from the right. match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full - | JoinType::LeftMark => { - let all_column_sides = required_exprs - .iter() - .filter_map(|r| { - r.expr.as_any().downcast_ref::().map(|col| { - if col.index() < left_columns_len { - JoinSide::Left - } else { - JoinSide::Right + | JoinType::LeftMark + | JoinType::RightMark => { + let eq_group = eqp.eq_group(); + let mut right_ordering = ordering.clone(); + let (mut valid_left, mut valid_right) = (true, true); + for (left, right) in ordering.iter_mut().zip(right_ordering.iter_mut()) { + let col = left.expr.as_any().downcast_ref::()?; + let eq_class = eq_group.get_equivalence_class(&left.expr); + if col.index() < left_columns_len { + if valid_right { + valid_right = eq_class.is_some_and(|cls| { + for expr in cls.iter() { + if expr + .as_any() + .downcast_ref::() + .is_some_and(|c| c.index() >= left_columns_len) + { + right.expr = Arc::clone(expr); + return true; + } + } + false + }); + } + } else if valid_left { + valid_left = eq_class.is_some_and(|cls| { + for expr in cls.iter() { + if expr + .as_any() + .downcast_ref::() + .is_some_and(|c| c.index() < left_columns_len) + { + left.expr = Arc::clone(expr); + return true; + } } - }) - }) - .collect::>(); - - // If the exprs are all coming from one side, the requirements can be pushed down - if all_column_sides.len() != required_exprs.len() { - None - } else if all_column_sides - .iter() - .all(|side| matches!(side, JoinSide::Left)) - { - Some(JoinSide::Left) - } else if all_column_sides - .iter() - .all(|side| matches!(side, JoinSide::Right)) - { - Some(JoinSide::Right) + false + }); + }; + if !(valid_left || valid_right) { + return None; + } + } + if valid_left { + Some((JoinSide::Left, ordering)) + } else if valid_right { + Some((JoinSide::Right, right_ordering)) } else { + // TODO: Handle the case where we can push down to both sides. None } } - JoinType::LeftSemi | JoinType::LeftAnti => required_exprs + JoinType::LeftSemi | JoinType::LeftAnti => ordering .iter() - .all(|e| e.expr.as_any().downcast_ref::().is_some()) - .then_some(JoinSide::Left), - JoinType::RightSemi | JoinType::RightAnti => required_exprs + .all(|e| e.expr.as_any().is::()) + .then_some((JoinSide::Left, ordering)), + JoinType::RightSemi | JoinType::RightAnti => ordering .iter() - .all(|e| e.expr.as_any().downcast_ref::().is_some()) - .then_some(JoinSide::Right), - } -} - -fn shift_right_required( - parent_required: &LexRequirement, - left_columns_len: usize, -) -> Result { - let new_right_required = parent_required - .iter() - .filter_map(|r| { - let col = r.expr.as_any().downcast_ref::()?; - col.index().checked_sub(left_columns_len).map(|offset| { - r.clone() - .with_expr(Arc::new(Column::new(col.name(), offset))) - }) - }) - .collect::>(); - if new_right_required.len() == parent_required.len() { - Ok(LexRequirement::new(new_right_required)) - } else { - plan_err!( - "Expect to shift all the parent required column indexes for SortMergeJoin" - ) + .all(|e| e.expr.as_any().is::()) + .then_some((JoinSide::Right, ordering)), } } @@ -565,16 +603,18 @@ fn shift_right_required( /// pushed down, `Ok(None)` if not. On error, returns a `Result::Err`. fn handle_custom_pushdown( plan: &Arc, - parent_required: &LexRequirement, + parent_required: OrderingRequirements, maintains_input_order: Vec, -) -> Result>>> { - // If there's no requirement from the parent or the plan has no children, return early - if parent_required.is_empty() || plan.children().is_empty() { +) -> Result>>> { + // If the plan has no children, return early: + if plan.children().is_empty() { return Ok(None); } - // Collect all unique column indices used in the parent-required sorting expression - let all_indices: HashSet = parent_required + // Collect all unique column indices used in the parent-required sorting + // expression: + let requirement = parent_required.into_single(); + let all_indices: HashSet = requirement .iter() .flat_map(|order| { collect_columns(&order.expr) @@ -584,14 +624,14 @@ fn handle_custom_pushdown( }) .collect(); - // Get the number of fields in each child's schema - let len_of_child_schemas: Vec = plan + // Get the number of fields in each child's schema: + let children_schema_lengths: Vec = plan .children() .iter() .map(|c| c.schema().fields().len()) .collect(); - // Find the index of the child that maintains input order + // Find the index of the order-maintaining child: let Some(maintained_child_idx) = maintains_input_order .iter() .enumerate() @@ -601,26 +641,28 @@ fn handle_custom_pushdown( return Ok(None); }; - // Check if all required columns come from the child that maintains input order - let start_idx = len_of_child_schemas[..maintained_child_idx] + // Check if all required columns come from the order-maintaining child: + let start_idx = children_schema_lengths[..maintained_child_idx] .iter() .sum::(); - let end_idx = start_idx + len_of_child_schemas[maintained_child_idx]; + let end_idx = start_idx + children_schema_lengths[maintained_child_idx]; let all_from_maintained_child = all_indices.iter().all(|i| i >= &start_idx && i < &end_idx); - // If all columns are from the maintained child, update the parent requirements + // If all columns are from the maintained child, update the parent requirements: if all_from_maintained_child { - let sub_offset = len_of_child_schemas + let sub_offset = children_schema_lengths .iter() .take(maintained_child_idx) .sum::(); - // Transform the parent-required expression for the child schema by adjusting columns - let updated_parent_req = parent_required - .iter() + // Transform the parent-required expression for the child schema by + // adjusting columns: + let updated_parent_req = requirement + .into_iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = Arc::clone(&req.expr) + let updated_columns = req + .expr .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { let new_index = col.index() - sub_offset; @@ -642,7 +684,8 @@ fn handle_custom_pushdown( .iter() .map(|&maintains_order| { if maintains_order { - Some(LexRequirement::new(updated_parent_req.clone())) + LexRequirement::new(updated_parent_req.clone()) + .map(OrderingRequirements::new) } else { None } @@ -659,16 +702,17 @@ fn handle_custom_pushdown( // for join type: Inner, Right, RightSemi, RightAnti fn handle_hash_join( plan: &HashJoinExec, - parent_required: &LexRequirement, -) -> Result>>> { - // If there's no requirement from the parent or the plan has no children - // or the join type is not Inner, Right, RightSemi, RightAnti, return early - if parent_required.is_empty() || !plan.maintains_input_order()[1] { + parent_required: OrderingRequirements, +) -> Result>>> { + // If the plan has no children or does not maintain the right side ordering, + // return early: + if !plan.maintains_input_order()[1] { return Ok(None); } // Collect all unique column indices used in the parent-required sorting expression - let all_indices: HashSet = parent_required + let requirement = parent_required.into_single(); + let all_indices: HashSet<_> = requirement .iter() .flat_map(|order| { collect_columns(&order.expr) @@ -694,11 +738,12 @@ fn handle_hash_join( // If all columns are from the right child, update the parent requirements if all_from_right_child { // Transform the parent-required expression for the child schema by adjusting columns - let updated_parent_req = parent_required - .iter() + let updated_parent_req = requirement + .into_iter() .map(|req| { let child_schema = plan.children()[1].schema(); - let updated_columns = Arc::clone(&req.expr) + let updated_columns = req + .expr .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { let index = projected_indices[col.index()].index; @@ -718,7 +763,7 @@ fn handle_hash_join( // Populating with the updated requirements for children that maintain order Ok(Some(vec![ None, - Some(LexRequirement::new(updated_parent_req)), + LexRequirement::new(updated_parent_req).map(OrderingRequirements::new), ])) } else { Ok(None) @@ -757,7 +802,7 @@ enum RequirementsCompatibility { /// Requirements satisfy Satisfy, /// Requirements compatible - Compatible(Option), + Compatible(Option), /// Requirements not compatible NonCompatible, } diff --git a/datafusion/physical-optimizer/src/ensure_coop.rs b/datafusion/physical-optimizer/src/ensure_coop.rs new file mode 100644 index 0000000000000..0c0b63c0b3e79 --- /dev/null +++ b/datafusion/physical-optimizer/src/ensure_coop.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The [`EnsureCooperative`] optimizer rule inspects the physical plan to find all +//! portions of the plan that will not yield cooperatively. +//! It will insert `CooperativeExec` nodes where appropriate to ensure execution plans +//! always yield cooperatively. + +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::Result; +use datafusion_physical_plan::coop::CooperativeExec; +use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType}; +use datafusion_physical_plan::ExecutionPlan; + +/// `EnsureCooperative` is a [`PhysicalOptimizerRule`] that inspects the physical plan for +/// sub plans that do not participate in cooperative scheduling. The plan is subdivided into sub +/// plans on eager evaluation boundaries. Leaf nodes and eager evaluation roots are checked +/// to see if they participate in cooperative scheduling. Those that do no are wrapped in +/// a [`CooperativeExec`] parent. +pub struct EnsureCooperative {} + +impl EnsureCooperative { + pub fn new() -> Self { + Self {} + } +} + +impl Default for EnsureCooperative { + fn default() -> Self { + Self::new() + } +} + +impl Debug for EnsureCooperative { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct(self.name()).finish() + } +} + +impl PhysicalOptimizerRule for EnsureCooperative { + fn name(&self) -> &str { + "EnsureCooperative" + } + + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(|plan| { + let is_leaf = plan.children().is_empty(); + let is_exchange = plan.properties().evaluation_type == EvaluationType::Eager; + if (is_leaf || is_exchange) + && plan.properties().scheduling_type != SchedulingType::Cooperative + { + // Wrap non-cooperative leaves or eager evaluation roots in a cooperative exec to + // ensure the plans they participate in are properly cooperative. + Ok(Transformed::new( + Arc::new(CooperativeExec::new(Arc::clone(&plan))), + true, + TreeNodeRecursion::Continue, + )) + } else { + Ok(Transformed::no(plan)) + } + }) + .map(|t| t.data) + } + + fn schema_check(&self) -> bool { + // Wrapping a leaf in YieldStreamExec preserves the schema, so it is safe. + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::config::ConfigOptions; + use datafusion_physical_plan::{displayable, test::scan_partitioned}; + use insta::assert_snapshot; + + #[tokio::test] + async fn test_cooperative_exec_for_custom_exec() { + let test_custom_exec = scan_partitioned(1); + let config = ConfigOptions::new(); + let optimized = EnsureCooperative::new() + .optimize(test_custom_exec, &config) + .unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + // Use insta snapshot to ensure full plan structure + assert_snapshot!(display, @r###" + CooperativeExec + DataSourceExec: partitions=1, partition_sizes=[1] + "###); + } +} diff --git a/datafusion/physical-optimizer/src/filter_pushdown.rs b/datafusion/physical-optimizer/src/filter_pushdown.rs new file mode 100644 index 0000000000000..5ee7023ff6ee2 --- /dev/null +++ b/datafusion/physical-optimizer/src/filter_pushdown.rs @@ -0,0 +1,867 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Filter Pushdown Optimization Process +//! +//! The filter pushdown mechanism involves four key steps: +//! 1. **Optimizer Asks Parent for a Filter Pushdown Plan**: The optimizer calls [`ExecutionPlan::gather_filters_for_pushdown`] +//! on the parent node, passing in parent predicates and phase. The parent node creates a [`FilterDescription`] +//! by inspecting its logic and children's schemas, determining which filters can be pushed to each child. +//! 2. **Optimizer Executes Pushdown**: The optimizer recursively calls `push_down_filters` in this module on each child, +//! passing the appropriate filters (`Vec>`) for that child. +//! 3. **Optimizer Gathers Results**: The optimizer collects [`FilterPushdownPropagation`] results from children, +//! containing information about which filters were successfully pushed down vs. unsupported. +//! 4. **Parent Responds**: The optimizer calls [`ExecutionPlan::handle_child_pushdown_result`] on the parent, +//! passing a [`ChildPushdownResult`] containing the aggregated pushdown outcomes. The parent decides +//! how to handle filters that couldn't be pushed down (e.g., keep them as FilterExec nodes). +//! +//! [`FilterDescription`]: datafusion_physical_plan::filter_pushdown::FilterDescription + +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{config::ConfigOptions, internal_err, Result}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::physical_expr::is_volatile; +use datafusion_physical_plan::filter_pushdown::{ + ChildFilterPushdownResult, ChildPushdownResult, FilterPushdownPhase, + FilterPushdownPropagation, PushedDown, +}; +use datafusion_physical_plan::{with_new_children_if_necessary, ExecutionPlan}; + +use itertools::{izip, Itertools}; + +/// Attempts to recursively push given filters from the top of the tree into leaves. +/// +/// # Default Implementation +/// +/// The default implementation in [`ExecutionPlan::gather_filters_for_pushdown`] +/// and [`ExecutionPlan::handle_child_pushdown_result`] assumes that: +/// +/// * Parent filters can't be passed onto children (determined by [`ExecutionPlan::gather_filters_for_pushdown`]) +/// * This node has no filters to contribute (determined by [`ExecutionPlan::gather_filters_for_pushdown`]). +/// * Any filters that could not be pushed down to the children are marked as unsupported (determined by [`ExecutionPlan::handle_child_pushdown_result`]). +/// +/// # Example: Push filter into a `DataSourceExec` +/// +/// For example, consider the following plan: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// Our goal is to move the `id = 1` filter from the [`FilterExec`] node to the `DataSourceExec` node. +/// +/// If this filter is selective pushing it into the scan can avoid massive +/// amounts of data being read from the source (the projection is `*` so all +/// matching columns are read). +/// +/// The new plan looks like: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// # Example: Push filters with `ProjectionExec` +/// +/// Let's consider a more complex example involving a [`ProjectionExec`] +/// node in between the [`FilterExec`] and `DataSourceExec` nodes that +/// creates a new column that the filter depends on. +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [cost>50,id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// We want to push down the filters `[id=1]` to the `DataSourceExec` node, +/// but can't push down `cost>50` because it requires the [`ProjectionExec`] +/// node to be executed first. A simple thing to do would be to split up the +/// filter into two separate filters and push down the first one: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [cost>50] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// We can actually however do better by pushing down `price * 1.2 > 50` +/// instead of `cost > 50`: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [id=1, │ +/// │ price * 1.2 > 50] │ +/// └──────────────────────┘ +/// ``` +/// +/// # Example: Push filters within a subtree +/// +/// There are also cases where we may be able to push down filters within a +/// subtree but not the entire tree. A good example of this is aggregation +/// nodes: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [sum > 10] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌───────────────────────┐ +/// │ AggregateExec │ +/// │ group by = [id] │ +/// │ aggregate = │ +/// │ [sum(price)] │ +/// └───────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// The transformation here is to push down the `id=1` filter to the +/// `DataSourceExec` node: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [sum > 10] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌───────────────────────┐ +/// │ AggregateExec │ +/// │ group by = [id] │ +/// │ aggregate = │ +/// │ [sum(price)] │ +/// └───────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// The point here is that: +/// 1. We cannot push down `sum > 10` through the [`AggregateExec`] node into the `DataSourceExec` node. +/// Any filters above the [`AggregateExec`] node are not pushed down. +/// This is determined by calling [`ExecutionPlan::gather_filters_for_pushdown`] on the [`AggregateExec`] node. +/// 2. We need to keep recursing into the tree so that we can discover the other [`FilterExec`] node and push +/// down the `id=1` filter. +/// +/// # Example: Push filters through Joins +/// +/// It is also possible to push down filters through joins and filters that +/// originate from joins. For example, a hash join where we build a hash +/// table of the left side and probe the right side (ignoring why we would +/// choose this order, typically it depends on the size of each table, +/// etc.). +/// +/// ```text +/// ┌─────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [d.size > 100] │ +/// └─────────────────────┘ +/// │ +/// │ +/// ┌──────────▼──────────┐ +/// │ │ +/// │ HashJoinExec │ +/// │ [u.dept@hash(d.id)] │ +/// │ │ +/// └─────────────────────┘ +/// │ +/// ┌────────────┴────────────┐ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ DataSourceExec │ │ DataSourceExec │ +/// │ alias [users as u] │ │ alias [dept as d] │ +/// │ │ │ │ +/// └─────────────────────┘ └─────────────────────┘ +/// ``` +/// +/// There are two pushdowns we can do here: +/// 1. Push down the `d.size > 100` filter through the `HashJoinExec` node to the `DataSourceExec` +/// node for the `departments` table. +/// 2. Push down the hash table state from the `HashJoinExec` node to the `DataSourceExec` node to avoid reading +/// rows from the `users` table that will be eliminated by the join. +/// This can be done via a bloom filter or similar and is not (yet) supported +/// in DataFusion. See . +/// +/// ```text +/// ┌─────────────────────┐ +/// │ │ +/// │ HashJoinExec │ +/// │ [u.dept@hash(d.id)] │ +/// │ │ +/// └─────────────────────┘ +/// │ +/// ┌────────────┴────────────┐ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ DataSourceExec │ │ DataSourceExec │ +/// │ alias [users as u] │ │ alias [dept as d] │ +/// │ filters = │ │ filters = │ +/// │ [depg@hash(d.id)] │ │ [ d.size > 100] │ +/// └─────────────────────┘ └─────────────────────┘ +/// ``` +/// +/// You may notice in this case that the filter is *dynamic*: the hash table +/// is built _after_ the `departments` table is read and at runtime. We +/// don't have a concrete `InList` filter or similar to push down at +/// optimization time. These sorts of dynamic filters are handled by +/// building a specialized [`PhysicalExpr`] that can be evaluated at runtime +/// and internally maintains a reference to the hash table or other state. +/// +/// To make working with these sorts of dynamic filters more tractable we have the method [`PhysicalExpr::snapshot`] +/// which attempts to simplify a dynamic filter into a "basic" non-dynamic filter. +/// For a join this could mean converting it to an `InList` filter or a min/max filter for example. +/// See `datafusion/physical-plan/src/dynamic_filters.rs` for more details. +/// +/// # Example: Push TopK filters into Scans +/// +/// Another form of dynamic filter is pushing down the state of a `TopK` +/// operator for queries like `SELECT * FROM t ORDER BY id LIMIT 10`: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ TopK │ +/// │ limit = 10 │ +/// │ order by = [id] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// We can avoid large amounts of data processing by transforming this into: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ TopK │ +/// │ limit = 10 │ +/// │ order by = [id] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = │ +/// │ [id < @ TopKHeap] │ +/// └──────────────────────┘ +/// ``` +/// +/// Now as we fill our `TopK` heap we can push down the state of the heap to +/// the `DataSourceExec` node to avoid reading files / row groups / pages / +/// rows that could not possibly be in the top 10. +/// +/// This is not yet implemented in DataFusion. See +/// +/// +/// [`PhysicalExpr`]: datafusion_physical_plan::PhysicalExpr +/// [`PhysicalExpr::snapshot`]: datafusion_physical_plan::PhysicalExpr::snapshot +/// [`FilterExec`]: datafusion_physical_plan::filter::FilterExec +/// [`ProjectionExec`]: datafusion_physical_plan::projection::ProjectionExec +/// [`AggregateExec`]: datafusion_physical_plan::aggregates::AggregateExec +#[derive(Debug)] +pub struct FilterPushdown { + phase: FilterPushdownPhase, + name: String, +} + +impl FilterPushdown { + fn new_with_phase(phase: FilterPushdownPhase) -> Self { + let name = match phase { + FilterPushdownPhase::Pre => "FilterPushdown", + FilterPushdownPhase::Post => "FilterPushdown(Post)", + } + .to_string(); + Self { phase, name } + } + + /// Create a new [`FilterPushdown`] optimizer rule that runs in the pre-optimization phase. + /// See [`FilterPushdownPhase`] for more details. + pub fn new() -> Self { + Self::new_with_phase(FilterPushdownPhase::Pre) + } + + /// Create a new [`FilterPushdown`] optimizer rule that runs in the post-optimization phase. + /// See [`FilterPushdownPhase`] for more details. + pub fn new_post_optimization() -> Self { + Self::new_with_phase(FilterPushdownPhase::Post) + } +} + +impl Default for FilterPushdown { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for FilterPushdown { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + Ok( + push_down_filters(Arc::clone(&plan), vec![], config, self.phase)? + .updated_node + .unwrap_or(plan), + ) + } + + fn name(&self) -> &str { + &self.name + } + + fn schema_check(&self) -> bool { + true // Filter pushdown does not change the schema of the plan + } +} + +fn push_down_filters( + node: Arc, + parent_predicates: Vec>, + config: &ConfigOptions, + phase: FilterPushdownPhase, +) -> Result>> { + let mut parent_filter_pushdown_supports: Vec> = + vec![vec![]; parent_predicates.len()]; + let mut self_filters_pushdown_supports = vec![]; + let mut new_children = Vec::with_capacity(node.children().len()); + + let children = node.children(); + + // Filter out expressions that are not allowed for pushdown + let parent_filtered = FilteredVec::new(&parent_predicates, allow_pushdown_for_expr); + + let filter_description = node.gather_filters_for_pushdown( + phase, + parent_filtered.items().to_vec(), + config, + )?; + + let filter_description_parent_filters = filter_description.parent_filters(); + let filter_description_self_filters = filter_description.self_filters(); + if filter_description_parent_filters.len() != children.len() { + return internal_err!( + "Filter pushdown expected FilterDescription to have parent filters for {}, but got {} for node {}", + children.len(), + filter_description_parent_filters.len(), + node.name() + ); + } + if filter_description_self_filters.len() != children.len() { + return internal_err!( + "Filter pushdown expected FilterDescription to have self filters for {}, but got {} for node {}", + children.len(), + filter_description_self_filters.len(), + node.name() + ); + } + + for (child_idx, (child, parent_filters, self_filters)) in izip!( + children, + filter_description.parent_filters(), + filter_description.self_filters() + ) + .enumerate() + { + // Here, `parent_filters` are the predicates which are provided by the parent node of + // the current node, and tried to be pushed down over the child which the loop points + // currently. `self_filters` are the predicates which are provided by the current node, + // and tried to be pushed down over the child similarly. + + // Filter out self_filters that contain volatile expressions and track indices + let self_filtered = FilteredVec::new(&self_filters, allow_pushdown_for_expr); + + let num_self_filters = self_filtered.len(); + let mut all_predicates = self_filtered.items().to_vec(); + + // Apply second filter pass: collect indices of parent filters that can be pushed down + let parent_filters_for_child = parent_filtered + .chain_filter_slice(&parent_filters, |filter| { + matches!(filter.discriminant, PushedDown::Yes) + }); + + // Add the filtered parent predicates to all_predicates + for filter in parent_filters_for_child.items() { + all_predicates.push(Arc::clone(&filter.predicate)); + } + + let num_parent_filters = all_predicates.len() - num_self_filters; + + // Any filters that could not be pushed down to a child are marked as not-supported to our parents + let result = push_down_filters(Arc::clone(child), all_predicates, config, phase)?; + + if let Some(new_child) = result.updated_node { + // If we have a filter pushdown result, we need to update our children + new_children.push(new_child); + } else { + // If we don't have a filter pushdown result, we need to update our children + new_children.push(Arc::clone(child)); + } + + // Our child doesn't know the difference between filters that were passed down + // from our parents and filters that the current node injected. We need to de-entangle + // this since we do need to distinguish between them. + let mut all_filters = result.filters.into_iter().collect_vec(); + if all_filters.len() != num_self_filters + num_parent_filters { + return internal_err!( + "Filter pushdown did not return the expected number of filters: expected {} self filters and {} parent filters, but got {}. Likely culprit is {}", + num_self_filters, + num_parent_filters, + all_filters.len(), + child.name() + ); + } + let parent_filters = all_filters + .split_off(num_self_filters) + .into_iter() + .collect_vec(); + // Map the results from filtered self filters back to their original positions using FilteredVec + let mapped_self_results = + self_filtered.map_results_to_original(all_filters, PushedDown::No); + + // Wrap each result with its corresponding expression + let self_filter_results: Vec<_> = mapped_self_results + .into_iter() + .zip(self_filters) + .map(|(support, filter)| support.wrap_expression(filter)) + .collect(); + + self_filters_pushdown_supports.push(self_filter_results); + + // Start by marking all parent filters as unsupported for this child + for parent_filter_pushdown_support in parent_filter_pushdown_supports.iter_mut() { + parent_filter_pushdown_support.push(PushedDown::No); + assert_eq!( + parent_filter_pushdown_support.len(), + child_idx + 1, + "Parent filter pushdown supports should have the same length as the number of children" + ); + } + // Map results from pushed-down filters back to original parent filter indices + let mapped_parent_results = parent_filters_for_child + .map_results_to_original(parent_filters, PushedDown::No); + + // Update parent_filter_pushdown_supports with the mapped results + // mapped_parent_results already has the results at their original indices + for (idx, support) in parent_filter_pushdown_supports.iter_mut().enumerate() { + support[child_idx] = mapped_parent_results[idx]; + } + } + + // Re-create this node with new children + let updated_node = with_new_children_if_necessary(Arc::clone(&node), new_children)?; + + // TODO: by calling `handle_child_pushdown_result` we are assuming that the + // `ExecutionPlan` implementation will not change the plan itself. + // Should we have a separate method for dynamic pushdown that does not allow modifying the plan? + let mut res = updated_node.handle_child_pushdown_result( + phase, + ChildPushdownResult { + parent_filters: parent_predicates + .into_iter() + .enumerate() + .map( + |(parent_filter_idx, parent_filter)| ChildFilterPushdownResult { + filter: parent_filter, + child_results: parent_filter_pushdown_supports[parent_filter_idx] + .clone(), + }, + ) + .collect(), + self_filters: self_filters_pushdown_supports, + }, + config, + )?; + // Compare pointers for new_node and node, if they are different we must replace + // ourselves because of changes in our children. + if res.updated_node.is_none() && !Arc::ptr_eq(&updated_node, &node) { + res.updated_node = Some(updated_node) + } + Ok(res) +} + +/// A helper structure for filtering elements from a vector through multiple passes while +/// tracking their original indices, allowing results to be mapped back to the original positions. +struct FilteredVec { + items: Vec, + // Chain of index mappings: each Vec maps from current level to previous level + // index_mappings[0] maps from first filter to original indices + // index_mappings[1] maps from second filter to first filter indices, etc. + index_mappings: Vec>, + original_len: usize, +} + +impl FilteredVec { + /// Creates a new FilteredVec by filtering items based on the given predicate + fn new(items: &[T], predicate: F) -> Self + where + F: Fn(&T) -> bool, + { + let mut filtered_items = Vec::new(); + let mut original_indices = Vec::new(); + + for (idx, item) in items.iter().enumerate() { + if predicate(item) { + filtered_items.push(item.clone()); + original_indices.push(idx); + } + } + + Self { + items: filtered_items, + index_mappings: vec![original_indices], + original_len: items.len(), + } + } + + /// Returns a reference to the filtered items + fn items(&self) -> &[T] { + &self.items + } + + /// Returns the number of filtered items + fn len(&self) -> usize { + self.items.len() + } + + /// Maps results from the filtered items back to their original positions + /// Returns a vector with the same length as the original input, filled with default_value + /// and updated with results at their original positions + fn map_results_to_original( + &self, + results: Vec, + default_value: R, + ) -> Vec { + let mut mapped_results = vec![default_value; self.original_len]; + + for (result_idx, result) in results.into_iter().enumerate() { + let original_idx = self.trace_to_original_index(result_idx); + mapped_results[original_idx] = result; + } + + mapped_results + } + + /// Traces a filtered index back to its original index through all filter passes + fn trace_to_original_index(&self, mut current_idx: usize) -> usize { + // Work backwards through the chain of index mappings + for mapping in self.index_mappings.iter().rev() { + current_idx = mapping[current_idx]; + } + current_idx + } + + /// Apply a filter to a new set of items while chaining the index mapping from self (parent) + /// This is useful when you have filtered items and then get a transformed slice + /// (e.g., from gather_filters_for_pushdown) that you need to filter again + fn chain_filter_slice(&self, items: &[U], predicate: F) -> FilteredVec + where + F: Fn(&U) -> bool, + { + let mut filtered_items = Vec::new(); + let mut filtered_indices = Vec::new(); + + for (idx, item) in items.iter().enumerate() { + if predicate(item) { + filtered_items.push(item.clone()); + filtered_indices.push(idx); + } + } + + // Chain the index mappings from parent (self) + let mut index_mappings = self.index_mappings.clone(); + index_mappings.push(filtered_indices); + + FilteredVec { + items: filtered_items, + index_mappings, + original_len: self.original_len, + } + } +} + +fn allow_pushdown_for_expr(expr: &Arc) -> bool { + let mut allow_pushdown = true; + expr.apply(|e| { + allow_pushdown = allow_pushdown && !is_volatile(e); + if allow_pushdown { + Ok(TreeNodeRecursion::Continue) + } else { + Ok(TreeNodeRecursion::Stop) + } + }) + .expect("Infallible traversal of PhysicalExpr tree failed"); + allow_pushdown +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_filtered_vec_single_pass() { + let items = vec![1, 2, 3, 4, 5, 6]; + let filtered = FilteredVec::new(&items, |&x| x % 2 == 0); + + // Check filtered items + assert_eq!(filtered.items(), &[2, 4, 6]); + assert_eq!(filtered.len(), 3); + + // Check index mapping + let results = vec!["a", "b", "c"]; + let mapped = filtered.map_results_to_original(results, "default"); + assert_eq!(mapped, vec!["default", "a", "default", "b", "default", "c"]); + } + + #[test] + fn test_filtered_vec_empty_filter() { + let items = vec![1, 3, 5]; + let filtered = FilteredVec::new(&items, |&x| x % 2 == 0); + + assert_eq!(filtered.items(), &[] as &[i32]); + assert_eq!(filtered.len(), 0); + + let results: Vec<&str> = vec![]; + let mapped = filtered.map_results_to_original(results, "default"); + assert_eq!(mapped, vec!["default", "default", "default"]); + } + + #[test] + fn test_filtered_vec_all_pass() { + let items = vec![2, 4, 6]; + let filtered = FilteredVec::new(&items, |&x| x % 2 == 0); + + assert_eq!(filtered.items(), &[2, 4, 6]); + assert_eq!(filtered.len(), 3); + + let results = vec!["a", "b", "c"]; + let mapped = filtered.map_results_to_original(results, "default"); + assert_eq!(mapped, vec!["a", "b", "c"]); + } + + #[test] + fn test_chain_filter_slice_different_types() { + // First pass: filter numbers + let numbers = vec![1, 2, 3, 4, 5, 6]; + let first_pass = FilteredVec::new(&numbers, |&x| x > 3); + assert_eq!(first_pass.items(), &[4, 5, 6]); + + // Transform to strings (simulating gather_filters_for_pushdown transformation) + let strings = vec!["four", "five", "six"]; + + // Second pass: filter strings that contain 'i' + let second_pass = first_pass.chain_filter_slice(&strings, |s| s.contains('i')); + assert_eq!(second_pass.items(), &["five", "six"]); + + // Map results back to original indices + let results = vec![100, 200]; + let mapped = second_pass.map_results_to_original(results, 0); + // "five" was at index 4 (1-based: 5), "six" was at index 5 (1-based: 6) + assert_eq!(mapped, vec![0, 0, 0, 0, 100, 200]); + } + + #[test] + fn test_chain_filter_slice_complex_scenario() { + // Simulating the filter pushdown scenario + // Parent predicates: [A, B, C, D, E] + let parent_predicates = vec!["A", "B", "C", "D", "E"]; + + // First pass: filter out some predicates (simulating allow_pushdown_for_expr) + let first_pass = FilteredVec::new(&parent_predicates, |s| *s != "B" && *s != "D"); + assert_eq!(first_pass.items(), &["A", "C", "E"]); + + // After gather_filters_for_pushdown, we get transformed results for a specific child + // Let's say child gets [A_transformed, C_transformed, E_transformed] + // but only C and E can be pushed down + #[derive(Clone, Debug, PartialEq)] + struct TransformedPredicate { + name: String, + can_push: bool, + } + + let child_predicates = vec![ + TransformedPredicate { + name: "A_transformed".to_string(), + can_push: false, + }, + TransformedPredicate { + name: "C_transformed".to_string(), + can_push: true, + }, + TransformedPredicate { + name: "E_transformed".to_string(), + can_push: true, + }, + ]; + + // Second pass: filter based on can_push + let second_pass = + first_pass.chain_filter_slice(&child_predicates, |p| p.can_push); + assert_eq!(second_pass.len(), 2); + assert_eq!(second_pass.items()[0].name, "C_transformed"); + assert_eq!(second_pass.items()[1].name, "E_transformed"); + + // Simulate getting results back from child + let child_results = vec!["C_result", "E_result"]; + let mapped = second_pass.map_results_to_original(child_results, "no_result"); + + // Results should be at original positions: C was at index 2, E was at index 4 + assert_eq!( + mapped, + vec![ + "no_result", + "no_result", + "C_result", + "no_result", + "E_result" + ] + ); + } + + #[test] + fn test_trace_to_original_index() { + let items = vec![10, 20, 30, 40, 50]; + let filtered = FilteredVec::new(&items, |&x| x != 20 && x != 40); + + // filtered items are [10, 30, 50] at original indices [0, 2, 4] + assert_eq!(filtered.trace_to_original_index(0), 0); // 10 was at index 0 + assert_eq!(filtered.trace_to_original_index(1), 2); // 30 was at index 2 + assert_eq!(filtered.trace_to_original_index(2), 4); // 50 was at index 4 + } + + #[test] + fn test_chain_filter_preserves_original_len() { + let items = vec![1, 2, 3, 4, 5]; + let first = FilteredVec::new(&items, |&x| x > 2); + + let strings = vec!["three", "four", "five"]; + let second = first.chain_filter_slice(&strings, |s| s.len() == 4); + + // Original length should still be 5 + let results = vec!["x", "y"]; + let mapped = second.map_results_to_original(results, "-"); + assert_eq!(mapped.len(), 5); + } +} diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 5a772ccdd249f..1db4d7b30565e 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -23,10 +23,7 @@ //! pipeline-friendly ones. To achieve the second goal, it selects the proper //! `PartitionMode` and the build side using the available statistics for hash joins. -use std::sync::Arc; - use crate::PhysicalOptimizerRule; - use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -35,12 +32,13 @@ use datafusion_expr_common::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::execution_plan::EmissionType; -use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use std::sync::Arc; /// The [`JoinSelection`] rule tries to modify a given plan so that it can /// accommodate infinite sources and optimize joins in the plan according to @@ -65,8 +63,8 @@ pub(crate) fn should_swap_join_order( // Get the left and right table's total bytes // If both the left and right tables contain total_byte_size statistics, // use `total_byte_size` to determine `should_swap_join_order`, else use `num_rows` - let left_stats = left.statistics()?; - let right_stats = right.statistics()?; + let left_stats = left.partition_statistics(None)?; + let right_stats = right.partition_statistics(None)?; // First compare `total_byte_size` of left and right side, // if information in this field is insufficient fallback to the `num_rows` match ( @@ -91,7 +89,7 @@ fn supports_collect_by_thresholds( ) -> bool { // Currently we do not trust the 0 value from stats, due to stats collection might have bug // TODO check the logic in datasource::get_statistics_with_limit() - let Ok(stats) = plan.statistics() else { + let Ok(stats) = plan.partition_statistics(None) else { return false; }; @@ -104,52 +102,6 @@ fn supports_collect_by_thresholds( } } -/// Predicate that checks whether the given join type supports input swapping. -#[deprecated(since = "45.0.0", note = "use JoinType::supports_swap instead")] -#[allow(dead_code)] -pub(crate) fn supports_swap(join_type: JoinType) -> bool { - join_type.supports_swap() -} - -/// This function returns the new join type we get after swapping the given -/// join's inputs. -#[deprecated(since = "45.0.0", note = "use datafusion-functions-nested instead")] -#[allow(dead_code)] -pub(crate) fn swap_join_type(join_type: JoinType) -> JoinType { - join_type.swap() -} - -/// This function swaps the inputs of the given join operator. -/// This function is public so other downstream projects can use it -/// to construct `HashJoinExec` with right side as the build side. -#[deprecated(since = "45.0.0", note = "use HashJoinExec::swap_inputs instead")] -pub fn swap_hash_join( - hash_join: &HashJoinExec, - partition_mode: PartitionMode, -) -> Result> { - hash_join.swap_inputs(partition_mode) -} - -/// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is required -#[deprecated(since = "45.0.0", note = "use NestedLoopJoinExec::swap_inputs")] -#[allow(dead_code)] -pub(crate) fn swap_nl_join(join: &NestedLoopJoinExec) -> Result> { - join.swap_inputs() -} - -/// Swaps join sides for filter column indices and produces new `JoinFilter` (if exists). -#[deprecated(since = "45.0.0", note = "use filter.map(JoinFilter::swap) instead")] -#[allow(dead_code)] -fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { - filter.map(JoinFilter::swap) -} - -#[deprecated(since = "45.0.0", note = "use JoinFilter::swap instead")] -#[allow(dead_code)] -pub(crate) fn swap_filter(filter: &JoinFilter) -> JoinFilter { - filter.swap() -} - impl PhysicalOptimizerRule for JoinSelection { fn optimize( &self, @@ -245,7 +197,7 @@ pub(crate) fn try_collect_left( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::CollectLeft, - hash_join.null_equals_null(), + hash_join.null_equality(), )?))) } } @@ -257,7 +209,7 @@ pub(crate) fn try_collect_left( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::CollectLeft, - hash_join.null_equals_null(), + hash_join.null_equality(), )?))), (false, true) => { if hash_join.join_type().supports_swap() { @@ -292,7 +244,7 @@ pub(crate) fn partitioned_hash_join( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::Partitioned, - hash_join.null_equals_null(), + hash_join.null_equality(), )?)) } } @@ -459,7 +411,7 @@ fn hash_join_convert_symmetric_subrule( JoinSide::Right => hash_join.right().output_ordering(), JoinSide::None => unreachable!(), } - .map(|p| LexOrdering::new(p.to_vec())) + .cloned() }) .flatten() }; @@ -474,7 +426,7 @@ fn hash_join_convert_symmetric_subrule( hash_join.on().to_vec(), hash_join.filter().cloned(), hash_join.join_type(), - hash_join.null_equals_null(), + hash_join.null_equality(), left_order, right_order, mode, @@ -562,7 +514,11 @@ pub(crate) fn swap_join_according_to_unboundedness( match (*partition_mode, *join_type) { ( _, - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full, + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark + | JoinType::Full, ) => internal_err!("{join_type} join cannot be swapped for unbounded input."), (PartitionMode::Partitioned, _) => { hash_join.swap_inputs(PartitionMode::Partitioned) diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 35503f3b0b5f9..79db43c1cbe94 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -19,23 +19,27 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] pub mod aggregate_statistics; +pub mod coalesce_async_exec_input; pub mod coalesce_batches; pub mod combine_partial_final_agg; pub mod enforce_distribution; pub mod enforce_sorting; +pub mod ensure_coop; +pub mod filter_pushdown; pub mod join_selection; pub mod limit_pushdown; +pub mod limit_pushdown_past_window; pub mod limited_distinct_aggregation; pub mod optimizer; pub mod output_requirements; pub mod projection_pushdown; -pub mod pruning; +pub use datafusion_pruning as pruning; pub mod sanity_checker; pub mod topk_aggregation; pub mod update_aggr_exprs; diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 5887cb51a727b..7469c3af9344c 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -246,16 +246,7 @@ pub fn pushdown_limit_helper( Ok((Transformed::no(pushdown_plan), global_state)) } } else { - // Add fetch or a `LimitExec`: - // If the plan's children have limit and the child's limit < parent's limit, we shouldn't change the global state to true, - // because the children limit will be overridden if the global state is changed. - if !pushdown_plan - .children() - .iter() - .any(|&child| extract_limit(child).is_some()) - { - global_state.satisfied = true; - } + global_state.satisfied = true; pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable { if global_skip > 0 { add_global_limit(plan_with_fetch, global_skip, Some(global_fetch)) diff --git a/datafusion/physical-optimizer/src/limit_pushdown_past_window.rs b/datafusion/physical-optimizer/src/limit_pushdown_past_window.rs new file mode 100644 index 0000000000000..1c671cd074886 --- /dev/null +++ b/datafusion/physical-optimizer/src/limit_pushdown_past_window.rs @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::PhysicalOptimizerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::ScalarValue; +use datafusion_expr::{LimitEffect, WindowFrameBound, WindowFrameUnits}; +use datafusion_physical_expr::window::{ + PlainAggregateWindowExpr, SlidingAggregateWindowExpr, StandardWindowExpr, + StandardWindowFunctionExpr, WindowExpr, +}; +use datafusion_physical_plan::execution_plan::CardinalityEffect; +use datafusion_physical_plan::limit::GlobalLimitExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowUDFExpr}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use std::cmp; +use std::sync::Arc; + +/// This rule inspects [`ExecutionPlan`]'s attempting to find fetch limits that were not pushed +/// down by `LimitPushdown` because [BoundedWindowAggExec]s were "in the way". If the window is +/// bounded by [WindowFrameUnits::Rows] then we calculate the adjustment needed to grow the limit +/// and continue pushdown. +#[derive(Default, Clone, Debug)] +pub struct LimitPushPastWindows; + +impl LimitPushPastWindows { + pub fn new() -> Self { + Self + } +} + +#[derive(Eq, PartialEq)] +enum Phase { + FindOrGrow, + Apply, +} + +#[derive(Default)] +struct TraverseState { + pub limit: Option, + pub lookahead: usize, +} + +impl TraverseState { + pub fn reset_limit(&mut self, limit: Option) { + self.limit = limit; + self.lookahead = 0; + } + + pub fn max_lookahead(&mut self, new_val: usize) { + self.lookahead = self.lookahead.max(new_val); + } +} + +impl PhysicalOptimizerRule for LimitPushPastWindows { + fn optimize( + &self, + original: Arc, + config: &ConfigOptions, + ) -> datafusion_common::Result> { + if !config.optimizer.enable_window_limits { + return Ok(original); + } + let mut ctx = TraverseState::default(); + let mut phase = Phase::FindOrGrow; + let result = original.transform_down(|node| { + // helper closure to DRY out most the early return cases + let reset = |node, + ctx: &mut TraverseState| + -> datafusion_common::Result< + Transformed>, + > { + ctx.limit = None; + ctx.lookahead = 0; + Ok(Transformed::no(node)) + }; + + // traversing sides of joins will require more thought + if node.children().len() > 1 { + return reset(node, &mut ctx); + } + + // grab the latest limit we see + if phase == Phase::FindOrGrow && get_limit(&node, &mut ctx) { + return Ok(Transformed::no(node)); + } + + // grow the limit if we hit a window function + if let Some(window) = node.as_any().downcast_ref::() { + phase = Phase::Apply; + if !grow_limit(window, &mut ctx) { + return reset(node, &mut ctx); + } + return Ok(Transformed::no(node)); + } + + // Apply the limit if we hit a sortpreservingmerge node + if phase == Phase::Apply { + if let Some(out) = apply_limit(&node, &mut ctx) { + return Ok(out); + } + } + + // nodes along the way + if !node.supports_limit_pushdown() { + return reset(node, &mut ctx); + } + if let Some(part) = node.as_any().downcast_ref::() { + let output = part.partitioning().partition_count(); + let input = part.input().output_partitioning().partition_count(); + if output < input { + return reset(node, &mut ctx); + } + } + match node.cardinality_effect() { + CardinalityEffect::Unknown => return reset(node, &mut ctx), + CardinalityEffect::LowerEqual => return reset(node, &mut ctx), + CardinalityEffect::Equal => {} + CardinalityEffect::GreaterEqual => {} + } + + Ok(Transformed::no(node)) + })?; + Ok(result.data) + } + + fn name(&self) -> &str { + "LimitPushPastWindows" + } + + fn schema_check(&self) -> bool { + false // we don't change the schema + } +} + +fn grow_limit(window: &BoundedWindowAggExec, ctx: &mut TraverseState) -> bool { + let mut max_rel = 0; + for expr in window.window_expr().iter() { + // grow based on function requirements + match get_limit_effect(expr) { + LimitEffect::None => {} + LimitEffect::Unknown => return false, + LimitEffect::Relative(rel) => max_rel = max_rel.max(rel), + LimitEffect::Absolute(val) => { + let cur = ctx.limit.unwrap_or(0); + ctx.limit = Some(cur.max(val)) + } + } + + // grow based on frames + let frame = expr.get_window_frame(); + if frame.units != WindowFrameUnits::Rows { + return false; // expression-based limits not statically evaluatable + } + let Some(end_bound) = bound_to_usize(&frame.end_bound) else { + return false; // can't optimize unbounded window expressions + }; + ctx.max_lookahead(end_bound); + } + + // finish grow + ctx.max_lookahead(ctx.lookahead + max_rel); + true +} + +fn apply_limit( + node: &Arc, + ctx: &mut TraverseState, +) -> Option>> { + if !node.as_any().is::() && !node.as_any().is::() { + return None; + } + let latest = ctx.limit.take(); + let Some(fetch) = latest else { + ctx.limit = None; + ctx.lookahead = 0; + return Some(Transformed::no(Arc::clone(node))); + }; + let fetch = match node.fetch() { + None => fetch + ctx.lookahead, + Some(existing) => cmp::min(existing, fetch + ctx.lookahead), + }; + Some(Transformed::complete(node.with_fetch(Some(fetch)).unwrap())) +} + +fn get_limit(node: &Arc, ctx: &mut TraverseState) -> bool { + if let Some(limit) = node.as_any().downcast_ref::() { + ctx.reset_limit(limit.fetch().map(|fetch| fetch + limit.skip())); + return true; + } + if let Some(limit) = node.as_any().downcast_ref::() { + ctx.reset_limit(limit.fetch()); + return true; + } + false +} + +/// Examines the `WindowExpr` and decides: +/// 1. The expression does not change the window size +/// 2. The expression grows it by X amount +/// 3. We don't know +/// +/// # Arguments +/// +/// * `expr` the expression to examine +/// +/// # Returns +/// +/// The effect on the limit +fn get_limit_effect(expr: &Arc) -> LimitEffect { + // White list aggregates + if expr.as_any().is::() + || expr.as_any().is::() + { + return LimitEffect::None; + } + + // Grab the window function + let Some(swe) = expr.as_any().downcast_ref::() else { + return LimitEffect::Unknown; // should be only remaining type + }; + let swfe = swe.get_standard_func_expr(); + let Some(udf) = swfe.as_any().downcast_ref::() else { + return LimitEffect::Unknown; // should be only remaining type + }; + udf.limit_effect() +} + +fn bound_to_usize(bound: &WindowFrameBound) -> Option { + match bound { + WindowFrameBound::Preceding(_) => Some(0), + WindowFrameBound::CurrentRow => Some(0), + WindowFrameBound::Following(ScalarValue::UInt64(Some(scalar))) => { + Some(*scalar as usize) + } + _ => None, + } +} diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index bab31150e2508..4d00f1029db71 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -25,6 +25,8 @@ use crate::coalesce_batches::CoalesceBatches; use crate::combine_partial_final_agg::CombinePartialFinalAggregate; use crate::enforce_distribution::EnforceDistribution; use crate::enforce_sorting::EnforceSorting; +use crate::ensure_coop::EnsureCooperative; +use crate::filter_pushdown::FilterPushdown; use crate::join_selection::JoinSelection; use crate::limit_pushdown::LimitPushdown; use crate::limited_distinct_aggregation::LimitedDistinctAggregation; @@ -34,6 +36,8 @@ use crate::sanity_checker::SanityCheckPlan; use crate::topk_aggregation::TopKAggregation; use crate::update_aggr_exprs::OptimizeAggregateOrder; +use crate::coalesce_async_exec_input::CoalesceAsyncExecInput; +use crate::limit_pushdown_past_window::LimitPushPastWindows; use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_physical_plan::ExecutionPlan; @@ -56,7 +60,7 @@ pub trait PhysicalOptimizerRule: Debug { /// A human readable name for this optimizer rule fn name(&self) -> &str; - /// A flag to indicate whether the physical planner should valid the rule will not + /// A flag to indicate whether the physical planner should validate that the rule will not /// change the schema of the plan after the rewriting. /// Some of the optimization rules might change the nullable properties of the schema /// and should disable the schema check. @@ -94,6 +98,12 @@ impl PhysicalOptimizer { // as that rule may inject other operations in between the different AggregateExecs. // Applying the rule early means only directly-connected AggregateExecs must be examined. Arc::new(LimitedDistinctAggregation::new()), + // The FilterPushdown rule tries to push down filters as far as it can. + // For example, it will push down filtering from a `FilterExec` to `DataSourceExec`. + // Note that this does not push down dynamic filters (such as those created by a `SortExec` operator in TopK mode), + // those are handled by the later `FilterPushdown` rule. + // See `FilterPushdownPhase` for more details. + Arc::new(FilterPushdown::new()), // The EnforceDistribution rule is for adding essential repartitioning to satisfy distribution // requirements. Please make sure that the whole plan tree is determined before this rule. // This rule increases parallelism if doing so is beneficial to the physical plan; i.e. at @@ -113,6 +123,7 @@ impl PhysicalOptimizer { // The CoalesceBatches rule will not influence the distribution and ordering of the // whole plan tree. Therefore, to avoid influencing other rules, it should run last. Arc::new(CoalesceBatches::new()), + Arc::new(CoalesceAsyncExecInput::new()), // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), @@ -121,6 +132,10 @@ impl PhysicalOptimizer { // into an `order by max(x) limit y`. In this case it will copy the limit value down // to the aggregation, allowing it to use only y number of accumulators. Arc::new(TopKAggregation::new()), + // Tries to push limits down through window functions, growing as appropriate + // This can possibly be combined with [LimitPushdown] + // It needs to come after [EnforceSorting] + Arc::new(LimitPushPastWindows::new()), // The LimitPushdown rule tries to push limits down as far as possible, // replacing operators with fetching variants, or adding limits // past operators that support limit pushdown. @@ -132,6 +147,11 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), + Arc::new(EnsureCooperative::new()), + // This FilterPushdown handles dynamic filters that may have references to the source ExecutionPlan. + // Therefore it should be run at the end of the optimization process since any changes to the plan may break the dynamic filter's references. + // See `FilterPushdownPhase` for more details. + Arc::new(FilterPushdown::new_post_optimization()), // The SanityCheckPlan rule checks whether the order and // distribution requirements of each node in the plan // is satisfied. It will also reject non-runnable query diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 3ca0547aa11d8..9e5e980219767 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -30,16 +30,18 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement}; +use datafusion_physical_expr::Distribution; +use datafusion_physical_expr_common::sort_expr::OrderingRequirements; +use datafusion_physical_plan::execution_plan::Boundedness; use datafusion_physical_plan::projection::{ - make_with_child, update_expr, ProjectionExec, + make_with_child, update_expr, update_ordering_requirement, ProjectionExec, }; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + SendableRecordBatchStream, }; -use datafusion_physical_plan::{ExecutionPlanProperties, PlanProperties}; /// This rule either adds or removes [`OutputRequirements`]s to/from the physical /// plan according to its `mode` attribute, which is set by the constructors @@ -94,23 +96,26 @@ enum RuleMode { #[derive(Debug)] pub struct OutputRequirementExec { input: Arc, - order_requirement: Option, + order_requirement: Option, dist_requirement: Distribution, cache: PlanProperties, + fetch: Option, } impl OutputRequirementExec { pub fn new( input: Arc, - requirements: Option, + requirements: Option, dist_requirement: Distribution, + fetch: Option, ) -> Self { - let cache = Self::compute_properties(&input); + let cache = Self::compute_properties(&input, &fetch); Self { input, order_requirement: requirements, dist_requirement, cache, + fetch, } } @@ -119,14 +124,28 @@ impl OutputRequirementExec { } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties(input: &Arc) -> PlanProperties { + fn compute_properties( + input: &Arc, + fetch: &Option, + ) -> PlanProperties { + let boundedness = if fetch.is_some() { + Boundedness::Bounded + } else { + input.boundedness() + }; + PlanProperties::new( input.equivalence_properties().clone(), // Equivalence Properties input.output_partitioning().clone(), // Output Partitioning input.pipeline_behavior(), // Pipeline Behavior - input.boundedness(), // Boundedness + boundedness, // Boundedness ) } + + /// Get fetch + pub fn fetch(&self) -> Option { + self.fetch + } } impl DisplayAs for OutputRequirementExec { @@ -137,10 +156,35 @@ impl DisplayAs for OutputRequirementExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "OutputRequirementExec") + let order_cols = self + .order_requirement + .as_ref() + .map(|reqs| reqs.first()) + .map(|lex| { + let pairs: Vec = lex + .iter() + .map(|req| { + let direction = req + .options + .as_ref() + .map( + |opt| if opt.descending { "desc" } else { "asc" }, + ) + .unwrap_or("unspecified"); + format!("({}, {direction})", req.expr) + }) + .collect(); + format!("[{}]", pairs.join(", ")) + }) + .unwrap_or_else(|| "[]".to_string()); + + write!( + f, + "OutputRequirementExec: order_by={}, dist_by={}", + order_cols, self.dist_requirement + ) } DisplayFormatType::TreeRender => { - // TODO: collect info write!(f, "") } } @@ -176,7 +220,7 @@ impl ExecutionPlan for OutputRequirementExec { vec![&self.input] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.order_requirement.clone()] } @@ -188,6 +232,7 @@ impl ExecutionPlan for OutputRequirementExec { children.remove(0), // has a single child self.order_requirement.clone(), self.dist_requirement.clone(), + self.fetch, ))) } @@ -200,7 +245,11 @@ impl ExecutionPlan for OutputRequirementExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) } fn try_swapping_with_projection( @@ -208,23 +257,23 @@ impl ExecutionPlan for OutputRequirementExec { projection: &ProjectionExec, ) -> Result>> { // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() >= projection.input().schema().fields().len() { + let proj_exprs = projection.expr(); + if proj_exprs.len() >= projection.input().schema().fields().len() { return Ok(None); } - let mut updated_sort_reqs = LexRequirement::new(vec![]); - // None or empty_vec can be treated in the same way. - if let Some(reqs) = &self.required_input_ordering()[0] { - for req in &reqs.inner { - let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? + let mut requirements = self.required_input_ordering().swap_remove(0); + if let Some(reqs) = requirements { + let mut updated_reqs = vec![]; + let (lexes, soft) = reqs.into_alternatives(); + for lex in lexes.into_iter() { + let Some(updated_lex) = update_ordering_requirement(lex, proj_exprs)? else { return Ok(None); }; - updated_sort_reqs.push(PhysicalSortRequirement { - expr: new_expr, - options: req.options, - }); + updated_reqs.push(updated_lex); } + requirements = OrderingRequirements::new_alternatives(updated_reqs, soft); } let dist_req = match &self.required_input_distribution()[0] { @@ -242,15 +291,14 @@ impl ExecutionPlan for OutputRequirementExec { dist => dist.clone(), }; - make_with_child(projection, &self.input()) - .map(|input| { - OutputRequirementExec::new( - input, - (!updated_sort_reqs.is_empty()).then_some(updated_sort_reqs), - dist_req, - ) - }) - .map(|e| Some(Arc::new(e) as _)) + make_with_child(projection, &self.input()).map(|input| { + let e = OutputRequirementExec::new(input, requirements, dist_req, self.fetch); + Some(Arc::new(e) as _) + }) + } + + fn fetch(&self) -> Option { + self.fetch } } @@ -298,6 +346,7 @@ fn require_top_ordering(plan: Arc) -> Result() { - // In case of constant columns, output ordering of SortExec would give an empty set. - // Therefore; we check the sort expression field of the SortExec to assign the requirements. + // In case of constant columns, output ordering of the `SortExec` would + // be an empty set. Therefore; we check the sort expression field to + // assign the requirements. + let req_dist = sort_exec.required_input_distribution().swap_remove(0); let req_ordering = sort_exec.expr(); - let req_dist = sort_exec.required_input_distribution()[0].clone(); - let reqs = LexRequirement::from(req_ordering.clone()); + let reqs = OrderingRequirements::from(req_ordering.clone()); + let fetch = sort_exec.fetch(); + Ok(( - Arc::new(OutputRequirementExec::new(plan, Some(reqs), req_dist)) as _, + Arc::new(OutputRequirementExec::new( + plan, + Some(reqs), + req_dist, + fetch, + )) as _, true, )) } else if let Some(spm) = plan.as_any().downcast_ref::() { - let reqs = LexRequirement::from(spm.expr().clone()); + let reqs = OrderingRequirements::from(spm.expr().clone()); + let fetch = spm.fetch(); Ok(( Arc::new(OutputRequirementExec::new( plan, Some(reqs), Distribution::SinglePartition, + fetch, )) as _, true, )) } else if plan.maintains_input_order()[0] - && plan.required_input_ordering()[0].is_none() + && (plan.required_input_ordering()[0] + .as_ref() + .is_none_or(|o| matches!(o, OrderingRequirements::Soft(_)))) { // Keep searching for a `SortExec` as long as ordering is maintained, // and on-the-way operators do not themselves require an ordering. diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs index 8edbb0f091140..acc70d39f057b 100644 --- a/datafusion/physical-optimizer/src/sanity_checker.rs +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -137,7 +137,8 @@ pub fn check_plan_sanity( ) { let child_eq_props = child.equivalence_properties(); if let Some(sort_req) = sort_req { - if !child_eq_props.ordering_satisfy_requirement(&sort_req) { + let sort_req = sort_req.into_single(); + if !child_eq_props.ordering_satisfy_requirement(sort_req.clone())? { let plan_str = get_plan_string(&plan); return plan_err!( "Plan: {:?} does not satisfy order requirements: {}. Child-{} order: {}", diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index faedea55ca150..b7505f0df4edb 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -25,7 +25,6 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; @@ -111,11 +110,12 @@ impl TopKAggregation { } } else if let Some(proj) = plan.as_any().downcast_ref::() { // track renames due to successive projections - for (src_expr, proj_name) in proj.expr() { - let Some(src_col) = src_expr.as_any().downcast_ref::() else { + for proj_expr in proj.expr() { + let Some(src_col) = proj_expr.expr.as_any().downcast_ref::() + else { continue; }; - if *proj_name == cur_col_name { + if proj_expr.alias == cur_col_name { cur_col_name = src_col.name().to_string(); } } @@ -131,7 +131,7 @@ impl TopKAggregation { Ok(Transformed::no(plan)) }; let child = Arc::clone(child).transform_down(closure).data().ok()?; - let sort = SortExec::new(LexOrdering::new(sort.expr().to_vec()), child) + let sort = SortExec::new(sort.expr().clone(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); Some(Arc::new(sort)) diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs b/datafusion/physical-optimizer/src/update_aggr_exprs.rs index 6228ed10ec341..61bc715592af6 100644 --- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs +++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs @@ -24,15 +24,10 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_datafusion_err, Result}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; -use datafusion_physical_expr::{ - reverse_order_bys, EquivalenceProperties, PhysicalSortRequirement, -}; -use datafusion_physical_expr::{LexOrdering, LexRequirement}; -use datafusion_physical_plan::aggregates::concat_slices; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; +use datafusion_physical_plan::aggregates::{concat_slices, AggregateExec}; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; -use datafusion_physical_plan::{ - aggregates::AggregateExec, ExecutionPlan, ExecutionPlanProperties, -}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use crate::PhysicalOptimizerRule; @@ -90,32 +85,30 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { return Ok(Transformed::no(plan)); } let input = aggr_exec.input(); - let mut aggr_expr = aggr_exec.aggr_expr().to_vec(); + let mut aggr_exprs = aggr_exec.aggr_expr().to_vec(); let groupby_exprs = aggr_exec.group_expr().input_exprs(); // If the existing ordering satisfies a prefix of the GROUP BY // expressions, prefix requirements with this section. In this // case, aggregation will work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, input); + let indices = get_ordered_partition_by_indices(&groupby_exprs, input)?; let requirement = indices .iter() .map(|&idx| { PhysicalSortRequirement::new( - Arc::::clone( - &groupby_exprs[idx], - ), + Arc::clone(&groupby_exprs[idx]), None, ) }) .collect::>(); - aggr_expr = try_convert_aggregate_if_better( - aggr_expr, + aggr_exprs = try_convert_aggregate_if_better( + aggr_exprs, &requirement, input.equivalence_properties(), )?; - let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_expr); + let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_exprs); Ok(Transformed::yes(Arc::new(aggr_exec) as _)) } else { @@ -159,31 +152,30 @@ fn try_convert_aggregate_if_better( aggr_exprs .into_iter() .map(|aggr_expr| { - let aggr_sort_exprs = aggr_expr.order_bys().unwrap_or(LexOrdering::empty()); - let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs); - let aggr_sort_reqs = LexRequirement::from(aggr_sort_exprs.clone()); - let reverse_aggr_req = LexRequirement::from(reverse_aggr_sort_exprs); - + let order_bys = aggr_expr.order_bys(); // If the aggregate expression benefits from input ordering, and // there is an actual ordering enabling this, try to update the // aggregate expression to benefit from the existing ordering. // Otherwise, leave it as is. - if aggr_expr.order_sensitivity().is_beneficial() && !aggr_sort_reqs.is_empty() - { - let reqs = LexRequirement { - inner: concat_slices(prefix_requirement, &aggr_sort_reqs), - }; - - let prefix_requirement = LexRequirement { - inner: prefix_requirement.to_vec(), - }; - - if eq_properties.ordering_satisfy_requirement(&reqs) { + if !aggr_expr.order_sensitivity().is_beneficial() { + Ok(aggr_expr) + } else if !order_bys.is_empty() { + if eq_properties.ordering_satisfy_requirement(concat_slices( + prefix_requirement, + &order_bys + .iter() + .map(|e| e.clone().into()) + .collect::>(), + ))? { // Existing ordering satisfies the aggregator requirements: aggr_expr.with_beneficial_ordering(true)?.map(Arc::new) - } else if eq_properties.ordering_satisfy_requirement(&LexRequirement { - inner: concat_slices(&prefix_requirement, &reverse_aggr_req), - }) { + } else if eq_properties.ordering_satisfy_requirement(concat_slices( + prefix_requirement, + &order_bys + .iter() + .map(|e| e.reverse().into()) + .collect::>(), + ))? { // Converting to reverse enables more efficient execution // given the existing ordering (if possible): aggr_expr diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 57a193315a5c3..3655e555a7440 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -17,8 +17,8 @@ use std::sync::Arc; -use datafusion_physical_expr::LexRequirement; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::Result; +use datafusion_physical_expr::{LexOrdering, LexRequirement}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -40,14 +40,18 @@ pub fn add_sort_above( sort_requirements: LexRequirement, fetch: Option, ) -> PlanContext { - let mut sort_expr = LexOrdering::from(sort_requirements); - sort_expr.retain(|sort_expr| { - !node - .plan + let mut sort_reqs: Vec<_> = sort_requirements.into(); + sort_reqs.retain(|sort_expr| { + node.plan .equivalence_properties() .is_expr_constant(&sort_expr.expr) + .is_none() }); - let mut new_sort = SortExec::new(sort_expr, Arc::clone(&node.plan)).with_fetch(fetch); + let sort_exprs = sort_reqs.into_iter().map(Into::into).collect::>(); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + return node; + }; + let mut new_sort = SortExec::new(ordering, Arc::clone(&node.plan)).with_fetch(fetch); if node.plan.output_partitioning().partition_count() > 1 { new_sort = new_sort.with_preserve_partitioning(true); } @@ -61,15 +65,15 @@ pub fn add_sort_above_with_check( node: PlanContext, sort_requirements: LexRequirement, fetch: Option, -) -> PlanContext { +) -> Result> { if !node .plan .equivalence_properties() - .ordering_satisfy_requirement(&sort_requirements) + .ordering_satisfy_requirement(sort_requirements.clone())? { - add_sort_above(node, sort_requirements, fetch) + Ok(add_sort_above(node, sort_requirements, fetch)) } else { - node + Ok(node) } } diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 1f38e2ed31263..607224782fc46 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -36,6 +36,8 @@ workspace = true [features] force_hash_collisions = [] +tokio_coop = [] +tokio_coop_fallback = [] [lib] name = "datafusion_physical_plan" @@ -47,10 +49,11 @@ arrow-ord = { workspace = true } arrow-schema = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } -datafusion-common = { workspace = true, default-features = true } +datafusion-common = { workspace = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-physical-expr-common = { workspace = true } @@ -81,3 +84,15 @@ tokio = { workspace = true, features = [ [[bench]] harness = false name = "partial_ordering" + +[[bench]] +harness = false +name = "spill_io" + +[[bench]] +harness = false +name = "sort_preserving_merge" + +[[bench]] +harness = false +name = "aggregate_vectorized" diff --git a/datafusion/physical-plan/README.md b/datafusion/physical-plan/README.md index ec604253fd2e5..3a33100f2f350 100644 --- a/datafusion/physical-plan/README.md +++ b/datafusion/physical-plan/README.md @@ -17,11 +17,17 @@ under the License. --> -# DataFusion Physical Plan +# Apache DataFusion Physical Plan -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate is a submodule of DataFusion that contains the `ExecutionPlan` trait and the various implementations of that trait for built in operators such as filters, projections, joins, aggregations, etc. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-plan/benches/aggregate_vectorized.rs b/datafusion/physical-plan/benches/aggregate_vectorized.rs new file mode 100644 index 0000000000000..5c28fcc20440d --- /dev/null +++ b/datafusion/physical-plan/benches/aggregate_vectorized.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::{Int32Type, StringViewType}; +use arrow::util::bench_util::{ + create_primitive_array, create_string_view_array_with_len, + create_string_view_array_with_max_len, +}; +use arrow::util::test_util::seedable_rng; +use arrow_schema::DataType; +use criterion::measurement::WallTime; +use criterion::{ + criterion_group, criterion_main, BenchmarkGroup, BenchmarkId, Criterion, +}; +use datafusion_physical_plan::aggregates::group_values::multi_group_by::bytes_view::ByteViewGroupValueBuilder; +use datafusion_physical_plan::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder; +use datafusion_physical_plan::aggregates::group_values::multi_group_by::GroupColumn; +use rand::distr::{Bernoulli, Distribution}; +use std::sync::Arc; + +const SIZES: [usize; 3] = [1_000, 10_000, 100_000]; +const NULL_DENSITIES: [f32; 3] = [0.0, 0.1, 0.5]; + +fn bench_vectorized_append(c: &mut Criterion) { + byte_view_vectorized_append(c); + primitive_vectorized_append(c); +} + +fn byte_view_vectorized_append(c: &mut Criterion) { + let mut group = c.benchmark_group("ByteViewGroupValueBuilder_vectorized_append"); + + for &size in &SIZES { + let rows: Vec = (0..size).collect(); + + for &null_density in &NULL_DENSITIES { + let input = create_string_view_array_with_len(size, null_density, 8, false); + let input: ArrayRef = Arc::new(input); + + bytes_bench(&mut group, "inline", size, &rows, null_density, &input); + } + } + + for &size in &SIZES { + let rows: Vec = (0..size).collect(); + + for &null_density in &NULL_DENSITIES { + let input = create_string_view_array_with_len(size, null_density, 64, true); + let input: ArrayRef = Arc::new(input); + + bytes_bench(&mut group, "scenario", size, &rows, null_density, &input); + } + } + + for &size in &SIZES { + let rows: Vec = (0..size).collect(); + + for &null_density in &NULL_DENSITIES { + let input = create_string_view_array_with_max_len(size, null_density, 400); + let input: ArrayRef = Arc::new(input); + + bytes_bench(&mut group, "random", size, &rows, null_density, &input); + } + } + + group.finish(); +} + +fn bytes_bench( + group: &mut BenchmarkGroup, + bench_prefix: &str, + size: usize, + rows: &Vec, + null_density: f32, + input: &ArrayRef, +) { + // vectorized_append + let function_name = format!("{bench_prefix}_null_{null_density:.1}_size_{size}"); + let id = BenchmarkId::new(&function_name, "vectorized_append"); + group.bench_function(id, |b| { + b.iter(|| { + let mut builder = ByteViewGroupValueBuilder::::new(); + builder.vectorized_append(input, rows).unwrap(); + }); + }); + + // append_val + let id = BenchmarkId::new(&function_name, "append_val"); + group.bench_function(id, |b| { + b.iter(|| { + let mut builder = ByteViewGroupValueBuilder::::new(); + for &i in rows { + builder.append_val(input, i).unwrap(); + } + }); + }); + + // vectorized_equal_to + vectorized_equal_to( + group, + ByteViewGroupValueBuilder::::new(), + &function_name, + rows, + input, + "all_true", + vec![true; size], + ); + vectorized_equal_to( + group, + ByteViewGroupValueBuilder::::new(), + &function_name, + rows, + input, + "0.75 true", + { + let mut rng = seedable_rng(); + let d = Bernoulli::new(0.75).unwrap(); + (0..size).map(|_| d.sample(&mut rng)).collect::>() + }, + ); + vectorized_equal_to( + group, + ByteViewGroupValueBuilder::::new(), + &function_name, + rows, + input, + "0.5 true", + { + let mut rng = seedable_rng(); + let d = Bernoulli::new(0.5).unwrap(); + (0..size).map(|_| d.sample(&mut rng)).collect::>() + }, + ); + vectorized_equal_to( + group, + ByteViewGroupValueBuilder::::new(), + &function_name, + rows, + input, + "0.25 true", + { + let mut rng = seedable_rng(); + let d = Bernoulli::new(0.25).unwrap(); + (0..size).map(|_| d.sample(&mut rng)).collect::>() + }, + ); + // Not adding 0 true case here as if we optimize for 0 true cases the caller should avoid calling this method at all +} + +fn primitive_vectorized_append(c: &mut Criterion) { + let mut group = c.benchmark_group("PrimitiveGroupValueBuilder_vectorized_append"); + + for &size in &SIZES { + let rows: Vec = (0..size).collect(); + + for &null_density in &NULL_DENSITIES { + if null_density == 0.0 { + bench_single_primitive::(&mut group, size, &rows, null_density) + } + bench_single_primitive::(&mut group, size, &rows, null_density); + } + } + + group.finish(); +} + +fn bench_single_primitive( + group: &mut BenchmarkGroup, + size: usize, + rows: &Vec, + null_density: f32, +) { + if !NULLABLE { + assert_eq!( + null_density, 0.0, + "non-nullable case must have null_density 0" + ); + } + + let input = create_primitive_array::(size, null_density); + let input: ArrayRef = Arc::new(input); + let function_name = format!("null_{null_density:.1}_nullable_{NULLABLE}_size_{size}"); + + // vectorized_append + let id = BenchmarkId::new(&function_name, "vectorized_append"); + group.bench_function(id, |b| { + b.iter(|| { + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int32); + builder.vectorized_append(&input, rows).unwrap(); + }); + }); + + // append_val + let id = BenchmarkId::new(&function_name, "append_val"); + group.bench_function(id, |b| { + b.iter(|| { + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int32); + for &i in rows { + builder.append_val(&input, i).unwrap(); + } + }); + }); + + // vectorized_equal_to + vectorized_equal_to( + group, + PrimitiveGroupValueBuilder::::new(DataType::Int32), + &function_name, + rows, + &input, + "all_true", + vec![true; size], + ); + vectorized_equal_to( + group, + PrimitiveGroupValueBuilder::::new(DataType::Int32), + &function_name, + rows, + &input, + "0.75 true", + { + let mut rng = seedable_rng(); + let d = Bernoulli::new(0.75).unwrap(); + (0..size).map(|_| d.sample(&mut rng)).collect::>() + }, + ); + vectorized_equal_to( + group, + PrimitiveGroupValueBuilder::::new(DataType::Int32), + &function_name, + rows, + &input, + "0.5 true", + { + let mut rng = seedable_rng(); + let d = Bernoulli::new(0.5).unwrap(); + (0..size).map(|_| d.sample(&mut rng)).collect::>() + }, + ); + vectorized_equal_to( + group, + PrimitiveGroupValueBuilder::::new(DataType::Int32), + &function_name, + rows, + &input, + "0.25 true", + { + let mut rng = seedable_rng(); + let d = Bernoulli::new(0.25).unwrap(); + (0..size).map(|_| d.sample(&mut rng)).collect::>() + }, + ); + // Not adding 0 true case here as if we optimize for 0 true cases the caller should avoid calling this method at all +} + +/// Test `vectorized_equal_to` with different number of true in the initial results +fn vectorized_equal_to( + group: &mut BenchmarkGroup, + mut builder: GroupColumnBuilder, + function_name: &str, + rows: &[usize], + input: &ArrayRef, + equal_to_result_description: &str, + equal_to_results: Vec, +) { + let id = BenchmarkId::new( + function_name, + format!("vectorized_equal_to_{equal_to_result_description}"), + ); + group.bench_function(id, |b| { + builder.vectorized_append(input, rows).unwrap(); + + b.iter(|| { + // Cloning is a must as `vectorized_equal_to` will modify the input vec + // and without cloning all benchmarks after the first one won't be meaningful + let mut equal_to_results = equal_to_results.clone(); + builder.vectorized_equal_to(rows, input, rows, &mut equal_to_results); + + // Make sure that the compiler does not optimize away the call + criterion::black_box(equal_to_results); + }); + }); +} + +criterion_group!(benches, bench_vectorized_append); +criterion_main!(benches); diff --git a/datafusion/physical-plan/benches/partial_ordering.rs b/datafusion/physical-plan/benches/partial_ordering.rs index 422826abcc8ba..e1a9d0b583e98 100644 --- a/datafusion/physical-plan/benches/partial_ordering.rs +++ b/datafusion/physical-plan/benches/partial_ordering.rs @@ -18,11 +18,10 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array}; -use arrow_schema::{DataType, Field, Schema, SortOptions}; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::aggregates::order::GroupOrderingPartial; +use criterion::{criterion_group, criterion_main, Criterion}; + const BATCH_SIZE: usize = 8192; fn create_test_arrays(num_columns: usize) -> Vec { @@ -39,31 +38,15 @@ fn bench_new_groups(c: &mut Criterion) { // Test with 1, 2, 4, and 8 order indices for num_columns in [1, 2, 4, 8] { - let fields: Vec = (0..num_columns) - .map(|i| Field::new(format!("col{}", i), DataType::Int32, false)) - .collect(); - let schema = Schema::new(fields); - let order_indices: Vec = (0..num_columns).collect(); - let ordering = LexOrdering::new( - (0..num_columns) - .map(|i| { - PhysicalSortExpr::new( - col(&format!("col{}", i), &schema).unwrap(), - SortOptions::default(), - ) - }) - .collect(), - ); - group.bench_function(format!("order_indices_{}", num_columns), |b| { + group.bench_function(format!("order_indices_{num_columns}"), |b| { let batch_group_values = create_test_arrays(num_columns); let group_indices: Vec = (0..BATCH_SIZE).collect(); b.iter(|| { let mut ordering = - GroupOrderingPartial::try_new(&schema, &order_indices, &ordering) - .unwrap(); + GroupOrderingPartial::try_new(order_indices.clone()).unwrap(); ordering .new_groups(&batch_group_values, &group_indices, BATCH_SIZE) .unwrap(); diff --git a/datafusion/physical-plan/benches/sort_preserving_merge.rs b/datafusion/physical-plan/benches/sort_preserving_merge.rs new file mode 100644 index 0000000000000..f223fd806b694 --- /dev/null +++ b/datafusion/physical-plan/benches/sort_preserving_merge.rs @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, StringArray, UInt64Array}, + record_batch::RecordBatch, +}; +use arrow_schema::{SchemaRef, SortOptions}; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::test::TestMemoryExec; +use datafusion_physical_plan::{ + collect, sorts::sort_preserving_merge::SortPreservingMergeExec, +}; + +use std::sync::Arc; + +const BENCH_ROWS: usize = 1_000_000; // 1 million rows + +fn get_large_string(idx: usize) -> String { + let base_content = [ + concat!( + "# Advanced Topics in Computer Science\n\n", + "## Summary\nThis article explores complex system design patterns and...\n\n", + "```rust\nfn process_data(data: &mut [i32]) {\n // Parallel processing example\n data.par_iter_mut().for_each(|x| *x *= 2);\n}\n```\n\n", + "## Performance Considerations\nWhen implementing concurrent systems...\n" + ), + concat!( + "## API Documentation\n\n", + "```json\n{\n \"endpoint\": \"/api/v2/users\",\n \"methods\": [\"GET\", \"POST\"],\n \"parameters\": {\n \"page\": \"number\"\n }\n}\n```\n\n", + "# Authentication Guide\nSecure your API access using OAuth 2.0...\n" + ), + concat!( + "# Data Processing Pipeline\n\n", + "```python\nfrom multiprocessing import Pool\n\ndef main():\n with Pool(8) as p:\n results = p.map(process_item, data)\n```\n\n", + "## Summary of Optimizations\n1. Batch processing\n2. Memory pooling\n3. Concurrent I/O operations\n" + ), + concat!( + "# System Architecture Overview\n\n", + "## Components\n- Load Balancer\n- Database Cluster\n- Cache Service\n\n", + "```go\nfunc main() {\n router := gin.Default()\n router.GET(\"/api/health\", healthCheck)\n router.Run(\":8080\")\n}\n```\n" + ), + concat!( + "## Configuration Reference\n\n", + "```yaml\nserver:\n port: 8080\n max_threads: 32\n\ndatabase:\n url: postgres://user@prod-db:5432/main\n```\n\n", + "# Deployment Strategies\nBlue-green deployment patterns with...\n" + ), + ]; + base_content[idx % base_content.len()].to_string() +} + +fn generate_sorted_string_column(rows: usize) -> ArrayRef { + let mut values = Vec::with_capacity(rows); + for i in 0..rows { + values.push(get_large_string(i)); + } + values.sort(); + Arc::new(StringArray::from(values)) +} + +fn generate_sorted_u64_column(rows: usize) -> ArrayRef { + Arc::new(UInt64Array::from((0_u64..rows as u64).collect::>())) +} + +fn create_partitions( + num_partitions: usize, + num_columns: usize, + num_rows: usize, +) -> Vec> { + (0..num_partitions) + .map(|_| { + let rows = (0..num_columns) + .map(|i| { + ( + format!("col-{i}"), + if IS_LARGE_COLUMN_TYPE { + generate_sorted_string_column(num_rows) + } else { + generate_sorted_u64_column(num_rows) + }, + ) + }) + .collect::>(); + + let batch = RecordBatch::try_from_iter(rows).unwrap(); + vec![batch] + }) + .collect() +} + +struct BenchData { + bench_name: String, + partitions: Vec>, + schema: SchemaRef, + sort_order: LexOrdering, +} + +fn get_bench_data() -> Vec { + let mut ret = Vec::new(); + let mut push_bench_data = |bench_name: &str, partitions: Vec>| { + let schema = partitions[0][0].schema(); + // Define sort order (col1 ASC, col2 ASC, col3 ASC) + let sort_order = LexOrdering::new(schema.fields().iter().map(|field| { + PhysicalSortExpr::new( + col(field.name(), &schema).unwrap(), + SortOptions::default(), + ) + })) + .unwrap(); + ret.push(BenchData { + bench_name: bench_name.to_string(), + partitions, + schema, + sort_order, + }); + }; + // 1. single large string column + { + let partitions = create_partitions::(3, 1, BENCH_ROWS); + push_bench_data("single_large_string_column_with_1m_rows", partitions); + } + // 2. single u64 column + { + let partitions = create_partitions::(3, 1, BENCH_ROWS); + push_bench_data("single_u64_column_with_1m_rows", partitions); + } + // 3. multiple large string columns + { + let partitions = create_partitions::(3, 3, BENCH_ROWS); + push_bench_data("multiple_large_string_columns_with_1m_rows", partitions); + } + // 4. multiple u64 columns + { + let partitions = create_partitions::(3, 3, BENCH_ROWS); + push_bench_data("multiple_u64_columns_with_1m_rows", partitions); + } + ret +} + +/// Add a benchmark to test the optimization effect of reusing Rows. +/// Run this benchmark with: +/// ```sh +/// cargo bench --features="bench" --bench sort_preserving_merge -- --sample-size=10 +/// ``` +fn bench_merge_sorted_preserving(c: &mut Criterion) { + let task_ctx = Arc::new(TaskContext::default()); + let bench_data = get_bench_data(); + for data in bench_data.into_iter() { + let BenchData { + bench_name, + partitions, + schema, + sort_order, + } = data; + c.bench_function( + &format!("bench_merge_sorted_preserving/{bench_name}"), + |b| { + b.iter_batched( + || { + let exec = TestMemoryExec::try_new_exec( + &partitions, + schema.clone(), + None, + ) + .unwrap(); + Arc::new(SortPreservingMergeExec::new(sort_order.clone(), exec)) + }, + |merge_exec| { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + collect(merge_exec, task_ctx.clone()).await.unwrap(); + }); + }, + BatchSize::LargeInput, + ) + }, + ); + } +} + +criterion_group!(benches, bench_merge_sorted_preserving); +criterion_main!(benches); diff --git a/datafusion/physical-plan/benches/spill_io.rs b/datafusion/physical-plan/benches/spill_io.rs new file mode 100644 index 0000000000000..40c8f7634c8c4 --- /dev/null +++ b/datafusion/physical-plan/benches/spill_io.rs @@ -0,0 +1,580 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Date32Builder, Decimal128Builder, Int32Builder, Int64Builder, RecordBatch, + StringBuilder, +}; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::measurement::WallTime; +use criterion::{ + criterion_group, criterion_main, BatchSize, BenchmarkGroup, BenchmarkId, Criterion, +}; +use datafusion_common::config::SpillCompression; +use datafusion_common::instant::Instant; +use datafusion_execution::memory_pool::human_readable_size; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_physical_plan::common::collect; +use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; +use datafusion_physical_plan::SpillManager; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; +use tokio::runtime::Runtime; + +pub fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Date32, true), + Field::new("c3", DataType::Decimal128(11, 2), true), + ])); + + let mut a = Int32Builder::new(); + let mut b = StringBuilder::new(); + let mut c = Date32Builder::new(); + let mut d = Decimal128Builder::new() + .with_precision_and_scale(11, 2) + .unwrap(); + + for i in 0..num_rows { + a.append_value(i as i32); + c.append_value(i as i32); + d.append_value((i * 1000000) as i128); + if allow_nulls && i % 10 == 0 { + b.append_null(); + } else { + b.append_value(format!("this is string number {i}")); + } + } + + let a = a.finish(); + let b = b.finish(); + let c = c.finish(); + let d = d.finish(); + + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(a), Arc::new(b), Arc::new(c), Arc::new(d)], + ) + .unwrap() +} + +// BENCHMARK: REVALIDATION OVERHEAD COMPARISON +// --------------------------------------------------------- +// To compare performance with/without Arrow IPC validation: +// +// 1. Locate the function `read_spill` +// 2. Modify the `skip_validation` flag: +// - Set to `false` to enable validation +// 3. Rerun `cargo bench --bench spill_io` +fn bench_spill_io(c: &mut Criterion) { + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let schema = Arc::new(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Date32, true), + Field::new("c3", DataType::Decimal128(11, 2), true), + ])); + let spill_manager = SpillManager::new(env, metrics, schema); + + let mut group = c.benchmark_group("spill_io"); + let rt = Runtime::new().unwrap(); + + group.bench_with_input( + BenchmarkId::new("StreamReader/read_100", ""), + &spill_manager, + |b, spill_manager| { + b.iter_batched( + // Setup phase: Create fresh state for each benchmark iteration. + // - generate an ipc file. + // This ensures each iteration starts with clean resources. + || { + let batch = create_batch(8192, true); + spill_manager + .spill_record_batch_and_finish(&vec![batch; 100], "Test") + .unwrap() + .unwrap() + }, + // Benchmark phase: + // - Execute the read operation via SpillManager + // - Wait for the consumer to finish processing + |spill_file| { + rt.block_on(async { + let stream = spill_manager + .read_spill_as_stream(spill_file, None) + .unwrap(); + let _ = collect(stream).await.unwrap(); + }) + }, + BatchSize::LargeInput, + ) + }, + ); + group.finish(); +} + +// Generate `num_batches` RecordBatches mimicking TPC-H Q2's partial aggregate result: +// GROUP BY ps_partkey -> MIN(ps_supplycost) +fn create_q2_like_batches( + num_batches: usize, + num_rows: usize, +) -> (Arc, Vec) { + // use fixed seed + let seed = 2; + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut batches = Vec::with_capacity(num_batches); + + let mut current_key = 400000_i64; + + let schema = Arc::new(Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("min_ps_supplycost", DataType::Decimal128(15, 2), true), + ])); + + for _ in 0..num_batches { + let mut partkey_builder = Int64Builder::new(); + let mut cost_builder = Decimal128Builder::new() + .with_precision_and_scale(15, 2) + .unwrap(); + + for _ in 0..num_rows { + // Occasionally skip a few partkey values to simulate sparsity + let jump = if rng.random_bool(0.05) { + rng.random_range(2..10) + } else { + 1 + }; + current_key += jump; + + let supply_cost = rng.random_range(10_00..100_000) as i128; + + partkey_builder.append_value(current_key); + cost_builder.append_value(supply_cost); + } + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(partkey_builder.finish()), + Arc::new(cost_builder.finish()), + ], + ) + .unwrap(); + + batches.push(batch); + } + + (schema, batches) +} + +/// Generate `num_batches` RecordBatches mimicking TPC-H Q16's partial aggregate result: +/// GROUP BY (p_brand, p_type, p_size) -> COUNT(DISTINCT ps_suppkey) +pub fn create_q16_like_batches( + num_batches: usize, + num_rows: usize, +) -> (Arc, Vec) { + let seed = 16; + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut batches = Vec::with_capacity(num_batches); + + let schema = Arc::new(Schema::new(vec![ + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("alias1", DataType::Int64, false), // COUNT(DISTINCT ps_suppkey) + ])); + + // Representative string pools + let brands = ["Brand#32", "Brand#33", "Brand#41", "Brand#42", "Brand#55"]; + let types = [ + "PROMO ANODIZED NICKEL", + "STANDARD BRUSHED NICKEL", + "PROMO POLISHED COPPER", + "ECONOMY ANODIZED BRASS", + "LARGE BURNISHED COPPER", + "STANDARD POLISHED TIN", + "SMALL PLATED STEEL", + "MEDIUM POLISHED COPPER", + ]; + let sizes = [3, 9, 14, 19, 23, 36, 45, 49]; + + for _ in 0..num_batches { + let mut brand_builder = StringBuilder::new(); + let mut type_builder = StringBuilder::new(); + let mut size_builder = Int32Builder::new(); + let mut count_builder = Int64Builder::new(); + + for _ in 0..num_rows { + let brand = brands[rng.random_range(0..brands.len())]; + let ptype = types[rng.random_range(0..types.len())]; + let size = sizes[rng.random_range(0..sizes.len())]; + let count = rng.random_range(1000..100_000); + + brand_builder.append_value(brand); + type_builder.append_value(ptype); + size_builder.append_value(size); + count_builder.append_value(count); + } + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(brand_builder.finish()), + Arc::new(type_builder.finish()), + Arc::new(size_builder.finish()), + Arc::new(count_builder.finish()), + ], + ) + .unwrap(); + + batches.push(batch); + } + + (schema, batches) +} + +// Generate `num_batches` RecordBatches mimicking TPC-H Q20's partial aggregate result: +// GROUP BY (l_partkey, l_suppkey) -> SUM(l_quantity) +fn create_q20_like_batches( + num_batches: usize, + num_rows: usize, +) -> (Arc, Vec) { + let seed = 20; + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut batches = Vec::with_capacity(num_batches); + + let mut current_partkey = 400000_i64; + + let schema = Arc::new(Schema::new(vec![ + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("sum_l_quantity", DataType::Decimal128(25, 2), true), + ])); + + for _ in 0..num_batches { + let mut partkey_builder = Int64Builder::new(); + let mut suppkey_builder = Int64Builder::new(); + let mut quantity_builder = Decimal128Builder::new() + .with_precision_and_scale(25, 2) + .unwrap(); + + for _ in 0..num_rows { + // Occasionally skip a few partkey values to simulate sparsity + let partkey_jump = if rng.random_bool(0.03) { + rng.random_range(2..6) + } else { + 1 + }; + current_partkey += partkey_jump; + + let suppkey = rng.random_range(10_000..99_999); + let quantity = rng.random_range(500..20_000) as i128; + + partkey_builder.append_value(current_partkey); + suppkey_builder.append_value(suppkey); + quantity_builder.append_value(quantity); + } + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(partkey_builder.finish()), + Arc::new(suppkey_builder.finish()), + Arc::new(quantity_builder.finish()), + ], + ) + .unwrap(); + + batches.push(batch); + } + + (schema, batches) +} + +/// Generate `num_batches` wide RecordBatches resembling sort-tpch Q10 for benchmarking. +/// This includes multiple numeric, date, and Utf8View columns (15 total). +pub fn create_wide_batches( + num_batches: usize, + num_rows: usize, +) -> (Arc, Vec) { + let seed = 10; + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut batches = Vec::with_capacity(num_batches); + + let schema = Arc::new(Schema::new(vec![ + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + ])); + + for _ in 0..num_batches { + let mut linenum = Int32Builder::new(); + let mut suppkey = Int64Builder::new(); + let mut orderkey = Int64Builder::new(); + let mut partkey = Int64Builder::new(); + let mut quantity = Decimal128Builder::new() + .with_precision_and_scale(15, 2) + .unwrap(); + let mut extprice = Decimal128Builder::new() + .with_precision_and_scale(15, 2) + .unwrap(); + let mut discount = Decimal128Builder::new() + .with_precision_and_scale(15, 2) + .unwrap(); + let mut tax = Decimal128Builder::new() + .with_precision_and_scale(15, 2) + .unwrap(); + let mut retflag = StringBuilder::new(); + let mut linestatus = StringBuilder::new(); + let mut shipdate = Date32Builder::new(); + let mut commitdate = Date32Builder::new(); + let mut receiptdate = Date32Builder::new(); + let mut shipinstruct = StringBuilder::new(); + let mut shipmode = StringBuilder::new(); + + let return_flags = ["A", "N", "R"]; + let statuses = ["F", "O"]; + let instructs = ["DELIVER IN PERSON", "COLLECT COD", "NONE"]; + let modes = ["TRUCK", "MAIL", "SHIP", "RAIL", "AIR"]; + + for i in 0..num_rows { + linenum.append_value((i % 7) as i32); + suppkey.append_value(rng.random_range(0..100_000)); + orderkey.append_value(1_000_000 + i as i64); + partkey.append_value(rng.random_range(0..200_000)); + + quantity.append_value(rng.random_range(100..10000) as i128); + extprice.append_value(rng.random_range(1_000..1_000_000) as i128); + discount.append_value(rng.random_range(0..10000) as i128); + tax.append_value(rng.random_range(0..5000) as i128); + + retflag.append_value(return_flags[rng.random_range(0..return_flags.len())]); + linestatus.append_value(statuses[rng.random_range(0..statuses.len())]); + + let base_date = 10_000; + shipdate.append_value(base_date + (i % 1000) as i32); + commitdate.append_value(base_date + (i % 1000) as i32 + 1); + receiptdate.append_value(base_date + (i % 1000) as i32 + 2); + + shipinstruct.append_value(instructs[rng.random_range(0..instructs.len())]); + shipmode.append_value(modes[rng.random_range(0..modes.len())]); + } + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(linenum.finish()), + Arc::new(suppkey.finish()), + Arc::new(orderkey.finish()), + Arc::new(partkey.finish()), + Arc::new(quantity.finish()), + Arc::new(extprice.finish()), + Arc::new(discount.finish()), + Arc::new(tax.finish()), + Arc::new(retflag.finish()), + Arc::new(linestatus.finish()), + Arc::new(shipdate.finish()), + Arc::new(commitdate.finish()), + Arc::new(receiptdate.finish()), + Arc::new(shipinstruct.finish()), + Arc::new(shipmode.finish()), + ], + ) + .unwrap(); + batches.push(batch); + } + (schema, batches) +} + +// Benchmarks spill write + read performance across multiple compression codecs +// using realistic input data inspired by TPC-H aggregate spill scenarios. +// +// This function prepares synthetic RecordBatches that mimic the schema and distribution +// of intermediate aggregate results from representative TPC-H queries (Q2, Q16, Q20) and sort-tpch Q10. +// The schemas of these batches are: +// Q2 [Int64, Decimal128] +// Q16 [Utf8, Utf8, Int32, Int64] +// Q20 [Int64, Int64, Decimal128] +// sort-tpch Q10 (wide batch) [Int32, Int64 * 3, Decimal128 * 4, Date * 3, Utf8 * 4] +// For each dataset: +// - It evaluates spill performance under different compression codecs (e.g., Uncompressed, Zstd, LZ4). +// - It measures end-to-end spill write + read performance using Criterion. +// - It prints the observed memory-to-disk compression ratio for each codec. +// +// This helps evaluate the tradeoffs between compression ratio and runtime overhead for various codecs. +fn bench_spill_compression(c: &mut Criterion) { + let env = Arc::new(RuntimeEnv::default()); + let mut group = c.benchmark_group("spill_compression"); + let rt = Runtime::new().unwrap(); + let compressions = vec![ + SpillCompression::Uncompressed, + SpillCompression::Zstd, + SpillCompression::Lz4Frame, + ]; + + // Modify these values to change data volume. Note that each batch contains `num_rows` rows. + let num_batches = 50; + let num_rows = 8192; + + // Q2 [Int64, Decimal128] + let (schema, batches) = create_q2_like_batches(num_batches, num_rows); + benchmark_spill_batches_for_all_codec( + &mut group, + "q2", + batches, + &compressions, + &rt, + env.clone(), + schema, + ); + // Q16 [Utf8, Utf8, Int32, Int64] + let (schema, batches) = create_q16_like_batches(num_batches, num_rows); + benchmark_spill_batches_for_all_codec( + &mut group, + "q16", + batches, + &compressions, + &rt, + env.clone(), + schema, + ); + // Q20 [Int64, Int64, Decimal128] + let (schema, batches) = create_q20_like_batches(num_batches, num_rows); + benchmark_spill_batches_for_all_codec( + &mut group, + "q20", + batches, + &compressions, + &rt, + env.clone(), + schema, + ); + // sort-tpch Q10 (wide batch) [Int32, Int64 * 3, Decimal128 * 4, Date * 3, Utf8 * 4] + let (schema, batches) = create_wide_batches(num_batches, num_rows); + benchmark_spill_batches_for_all_codec( + &mut group, + "wide", + batches, + &compressions, + &rt, + env, + schema, + ); + group.finish(); +} + +fn benchmark_spill_batches_for_all_codec( + group: &mut BenchmarkGroup<'_, WallTime>, + batch_label: &str, + batches: Vec, + compressions: &[SpillCompression], + rt: &Runtime, + env: Arc, + schema: Arc, +) { + let mem_bytes: usize = batches.iter().map(|b| b.get_array_memory_size()).sum(); + + for &compression in compressions { + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let spill_manager = + SpillManager::new(Arc::clone(&env), metrics.clone(), Arc::clone(&schema)) + .with_compression_type(compression); + + let bench_id = BenchmarkId::new(batch_label, compression.to_string()); + group.bench_with_input(bench_id, &spill_manager, |b, spill_manager| { + b.iter_batched( + || batches.clone(), + |batches| { + rt.block_on(async { + let spill_file = spill_manager + .spill_record_batch_and_finish( + &batches, + &format!("{batch_label}_{compression}"), + ) + .unwrap() + .unwrap(); + let stream = spill_manager + .read_spill_as_stream(spill_file, None) + .unwrap(); + let _ = collect(stream).await.unwrap(); + }) + }, + BatchSize::LargeInput, + ) + }); + + // Run Spilling Read & Write once more to read file size & calculate bandwidth + let start = Instant::now(); + + let spill_file = spill_manager + .spill_record_batch_and_finish( + &batches, + &format!("{batch_label}_{compression}"), + ) + .unwrap() + .unwrap(); + + // calculate write_throughput (includes both compression and I/O time) based on in memory batch size + let write_time = start.elapsed(); + let write_throughput = (mem_bytes as u128 / write_time.as_millis().max(1)) * 1000; + + // calculate compression ratio + let disk_bytes = std::fs::metadata(spill_file.path()) + .expect("metadata read fail") + .len() as usize; + let ratio = mem_bytes as f64 / disk_bytes.max(1) as f64; + + // calculate read_throughput (includes both compression and I/O time) based on in memory batch size + let rt = Runtime::new().unwrap(); + let start = Instant::now(); + rt.block_on(async { + let stream = spill_manager + .read_spill_as_stream(spill_file, None) + .unwrap(); + let _ = collect(stream).await.unwrap(); + }); + let read_time = start.elapsed(); + let read_throughput = (mem_bytes as u128 / read_time.as_millis().max(1)) * 1000; + + println!( + "[{} | {:?}] mem: {}| disk: {}| compression ratio: {:.3}x| throughput: (w) {}/s (r) {}/s", + batch_label, + compression, + human_readable_size(mem_bytes), + human_readable_size(disk_bytes), + ratio, + human_readable_size(write_throughput as usize), + human_readable_size(read_throughput as usize), + ); + } +} + +criterion_group!(benches, bench_spill_io, bench_spill_compression); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ce56ca4f7dfd7..316fbe11ae313 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -28,7 +28,7 @@ use datafusion_common::Result; use datafusion_expr::EmitTo; -pub(crate) mod multi_group_by; +pub mod multi_group_by; mod row; mod single_group_by; @@ -40,8 +40,8 @@ pub(crate) use single_group_by::primitive::HashValue; use crate::aggregates::{ group_values::single_group_by::{ - bytes::GroupValuesByes, bytes_view::GroupValuesBytesView, - primitive::GroupValuesPrimitive, + boolean::GroupValuesBoolean, bytes::GroupValuesBytes, + bytes_view::GroupValuesBytesView, primitive::GroupValuesPrimitive, }, order::GroupOrdering, }; @@ -84,7 +84,7 @@ mod null_builder; /// Each distinct group in a hash aggregation is identified by a unique group id /// (usize) which is assigned by instances of this trait. Group ids are /// continuous without gaps, starting from 0. -pub(crate) trait GroupValues: Send { +pub trait GroupValues: Send { /// Calculates the group id for each input row of `cols`, assigning new /// group ids as necessary. /// @@ -119,15 +119,17 @@ pub(crate) trait GroupValues: Send { /// - If group by single column, and type of this column has /// the specific [`GroupValues`] implementation, such implementation /// will be chosen. -/// +/// /// - If group by multiple columns, and all column types have the specific -/// [`GroupColumn`] implementations, [`GroupValuesColumn`] will be chosen. +/// `GroupColumn` implementations, `GroupValuesColumn` will be chosen. /// -/// - Otherwise, the general implementation [`GroupValuesRows`] will be chosen. +/// - Otherwise, the general implementation `GroupValuesRows` will be chosen. /// -/// [`GroupColumn`]: crate::aggregates::group_values::multi_group_by::GroupColumn +/// `GroupColumn`: crate::aggregates::group_values::multi_group_by::GroupColumn +/// `GroupValuesColumn`: crate::aggregates::group_values::multi_group_by::GroupValuesColumn +/// `GroupValuesRows`: crate::aggregates::group_values::row::GroupValuesRows /// -pub(crate) fn new_group_values( +pub fn new_group_values( schema: SchemaRef, group_ordering: &GroupOrdering, ) -> Result> { @@ -172,23 +174,26 @@ pub(crate) fn new_group_values( downcast_helper!(Decimal128Type, d); } DataType::Utf8 => { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); + return Ok(Box::new(GroupValuesBytes::::new(OutputType::Utf8))); } DataType::LargeUtf8 => { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); + return Ok(Box::new(GroupValuesBytes::::new(OutputType::Utf8))); } DataType::Utf8View => { return Ok(Box::new(GroupValuesBytesView::new(OutputType::Utf8View))); } DataType::Binary => { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); + return Ok(Box::new(GroupValuesBytes::::new(OutputType::Binary))); } DataType::LargeBinary => { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); + return Ok(Box::new(GroupValuesBytes::::new(OutputType::Binary))); } DataType::BinaryView => { return Ok(Box::new(GroupValuesBytesView::new(OutputType::BinaryView))); } + DataType::Boolean => { + return Ok(Box::new(GroupValuesBoolean::new())); + } _ => {} } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/boolean.rs new file mode 100644 index 0000000000000..03e26446f5751 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/boolean.rs @@ -0,0 +1,475 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::aggregates::group_values::multi_group_by::Nulls; +use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::array::{Array as _, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder}; +use datafusion_common::Result; +use itertools::izip; + +/// An implementation of [`GroupColumn`] for booleans +/// +/// Optimized to skip null buffer construction if the input is known to be non nullable +/// +/// # Template parameters +/// +/// `NULLABLE`: if the data can contain any nulls +#[derive(Debug)] +pub struct BooleanGroupValueBuilder { + buffer: BooleanBufferBuilder, + nulls: MaybeNullBufferBuilder, +} + +impl BooleanGroupValueBuilder { + /// Create a new `BooleanGroupValueBuilder` + pub fn new() -> Self { + Self { + buffer: BooleanBufferBuilder::new(0), + nulls: MaybeNullBufferBuilder::new(), + } + } +} + +impl GroupColumn for BooleanGroupValueBuilder { + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + } + + self.buffer.get_bit(lhs_row) == array.as_boolean().value(rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { + if NULLABLE { + if array.is_null(row) { + self.nulls.append(true); + self.buffer.append(bool::default()); + } else { + self.nulls.append(false); + self.buffer.append(array.as_boolean().value(row)); + } + } else { + self.buffer.append(array.as_boolean().value(row)); + } + + Ok(()) + } + + fn vectorized_equal_to( + &self, + lhs_rows: &[usize], + array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut [bool], + ) { + let array = array.as_boolean(); + + let iter = izip!( + lhs_rows.iter(), + rhs_rows.iter(), + equal_to_results.iter_mut(), + ); + + for (&lhs_row, &rhs_row, equal_to_result) in iter { + // Has found not equal to in previous column, don't need to check + if !*equal_to_result { + continue; + } + + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + *equal_to_result = result; + continue; + } + } + + *equal_to_result = self.buffer.get_bit(lhs_row) == array.value(rhs_row); + } + } + + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { + let arr = array.as_boolean(); + + let null_count = array.null_count(); + let num_rows = array.len(); + let all_null_or_non_null = if null_count == 0 { + Nulls::None + } else if null_count == num_rows { + Nulls::All + } else { + Nulls::Some + }; + + match (NULLABLE, all_null_or_non_null) { + (true, Nulls::Some) => { + for &row in rows { + if array.is_null(row) { + self.nulls.append(true); + self.buffer.append(bool::default()); + } else { + self.nulls.append(false); + self.buffer.append(arr.value(row)); + } + } + } + + (true, Nulls::None) => { + self.nulls.append_n(rows.len(), false); + for &row in rows { + self.buffer.append(arr.value(row)); + } + } + + (true, Nulls::All) => { + self.nulls.append_n(rows.len(), true); + self.buffer.append_n(rows.len(), bool::default()); + } + + (false, _) => { + for &row in rows { + self.buffer.append(arr.value(row)); + } + } + } + + Ok(()) + } + + fn len(&self) -> usize { + self.buffer.len() + } + + fn size(&self) -> usize { + self.buffer.capacity() / 8 + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + let Self { mut buffer, nulls } = *self; + + let nulls = nulls.build(); + if !NULLABLE { + assert!(nulls.is_none(), "unexpected nulls in non nullable input"); + } + + let arr = BooleanArray::new(buffer.finish(), nulls); + + Arc::new(arr) + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; + + let mut new_builder = BooleanBufferBuilder::new(self.buffer.len()); + new_builder.append_packed_range(n..self.buffer.len(), self.buffer.as_slice()); + std::mem::swap(&mut new_builder, &mut self.buffer); + + // take only first n values from the original builder + new_builder.truncate(n); + + Arc::new(BooleanArray::new(new_builder.finish(), first_n_nulls)) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::NullBufferBuilder; + + use super::*; + + #[test] + fn test_nullable_boolean_equal_to() { + let append = |builder: &mut BooleanGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index).unwrap(); + } + }; + + let equal_to = |builder: &BooleanGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_nullable_boolean_equal_to_internal(append, equal_to); + } + + #[test] + fn test_nullable_primitive_vectorized_equal_to() { + let append = |builder: &mut BooleanGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); + }; + + let equal_to = |builder: &BooleanGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_nullable_boolean_equal_to_internal(append, equal_to); + } + + fn test_nullable_boolean_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut BooleanGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &BooleanGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = BooleanGroupValueBuilder::::new(); + let builder_array = Arc::new(BooleanArray::from(vec![ + None, + None, + None, + Some(true), + Some(false), + Some(true), + ])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); + + // Define input array + let (values, _nulls) = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + None, + Some(true), + Some(true), + ]) + .into_parts(); + + // explicitly build a null buffer where one of the null values also happens to match + let mut nulls = NullBufferBuilder::new(6); + nulls.append_non_null(); + nulls.append_null(); // this sets Some(false) to null above + nulls.append_null(); + nulls.append_null(); + nulls.append_non_null(); + nulls.append_non_null(); + let input_array = Arc::new(BooleanArray::new(values, nulls.finish())) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; builder.len()]; + equal_to( + &builder, + &[0, 1, 2, 3, 4, 5], + &input_array, + &[0, 1, 2, 3, 4, 5], + &mut equal_to_results, + ); + + assert!(!equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(!equal_to_results[3]); + assert!(!equal_to_results[4]); + assert!(equal_to_results[5]); + } + + #[test] + fn test_not_nullable_primitive_equal_to() { + let append = |builder: &mut BooleanGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + for &index in append_rows { + builder.append_val(builder_array, index).unwrap(); + } + }; + + let equal_to = |builder: &BooleanGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); + } + }; + + test_not_nullable_boolean_equal_to_internal(append, equal_to); + } + + #[test] + fn test_not_nullable_primitive_vectorized_equal_to() { + let append = |builder: &mut BooleanGroupValueBuilder, + builder_array: &ArrayRef, + append_rows: &[usize]| { + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); + }; + + let equal_to = |builder: &BooleanGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut Vec| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; + + test_not_nullable_boolean_equal_to_internal(append, equal_to); + } + + fn test_not_nullable_boolean_equal_to_internal(mut append: A, mut equal_to: E) + where + A: FnMut(&mut BooleanGroupValueBuilder, &ArrayRef, &[usize]), + E: FnMut( + &BooleanGroupValueBuilder, + &[usize], + &ArrayRef, + &[usize], + &mut Vec, + ), + { + // Will cover such cases: + // - values equal + // - values not equal + + // Define PrimitiveGroupValueBuilder + let mut builder = BooleanGroupValueBuilder::::new(); + let builder_array = Arc::new(BooleanArray::from(vec![ + Some(false), + Some(true), + Some(false), + Some(true), + ])) as ArrayRef; + append(&mut builder, &builder_array, &[0, 1, 2, 3]); + + // Define input array + let input_array = Arc::new(BooleanArray::from(vec![ + Some(false), + Some(false), + Some(true), + Some(true), + ])) as ArrayRef; + + // Check + let mut equal_to_results = vec![true; builder.len()]; + equal_to( + &builder, + &[0, 1, 2, 3], + &input_array, + &[0, 1, 2, 3], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(!equal_to_results[1]); + assert!(!equal_to_results[2]); + assert!(equal_to_results[3]); + } + + #[test] + fn test_nullable_boolean_vectorized_operation_special_case() { + // Test the special `all nulls` or `not nulls` input array case + // for vectorized append and equal to + + let mut builder = BooleanGroupValueBuilder::::new(); + + // All nulls input array + let all_nulls_input_array = + Arc::new(BooleanArray::from(vec![None, None, None, None, None])) as _; + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); + + let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[0, 1, 2, 3, 4], + &all_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + + // All not nulls input array + let all_not_nulls_input_array = Arc::new(BooleanArray::from(vec![ + Some(false), + Some(true), + Some(false), + Some(true), + Some(true), + ])) as _; + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); + + let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + builder.vectorized_equal_to( + &[5, 6, 7, 8, 9], + &all_not_nulls_input_array, + &[0, 1, 2, 3, 4], + &mut equal_to_results, + ); + + assert!(equal_to_results[0]); + assert!(equal_to_results[1]); + assert!(equal_to_results[2]); + assert!(equal_to_results[3]); + assert!(equal_to_results[4]); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs index c4525256dbae2..d52721c2ee6c3 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::multi_group_by::{ + nulls_equal_to, GroupColumn, Nulls, +}; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::{ types::GenericStringType, Array, ArrayRef, AsArray, BufferBuilder, @@ -24,6 +26,7 @@ use arrow::array::{ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType}; use datafusion_common::utils::proxy::VecAllocExt; +use datafusion_common::{exec_datafusion_err, Result}; use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; use itertools::izip; use std::mem::size_of; @@ -50,6 +53,8 @@ where offsets: Vec, /// Nulls nulls: MaybeNullBufferBuilder, + /// The maximum size of the buffer for `0` + max_buffer_size: usize, } impl ByteGroupValueBuilder @@ -62,6 +67,11 @@ where buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], nulls: MaybeNullBufferBuilder::new(), + max_buffer_size: if O::IS_LARGE { + i64::MAX as usize + } else { + i32::MAX as usize + }, } } @@ -73,7 +83,7 @@ where self.do_equal_to_inner(lhs_row, array, rhs_row) } - fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) -> Result<()> where B: ByteArrayType, { @@ -85,8 +95,10 @@ where self.offsets.push(O::usize_as(offset)); } else { self.nulls.append(false); - self.do_append_val_inner(arr, row); + self.do_append_val_inner(arr, row)?; } + + Ok(()) } fn vectorized_equal_to_inner( @@ -116,7 +128,11 @@ where } } - fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) + fn vectorized_append_inner( + &mut self, + array: &ArrayRef, + rows: &[usize], + ) -> Result<()> where B: ByteArrayType, { @@ -124,36 +140,28 @@ where let null_count = array.null_count(); let num_rows = array.len(); let all_null_or_non_null = if null_count == 0 { - Some(true) + Nulls::None } else if null_count == num_rows { - Some(false) + Nulls::All } else { - None + Nulls::Some }; match all_null_or_non_null { - None => { + Nulls::Some => { for &row in rows { - if arr.is_null(row) { - self.nulls.append(true); - // nulls need a zero length in the offset buffer - let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - } else { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } + self.append_val_inner::(array, row)? } } - Some(true) => { + Nulls::None => { self.nulls.append_n(rows.len(), false); for &row in rows { - self.do_append_val_inner(arr, row); + self.do_append_val_inner(arr, row)?; } } - Some(false) => { + Nulls::All => { self.nulls.append_n(rows.len(), true); let new_len = self.offsets.len() + rows.len(); @@ -161,6 +169,8 @@ where self.offsets.resize(new_len, O::usize_as(offset)); } } + + Ok(()) } fn do_equal_to_inner( @@ -181,13 +191,26 @@ where self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) } - fn do_append_val_inner(&mut self, array: &GenericByteArray, row: usize) + fn do_append_val_inner( + &mut self, + array: &GenericByteArray, + row: usize, + ) -> Result<()> where B: ByteArrayType, { let value: &[u8] = array.value(row).as_ref(); self.buffer.append_slice(value); + + if self.buffer.len() > self.max_buffer_size { + return Err(exec_datafusion_err!( + "offset overflow, buffer size > {}", + self.max_buffer_size + )); + } + self.offsets.push(O::usize_as(self.buffer.len())); + Ok(()) } /// return the current value of the specified row irrespective of null @@ -224,7 +247,7 @@ where } } - fn append_val(&mut self, column: &ArrayRef, row: usize) { + fn append_val(&mut self, column: &ArrayRef, row: usize) -> Result<()> { // Sanity array type match self.output_type { OutputType::Binary => { @@ -232,17 +255,19 @@ where column.data_type(), DataType::Binary | DataType::LargeBinary )); - self.append_val_inner::>(column, row) + self.append_val_inner::>(column, row)? } OutputType::Utf8 => { debug_assert!(matches!( column.data_type(), DataType::Utf8 | DataType::LargeUtf8 )); - self.append_val_inner::>(column, row) + self.append_val_inner::>(column, row)? } _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; + + Ok(()) } fn vectorized_equal_to( @@ -282,24 +307,26 @@ where } } - fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) -> Result<()> { match self.output_type { OutputType::Binary => { debug_assert!(matches!( column.data_type(), DataType::Binary | DataType::LargeBinary )); - self.vectorized_append_inner::>(column, rows) + self.vectorized_append_inner::>(column, rows)? } OutputType::Utf8 => { debug_assert!(matches!( column.data_type(), DataType::Utf8 | DataType::LargeUtf8 )); - self.vectorized_append_inner::>(column, rows) + self.vectorized_append_inner::>(column, rows)? } _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; + + Ok(()) } fn len(&self) -> usize { @@ -318,6 +345,7 @@ where mut buffer, offsets, nulls, + .. } = *self; let null_buffer = nulls.build(); @@ -406,27 +434,50 @@ mod tests { use crate::aggregates::group_values::multi_group_by::bytes::ByteGroupValueBuilder; use arrow::array::{ArrayRef, NullBufferBuilder, StringArray}; + use datafusion_common::DataFusionError; use datafusion_physical_expr::binary_map::OutputType; use super::GroupColumn; + #[test] + fn test_byte_group_value_builder_overflow() { + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + + let large_string = "a".repeat(1024 * 1024); + + let array = + Arc::new(StringArray::from(vec![Some(large_string.as_str())])) as ArrayRef; + + // Append items until our buffer length is i32::MAX as usize + for _ in 0..2047 { + builder.append_val(&array, 0).unwrap(); + } + + assert!(matches!( + builder.append_val(&array, 0), + Err(DataFusionError::Execution(e)) if e.contains("offset overflow") + )); + + assert_eq!(builder.value(2046), large_string.as_bytes()); + } + #[test] fn test_byte_take_n() { let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef; // a, null, null - builder.append_val(&array, 0); - builder.append_val(&array, 1); - builder.append_val(&array, 1); + builder.append_val(&array, 0).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 1).unwrap(); // (a, null) remaining: null let output = builder.take_n(2); assert_eq!(&output, &array); // null, a, null, a - builder.append_val(&array, 0); - builder.append_val(&array, 1); - builder.append_val(&array, 0); + builder.append_val(&array, 0).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 0).unwrap(); // (null, a) remaining: (null, a) let output = builder.take_n(2); @@ -440,9 +491,9 @@ mod tests { ])) as ArrayRef; // null, a, longstringfortest, null, null - builder.append_val(&array, 2); - builder.append_val(&array, 1); - builder.append_val(&array, 1); + builder.append_val(&array, 2).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 1).unwrap(); // (null, a, longstringfortest, null) remaining: (null) let output = builder.take_n(4); @@ -461,7 +512,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -484,7 +535,9 @@ mod tests { let append = |builder: &mut ByteGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &ByteGroupValueBuilder, @@ -518,7 +571,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -542,7 +597,9 @@ mod tests { Some("string4"), Some("string5"), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -578,7 +635,7 @@ mod tests { // - exist not null, input not null; values not equal // - exist not null, input not null; values equal - // Define PrimitiveGroupValueBuilder + // Define ByteGroupValueBuilder let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); let builder_array = Arc::new(StringArray::from(vec![ None, diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs index b6d97b5d788da..fde477c2cf7b5 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::multi_group_by::{ + nulls_equal_to, GroupColumn, Nulls, +}; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::{make_view, Array, ArrayRef, AsArray, ByteView, GenericByteViewArray}; use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::ByteViewType; +use datafusion_common::Result; use itertools::izip; use std::marker::PhantomData; use std::mem::{replace, size_of}; @@ -70,6 +73,12 @@ pub struct ByteViewGroupValueBuilder { _phantom: PhantomData, } +impl Default for ByteViewGroupValueBuilder { + fn default() -> Self { + Self::new() + } +} + impl ByteViewGroupValueBuilder { pub fn new() -> Self { Self { @@ -138,35 +147,28 @@ impl ByteViewGroupValueBuilder { let null_count = array.null_count(); let num_rows = array.len(); let all_null_or_non_null = if null_count == 0 { - Some(true) + Nulls::None } else if null_count == num_rows { - Some(false) + Nulls::All } else { - None + Nulls::Some }; match all_null_or_non_null { - None => { + Nulls::Some => { for &row in rows { - // Null row case, set and return - if arr.is_valid(row) { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } else { - self.nulls.append(true); - self.views.push(0); - } + self.append_val_inner(array, row); } } - Some(true) => { + Nulls::None => { self.nulls.append_n(rows.len(), false); for &row in rows { self.do_append_val_inner(arr, row); } } - Some(false) => { + Nulls::All => { self.nulls.append_n(rows.len(), true); let new_len = self.views.len() + rows.len(); self.views.resize(new_len, 0); @@ -493,8 +495,9 @@ impl GroupColumn for ByteViewGroupValueBuilder { self.equal_to_inner(lhs_row, array, rhs_row) } - fn append_val(&mut self, array: &ArrayRef, row: usize) { - self.append_val_inner(array, row) + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { + self.append_val_inner(array, row); + Ok(()) } fn vectorized_equal_to( @@ -507,8 +510,9 @@ impl GroupColumn for ByteViewGroupValueBuilder { self.vectorized_equal_to_inner(group_indices, array, rows, equal_to_results); } - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { self.vectorized_append_inner(array, rows); + Ok(()) } fn len(&self) -> usize { @@ -563,7 +567,7 @@ mod tests { ]); let builder_array: ArrayRef = Arc::new(builder_array); for row in 0..builder_array.len() { - builder.append_val(&builder_array, row); + builder.append_val(&builder_array, row).unwrap(); } let output = Box::new(builder).build(); @@ -578,7 +582,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -601,7 +605,9 @@ mod tests { let append = |builder: &mut ByteViewGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &ByteViewGroupValueBuilder, @@ -636,7 +642,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -660,7 +668,9 @@ mod tests { Some("stringview4"), Some("stringview5"), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -841,7 +851,7 @@ mod tests { // ####### Test situation 1~5 ####### for row in 0..first_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert_eq!(builder.completed.len(), 2); @@ -879,7 +889,7 @@ mod tests { assert!(builder.views.is_empty()); for row in first_ones_to_append..first_ones_to_append + second_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert!(builder.completed.is_empty()); @@ -894,7 +904,7 @@ mod tests { ByteViewGroupValueBuilder::::new().with_max_block_size(60); for row in 0..final_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert_eq!(builder.completed.len(), 3); diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index ac96a98edfe11..58bd35d640c39 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -17,15 +17,16 @@ //! `GroupValues` implementations for multi group by cases +mod boolean; mod bytes; -mod bytes_view; -mod primitive; +pub mod bytes_view; +pub mod primitive; use std::mem::{self, size_of}; use crate::aggregates::group_values::multi_group_by::{ - bytes::ByteGroupValueBuilder, bytes_view::ByteViewGroupValueBuilder, - primitive::PrimitiveGroupValueBuilder, + boolean::BooleanGroupValueBuilder, bytes::ByteGroupValueBuilder, + bytes_view::ByteViewGroupValueBuilder, primitive::PrimitiveGroupValueBuilder, }; use crate::aggregates::group_values::GroupValues; use ahash::RandomState; @@ -40,7 +41,7 @@ use arrow::datatypes::{ UInt8Type, }; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{internal_datafusion_err, not_impl_err, Result}; use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; @@ -65,7 +66,7 @@ pub trait GroupColumn: Send + Sync { fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; /// Appends the row at `row` in `array` to this builder - fn append_val(&mut self, array: &ArrayRef, row: usize); + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()>; /// The vectorized version equal to /// @@ -86,11 +87,16 @@ pub trait GroupColumn: Send + Sync { ); /// The vectorized version `append_val` - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]); + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()>; /// Returns the number of rows stored in this builder fn len(&self) -> usize; + /// true if len == 0 + fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Returns the number of bytes used by this [`GroupColumn`] fn size(&self) -> usize; @@ -270,7 +276,7 @@ impl GroupValuesColumn { map_size: 0, group_values: vec![], hashes_buffer: Default::default(), - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } @@ -384,7 +390,7 @@ impl GroupValuesColumn { let mut checklen = 0; let group_idx = self.group_values[0].len(); for (i, group_value) in self.group_values.iter_mut().enumerate() { - group_value.append_val(&cols[i], row); + group_value.append_val(&cols[i], row)?; let len = group_value.len(); if i == 0 { checklen = len; @@ -460,14 +466,14 @@ impl GroupValuesColumn { self.collect_vectorized_process_context(&batch_hashes, groups); // 2. Perform `vectorized_append` - self.vectorized_append(cols); + self.vectorized_append(cols)?; // 3. Perform `vectorized_equal_to` self.vectorized_equal_to(cols, groups); // 4. Perform scalarized inter for remaining rows // (about remaining rows, can see comments for `remaining_row_indices`) - self.scalarized_intern_remaining(cols, &batch_hashes, groups); + self.scalarized_intern_remaining(cols, &batch_hashes, groups)?; self.hashes_buffer = batch_hashes; @@ -563,13 +569,13 @@ impl GroupValuesColumn { } /// Perform `vectorized_append`` for `rows` in `vectorized_append_row_indices` - fn vectorized_append(&mut self, cols: &[ArrayRef]) { + fn vectorized_append(&mut self, cols: &[ArrayRef]) -> Result<()> { if self .vectorized_operation_buffers .append_row_indices .is_empty() { - return; + return Ok(()); } let iter = self.group_values.iter_mut().zip(cols.iter()); @@ -577,8 +583,10 @@ impl GroupValuesColumn { group_column.vectorized_append( col, &self.vectorized_operation_buffers.append_row_indices, - ); + )?; } + + Ok(()) } /// Perform `vectorized_equal_to` @@ -719,13 +727,13 @@ impl GroupValuesColumn { cols: &[ArrayRef], batch_hashes: &[u64], groups: &mut [usize], - ) { + ) -> Result<()> { if self .vectorized_operation_buffers .remaining_row_indices .is_empty() { - return; + return Ok(()); } let mut map = mem::take(&mut self.map); @@ -758,7 +766,7 @@ impl GroupValuesColumn { let group_idx = self.group_values[0].len(); let mut checklen = 0; for (i, group_value) in self.group_values.iter_mut().enumerate() { - group_value.append_val(&cols[i], row); + group_value.append_val(&cols[i], row)?; let len = group_value.len(); if i == 0 { checklen = len; @@ -795,6 +803,7 @@ impl GroupValuesColumn { } self.map = map; + Ok(()) } fn scalarized_equal_to_remaining( @@ -1039,6 +1048,15 @@ impl GroupValues for GroupValuesColumn { let b = ByteViewGroupValueBuilder::::new(); v.push(Box::new(b) as _) } + &DataType::Boolean => { + if nullable { + let b = BooleanGroupValueBuilder::::new(); + v.push(Box::new(b) as _) + } else { + let b = BooleanGroupValueBuilder::::new(); + v.push(Box::new(b) as _) + } + } dt => { return not_impl_err!("{dt} not supported in GroupValuesColumn") } @@ -1162,9 +1180,9 @@ impl GroupValues for GroupValuesColumn { if let DataType::Dictionary(_, v) = expected { let actual = array.data_type(); if v.as_ref() != actual { - return Err(DataFusionError::Internal(format!( + return Err(internal_datafusion_err!( "Converted group rows expected dictionary of {v} got {actual}" - ))); + )); } *array = cast(array.as_ref(), expected)?; } @@ -1228,9 +1246,20 @@ fn supported_type(data_type: &DataType) -> bool { | DataType::Timestamp(_, _) | DataType::Utf8View | DataType::BinaryView + | DataType::Boolean ) } +///Shows how many `null`s there are in an array +enum Nulls { + /// All array items are `null`s + All, + /// There are both `null`s and non-`null`s in the array items + Some, + /// There are no `null`s in the array items + None, +} + #[cfg(test)] mod tests { use std::{collections::HashMap, sync::Arc}; @@ -1733,16 +1762,19 @@ mod tests { } fn check_result(actual_batch: &RecordBatch, expected_batch: &RecordBatch) { - let formatted_actual_batch = pretty_format_batches(&[actual_batch.clone()]) - .unwrap() - .to_string(); + let formatted_actual_batch = + pretty_format_batches(std::slice::from_ref(actual_batch)) + .unwrap() + .to_string(); let mut formatted_actual_batch_sorted: Vec<&str> = formatted_actual_batch.trim().lines().collect(); formatted_actual_batch_sorted.sort_unstable(); - let formatted_expected_batch = pretty_format_batches(&[expected_batch.clone()]) - .unwrap() - .to_string(); + let formatted_expected_batch = + pretty_format_batches(std::slice::from_ref(expected_batch)) + .unwrap() + .to_string(); + let mut formatted_expected_batch_sorted: Vec<&str> = formatted_expected_batch.trim().lines().collect(); formatted_expected_batch_sorted.sort_unstable(); @@ -1756,11 +1788,9 @@ mod tests { (i, actual_line), (i, expected_line), "Inconsistent result\n\n\ - Actual batch:\n{}\n\ - Expected batch:\n{}\n\ + Actual batch:\n{formatted_actual_batch}\n\ + Expected batch:\n{formatted_expected_batch}\n\ ", - formatted_actual_batch, - formatted_expected_batch, ); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index 005dcc8da3863..a586197e50341 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -15,11 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::multi_group_by::{ + nulls_equal_to, GroupColumn, Nulls, +}; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::array::ArrowNativeTypeOp; use arrow::array::{cast::AsArray, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::ScalarBuffer; use arrow::datatypes::DataType; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use itertools::izip; use std::iter; @@ -68,10 +72,10 @@ impl GroupColumn // Otherwise, we need to check their values } - self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) + self.group_values[lhs_row].is_eq(array.as_primitive::().value(rhs_row)) } - fn append_val(&mut self, array: &ArrayRef, row: usize) { + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { // Perf: skip null check if input can't have nulls if NULLABLE { if array.is_null(row) { @@ -84,6 +88,8 @@ impl GroupColumn } else { self.group_values.push(array.as_primitive::().value(row)); } + + Ok(()) } fn vectorized_equal_to( @@ -118,25 +124,25 @@ impl GroupColumn // Otherwise, we need to check their values } - *equal_to_result = self.group_values[lhs_row] == array.value(rhs_row); + *equal_to_result = self.group_values[lhs_row].is_eq(array.value(rhs_row)); } } - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { let arr = array.as_primitive::(); let null_count = array.null_count(); let num_rows = array.len(); let all_null_or_non_null = if null_count == 0 { - Some(true) + Nulls::None } else if null_count == num_rows { - Some(false) + Nulls::All } else { - None + Nulls::Some }; match (NULLABLE, all_null_or_non_null) { - (true, None) => { + (true, Nulls::Some) => { for &row in rows { if array.is_null(row) { self.nulls.append(true); @@ -148,17 +154,17 @@ impl GroupColumn } } - (true, Some(true)) => { + (true, Nulls::None) => { self.nulls.append_n(rows.len(), false); for &row in rows { self.group_values.push(arr.value(row)); } } - (true, Some(false)) => { + (true, Nulls::All) => { self.nulls.append_n(rows.len(), true); self.group_values - .extend(iter::repeat(T::default_value()).take(rows.len())); + .extend(iter::repeat_n(T::default_value(), rows.len())); } (false, _) => { @@ -167,6 +173,8 @@ impl GroupColumn } } } + + Ok(()) } fn len(&self) -> usize { @@ -211,22 +219,22 @@ mod tests { use std::sync::Arc; use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder; - use arrow::array::{ArrayRef, Int64Array, NullBufferBuilder}; - use arrow::datatypes::{DataType, Int64Type}; + use arrow::array::{ArrayRef, Float32Array, Int64Array, NullBufferBuilder}; + use arrow::datatypes::{DataType, Float32Type, Int64Type}; use super::GroupColumn; #[test] fn test_nullable_primitive_equal_to() { - let append = |builder: &mut PrimitiveGroupValueBuilder, + let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; - let equal_to = |builder: &PrimitiveGroupValueBuilder, + let equal_to = |builder: &PrimitiveGroupValueBuilder, lhs_rows: &[usize], input_array: &ArrayRef, rhs_rows: &[usize], @@ -242,13 +250,15 @@ mod tests { #[test] fn test_nullable_primitive_vectorized_equal_to() { - let append = |builder: &mut PrimitiveGroupValueBuilder, + let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; - let equal_to = |builder: &PrimitiveGroupValueBuilder, + let equal_to = |builder: &PrimitiveGroupValueBuilder, lhs_rows: &[usize], input_array: &ArrayRef, rhs_rows: &[usize], @@ -266,9 +276,9 @@ mod tests { fn test_nullable_primitive_equal_to_internal(mut append: A, mut equal_to: E) where - A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), + A: FnMut(&mut PrimitiveGroupValueBuilder, &ArrayRef, &[usize]), E: FnMut( - &PrimitiveGroupValueBuilder, + &PrimitiveGroupValueBuilder, &[usize], &ArrayRef, &[usize], @@ -285,48 +295,58 @@ mod tests { // Define PrimitiveGroupValueBuilder let mut builder = - PrimitiveGroupValueBuilder::::new(DataType::Int64); - let builder_array = Arc::new(Int64Array::from(vec![ + PrimitiveGroupValueBuilder::::new(DataType::Float32); + let builder_array = Arc::new(Float32Array::from(vec![ None, None, None, - Some(1), - Some(2), - Some(3), + Some(1.0), + Some(2.0), + Some(f32::NAN), + Some(3.0), ])) as ArrayRef; - append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]); + append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6]); // Define input array - let (_nulls, values, _) = - Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) - .into_parts(); + let (_, values, _nulls) = Float32Array::from(vec![ + Some(1.0), + Some(2.0), + None, + Some(1.0), + None, + Some(f32::NAN), + None, + ]) + .into_parts(); // explicitly build a null buffer where one of the null values also happens to match let mut nulls = NullBufferBuilder::new(6); nulls.append_non_null(); nulls.append_null(); // this sets Some(2) to null above nulls.append_null(); - nulls.append_null(); nulls.append_non_null(); + nulls.append_null(); nulls.append_non_null(); - let input_array = Arc::new(Int64Array::new(values, nulls.finish())) as ArrayRef; + nulls.append_null(); + let input_array = Arc::new(Float32Array::new(values, nulls.finish())) as ArrayRef; // Check let mut equal_to_results = vec![true; builder.len()]; equal_to( &builder, - &[0, 1, 2, 3, 4, 5], + &[0, 1, 2, 3, 4, 5, 6], &input_array, - &[0, 1, 2, 3, 4, 5], + &[0, 1, 2, 3, 4, 5, 6], &mut equal_to_results, ); assert!(!equal_to_results[0]); assert!(equal_to_results[1]); assert!(equal_to_results[2]); - assert!(!equal_to_results[3]); + assert!(equal_to_results[3]); assert!(!equal_to_results[4]); assert!(equal_to_results[5]); + assert!(!equal_to_results[6]); } #[test] @@ -335,7 +355,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -358,7 +378,9 @@ mod tests { let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &PrimitiveGroupValueBuilder, @@ -432,7 +454,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -456,7 +480,9 @@ mod tests { Some(4), Some(5), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 63751d4703135..34893fcc4ed98 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -82,7 +82,7 @@ impl GroupValuesRows { pub fn try_new(schema: SchemaRef) -> Result { // Print a debugging message, so it is clear when the (slower) fallback // GroupValuesRows is used. - debug!("Creating GroupValuesRows for schema: {}", schema); + debug!("Creating GroupValuesRows for schema: {schema}"); let row_converter = RowConverter::new( schema .fields() @@ -106,7 +106,7 @@ impl GroupValuesRows { group_values: None, hashes_buffer: Default::default(), rows_buffer, - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } } @@ -202,6 +202,7 @@ impl GroupValues for GroupValuesRows { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); + self.map.clear(); output } EmitTo::First(n) => { diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs new file mode 100644 index 0000000000000..44b763a91f523 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::GroupValues; + +use arrow::array::{ + ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder, NullBufferBuilder, + RecordBatch, +}; +use datafusion_common::Result; +use datafusion_expr::EmitTo; +use std::{mem::size_of, sync::Arc}; + +#[derive(Debug)] +pub struct GroupValuesBoolean { + false_group: Option, + true_group: Option, + null_group: Option, +} + +impl GroupValuesBoolean { + pub fn new() -> Self { + Self { + false_group: None, + true_group: None, + null_group: None, + } + } +} + +impl GroupValues for GroupValuesBoolean { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + let array = cols[0].as_boolean(); + groups.clear(); + + for value in array.iter() { + let index = match value { + Some(false) => { + if let Some(index) = self.false_group { + index + } else { + let index = self.len(); + self.false_group = Some(index); + index + } + } + Some(true) => { + if let Some(index) = self.true_group { + index + } else { + let index = self.len(); + self.true_group = Some(index); + index + } + } + None => { + if let Some(index) = self.null_group { + index + } else { + let index = self.len(); + self.null_group = Some(index); + index + } + } + }; + + groups.push(index); + } + + Ok(()) + } + + fn size(&self) -> usize { + size_of::() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + self.false_group.is_some() as usize + + self.true_group.is_some() as usize + + self.null_group.is_some() as usize + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let len = self.len(); + let mut builder = BooleanBufferBuilder::new(len); + let emit_count = match emit_to { + EmitTo::All => len, + EmitTo::First(n) => n, + }; + builder.append_n(emit_count, false); + if let Some(idx) = self.true_group.as_mut() { + if *idx < emit_count { + builder.set_bit(*idx, true); + self.true_group = None; + } else { + *idx -= emit_count; + } + } + + if let Some(idx) = self.false_group.as_mut() { + if *idx < emit_count { + // already false, no need to set + self.false_group = None; + } else { + *idx -= emit_count; + } + } + + let values = builder.finish(); + + let nulls = if let Some(idx) = self.null_group.as_mut() { + if *idx < emit_count { + let mut buffer = NullBufferBuilder::new(len); + buffer.append_n_non_nulls(*idx); + buffer.append_null(); + buffer.append_n_non_nulls(emit_count - *idx - 1); + + self.null_group = None; + Some(buffer.finish().unwrap()) + } else { + *idx -= emit_count; + None + } + } else { + None + }; + + Ok(vec![Arc::new(BooleanArray::new(values, nulls)) as _]) + } + + fn clear_shrink(&mut self, _batch: &RecordBatch) { + self.false_group = None; + self.true_group = None; + self.null_group = None; + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index 9686b8c3521d2..b901aee313fb7 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -15,24 +15,27 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; + use crate::aggregates::group_values::GroupValues; + use arrow::array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; +use datafusion_common::Result; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; -use std::mem::size_of; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format -pub struct GroupValuesByes { +pub struct GroupValuesBytes { /// Map string/binary values to group index map: ArrowBytesMap, /// The total number of groups so far (used to assign group_index) num_groups: usize, } -impl GroupValuesByes { +impl GroupValuesBytes { pub fn new(output_type: OutputType) -> Self { Self { map: ArrowBytesMap::new(output_type), @@ -41,12 +44,8 @@ impl GroupValuesByes { } } -impl GroupValues for GroupValuesByes { - fn intern( - &mut self, - cols: &[ArrayRef], - groups: &mut Vec, - ) -> datafusion_common::Result<()> { +impl GroupValues for GroupValuesBytes { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { assert_eq!(cols.len(), 1); // look up / add entries in the table @@ -85,7 +84,7 @@ impl GroupValues for GroupValuesByes { self.num_groups } - fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + fn emit(&mut self, emit_to: EmitTo) -> Result> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs index 417618ba66af4..89c6b624e8e0a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs @@ -17,6 +17,7 @@ //! `GroupValues` implementations for single group by cases +pub(crate) mod boolean; pub(crate) mod bytes; pub(crate) mod bytes_view; pub(crate) mod primitive; diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index d945d3ddcbf5c..8b1905e540416 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -81,11 +81,14 @@ hash_float!(f16, f32, f64); pub struct GroupValuesPrimitive { /// The data type of the output array data_type: DataType, - /// Stores the group index based on the hash of its value + /// Stores the `(group_index, hash)` based on the hash of its value /// - /// We don't store the hashes as hashing fixed width primitives - /// is fast enough for this not to benefit performance - map: HashTable, + /// We also store `hash` is for reducing cost of rehashing. Such cost + /// is obvious in high cardinality group by situation. + /// More details can see: + /// + /// + map: HashTable<(usize, u64)>, /// The group index of the null value if any null_group: Option, /// The values for each group index @@ -102,7 +105,7 @@ impl GroupValuesPrimitive { map: HashTable::with_capacity(128), values: Vec::with_capacity(128), null_group: None, - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, } } } @@ -127,15 +130,15 @@ where let hash = key.hash(state); let insert = self.map.entry( hash, - |g| unsafe { self.values.get_unchecked(*g).is_eq(key) }, - |g| unsafe { self.values.get_unchecked(*g).hash(state) }, + |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, + |&(_, h)| h, ); match insert { - hashbrown::hash_table::Entry::Occupied(o) => *o.get(), + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, hashbrown::hash_table::Entry::Vacant(v) => { let g = self.values.len(); - v.insert(g); + v.insert((g, hash)); self.values.push(key); g } @@ -148,7 +151,7 @@ where } fn size(&self) -> usize { - self.map.capacity() * size_of::() + self.values.allocated_size() + self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() } fn is_empty(&self) -> bool { @@ -181,12 +184,13 @@ where build_primitive(std::mem::take(&mut self.values), self.null_group.take()) } EmitTo::First(n) => { - self.map.retain(|group_idx| { + self.map.retain(|entry| { // Decrement group index by n + let group_idx = entry.0; match group_idx.checked_sub(n) { // Group index was >= n, shift value down Some(sub) => { - *group_idx = sub; + entry.0 = sub; true } // Group index was < n, so remove from table diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8906468f68db2..878bccc1d1778 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,7 +27,6 @@ use crate::aggregates::{ }; use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, @@ -37,27 +36,36 @@ use crate::{ use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::FieldRef; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ - equivalence::ProjectionMapping, expressions::Column, physical_exprs_contains, - ConstExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, - PhysicalSortRequirement, + physical_exprs_contains, ConstExpr, EquivalenceProperties, +}; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExpr}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::physical_expr::fmt_sql; +use datafusion_expr::utils::AggregateOrderSensitivity; use itertools::Itertools; -pub(crate) mod group_values; +pub mod group_values; mod no_grouping; pub mod order; mod row_hash; mod topk; mod topk_stream; +/// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions. +const AGGREGATION_HASH_SEED: ahash::RandomState = + ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64); + /// Aggregation modes /// /// See [`Accumulator::state`] for background information on multi-phase @@ -261,7 +269,7 @@ impl PhysicalGroupBy { } /// Returns the number expression as grouping keys. - fn num_group_exprs(&self) -> usize { + pub fn num_group_exprs(&self) -> usize { if self.is_single() { self.expr.len() } else { @@ -274,7 +282,7 @@ impl PhysicalGroupBy { } /// Returns the fields that are used as the grouping keys. - fn group_fields(&self, input_schema: &Schema) -> Result> { + fn group_fields(&self, input_schema: &Schema) -> Result> { let mut fields = Vec::with_capacity(self.num_group_exprs()); for ((expr, name), group_expr_nullable) in self.expr.iter().zip(self.exprs_nullable().into_iter()) @@ -285,17 +293,19 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata( - get_field_metadata(expr, input_schema).unwrap_or_default(), - ), + .with_metadata(expr.return_field(input_schema)?.metadata().clone()) + .into(), ); } if !self.is_single() { - fields.push(Field::new( - Aggregate::INTERNAL_GROUPING_ID, - Aggregate::grouping_id_type(self.expr.len()), - false, - )); + fields.push( + Field::new( + Aggregate::INTERNAL_GROUPING_ID, + Aggregate::grouping_id_type(self.expr.len()), + false, + ) + .into(), + ); } Ok(fields) } @@ -304,7 +314,7 @@ impl PhysicalGroupBy { /// /// This might be different from the `group_fields` that might contain internal expressions that /// should not be part of the output schema. - fn output_fields(&self, input_schema: &Schema) -> Result> { + fn output_fields(&self, input_schema: &Schema) -> Result> { let mut fields = self.group_fields(input_schema)?; fields.truncate(self.num_output_exprs()); Ok(fields) @@ -323,10 +333,17 @@ impl PhysicalGroupBy { ) .collect(); let num_exprs = expr.len(); + let groups = if self.expr.is_empty() { + // No GROUP BY expressions - should have no groups + vec![] + } else { + // Has GROUP BY expressions - create a single group + vec![vec![false; num_exprs]] + }; Self { expr, null_expr: vec![], - groups: vec![vec![false; num_exprs]], + groups, } } } @@ -349,6 +366,7 @@ impl PartialEq for PhysicalGroupBy { } } +#[allow(clippy::large_enum_variant)] enum StreamType { AggregateStream(AggregateStream), GroupedHash(GroupedHashAggregateStream), @@ -390,7 +408,7 @@ pub struct AggregateExec { pub input_schema: SchemaRef, /// Execution metrics metrics: ExecutionPlanMetricsSet, - required_input_ordering: Option, + required_input_ordering: Option, /// Describes how the input is ordered relative to the group by columns input_order_mode: InputOrderMode, cache: PlanProperties, @@ -477,16 +495,13 @@ impl AggregateExec { // If existing ordering satisfies a prefix of the GROUP BY expressions, // prefix requirements with this section. In this case, aggregation will // work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); - let mut new_requirement = LexRequirement::new( - indices - .iter() - .map(|&idx| PhysicalSortRequirement { - expr: Arc::clone(&groupby_exprs[idx]), - options: None, - }) - .collect::>(), - ); + let indices = get_ordered_partition_by_indices(&groupby_exprs, &input)?; + let mut new_requirements = indices + .iter() + .map(|&idx| { + PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]), None) + }) + .collect::>(); let req = get_finer_aggregate_exprs_requirement( &mut aggr_expr, @@ -494,8 +509,10 @@ impl AggregateExec { input_eq_properties, &mode, )?; - new_requirement.inner.extend(req); - new_requirement = new_requirement.collapse(); + new_requirements.extend(req); + + let required_input_ordering = + LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft); // If our aggregation has grouping sets then our base grouping exprs will // be expanded based on the flags in `group_by.groups` where for each @@ -520,10 +537,7 @@ impl AggregateExec { // construct a map from the input expression to the output expression of the Aggregation group by let group_expr_mapping = - ProjectionMapping::try_new(&group_by.expr, &input.schema())?; - - let required_input_ordering = - (!new_requirement.is_empty()).then_some(new_requirement); + ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?; let cache = Self::compute_properties( &input, @@ -532,7 +546,7 @@ impl AggregateExec { &mode, &input_order_mode, aggr_expr.as_slice(), - ); + )?; Ok(AggregateExec { mode, @@ -623,7 +637,7 @@ impl AggregateExec { } /// Finds the DataType and SortDirection for this Aggregate, if there is one - pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; agg_expr.get_minmax_desc() } @@ -647,7 +661,7 @@ impl AggregateExec { return false; } // ensure there are no order by expressions - if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) { + if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) { return false; } // ensure there is no output ordering; can this rule be relaxed? @@ -655,8 +669,8 @@ impl AggregateExec { return false; } // ensure no ordering is required on the input - if self.required_input_ordering()[0].is_some() { - return false; + if let Some(requirement) = self.required_input_ordering().swap_remove(0) { + return matches!(requirement, OrderingRequirements::Hard(_)); } true } @@ -669,7 +683,7 @@ impl AggregateExec { mode: &AggregateMode, input_order_mode: &InputOrderMode, aggr_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Construct equivalence properties: let mut eq_properties = input .equivalence_properties() @@ -677,13 +691,12 @@ impl AggregateExec { // If the group by is empty, then we ensure that the operator will produce // only one row, and mark the generated result as a constant value. - if group_expr_mapping.map.is_empty() { - let mut constants = eq_properties.constants().to_vec(); + if group_expr_mapping.is_empty() { let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| { - ConstExpr::new(Arc::new(Column::new(func.name(), idx))) + let column = Arc::new(Column::new(func.name(), idx)); + ConstExpr::from(column as Arc) }); - constants.extend(new_constants); - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(new_constants)?; } // Group by expression will be a distinct value after the aggregation. @@ -691,13 +704,11 @@ impl AggregateExec { let mut constraints = eq_properties.constraints().to_vec(); let new_constraint = Constraint::Unique( group_expr_mapping - .map .iter() - .filter_map(|(_, target_col)| { - target_col - .as_any() - .downcast_ref::() - .map(|c| c.index()) + .flat_map(|(_, target_cols)| { + target_cols.iter().flat_map(|(expr, _)| { + expr.as_any().downcast_ref::().map(|c| c.index()) + }) }) .collect(), ); @@ -724,17 +735,80 @@ impl AggregateExec { input.pipeline_behavior() }; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, emission_type, input.boundedness(), - ) + )) } pub fn input_order_mode(&self) -> &InputOrderMode { &self.input_order_mode } + + fn statistics_inner(&self, child_statistics: Statistics) -> Result { + // TODO stats: group expressions: + // - once expressions will be able to compute their own stats, use it here + // - case where we group by on a column for which with have the `distinct` stat + // TODO stats: aggr expression: + // - aggregations sometimes also preserve invariants such as min, max... + + let column_statistics = { + // self.schema: [, ] + let mut column_statistics = Statistics::unknown_column(&self.schema()); + + for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() { + if let Some(col) = expr.as_any().downcast_ref::() { + column_statistics[idx].max_value = child_statistics.column_statistics + [col.index()] + .max_value + .clone(); + + column_statistics[idx].min_value = child_statistics.column_statistics + [col.index()] + .min_value + .clone(); + } + } + + column_statistics + }; + match self.mode { + AggregateMode::Final | AggregateMode::FinalPartitioned + if self.group_by.expr.is_empty() => + { + Ok(Statistics { + num_rows: Precision::Exact(1), + column_statistics, + total_byte_size: Precision::Absent, + }) + } + _ => { + // When the input row count is 1, we can adopt that statistic keeping its reliability. + // When it is larger than 1, we degrade the precision since it may decrease after aggregation. + let num_rows = if let Some(value) = child_statistics.num_rows.get_value() + { + if *value > 1 { + child_statistics.num_rows.to_inexact() + } else if *value == 0 { + child_statistics.num_rows + } else { + // num_rows = 1 case + let grouping_set_num = self.group_by.groups.len(); + child_statistics.num_rows.map(|x| x * grouping_set_num) + } + } else { + Precision::Absent + }; + Ok(Statistics { + num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } + } + } } impl DisplayAs for AggregateExec { @@ -888,7 +962,7 @@ impl ExecutionPlan for AggregateExec { } } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.required_input_ordering.clone()] } @@ -941,50 +1015,11 @@ impl ExecutionPlan for AggregateExec { } fn statistics(&self) -> Result { - // TODO stats: group expressions: - // - once expressions will be able to compute their own stats, use it here - // - case where we group by on a column for which with have the `distinct` stat - // TODO stats: aggr expression: - // - aggregations sometimes also preserve invariants such as min, max... - let column_statistics = Statistics::unknown_column(&self.schema()); - match self.mode { - AggregateMode::Final | AggregateMode::FinalPartitioned - if self.group_by.expr.is_empty() => - { - Ok(Statistics { - num_rows: Precision::Exact(1), - column_statistics, - total_byte_size: Precision::Absent, - }) - } - _ => { - // When the input row count is 0 or 1, we can adopt that statistic keeping its reliability. - // When it is larger than 1, we degrade the precision since it may decrease after aggregation. - let num_rows = if let Some(value) = - self.input().statistics()?.num_rows.get_value() - { - if *value > 1 { - self.input().statistics()?.num_rows.to_inexact() - } else if *value == 0 { - // Aggregation on an empty table creates a null row. - self.input() - .statistics()? - .num_rows - .add(&Precision::Exact(1)) - } else { - // num_rows = 1 case - self.input().statistics()?.num_rows - } - } else { - Precision::Absent - }; - Ok(Statistics { - num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) - } - } + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.statistics_inner(self.input().partition_statistics(partition)?) } fn cardinality_effect(&self) -> CardinalityEffect { @@ -1035,6 +1070,11 @@ fn create_schema( /// physical GROUP BY expression. /// - `agg_mode`: A reference to an `AggregateMode` instance representing the /// mode of aggregation. +/// - `include_soft_requirement`: When `false`, only hard requirements are +/// considered, as indicated by [`AggregateFunctionExpr::order_sensitivity`] +/// returning [`AggregateOrderSensitivity::HardRequirement`]. +/// Otherwise, also soft requirements ([`AggregateOrderSensitivity::SoftRequirement`]) +/// are considered. /// /// # Returns /// @@ -1044,16 +1084,27 @@ fn get_aggregate_expr_req( aggr_expr: &AggregateFunctionExpr, group_by: &PhysicalGroupBy, agg_mode: &AggregateMode, -) -> LexOrdering { - // If the aggregation function is ordering requirement is not absolutely - // necessary, or the aggregation is performing a "second stage" calculation, - // then ignore the ordering requirement. - if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() { - return LexOrdering::default(); + include_soft_requirement: bool, +) -> Option { + // If the aggregation is performing a "second stage" calculation, + // then ignore the ordering requirement. Ordering requirement applies + // only to the aggregation input data. + if !agg_mode.is_first_stage() { + return None; + } + + match aggr_expr.order_sensitivity() { + AggregateOrderSensitivity::Insensitive => return None, + AggregateOrderSensitivity::HardRequirement => {} + AggregateOrderSensitivity::SoftRequirement => { + if !include_soft_requirement { + return None; + } + } + AggregateOrderSensitivity::Beneficial => return None, } - let mut req = aggr_expr.order_bys().cloned().unwrap_or_default(); - + let mut sort_exprs = aggr_expr.order_bys().to_vec(); // In non-first stage modes, we accumulate data (using `merge_batch`) from // different partitions (i.e. merge partial results). During this merge, we // consider the ordering of each partial result. Hence, we do not need to @@ -1064,38 +1115,11 @@ fn get_aggregate_expr_req( // will definitely be satisfied -- Each group by expression will have // distinct values per group, hence all requirements are satisfied. let physical_exprs = group_by.input_exprs(); - req.retain(|sort_expr| { + sort_exprs.retain(|sort_expr| { !physical_exprs_contains(&physical_exprs, &sort_expr.expr) }); } - req -} - -/// Computes the finer ordering for between given existing ordering requirement -/// of aggregate expression. -/// -/// # Parameters -/// -/// * `existing_req` - The existing lexical ordering that needs refinement. -/// * `aggr_expr` - A reference to an aggregate expression trait object. -/// * `group_by` - Information about the physical grouping (e.g group by expression). -/// * `eq_properties` - Equivalence properties relevant to the computation. -/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.). -/// -/// # Returns -/// -/// An `Option` representing the computed finer lexical ordering, -/// or `None` if there is no finer ordering; e.g. the existing requirement and -/// the aggregator requirement is incompatible. -fn finer_ordering( - existing_req: &LexOrdering, - aggr_expr: &AggregateFunctionExpr, - group_by: &PhysicalGroupBy, - eq_properties: &EquivalenceProperties, - agg_mode: &AggregateMode, -) -> Option { - let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); - eq_properties.get_finer_ordering(existing_req, aggr_req.as_ref()) + LexOrdering::new(sort_exprs) } /// Concatenates the given slices. @@ -1103,7 +1127,23 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { [lhs, rhs].concat() } -/// Get the common requirement that satisfies all the aggregate expressions. +// Determines if the candidate ordering is finer than the current ordering. +// Returns `None` if they are incomparable, `Some(true)` if there is no current +// ordering or candidate ordering is finer, and `Some(false)` otherwise. +fn determine_finer( + current: &Option, + candidate: &LexOrdering, +) -> Option { + if let Some(ordering) = current { + candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt()) + } else { + Some(true) + } +} + +/// Gets the common requirement that satisfies all the aggregate expressions. +/// When possible, chooses the requirement that is already satisfied by the +/// equivalence properties. /// /// # Parameters /// @@ -1118,75 +1158,91 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// /// # Returns /// -/// A `LexRequirement` instance, which is the requirement that satisfies all the -/// aggregate requirements. Returns an error in case of conflicting requirements. +/// A `Result>` instance, which is the requirement +/// that satisfies all the aggregate requirements. Returns an error in case of +/// conflicting requirements. pub fn get_finer_aggregate_exprs_requirement( aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, -) -> Result { - let mut requirement = LexOrdering::default(); - for aggr_expr in aggr_exprs.iter_mut() { - if let Some(finer_ordering) = - finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) - { - if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { - // Requirement is satisfied by existing ordering - requirement = finer_ordering; - continue; - } - } - if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { - if let Some(finer_ordering) = finer_ordering( - &requirement, - &reverse_aggr_expr, +) -> Result> { + let mut requirement = None; + + // First try and find a match for all hard and soft requirements. + // If a match can't be found, try a second time just matching hard + // requirements. + for include_soft_requirement in [false, true] { + for aggr_expr in aggr_exprs.iter_mut() { + let Some(aggr_req) = get_aggregate_expr_req( + aggr_expr, group_by, - eq_properties, agg_mode, - ) { - if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { - // Reverse requirement is satisfied by exiting ordering. - // Hence reverse the aggregator - requirement = finer_ordering; - *aggr_expr = Arc::new(reverse_aggr_expr); + include_soft_requirement, + ) + .and_then(|o| eq_properties.normalize_sort_exprs(o)) else { + // There is no aggregate ordering requirement, or it is trivially + // satisfied -- we can skip this expression. + continue; + }; + // If the common requirement is finer than the current expression's, + // we can skip this expression. If the latter is finer than the former, + // adopt it if it is satisfied by the equivalence properties. Otherwise, + // defer the analysis to the reverse expression. + let forward_finer = determine_finer(&requirement, &aggr_req); + if let Some(finer) = forward_finer { + if !finer { + continue; + } else if eq_properties.ordering_satisfy(aggr_req.clone())? { + requirement = Some(aggr_req); continue; } } - } - if let Some(finer_ordering) = - finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) - { - // There is a requirement that both satisfies existing requirement and current - // aggregate requirement. Use updated requirement - requirement = finer_ordering; - continue; - } - if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { - if let Some(finer_ordering) = finer_ordering( - &requirement, - &reverse_aggr_expr, - group_by, - eq_properties, - agg_mode, - ) { - // There is a requirement that both satisfies existing requirement and reverse - // aggregate requirement. Use updated requirement - requirement = finer_ordering; - *aggr_expr = Arc::new(reverse_aggr_expr); - continue; + if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { + let Some(rev_aggr_req) = get_aggregate_expr_req( + &reverse_aggr_expr, + group_by, + agg_mode, + include_soft_requirement, + ) + .and_then(|o| eq_properties.normalize_sort_exprs(o)) else { + // The reverse requirement is trivially satisfied -- just reverse + // the expression and continue with the next one: + *aggr_expr = Arc::new(reverse_aggr_expr); + continue; + }; + // If the common requirement is finer than the reverse expression's, + // just reverse it and continue the loop with the next aggregate + // expression. If the latter is finer than the former, adopt it if + // it is satisfied by the equivalence properties. Otherwise, adopt + // the forward expression. + if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) { + if !finer { + *aggr_expr = Arc::new(reverse_aggr_expr); + } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? { + *aggr_expr = Arc::new(reverse_aggr_expr); + requirement = Some(rev_aggr_req); + } else { + requirement = Some(aggr_req); + } + } else if forward_finer.is_some() { + requirement = Some(aggr_req); + } else { + // Neither the existing requirement nor the current aggregate + // requirement satisfy the other (forward or reverse), this + // means they are conflicting. This is a problem only for hard + // requirements. Unsatisfied soft requirements can be ignored. + if !include_soft_requirement { + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); + } + } } } - - // Neither the existing requirement and current aggregate requirement satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); } - Ok(LexRequirement::from(requirement)) + Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect())) } /// Returns physical expressions for arguments to evaluate against a batch. @@ -1209,9 +1265,7 @@ pub fn aggregate_expressions( // Append ordering requirements to expressions' results. This // way order sensitive aggregators can satisfy requirement // themselves. - if let Some(ordering_req) = agg.order_bys() { - result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr))); - } + result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr))); result }) .collect()), @@ -1306,7 +1360,7 @@ fn evaluate( } /// Evaluates expressions against a record batch. -pub(crate) fn evaluate_many( +pub fn evaluate_many( expr: &[Vec>], batch: &RecordBatch, ) -> Result>> { @@ -1360,7 +1414,7 @@ fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { /// The outer Vec appears to be for grouping sets /// The inner Vec contains the results per expression /// The inner-inner Array contains the results per row -pub(crate) fn evaluate_group_by( +pub fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { @@ -1924,6 +1978,13 @@ mod tests { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(self.schema().as_ref())); + } let (_, batches) = some_data(); Ok(common::compute_record_batch_statistics( &[batches], @@ -2211,14 +2272,14 @@ mod tests { schema: &Schema, sort_options: SortOptions, ) -> Result> { - let ordering_req = [PhysicalSortExpr { + let order_bys = vec![PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; let args = [col("b", schema)?]; AggregateExprBuilder::new(first_value_udaf(), args.to_vec()) - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_bys) .schema(Arc::new(schema.clone())) .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) .build() @@ -2230,13 +2291,13 @@ mod tests { schema: &Schema, sort_options: SortOptions, ) -> Result> { - let ordering_req = [PhysicalSortExpr { + let order_bys = vec![PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; let args = [col("b", schema)?]; AggregateExprBuilder::new(last_value_udaf(), args.to_vec()) - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_bys) .schema(Arc::new(schema.clone())) .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) .build() @@ -2358,9 +2419,7 @@ mod tests { async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; - // Assume column a and b are aliases - // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). - let options1 = SortOptions { + let options = SortOptions { descending: false, nulls_first: false, }; @@ -2369,58 +2428,51 @@ mod tests { let col_c = &col("c", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Columns a and b are equal. - eq_properties.add_equal_conditions(col_a, col_b)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?; // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively let order_by_exprs = vec![ - None, - Some(vec![PhysicalSortExpr { + vec![], + vec![PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, - }]), - Some(vec![ + options, + }], + vec![ PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_b), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_c), - options: options1, + options, }, - ]), - Some(vec![ + ], + vec![ PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_b), - options: options1, + options, }, - ]), + ], ]; - let common_requirement = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: options1, - }, - PhysicalSortExpr { - expr: Arc::clone(col_c), - options: options1, - }, - ]); + let common_requirement = vec![ + PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)), + PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)), + ]; let mut aggr_exprs = order_by_exprs .into_iter() .map(|order_by_expr| { - let ordering_req = order_by_expr.unwrap_or_default(); AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)]) .alias("a") - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_by_expr) .schema(Arc::clone(&test_schema)) .build() .map(Arc::new) @@ -2428,14 +2480,13 @@ mod tests { }) .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); - let res = get_finer_aggregate_exprs_requirement( + let result = get_finer_aggregate_exprs_requirement( &mut aggr_exprs, &group_by, &eq_properties, &AggregateMode::Partial, )?; - let res = LexOrdering::from(res); - assert_eq!(res, common_requirement); + assert_eq!(result, common_requirement); Ok(()) } diff --git a/datafusion/physical-plan/src/aggregates/order/full.rs b/datafusion/physical-plan/src/aggregates/order/full.rs index 218855459b1e2..eb98611f79dfb 100644 --- a/datafusion/physical-plan/src/aggregates/order/full.rs +++ b/datafusion/physical-plan/src/aggregates/order/full.rs @@ -92,7 +92,7 @@ impl GroupOrderingFull { Some(EmitTo::First(*current)) } } - State::Complete { .. } => Some(EmitTo::All), + State::Complete => Some(EmitTo::All), } } @@ -106,7 +106,7 @@ impl GroupOrderingFull { assert!(*current >= n); *current -= n; } - State::Complete { .. } => panic!("invalid state: complete"), + State::Complete => panic!("invalid state: complete"), } } @@ -133,7 +133,7 @@ impl GroupOrderingFull { current: max_group_index, } } - State::Complete { .. } => { + State::Complete => { panic!("Saw new group after input was complete"); } }; diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 0b742b3d20fdc..bbcb30d877cf0 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; + use arrow::array::ArrayRef; -use arrow::datatypes::Schema; use datafusion_common::Result; use datafusion_expr::EmitTo; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::mem::size_of; mod full; mod partial; @@ -42,15 +41,11 @@ pub enum GroupOrdering { impl GroupOrdering { /// Create a `GroupOrdering` for the specified ordering - pub fn try_new( - input_schema: &Schema, - mode: &InputOrderMode, - ordering: &LexOrdering, - ) -> Result { + pub fn try_new(mode: &InputOrderMode) -> Result { match mode { InputOrderMode::Linear => Ok(GroupOrdering::None), InputOrderMode::PartiallySorted(order_indices) => { - GroupOrderingPartial::try_new(input_schema, order_indices, ordering) + GroupOrderingPartial::try_new(order_indices.clone()) .map(GroupOrdering::Partial) } InputOrderMode::Sorted => Ok(GroupOrdering::Full(GroupOrderingFull::new())), diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index aff69277a4cef..3e495900f77a1 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -15,18 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; +use std::mem::size_of; +use std::sync::Arc; + use arrow::array::ArrayRef; use arrow::compute::SortOptions; -use arrow::datatypes::Schema; use arrow_ord::partition::partition; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{Result, ScalarValue}; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::cmp::Ordering; -use std::mem::size_of; -use std::sync::Arc; /// Tracks grouping state when the data is ordered by some subset of /// the group keys. @@ -118,17 +117,11 @@ impl State { impl GroupOrderingPartial { /// TODO: Remove unnecessary `input_schema` parameter. - pub fn try_new( - _input_schema: &Schema, - order_indices: &[usize], - ordering: &LexOrdering, - ) -> Result { - assert!(!order_indices.is_empty()); - assert!(order_indices.len() <= ordering.len()); - + pub fn try_new(order_indices: Vec) -> Result { + debug_assert!(!order_indices.is_empty()); Ok(Self { state: State::Start, - order_indices: order_indices.to_vec(), + order_indices, }) } @@ -181,7 +174,7 @@ impl GroupOrderingPartial { assert!(*current_sort >= n); *current_sort -= n; } - State::Complete { .. } => panic!("invalid state: complete"), + State::Complete => panic!("invalid state: complete"), } } @@ -276,29 +269,15 @@ impl GroupOrderingPartial { #[cfg(test)] mod tests { - use arrow::array::Int32Array; - use arrow_schema::{DataType, Field}; - use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; - use super::*; + use arrow::array::Int32Array; + #[test] fn test_group_ordering_partial() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - // Ordered on column a let order_indices = vec![0]; - - let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( - col("a", &schema)?, - SortOptions::default(), - )]); - - let mut group_ordering = - GroupOrderingPartial::try_new(&schema, &order_indices, &ordering)?; + let mut group_ordering = GroupOrderingPartial::try_new(order_indices)?; let batch_group_values: Vec = vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 077f18d510339..6132a8b0add52 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -21,6 +21,8 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; +use super::order::GroupOrdering; +use super::AggregateExec; use crate::aggregates::group_values::{new_group_values, GroupValues}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ @@ -29,28 +31,24 @@ use crate::aggregates::{ }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; -use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::spill_manager::SpillManager; use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; +use crate::{aggregates, metrics, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; -use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, DataFusionError, Result}; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; - -use super::order::GroupOrdering; -use super::AggregateExec; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; + use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; @@ -100,7 +98,7 @@ struct SpillState { // ======================================================================== /// If data has previously been spilled, the locations of the /// spill files (in Arrow IPC format) - spills: Vec, + spills: Vec, /// true when streaming merge is in progress is_stream_merging: bool, @@ -507,32 +505,43 @@ impl GroupedHashAggregateStream { AggregateMode::Partial, )?; - let partial_agg_schema = Arc::new(partial_agg_schema); - - let spill_expr = group_schema - .fields - .into_iter() + // Need to update the GROUP BY expressions to point to the correct column after schema change + let merging_group_by_expr = agg_group_by + .expr + .iter() .enumerate() - .map(|(idx, field)| PhysicalSortExpr { - expr: Arc::new(Column::new(field.name().as_str(), idx)) as _, - options: SortOptions::default(), + .map(|(idx, (_, name))| { + (Arc::new(Column::new(name.as_str(), idx)) as _, name.clone()) }) .collect(); - let name = format!("GroupedHashAggregateStream[{partition}]"); + let partial_agg_schema = Arc::new(partial_agg_schema); + + let spill_expr = + group_schema + .fields + .into_iter() + .enumerate() + .map(|(idx, field)| { + PhysicalSortExpr::new_default(Arc::new(Column::new( + field.name().as_str(), + idx, + )) as _) + }); + let Some(spill_expr) = LexOrdering::new(spill_expr) else { + return internal_err!("Spill expression is empty"); + }; + + let agg_fn_names = aggregate_exprs + .iter() + .map(|expr| expr.human_display()) + .collect::>() + .join(", "); + let name = format!("GroupedHashAggregateStream[{partition}] ({agg_fn_names})"); let reservation = MemoryConsumer::new(name) .with_can_spill(true) .register(context.memory_pool()); - let (ordering, _) = agg - .properties() - .equivalence_properties() - .find_longest_permutation(&agg_group_by.output_exprs()); - let group_ordering = GroupOrdering::try_new( - &group_schema, - &agg.input_order_mode, - ordering.as_ref(), - )?; - + let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?; let group_values = new_group_values(group_schema, &group_ordering)?; timer.done(); @@ -542,7 +551,8 @@ impl GroupedHashAggregateStream { context.runtime_env(), metrics::SpillMetrics::new(&agg.metrics, partition), Arc::clone(&partial_agg_schema), - ); + ) + .with_compression_type(context.session_config().spill_compression()); let spill_state = SpillState { spills: vec![], @@ -550,7 +560,7 @@ impl GroupedHashAggregateStream { spill_schema: partial_agg_schema, is_stream_merging: false, merging_aggregate_arguments, - merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), + merging_group_by: PhysicalGroupBy::new_single(merging_group_by_expr), peak_mem_used: MetricBuilder::new(&agg.metrics) .gauge("peak_mem_used", partition), spill_manager, @@ -965,7 +975,7 @@ impl GroupedHashAggregateStream { /// memory. Currently only [`GroupOrdering::None`] is supported for spilling. fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> { // TODO: support group_ordering for spilling - if self.group_values.len() > 0 + if !self.group_values.is_empty() && batch.num_rows() > 0 && matches!(self.group_ordering, GroupOrdering::None) && !self.spill_state.is_stream_merging @@ -986,16 +996,24 @@ impl GroupedHashAggregateStream { let Some(emit) = self.emit(EmitTo::All, true)? else { return Ok(()); }; - let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?; + let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; // Spill sorted state to disk - let spillfile = self.spill_state.spill_manager.spill_record_batch_by_size( - &sorted, - "HashAggSpill", - self.batch_size, - )?; + let spillfile = self + .spill_state + .spill_manager + .spill_record_batch_by_size_and_return_max_batch_memory( + &sorted, + "HashAggSpill", + self.batch_size, + )?; match spillfile { - Some(spillfile) => self.spill_state.spills.push(spillfile), + Some((spillfile, max_record_batch_memory)) => { + self.spill_state.spills.push(SortedSpillFile { + file: spillfile, + max_record_batch_memory, + }) + } None => { return internal_err!( "Calling spill with no intermediate batch to spill" @@ -1053,18 +1071,17 @@ impl GroupedHashAggregateStream { streams.push(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, expr.as_ref(), None) + sort_batch(&batch, &expr, None) })), ))); - for spill in self.spill_state.spills.drain(..) { - let stream = self.spill_state.spill_manager.read_spill_as_stream(spill)?; - streams.push(stream); - } + self.spill_state.is_stream_merging = true; self.input = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(schema) - .with_expressions(self.spill_state.spill_expr.as_ref()) + .with_spill_manager(self.spill_state.spill_manager.clone()) + .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills)) + .with_expressions(&self.spill_state.spill_expr) .with_metrics(self.baseline_metrics.clone()) .with_batch_size(self.batch_size) .with_reservation(self.reservation.new_empty()) diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index ae44eb35e6d04..974aea3b6292c 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -26,7 +26,7 @@ use arrow::array::{ ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, StringViewArray, }; use arrow::datatypes::{i256, DataType}; -use datafusion_common::DataFusionError; +use datafusion_common::exec_datafusion_err; use datafusion_common::Result; use half::f16; use hashbrown::raw::RawTable; @@ -99,6 +99,7 @@ where owned: ArrayRef, map: TopKHashTable>, rnd: RandomState, + kt: DataType, } impl StringHashTable { @@ -216,12 +217,17 @@ where Option<::Native>: Comparable, Option<::Native>: HashValue, { - pub fn new(limit: usize) -> Self { - let owned = Arc::new(PrimitiveArray::::builder(0).finish()); + pub fn new(limit: usize, kt: DataType) -> Self { + let owned = Arc::new( + PrimitiveArray::::builder(0) + .with_data_type(kt.clone()) + .finish(), + ); Self { owned, map: TopKHashTable::new(limit, limit * 10), rnd: RandomState::default(), + kt, } } } @@ -249,7 +255,8 @@ where unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { let ids = self.map.take_all(indexes); - let mut builder: PrimitiveBuilder = PrimitiveArray::builder(ids.len()); + let mut builder: PrimitiveBuilder = + PrimitiveArray::builder(ids.len()).with_data_type(self.kt.clone()); for id in ids.into_iter() { match id { None => builder.append_null(), @@ -413,7 +420,7 @@ pub fn new_hash_table( ) -> Result> { macro_rules! downcast_helper { ($kt:ty, $d:ident) => { - return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit))) + return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit, kt))) }; } @@ -425,16 +432,35 @@ pub fn new_hash_table( _ => {} } - Err(DataFusionError::Execution(format!( + Err(exec_datafusion_err!( "Can't create HashTable for type: {kt:?}" - ))) + )) } #[cfg(test)] mod tests { use super::*; + use arrow::array::TimestampMillisecondArray; + use arrow_schema::TimeUnit; use std::collections::BTreeMap; + #[test] + fn should_emit_correct_type() -> Result<()> { + let ids = + TimestampMillisecondArray::from(vec![1000]).with_timezone("UTC".to_string()); + let dt = DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())); + let mut ht = new_hash_table(1, dt.clone())?; + ht.set_batch(Arc::new(ids)); + let mut mapper = vec![]; + let ids = unsafe { + ht.find_or_insert(0, 0, &mut mapper); + ht.take_all(vec![0]) + }; + assert_eq!(ids.data_type(), &dt); + + Ok(()) + } + #[test] fn should_resize_properly() -> Result<()> { let mut heap_to_map = BTreeMap::::new(); @@ -461,7 +487,7 @@ mod tests { let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip(); let ids = unsafe { map.take_all(map_idxs) }; assert_eq!( - format!("{:?}", ids), + format!("{ids:?}"), r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# ); assert_eq!(map.len(), 0, "Map should have been cleared!"); diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 8b4b07d211a0e..23ccf5e17ef69 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -24,7 +24,7 @@ use arrow::array::{ use arrow::array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::ScalarBuffer; use arrow::datatypes::{i256, DataType}; -use datafusion_common::DataFusionError; +use datafusion_common::exec_datafusion_err; use datafusion_common::Result; use half::f16; @@ -348,7 +348,7 @@ impl TopKHeap { prefix, connector, hi.val, idx, hi.map_idx )); let new_prefix = if is_tail { "" } else { "│ " }; - let child_prefix = format!("{}{}", prefix, new_prefix); + let child_prefix = format!("{prefix}{new_prefix}"); let left_idx = idx * 2 + 1; let right_idx = idx * 2 + 2; @@ -372,7 +372,7 @@ impl Display for TopKHeap { if !self.heap.is_empty() { self._tree_print(0, String::new(), true, &mut output); } - write!(f, "{}", output) + write!(f, "{output}") } } @@ -478,9 +478,7 @@ pub fn new_heap( _ => {} } - Err(DataFusionError::Execution(format!( - "Can't group type: {vt:?}" - ))) + Err(exec_datafusion_err!("Can't group type: {vt:?}")) } #[cfg(test)] diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index bf02692486cc6..9aaadfd52b96b 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -26,7 +26,7 @@ use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::{Array, ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; use arrow::util::pretty::print_batches; -use datafusion_common::DataFusionError; +use datafusion_common::internal_datafusion_err; use datafusion_common::Result; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; @@ -61,7 +61,7 @@ impl GroupedTopKAggregateStream { aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; let (val_field, desc) = aggr .get_minmax_desc() - .ok_or_else(|| DataFusionError::Internal("Min/max required".to_string()))?; + .ok_or_else(|| internal_datafusion_err!("Min/max required"))?; let (expr, _) = &aggr.group_expr().expr()[0]; let kt = expr.data_type(&aggr.input().schema())?; diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index ea14ce676c1a6..c095afe5e716e 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -134,9 +134,8 @@ impl ExecutionPlan for AnalyzeExec { vec![&self.input] } - /// AnalyzeExec is handled specially so this value is ignored fn required_input_distribution(&self) -> Vec { - vec![] + vec![Distribution::UnspecifiedDistribution] } fn with_new_children( diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs new file mode 100644 index 0000000000000..54a76e0ebb971 --- /dev/null +++ b/datafusion/physical-plan/src/async_func.rs @@ -0,0 +1,300 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +use arrow::array::RecordBatch; +use arrow_schema::{Fields, Schema, SchemaRef}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{internal_err, Result}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use futures::stream::StreamExt; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +/// This structure evaluates a set of async expressions on a record +/// batch producing a new record batch +/// +/// The schema of the output of the AsyncFuncExec is: +/// Input columns followed by one column for each async expression +#[derive(Debug)] +pub struct AsyncFuncExec { + /// The async expressions to evaluate + async_exprs: Vec>, + input: Arc, + cache: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl AsyncFuncExec { + pub fn try_new( + async_exprs: Vec>, + input: Arc, + ) -> Result { + let async_fields = async_exprs + .iter() + .map(|async_expr| async_expr.field(input.schema().as_ref())) + .collect::>>()?; + + // compute the output schema: input schema then async expressions + let fields: Fields = input + .schema() + .fields() + .iter() + .cloned() + .chain(async_fields.into_iter().map(Arc::new)) + .collect(); + + let schema = Arc::new(Schema::new(fields)); + let tuples = async_exprs + .iter() + .map(|expr| (Arc::clone(&expr.func), expr.name().to_string())) + .collect::>(); + let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?; + let cache = + AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?; + Ok(Self { + input, + async_exprs, + cache, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// This function creates the cache object that stores the plan properties + /// such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + input: &Arc, + schema: SchemaRef, + async_expr_mapping: &ProjectionMapping, + ) -> Result { + Ok(PlanProperties::new( + input + .equivalence_properties() + .project(async_expr_mapping, schema), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + )) + } +} + +impl DisplayAs for AsyncFuncExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let expr: Vec = self + .async_exprs + .iter() + .map(|async_expr| async_expr.to_string()) + .collect(); + let exprs = expr.join(", "); + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "AsyncFuncExec: async_expr=[{exprs}]") + } + DisplayFormatType::TreeRender => { + writeln!(f, "format=async_expr")?; + writeln!(f, "async_expr={exprs}")?; + Ok(()) + } + } + } +} + +impl ExecutionPlan for AsyncFuncExec { + fn name(&self) -> &str { + "async_func" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return internal_err!("AsyncFuncExec wrong number of children"); + } + Ok(Arc::new(AsyncFuncExec::try_new( + self.async_exprs.clone(), + Arc::clone(&children[0]), + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!( + "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); + // TODO figure out how to record metrics + + // first execute the input stream + let input_stream = self.input.execute(partition, Arc::clone(&context))?; + + // now, for each record batch, evaluate the async expressions and add the columns to the result + let async_exprs_captured = Arc::new(self.async_exprs.clone()); + let schema_captured = self.schema(); + let config_options_ref = Arc::clone(context.session_config().options()); + + let stream_with_async_functions = input_stream.then(move |batch| { + // need to clone *again* to capture the async_exprs and schema in the + // stream and satisfy lifetime requirements. + let async_exprs_captured = Arc::clone(&async_exprs_captured); + let schema_captured = Arc::clone(&schema_captured); + let config_options = Arc::clone(&config_options_ref); + + async move { + let batch = batch?; + // append the result of evaluating the async expressions to the output + let mut output_arrays = batch.columns().to_vec(); + for async_expr in async_exprs_captured.iter() { + let output = async_expr + .invoke_with_args(&batch, Arc::clone(&config_options)) + .await?; + output_arrays.push(output.to_array(batch.num_rows())?); + } + let batch = RecordBatch::try_new(schema_captured, output_arrays)?; + Ok(batch) + } + }); + + // Adapt the stream with the output schema + let adapter = + RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions); + Ok(Box::pin(adapter)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +const ASYNC_FN_PREFIX: &str = "__async_fn_"; + +/// Maps async_expressions to new columns +/// +/// The output of the async functions are appended, in order, to the end of the input schema +#[derive(Debug)] +pub struct AsyncMapper { + /// the number of columns in the input plan + /// used to generate the output column names. + /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc + num_input_columns: usize, + /// the expressions to map + pub async_exprs: Vec>, +} + +impl AsyncMapper { + pub fn new(num_input_columns: usize) -> Self { + Self { + num_input_columns, + async_exprs: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.async_exprs.is_empty() + } + + pub fn next_column_name(&self) -> String { + format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len()) + } + + /// Finds any references to async functions in the expression and adds them to the map + pub fn find_references( + &mut self, + physical_expr: &Arc, + schema: &Schema, + ) -> Result<()> { + // recursively look for references to async functions + physical_expr.apply(|expr| { + if let Some(scalar_func_expr) = + expr.as_any().downcast_ref::() + { + if scalar_func_expr.fun().as_async().is_some() { + let next_name = self.next_column_name(); + self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new( + next_name, + Arc::clone(expr), + schema, + )?)); + } + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(()) + } + + /// If the expression matches any of the async functions, return the new column + pub fn map_expr( + &self, + expr: Arc, + ) -> Transformed> { + // find the first matching async function if any + let Some(idx) = + self.async_exprs + .iter() + .enumerate() + .find_map(|(idx, async_expr)| { + if async_expr.func == Arc::clone(&expr) { + Some(idx) + } else { + None + } + }) + else { + return Transformed::no(expr); + }; + // rewrite in terms of the output column + Transformed::yes(self.output_column(idx)) + } + + /// return the output column for the async function at index idx + pub fn output_column(&self, idx: usize) -> Arc { + let async_expr = &self.async_exprs[idx]; + let output_idx = self.num_input_columns + idx; + Arc::new(Column::new(async_expr.name(), output_idx)) + } +} diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index eb4a7d875c95a..5962362d76810 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -90,7 +90,7 @@ impl BatchCoalescer { /// # Arguments /// - `schema` - the schema of the output batches /// - `target_batch_size` - the minimum number of rows for each - /// output batch (until limit reached) + /// output batch (until limit reached) /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows pub fn new( schema: SchemaRef, @@ -228,6 +228,12 @@ fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch { let Some(s) = c.as_string_view_opt() else { return Arc::clone(c); }; + + // Fast path: if the data buffers are empty, we can return the original array + if s.data_buffers().is_empty() { + return Arc::clone(c); + } + let ideal_buffer_size: usize = s .views() .iter() @@ -240,7 +246,11 @@ fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch { } }) .sum(); - let actual_buffer_size = s.get_buffer_memory_size(); + + // We don't use get_buffer_memory_size here, because gc is for the contents of the + // data buffers, not views and nulls. + let actual_buffer_size = + s.data_buffers().iter().map(|b| b.capacity()).sum::(); // Re-creating the array copies data and can be time consuming. // We only do it if the array is sparse @@ -285,7 +295,7 @@ mod tests { fn test_coalesce() { let batch = uint32_batch(0..8); Test::new() - .with_batches(std::iter::repeat(batch).take(10)) + .with_batches(std::iter::repeat_n(batch, 10)) // expected output is batches of at least 20 rows (except for the final batch) .with_target_batch_size(21) .with_expected_output_sizes(vec![24, 24, 24, 8]) @@ -296,7 +306,7 @@ mod tests { fn test_coalesce_with_fetch_larger_than_input_size() { let batch = uint32_batch(0..8); Test::new() - .with_batches(std::iter::repeat(batch).take(10)) + .with_batches(std::iter::repeat_n(batch, 10)) // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 // expected to behave the same as `test_concat_batches` .with_target_batch_size(21) @@ -309,7 +319,7 @@ mod tests { fn test_coalesce_with_fetch_less_than_input_size() { let batch = uint32_batch(0..8); Test::new() - .with_batches(std::iter::repeat(batch).take(10)) + .with_batches(std::iter::repeat_n(batch, 10)) // input is 10 batches x 8 rows (80 rows) with fetch limit of 50 .with_target_batch_size(21) .with_fetch(Some(50)) @@ -321,7 +331,7 @@ mod tests { fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() { let batch = uint32_batch(0..8); Test::new() - .with_batches(std::iter::repeat(batch).take(10)) + .with_batches(std::iter::repeat_n(batch, 10)) // input is 10 batches x 8 rows (80 rows) with fetch limit of 48 .with_target_batch_size(21) .with_fetch(Some(48)) @@ -333,7 +343,7 @@ mod tests { fn test_coalesce_with_fetch_less_target_batch_size() { let batch = uint32_batch(0..8); Test::new() - .with_batches(std::iter::repeat(batch).take(10)) + .with_batches(std::iter::repeat_n(batch, 10)) // input is 10 batches x 8 rows (80 rows) with fetch limit of 10 .with_target_batch_size(21) .with_fetch(Some(10)) @@ -593,7 +603,7 @@ mod tests { } } fn batch_to_pretty_strings(batch: &RecordBatch) -> String { - arrow::util::pretty::pretty_format_batches(&[batch.clone()]) + arrow::util::pretty::pretty_format_batches(std::slice::from_ref(batch)) .unwrap() .to_string() } diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 5244038b9ae27..397bd9a377c35 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -32,9 +32,15 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; use crate::coalesce::{BatchCoalescer, CoalescerState}; use crate::execution_plan::CardinalityEffect; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use datafusion_common::config::ConfigOptions; use futures::ready; use futures::stream::{Stream, StreamExt}; @@ -192,7 +198,13 @@ impl ExecutionPlan for CoalesceBatchesExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input + .partition_statistics(partition)? + .with_fetch(self.fetch, 0, 1) } fn with_fetch(&self, limit: Option) -> Option> { @@ -212,6 +224,24 @@ impl ExecutionPlan for CoalesceBatchesExec { fn cardinality_effect(&self) -> CardinalityEffect { CardinalityEffect::Equal } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } } /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. @@ -321,6 +351,7 @@ impl CoalesceBatchesStream { } } CoalesceBatchesStreamState::ReturnBuffer => { + let _timer = cloned_time.timer(); // Combine buffered batches into one batch and return it. let batch = self.coalescer.finish_batch()?; // Set to pull state for the next iteration. @@ -333,6 +364,7 @@ impl CoalesceBatchesStream { // If buffer is empty, return None indicating the stream is fully consumed. Poll::Ready(None) } else { + let _timer = cloned_time.timer(); // If the buffer still contains batches, prepare to return them. let batch = self.coalescer.finish_batch()?; Poll::Ready(Some(Ok(batch))) diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 95a0c8f6ce833..5869c51b26b8d 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -27,12 +27,15 @@ use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, SendableRecordBatchStream, Statistics, }; -use crate::execution_plan::CardinalityEffect; +use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; +use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; use crate::projection::{make_with_child, ProjectionExec}; use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; +use datafusion_common::config::ConfigOptions; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -59,6 +62,12 @@ impl CoalescePartitionsExec { } } + /// Update fetch with the argument + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + /// Input execution plan pub fn input(&self) -> &Arc { &self.input @@ -66,6 +75,16 @@ impl CoalescePartitionsExec { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(input: &Arc) -> PlanProperties { + let input_partitions = input.output_partitioning().partition_count(); + let (drive, scheduling) = if input_partitions > 1 { + (EvaluationType::Eager, SchedulingType::Cooperative) + } else { + ( + input.properties().evaluation_type, + input.properties().scheduling_type, + ) + }; + // Coalescing partitions loses existing orderings: let mut eq_properties = input.equivalence_properties().clone(); eq_properties.clear_orderings(); @@ -76,6 +95,8 @@ impl CoalescePartitionsExec { input.pipeline_behavior(), input.boundedness(), ) + .with_evaluation_type(drive) + .with_scheduling_type(scheduling) } } @@ -190,7 +211,13 @@ impl ExecutionPlan for CoalescePartitionsExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, _partition: Option) -> Result { + self.input + .partition_statistics(None)? + .with_fetch(self.fetch, 0, 1) } fn supports_limit_pushdown(&self) -> bool { @@ -236,6 +263,15 @@ impl ExecutionPlan for CoalescePartitionsExec { cache: self.cache.clone(), })) } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) + } } #[cfg(test)] diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index a8d4a3ddf3d1a..e9a8499a7c9ac 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -18,8 +18,7 @@ //! Defines common code used in execution plans use std::fs; -use std::fs::{metadata, File}; -use std::path::{Path, PathBuf}; +use std::fs::metadata; use std::sync::Arc; use super::SendableRecordBatchStream; @@ -28,10 +27,9 @@ use crate::{ColumnStatistics, Statistics}; use arrow::array::Array; use arrow::datatypes::Schema; -use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{plan_err, Result}; use datafusion_execution::memory_pool::MemoryReservation; use futures::{StreamExt, TryStreamExt}; @@ -93,7 +91,7 @@ fn build_file_list_recurse( /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` -pub(crate) fn spawn_buffered( +pub fn spawn_buffered( mut input: SendableRecordBatchStream, buffer: usize, ) -> SendableRecordBatchStream { @@ -180,77 +178,6 @@ pub fn compute_record_batch_statistics( } } -/// Write in Arrow IPC File format. -pub struct IPCWriter { - /// Path - pub path: PathBuf, - /// Inner writer - pub writer: FileWriter, - /// Batches written - pub num_batches: usize, - /// Rows written - pub num_rows: usize, - /// Bytes written - pub num_bytes: usize, -} - -impl IPCWriter { - /// Create new writer - pub fn new(path: &Path, schema: &Schema) -> Result { - let file = File::create(path).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create partition file at {path:?}: {e:?}" - )) - })?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.into(), - writer: FileWriter::try_new(file, schema)?, - }) - } - - /// Create new writer with IPC write options - pub fn new_with_options( - path: &Path, - schema: &Schema, - write_options: IpcWriteOptions, - ) -> Result { - let file = File::create(path).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create partition file at {path:?}: {e:?}" - )) - })?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.into(), - writer: FileWriter::try_new_with_options(file, schema, write_options)?, - }) - } - /// Write one single batch - pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { - self.writer.write(batch)?; - self.num_batches += 1; - self.num_rows += batch.num_rows(); - let num_bytes: usize = batch.get_array_memory_size(); - self.num_bytes += num_bytes; - Ok(()) - } - - /// Finish the writer - pub fn finish(&mut self) -> Result<()> { - self.writer.finish().map_err(Into::into) - } - - /// Path write to - pub fn path(&self) -> &Path { - &self.path - } -} - /// Checks if the given projection is valid for the given schema. pub fn can_project( schema: &arrow::datatypes::SchemaRef, diff --git a/datafusion/physical-plan/src/coop.rs b/datafusion/physical-plan/src/coop.rs new file mode 100644 index 0000000000000..b62d15e6d2f17 --- /dev/null +++ b/datafusion/physical-plan/src/coop.rs @@ -0,0 +1,392 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for improved cooperative scheduling. +//! +//! # Cooperative scheduling +//! +//! A single call to `poll_next` on a top-level [`Stream`] may potentially perform a lot of work +//! before it returns a `Poll::Pending`. Think for instance of calculating an aggregation over a +//! large dataset. +//! If a `Stream` runs for a long period of time without yielding back to the Tokio executor, +//! it can starve other tasks waiting on that executor to execute them. +//! Additionally, this prevents the query execution from being cancelled. +//! +//! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield +//! points using the utilities in this module. For most operators this is **not** necessary. The +//! `Stream`s of the built-in DataFusion operators that generate (rather than manipulate) +//! `RecordBatch`es such as `DataSourceExec` and those that eagerly consume `RecordBatch`es +//! (for instance, `RepartitionExec`) contain yield points that will make most query `Stream`s yield +//! periodically. +//! +//! There are a couple of types of operators that _should_ insert yield points: +//! - New source operators that do not make use of Tokio resources +//! - Exchange like operators that do not use Tokio's `Channel` implementation to pass data between +//! tasks +//! +//! ## Adding yield points +//! +//! Yield points can be inserted manually using the facilities provided by the +//! [Tokio coop module](https://docs.rs/tokio/latest/tokio/task/coop/index.html) such as +//! [`tokio::task::coop::consume_budget`](https://docs.rs/tokio/latest/tokio/task/coop/fn.consume_budget.html). +//! +//! Another option is to use the wrapper `Stream` implementation provided by this module which will +//! consume a unit of task budget every time a `RecordBatch` is produced. +//! Wrapper `Stream`s can be created using the [`cooperative`] and [`make_cooperative`] functions. +//! +//! [`cooperative`] is a generic function that takes ownership of the wrapped [`RecordBatchStream`]. +//! This function has the benefit of not requiring an additional heap allocation and can avoid +//! dynamic dispatch. +//! +//! [`make_cooperative`] is a non-generic function that wraps a [`SendableRecordBatchStream`]. This +//! can be used to wrap dynamically typed, heap allocated [`RecordBatchStream`]s. +//! +//! ## Automatic cooperation +//! +//! The `EnsureCooperative` physical optimizer rule, which is included in the default set of +//! optimizer rules, inspects query plans for potential cooperative scheduling issues. +//! It injects the [`CooperativeExec`] wrapper `ExecutionPlan` into the query plan where necessary. +//! This `ExecutionPlan` uses [`make_cooperative`] to wrap the `Stream` of its input. +//! +//! The optimizer rule currently checks the plan for exchange-like operators and leave operators +//! that report [`SchedulingType::NonCooperative`] in their [plan properties](ExecutionPlan::properties). + +use datafusion_common::config::ConfigOptions; +use datafusion_physical_expr::PhysicalExpr; +#[cfg(datafusion_coop = "tokio_fallback")] +use futures::Future; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::execution_plan::CardinalityEffect::{self, Equal}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, +}; +use arrow::record_batch::RecordBatch; +use arrow_schema::Schema; +use datafusion_common::{internal_err, Result, Statistics}; +use datafusion_execution::TaskContext; + +use crate::execution_plan::SchedulingType; +use crate::stream::RecordBatchStreamAdapter; +use futures::{Stream, StreamExt}; + +/// A stream that passes record batches through unchanged while cooperating with the Tokio runtime. +/// It consumes cooperative scheduling budget for each returned [`RecordBatch`], +/// allowing other tasks to execute when the budget is exhausted. +/// +/// See the [module level documentation](crate::coop) for an in-depth discussion. +pub struct CooperativeStream +where + T: RecordBatchStream + Unpin, +{ + inner: T, + #[cfg(datafusion_coop = "per_stream")] + budget: u8, +} + +#[cfg(datafusion_coop = "per_stream")] +// Magic value that matches Tokio's task budget value +const YIELD_FREQUENCY: u8 = 128; + +impl CooperativeStream +where + T: RecordBatchStream + Unpin, +{ + /// Creates a new `CooperativeStream` that wraps the provided stream. + /// The resulting stream will cooperate with the Tokio scheduler by consuming a unit of + /// scheduling budget when the wrapped `Stream` returns a record batch. + pub fn new(inner: T) -> Self { + Self { + inner, + #[cfg(datafusion_coop = "per_stream")] + budget: YIELD_FREQUENCY, + } + } +} + +impl Stream for CooperativeStream +where + T: RecordBatchStream + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + #[cfg(any( + datafusion_coop = "tokio", + not(any( + datafusion_coop = "tokio_fallback", + datafusion_coop = "per_stream" + )) + ))] + { + let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx)); + let value = self.inner.poll_next_unpin(cx); + if value.is_ready() { + coop.made_progress(); + } + value + } + + #[cfg(datafusion_coop = "tokio_fallback")] + { + // This is a temporary placeholder implementation that may have slightly + // worse performance compared to `poll_proceed` + if !tokio::task::coop::has_budget_remaining() { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + let value = self.inner.poll_next_unpin(cx); + if value.is_ready() { + // In contrast to `poll_proceed` we are not able to consume + // budget before proceeding to do work. Instead, we try to consume budget + // after the work has been done and just assume that that succeeded. + // The poll result is ignored because we don't want to discard + // or buffer the Ready result we got from the inner stream. + let consume = tokio::task::coop::consume_budget(); + let consume_ref = std::pin::pin!(consume); + let _ = consume_ref.poll(cx); + } + value + } + + #[cfg(datafusion_coop = "per_stream")] + { + if self.budget == 0 { + self.budget = YIELD_FREQUENCY; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + let value = { self.inner.poll_next_unpin(cx) }; + + if value.is_ready() { + self.budget -= 1; + } else { + self.budget = YIELD_FREQUENCY; + } + value + } + } +} + +impl RecordBatchStream for CooperativeStream +where + T: RecordBatchStream + Unpin, +{ + fn schema(&self) -> Arc { + self.inner.schema() + } +} + +/// An execution plan decorator that enables cooperative multitasking. +/// It wraps the streams produced by its input execution plan using the [`make_cooperative`] function, +/// which makes the stream participate in Tokio cooperative scheduling. +#[derive(Debug)] +pub struct CooperativeExec { + input: Arc, + properties: PlanProperties, +} + +impl CooperativeExec { + /// Creates a new `CooperativeExec` operator that wraps the given input execution plan. + pub fn new(input: Arc) -> Self { + let properties = input + .properties() + .clone() + .with_scheduling_type(SchedulingType::Cooperative); + + Self { input, properties } + } + + /// Returns a reference to the wrapped input execution plan. + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for CooperativeExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "CooperativeExec") + } +} + +impl ExecutionPlan for CooperativeExec { + fn name(&self) -> &str { + "CooperativeExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> Arc { + self.input.schema() + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn maintains_input_order(&self) -> Vec { + vec![true; self.children().len()] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + if children.len() != 1 { + return internal_err!("CooperativeExec requires exactly one child"); + } + Ok(Arc::new(CooperativeExec::new(children.swap_remove(0)))) + } + + fn execute( + &self, + partition: usize, + task_ctx: Arc, + ) -> Result { + let child_stream = self.input.execute(partition, task_ctx)?; + Ok(make_cooperative(child_stream)) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + true + } + + fn cardinality_effect(&self) -> CardinalityEffect { + Equal + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } +} + +/// Creates a [`CooperativeStream`] wrapper around the given [`RecordBatchStream`]. +/// This wrapper collaborates with the Tokio cooperative scheduler by consuming a unit of +/// scheduling budget for each returned record batch. +pub fn cooperative(stream: T) -> CooperativeStream +where + T: RecordBatchStream + Unpin + Send + 'static, +{ + CooperativeStream::new(stream) +} + +/// Wraps a `SendableRecordBatchStream` inside a [`CooperativeStream`] to enable cooperative multitasking. +/// Since `SendableRecordBatchStream` is a `dyn RecordBatchStream` this requires the use of dynamic +/// method dispatch. +/// When the stream type is statically known, consider use the generic [`cooperative`] function +/// to allow static method dispatch. +pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream { + // TODO is there a more elegant way to overload cooperative + Box::pin(cooperative(RecordBatchStreamAdapter::new( + stream.schema(), + stream, + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::stream::RecordBatchStreamAdapter; + + use arrow_schema::SchemaRef; + + use futures::{stream, StreamExt}; + + // This is the hardcoded value Tokio uses + const TASK_BUDGET: usize = 128; + + /// Helper: construct a SendableRecordBatchStream containing `n` empty batches + fn make_empty_batches(n: usize) -> SendableRecordBatchStream { + let schema: SchemaRef = Arc::new(Schema::empty()); + let schema_for_stream = Arc::clone(&schema); + + let s = + stream::iter((0..n).map(move |_| { + Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream))) + })); + + Box::pin(RecordBatchStreamAdapter::new(schema, s)) + } + + #[tokio::test] + async fn yield_less_than_threshold() -> Result<()> { + let count = TASK_BUDGET - 10; + let inner = make_empty_batches(count); + let out = make_cooperative(inner).collect::>().await; + assert_eq!(out.len(), count); + Ok(()) + } + + #[tokio::test] + async fn yield_equal_to_threshold() -> Result<()> { + let count = TASK_BUDGET; + let inner = make_empty_batches(count); + let out = make_cooperative(inner).collect::>().await; + assert_eq!(out.len(), count); + Ok(()) + } + + #[tokio::test] + async fn yield_more_than_threshold() -> Result<()> { + let count = TASK_BUDGET + 20; + let inner = make_empty_batches(count); + let out = make_cooperative(inner).collect::>().await; + assert_eq!(out.len(), count); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index f437295a35551..2420edfc743da 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -120,6 +120,8 @@ pub struct DisplayableExecutionPlan<'a> { show_statistics: bool, /// If schema should be displayed. See [`Self::set_show_schema`] show_schema: bool, + // (TreeRender) Maximum total width of the rendered tree + tree_maximum_render_width: usize, } impl<'a> DisplayableExecutionPlan<'a> { @@ -131,6 +133,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics::None, show_statistics: false, show_schema: false, + tree_maximum_render_width: 240, } } @@ -143,6 +146,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics::Aggregated, show_statistics: false, show_schema: false, + tree_maximum_render_width: 240, } } @@ -155,6 +159,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics::Full, show_statistics: false, show_schema: false, + tree_maximum_render_width: 240, } } @@ -173,6 +178,12 @@ impl<'a> DisplayableExecutionPlan<'a> { self } + /// Set the maximum render width for the tree format + pub fn set_tree_maximum_render_width(mut self, width: usize) -> Self { + self.tree_maximum_render_width = width; + self + } + /// Return a `format`able structure that produces a single line /// per node. /// @@ -270,14 +281,21 @@ impl<'a> DisplayableExecutionPlan<'a> { pub fn tree_render(&self) -> impl fmt::Display + 'a { struct Wrapper<'a> { plan: &'a dyn ExecutionPlan, + maximum_render_width: usize, } impl fmt::Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - let mut visitor = TreeRenderVisitor { f }; + let mut visitor = TreeRenderVisitor { + f, + maximum_render_width: self.maximum_render_width, + }; visitor.visit(self.plan) } } - Wrapper { plan: self.inner } + Wrapper { + plan: self.inner, + maximum_render_width: self.tree_maximum_render_width, + } } /// Return a single-line summary of the root of the plan @@ -394,8 +412,8 @@ impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { } } if self.show_statistics { - let stats = plan.statistics().map_err(|_e| fmt::Error)?; - write!(self.f, ", statistics=[{}]", stats)?; + let stats = plan.partition_statistics(None).map_err(|_e| fmt::Error)?; + write!(self.f, ", statistics=[{stats}]")?; } if self.show_schema { write!( @@ -479,8 +497,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { }; let statistics = if self.show_statistics { - let stats = plan.statistics().map_err(|_e| fmt::Error)?; - format!("statistics=[{}]", stats) + let stats = plan.partition_statistics(None).map_err(|_e| fmt::Error)?; + format!("statistics=[{stats}]") } else { "".to_string() }; @@ -495,7 +513,7 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { self.f, id, &label, - Some(&format!("{}{}{}", metrics, delimiter, statistics)), + Some(&format!("{metrics}{delimiter}{statistics}")), )?; if let Some(parent_node_id) = self.parents.last() { @@ -540,6 +558,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { struct TreeRenderVisitor<'a, 'b> { /// Write to this formatter f: &'a mut Formatter<'b>, + /// Maximum total width of the rendered tree + maximum_render_width: usize, } impl TreeRenderVisitor<'_, '_> { @@ -557,7 +577,6 @@ impl TreeRenderVisitor<'_, '_> { const HORIZONTAL: &'static str = "─"; // Horizontal line // TODO: Make these variables configurable. - const MAXIMUM_RENDER_WIDTH: usize = 240; // Maximum total width of the rendered tree const NODE_RENDER_WIDTH: usize = 29; // Width of each node's box const MAX_EXTRA_LINES: usize = 30; // Maximum number of extra info lines per node @@ -592,6 +611,12 @@ impl TreeRenderVisitor<'_, '_> { y: usize, ) -> Result<(), fmt::Error> { for x in 0..root.width { + if self.maximum_render_width > 0 + && x * Self::NODE_RENDER_WIDTH >= self.maximum_render_width + { + break; + } + if root.has_node(x, y) { write!(self.f, "{}", Self::LTCORNER)?; write!( @@ -657,12 +682,14 @@ impl TreeRenderVisitor<'_, '_> { } } - let halfway_point = (extra_height + 1) / 2; + let halfway_point = extra_height.div_ceil(2); // Render the actual node. for render_y in 0..=extra_height { for (x, _) in root.nodes.iter().enumerate().take(root.width) { - if x * Self::NODE_RENDER_WIDTH >= Self::MAXIMUM_RENDER_WIDTH { + if self.maximum_render_width > 0 + && x * Self::NODE_RENDER_WIDTH >= self.maximum_render_width + { break; } @@ -686,7 +713,7 @@ impl TreeRenderVisitor<'_, '_> { &render_text, Self::NODE_RENDER_WIDTH - 2, ); - write!(self.f, "{}", render_text)?; + write!(self.f, "{render_text}")?; if render_y == halfway_point && node.child_positions.len() > 1 { write!(self.f, "{}", Self::LMIDDLE)?; @@ -780,7 +807,9 @@ impl TreeRenderVisitor<'_, '_> { y: usize, ) -> Result<(), fmt::Error> { for x in 0..=root.width { - if x * Self::NODE_RENDER_WIDTH >= Self::MAXIMUM_RENDER_WIDTH { + if self.maximum_render_width > 0 + && x * Self::NODE_RENDER_WIDTH >= self.maximum_render_width + { break; } let mut has_adjacent_nodes = false; @@ -856,10 +885,10 @@ impl TreeRenderVisitor<'_, '_> { if str.is_empty() { str = key.to_string(); } else if !is_multiline && total_size < available_width { - str = format!("{}: {}", key, str); + str = format!("{key}: {str}"); is_inlined = true; } else { - str = format!("{}:\n{}", key, str); + str = format!("{key}:\n{str}"); } if is_inlined && was_inlined { @@ -902,11 +931,11 @@ impl TreeRenderVisitor<'_, '_> { let render_width = source.chars().count(); if render_width > max_render_width { let truncated = &source[..max_render_width - 3]; - format!("{}...", truncated) + format!("{truncated}...") } else { let total_spaces = max_render_width - render_width; let half_spaces = total_spaces / 2; - let extra_left_space = if total_spaces % 2 == 0 { 0 } else { 1 }; + let extra_left_space = if total_spaces.is_multiple_of(2) { 0 } else { 1 }; format!( "{}{}{}", " ".repeat(half_spaces + extra_left_space), @@ -1034,27 +1063,22 @@ impl fmt::Display for ProjectSchemaDisplay<'_> { } pub fn display_orderings(f: &mut Formatter, orderings: &[LexOrdering]) -> fmt::Result { - if let Some(ordering) = orderings.first() { - if !ordering.is_empty() { - let start = if orderings.len() == 1 { - ", output_ordering=" - } else { - ", output_orderings=[" - }; - write!(f, "{}", start)?; - for (idx, ordering) in - orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) - { - match idx { - 0 => write!(f, "[{}]", ordering)?, - _ => write!(f, ", [{}]", ordering)?, - } + if !orderings.is_empty() { + let start = if orderings.len() == 1 { + ", output_ordering=" + } else { + ", output_orderings=[" + }; + write!(f, "{start}")?; + for (idx, ordering) in orderings.iter().enumerate() { + match idx { + 0 => write!(f, "[{ordering}]")?, + _ => write!(f, ", [{ordering}]")?, } - let end = if orderings.len() == 1 { "" } else { "]" }; - write!(f, "{}", end)?; } + let end = if orderings.len() == 1 { "" } else { "]" }; + write!(f, "{end}")?; } - Ok(()) } @@ -1063,7 +1087,7 @@ mod tests { use std::fmt::Write; use std::sync::Arc; - use datafusion_common::{DataFusionError, Result, Statistics}; + use datafusion_common::{internal_datafusion_err, Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use crate::{DisplayAs, ExecutionPlan, PlanProperties}; @@ -1120,11 +1144,16 @@ mod tests { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(self.schema().as_ref())); + } match self { Self::Panic => panic!("expected panic"), - Self::Error => { - Err(DataFusionError::Internal("expected error".to_string())) - } + Self::Error => Err(internal_datafusion_err!("expected error")), Self::Ok => Ok(Statistics::new_unknown(self.schema().as_ref())), } } diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 3fdde39df6f11..40b4ec61dc102 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -33,6 +33,7 @@ use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; +use crate::execution_plan::SchedulingType; use log::trace; /// Execution plan for empty relation with produce_one_row=false @@ -81,6 +82,7 @@ impl EmptyExec { EmissionType::Incremental, Boundedness::Bounded, ) + .with_scheduling_type(SchedulingType::Cooperative) } } @@ -150,6 +152,20 @@ impl ExecutionPlan for EmptyExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + if partition >= self.partitions { + return internal_err!( + "EmptyExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); + } + } + let batch = self .data() .expect("Create empty RecordBatch should not fail"); diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 2bc5706ee0e18..a70cd9cb0d64d 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -16,6 +16,10 @@ // under the License. pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; pub use crate::metrics::Metric; pub use crate::ordering::InputOrderMode; pub use crate::stream::EmptyRecordBatchStream; @@ -38,18 +42,16 @@ use crate::coalesce_partitions::CoalescePartitionsExec; use crate::display::DisplayableExecutionPlan; use crate::metrics::MetricsSet; use crate::projection::ProjectionExec; -use crate::repartition::RepartitionExec; -use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::stream::RecordBatchStreamAdapter; use arrow::array::{Array, RecordBatch}; use arrow::datatypes::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::{exec_err, Constraints, Result}; +use datafusion_common::{exec_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::JoinSet; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use futures::stream::{StreamExt, TryStreamExt}; @@ -72,6 +74,15 @@ use futures::stream::{StreamExt, TryStreamExt}; /// [`execute`]: ExecutionPlan::execute /// [`required_input_distribution`]: ExecutionPlan::required_input_distribution /// [`required_input_ordering`]: ExecutionPlan::required_input_ordering +/// +/// # Examples +/// +/// See [`datafusion-examples`] for examples, including +/// [`memory_pool_execution_plan.rs`] which shows how to implement a custom +/// `ExecutionPlan` with memory tracking and spilling support. +/// +/// [`datafusion-examples`]: https://github.com/apache/datafusion/tree/main/datafusion-examples +/// [`memory_pool_execution_plan.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/memory_pool_execution_plan.rs pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// Short name for the ExecutionPlan, such as 'DataSourceExec'. /// @@ -115,10 +126,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// Returns an error if this individual node does not conform to its invariants. /// These invariants are typically only checked in debug mode. /// - /// A default set of invariants is provided in the default implementation. + /// A default set of invariants is provided in the [check_default_invariants] function. + /// The default implementation of `check_invariants` calls this function. /// Extension nodes can provide their own invariants. - fn check_invariants(&self, _check: InvariantLevel) -> Result<()> { - Ok(()) + fn check_invariants(&self, check: InvariantLevel) -> Result<()> { + check_default_invariants(self, check) } /// Specifies the data distribution requirements for all the @@ -136,7 +148,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// NOTE that checking `!is_empty()` does **not** check for a /// required input ordering. Instead, the correct check is that at /// least one entry must be `Some` - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![None; self.children().len()] } @@ -192,6 +204,31 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { children: Vec>, ) -> Result>; + /// Reset any internal state within this [`ExecutionPlan`]. + /// + /// This method is called when an [`ExecutionPlan`] needs to be re-executed, + /// such as in recursive queries. Unlike [`ExecutionPlan::with_new_children`], this method + /// ensures that any stateful components (e.g., [`DynamicFilterPhysicalExpr`]) + /// are reset to their initial state. + /// + /// The default implementation simply calls [`ExecutionPlan::with_new_children`] with the existing children, + /// effectively creating a new instance of the [`ExecutionPlan`] with the same children but without + /// necessarily resetting any internal state. Implementations that require resetting of some + /// internal state should override this method to provide the necessary logic. + /// + /// This method should *not* reset state recursively for children, as it is expected that + /// it will be called from within a walk of the execution plan tree so that it will be called on each child later + /// or was already called on each child. + /// + /// Note to implementers: unlike [`ExecutionPlan::with_new_children`] this method does not accept new children as an argument, + /// thus it is expected that any cached plan properties will remain valid after the reset. + /// + /// [`DynamicFilterPhysicalExpr`]: datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr + fn reset_state(self: Arc) -> Result> { + let children = self.children().into_iter().cloned().collect(); + self.with_new_children(children) + } + /// If supported, attempt to increase the partitioning of this `ExecutionPlan` to /// produce `target_partitions` partitions. /// @@ -267,11 +304,13 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// batch is superlinear. See this [general guideline][async-guideline] for more context /// on this point, which explains why one should avoid spending a long time without /// reaching an `await`/yield point in asynchronous runtimes. - /// This can be achieved by manually returning [`Poll::Pending`] and setting up wakers - /// appropriately, or the use of [`tokio::task::yield_now()`] when appropriate. + /// This can be achieved by using the utilities from the [`coop`](crate::coop) module, by + /// manually returning [`Poll::Pending`] and setting up wakers appropriately, or by calling + /// [`tokio::task::yield_now()`] when appropriate. /// In special cases that warrant manual yielding, determination for "regularly" may be - /// made using a timer (being careful with the overhead-heavy system call needed to - /// take the time), or by counting rows or batches. + /// made using the [Tokio task budget](https://docs.rs/tokio/latest/tokio/task/coop/index.html), + /// a timer (being careful with the overhead-heavy system call needed to take the time), or by + /// counting rows or batches. /// /// The [cancellation benchmark] tracks some cases of how quickly queries can /// be cancelled. @@ -423,10 +462,30 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// /// For TableScan executors, which supports filter pushdown, special attention /// needs to be paid to whether the stats returned by this method are exact or not + #[deprecated(since = "48.0.0", note = "Use `partition_statistics` method instead")] fn statistics(&self) -> Result { Ok(Statistics::new_unknown(&self.schema())) } + /// Returns statistics for a specific partition of this `ExecutionPlan` node. + /// If statistics are not available, should return [`Statistics::new_unknown`] + /// (the default), not an error. + /// If `partition` is `None`, it returns statistics for the entire plan. + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(idx) = partition { + // Validate partition index + let partition_count = self.properties().partitioning.partition_count(); + if idx >= partition_count { + return internal_err!( + "Invalid partition index: {}, the partition count is {}", + idx, + partition_count + ); + } + } + Ok(Statistics::new_unknown(&self.schema())) + } + /// Returns `true` if a limit can be safely pushed down through this /// `ExecutionPlan` node. /// @@ -467,6 +526,151 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { ) -> Result>> { Ok(None) } + + /// Collect filters that this node can push down to its children. + /// Filters that are being pushed down from parents are passed in, + /// and the node may generate additional filters to push down. + /// For example, given the plan FilterExec -> HashJoinExec -> DataSourceExec, + /// what will happen is that we recurse down the plan calling `ExecutionPlan::gather_filters_for_pushdown`: + /// 1. `FilterExec::gather_filters_for_pushdown` is called with no parent + /// filters so it only returns that `FilterExec` wants to push down its own predicate. + /// 2. `HashJoinExec::gather_filters_for_pushdown` is called with the filter from + /// `FilterExec`, which it only allows to push down to one side of the join (unless it's on the join key) + /// but it also adds its own filters (e.g. pushing down a bloom filter of the hash table to the scan side of the join). + /// 3. `DataSourceExec::gather_filters_for_pushdown` is called with both filters from `HashJoinExec` + /// and `FilterExec`, however `DataSourceExec::gather_filters_for_pushdown` doesn't actually do anything + /// since it has no children and no additional filters to push down. + /// It's only once [`ExecutionPlan::handle_child_pushdown_result`] is called on `DataSourceExec` as we recurse + /// up the plan that `DataSourceExec` can actually bind the filters. + /// + /// The default implementation bars all parent filters from being pushed down and adds no new filters. + /// This is the safest option, making filter pushdown opt-in on a per-node pasis. + /// + /// There are two different phases in filter pushdown, which some operators may handle the same and some differently. + /// Depending on the phase the operator may or may not be allowed to modify the plan. + /// See [`FilterPushdownPhase`] for more details. + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok(FilterDescription::all_unsupported( + &parent_filters, + &self.children(), + )) + } + + /// Handle the result of a child pushdown. + /// This method is called as we recurse back up the plan tree after pushing + /// filters down to child nodes via [`ExecutionPlan::gather_filters_for_pushdown`]. + /// It allows the current node to process the results of filter pushdown from + /// its children, deciding whether to absorb filters, modify the plan, or pass + /// filters back up to its parent. + /// + /// **Purpose and Context:** + /// Filter pushdown is a critical optimization in DataFusion that aims to + /// reduce the amount of data processed by applying filters as early as + /// possible in the query plan. This method is part of the second phase of + /// filter pushdown, where results are propagated back up the tree after + /// being pushed down. Each node can inspect the pushdown results from its + /// children and decide how to handle any unapplied filters, potentially + /// optimizing the plan structure or filter application. + /// + /// **Behavior in Different Nodes:** + /// - For a `DataSourceExec`, this often means absorbing the filters to apply + /// them during the scan phase (late materialization), reducing the data + /// read from the source. + /// - A `FilterExec` may absorb any filters its children could not handle, + /// combining them with its own predicate. If no filters remain (i.e., the + /// predicate becomes trivially true), it may remove itself from the plan + /// altogether. It typically marks parent filters as supported, indicating + /// they have been handled. + /// - A `HashJoinExec` might ignore the pushdown result if filters need to + /// be applied during the join operation. It passes the parent filters back + /// up wrapped in [`FilterPushdownPropagation::if_any`], discarding + /// any self-filters from children. + /// + /// **Example Walkthrough:** + /// Consider a query plan: `FilterExec (f1) -> HashJoinExec -> DataSourceExec`. + /// 1. **Downward Phase (`gather_filters_for_pushdown`):** Starting at + /// `FilterExec`, the filter `f1` is gathered and pushed down to + /// `HashJoinExec`. `HashJoinExec` may allow `f1` to pass to one side of + /// the join or add its own filters (e.g., a min-max filter from the build side), + /// then pushes filters to `DataSourceExec`. `DataSourceExec`, being a leaf node, + /// has no children to push to, so it prepares to handle filters in the + /// upward phase. + /// 2. **Upward Phase (`handle_child_pushdown_result`):** Starting at + /// `DataSourceExec`, it absorbs applicable filters from `HashJoinExec` + /// for late materialization during scanning, marking them as supported. + /// `HashJoinExec` receives the result, decides whether to apply any + /// remaining filters during the join, and passes unhandled filters back + /// up to `FilterExec`. `FilterExec` absorbs any unhandled filters, + /// updates its predicate if necessary, or removes itself if the predicate + /// becomes trivial (e.g., `lit(true)`), and marks filters as supported + /// for its parent. + /// + /// The default implementation is a no-op that passes the result of pushdown + /// from the children to its parent transparently, ensuring no filters are + /// lost if a node does not override this behavior. + /// + /// **Notes for Implementation:** + /// When returning filters via [`FilterPushdownPropagation`], the order of + /// filters need not match the order they were passed in via + /// `child_pushdown_result`. However, preserving the order is recommended for + /// debugging and ease of reasoning about the resulting plans. + /// + /// **Helper Methods for Customization:** + /// There are various helper methods to simplify implementing this method: + /// - [`FilterPushdownPropagation::if_any`]: Marks all parent filters as + /// supported as long as at least one child supports them. + /// - [`FilterPushdownPropagation::if_all`]: Marks all parent filters as + /// supported as long as all children support them. + /// - [`FilterPushdownPropagation::with_parent_pushdown_result`]: Allows adding filters + /// to the propagation result, indicating which filters are supported by + /// the current node. + /// - [`FilterPushdownPropagation::with_updated_node`]: Allows updating the + /// current node in the propagation result, used if the node + /// has modified its plan based on the pushdown results. + /// + /// **Filter Pushdown Phases:** + /// There are two different phases in filter pushdown (`Pre` and others), + /// which some operators may handle differently. Depending on the phase, the + /// operator may or may not be allowed to modify the plan. See + /// [`FilterPushdownPhase`] for more details on phase-specific behavior. + /// + /// [`PushedDownPredicate::supported`]: crate::filter_pushdown::PushedDownPredicate::supported + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } + + /// Injects arbitrary run-time state into this execution plan, returning a new plan + /// instance that incorporates that state *if* it is relevant to the concrete + /// node implementation. + /// + /// This is a generic entry point: the `state` can be any type wrapped in + /// `Arc`. A node that cares about the state should + /// down-cast it to the concrete type it expects and, if successful, return a + /// modified copy of itself that captures the provided value. If the state is + /// not applicable, the default behaviour is to return `None` so that parent + /// nodes can continue propagating the attempt further down the plan tree. + /// + /// For example, [`WorkTableExec`](crate::work_table::WorkTableExec) + /// down-casts the supplied state to an `Arc` + /// in order to wire up the working table used during recursive-CTE execution. + /// Similar patterns can be followed by custom nodes that need late-bound + /// dependencies or shared state. + fn with_new_state( + &self, + _state: Arc, + ) -> Option> { + None + } } /// [`ExecutionPlan`] Invariant Level @@ -519,13 +723,15 @@ pub trait ExecutionPlanProperties { /// If this ExecutionPlan makes no changes to the schema of the rows flowing /// through it or how columns within each row relate to each other, it /// should return the equivalence properties of its input. For - /// example, since `FilterExec` may remove rows from its input, but does not + /// example, since [`FilterExec`] may remove rows from its input, but does not /// otherwise modify them, it preserves its input equivalence properties. /// However, since `ProjectionExec` may calculate derived expressions, it /// needs special handling. /// /// See also [`ExecutionPlan::maintains_input_order`] and [`Self::output_ordering`] /// for related concepts. + /// + /// [`FilterExec`]: crate::filter::FilterExec fn equivalence_properties(&self) -> &EquivalenceProperties; } @@ -639,6 +845,49 @@ pub enum EmissionType { Both, } +/// Represents whether an operator's `Stream` has been implemented to actively cooperate with the +/// Tokio scheduler or not. Please refer to the [`coop`](crate::coop) module for more details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SchedulingType { + /// The stream generated by [`execute`](ExecutionPlan::execute) does not actively participate in + /// cooperative scheduling. This means the implementation of the `Stream` returned by + /// [`ExecutionPlan::execute`] does not contain explicit task budget consumption such as + /// [`tokio::task::coop::consume_budget`]. + /// + /// `NonCooperative` is the default value and is acceptable for most operators. Please refer to + /// the [`coop`](crate::coop) module for details on when it may be useful to use + /// `Cooperative` instead. + NonCooperative, + /// The stream generated by [`execute`](ExecutionPlan::execute) actively participates in + /// cooperative scheduling by consuming task budget when it was able to produce a + /// [`RecordBatch`]. + Cooperative, +} + +/// Represents how an operator's `Stream` implementation generates `RecordBatch`es. +/// +/// Most operators in DataFusion generate `RecordBatch`es when asked to do so by a call to +/// `Stream::poll_next`. This is known as demand-driven or lazy evaluation. +/// +/// Some operators like `Repartition` need to drive `RecordBatch` generation themselves though. This +/// is known as data-driven or eager evaluation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EvaluationType { + /// The stream generated by [`execute`](ExecutionPlan::execute) only generates `RecordBatch` + /// instances when it is demanded by invoking `Stream::poll_next`. + /// Filter, projection, and join are examples of such lazy operators. + /// + /// Lazy operators are also known as demand-driven operators. + Lazy, + /// The stream generated by [`execute`](ExecutionPlan::execute) eagerly generates `RecordBatch` + /// in one or more spawned Tokio tasks. Eager evaluation is only started the first time + /// `Stream::poll_next` is called. + /// Examples of eager operators are repartition, coalesce partitions, and sort preserving merge. + /// + /// Eager operators are also known as a data-driven operators. + Eager, +} + /// Utility to determine an operator's boundedness based on its children's boundedness. /// /// Assumes boundedness can be inferred from child operators: @@ -727,6 +976,8 @@ pub struct PlanProperties { pub emission_type: EmissionType, /// See [ExecutionPlanProperties::boundedness] pub boundedness: Boundedness, + pub evaluation_type: EvaluationType, + pub scheduling_type: SchedulingType, /// See [ExecutionPlanProperties::output_ordering] output_ordering: Option, } @@ -746,6 +997,8 @@ impl PlanProperties { partitioning, emission_type, boundedness, + evaluation_type: EvaluationType::Lazy, + scheduling_type: SchedulingType::NonCooperative, output_ordering, } } @@ -777,6 +1030,22 @@ impl PlanProperties { self } + /// Set the [`SchedulingType`]. + /// + /// Defaults to [`SchedulingType::NonCooperative`] + pub fn with_scheduling_type(mut self, scheduling_type: SchedulingType) -> Self { + self.scheduling_type = scheduling_type; + self + } + + /// Set the [`EvaluationType`]. + /// + /// Defaults to [`EvaluationType::Lazy`] + pub fn with_evaluation_type(mut self, drive_type: EvaluationType) -> Self { + self.evaluation_type = drive_type; + self + } + /// Overwrite constraints with its new value. pub fn with_constraints(mut self, constraints: Constraints) -> Self { self.eq_properties = self.eq_properties.with_constraints(constraints); @@ -801,32 +1070,45 @@ impl PlanProperties { } } +macro_rules! check_len { + ($target:expr, $func_name:ident, $expected_len:expr) => { + let actual_len = $target.$func_name().len(); + if actual_len != $expected_len { + return internal_err!( + "{}::{} returned Vec with incorrect size: {} != {}", + $target.name(), + stringify!($func_name), + actual_len, + $expected_len + ); + } + }; +} + +/// Checks a set of invariants that apply to all ExecutionPlan implementations. +/// Returns an error if the given node does not conform. +pub fn check_default_invariants( + plan: &P, + _check: InvariantLevel, +) -> Result<(), DataFusionError> { + let children_len = plan.children().len(); + + check_len!(plan, maintains_input_order, children_len); + check_len!(plan, required_input_ordering, children_len); + check_len!(plan, required_input_distribution, children_len); + check_len!(plan, benefits_from_input_partitioning, children_len); + + Ok(()) +} + /// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful /// especially for the distributed engine to judge whether need to deal with shuffling. -/// Currently there are 3 kinds of execution plan which needs data exchange +/// Currently, there are 3 kinds of execution plan which needs data exchange /// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s /// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee /// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee pub fn need_data_exchange(plan: Arc) -> bool { - if let Some(repartition) = plan.as_any().downcast_ref::() { - !matches!( - repartition.properties().output_partitioning(), - Partitioning::RoundRobinBatch(_) - ) - } else if let Some(coalesce) = plan.as_any().downcast_ref::() - { - coalesce.input().output_partitioning().partition_count() > 1 - } else if let Some(sort_preserving_merge) = - plan.as_any().downcast_ref::() - { - sort_preserving_merge - .input() - .output_partitioning() - .partition_count() - > 1 - } else { - false - } + plan.properties().evaluation_type == EvaluationType::Eager } /// Returns a copy of this plan if we change any child according to the pointer comparison. @@ -1053,7 +1335,7 @@ pub fn check_not_null_constraints( pub fn get_plan_string(plan: &Arc) -> Vec { let formatted = displayable(plan.as_ref()).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() + actual.iter().map(|elem| (*elem).to_string()).collect() } /// Indicates the effect an execution plan operator will have on the cardinality @@ -1072,17 +1354,17 @@ pub enum CardinalityEffect { #[cfg(test)] mod tests { - use super::*; - use arrow::array::{DictionaryArray, Int32Array, NullArray, RunArray}; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use std::any::Any; use std::sync::Arc; + use super::*; + use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + + use arrow::array::{DictionaryArray, Int32Array, NullArray, RunArray}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; - #[derive(Debug)] pub struct EmptyExec; @@ -1137,6 +1419,10 @@ mod tests { fn statistics(&self) -> Result { unimplemented!() } + + fn partition_statistics(&self, _partition: Option) -> Result { + unimplemented!() + } } #[derive(Debug)] @@ -1200,6 +1486,10 @@ mod tests { fn statistics(&self) -> Result { unimplemented!() } + + fn partition_statistics(&self, _partition: Option) -> Result { + unimplemented!() + } } #[test] diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index a8a9973ea0434..047c72076e4c6 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -20,15 +20,21 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; +use itertools::Itertools; + use super::{ ColumnStatistics, DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; use crate::common::can_project; use crate::execution_plan::CardinalityEffect; +use crate::filter_pushdown::{ + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, PushedDown, PushedDownPredicate, +}; use crate::projection::{ make_with_child, try_embed_projection, update_expr, EmbeddedProjection, - ProjectionExec, + ProjectionExec, ProjectionExpr, }; use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, @@ -39,6 +45,7 @@ use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; +use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::{ internal_err, plan_err, project_schema, DataFusionError, Result, ScalarValue, @@ -46,18 +53,20 @@ use datafusion_common::{ use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::BinaryExpr; +use datafusion_physical_expr::expressions::{lit, BinaryExpr, Column}; use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - analyze, split_conjunction, AcrossPartitions, AnalysisContext, ConstExpr, - ExprBoundaries, PhysicalExpr, + analyze, conjunction, split_conjunction, AcrossPartitions, AnalysisContext, + ConstExpr, ExprBoundaries, PhysicalExpr, }; use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::stream::{Stream, StreamExt}; use log::trace; +const FILTER_EXEC_DEFAULT_SELECTIVITY: u8 = 20; + /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to /// include in its output batches. #[derive(Debug, Clone)] @@ -84,7 +93,7 @@ impl FilterExec { ) -> Result { match predicate.data_type(input.schema().as_ref())? { DataType::Boolean => { - let default_selectivity = 20; + let default_selectivity = FILTER_EXEC_DEFAULT_SELECTIVITY; let cache = Self::compute_properties( &input, &predicate, @@ -170,12 +179,11 @@ impl FilterExec { /// Calculates `Statistics` for `FilterExec`, by applying selectivity (either default, or estimated) to input statistics. fn statistics_helper( - input: &Arc, + schema: SchemaRef, + input_stats: Statistics, predicate: &Arc, default_selectivity: u8, ) -> Result { - let input_stats = input.statistics()?; - let schema = input.schema(); if !check_support(predicate, &schema) { let selectivity = default_selectivity as f64 / 100.0; let mut stats = input_stats.to_inexact(); @@ -189,7 +197,7 @@ impl FilterExec { let num_rows = input_stats.num_rows; let total_byte_size = input_stats.total_byte_size; let input_analysis_ctx = AnalysisContext::try_from_statistics( - &input.schema(), + &schema, &input_stats.column_statistics, )?; @@ -223,24 +231,18 @@ impl FilterExec { if let Some(binary) = conjunction.as_any().downcast_ref::() { if binary.op() == &Operator::Eq { // Filter evaluates to single value for all partitions - if input_eqs.is_expr_constant(binary.left()) { - let (expr, across_parts) = ( - binary.right(), - input_eqs.get_expr_constant_value(binary.right()), - ); - res_constants.push( - ConstExpr::new(Arc::clone(expr)) - .with_across_partitions(across_parts), - ); - } else if input_eqs.is_expr_constant(binary.right()) { - let (expr, across_parts) = ( - binary.left(), - input_eqs.get_expr_constant_value(binary.left()), - ); - res_constants.push( - ConstExpr::new(Arc::clone(expr)) - .with_across_partitions(across_parts), - ); + if input_eqs.is_expr_constant(binary.left()).is_some() { + let across = input_eqs + .is_expr_constant(binary.right()) + .unwrap_or_default(); + res_constants + .push(ConstExpr::new(Arc::clone(binary.right()), across)); + } else if input_eqs.is_expr_constant(binary.right()).is_some() { + let across = input_eqs + .is_expr_constant(binary.left()) + .unwrap_or_default(); + res_constants + .push(ConstExpr::new(Arc::clone(binary.left()), across)); } } } @@ -256,11 +258,16 @@ impl FilterExec { ) -> Result { // Combine the equal predicates with the input equivalence properties // to construct the equivalence properties: - let stats = Self::statistics_helper(input, predicate, default_selectivity)?; + let stats = Self::statistics_helper( + input.schema(), + input.partition_statistics(None)?, + predicate, + default_selectivity, + )?; let mut eq_properties = input.equivalence_properties().clone(); - let (equal_pairs, _) = collect_columns_from_predicate(predicate); + let (equal_pairs, _) = collect_columns_from_predicate_inner(predicate); for (lhs, rhs) in equal_pairs { - eq_properties.add_equal_conditions(lhs, rhs)? + eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))? } // Add the columns that have only one viable value (singleton) after // filtering to constants. @@ -272,15 +279,13 @@ impl FilterExec { .min_value .get_value(); let expr = Arc::new(column) as _; - ConstExpr::new(expr) - .with_across_partitions(AcrossPartitions::Uniform(value.cloned())) + ConstExpr::new(expr, AcrossPartitions::Uniform(value.cloned())) }); // This is for statistics - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(constants)?; // This is for logical constant (for example: a = '1', then a could be marked as a constant) // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) - eq_properties = - eq_properties.with_constants(Self::extend_constants(input, predicate)); + eq_properties.add_constants(Self::extend_constants(input, predicate))?; let mut output_partitioning = input.output_partitioning().clone(); // If contains projection, update the PlanProperties. @@ -396,8 +401,14 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; let stats = Self::statistics_helper( - &self.input, + self.schema(), + input_stats, self.predicate(), self.default_selectivity, )?; @@ -433,6 +444,119 @@ impl ExecutionPlan for FilterExec { } try_embed_projection(projection, self) } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + if !matches!(phase, FilterPushdownPhase::Pre) { + // For non-pre phase, filters pass through unchanged + let filter_supports = parent_filters + .into_iter() + .map(PushedDownPredicate::supported) + .collect(); + return Ok(FilterDescription::new().with_child(ChildFilterDescription { + parent_filters: filter_supports, + self_filters: vec![], + })); + } + + let child = ChildFilterDescription::from_child(&parent_filters, self.input())? + .with_self_filters( + split_conjunction(&self.predicate) + .into_iter() + .cloned() + .collect(), + ); + + Ok(FilterDescription::new().with_child(child)) + } + + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + if !matches!(phase, FilterPushdownPhase::Pre) { + return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); + } + // We absorb any parent filters that were not handled by our children + let unsupported_parent_filters = + child_pushdown_result.parent_filters.iter().filter_map(|f| { + matches!(f.all(), PushedDown::No).then_some(Arc::clone(&f.filter)) + }); + let unsupported_self_filters = child_pushdown_result + .self_filters + .first() + .expect("we have exactly one child") + .iter() + .filter_map(|f| match f.discriminant { + PushedDown::Yes => None, + PushedDown::No => Some(&f.predicate), + }) + .cloned(); + + let unhandled_filters = unsupported_parent_filters + .into_iter() + .chain(unsupported_self_filters) + .collect_vec(); + + // If we have unhandled filters, we need to create a new FilterExec + let filter_input = Arc::clone(self.input()); + let new_predicate = conjunction(unhandled_filters); + let updated_node = if new_predicate.eq(&lit(true)) { + // FilterExec is no longer needed, but we may need to leave a projection in place + match self.projection() { + Some(projection_indices) => { + let filter_child_schema = filter_input.schema(); + let proj_exprs = projection_indices + .iter() + .map(|p| { + let field = filter_child_schema.field(*p).clone(); + ProjectionExpr { + expr: Arc::new(Column::new(field.name(), *p)) + as Arc, + alias: field.name().to_string(), + } + }) + .collect::>(); + Some(Arc::new(ProjectionExec::try_new(proj_exprs, filter_input)?) + as Arc) + } + None => { + // No projection needed, just return the input + Some(filter_input) + } + } + } else if new_predicate.eq(&self.predicate) { + // The new predicate is the same as our current predicate + None + } else { + // Create a new FilterExec with the new predicate + let new = FilterExec { + predicate: Arc::clone(&new_predicate), + input: Arc::clone(&filter_input), + metrics: self.metrics.clone(), + default_selectivity: self.default_selectivity, + cache: Self::compute_properties( + &filter_input, + &new_predicate, + self.default_selectivity, + self.projection.as_ref(), + )?, + projection: None, + }; + Some(Arc::new(new) as _) + }; + + Ok(FilterPushdownPropagation { + filters: vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()], + updated_node, + }) + } } impl EmbeddedProjection for FilterExec { @@ -592,7 +716,19 @@ impl RecordBatchStream for FilterExecStream { } /// Return the equals Column-Pairs and Non-equals Column-Pairs -fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual { +#[deprecated( + since = "51.0.0", + note = "This function will be internal in the future" +)] +pub fn collect_columns_from_predicate( + predicate: &'_ Arc, +) -> EqualAndNonEqual<'_> { + collect_columns_from_predicate_inner(predicate) +} + +fn collect_columns_from_predicate_inner( + predicate: &'_ Arc, +) -> EqualAndNonEqual<'_> { let mut eq_predicate_columns = Vec::::new(); let mut ne_predicate_columns = Vec::::new(); @@ -661,7 +797,7 @@ mod tests { &schema, )?; - let (equal_pairs, ne_pairs) = collect_columns_from_predicate(&predicate); + let (equal_pairs, ne_pairs) = collect_columns_from_predicate_inner(&predicate); assert_eq!(2, equal_pairs.len()); assert!(equal_pairs[0].0.eq(&col("c2", &schema)?)); assert!(equal_pairs[0].1.eq(&lit(4u32))); @@ -703,7 +839,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(25)); assert_eq!( statistics.total_byte_size, @@ -753,7 +889,7 @@ mod tests { sub_filter, )?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(16)); assert_eq!( statistics.column_statistics, @@ -813,7 +949,7 @@ mod tests { binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?, b_gt_5, )?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; // On a uniform distribution, only fifteen rows will satisfy the // filter that 'a' proposed (a >= 10 AND a <= 25) (15/100) and only // 5 rows will satisfy the filter that 'b' proposed (b > 45) (5/50). @@ -858,7 +994,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Absent); Ok(()) @@ -931,7 +1067,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; // 0.5 (from a) * 0.333333... (from b) * 0.798387... (from c) ≈ 0.1330... // num_rows after ceil => 133.0... => 134 // total_byte_size after ceil => 532.0... => 533 @@ -1027,10 +1163,10 @@ mod tests { )), )); // Since filter predicate passes all entries, statistics after filter shouldn't change. - let expected = input.statistics()?.column_statistics; + let expected = input.partition_statistics(None)?.column_statistics; let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(1000)); assert_eq!(statistics.total_byte_size, Precision::Inexact(4000)); @@ -1083,7 +1219,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(0)); assert_eq!(statistics.total_byte_size, Precision::Inexact(0)); @@ -1143,7 +1279,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(490)); assert_eq!(statistics.total_byte_size, Precision::Inexact(1960)); @@ -1193,7 +1329,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let filter_statistics = filter.statistics()?; + let filter_statistics = filter.partition_statistics(None)?; let expected_filter_statistics = Statistics { num_rows: Precision::Absent, @@ -1227,7 +1363,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let filter_statistics = filter.statistics()?; + let filter_statistics = filter.partition_statistics(None)?; // First column is "a", and it is a column with only one value after the filter. assert!(filter_statistics.column_statistics[0].is_singleton()); @@ -1274,11 +1410,11 @@ mod tests { Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), )); let filter = FilterExec::try_new(predicate, input)?; - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(200)); assert_eq!(statistics.total_byte_size, Precision::Inexact(800)); let filter = filter.with_default_selectivity(40)?; - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(400)); assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); Ok(()) @@ -1312,7 +1448,7 @@ mod tests { Arc::new(EmptyExec::new(Arc::clone(&schema))), )?; - exec.statistics().unwrap(); + exec.partition_statistics(None).unwrap(); Ok(()) } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs new file mode 100644 index 0000000000000..f6b1b7448f885 --- /dev/null +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -0,0 +1,464 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Filter Pushdown Optimization Process +//! +//! The filter pushdown mechanism involves four key steps: +//! 1. **Optimizer Asks Parent for a Filter Pushdown Plan**: The optimizer calls [`ExecutionPlan::gather_filters_for_pushdown`] +//! on the parent node, passing in parent predicates and phase. The parent node creates a [`FilterDescription`] +//! by inspecting its logic and children's schemas, determining which filters can be pushed to each child. +//! 2. **Optimizer Executes Pushdown**: The optimizer recursively pushes down filters for each child, +//! passing the appropriate filters (`Vec>`) for that child. +//! 3. **Optimizer Gathers Results**: The optimizer collects [`FilterPushdownPropagation`] results from children, +//! containing information about which filters were successfully pushed down vs. unsupported. +//! 4. **Parent Responds**: The optimizer calls [`ExecutionPlan::handle_child_pushdown_result`] on the parent, +//! passing a [`ChildPushdownResult`] containing the aggregated pushdown outcomes. The parent decides +//! how to handle filters that couldn't be pushed down (e.g., keep them as FilterExec nodes). +//! +//! [`ExecutionPlan::gather_filters_for_pushdown`]: crate::ExecutionPlan::gather_filters_for_pushdown +//! [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result +//! +//! See also datafusion/physical-optimizer/src/filter_pushdown.rs. + +use std::collections::HashSet; +use std::sync::Arc; + +use datafusion_common::Result; +use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FilterPushdownPhase { + /// Pushdown that happens before most other optimizations. + /// This pushdown allows static filters that do not reference any [`ExecutionPlan`]s to be pushed down. + /// Filters that reference an [`ExecutionPlan`] cannot be pushed down at this stage since the whole plan tree may be rewritten + /// by other optimizations. + /// Implementers are however allowed to modify the execution plan themselves during this phase, for example by returning a completely + /// different [`ExecutionPlan`] from [`ExecutionPlan::handle_child_pushdown_result`]. + /// + /// Pushdown of [`FilterExec`] into `DataSourceExec` is an example of a pre-pushdown. + /// Unlike filter pushdown in the logical phase, which operates on the logical plan to push filters into the logical table scan, + /// the `Pre` phase in the physical plan targets the actual physical scan, pushing filters down to specific data source implementations. + /// For example, Parquet supports filter pushdown to reduce data read during scanning, while CSV typically does not. + /// + /// [`ExecutionPlan`]: crate::ExecutionPlan + /// [`FilterExec`]: crate::filter::FilterExec + /// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result + Pre, + /// Pushdown that happens after most other optimizations. + /// This stage of filter pushdown allows filters that reference an [`ExecutionPlan`] to be pushed down. + /// Since subsequent optimizations should not change the structure of the plan tree except for calling [`ExecutionPlan::with_new_children`] + /// (which generally preserves internal references) it is safe for references between [`ExecutionPlan`]s to be established at this stage. + /// + /// This phase is used to link a [`SortExec`] (with a TopK operator) or a [`HashJoinExec`] to a `DataSourceExec`. + /// + /// [`ExecutionPlan`]: crate::ExecutionPlan + /// [`ExecutionPlan::with_new_children`]: crate::ExecutionPlan::with_new_children + /// [`SortExec`]: crate::sorts::sort::SortExec + /// [`HashJoinExec`]: crate::joins::HashJoinExec + /// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result + Post, +} + +impl std::fmt::Display for FilterPushdownPhase { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FilterPushdownPhase::Pre => write!(f, "Pre"), + FilterPushdownPhase::Post => write!(f, "Post"), + } + } +} + +/// The result of a plan for pushing down a filter into a child node. +/// This contains references to filters so that nodes can mutate a filter +/// before pushing it down to a child node (e.g. to adjust a projection) +/// or can directly take ownership of filters that their children +/// could not handle. +#[derive(Debug, Clone)] +pub struct PushedDownPredicate { + pub discriminant: PushedDown, + pub predicate: Arc, +} + +impl PushedDownPredicate { + /// Return the wrapped [`PhysicalExpr`], discarding whether it is supported or unsupported. + pub fn into_inner(self) -> Arc { + self.predicate + } + + /// Create a new [`PushedDownPredicate`] with supported pushdown. + pub fn supported(predicate: Arc) -> Self { + Self { + discriminant: PushedDown::Yes, + predicate, + } + } + + /// Create a new [`PushedDownPredicate`] with unsupported pushdown. + pub fn unsupported(predicate: Arc) -> Self { + Self { + discriminant: PushedDown::No, + predicate, + } + } +} + +/// Discriminant for the result of pushing down a filter into a child node. +#[derive(Debug, Clone, Copy)] +pub enum PushedDown { + /// The predicate was successfully pushed down into the child node. + Yes, + /// The predicate could not be pushed down into the child node. + No, +} + +impl PushedDown { + /// Logical AND operation: returns `Yes` only if both operands are `Yes`. + pub fn and(self, other: PushedDown) -> PushedDown { + match (self, other) { + (PushedDown::Yes, PushedDown::Yes) => PushedDown::Yes, + _ => PushedDown::No, + } + } + + /// Logical OR operation: returns `Yes` if either operand is `Yes`. + pub fn or(self, other: PushedDown) -> PushedDown { + match (self, other) { + (PushedDown::Yes, _) | (_, PushedDown::Yes) => PushedDown::Yes, + (PushedDown::No, PushedDown::No) => PushedDown::No, + } + } + + /// Wrap a [`PhysicalExpr`] with this pushdown result. + pub fn wrap_expression(self, expr: Arc) -> PushedDownPredicate { + PushedDownPredicate { + discriminant: self, + predicate: expr, + } + } +} + +/// The result of pushing down a single parent filter into all children. +#[derive(Debug, Clone)] +pub struct ChildFilterPushdownResult { + pub filter: Arc, + pub child_results: Vec, +} + +impl ChildFilterPushdownResult { + /// Combine all child results using OR logic. + /// Returns `Yes` if **any** child supports the filter. + /// Returns `No` if **all** children reject the filter or if there are no children. + pub fn any(&self) -> PushedDown { + if self.child_results.is_empty() { + // If there are no children, filters cannot be supported + PushedDown::No + } else { + self.child_results + .iter() + .fold(PushedDown::No, |acc, result| acc.or(*result)) + } + } + + /// Combine all child results using AND logic. + /// Returns `Yes` if **all** children support the filter. + /// Returns `No` if **any** child rejects the filter or if there are no children. + pub fn all(&self) -> PushedDown { + if self.child_results.is_empty() { + // If there are no children, filters cannot be supported + PushedDown::No + } else { + self.child_results + .iter() + .fold(PushedDown::Yes, |acc, result| acc.and(*result)) + } + } +} + +/// The result of pushing down filters into a child node. +/// +/// This is the result provided to nodes in [`ExecutionPlan::handle_child_pushdown_result`]. +/// Nodes process this result and convert it into a [`FilterPushdownPropagation`] +/// that is returned to their parent. +/// +/// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result +#[derive(Debug, Clone)] +pub struct ChildPushdownResult { + /// The parent filters that were pushed down as received by the current node when [`ExecutionPlan::gather_filters_for_pushdown`](crate::ExecutionPlan::handle_child_pushdown_result) was called. + /// Note that this may *not* be the same as the filters that were passed to the children as the current node may have modified them + /// (e.g. by reassigning column indices) when it returned them from [`ExecutionPlan::gather_filters_for_pushdown`](crate::ExecutionPlan::handle_child_pushdown_result) in a [`FilterDescription`]. + /// Attached to each filter is a [`PushedDown`] *per child* that indicates whether the filter was supported or unsupported by each child. + /// To get combined results see [`ChildFilterPushdownResult::any`] and [`ChildFilterPushdownResult::all`]. + pub parent_filters: Vec, + /// The result of pushing down each filter this node provided into each of it's children. + /// The outer vector corresponds to each child, and the inner vector corresponds to each filter. + /// Since this node may have generated a different filter for each child the inner vector may have different lengths or the expressions may not match at all. + /// It is up to each node to interpret this result based on the filters it provided for each child in [`ExecutionPlan::gather_filters_for_pushdown`](crate::ExecutionPlan::handle_child_pushdown_result). + pub self_filters: Vec>, +} + +/// The result of pushing down filters into a node. +/// +/// Returned from [`ExecutionPlan::handle_child_pushdown_result`] to communicate +/// to the optimizer: +/// +/// 1. What to do with any parent filters that were could not be pushed down into the children. +/// 2. If the node needs to be replaced in the execution plan with a new node or not. +/// +/// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result +#[derive(Debug, Clone)] +pub struct FilterPushdownPropagation { + /// What filters were pushed into the parent node. + pub filters: Vec, + /// The updated node, if it was updated during pushdown + pub updated_node: Option, +} + +impl FilterPushdownPropagation { + /// Create a new [`FilterPushdownPropagation`] that tells the parent node that each parent filter + /// is supported if it was supported by *all* children. + pub fn if_all(child_pushdown_result: ChildPushdownResult) -> Self { + let filters = child_pushdown_result + .parent_filters + .into_iter() + .map(|result| result.all()) + .collect(); + Self { + filters, + updated_node: None, + } + } + + /// Create a new [`FilterPushdownPropagation`] that tells the parent node that each parent filter + /// is supported if it was supported by *any* child. + pub fn if_any(child_pushdown_result: ChildPushdownResult) -> Self { + let filters = child_pushdown_result + .parent_filters + .into_iter() + .map(|result| result.any()) + .collect(); + Self { + filters, + updated_node: None, + } + } + + /// Create a new [`FilterPushdownPropagation`] that tells the parent node that no filters were pushed down regardless of the child results. + pub fn all_unsupported(child_pushdown_result: ChildPushdownResult) -> Self { + let filters = child_pushdown_result + .parent_filters + .into_iter() + .map(|_| PushedDown::No) + .collect(); + Self { + filters, + updated_node: None, + } + } + + /// Create a new [`FilterPushdownPropagation`] with the specified filter support. + /// This transmits up to our parent node what the result of pushing down the filters into our node and possibly our subtree was. + pub fn with_parent_pushdown_result(filters: Vec) -> Self { + Self { + filters, + updated_node: None, + } + } + + /// Bind an updated node to the [`FilterPushdownPropagation`]. + /// Use this when the current node wants to update itself in the tree or replace itself with a new node (e.g. one of it's children). + /// You do not need to call this if one of the children of the current node may have updated itself, that is handled by the optimizer. + pub fn with_updated_node(mut self, updated_node: T) -> Self { + self.updated_node = Some(updated_node); + self + } +} + +/// Describes filter pushdown for a single child node. +/// +/// This structure contains two types of filters: +/// - **Parent filters**: Filters received from the parent node, marked as supported or unsupported +/// - **Self filters**: Filters generated by the current node to be pushed down to this child +#[derive(Debug, Clone)] +pub struct ChildFilterDescription { + /// Description of which parent filters can be pushed down into this node. + /// Since we need to transmit filter pushdown results back to this node's parent + /// we need to track each parent filter for each child, even those that are unsupported / won't be pushed down. + pub(crate) parent_filters: Vec, + /// Description of which filters this node is pushing down to its children. + /// Since this is not transmitted back to the parents we can have variable sized inner arrays + /// instead of having to track supported/unsupported. + pub(crate) self_filters: Vec>, +} + +impl ChildFilterDescription { + /// Build a child filter description by analyzing which parent filters can be pushed to a specific child. + /// + /// This method performs column analysis to determine which filters can be pushed down: + /// - If all columns referenced by a filter exist in the child's schema, it can be pushed down + /// - Otherwise, it cannot be pushed down to that child + /// + /// See [`FilterDescription::from_children`] for more details + pub fn from_child( + parent_filters: &[Arc], + child: &Arc, + ) -> Result { + let child_schema = child.schema(); + + // Get column names from child schema for quick lookup + let child_column_names: HashSet<&str> = child_schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect(); + + // Analyze each parent filter + let mut child_parent_filters = Vec::with_capacity(parent_filters.len()); + + for filter in parent_filters { + // Check which columns the filter references + let referenced_columns = collect_columns(filter); + + // Check if all referenced columns exist in the child schema + let all_columns_exist = referenced_columns + .iter() + .all(|col| child_column_names.contains(col.name())); + + if all_columns_exist { + // All columns exist in child - we can push down + // Need to reassign column indices to match child schema + let reassigned_filter = + reassign_expr_columns(Arc::clone(filter), &child_schema)?; + child_parent_filters + .push(PushedDownPredicate::supported(reassigned_filter)); + } else { + // Some columns don't exist in child - cannot push down + child_parent_filters + .push(PushedDownPredicate::unsupported(Arc::clone(filter))); + } + } + + Ok(Self { + parent_filters: child_parent_filters, + self_filters: vec![], + }) + } + + /// Add a self filter (from the current node) to be pushed down to this child. + pub fn with_self_filter(mut self, filter: Arc) -> Self { + self.self_filters.push(filter); + self + } + + /// Add multiple self filters. + pub fn with_self_filters(mut self, filters: Vec>) -> Self { + self.self_filters.extend(filters); + self + } +} + +/// Describes how filters should be pushed down to children. +/// +/// This structure contains filter descriptions for each child node, specifying: +/// - Which parent filters can be pushed down to each child +/// - Which self-generated filters should be pushed down to each child +/// +/// The filter routing is determined by column analysis - filters can only be pushed +/// to children whose schemas contain all the referenced columns. +#[derive(Debug, Clone)] +pub struct FilterDescription { + /// A filter description for each child. + /// This includes which parent filters and which self filters (from the node in question) + /// will get pushed down to each child. + child_filter_descriptions: Vec, +} + +impl Default for FilterDescription { + fn default() -> Self { + Self::new() + } +} + +impl FilterDescription { + /// Create a new empty FilterDescription + pub fn new() -> Self { + Self { + child_filter_descriptions: vec![], + } + } + + /// Add a child filter description + pub fn with_child(mut self, child: ChildFilterDescription) -> Self { + self.child_filter_descriptions.push(child); + self + } + + /// Build a filter description by analyzing which parent filters can be pushed to each child. + /// This method automatically determines filter routing based on column analysis: + /// - If all columns referenced by a filter exist in a child's schema, it can be pushed down + /// - Otherwise, it cannot be pushed down to that child + pub fn from_children( + parent_filters: Vec>, + children: &[&Arc], + ) -> Result { + let mut desc = Self::new(); + + // For each child, create a ChildFilterDescription + for child in children { + desc = desc + .with_child(ChildFilterDescription::from_child(&parent_filters, child)?); + } + + Ok(desc) + } + + /// Mark all parent filters as unsupported for all children. + pub fn all_unsupported( + parent_filters: &[Arc], + children: &[&Arc], + ) -> Self { + let mut desc = Self::new(); + let child_filters = parent_filters + .iter() + .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) + .collect_vec(); + for _ in 0..children.len() { + desc = desc.with_child(ChildFilterDescription { + parent_filters: child_filters.clone(), + self_filters: vec![], + }); + } + desc + } + + pub fn parent_filters(&self) -> Vec> { + self.child_filter_descriptions + .iter() + .map(|d| &d.parent_filters) + .cloned() + .collect() + } + + pub fn self_filters(&self) -> Vec>> { + self.child_filter_descriptions + .iter() + .map(|d| &d.self_filters) + .cloned() + .collect() + } +} diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 639fae7615af0..949c4e784bc3e 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -25,7 +25,6 @@ use super::utils::{ BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; -use crate::coalesce_partitions::CoalescePartitionsExec; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ @@ -116,7 +115,7 @@ impl CrossJoinExec { }; let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); - let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap(); CrossJoinExec { left, @@ -143,7 +142,7 @@ impl CrossJoinExec { left: &Arc, right: &Arc, schema: SchemaRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties // TODO: Check equivalence properties of cross join, it may preserve // ordering in some cases. @@ -155,7 +154,7 @@ impl CrossJoinExec { &[false, false], None, &[], - ); + )?; // Get output partitioning: // TODO: Optimize the cross join implementation to generate M * N @@ -163,19 +162,25 @@ impl CrossJoinExec { let output_partitioning = adjust_right_output_partitioning( right.output_partitioning(), left.schema().fields.len(), - ); + )?; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, EmissionType::Final, boundedness_from_children([left, right]), - ) + )) } /// Returns a new `ExecutionPlan` that computes the same join as this one, /// with the left and right inputs swapped using the specified /// `partition_mode`. + /// + /// # Notes: + /// + /// This function should be called BEFORE inserting any repartitioning + /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`] + /// for more details. pub fn swap_inputs(&self) -> Result> { let new_join = CrossJoinExec::new(Arc::clone(&self.right), Arc::clone(&self.left)); @@ -189,19 +194,11 @@ impl CrossJoinExec { /// Asynchronously collect the result of the left child async fn load_left_input( - left: Arc, - context: Arc, + stream: SendableRecordBatchStream, metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, ) -> Result { - // merge all left parts into a single stream - let left_schema = left.schema(); - let merge = if left.output_partitioning().partition_count() != 1 { - Arc::new(CoalescePartitionsExec::new(left)) - } else { - left - }; - let stream = merge.execute(0, context)?; + let left_schema = stream.schema(); // Load all batches and count the rows let (batches, _metrics, reservation) = stream @@ -279,6 +276,18 @@ impl ExecutionPlan for CrossJoinExec { ))) } + fn reset_state(self: Arc) -> Result> { + let new_exec = CrossJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + schema: Arc::clone(&self.schema), + left_fut: Default::default(), // reset the build side! + metrics: ExecutionPlanMetricsSet::default(), + cache: self.cache.clone(), + }; + Ok(Arc::new(new_exec)) + } + fn required_input_distribution(&self) -> Vec { vec![ Distribution::SinglePartition, @@ -291,6 +300,13 @@ impl ExecutionPlan for CrossJoinExec { partition: usize, context: Arc, ) -> Result { + if self.left.output_partitioning().partition_count() != 1 { + return internal_err!( + "Invalid CrossJoinExec, the output partition count of the left child must be 1,\ + consider using CoalescePartitionsExec or the EnforceDistribution rule" + ); + } + let stream = self.right.execute(partition, Arc::clone(&context))?; let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); @@ -303,14 +319,15 @@ impl ExecutionPlan for CrossJoinExec { let enforce_batch_size_in_joins = context.session_config().enforce_batch_size_in_joins(); - let left_fut = self.left_fut.once(|| { - load_left_input( - Arc::clone(&self.left), - context, + let left_fut = self.left_fut.try_once(|| { + let left_stream = self.left.execute(0, context)?; + + Ok(load_left_input( + left_stream, join_metrics.clone(), reservation, - ) - }); + )) + })?; if enforce_batch_size_in_joins { Ok(Box::pin(CrossJoinStream { @@ -338,10 +355,15 @@ impl ExecutionPlan for CrossJoinExec { } fn statistics(&self) -> Result { - Ok(stats_cartesian_product( - self.left.statistics()?, - self.right.statistics()?, - )) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + // Get the all partitions statistics of the left + let left_stats = self.left.partition_statistics(None)?; + let right_stats = self.right.partition_statistics(partition)?; + + Ok(stats_cartesian_product(left_stats, right_stats)) } /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, @@ -555,7 +577,8 @@ impl CrossJoinStream { handle_state!(ready!(self.fetch_probe_batch(cx))) } CrossJoinStreamState::BuildBatches(_) => { - handle_state!(self.build_batches()) + let poll = handle_state!(self.build_batches()); + self.join_metrics.baseline.record_poll(poll) } }; } @@ -628,7 +651,6 @@ impl CrossJoinStream { } self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); return Ok(StatefulStreamResult::Ready(Some(batch))); } } @@ -643,7 +665,7 @@ impl CrossJoinStream { mod tests { use super::*; use crate::common; - use crate::test::build_table_scan_i32; + use crate::test::{assert_join_metrics, build_table_scan_i32}; use datafusion_common::{assert_contains, test_util::batches_to_sort_string}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; @@ -653,14 +675,15 @@ mod tests { left: Arc, right: Arc, context: Arc, - ) -> Result<(Vec, Vec)> { + ) -> Result<(Vec, Vec, MetricsSet)> { let join = CrossJoinExec::new(left, right); let columns_header = columns(&join.schema()); let stream = join.execute(0, context)?; let batches = common::collect(stream).await?; + let metrics = join.metrics().unwrap(); - Ok((columns_header, batches)) + Ok((columns_header, batches, metrics)) } #[tokio::test] @@ -827,7 +850,7 @@ mod tests { ("c2", &vec![14, 15]), ); - let (columns, batches) = join_collect(left, right, task_ctx).await?; + let (columns, batches, metrics) = join_collect(left, right, task_ctx).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -844,6 +867,8 @@ mod tests { +----+----+----+----+----+----+ "#); + assert_join_metrics!(metrics, 6); + Ok(()) } @@ -870,7 +895,7 @@ mod tests { assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec" + "Resources exhausted: Additional allocation failed for CrossJoinExec with top memory consumers (across reservations) as:\n CrossJoinExec" ); Ok(()) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs similarity index 75% rename from datafusion/physical-plan/src/joins/hash_join.rs rename to datafusion/physical-plan/src/joins/hash_join/exec.rs index c2a313edd1564..4c293b0498e77 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -15,24 +15,27 @@ // specific language governing permissions and limitations // under the License. -//! [`HashJoinExec`] Partitioned Hash Join Operator - use std::fmt; use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::task::Poll; +use std::sync::{Arc, OnceLock}; use std::{any::Any, vec}; -use super::utils::{ - asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap, - reorder_output_after_swap, swap_join_projection, +use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, }; -use super::{ - utils::{OnceAsync, OnceFut}, - PartitionMode, SharedBitmapBuilder, +use crate::joins::hash_join::shared_bounds::{ColumnBounds, SharedBoundsAccumulator}; +use crate::joins::hash_join::stream::{ + BuildSide, BuildSideInitialState, HashJoinStream, HashJoinStreamState, }; -use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::joins::join_hash_map::{JoinHashMapU32, JoinHashMapU64}; +use crate::joins::utils::{ + asymmetric_join_output_partitioning, reorder_output_after_swap, swap_join_projection, + update_hash, OnceAsync, OnceFut, +}; +use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, @@ -40,57 +43,51 @@ use crate::projection::{ use crate::spill::get_record_batch_memory_size; use crate::ExecutionPlanProperties; use crate::{ - coalesce_partitions::CoalescePartitionsExec, common::can_project, - handle_state, - hash_utils::create_hashes, - joins::join_hash_map::JoinHashMapOffset, joins::utils::{ - adjust_indices_by_join_type, apply_join_filter_to_indices, - build_batch_from_indices, build_join_schema, check_join_is_valid, - estimate_join_statistics, need_produce_result_in_final, - symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, - JoinFilter, JoinHashMap, JoinHashMapType, JoinOn, JoinOnRef, - StatefulStreamResult, + build_join_schema, check_join_is_valid, estimate_join_statistics, + need_produce_result_in_final, symmetric_join_output_partitioning, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + PlanProperties, SendableRecordBatchStream, Statistics, }; -use arrow::array::{ - cast::downcast_array, Array, ArrayRef, BooleanArray, BooleanBufferBuilder, - UInt32Array, UInt64Array, -}; -use arrow::compute::kernels::cmp::{eq, not_distinct}; -use arrow::compute::{and, concat_batches, take, FilterBuilder}; -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::ArrowError; +use arrow::array::{ArrayRef, BooleanBufferBuilder}; +use arrow::compute::concat_batches; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; +use arrow_schema::DataType; +use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, - JoinSide, JoinType, Result, + internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; -use datafusion_expr::Operator; +use datafusion_expr::Accumulator; +use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; -use datafusion_physical_expr::PhysicalExprRef; -use datafusion_physical_expr_common::datum::compare_op_for_nested; +use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use futures::TryStreamExt; use parking_lot::Mutex; +/// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. +const HASH_JOIN_SEED: RandomState = + RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); + /// HashTable and input data for the left (build side) of a join -struct JoinLeftData { +pub(super) struct JoinLeftData { /// The hash table with indices into `batch` - hash_map: JoinHashMap, + pub(super) hash_map: Box, /// The input rows for the build side batch: RecordBatch, /// The build side on expressions values @@ -105,17 +102,20 @@ struct JoinLeftData { /// This could hide potential out-of-memory issues, especially when upstream operators increase their memory consumption. /// The MemoryReservation ensures proper tracking of memory resources throughout the join operation's lifecycle. _reservation: MemoryReservation, + /// Bounds computed from the build side for dynamic filter pushdown + pub(super) bounds: Option>, } impl JoinLeftData { /// Create a new `JoinLeftData` from its parts - fn new( - hash_map: JoinHashMap, + pub(super) fn new( + hash_map: Box, batch: RecordBatch, values: Vec, visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, reservation: MemoryReservation, + bounds: Option>, ) -> Self { Self { hash_map, @@ -124,32 +124,33 @@ impl JoinLeftData { visited_indices_bitmap, probe_threads_counter, _reservation: reservation, + bounds, } } /// return a reference to the hash map - fn hash_map(&self) -> &JoinHashMap { - &self.hash_map + pub(super) fn hash_map(&self) -> &dyn JoinHashMapType { + &*self.hash_map } /// returns a reference to the build side batch - fn batch(&self) -> &RecordBatch { + pub(super) fn batch(&self) -> &RecordBatch { &self.batch } /// returns a reference to the build side expressions values - fn values(&self) -> &[ArrayRef] { + pub(super) fn values(&self) -> &[ArrayRef] { &self.values } /// returns a reference to the visited indices bitmap - fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder { + pub(super) fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder { &self.visited_indices_bitmap } /// Decrements the counter of running threads, and returns `true` /// if caller is the last running thread - fn report_probe_completed(&self) -> bool { + pub(super) fn report_probe_completed(&self) -> bool { self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 } } @@ -318,7 +319,6 @@ impl JoinLeftData { /// Note this structure includes a [`OnceAsync`] that is used to coordinate the /// loading of the left side with the processing in each output stream. /// Therefore it can not be [`Clone`] -#[derive(Debug)] pub struct HashJoinExec { /// left (build) side which gets hashed pub left: Arc, @@ -339,7 +339,7 @@ pub struct HashJoinExec { /// /// Each output stream waits on the `OnceAsync` to signal the completion of /// the hash table creation. - left_fut: OnceAsync, + left_fut: Arc>, /// Shared the `RandomState` for the hashing algorithm random_state: RandomState, /// Partitioning mode to use @@ -350,13 +350,51 @@ pub struct HashJoinExec { pub projection: Option>, /// Information of index and left / right placement of columns column_indices: Vec, - /// Null matching behavior: If `null_equals_null` is true, rows that have - /// `null`s in both left and right equijoin columns will be matched. - /// Otherwise, rows that have `null`s in the join columns will not be - /// matched and thus will not appear in the output. - pub null_equals_null: bool, + /// The equality null-handling behavior of the join algorithm. + pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, + /// Dynamic filter for pushing down to the probe side + /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. + /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. + dynamic_filter: Option, +} + +#[derive(Clone)] +struct HashJoinExecDynamicFilter { + /// Dynamic filter that we'll update with the results of the build side once that is done. + filter: Arc, + /// Bounds accumulator to keep track of the min/max bounds on the join keys for each partition. + /// It is lazily initialized during execution to make sure we use the actual execution time partition counts. + bounds_accumulator: OnceLock>, +} + +impl fmt::Debug for HashJoinExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HashJoinExec") + .field("left", &self.left) + .field("right", &self.right) + .field("on", &self.on) + .field("filter", &self.filter) + .field("join_type", &self.join_type) + .field("join_schema", &self.join_schema) + .field("left_fut", &self.left_fut) + .field("random_state", &self.random_state) + .field("mode", &self.mode) + .field("metrics", &self.metrics) + .field("projection", &self.projection) + .field("column_indices", &self.column_indices) + .field("null_equality", &self.null_equality) + .field("cache", &self.cache) + // Explicitly exclude dynamic_filter to avoid runtime state differences in tests + .finish() + } +} + +impl EmbeddedProjection for HashJoinExec { + fn with_projection(&self, projection: Option>) -> Result { + self.with_projection(projection) + } } impl HashJoinExec { @@ -373,7 +411,7 @@ impl HashJoinExec { join_type: &JoinType, projection: Option>, partition_mode: PartitionMode, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); @@ -386,7 +424,7 @@ impl HashJoinExec { let (join_schema, column_indices) = build_join_schema(&left_schema, &right_schema, join_type); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = HASH_JOIN_SEED; let join_schema = Arc::new(join_schema); @@ -403,6 +441,9 @@ impl HashJoinExec { projection.as_ref(), )?; + // Initialize both dynamic filter and bounds accumulator to None + // They will be set later if dynamic filtering is enabled + Ok(HashJoinExec { left, right, @@ -416,11 +457,20 @@ impl HashJoinExec { metrics: ExecutionPlanMetricsSet::new(), projection, column_indices, - null_equals_null, + null_equality, cache, + dynamic_filter: None, }) } + fn create_dynamic_filter(on: &JoinOn) -> Arc { + // Extract the right-side keys (probe side keys) from the `on` clauses + // Dynamic filter will be created from build side values (left side) and applied to probe side (right side) + let right_keys: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect(); + // Initialize with a placeholder expression (true) that will be updated when the hash table is built + Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true))) + } + /// left (build) side which gets hashed pub fn left(&self) -> &Arc { &self.left @@ -457,9 +507,9 @@ impl HashJoinExec { &self.mode } - /// Get null_equals_null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null + /// Get null_equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality } /// Calculate order preservation flags for this hash join. @@ -472,6 +522,7 @@ impl HashJoinExec { | JoinType::Right | JoinType::RightAnti | JoinType::RightSemi + | JoinType::RightMark ), ] } @@ -506,7 +557,7 @@ impl HashJoinExec { &self.join_type, projection, self.mode, - self.null_equals_null, + self.null_equality, ) } @@ -529,17 +580,17 @@ impl HashJoinExec { &Self::maintains_input_order(join_type), Some(Self::probe_side()), on, - ); + )?; let mut output_partitioning = match mode { PartitionMode::CollectLeft => { - asymmetric_join_output_partitioning(left, right, &join_type) + asymmetric_join_output_partitioning(left, right, &join_type)? } PartitionMode::Auto => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), PartitionMode::Partitioned => { - symmetric_join_output_partitioning(left, right, &join_type) + symmetric_join_output_partitioning(left, right, &join_type)? } }; @@ -553,7 +604,8 @@ impl HashJoinExec { | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Right - | JoinType::RightAnti => EmissionType::Incremental, + | JoinType::RightAnti + | JoinType::RightMark => EmissionType::Incremental, // If we need to generate unmatched rows from the *build side*, // we need to emit them at the end. JoinType::Left @@ -592,6 +644,21 @@ impl HashJoinExec { /// /// This function is public so other downstream projects can use it to /// construct `HashJoinExec` with right side as the build side. + /// + /// For using this interface directly, please refer to below: + /// + /// Hash join execution may require specific input partitioning (for example, + /// the left child may have a single partition while the right child has multiple). + /// + /// Calling this function on join nodes whose children have already been repartitioned + /// (e.g., after a `RepartitionExec` has been inserted) may break the partitioning + /// requirements of the hash join. Therefore, ensure you call this function + /// before inserting any repartitioning operators on the join's children. + /// + /// In DataFusion's default SQL interface, this function is used by the `JoinSelection` + /// physical optimizer rule to determine a good join order, which is + /// executed before the `EnforceDistribution` rule (the rule that may + /// insert `RepartitionExec` operators). pub fn swap_inputs( &self, partition_mode: PartitionMode, @@ -614,7 +681,7 @@ impl HashJoinExec { self.join_type(), ), partition_mode, - self.null_equals_null(), + self.null_equality(), )?; // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( @@ -623,6 +690,8 @@ impl HashJoinExec { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark ) || self.projection.is_some() { Ok(Arc::new(new_join)) @@ -658,16 +727,27 @@ impl DisplayAs for HashJoinExec { } else { "".to_string() }; + let display_null_equality = + if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + ", NullsEqual: true" + } else { + "" + }; let on = self .on .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); write!( f, - "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}", - self.mode, self.join_type, on, display_filter, display_projections + "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}{}", + self.mode, + self.join_type, + on, + display_filter, + display_projections, + display_null_equality, ) } DisplayFormatType::TreeRender => { @@ -683,7 +763,18 @@ impl DisplayAs for HashJoinExec { if *self.join_type() != JoinType::Inner { writeln!(f, "join_type={:?}", self.join_type)?; } - writeln!(f, "on={}", on) + + writeln!(f, "on={on}")?; + + if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + writeln!(f, "NullsEqual: true")?; + } + + if let Some(filter) = self.filter.as_ref() { + writeln!(f, "filter={filter}")?; + } + + Ok(()) } } } @@ -750,20 +841,63 @@ impl ExecutionPlan for HashJoinExec { vec![&self.left, &self.right] } + /// Creates a new HashJoinExec with different children while preserving configuration. + /// + /// This method is called during query optimization when the optimizer creates new + /// plan nodes. Importantly, it creates a fresh bounds_accumulator via `try_new` + /// rather than cloning the existing one because partitioning may have changed. fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(HashJoinExec::try_new( - Arc::clone(&children[0]), - Arc::clone(&children[1]), - self.on.clone(), - self.filter.clone(), - &self.join_type, - self.projection.clone(), - self.mode, - self.null_equals_null, - )?)) + Ok(Arc::new(HashJoinExec { + left: Arc::clone(&children[0]), + right: Arc::clone(&children[1]), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + left_fut: Arc::clone(&self.left_fut), + random_state: self.random_state.clone(), + mode: self.mode, + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: Self::compute_properties( + &children[0], + &children[1], + Arc::clone(&self.join_schema), + self.join_type, + &self.on, + self.mode, + self.projection.as_ref(), + )?, + // Keep the dynamic filter, bounds accumulator will be reset + dynamic_filter: self.dynamic_filter.clone(), + })) + } + + fn reset_state(self: Arc) -> Result> { + Ok(Arc::new(HashJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + // Reset the left_fut to allow re-execution + left_fut: Arc::new(OnceAsync::default()), + random_state: self.random_state.clone(), + mode: self.mode, + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: self.cache.clone(), + // Reset dynamic filter and bounds accumulator to initial state + dynamic_filter: None, + })) } fn execute( @@ -776,11 +910,6 @@ impl ExecutionPlan for HashJoinExec { .iter() .map(|on| Arc::clone(&on.0)) .collect::>(); - let on_right = self - .on - .iter() - .map(|on| Arc::clone(&on.1)) - .collect::>(); let left_partitions = self.left.output_partitioning().partition_count(); let right_partitions = self.right.output_partitioning().partition_count(); @@ -792,38 +921,50 @@ impl ExecutionPlan for HashJoinExec { ); } + if self.mode == PartitionMode::CollectLeft && left_partitions != 1 { + return internal_err!( + "Invalid HashJoinExec, the output partition count of the left child must be 1 in CollectLeft mode,\ + consider using CoalescePartitionsExec or the EnforceDistribution rule" + ); + } + + let enable_dynamic_filter_pushdown = self.dynamic_filter.is_some(); + let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); let left_fut = match self.mode { - PartitionMode::CollectLeft => self.left_fut.once(|| { + PartitionMode::CollectLeft => self.left_fut.try_once(|| { + let left_stream = self.left.execute(0, Arc::clone(&context))?; + let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); - collect_left_input( - None, + + Ok(collect_left_input( self.random_state.clone(), - Arc::clone(&self.left), + left_stream, on_left.clone(), - Arc::clone(&context), join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), - ) - }), + enable_dynamic_filter_pushdown, + )) + })?, PartitionMode::Partitioned => { + let left_stream = self.left.execute(partition, Arc::clone(&context))?; + let reservation = MemoryConsumer::new(format!("HashJoinInput[{partition}]")) .register(context.memory_pool()); OnceFut::new(collect_left_input( - Some(partition), self.random_state.clone(), - Arc::clone(&self.left), + left_stream, on_left.clone(), - Arc::clone(&context), join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), 1, + enable_dynamic_filter_pushdown, )) } PartitionMode::Auto => { @@ -836,6 +977,30 @@ impl ExecutionPlan for HashJoinExec { let batch_size = context.session_config().batch_size(); + // Initialize bounds_accumulator lazily with runtime partition counts (only if enabled) + let bounds_accumulator = enable_dynamic_filter_pushdown + .then(|| { + self.dynamic_filter.as_ref().map(|df| { + let filter = Arc::clone(&df.filter); + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + Some(Arc::clone(df.bounds_accumulator.get_or_init(|| { + Arc::new(SharedBoundsAccumulator::new_from_partition_mode( + self.mode, + self.left.as_ref(), + self.right.as_ref(), + filter, + on_right, + )) + }))) + }) + }) + .flatten() + .flatten(); + // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. let right_stream = self.right.execute(partition, context)?; @@ -849,22 +1014,31 @@ impl ExecutionPlan for HashJoinExec { None => self.column_indices.clone(), }; - Ok(Box::pin(HashJoinStream { - schema: self.schema(), + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + + Ok(Box::pin(HashJoinStream::new( + partition, + self.schema(), on_right, - filter: self.filter.clone(), - join_type: self.join_type, - right: right_stream, - column_indices: column_indices_after_projection, - random_state: self.random_state.clone(), + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), join_metrics, - null_equals_null: self.null_equals_null, - state: HashJoinStreamState::WaitBuildSide, - build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, - hashes_buffer: vec![], - right_side_ordered: self.right.output_ordering().is_some(), - })) + vec![], + self.right.output_ordering().is_some(), + bounds_accumulator, + self.mode, + ))) } fn metrics(&self) -> Option { @@ -872,12 +1046,19 @@ impl ExecutionPlan for HashJoinExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` let stats = estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, self.on.clone(), &self.join_type, &self.join_schema, @@ -920,72 +1101,336 @@ impl ExecutionPlan for HashJoinExec { // Returned early if projection is not None None, *self.partition_mode(), - self.null_equals_null, + self.null_equality, )?))) } else { try_embed_projection(projection, self) } } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + config: &ConfigOptions, + ) -> Result { + // Other types of joins can support *some* filters, but restrictions are complex and error prone. + // For now we don't support them. + // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs + // See https://github.com/apache/datafusion/issues/16973 for tracking. + if self.join_type != JoinType::Inner { + return Ok(FilterDescription::all_unsupported( + &parent_filters, + &self.children(), + )); + } + + // Get basic filter descriptions for both children + let left_child = crate::filter_pushdown::ChildFilterDescription::from_child( + &parent_filters, + self.left(), + )?; + let mut right_child = crate::filter_pushdown::ChildFilterDescription::from_child( + &parent_filters, + self.right(), + )?; + + // Add dynamic filters in Post phase if enabled + if matches!(phase, FilterPushdownPhase::Post) + && config.optimizer.enable_dynamic_filter_pushdown + { + // Add actual dynamic filter to right side (probe side) + let dynamic_filter = Self::create_dynamic_filter(&self.on); + right_child = right_child.with_self_filter(dynamic_filter); + } + + Ok(FilterDescription::new() + .with_child(left_child) + .with_child(right_child)) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // Note: this check shouldn't be necessary because we already marked all parent filters as unsupported for + // non-inner joins in `gather_filters_for_pushdown`. + // However it's a cheap check and serves to inform future devs touching this function that they need to be really + // careful pushing down filters through non-inner joins. + if self.join_type != JoinType::Inner { + // Other types of joins can support *some* filters, but restrictions are complex and error prone. + // For now we don't support them. + // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs + return Ok(FilterPushdownPropagation::all_unsupported( + child_pushdown_result, + )); + } + + let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); + assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children + let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child + // We expect 0 or 1 self filters + if let Some(filter) = right_child_self_filters.first() { + // Note that we don't check PushdDownPredicate::discrimnant because even if nothing said + // "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating + let predicate = Arc::clone(&filter.predicate); + if let Ok(dynamic_filter) = + Arc::downcast::(predicate) + { + // We successfully pushed down our self filter - we need to make a new node with the dynamic filter + let new_node = Arc::new(HashJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + left_fut: Arc::clone(&self.left_fut), + random_state: self.random_state.clone(), + mode: self.mode, + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: self.cache.clone(), + dynamic_filter: Some(HashJoinExecDynamicFilter { + filter: dynamic_filter, + bounds_accumulator: OnceLock::new(), + }), + }); + result = result.with_updated_node(new_node as Arc); + } + } + Ok(result) + } +} + +/// Accumulator for collecting min/max bounds from build-side data during hash join. +/// +/// This struct encapsulates the logic for progressively computing column bounds +/// (minimum and maximum values) for a specific join key expression as batches +/// are processed during the build phase of a hash join. +/// +/// The bounds are used for dynamic filter pushdown optimization, where filters +/// based on the actual data ranges can be pushed down to the probe side to +/// eliminate unnecessary data early. +struct CollectLeftAccumulator { + /// The physical expression to evaluate for each batch + expr: Arc, + /// Accumulator for tracking the minimum value across all batches + min: MinAccumulator, + /// Accumulator for tracking the maximum value across all batches + max: MaxAccumulator, +} + +impl CollectLeftAccumulator { + /// Creates a new accumulator for tracking bounds of a join key expression. + /// + /// # Arguments + /// * `expr` - The physical expression to track bounds for + /// * `schema` - The schema of the input data + /// + /// # Returns + /// A new `CollectLeftAccumulator` instance configured for the expression's data type + fn try_new(expr: Arc, schema: &SchemaRef) -> Result { + /// Recursively unwraps dictionary types to get the underlying value type. + fn dictionary_value_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Dictionary(_, value_type) => { + dictionary_value_type(value_type.as_ref()) + } + _ => data_type.clone(), + } + } + + let data_type = expr + .data_type(schema) + // Min/Max can operate on dictionary data but expect to be initialized with the underlying value type + .map(|dt| dictionary_value_type(&dt))?; + Ok(Self { + expr, + min: MinAccumulator::try_new(&data_type)?, + max: MaxAccumulator::try_new(&data_type)?, + }) + } + + /// Updates the accumulators with values from a new batch. + /// + /// Evaluates the expression on the batch and updates both min and max + /// accumulators with the resulting values. + /// + /// # Arguments + /// * `batch` - The record batch to process + /// + /// # Returns + /// Ok(()) if the update succeeds, or an error if expression evaluation fails + fn update_batch(&mut self, batch: &RecordBatch) -> Result<()> { + let array = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; + self.min.update_batch(std::slice::from_ref(&array))?; + self.max.update_batch(std::slice::from_ref(&array))?; + Ok(()) + } + + /// Finalizes the accumulation and returns the computed bounds. + /// + /// Consumes self to extract the final min and max values from the accumulators. + /// + /// # Returns + /// The `ColumnBounds` containing the minimum and maximum values observed + fn evaluate(mut self) -> Result { + Ok(ColumnBounds::new( + self.min.evaluate()?, + self.max.evaluate()?, + )) + } } -/// Reads the left (build) side of the input, buffering it in memory, to build a -/// hash table (`LeftJoinData`) +/// State for collecting the build-side data during hash join +struct BuildSideState { + batches: Vec, + num_rows: usize, + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + bounds_accumulators: Option>, +} + +impl BuildSideState { + /// Create a new BuildSideState with optional accumulators for bounds computation + fn try_new( + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + on_left: Vec>, + schema: &SchemaRef, + should_compute_bounds: bool, + ) -> Result { + Ok(Self { + batches: Vec::new(), + num_rows: 0, + metrics, + reservation, + bounds_accumulators: should_compute_bounds + .then(|| { + on_left + .iter() + .map(|expr| { + CollectLeftAccumulator::try_new(Arc::clone(expr), schema) + }) + .collect::>>() + }) + .transpose()?, + }) + } +} + +/// Collects all batches from the left (build) side stream and creates a hash map for joining. +/// +/// This function is responsible for: +/// 1. Consuming the entire left stream and collecting all batches into memory +/// 2. Building a hash map from the join key columns for efficient probe operations +/// 3. Computing bounds for dynamic filter pushdown (if enabled) +/// 4. Preparing visited indices bitmap for certain join types +/// +/// # Parameters +/// * `random_state` - Random state for consistent hashing across partitions +/// * `left_stream` - Stream of record batches from the build side +/// * `on_left` - Physical expressions for the left side join keys +/// * `metrics` - Metrics collector for tracking memory usage and row counts +/// * `reservation` - Memory reservation tracker for the hash table and data +/// * `with_visited_indices_bitmap` - Whether to track visited indices (for outer joins) +/// * `probe_threads_count` - Number of threads that will probe this hash table +/// * `should_compute_bounds` - Whether to compute min/max bounds for dynamic filtering +/// +/// # Dynamic Filter Coordination +/// When `should_compute_bounds` is true, this function computes the min/max bounds +/// for each join key column but does NOT update the dynamic filter. Instead, the +/// bounds are stored in the returned `JoinLeftData` and later coordinated by +/// `SharedBoundsAccumulator` to ensure all partitions contribute their bounds +/// before updating the filter exactly once. +/// +/// # Returns +/// `JoinLeftData` containing the hash map, consolidated batch, join key values, +/// visited indices bitmap, and computed bounds (if requested). #[allow(clippy::too_many_arguments)] async fn collect_left_input( - partition: Option, random_state: RandomState, - left: Arc, + left_stream: SendableRecordBatchStream, on_left: Vec, - context: Arc, metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, with_visited_indices_bitmap: bool, probe_threads_count: usize, + should_compute_bounds: bool, ) -> Result { - let schema = left.schema(); - - let (left_input, left_input_partition) = if let Some(partition) = partition { - (left, partition) - } else if left.output_partitioning().partition_count() != 1 { - (Arc::new(CoalescePartitionsExec::new(left)) as _, 0) - } else { - (left, 0) - }; - - // Depending on partition argument load single partition or whole left side in memory - let stream = left_input.execute(left_input_partition, Arc::clone(&context))?; + let schema = left_stream.schema(); // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream // 2. stores the batches in a vector. - let initial = (Vec::new(), 0, metrics, reservation); - let (batches, num_rows, metrics, mut reservation) = stream - .try_fold(initial, |mut acc, batch| async { + let initial = BuildSideState::try_new( + metrics, + reservation, + on_left.clone(), + &schema, + should_compute_bounds, + )?; + + let state = left_stream + .try_fold(initial, |mut state, batch| async move { + // Update accumulators if computing bounds + if let Some(ref mut accumulators) = state.bounds_accumulators { + for accumulator in accumulators { + accumulator.update_batch(&batch)?; + } + } + + // Decide if we spill or not let batch_size = get_record_batch_memory_size(&batch); // Reserve memory for incoming batch - acc.3.try_grow(batch_size)?; + state.reservation.try_grow(batch_size)?; // Update metrics - acc.2.build_mem_used.add(batch_size); - acc.2.build_input_batches.add(1); - acc.2.build_input_rows.add(batch.num_rows()); + state.metrics.build_mem_used.add(batch_size); + state.metrics.build_input_batches.add(1); + state.metrics.build_input_rows.add(batch.num_rows()); // Update row count - acc.1 += batch.num_rows(); + state.num_rows += batch.num_rows(); // Push batch to output - acc.0.push(batch); - Ok(acc) + state.batches.push(batch); + Ok(state) }) .await?; + // Extract fields from state + let BuildSideState { + batches, + num_rows, + metrics, + mut reservation, + bounds_accumulators, + } = state; + // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` - let fixed_size = size_of::(); - let estimated_hashtable_size = - estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)?; - - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + + // Use `u32` indices for the JoinHashMap when num_rows ≤ u32::MAX, otherwise use the + // `u64` indice variant + let mut hashmap: Box = if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; - let mut hashmap = JoinHashMap::with_capacity(num_rows); let mut hashes_buffer = Vec::new(); let mut offset = 0; @@ -997,7 +1442,7 @@ async fn collect_left_input( update_hash( &on_left, batch, - &mut hashmap, + &mut *hashmap, offset, &random_state, &mut hashes_buffer, @@ -1030,640 +1475,47 @@ async fn collect_left_input( }) .collect::>>()?; + // Compute bounds for dynamic filter if enabled + let bounds = match bounds_accumulators { + Some(accumulators) if num_rows > 0 => { + let bounds = accumulators + .into_iter() + .map(CollectLeftAccumulator::evaluate) + .collect::>>()?; + Some(bounds) + } + _ => None, + }; + let data = JoinLeftData::new( hashmap, single_batch, - left_values, + left_values.clone(), Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), reservation, + bounds, ); Ok(data) } -/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` -/// using `offset` as a start value for `batch` row indices. -/// -/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, -/// which allows to keep either first (if set to true) or last (if set to false) row index -/// as a chain head for rows with equal hash values. -#[allow(clippy::too_many_arguments)] -pub fn update_hash( - on: &[PhysicalExprRef], - batch: &RecordBatch, - hash_map: &mut T, - offset: usize, - random_state: &RandomState, - hashes_buffer: &mut Vec, - deleted_offset: usize, - fifo_hashmap: bool, -) -> Result<()> -where - T: JoinHashMapType, -{ - // evaluate the keys - let keys_values = on - .iter() - .map(|c| c.evaluate(batch)?.into_array(batch.num_rows())) - .collect::>>()?; - - // calculate the hash values - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - - // For usual JoinHashmap, the implementation is void. - hash_map.extend_zero(batch.num_rows()); - - // Updating JoinHashMap from hash values iterator - let hash_values_iter = hash_values - .iter() - .enumerate() - .map(|(i, val)| (i + offset, val)); - - if fifo_hashmap { - hash_map.update_from_iter(hash_values_iter.rev(), deleted_offset); - } else { - hash_map.update_from_iter(hash_values_iter, deleted_offset); - } - - Ok(()) -} - -/// Represents build-side of hash join. -enum BuildSide { - /// Indicates that build-side not collected yet - Initial(BuildSideInitialState), - /// Indicates that build-side data has been collected - Ready(BuildSideReadyState), -} - -/// Container for BuildSide::Initial related data -struct BuildSideInitialState { - /// Future for building hash table from build-side input - left_fut: OnceFut, -} - -/// Container for BuildSide::Ready related data -struct BuildSideReadyState { - /// Collected build-side data - left_data: Arc, -} - -impl BuildSide { - /// Tries to extract BuildSideInitialState from BuildSide enum. - /// Returns an error if state is not Initial. - fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { - match self { - BuildSide::Initial(state) => Ok(state), - _ => internal_err!("Expected build side in initial state"), - } - } - - /// Tries to extract BuildSideReadyState from BuildSide enum. - /// Returns an error if state is not Ready. - fn try_as_ready(&self) -> Result<&BuildSideReadyState> { - match self { - BuildSide::Ready(state) => Ok(state), - _ => internal_err!("Expected build side in ready state"), - } - } - - /// Tries to extract BuildSideReadyState from BuildSide enum. - /// Returns an error if state is not Ready. - fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { - match self { - BuildSide::Ready(state) => Ok(state), - _ => internal_err!("Expected build side in ready state"), - } - } -} - -/// Represents state of HashJoinStream -/// -/// Expected state transitions performed by HashJoinStream are: -/// -/// ```text -/// -/// WaitBuildSide -/// │ -/// ▼ -/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed -/// │ │ -/// │ ▼ -/// └─ ProcessProbeBatch -/// -/// ``` -#[derive(Debug, Clone)] -enum HashJoinStreamState { - /// Initial state for HashJoinStream indicating that build-side data not collected yet - WaitBuildSide, - /// Indicates that build-side has been collected, and stream is ready for fetching probe-side - FetchProbeBatch, - /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed - ProcessProbeBatch(ProcessProbeBatchState), - /// Indicates that probe-side has been fully processed - ExhaustedProbeSide, - /// Indicates that HashJoinStream execution is completed - Completed, -} - -impl HashJoinStreamState { - /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. - /// Returns an error if state is not ProcessProbeBatchState. - fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> { - match self { - HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), - _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), - } - } -} - -/// Container for HashJoinStreamState::ProcessProbeBatch related data -#[derive(Debug, Clone)] -struct ProcessProbeBatchState { - /// Current probe-side batch - batch: RecordBatch, - /// Probe-side on expressions values - values: Vec, - /// Starting offset for JoinHashMap lookups - offset: JoinHashMapOffset, - /// Max joined probe-side index from current batch - joined_probe_idx: Option, -} - -impl ProcessProbeBatchState { - fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option) { - self.offset = offset; - if joined_probe_idx.is_some() { - self.joined_probe_idx = joined_probe_idx; - } - } -} - -/// [`Stream`] for [`HashJoinExec`] that does the actual join. -/// -/// This stream: -/// -/// 1. Reads the entire left input (build) and constructs a hash table -/// -/// 2. Streams [RecordBatch]es as they arrive from the right input (probe) and joins -/// them with the contents of the hash table -struct HashJoinStream { - /// Input schema - schema: Arc, - /// equijoin columns from the right (probe side) - on_right: Vec, - /// optional join filter - filter: Option, - /// type of the join (left, right, semi, etc) - join_type: JoinType, - /// right (probe) input - right: SendableRecordBatchStream, - /// Random state used for hashing initialization - random_state: RandomState, - /// Metrics - join_metrics: BuildProbeJoinMetrics, - /// Information of index and left / right placement of columns - column_indices: Vec, - /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, - /// State of the stream - state: HashJoinStreamState, - /// Build side - build_side: BuildSide, - /// Maximum output batch size - batch_size: usize, - /// Scratch space for computing hashes - hashes_buffer: Vec, - /// Specifies whether the right side has an ordering to potentially preserve - right_side_ordered: bool, -} - -impl RecordBatchStream for HashJoinStream { - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } -} - -/// Executes lookups by hash against JoinHashMap and resolves potential -/// hash collisions. -/// Returns build/probe indices satisfying the equality condition, along with -/// (optional) starting point for next iteration. -/// -/// # Example -/// -/// For `LEFT.b1 = RIGHT.b2`: -/// LEFT (build) Table: -/// ```text -/// a1 b1 c1 -/// 1 1 10 -/// 3 3 30 -/// 5 5 50 -/// 7 7 70 -/// 9 8 90 -/// 11 8 110 -/// 13 10 130 -/// ``` -/// -/// RIGHT (probe) Table: -/// ```text -/// a2 b2 c2 -/// 2 2 20 -/// 4 4 40 -/// 6 6 60 -/// 8 8 80 -/// 10 10 100 -/// 12 10 120 -/// ``` -/// -/// The result is -/// ```text -/// "+----+----+-----+----+----+-----+", -/// "| a1 | b1 | c1 | a2 | b2 | c2 |", -/// "+----+----+-----+----+----+-----+", -/// "| 9 | 8 | 90 | 8 | 8 | 80 |", -/// "| 11 | 8 | 110 | 8 | 8 | 80 |", -/// "| 13 | 10 | 130 | 10 | 10 | 100 |", -/// "| 13 | 10 | 130 | 12 | 10 | 120 |", -/// "+----+----+-----+----+----+-----+" -/// ``` -/// -/// And the result of build and probe indices are: -/// ```text -/// Build indices: 4, 5, 6, 6 -/// Probe indices: 3, 3, 4, 5 -/// ``` -#[allow(clippy::too_many_arguments)] -fn lookup_join_hashmap( - build_hashmap: &JoinHashMap, - build_side_values: &[ArrayRef], - probe_side_values: &[ArrayRef], - null_equals_null: bool, - hashes_buffer: &[u64], - limit: usize, - offset: JoinHashMapOffset, -) -> Result<(UInt64Array, UInt32Array, Option)> { - let (probe_indices, build_indices, next_offset) = build_hashmap - .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset); - - let build_indices: UInt64Array = build_indices.into(); - let probe_indices: UInt32Array = probe_indices.into(); - - let (build_indices, probe_indices) = equal_rows_arr( - &build_indices, - &probe_indices, - build_side_values, - probe_side_values, - null_equals_null, - )?; - - Ok((build_indices, probe_indices, next_offset)) -} - -// version of eq_dyn supporting equality on null arrays -fn eq_dyn_null( - left: &dyn Array, - right: &dyn Array, - null_equals_null: bool, -) -> Result { - // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special - // implementation - // - if left.data_type().is_nested() { - let op = if null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq - }; - return Ok(compare_op_for_nested(op, &left, &right)?); - } - match (left.data_type(), right.data_type()) { - _ if null_equals_null => not_distinct(&left, &right), - _ => eq(&left, &right), - } -} - -pub fn equal_rows_arr( - indices_left: &UInt64Array, - indices_right: &UInt32Array, - left_arrays: &[ArrayRef], - right_arrays: &[ArrayRef], - null_equals_null: bool, -) -> Result<(UInt64Array, UInt32Array)> { - let mut iter = left_arrays.iter().zip(right_arrays.iter()); - - let (first_left, first_right) = iter.next().ok_or_else(|| { - DataFusionError::Internal( - "At least one array should be provided for both left and right".to_string(), - ) - })?; - - let arr_left = take(first_left.as_ref(), indices_left, None)?; - let arr_right = take(first_right.as_ref(), indices_right, None)?; - - let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equals_null)?; - - // Use map and try_fold to iterate over the remaining pairs of arrays. - // In each iteration, take is used on the pair of arrays and their equality is determined. - // The results are then folded (combined) using the and function to get a final equality result. - equal = iter - .map(|(left, right)| { - let arr_left = take(left.as_ref(), indices_left, None)?; - let arr_right = take(right.as_ref(), indices_right, None)?; - eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equals_null) - }) - .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?; - - let filter_builder = FilterBuilder::new(&equal).optimize().build(); - - let left_filtered = filter_builder.filter(indices_left)?; - let right_filtered = filter_builder.filter(indices_right)?; - - Ok(( - downcast_array(left_filtered.as_ref()), - downcast_array(right_filtered.as_ref()), - )) -} - -impl HashJoinStream { - /// Separate implementation function that unpins the [`HashJoinStream`] so - /// that partial borrows work correctly - fn poll_next_impl( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>> { - loop { - return match self.state { - HashJoinStreamState::WaitBuildSide => { - handle_state!(ready!(self.collect_build_side(cx))) - } - HashJoinStreamState::FetchProbeBatch => { - handle_state!(ready!(self.fetch_probe_batch(cx))) - } - HashJoinStreamState::ProcessProbeBatch(_) => { - handle_state!(self.process_probe_batch()) - } - HashJoinStreamState::ExhaustedProbeSide => { - handle_state!(self.process_unmatched_build_batch()) - } - HashJoinStreamState::Completed => Poll::Ready(None), - }; - } - } - - /// Collects build-side data by polling `OnceFut` future from initialized build-side - /// - /// Updates build-side to `Ready`, and state to `FetchProbeSide` - fn collect_build_side( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>>> { - let build_timer = self.join_metrics.build_time.timer(); - // build hash table from left (build) side, if not yet done - let left_data = ready!(self - .build_side - .try_as_initial_mut()? - .left_fut - .get_shared(cx))?; - build_timer.done(); - - self.state = HashJoinStreamState::FetchProbeBatch; - self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); - - Poll::Ready(Ok(StatefulStreamResult::Continue)) - } - - /// Fetches next batch from probe-side - /// - /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, - /// otherwise updates state to `ExhaustedProbeSide` - fn fetch_probe_batch( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>>> { - match ready!(self.right.poll_next_unpin(cx)) { - None => { - self.state = HashJoinStreamState::ExhaustedProbeSide; - } - Some(Ok(batch)) => { - // Precalculate hash values for fetched batch - let keys_values = self - .on_right - .iter() - .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows())) - .collect::>>()?; - - self.hashes_buffer.clear(); - self.hashes_buffer.resize(batch.num_rows(), 0); - create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; - - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.state = - HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { - batch, - values: keys_values, - offset: (0, None), - joined_probe_idx: None, - }); - } - Some(Err(err)) => return Poll::Ready(Err(err)), - }; - - Poll::Ready(Ok(StatefulStreamResult::Continue)) - } - - /// Joins current probe batch with build-side data and produces batch with matched output - /// - /// Updates state to `FetchProbeBatch` - fn process_probe_batch( - &mut self, - ) -> Result>> { - let state = self.state.try_as_process_probe_batch_mut()?; - let build_side = self.build_side.try_as_ready_mut()?; - - let timer = self.join_metrics.join_time.timer(); - - // get the matched by join keys indices - let (left_indices, right_indices, next_offset) = lookup_join_hashmap( - build_side.left_data.hash_map(), - build_side.left_data.values(), - &state.values, - self.null_equals_null, - &self.hashes_buffer, - self.batch_size, - state.offset, - )?; - - // apply join filter if exists - let (left_indices, right_indices) = if let Some(filter) = &self.filter { - apply_join_filter_to_indices( - build_side.left_data.batch(), - &state.batch, - left_indices, - right_indices, - filter, - JoinSide::Left, - )? - } else { - (left_indices, right_indices) - }; - - // mark joined left-side indices as visited, if required by join type - if need_produce_result_in_final(self.join_type) { - let mut bitmap = build_side.left_data.visited_indices_bitmap().lock(); - left_indices.iter().flatten().for_each(|x| { - bitmap.set_bit(x as usize, true); - }); - } - - // The goals of index alignment for different join types are: - // - // 1) Right & FullJoin -- to append all missing probe-side indices between - // previous (excluding) and current joined indices. - // 2) SemiJoin -- deduplicate probe indices in range between previous - // (excluding) and current joined indices. - // 3) AntiJoin -- return only missing indices in range between - // previous and current joined indices. - // Inclusion/exclusion of the indices themselves don't matter - // - // As a summary -- alignment range can be produced based only on - // joined (matched with filters applied) probe side indices, excluding starting one - // (left from previous iteration). - - // if any rows have been joined -- get last joined probe-side (right) row - // it's important that index counts as "joined" after hash collisions checks - // and join filters applied. - let last_joined_right_idx = match right_indices.len() { - 0 => None, - n => Some(right_indices.value(n - 1) as usize), - }; - - // Calculate range and perform alignment. - // In case probe batch has been processed -- align all remaining rows. - let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); - let index_alignment_range_end = if next_offset.is_none() { - state.batch.num_rows() - } else { - last_joined_right_idx.map_or(0, |v| v + 1) - }; - - let (left_indices, right_indices) = adjust_indices_by_join_type( - left_indices, - right_indices, - index_alignment_range_start..index_alignment_range_end, - self.join_type, - self.right_side_ordered, - )?; - - let result = build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &state.batch, - &left_indices, - &right_indices, - &self.column_indices, - JoinSide::Left, - )?; - - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(result.num_rows()); - timer.done(); - - if next_offset.is_none() { - self.state = HashJoinStreamState::FetchProbeBatch; - } else { - state.advance( - next_offset - .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?, - last_joined_right_idx, - ) - }; - - Ok(StatefulStreamResult::Ready(Some(result))) - } - - /// Processes unmatched build-side rows for certain join types and produces output batch - /// - /// Updates state to `Completed` - fn process_unmatched_build_batch( - &mut self, - ) -> Result>> { - let timer = self.join_metrics.join_time.timer(); - - if !need_produce_result_in_final(self.join_type) { - self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); - } - - let build_side = self.build_side.try_as_ready()?; - if !build_side.left_data.report_probe_completed() { - self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); - } - - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_shared_bitmap( - build_side.left_data.visited_indices_bitmap(), - self.join_type, - ); - let empty_right_batch = RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - timer.done(); - - self.state = HashJoinStreamState::Completed; - - Ok(StatefulStreamResult::Ready(Some(result?))) - } -} - -impl Stream for HashJoinStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.poll_next_impl(cx) - } -} - -impl EmbeddedProjection for HashJoinExec { - fn with_projection(&self, projection: Option>) -> Result { - self.with_projection(projection) - } -} - #[cfg(test)] mod tests { use super::*; - use crate::test::TestMemoryExec; + use crate::coalesce_partitions::CoalescePartitionsExec; + use crate::joins::hash_join::stream::lookup_join_hashmap; + use crate::test::{assert_join_metrics, TestMemoryExec}; use crate::{ common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, test::exec::MockExec, }; - use arrow::array::{Date32Array, Int32Array, StructArray}; + use arrow::array::{Date32Array, Int32Array, StructArray, UInt32Array, UInt64Array}; use arrow::buffer::NullBuffer; use arrow::datatypes::{DataType, Field}; + use arrow_schema::Schema; + use datafusion_common::hash_utils::create_hashes; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, @@ -1707,7 +1559,7 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { HashJoinExec::try_new( left, @@ -1717,7 +1569,7 @@ mod tests { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + null_equality, ) } @@ -1727,7 +1579,7 @@ mod tests { on: JoinOn, filter: JoinFilter, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { HashJoinExec::try_new( left, @@ -1737,7 +1589,7 @@ mod tests { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + null_equality, ) } @@ -1746,16 +1598,17 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, - ) -> Result<(Vec, Vec)> { - let join = join(left, right, on, join_type, null_equals_null)?; + ) -> Result<(Vec, Vec, MetricsSet)> { + let join = join(left, right, on, join_type, null_equality)?; let columns_header = columns(&join.schema()); let stream = join.execute(0, context)?; let batches = common::collect(stream).await?; + let metrics = join.metrics().unwrap(); - Ok((columns_header, batches)) + Ok((columns_header, batches, metrics)) } async fn partitioned_join_collect( @@ -1763,16 +1616,16 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, - ) -> Result<(Vec, Vec)> { + ) -> Result<(Vec, Vec, MetricsSet)> { join_collect_with_partition_mode( left, right, on, join_type, PartitionMode::Partitioned, - null_equals_null, + null_equality, context, ) .await @@ -1784,9 +1637,9 @@ mod tests { on: JoinOn, join_type: &JoinType, partition_mode: PartitionMode, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, - ) -> Result<(Vec, Vec)> { + ) -> Result<(Vec, Vec, MetricsSet)> { let partition_count = 4; let (left_expr, right_expr) = on @@ -1834,7 +1687,7 @@ mod tests { join_type, None, partition_mode, - null_equals_null, + null_equality, )?; let columns = columns(&join.schema()); @@ -1850,8 +1703,9 @@ mod tests { .collect::>(), ); } + let metrics = join.metrics().unwrap(); - Ok((columns, batches)) + Ok((columns, batches, metrics)) } #[apply(batch_sizes)] @@ -1874,12 +1728,12 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = join_collect( + let (columns, batches, metrics) = join_collect( Arc::clone(&left), Arc::clone(&right), on.clone(), &JoinType::Inner, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -1899,6 +1753,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -1921,12 +1777,12 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = partitioned_join_collect( + let (columns, batches, metrics) = partitioned_join_collect( Arc::clone(&left), Arc::clone(&right), on.clone(), &JoinType::Inner, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -1945,6 +1801,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -1966,8 +1824,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1984,6 +1849,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -2005,8 +1872,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2024,6 +1898,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 4); + Ok(()) } @@ -2052,8 +1928,15 @@ mod tests { ), ]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -2086,6 +1969,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -2105,6 +1990,7 @@ mod tests { let left = TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None) .unwrap(); + let left = Arc::new(CoalescePartitionsExec::new(left)); let right = build_table( ("a1", &vec![1, 2, 3]), @@ -2122,8 +2008,15 @@ mod tests { ), ]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -2156,6 +2049,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -2177,6 +2072,7 @@ mod tests { let left = TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None) .unwrap(); + let left = Arc::new(CoalescePartitionsExec::new(left)); let right = build_table( ("a2", &vec![20, 30, 10]), ("b2", &vec![5, 6, 4]), @@ -2187,8 +2083,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2206,6 +2109,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 4); + Ok(()) } @@ -2237,7 +2142,13 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2330,7 +2241,14 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Left, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Left, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2373,7 +2291,14 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Full, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2414,7 +2339,14 @@ mod tests { )]; let schema = right.schema(); let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, None).unwrap(); - let join = join(left, right, on, &JoinType::Left, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Left, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2451,7 +2383,14 @@ mod tests { )]; let schema = right.schema(); let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, None).unwrap(); - let join = join(left, right, on, &JoinType::Full, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2491,15 +2430,16 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = join_collect( + let (columns, batches, metrics) = join_collect( Arc::clone(&left), Arc::clone(&right), on.clone(), &JoinType::Left, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); allow_duplicates! { @@ -2514,6 +2454,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -2536,15 +2478,16 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = partitioned_join_collect( + let (columns, batches, metrics) = partitioned_join_collect( Arc::clone(&left), Arc::clone(&right), on.clone(), &JoinType::Left, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); allow_duplicates! { @@ -2559,6 +2502,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -2594,7 +2539,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::LeftSemi, false)?; + let join = join( + left, + right, + on, + &JoinType::LeftSemi, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -2656,7 +2607,7 @@ mod tests { on.clone(), filter, &JoinType::LeftSemi, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -2689,7 +2640,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::LeftSemi, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a1", "b1", "c1"]); @@ -2723,7 +2681,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::RightSemi, false)?; + let join = join( + left, + right, + on, + &JoinType::RightSemi, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a2", "b2", "c2"]); @@ -2785,7 +2749,7 @@ mod tests { on.clone(), filter, &JoinType::RightSemi, - false, + NullEquality::NullEqualsNothing, )?; let columns = columns(&join.schema()); @@ -2820,8 +2784,14 @@ mod tests { Arc::new(intermediate_schema.clone()), ); - let join = - join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::RightSemi, + NullEquality::NullEqualsNothing, + )?; let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; @@ -2852,7 +2822,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::LeftAnti, false)?; + let join = join( + left, + right, + on, + &JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -2911,7 +2887,7 @@ mod tests { on.clone(), filter, &JoinType::LeftAnti, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -2948,7 +2924,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a1", "b1", "c1"]); @@ -2985,7 +2968,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::RightAnti, false)?; + let join = join( + left, + right, + on, + &JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a2", "b2", "c2"]); @@ -3045,7 +3034,7 @@ mod tests { on.clone(), filter, &JoinType::RightAnti, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -3086,8 +3075,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = - join_with_filter(left, right, on, filter, &JoinType::RightAnti, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a2", "b2", "c2"]); @@ -3131,8 +3126,15 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Right, false, task_ctx).await?; + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Right, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -3148,6 +3150,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -3170,9 +3174,15 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = - partitioned_join_collect(left, right, on, &JoinType::Right, false, task_ctx) - .await?; + let (columns, batches, metrics) = partitioned_join_collect( + left, + right, + on, + &JoinType::Right, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -3188,6 +3198,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -3210,7 +3222,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Full, false)?; + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -3253,15 +3271,16 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = join_collect( + let (columns, batches, metrics) = join_collect( Arc::clone(&left), Arc::clone(&right), on.clone(), &JoinType::LeftMark, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); allow_duplicates! { @@ -3276,6 +3295,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -3298,15 +3319,16 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = partitioned_join_collect( + let (columns, batches, metrics) = partitioned_join_collect( Arc::clone(&left), Arc::clone(&right), on.clone(), &JoinType::LeftMark, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); allow_duplicates! { @@ -3321,12 +3343,109 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn join_right_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches, metrics) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::RightMark, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a2 | b1 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 70 | true |", + "| 20 | 5 | 80 | true |", + "| 30 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + assert_join_metrics!(metrics, 3); + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn partitioned_join_right_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches, metrics) = partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::RightMark, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a2 | b1 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 60 | true |", + "| 20 | 4 | 70 | true |", + "| 30 | 5 | 80 | true |", + "| 40 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + assert_join_metrics!(metrics, 4); + Ok(()) } #[test] - fn join_with_hash_collision() -> Result<()> { - let mut hashmap_left = HashTable::with_capacity(2); + fn join_with_hash_collisions_64() -> Result<()> { + let mut hashmap_left = HashTable::with_capacity(4); let left = build_table_i32( ("a", &vec![10, 20]), ("x", &vec![100, 200]), @@ -3341,9 +3460,15 @@ mod tests { hashes_buff, )?; - // Create hash collisions (same hashes) + // Maps both values to both indices (1 and 2, representing input 0 and 1) + // 0 -> (0, 1) + // 1 -> (0, 2) + // The equality check will make sure only hashes[0] maps to 0 and hashes[1] maps to 1 hashmap_left.insert_unique(hashes[0], (hashes[0], 1), |(h, _)| *h); + hashmap_left.insert_unique(hashes[0], (hashes[0], 2), |(h, _)| *h); + hashmap_left.insert_unique(hashes[1], (hashes[1], 1), |(h, _)| *h); + hashmap_left.insert_unique(hashes[1], (hashes[1], 2), |(h, _)| *h); let next = vec![2, 0]; @@ -3356,7 +3481,7 @@ mod tests { // Join key column for both join sides let key_column: PhysicalExprRef = Arc::new(Column::new("a", 0)) as _; - let join_hash_map = JoinHashMap::new(hashmap_left, next); + let join_hash_map = JoinHashMapU64::new(hashmap_left, next); let left_keys_values = key_column.evaluate(&left)?.into_array(left.num_rows())?; let right_keys_values = @@ -3372,7 +3497,7 @@ mod tests { &join_hash_map, &[left_keys_values], &[right_keys_values], - false, + NullEquality::NullEqualsNothing, &hashes_buffer, 8192, (0, None), @@ -3389,6 +3514,70 @@ mod tests { Ok(()) } + #[test] + fn join_with_hash_collisions_u32() -> Result<()> { + let mut hashmap_left = HashTable::with_capacity(4); + let left = build_table_i32( + ("a", &vec![10, 20]), + ("x", &vec![100, 200]), + ("y", &vec![200, 300]), + ); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let hashes_buff = &mut vec![0; left.num_rows()]; + let hashes = create_hashes( + &[Arc::clone(&left.columns()[0])], + &random_state, + hashes_buff, + )?; + + hashmap_left.insert_unique(hashes[0], (hashes[0], 1u32), |(h, _)| *h); + hashmap_left.insert_unique(hashes[0], (hashes[0], 2u32), |(h, _)| *h); + hashmap_left.insert_unique(hashes[1], (hashes[1], 1u32), |(h, _)| *h); + hashmap_left.insert_unique(hashes[1], (hashes[1], 2u32), |(h, _)| *h); + + let next: Vec = vec![2, 0]; + + let right = build_table_i32( + ("a", &vec![10, 20]), + ("b", &vec![0, 0]), + ("c", &vec![30, 40]), + ); + + let key_column: PhysicalExprRef = Arc::new(Column::new("a", 0)) as _; + + let join_hash_map = JoinHashMapU32::new(hashmap_left, next); + + let left_keys_values = key_column.evaluate(&left)?.into_array(left.num_rows())?; + let right_keys_values = + key_column.evaluate(&right)?.into_array(right.num_rows())?; + let mut hashes_buffer = vec![0; right.num_rows()]; + create_hashes( + &[Arc::clone(&right_keys_values)], + &random_state, + &mut hashes_buffer, + )?; + + let (l, r, _) = lookup_join_hashmap( + &join_hash_map, + &[left_keys_values], + &[right_keys_values], + NullEquality::NullEqualsNothing, + &hashes_buffer, + 8192, + (0, None), + )?; + + // We still expect to match rows 0 and 1 on both sides + let left_ids: UInt64Array = vec![0, 1].into(); + let right_ids: UInt32Array = vec![0, 1].into(); + + assert_eq!(left_ids, l); + assert_eq!(right_ids, r); + + Ok(()) + } + #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -3408,7 +3597,13 @@ mod tests { Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3478,7 +3673,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Inner, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3520,7 +3722,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Left, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Left, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3565,7 +3774,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Right, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Right, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3609,7 +3825,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Full, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Full, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3632,7 +3855,7 @@ mod tests { ]; assert_batches_sorted_eq!(expected, &batches); - // THIS MIGRATION HAULTED DUE TO ISSUE #15312 + // THIS MIGRATION HALTED DUE TO ISSUE #15312 //allow_duplicates! { // assert_snapshot!(batches_to_sort_string(&batches), @r#" // +---+---+---+----+---+---+ @@ -3746,6 +3969,15 @@ mod tests { "| 3 | 7 | 9 | false |", "+----+----+----+-------+", ]; + let expected_right_mark = vec![ + "+----+----+----+-------+", + "| a2 | b2 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 70 | true |", + "| 20 | 5 | 80 | true |", + "| 30 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; let test_cases = vec![ (JoinType::Inner, expected_inner), @@ -3757,20 +3989,22 @@ mod tests { (JoinType::RightSemi, expected_right_semi), (JoinType::RightAnti, expected_right_anti), (JoinType::LeftMark, expected_left_mark), + (JoinType::RightMark, expected_right_mark), ]; for (join_type, expected) in test_cases { - let (_, batches) = join_collect_with_partition_mode( + let (_, batches, metrics) = join_collect_with_partition_mode( Arc::clone(&left), Arc::clone(&right), on.clone(), &join_type, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, Arc::clone(&task_ctx), ) .await?; assert_batches_sorted_eq!(expected, &batches); + assert_join_metrics!(metrics, expected.len() - 4); } Ok(()) @@ -3798,7 +4032,13 @@ mod tests { Arc::new(Column::new_with_schema("date", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let task_ctx = Arc::new(TaskContext::default()); let stream = join.execute(0, task_ctx)?; @@ -3857,7 +4097,7 @@ mod tests { Arc::clone(&right_input) as Arc, on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); let task_ctx = Arc::new(TaskContext::default()); @@ -3874,7 +4114,7 @@ mod tests { } #[tokio::test] - async fn join_splitted_batch() { + async fn join_split_batch() { let left = build_table( ("a1", &vec![1, 2, 3, 4]), ("b1", &vec![1, 1, 1, 1]), @@ -3971,7 +4211,7 @@ mod tests { Arc::clone(&right), on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -3994,10 +4234,7 @@ mod tests { assert_eq!( batches.len(), expected_batch_count, - "expected {} output batches for {} join with batch_size = {}", - expected_batch_count, - join_type, - batch_size + "expected {expected_batch_count} output batches for {join_type} join with batch_size = {batch_size}" ); let expected = match join_type { @@ -4039,6 +4276,7 @@ mod tests { JoinType::RightSemi, JoinType::RightAnti, JoinType::LeftMark, + JoinType::RightMark, ]; for join_type in join_types { @@ -4053,7 +4291,7 @@ mod tests { Arc::clone(&right), on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4062,12 +4300,12 @@ mod tests { // Asserting that operator-level reservation attempting to overallocate assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput" + "Resources exhausted: Additional allocation failed for HashJoinInput with top memory consumers (across reservations) as:\n HashJoinInput" ); assert_contains!( err.to_string(), - "Failed to allocate additional 120 bytes for HashJoinInput" + "Failed to allocate additional 120.0 B for HashJoinInput" ); } @@ -4134,7 +4372,7 @@ mod tests { &join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(1, task_ctx)?; @@ -4143,13 +4381,13 @@ mod tests { // Asserting that stream-level reservation attempting to overallocate assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput[1]" + "Resources exhausted: Additional allocation failed for HashJoinInput[1] with top memory consumers (across reservations) as:\n HashJoinInput[1]" ); assert_contains!( err.to_string(), - "Failed to allocate additional 120 bytes for HashJoinInput[1]" + "Failed to allocate additional 120.0 B for HashJoinInput[1]" ); } @@ -4194,8 +4432,15 @@ mod tests { Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["n1", "n2"]); @@ -4211,6 +4456,8 @@ mod tests { "#); } + assert_join_metrics!(metrics, 3); + Ok(()) } @@ -4226,12 +4473,12 @@ mod tests { Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, )]; - let (_, batches_null_eq) = join_collect( + let (_, batches_null_eq, metrics) = join_collect( Arc::clone(&left), Arc::clone(&right), on.clone(), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, Arc::clone(&task_ctx), ) .await?; @@ -4246,8 +4493,19 @@ mod tests { "#); } - let (_, batches_null_neq) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + assert_join_metrics!(metrics, 1); + + let (_, batches_null_neq, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_join_metrics!(metrics, 0); let expected_null_neq = ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs new file mode 100644 index 0000000000000..7f1e5cae13a3e --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`HashJoinExec`] Partitioned Hash Join Operator + +pub use exec::HashJoinExec; + +mod exec; +mod shared_bounds; +mod stream; diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs new file mode 100644 index 0000000000000..25f7a0de31acd --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for shared bounds. Used in dynamic filter pushdown in Hash Joins. +// TODO: include the link to the Dynamic Filter blog post. + +use std::fmt; +use std::sync::Arc; + +use crate::joins::PartitionMode; +use crate::ExecutionPlan; +use crate::ExecutionPlanProperties; + +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; + +use itertools::Itertools; +use parking_lot::Mutex; +use tokio::sync::Barrier; + +/// Represents the minimum and maximum values for a specific column. +/// Used in dynamic filter pushdown to establish value boundaries. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct ColumnBounds { + /// The minimum value observed for this column + min: ScalarValue, + /// The maximum value observed for this column + max: ScalarValue, +} + +impl ColumnBounds { + pub(crate) fn new(min: ScalarValue, max: ScalarValue) -> Self { + Self { min, max } + } +} + +/// Represents the bounds for all join key columns from a single partition. +/// This contains the min/max values computed from one partition's build-side data. +#[derive(Debug, Clone)] +pub(crate) struct PartitionBounds { + /// Partition identifier for debugging and determinism (not strictly necessary) + partition: usize, + /// Min/max bounds for each join key column in this partition. + /// Index corresponds to the join key expression index. + column_bounds: Vec, +} + +impl PartitionBounds { + pub(crate) fn new(partition: usize, column_bounds: Vec) -> Self { + Self { + partition, + column_bounds, + } + } + + pub(crate) fn len(&self) -> usize { + self.column_bounds.len() + } + + pub(crate) fn get_column_bounds(&self, index: usize) -> Option<&ColumnBounds> { + self.column_bounds.get(index) + } +} + +/// Coordinates dynamic filter bounds collection across multiple partitions +/// +/// This structure ensures that dynamic filters are built with complete information from all +/// relevant partitions before being applied to probe-side scans. Incomplete filters would +/// incorrectly eliminate valid join results. +/// +/// ## Synchronization Strategy +/// +/// 1. Each partition computes bounds from its build-side data +/// 2. Bounds are stored in the shared vector +/// 3. A barrier tracks how many partitions have reported their bounds +/// 4. When the last partition reports, bounds are merged and the filter is updated exactly once +/// +/// ## Partition Counting +/// +/// The `total_partitions` count represents how many times `collect_build_side` will be called: +/// - **CollectLeft**: Number of output partitions (each accesses shared build data) +/// - **Partitioned**: Number of input partitions (each builds independently) +/// +/// ## Thread Safety +/// +/// All fields use a single mutex to ensure correct coordination between concurrent +/// partition executions. +pub(crate) struct SharedBoundsAccumulator { + /// Shared state protected by a single mutex to avoid ordering concerns + inner: Mutex, + barrier: Barrier, + /// Dynamic filter for pushdown to probe side + dynamic_filter: Arc, + /// Right side join expressions needed for creating filter bounds + on_right: Vec, +} + +/// State protected by SharedBoundsAccumulator's mutex +struct SharedBoundsState { + /// Bounds from completed partitions. + /// Each element represents the column bounds computed by one partition. + bounds: Vec, +} + +impl SharedBoundsAccumulator { + /// Creates a new SharedBoundsAccumulator configured for the given partition mode + /// + /// This method calculates how many times `collect_build_side` will be called based on the + /// partition mode's execution pattern. This count is critical for determining when we have + /// complete information from all partitions to build the dynamic filter. + /// + /// ## Partition Mode Execution Patterns + /// + /// - **CollectLeft**: Build side is collected ONCE from partition 0 and shared via `OnceFut` + /// across all output partitions. Each output partition calls `collect_build_side` to access the shared build data. + /// Although this results in multiple invocations, the `report_partition_bounds` function contains deduplication logic to handle them safely. + /// Expected calls = number of output partitions. + /// + /// + /// - **Partitioned**: Each partition independently builds its own hash table by calling + /// `collect_build_side` once. Expected calls = number of build partitions. + /// + /// - **Auto**: Placeholder mode resolved during optimization. Uses 1 as safe default since + /// the actual mode will be determined and a new bounds_accumulator created before execution. + /// + /// ## Why This Matters + /// + /// We cannot build a partial filter from some partitions - it would incorrectly eliminate + /// valid join results. We must wait until we have complete bounds information from ALL + /// relevant partitions before updating the dynamic filter. + pub(crate) fn new_from_partition_mode( + partition_mode: PartitionMode, + left_child: &dyn ExecutionPlan, + right_child: &dyn ExecutionPlan, + dynamic_filter: Arc, + on_right: Vec, + ) -> Self { + // Troubleshooting: If partition counts are incorrect, verify this logic matches + // the actual execution pattern in collect_build_side() + let expected_calls = match partition_mode { + // Each output partition accesses shared build data + PartitionMode::CollectLeft => { + right_child.output_partitioning().partition_count() + } + // Each partition builds its own data + PartitionMode::Partitioned => { + left_child.output_partitioning().partition_count() + } + // Default value, will be resolved during optimization (does not exist once `execute()` is called; will be replaced by one of the other two) + PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), + }; + Self { + inner: Mutex::new(SharedBoundsState { + bounds: Vec::with_capacity(expected_calls), + }), + barrier: Barrier::new(expected_calls), + dynamic_filter, + on_right, + } + } + + /// Create a filter expression from individual partition bounds using OR logic. + /// + /// This creates a filter where each partition's bounds form a conjunction (AND) + /// of column range predicates, and all partitions are combined with OR. + /// + /// For example, with 2 partitions and 2 columns: + /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= p0_max1) + /// OR + /// (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= p1_max1)) + pub(crate) fn create_filter_from_partition_bounds( + &self, + bounds: &[PartitionBounds], + ) -> Result> { + if bounds.is_empty() { + return Ok(lit(true)); + } + + // Create a predicate for each partition + let mut partition_predicates = Vec::with_capacity(bounds.len()); + + for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) { + // Create range predicates for each join key in this partition + let mut column_predicates = Vec::with_capacity(partition_bounds.len()); + + for (col_idx, right_expr) in self.on_right.iter().enumerate() { + if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { + // Create predicate: col >= min AND col <= max + let min_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::GtEq, + lit(column_bounds.min.clone()), + )) as Arc; + let max_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::LtEq, + lit(column_bounds.max.clone()), + )) as Arc; + let range_expr = + Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) + as Arc; + column_predicates.push(range_expr); + } + } + + // Combine all column predicates for this partition with AND + if !column_predicates.is_empty() { + let partition_predicate = column_predicates + .into_iter() + .reduce(|acc, pred| { + Arc::new(BinaryExpr::new(acc, Operator::And, pred)) + as Arc + }) + .unwrap(); + partition_predicates.push(partition_predicate); + } + } + + // Combine all partition predicates with OR + let combined_predicate = partition_predicates + .into_iter() + .reduce(|acc, pred| { + Arc::new(BinaryExpr::new(acc, Operator::Or, pred)) + as Arc + }) + .unwrap_or_else(|| lit(true)); + + Ok(combined_predicate) + } + + /// Report bounds from a completed partition and update dynamic filter if all partitions are done + /// + /// This method coordinates the dynamic filter updates across all partitions. It stores the + /// bounds from the current partition, increments the completion counter, and when all + /// partitions have reported, creates an OR'd filter from individual partition bounds. + /// + /// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions + /// to report their bounds. Once that occurs, the method will resolve for all callers and the + /// dynamic filter will be updated exactly once. + /// + /// # Note + /// + /// As barriers are reusable, it is likely an error to call this method more times than the + /// total number of partitions - as it can lead to pending futures that never resolve. We rely + /// on correct usage from the caller rather than imposing additional checks here. If this is a concern, + /// consider making the resulting future shared so the ready result can be reused. + /// + /// # Arguments + /// * `left_side_partition_id` - The identifier for the **left-side** partition reporting its bounds + /// * `partition_bounds` - The bounds computed by this partition (if any) + /// + /// # Returns + /// * `Result<()>` - Ok if successful, Err if filter update failed + pub(crate) async fn report_partition_bounds( + &self, + left_side_partition_id: usize, + partition_bounds: Option>, + ) -> Result<()> { + // Store bounds in the accumulator - this runs once per partition + if let Some(bounds) = partition_bounds { + let mut guard = self.inner.lock(); + + let should_push = if let Some(last_bound) = guard.bounds.last() { + // In `PartitionMode::CollectLeft`, all streams on the left side share the same partition id (0). + // Since this function can be called multiple times for that same partition, we must deduplicate + // by checking against the last recorded bound. + last_bound.partition != left_side_partition_id + } else { + true + }; + + if should_push { + guard + .bounds + .push(PartitionBounds::new(left_side_partition_id, bounds)); + } + } + + if self.barrier.wait().await.is_leader() { + // All partitions have reported, so we can update the filter + let inner = self.inner.lock(); + if !inner.bounds.is_empty() { + let filter_expr = + self.create_filter_from_partition_bounds(&inner.bounds)?; + self.dynamic_filter.update(filter_expr)?; + } + } + + Ok(()) + } +} + +impl fmt::Debug for SharedBoundsAccumulator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SharedBoundsAccumulator") + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs new file mode 100644 index 0000000000000..adc00d9fe75ec --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -0,0 +1,676 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream implementation for Hash Join +//! +//! This module implements [`HashJoinStream`], the streaming engine for +//! [`super::HashJoinExec`]. See comments in [`HashJoinStream`] for more details. + +use std::sync::Arc; +use std::task::Poll; + +use crate::joins::hash_join::exec::JoinLeftData; +use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; +use crate::joins::utils::{ + equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, +}; +use crate::joins::PartitionMode; +use crate::{ + handle_state, + hash_utils::create_hashes, + joins::join_hash_map::JoinHashMapOffset, + joins::utils::{ + adjust_indices_by_join_type, apply_join_filter_to_indices, + build_batch_empty_build_side, build_batch_from_indices, + need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + JoinHashMapType, StatefulStreamResult, + }, + RecordBatchStream, SendableRecordBatchStream, +}; + +use arrow::array::{ArrayRef, UInt32Array, UInt64Array}; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{ + internal_datafusion_err, internal_err, JoinSide, JoinType, NullEquality, Result, +}; +use datafusion_physical_expr::PhysicalExprRef; + +use ahash::RandomState; +use futures::{ready, Stream, StreamExt}; + +/// Represents build-side of hash join. +pub(super) enum BuildSide { + /// Indicates that build-side not collected yet + Initial(BuildSideInitialState), + /// Indicates that build-side data has been collected + Ready(BuildSideReadyState), +} + +/// Container for BuildSide::Initial related data +pub(super) struct BuildSideInitialState { + /// Future for building hash table from build-side input + pub(super) left_fut: OnceFut, +} + +/// Container for BuildSide::Ready related data +pub(super) struct BuildSideReadyState { + /// Collected build-side data + left_data: Arc, +} + +impl BuildSide { + /// Tries to extract BuildSideInitialState from BuildSide enum. + /// Returns an error if state is not Initial. + fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { + match self { + BuildSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready(&self) -> Result<&BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +/// Represents state of HashJoinStream +/// +/// Expected state transitions performed by HashJoinStream are: +/// +/// ```text +/// +/// WaitBuildSide +/// │ +/// ▼ +/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed +/// │ │ +/// │ ▼ +/// └─ ProcessProbeBatch +/// +/// ``` +#[derive(Debug, Clone)] +pub(super) enum HashJoinStreamState { + /// Initial state for HashJoinStream indicating that build-side data not collected yet + WaitBuildSide, + /// Waiting for bounds to be reported by all partitions + WaitPartitionBoundsReport, + /// Indicates that build-side has been collected, and stream is ready for fetching probe-side + FetchProbeBatch, + /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed + ProcessProbeBatch(ProcessProbeBatchState), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that HashJoinStream execution is completed + Completed, +} + +impl HashJoinStreamState { + /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. + /// Returns an error if state is not ProcessProbeBatchState. + fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> { + match self { + HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), + } + } +} + +/// Container for HashJoinStreamState::ProcessProbeBatch related data +#[derive(Debug, Clone)] +pub(super) struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, + /// Probe-side on expressions values + values: Vec, + /// Starting offset for JoinHashMap lookups + offset: JoinHashMapOffset, + /// Max joined probe-side index from current batch + joined_probe_idx: Option, +} + +impl ProcessProbeBatchState { + fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option) { + self.offset = offset; + if joined_probe_idx.is_some() { + self.joined_probe_idx = joined_probe_idx; + } + } +} + +/// [`Stream`] for [`super::HashJoinExec`] that does the actual join. +/// +/// This stream: +/// +/// - Collecting the build side (left input) into a hash map +/// - Iterating over the probe side (right input) in streaming fashion +/// - Looking up matches against the hash table and applying join filters +/// - Producing joined [`RecordBatch`]es incrementally +/// - Emitting unmatched rows for outer/semi/anti joins in the final stage +pub(super) struct HashJoinStream { + /// Partition identifier for debugging and determinism + partition: usize, + /// Input schema + schema: Arc, + /// equijoin columns from the right (probe side) + on_right: Vec, + /// optional join filter + filter: Option, + /// type of the join (left, right, semi, etc) + join_type: JoinType, + /// right (probe) input + right: SendableRecordBatchStream, + /// Random state used for hashing initialization + random_state: RandomState, + /// Metrics + join_metrics: BuildProbeJoinMetrics, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// Defines the null equality for the join. + null_equality: NullEquality, + /// State of the stream + state: HashJoinStreamState, + /// Build side + build_side: BuildSide, + /// Maximum output batch size + batch_size: usize, + /// Scratch space for computing hashes + hashes_buffer: Vec, + /// Specifies whether the right side has an ordering to potentially preserve + right_side_ordered: bool, + /// Shared bounds accumulator for coordinating dynamic filter updates (optional) + bounds_accumulator: Option>, + /// Optional future to signal when bounds have been reported by all partitions + /// and the dynamic filter has been updated + bounds_waiter: Option>, + + /// Partitioning mode to use + mode: PartitionMode, +} + +impl RecordBatchStream for HashJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +/// Executes lookups by hash against JoinHashMap and resolves potential +/// hash collisions. +/// Returns build/probe indices satisfying the equality condition, along with +/// (optional) starting point for next iteration. +/// +/// # Example +/// +/// For `LEFT.b1 = RIGHT.b2`: +/// LEFT (build) Table: +/// ```text +/// a1 b1 c1 +/// 1 1 10 +/// 3 3 30 +/// 5 5 50 +/// 7 7 70 +/// 9 8 90 +/// 11 8 110 +/// 13 10 130 +/// ``` +/// +/// RIGHT (probe) Table: +/// ```text +/// a2 b2 c2 +/// 2 2 20 +/// 4 4 40 +/// 6 6 60 +/// 8 8 80 +/// 10 10 100 +/// 12 10 120 +/// ``` +/// +/// The result is +/// ```text +/// "+----+----+-----+----+----+-----+", +/// "| a1 | b1 | c1 | a2 | b2 | c2 |", +/// "+----+----+-----+----+----+-----+", +/// "| 9 | 8 | 90 | 8 | 8 | 80 |", +/// "| 11 | 8 | 110 | 8 | 8 | 80 |", +/// "| 13 | 10 | 130 | 10 | 10 | 100 |", +/// "| 13 | 10 | 130 | 12 | 10 | 120 |", +/// "+----+----+-----+----+----+-----+" +/// ``` +/// +/// And the result of build and probe indices are: +/// ```text +/// Build indices: 4, 5, 6, 6 +/// Probe indices: 3, 3, 4, 5 +/// ``` +#[allow(clippy::too_many_arguments)] +pub(super) fn lookup_join_hashmap( + build_hashmap: &dyn JoinHashMapType, + build_side_values: &[ArrayRef], + probe_side_values: &[ArrayRef], + null_equality: NullEquality, + hashes_buffer: &[u64], + limit: usize, + offset: JoinHashMapOffset, +) -> Result<(UInt64Array, UInt32Array, Option)> { + let (probe_indices, build_indices, next_offset) = + build_hashmap.get_matched_indices_with_limit_offset(hashes_buffer, limit, offset); + + let build_indices: UInt64Array = build_indices.into(); + let probe_indices: UInt32Array = probe_indices.into(); + + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, + build_side_values, + probe_side_values, + null_equality, + )?; + + Ok((build_indices, probe_indices, next_offset)) +} + +impl HashJoinStream { + #[allow(clippy::too_many_arguments)] + pub(super) fn new( + partition: usize, + schema: Arc, + on_right: Vec, + filter: Option, + join_type: JoinType, + right: SendableRecordBatchStream, + random_state: RandomState, + join_metrics: BuildProbeJoinMetrics, + column_indices: Vec, + null_equality: NullEquality, + state: HashJoinStreamState, + build_side: BuildSide, + batch_size: usize, + hashes_buffer: Vec, + right_side_ordered: bool, + bounds_accumulator: Option>, + mode: PartitionMode, + ) -> Self { + Self { + partition, + schema, + on_right, + filter, + join_type, + right, + random_state, + join_metrics, + column_indices, + null_equality, + state, + build_side, + batch_size, + hashes_buffer, + right_side_ordered, + bounds_accumulator, + bounds_waiter: None, + mode, + } + } + + /// Separate implementation function that unpins the [`HashJoinStream`] so + /// that partial borrows work correctly + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + return match self.state { + HashJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + HashJoinStreamState::WaitPartitionBoundsReport => { + handle_state!(ready!(self.wait_for_partition_bounds_report(cx))) + } + HashJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + HashJoinStreamState::ProcessProbeBatch(_) => { + let poll = handle_state!(self.process_probe_batch()); + self.join_metrics.baseline.record_poll(poll) + } + HashJoinStreamState::ExhaustedProbeSide => { + let poll = handle_state!(self.process_unmatched_build_batch()); + self.join_metrics.baseline.record_poll(poll) + } + HashJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + /// Optional step to wait until bounds have been reported by all partitions. + /// This state is only entered if a bounds accumulator is present. + /// + /// ## Why wait? + /// + /// The dynamic filter is only built once all partitions have reported their bounds. + /// If we do not wait here, the probe-side scan may start before the filter is ready. + /// This can lead to the probe-side scan missing the opportunity to apply the filter + /// and skip reading unnecessary data. + fn wait_for_partition_bounds_report( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + if let Some(ref mut fut) = self.bounds_waiter { + ready!(fut.get_shared(cx))?; + } + self.state = HashJoinStreamState::FetchProbeBatch; + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Collects build-side data by polling `OnceFut` future from initialized build-side + /// + /// Updates build-side to `Ready`, and state to `FetchProbeSide` + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let build_timer = self.join_metrics.build_time.timer(); + // build hash table from left (build) side, if not yet done + let left_data = ready!(self + .build_side + .try_as_initial_mut()? + .left_fut + .get_shared(cx))?; + build_timer.done(); + + // Handle dynamic filter bounds accumulation + // + // Dynamic filter coordination between partitions: + // Report bounds to the accumulator which will handle synchronization and filter updates + if let Some(ref bounds_accumulator) = self.bounds_accumulator { + let bounds_accumulator = Arc::clone(bounds_accumulator); + + let left_side_partition_id = match self.mode { + PartitionMode::Partitioned => self.partition, + PartitionMode::CollectLeft => 0, + PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), + }; + + let left_data_bounds = left_data.bounds.clone(); + self.bounds_waiter = Some(OnceFut::new(async move { + bounds_accumulator + .report_partition_bounds(left_side_partition_id, left_data_bounds) + .await + })); + self.state = HashJoinStreamState::WaitPartitionBoundsReport; + } else { + self.state = HashJoinStreamState::FetchProbeBatch; + } + + self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, + /// otherwise updates state to `ExhaustedProbeSide` + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.right.poll_next_unpin(cx)) { + None => { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(batch)) => { + // Precalculate hash values for fetched batch + let keys_values = self + .on_right + .iter() + .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows())) + .collect::>>()?; + + self.hashes_buffer.clear(); + self.hashes_buffer.resize(batch.num_rows(), 0); + create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.state = + HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { + batch, + values: keys_values, + offset: (0, None), + joined_probe_idx: None, + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with matched output + /// + /// Updates state to `FetchProbeBatch` + fn process_probe_batch( + &mut self, + ) -> Result>> { + let state = self.state.try_as_process_probe_batch_mut()?; + let build_side = self.build_side.try_as_ready_mut()?; + + let timer = self.join_metrics.join_time.timer(); + + // if the left side is empty, we can skip the (potentially expensive) join operation + if build_side.left_data.hash_map.is_empty() && self.filter.is_none() { + let result = build_batch_empty_build_side( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &self.column_indices, + self.join_type, + )?; + self.join_metrics.output_batches.add(1); + timer.done(); + + self.state = HashJoinStreamState::FetchProbeBatch; + + return Ok(StatefulStreamResult::Ready(Some(result))); + } + + // get the matched by join keys indices + let (left_indices, right_indices, next_offset) = lookup_join_hashmap( + build_side.left_data.hash_map(), + build_side.left_data.values(), + &state.values, + self.null_equality, + &self.hashes_buffer, + self.batch_size, + state.offset, + )?; + + // apply join filter if exists + let (left_indices, right_indices) = if let Some(filter) = &self.filter { + apply_join_filter_to_indices( + build_side.left_data.batch(), + &state.batch, + left_indices, + right_indices, + filter, + JoinSide::Left, + None, + )? + } else { + (left_indices, right_indices) + }; + + // mark joined left-side indices as visited, if required by join type + if need_produce_result_in_final(self.join_type) { + let mut bitmap = build_side.left_data.visited_indices_bitmap().lock(); + left_indices.iter().flatten().for_each(|x| { + bitmap.set_bit(x as usize, true); + }); + } + + // The goals of index alignment for different join types are: + // + // 1) Right & FullJoin -- to append all missing probe-side indices between + // previous (excluding) and current joined indices. + // 2) SemiJoin -- deduplicate probe indices in range between previous + // (excluding) and current joined indices. + // 3) AntiJoin -- return only missing indices in range between + // previous and current joined indices. + // Inclusion/exclusion of the indices themselves don't matter + // + // As a summary -- alignment range can be produced based only on + // joined (matched with filters applied) probe side indices, excluding starting one + // (left from previous iteration). + + // if any rows have been joined -- get last joined probe-side (right) row + // it's important that index counts as "joined" after hash collisions checks + // and join filters applied. + let last_joined_right_idx = match right_indices.len() { + 0 => None, + n => Some(right_indices.value(n - 1) as usize), + }; + + // Calculate range and perform alignment. + // In case probe batch has been processed -- align all remaining rows. + let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); + let index_alignment_range_end = if next_offset.is_none() { + state.batch.num_rows() + } else { + last_joined_right_idx.map_or(0, |v| v + 1) + }; + + let (left_indices, right_indices) = adjust_indices_by_join_type( + left_indices, + right_indices, + index_alignment_range_start..index_alignment_range_end, + self.join_type, + self.right_side_ordered, + )?; + + let result = if self.join_type == JoinType::RightMark { + build_batch_from_indices( + &self.schema, + &state.batch, + build_side.left_data.batch(), + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Right, + )? + } else { + build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Left, + )? + }; + + self.join_metrics.output_batches.add(1); + timer.done(); + + if next_offset.is_none() { + self.state = HashJoinStreamState::FetchProbeBatch; + } else { + state.advance( + next_offset + .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?, + last_joined_right_idx, + ) + }; + + Ok(StatefulStreamResult::Ready(Some(result))) + } + + /// Processes unmatched build-side rows for certain join types and produces output batch + /// + /// Updates state to `Completed` + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let timer = self.join_metrics.join_time.timer(); + + if !need_produce_result_in_final(self.join_type) { + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + + let build_side = self.build_side.try_as_ready()?; + if !build_side.left_data.report_probe_completed() { + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = get_final_indices_from_shared_bitmap( + build_side.left_data.visited_indices_bitmap(), + self.join_type, + ); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + } + timer.done(); + + self.state = HashJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(result?))) + } +} + +impl Stream for HashJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} diff --git a/datafusion/physical-plan/src/joins/join_filter.rs b/datafusion/physical-plan/src/joins/join_filter.rs index 0e46a971d90bb..de5df2be55650 100644 --- a/datafusion/physical-plan/src/joins/join_filter.rs +++ b/datafusion/physical-plan/src/joins/join_filter.rs @@ -19,7 +19,7 @@ use crate::joins::utils::ColumnIndex; use arrow::datatypes::SchemaRef; use datafusion_common::JoinSide; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use std::sync::Arc; +use std::{fmt::Display, sync::Arc}; /// Filter applied before join output. Fields are crate-public to allow /// downstream implementations to experiment with custom joins. @@ -33,6 +33,14 @@ pub struct JoinFilter { pub(crate) schema: SchemaRef, } +/// For display in `EXPLAIN` plans, only expression with column names is needed, +/// it output expression like `(col1 + col2) = 0` +impl Display for JoinFilter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.expression.fmt_sql(f) + } +} + impl JoinFilter { /// Creates new JoinFilter pub fn new( diff --git a/datafusion/physical-plan/src/joins/join_hash_map.rs b/datafusion/physical-plan/src/joins/join_hash_map.rs index 7af0aeca0fd68..bdd4bfeeb0fbe 100644 --- a/datafusion/physical-plan/src/joins/join_hash_map.rs +++ b/datafusion/physical-plan/src/joins/join_hash_map.rs @@ -20,7 +20,7 @@ //! ["on" values] to a list of indices with this key's value. use std::fmt::{self, Debug}; -use std::ops::IndexMut; +use std::ops::Sub; use hashbrown::hash_table::Entry::{Occupied, Vacant}; use hashbrown::HashTable; @@ -35,7 +35,7 @@ use hashbrown::HashTable; /// During this stage it might be the case that a row is contained the same hashmap value, /// but the values don't match. Those are checked in the `equal_rows_arr` method. /// -/// The indices (values) are stored in a separate chained list stored in the `Vec`. +/// The indices (values) are stored in a separate chained list stored as `Vec` or `Vec`. /// /// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. /// @@ -87,27 +87,170 @@ use hashbrown::HashTable; /// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) /// --------------------- /// ``` -pub struct JoinHashMap { +/// +/// Here we have an option between creating a `JoinHashMapType` using `u32` or `u64` indices +/// based on how many rows were being used for indices. +/// +/// At runtime we choose between using `JoinHashMapU32` and `JoinHashMapU64` which oth implement +/// `JoinHashMapType`. +pub trait JoinHashMapType: Send + Sync { + fn extend_zero(&mut self, len: usize); + + fn update_from_iter<'a>( + &mut self, + iter: Box + Send + 'a>, + deleted_offset: usize, + ); + + fn get_matched_indices<'a>( + &self, + iter: Box + 'a>, + deleted_offset: Option, + ) -> (Vec, Vec); + + fn get_matched_indices_with_limit_offset( + &self, + hash_values: &[u64], + limit: usize, + offset: JoinHashMapOffset, + ) -> (Vec, Vec, Option); + + /// Returns `true` if the join hash map contains no entries. + fn is_empty(&self) -> bool; +} + +pub struct JoinHashMapU32 { + // Stores hash value to last row index + map: HashTable<(u64, u32)>, + // Stores indices in chained list data structure + next: Vec, +} + +impl JoinHashMapU32 { + #[cfg(test)] + pub(crate) fn new(map: HashTable<(u64, u32)>, next: Vec) -> Self { + Self { map, next } + } + + pub fn with_capacity(cap: usize) -> Self { + Self { + map: HashTable::with_capacity(cap), + next: vec![0; cap], + } + } +} + +impl Debug for JoinHashMapU32 { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } +} + +impl JoinHashMapType for JoinHashMapU32 { + fn extend_zero(&mut self, _: usize) {} + + fn update_from_iter<'a>( + &mut self, + iter: Box + Send + 'a>, + deleted_offset: usize, + ) { + update_from_iter::(&mut self.map, &mut self.next, iter, deleted_offset); + } + + fn get_matched_indices<'a>( + &self, + iter: Box + 'a>, + deleted_offset: Option, + ) -> (Vec, Vec) { + get_matched_indices::(&self.map, &self.next, iter, deleted_offset) + } + + fn get_matched_indices_with_limit_offset( + &self, + hash_values: &[u64], + limit: usize, + offset: JoinHashMapOffset, + ) -> (Vec, Vec, Option) { + get_matched_indices_with_limit_offset::( + &self.map, + &self.next, + hash_values, + limit, + offset, + ) + } + + fn is_empty(&self) -> bool { + self.map.is_empty() + } +} + +pub struct JoinHashMapU64 { // Stores hash value to last row index map: HashTable<(u64, u64)>, // Stores indices in chained list data structure next: Vec, } -impl JoinHashMap { +impl JoinHashMapU64 { #[cfg(test)] pub(crate) fn new(map: HashTable<(u64, u64)>, next: Vec) -> Self { Self { map, next } } - pub(crate) fn with_capacity(capacity: usize) -> Self { - JoinHashMap { - map: HashTable::with_capacity(capacity), - next: vec![0; capacity], + pub fn with_capacity(cap: usize) -> Self { + Self { + map: HashTable::with_capacity(cap), + next: vec![0; cap], } } } +impl Debug for JoinHashMapU64 { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { + Ok(()) + } +} + +impl JoinHashMapType for JoinHashMapU64 { + fn extend_zero(&mut self, _: usize) {} + + fn update_from_iter<'a>( + &mut self, + iter: Box + Send + 'a>, + deleted_offset: usize, + ) { + update_from_iter::(&mut self.map, &mut self.next, iter, deleted_offset); + } + + fn get_matched_indices<'a>( + &self, + iter: Box + 'a>, + deleted_offset: Option, + ) -> (Vec, Vec) { + get_matched_indices::(&self.map, &self.next, iter, deleted_offset) + } + + fn get_matched_indices_with_limit_offset( + &self, + hash_values: &[u64], + limit: usize, + offset: JoinHashMapOffset, + ) -> (Vec, Vec, Option) { + get_matched_indices_with_limit_offset::( + &self.map, + &self.next, + hash_values, + limit, + offset, + ) + } + + fn is_empty(&self) -> bool { + self.map.is_empty() + } +} + // Type of offsets for obtaining indices from JoinHashMap. pub(crate) type JoinHashMapOffset = (usize, Option); @@ -115,233 +258,198 @@ pub(crate) type JoinHashMapOffset = (usize, Option); // Early returns in case of reaching output tuples limit. macro_rules! chain_traverse { ( - $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident, - $input_idx:ident, $chain_idx:ident, $deleted_offset:ident, $remaining_output:ident - ) => { - let mut i = $chain_idx - 1; + $input_indices:ident, $match_indices:ident, + $hash_values:ident, $next_chain:ident, + $input_idx:ident, $chain_idx:ident, $remaining_output:ident, $one:ident, $zero:ident + ) => {{ + // now `one` and `zero` are in scope from the outer function + let mut match_row_idx = $chain_idx - $one; loop { - let match_row_idx = if let Some(offset) = $deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; - $match_indices.push(match_row_idx); + $match_indices.push(match_row_idx.into()); $input_indices.push($input_idx as u32); $remaining_output -= 1; - // Follow the chain to get the next index value - let next = $next_chain[match_row_idx as usize]; + + let next = $next_chain[match_row_idx.into() as usize]; if $remaining_output == 0 { - // In case current input index is the last, and no more chain values left - // returning None as whole input has been scanned - let next_offset = if $input_idx == $hash_values.len() - 1 && next == 0 { + // we compare against `zero` (of type T) here too + let next_offset = if $input_idx == $hash_values.len() - 1 && next == $zero + { None } else { - Some(($input_idx, Some(next))) + Some(($input_idx, Some(next.into()))) }; return ($input_indices, $match_indices, next_offset); } - if next == 0 { - // end of list + if next == $zero { break; } - i = next - 1; + match_row_idx = next - $one; } - }; + }}; } -// Trait defining methods that must be implemented by a hash map type to be used for joins. -pub trait JoinHashMapType { - /// The type of list used to store the next list - type NextType: IndexMut; - /// Extend with zero - fn extend_zero(&mut self, len: usize); - /// Returns mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut HashTable<(u64, u64)>, &mut Self::NextType); - /// Returns a reference to the hash map. - fn get_map(&self) -> &HashTable<(u64, u64)>; - /// Returns a reference to the next. - fn get_list(&self) -> &Self::NextType; - - /// Updates hashmap from iterator of row indices & row hashes pairs. - fn update_from_iter<'a>( - &mut self, - iter: impl Iterator, - deleted_offset: usize, - ) { - let (mut_map, mut_list) = self.get_mut(); - for (row, &hash_value) in iter { - let entry = mut_map.entry( - hash_value, - |&(hash, _)| hash_value == hash, - |&(hash, _)| hash, - ); +pub fn update_from_iter<'a, T>( + map: &mut HashTable<(u64, T)>, + next: &mut [T], + iter: Box + Send + 'a>, + deleted_offset: usize, +) where + T: Copy + TryFrom + PartialOrd, + >::Error: Debug, +{ + for (row, &hash_value) in iter { + let entry = map.entry( + hash_value, + |&(hash, _)| hash_value == hash, + |&(hash, _)| hash, + ); - match entry { - Occupied(mut occupied_entry) => { - // Already exists: add index to next array - let (_, index) = occupied_entry.get_mut(); - let prev_index = *index; - // Store new value inside hashmap - *index = (row + 1) as u64; - // Update chained Vec at `row` with previous value - mut_list[row - deleted_offset] = prev_index; - } - Vacant(vacant_entry) => { - vacant_entry.insert((hash_value, (row + 1) as u64)); - // chained list at `row` is already initialized with 0 - // meaning end of list - } + match entry { + Occupied(mut occupied_entry) => { + // Already exists: add index to next array + let (_, index) = occupied_entry.get_mut(); + let prev_index = *index; + // Store new value inside hashmap + *index = T::try_from(row + 1).unwrap(); + // Update chained Vec at `row` with previous value + next[row - deleted_offset] = prev_index; + } + Vacant(vacant_entry) => { + vacant_entry.insert((hash_value, T::try_from(row + 1).unwrap())); } } } +} - /// Returns all pairs of row indices matched by hash. - /// - /// This method only compares hashes, so additional further check for actual values - /// equality may be required. - fn get_matched_indices<'a>( - &self, - iter: impl Iterator, - deleted_offset: Option, - ) -> (Vec, Vec) { - let mut input_indices = vec![]; - let mut match_indices = vec![]; - - let hash_map = self.get_map(); - let next_chain = self.get_list(); - for (row_idx, hash_value) in iter { - // Get the hash and find it in the index - if let Some((_, index)) = - hash_map.find(*hash_value, |(hash, _)| *hash_value == *hash) - { - let mut i = *index - 1; - loop { - let match_row_idx = if let Some(offset) = deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; - match_indices.push(match_row_idx); - input_indices.push(row_idx as u32); - // Follow the chain to get the next index value - let next = next_chain[match_row_idx as usize]; - if next == 0 { - // end of list +pub fn get_matched_indices<'a, T>( + map: &HashTable<(u64, T)>, + next: &[T], + iter: Box + 'a>, + deleted_offset: Option, +) -> (Vec, Vec) +where + T: Copy + TryFrom + PartialOrd + Into + Sub, + >::Error: Debug, +{ + let mut input_indices = vec![]; + let mut match_indices = vec![]; + let zero = T::try_from(0).unwrap(); + let one = T::try_from(1).unwrap(); + + for (row_idx, hash_value) in iter { + // Get the hash and find it in the index + if let Some((_, index)) = map.find(*hash_value, |(hash, _)| *hash_value == *hash) + { + let mut i = *index - one; + loop { + let match_row_idx = if let Some(offset) = deleted_offset { + let offset = T::try_from(offset).unwrap(); + // This arguments means that we prune the next index way before here. + if i < offset { + // End of the list due to pruning break; } - i = next - 1; + i - offset + } else { + i + }; + match_indices.push(match_row_idx.into()); + input_indices.push(row_idx as u32); + // Follow the chain to get the next index value + let next_chain = next[match_row_idx.into() as usize]; + if next_chain == zero { + // end of list + break; } + i = next_chain - one; } } - - (input_indices, match_indices) } - /// Matches hashes with taking limit and offset into account. - /// Returns pairs of matched indices along with the starting point for next - /// matching iteration (`None` if limit has not been reached). - /// - /// This method only compares hashes, so additional further check for actual values - /// equality may be required. - fn get_matched_indices_with_limit_offset( - &self, - hash_values: &[u64], - deleted_offset: Option, - limit: usize, - offset: JoinHashMapOffset, - ) -> (Vec, Vec, Option) { - let mut input_indices = vec![]; - let mut match_indices = vec![]; - - let mut remaining_output = limit; - - let hash_map: &HashTable<(u64, u64)> = self.get_map(); - let next_chain = self.get_list(); - - // Calculate initial `hash_values` index before iterating - let to_skip = match offset { - // None `initial_next_idx` indicates that `initial_idx` processing has'n been started - (initial_idx, None) => initial_idx, - // Zero `initial_next_idx` indicates that `initial_idx` has been processed during - // previous iteration, and it should be skipped - (initial_idx, Some(0)) => initial_idx + 1, - // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`, - // to start with the next index - (initial_idx, Some(initial_next_idx)) => { - chain_traverse!( - input_indices, - match_indices, - hash_values, - next_chain, - initial_idx, - initial_next_idx, - deleted_offset, - remaining_output - ); - - initial_idx + 1 - } - }; - - let mut row_idx = to_skip; - for hash_value in &hash_values[to_skip..] { - if let Some((_, index)) = - hash_map.find(*hash_value, |(hash, _)| *hash_value == *hash) - { - chain_traverse!( - input_indices, - match_indices, - hash_values, - next_chain, - row_idx, - index, - deleted_offset, - remaining_output - ); - } - row_idx += 1; - } - - (input_indices, match_indices, None) - } + (input_indices, match_indices) } -/// Implementation of `JoinHashMapType` for `JoinHashMap`. -impl JoinHashMapType for JoinHashMap { - type NextType = Vec; - - // Void implementation - fn extend_zero(&mut self, _: usize) {} +pub fn get_matched_indices_with_limit_offset( + map: &HashTable<(u64, T)>, + next_chain: &[T], + hash_values: &[u64], + limit: usize, + offset: JoinHashMapOffset, +) -> (Vec, Vec, Option) +where + T: Copy + TryFrom + PartialOrd + Into + Sub, + >::Error: Debug, +{ + let mut input_indices = Vec::with_capacity(limit); + let mut match_indices = Vec::with_capacity(limit); + let zero = T::try_from(0).unwrap(); + let one = T::try_from(1).unwrap(); - /// Get mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut HashTable<(u64, u64)>, &mut Self::NextType) { - (&mut self.map, &mut self.next) + // Check if hashmap consists of unique values + // If so, we can skip the chain traversal + if map.len() == next_chain.len() { + let start = offset.0; + let end = (start + limit).min(hash_values.len()); + for (i, &hash) in hash_values[start..end].iter().enumerate() { + if let Some((_, idx)) = map.find(hash, |(h, _)| hash == *h) { + input_indices.push(start as u32 + i as u32); + match_indices.push((*idx - one).into()); + } + } + let next_off = if end == hash_values.len() { + None + } else { + Some((end, None)) + }; + return (input_indices, match_indices, next_off); } - /// Get a reference to the hash map. - fn get_map(&self) -> &HashTable<(u64, u64)> { - &self.map - } + let mut remaining_output = limit; - /// Get a reference to the next. - fn get_list(&self) -> &Self::NextType { - &self.next - } -} + // Calculate initial `hash_values` index before iterating + let to_skip = match offset { + // None `initial_next_idx` indicates that `initial_idx` processing has'n been started + (idx, None) => idx, + // Zero `initial_next_idx` indicates that `initial_idx` has been processed during + // previous iteration, and it should be skipped + (idx, Some(0)) => idx + 1, + // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`, + // to start with the next index + (idx, Some(next_idx)) => { + let next_idx: T = T::try_from(next_idx as usize).unwrap(); + chain_traverse!( + input_indices, + match_indices, + hash_values, + next_chain, + idx, + next_idx, + remaining_output, + one, + zero + ); + idx + 1 + } + }; -impl Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { - Ok(()) + let mut row_idx = to_skip; + for &hash in &hash_values[to_skip..] { + if let Some((_, idx)) = map.find(hash, |(h, _)| hash == *h) { + let idx: T = *idx; + chain_traverse!( + input_indices, + match_indices, + hash_values, + next_chain, + row_idx, + idx, + remaining_output, + one, + zero + ); + } + row_idx += 1; } + (input_indices, match_indices, None) } diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 22a8c0bc798c8..1d36db996434e 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -19,6 +19,7 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; +use datafusion_physical_expr::PhysicalExprRef; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; @@ -39,6 +40,11 @@ mod join_hash_map; #[cfg(test)] pub mod test_utils; +/// The on clause of the join, as vector of (left, right) columns. +pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>; +/// Reference for JoinOn. +pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)]; + #[derive(Clone, Copy, Debug, PartialEq, Eq)] /// Hash join Partitioning mode pub enum PartitionMode { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index cdd2eaeca8997..0974b3a9114ef 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -19,41 +19,44 @@ use std::any::Any; use std::fmt::Formatter; +use std::ops::{BitOr, ControlFlow}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; use super::utils::{ - asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap, - need_produce_result_in_final, reorder_output_after_swap, swap_join_projection, - BatchSplitter, BatchTransformer, NoopBatchTransformer, StatefulStreamResult, + asymmetric_join_output_partitioning, need_produce_result_in_final, + reorder_output_after_swap, swap_join_projection, }; -use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common::can_project; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::joins::utils::{ - adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, check_join_is_valid, estimate_join_statistics, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut, + need_produce_right_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + OnceAsync, OnceFut, }; use crate::joins::SharedBitmapBuilder; -use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; use crate::{ - handle_state, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, - ExecutionPlanProperties, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; -use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array}; -use arrow::compute::concat_batches; +use arrow::array::{ + new_null_array, Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions, +}; +use arrow::buffer::BooleanBuffer; +use arrow::compute::{concat_batches, filter, filter_record_batch, not, BatchCoalescer}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::cast::as_boolean_array; use datafusion_common::{ - exec_datafusion_err, internal_err, project_schema, JoinSide, Result, Statistics, + arrow_err, internal_datafusion_err, internal_err, project_schema, + unwrap_or_internal_err, DataFusionError, JoinSide, Result, ScalarValue, Statistics, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -62,92 +65,101 @@ use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; +use log::debug; use parking_lot::Mutex; -/// Left (build-side) data -struct JoinLeftData { - /// Build-side data collected to single batch - batch: RecordBatch, - /// Shared bitmap builder for visited left indices - bitmap: SharedBitmapBuilder, - /// Counter of running probe-threads, potentially able to update `bitmap` - probe_threads_counter: AtomicUsize, - /// Memory reservation for tracking batch and bitmap - /// Cleared on `JoinLeftData` drop - /// reservation is cleared on Drop - #[expect(dead_code)] - reservation: MemoryReservation, -} - -impl JoinLeftData { - fn new( - batch: RecordBatch, - bitmap: SharedBitmapBuilder, - probe_threads_counter: AtomicUsize, - reservation: MemoryReservation, - ) -> Self { - Self { - batch, - bitmap, - probe_threads_counter, - reservation, - } - } - - fn batch(&self) -> &RecordBatch { - &self.batch - } - - fn bitmap(&self) -> &SharedBitmapBuilder { - &self.bitmap - } - - /// Decrements counter of running threads, and returns `true` - /// if caller is the last running thread - fn report_probe_completed(&self) -> bool { - self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 - } -} - #[allow(rustdoc::private_intra_doc_links)] -/// NestedLoopJoinExec is build-probe join operator, whose main task is to -/// perform joins without any equijoin conditions in `ON` clause. +/// NestedLoopJoinExec is a build-probe join operator designed for joins that +/// do not have equijoin keys in their `ON` clause. /// -/// Execution consists of following phases: +/// # Execution Flow /// -/// #### 1. Build phase -/// Collecting build-side data in memory, by polling all available data from build-side input. -/// Due to the absence of equijoin conditions, it's not possible to partition build-side data -/// across multiple threads of the operator, so build-side is always collected in a single -/// batch shared across all threads. -/// The operator always considers LEFT input as build-side input, so it's crucial to adjust -/// smaller input to be the LEFT one. Normally this selection is handled by physical optimizer. +/// ```text +/// Incoming right batch +/// Left Side Buffered Batches +/// ┌───────────┐ ┌───────────────┐ +/// │ ┌───────┐ │ │ │ +/// │ │ │ │ │ │ +/// Current Left Row ───▶│ ├───────├─┤──────────┐ │ │ +/// │ │ │ │ │ └───────────────┘ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// │ └───────┘ │ │ │ +/// │ ┌───────┐ │ │ │ +/// │ │ │ │ │ ┌─────┘ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// │ │ │ │ │ │ +/// │ └───────┘ │ ▼ ▼ +/// │ ...... │ ┌──────────────────────┐ +/// │ │ │X (Cartesian Product) │ +/// │ │ └──────────┬───────────┘ +/// └───────────┘ │ +/// │ +/// ▼ +/// ┌───────┬───────────────┐ +/// │ │ │ +/// │ │ │ +/// │ │ │ +/// └───────┴───────────────┘ +/// Intermediate Batch +/// (For join predicate evaluation) +/// ``` /// -/// #### 2. Probe phase -/// Sequentially polling batches from the probe-side input and processing them according to the -/// following logic: -/// - apply join filter (`ON` clause) to Cartesian product of probe batch and build side data -/// -- filter evaluation is executed once per build-side data row -/// - update shared bitmap of joined ("visited") build-side row indices, if required -- allows -/// to produce unmatched build-side data in case of e.g. LEFT/FULL JOIN after probing phase -/// completed -/// - perform join index alignment is required -- depending on `JoinType` -/// - produce output join batch +/// The execution follows a two-phase design: /// -/// Probing phase is executed in parallel, according to probe-side input partitioning -- one -/// thread per partition. After probe input is exhausted, each thread **ATTEMPTS** to produce -/// unmatched build-side data. +/// ## 1. Buffering Left Input +/// - The operator eagerly buffers all left-side input batches into memory, +/// util a memory limit is reached. +/// Currently, an out-of-memory error will be thrown if all the left-side input batches +/// cannot fit into memory at once. +/// In the future, it's possible to make this case finish execution. (see +/// 'Memory-limited Execution' section) +/// - The rationale for buffering the left side is that scanning the right side +/// can be expensive (e.g., decoding Parquet files), so buffering more left +/// rows reduces the number of right-side scan passes required. /// -/// #### 3. Producing unmatched build-side data -/// Producing unmatched build-side data as an output batch, after probe input is exhausted. -/// This step is also executed in parallel (once per probe input partition), and to avoid -/// duplicate output of unmatched data (due to shared nature build-side data), each thread -/// "reports" about probe phase completion (which means that "visited" bitmap won't be -/// updated anymore), and only the last thread, reporting about completion, will return output. +/// ## 2. Probing Right Input +/// - Right-side input is streamed batch by batch. +/// - For each right-side batch: +/// - It evaluates the join filter against the full buffered left input. +/// This results in a Cartesian product between the right batch and each +/// left row -- with the join predicate/filter applied -- for each inner +/// loop iteration. +/// - Matched results are accumulated into an output buffer. (see more in +/// `Output Buffering Strategy` section) +/// - This process continues until all right-side input is consumed. /// -/// # Clone / Shared State +/// # Producing unmatched build-side data +/// - For special join types like left/full joins, it's required to also output +/// unmatched pairs. During execution, bitmaps are kept for both left and right +/// sides of the input; they'll be handled by dedicated states in `NLJStream`. +/// - The final output of the left side unmatched rows is handled by a single +/// partition for simplicity, since it only counts a small portion of the +/// execution time. (e.g. if probe side has 10k rows, the final output of +/// unmatched build side only roughly counts for 1/10k of the total time) +/// +/// # Output Buffering Strategy +/// The operator uses an intermediate output buffer to accumulate results. Once +/// the output threshold is reached (currently set to the same value as +/// `batch_size` in the configuration), the results will be eagerly output. +/// +/// # Extra Notes +/// - The operator always considers the **left** side as the build (buffered) side. +/// Therefore, the physical optimizer should assign the smaller input to the left. +/// - The design try to minimize the intermediate data size to approximately +/// 1 batch, for better cache locality and memory efficiency. /// +/// # TODO: Memory-limited Execution +/// If the memory budget is exceeded during left-side buffering, fallback +/// strategies such as streaming left batches and re-scanning the right side +/// may be implemented in the future. +/// +/// Tracking issue: +/// +/// # Clone / Shared State /// Note this structure includes a [`OnceAsync`] that is used to coordinate the /// loading of the left side with the processing in each output stream. /// Therefore it can not be [`Clone`] @@ -161,15 +173,16 @@ pub struct NestedLoopJoinExec { pub(crate) filter: Option, /// How the join is performed pub(crate) join_type: JoinType, - /// The schema once the join is applied + /// The full concatenated schema of left and right children should be distinct from + /// the output schema of the operator join_schema: SchemaRef, /// Future that consumes left input and buffers it in memory /// /// This structure is *shared* across all output streams. /// /// Each output stream waits on the `OnceAsync` to signal the completion of - /// the hash table creation. - inner_table: OnceAsync, + /// the build(left) side data, and buffer them all for later joining. + build_side_data: OnceAsync, /// Information of index and left / right placement of columns column_indices: Vec, /// Projection to apply to the output of the join @@ -210,7 +223,7 @@ impl NestedLoopJoinExec { filter, join_type: *join_type, join_schema, - inner_table: Default::default(), + build_side_data: Default::default(), column_indices, projection, metrics: Default::default(), @@ -260,10 +273,10 @@ impl NestedLoopJoinExec { None, // No on columns in nested loop join &[], - ); + )?; let mut output_partitioning = - asymmetric_join_output_partitioning(left, right, &join_type); + asymmetric_join_output_partitioning(left, right, &join_type)?; let emission_type = if left.boundedness().is_unbounded() { EmissionType::Final @@ -275,7 +288,8 @@ impl NestedLoopJoinExec { | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Right - | JoinType::RightAnti => EmissionType::Incremental, + | JoinType::RightAnti + | JoinType::RightMark => EmissionType::Incremental, // If we need to generate unmatched rows from the *build side*, // we need to emit them at the end. JoinType::Left @@ -305,29 +319,9 @@ impl NestedLoopJoinExec { )) } - /// Returns a vector indicating whether the left and right inputs maintain their order. - /// The first element corresponds to the left input, and the second to the right. - /// - /// The left (build-side) input's order may change, but the right (probe-side) input's - /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins. - /// - /// Maintaining the right input's order helps optimize the nodes down the pipeline - /// (See [`ExecutionPlan::maintains_input_order`]). - /// - /// This is a separate method because it is also called when computing properties, before - /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as - /// opposed to `Self`, for the same reason. - fn maintains_input_order(join_type: JoinType) -> Vec { - vec![ - false, - matches!( - join_type, - JoinType::Inner - | JoinType::Right - | JoinType::RightAnti - | JoinType::RightSemi - ), - ] + /// This join implementation does not preserve the input order of either side. + fn maintains_input_order(_join_type: JoinType) -> Vec { + vec![false, false] } pub fn contains_projection(&self) -> bool { @@ -355,6 +349,12 @@ impl NestedLoopJoinExec { /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left /// and right inputs swapped. + /// + /// # Notes: + /// + /// This function should be called BEFORE inserting any repartitioning + /// operators on the join's children. Check [`super::HashJoinExec::swap_inputs`] + /// for more details. pub fn swap_inputs(&self) -> Result> { let left = self.left(); let right = self.right(); @@ -379,6 +379,8 @@ impl NestedLoopJoinExec { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark ) || self.projection.is_some() { Arc::new(new_join) @@ -483,6 +485,13 @@ impl ExecutionPlan for NestedLoopJoinExec { partition: usize, context: Arc, ) -> Result { + if self.left.output_partitioning().partition_count() != 1 { + return internal_err!( + "Invalid NestedLoopJoinExec, the output partition count of the left child must be 1,\ + consider using CoalescePartitionsExec or the EnforceDistribution rule" + ); + } + let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); // Initialization reservation for load of inner table @@ -490,28 +499,21 @@ impl ExecutionPlan for NestedLoopJoinExec { MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]")) .register(context.memory_pool()); - let inner_table = self.inner_table.once(|| { - collect_left_input( - Arc::clone(&self.left), - Arc::clone(&context), + let build_side_data = self.build_side_data.try_once(|| { + let stream = self.left.execute(0, Arc::clone(&context))?; + + Ok(collect_left_input( + stream, join_metrics.clone(), load_reservation, need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), - ) - }); + )) + })?; let batch_size = context.session_config().batch_size(); - let enforce_batch_size_in_joins = - context.session_config().enforce_batch_size_in_joins(); - - let outer_table = self.right.execute(partition, context)?; - - let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); - // Right side has an order and it is maintained during operation. - let right_side_ordered = - self.maintains_input_order()[1] && self.right.output_ordering().is_some(); + let probe_side_data = self.right.execute(partition, context)?; // update column indices to reflect the projection let column_indices_after_projection = match &self.projection { @@ -522,37 +524,16 @@ impl ExecutionPlan for NestedLoopJoinExec { None => self.column_indices.clone(), }; - if enforce_batch_size_in_joins { - Ok(Box::pin(NestedLoopJoinStream { - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - outer_table, - inner_table, - column_indices: column_indices_after_projection, - join_metrics, - indices_cache, - right_side_ordered, - state: NestedLoopJoinStreamState::WaitBuildSide, - batch_transformer: BatchSplitter::new(batch_size), - left_data: None, - })) - } else { - Ok(Box::pin(NestedLoopJoinStream { - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - outer_table, - inner_table, - column_indices: column_indices_after_projection, - join_metrics, - indices_cache, - right_side_ordered, - state: NestedLoopJoinStreamState::WaitBuildSide, - batch_transformer: NoopBatchTransformer::new(), - left_data: None, - })) - } + Ok(Box::pin(NestedLoopJoinStream::new( + self.schema(), + self.filter.clone(), + self.join_type, + probe_side_data, + build_side_data, + column_indices_after_projection, + join_metrics, + batch_size, + ))) } fn metrics(&self) -> Option { @@ -560,12 +541,19 @@ impl ExecutionPlan for NestedLoopJoinExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, vec![], &self.join_type, - &self.join_schema, + &self.schema(), ) } @@ -608,22 +596,66 @@ impl ExecutionPlan for NestedLoopJoinExec { } } +impl EmbeddedProjection for NestedLoopJoinExec { + fn with_projection(&self, projection: Option>) -> Result { + self.with_projection(projection) + } +} + +/// Left (build-side) data +pub(crate) struct JoinLeftData { + /// Build-side data collected to single batch + batch: RecordBatch, + /// Shared bitmap builder for visited left indices + bitmap: SharedBitmapBuilder, + /// Counter of running probe-threads, potentially able to update `bitmap` + probe_threads_counter: AtomicUsize, + /// Memory reservation for tracking batch and bitmap + /// Cleared on `JoinLeftData` drop + /// reservation is cleared on Drop + #[expect(dead_code)] + reservation: MemoryReservation, +} + +impl JoinLeftData { + pub(crate) fn new( + batch: RecordBatch, + bitmap: SharedBitmapBuilder, + probe_threads_counter: AtomicUsize, + reservation: MemoryReservation, + ) -> Self { + Self { + batch, + bitmap, + probe_threads_counter, + reservation, + } + } + + pub(crate) fn batch(&self) -> &RecordBatch { + &self.batch + } + + pub(crate) fn bitmap(&self) -> &SharedBitmapBuilder { + &self.bitmap + } + + /// Decrements counter of running threads, and returns `true` + /// if caller is the last running thread + pub(crate) fn report_probe_completed(&self) -> bool { + self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 + } +} + /// Asynchronously collect input into a single batch, and creates `JoinLeftData` from it async fn collect_left_input( - input: Arc, - context: Arc, + stream: SendableRecordBatchStream, join_metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, with_visited_left_side: bool, probe_threads_count: usize, ) -> Result { - let schema = input.schema(); - let merge = if input.output_partitioning().partition_count() != 1 { - Arc::new(CoalescePartitionsExec::new(input)) - } else { - input - }; - let stream = merge.execute(0, context)?; + let schema = stream.schema(); // Load all batches and count the rows let (batches, metrics, mut reservation) = stream @@ -668,383 +700,1215 @@ async fn collect_left_input( )) } -/// This enumeration represents various states of the nested loop join algorithm. -#[derive(Debug, Clone)] -enum NestedLoopJoinStreamState { - /// The initial state, indicating that build-side data not collected yet - WaitBuildSide, - /// Indicates that build-side has been collected, and stream is ready for - /// fetching probe-side - FetchProbeBatch, - /// Indicates that a non-empty batch has been fetched from probe-side, and - /// is ready to be processed - ProcessProbeBatch(RecordBatch), - /// Indicates that probe-side has been fully processed - ExhaustedProbeSide, - /// Indicates that NestedLoopJoinStream execution is completed - Completed, -} - -impl NestedLoopJoinStreamState { - /// Tries to extract a `ProcessProbeBatchState` from the - /// `NestedLoopJoinStreamState` enum. Returns an error if state is not - /// `ProcessProbeBatchState`. - fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> { - match self { - NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state), - _ => internal_err!("Expected join stream in ProcessProbeBatch state"), - } - } +/// States for join processing. See `poll_next()` comment for more details about +/// state transitions. +#[derive(Debug, Clone, Copy)] +enum NLJState { + BufferingLeft, + FetchingRight, + ProbeRight, + EmitRightUnmatched, + EmitLeftUnmatched, + Done, } - -/// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct NestedLoopJoinStream { - /// Input schema - schema: Arc, +pub(crate) struct NestedLoopJoinStream { + // ======================================================================== + // PROPERTIES: + // Operator's properties that remain constant + // + // Note: The implementation uses the terms left/build-side table and + // right/probe-side table interchangeably. Treating the left side as the + // build side is a convention in DataFusion: the planner always tries to + // swap the smaller table to the left side. + // ======================================================================== + /// Output schema + pub(crate) output_schema: Arc, /// join filter - filter: Option, + pub(crate) join_filter: Option, /// type of the join - join_type: JoinType, - /// the outer table data of the nested loop join - outer_table: SendableRecordBatchStream, - /// the inner table data of the nested loop join - inner_table: OnceFut, - /// Information of index and left / right placement of columns - column_indices: Vec, - // TODO: support null aware equal - // null_equals_null: bool + pub(crate) join_type: JoinType, + /// the probe-side(right) table data of the nested loop join + pub(crate) right_data: SendableRecordBatchStream, + /// the build-side table data of the nested loop join + pub(crate) left_data: OnceFut, + /// Projection to construct the output schema from the left and right tables. + /// Example: + /// - output_schema: ['a', 'c'] + /// - left_schema: ['a', 'b'] + /// - right_schema: ['c'] + /// + /// The column indices would be [(left, 0), (right, 0)] -- taking the left + /// 0th column and right 0th column can construct the output schema. + /// + /// Note there are other columns ('b' in the example) still kept after + /// projection pushdown; this is because they might be used to evaluate + /// the join filter (e.g., `JOIN ON (b+c)>0`). + pub(crate) column_indices: Vec, /// Join execution metrics - join_metrics: BuildProbeJoinMetrics, - /// Cache for join indices calculations - indices_cache: (UInt64Array, UInt32Array), - /// Whether the right side is ordered - right_side_ordered: bool, - /// Current state of the stream - state: NestedLoopJoinStreamState, - /// Transforms the output batch before returning. - batch_transformer: T, - /// Result of the left data future - left_data: Option>, + pub(crate) join_metrics: BuildProbeJoinMetrics, + + /// `batch_size` from configuration + batch_size: usize, + + /// See comments in [`need_produce_right_in_final`] for more detail + should_track_unmatched_right: bool, + + // ======================================================================== + // STATE FLAGS/BUFFERS: + // Fields that hold intermediate data/flags during execution + // ======================================================================== + /// State Tracking + state: NLJState, + /// Output buffer holds the join result to output. It will emit eagerly when + /// the threshold is reached. + output_buffer: Box, + /// See comments in [`NLJState::Done`] for its purpose + handled_empty_output: bool, + + // Buffer(left) side + // ----------------- + /// The current buffered left data to join + buffered_left_data: Option>, + /// Index into the left buffered batch. Used in `ProbeRight` state + left_probe_idx: usize, + /// Index into the left buffered batch. Used in `EmitLeftUnmatched` state + left_emit_idx: usize, + /// Should we go back to `BufferingLeft` state again after `EmitLeftUnmatched` + /// state is over. + left_exhausted: bool, + /// If we can buffer all left data in one pass + /// TODO(now): this is for the (unimplemented) memory-limited execution + #[allow(dead_code)] + left_buffered_in_one_pass: bool, + + // Probe(right) side + // ----------------- + /// The current probe batch to process + current_right_batch: Option, + // For right join, keep track of matched rows in `current_right_batch` + // Constructed when fetching each new incoming right batch in `FetchingRight` state. + current_right_batch_matched: Option, } -/// Creates a Cartesian product of two input batches, preserving the order of the right batch, -/// and applying a join filter if provided. -/// -/// # Example -/// Input: -/// left = [0, 1], right = [0, 1, 2] -/// -/// Output: -/// left_indices = [0, 1, 0, 1, 0, 1], right_indices = [0, 0, 1, 1, 2, 2] -/// -/// Input: -/// left = [0, 1, 2], right = [0, 1, 2, 3], filter = left.a != right.a -/// -/// Output: -/// left_indices = [1, 2, 0, 2, 0, 1, 0, 1, 2], right_indices = [0, 0, 1, 1, 2, 2, 3, 3, 3] -fn build_join_indices( - left_batch: &RecordBatch, - right_batch: &RecordBatch, - filter: Option<&JoinFilter>, - indices_cache: &mut (UInt64Array, UInt32Array), -) -> Result<(UInt64Array, UInt32Array)> { - let left_row_count = left_batch.num_rows(); - let right_row_count = right_batch.num_rows(); - let output_row_count = left_row_count * right_row_count; - - // We always use the same indices before applying the filter, so we can cache them - let (left_indices_cache, right_indices_cache) = indices_cache; - let cached_output_row_count = left_indices_cache.len(); - - let (left_indices, right_indices) = - match output_row_count.cmp(&cached_output_row_count) { - std::cmp::Ordering::Equal => { - // Reuse the cached indices - (left_indices_cache.clone(), right_indices_cache.clone()) - } - std::cmp::Ordering::Less => { - // Left_row_count never changes because it's the build side. The changes to the - // right_row_count can be handled trivially by taking the first output_row_count - // elements of the cache because of how the indices are generated. - // (See the Ordering::Greater match arm) - ( - left_indices_cache.slice(0, output_row_count), - right_indices_cache.slice(0, output_row_count), - ) - } - std::cmp::Ordering::Greater => { - // Rebuild the indices cache +impl Stream for NestedLoopJoinStream { + type Item = Result; - // Produces 0, 1, 2, 0, 1, 2, 0, 1, 2, ... - *left_indices_cache = UInt64Array::from_iter_values( - (0..output_row_count as u64).map(|i| i % left_row_count as u64), - ); + /// See the comments [`NestedLoopJoinExec`] for high-level design ideas. + /// + /// # Implementation + /// + /// This function is the entry point of NLJ operator's state machine + /// transitions. The rough state transition graph is as follow, for more + /// details see the comment in each state's matching arm. + /// + /// ============================ + /// State transition graph: + /// ============================ + /// + /// (start) --> BufferingLeft + /// ---------------------------- + /// BufferingLeft → FetchingRight + /// + /// FetchingRight → ProbeRight (if right batch available) + /// FetchingRight → EmitLeftUnmatched (if right exhausted) + /// + /// ProbeRight → ProbeRight (next left row or after yielding output) + /// ProbeRight → EmitRightUnmatched (for special join types like right join) + /// ProbeRight → FetchingRight (done with the current right batch) + /// + /// EmitRightUnmatched → FetchingRight + /// + /// EmitLeftUnmatched → EmitLeftUnmatched (only process 1 chunk for each + /// iteration) + /// EmitLeftUnmatched → Done (if finished) + /// ---------------------------- + /// Done → (end) + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + loop { + match self.state { + // # NLJState transitions + // --> FetchingRight + // This state will prepare the left side batches, next state + // `FetchingRight` is responsible for preparing a single probe + // side batch, before start joining. + NLJState::BufferingLeft => { + debug!("[NLJState] Entering: {:?}", self.state); + // inside `collect_left_input` (the routine to buffer build + // -side batches), related metrics except build time will be + // updated. + // stop on drop + let build_metric = self.join_metrics.build_time.clone(); + let _build_timer = build_metric.timer(); + + match self.handle_buffering_left(cx) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(poll) => return poll, + } + } - // Produces 0, 0, 0, 1, 1, 1, 2, 2, 2, ... - *right_indices_cache = UInt32Array::from_iter_values( - (0..output_row_count as u32).map(|i| i / left_row_count as u32), - ); + // # NLJState transitions: + // 1. --> ProbeRight + // Start processing the join for the newly fetched right + // batch. + // 2. --> EmitLeftUnmatched: When the right side input is exhausted, (maybe) emit + // unmatched left side rows. + // + // After fetching a new batch from the right side, it will + // process all rows from the buffered left data: + // ```text + // for batch in right_side: + // for row in left_buffer: + // join(batch, row) + // ``` + // Note: the implementation does this step incrementally, + // instead of materializing all intermediate Cartesian products + // at once in memory. + // + // So after the right side input is exhausted, the join phase + // for the current buffered left data is finished. We can go to + // the next `EmitLeftUnmatched` phase to check if there is any + // special handling (e.g., in cases like left join). + NLJState::FetchingRight => { + debug!("[NLJState] Entering: {:?}", self.state); + // stop on drop + let join_metric = self.join_metrics.join_time.clone(); + let _join_timer = join_metric.timer(); + + match self.handle_fetching_right(cx) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(poll) => return poll, + } + } + + // NLJState transitions: + // 1. --> ProbeRight(1) + // If we have already buffered enough output to yield, it + // will first give back control to the parent state machine, + // then resume at the same place. + // 2. --> ProbeRight(2) + // After probing one right batch, and evaluating the + // join filter on (left-row x right-batch), it will advance + // to the next left row, then re-enter the current state and + // continue joining. + // 3. --> FetchRight + // After it has done with the current right batch (to join + // with all rows in the left buffer), it will go to + // FetchRight state to check what to do next. + NLJState::ProbeRight => { + debug!("[NLJState] Entering: {:?}", self.state); + + // stop on drop + let join_metric = self.join_metrics.join_time.clone(); + let _join_timer = join_metric.timer(); + + match self.handle_probe_right() { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(poll) => { + return self.join_metrics.baseline.record_poll(poll) + } + } + } + + // In the `current_right_batch_matched` bitmap, all trues mean + // it has been output by the join. In this state we have to + // output unmatched rows for current right batch (with null + // padding for left relation) + // Precondition: we have checked the join type so that it's + // possible to output right unmatched (e.g. it's right join) + NLJState::EmitRightUnmatched => { + debug!("[NLJState] Entering: {:?}", self.state); + + // stop on drop + let join_metric = self.join_metrics.join_time.clone(); + let _join_timer = join_metric.timer(); + + match self.handle_emit_right_unmatched() { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(poll) => { + return self.join_metrics.baseline.record_poll(poll) + } + } + } - (left_indices_cache.clone(), right_indices_cache.clone()) + // NLJState transitions: + // 1. --> EmitLeftUnmatched(1) + // If we have already buffered enough output to yield, it + // will first give back control to the parent state machine, + // then resume at the same place. + // 2. --> EmitLeftUnmatched(2) + // After processing some unmatched rows, it will re-enter + // the same state, to check if there are any more final + // results to output. + // 3. --> Done + // It has processed all data, go to the final state and ready + // to exit. + // + // TODO: For memory-limited case, go back to `BufferingLeft` + // state again. + NLJState::EmitLeftUnmatched => { + debug!("[NLJState] Entering: {:?}", self.state); + + // stop on drop + let join_metric = self.join_metrics.join_time.clone(); + let _join_timer = join_metric.timer(); + + match self.handle_emit_left_unmatched() { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(poll) => { + return self.join_metrics.baseline.record_poll(poll) + } + } + } + + // The final state and the exit point + NLJState::Done => { + debug!("[NLJState] Entering: {:?}", self.state); + + // stop on drop + let join_metric = self.join_metrics.join_time.clone(); + let _join_timer = join_metric.timer(); + // counting it in join timer due to there might be some + // final resout batches to output in this state + + let poll = self.handle_done(); + return self.join_metrics.baseline.record_poll(poll); + } } - }; + } + } +} - if let Some(filter) = filter { - apply_join_filter_to_indices( - left_batch, - right_batch, - left_indices, - right_indices, - filter, - JoinSide::Left, - ) - } else { - Ok((left_indices, right_indices)) +impl RecordBatchStream for NestedLoopJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.output_schema) } } -impl NestedLoopJoinStream { - fn poll_next_impl( +impl NestedLoopJoinStream { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + schema: Arc, + filter: Option, + join_type: JoinType, + right_data: SendableRecordBatchStream, + left_data: OnceFut, + column_indices: Vec, + join_metrics: BuildProbeJoinMetrics, + batch_size: usize, + ) -> Self { + Self { + output_schema: Arc::clone(&schema), + join_filter: filter, + join_type, + right_data, + column_indices, + left_data, + join_metrics, + buffered_left_data: None, + output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)), + batch_size, + current_right_batch: None, + current_right_batch_matched: None, + state: NLJState::BufferingLeft, + left_probe_idx: 0, + left_emit_idx: 0, + left_exhausted: false, + left_buffered_in_one_pass: true, + handled_empty_output: false, + should_track_unmatched_right: need_produce_right_in_final(join_type), + } + } + + // ==== State handler functions ==== + + /// Handle BufferingLeft state - prepare left side batches + fn handle_buffering_left( &mut self, cx: &mut std::task::Context<'_>, - ) -> Poll>> { - loop { - return match self.state { - NestedLoopJoinStreamState::WaitBuildSide => { - handle_state!(ready!(self.collect_build_side(cx))) - } - NestedLoopJoinStreamState::FetchProbeBatch => { - handle_state!(ready!(self.fetch_probe_batch(cx))) - } - NestedLoopJoinStreamState::ProcessProbeBatch(_) => { - handle_state!(self.process_probe_batch()) - } - NestedLoopJoinStreamState::ExhaustedProbeSide => { - handle_state!(self.process_unmatched_build_batch()) - } - NestedLoopJoinStreamState::Completed => Poll::Ready(None), - }; + ) -> ControlFlow>>> { + match self.left_data.get_shared(cx) { + Poll::Ready(Ok(left_data)) => { + self.buffered_left_data = Some(left_data); + // TODO: implement memory-limited case + self.left_exhausted = true; + self.state = NLJState::FetchingRight; + // Continue to next state immediately + ControlFlow::Continue(()) + } + Poll::Ready(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), + Poll::Pending => ControlFlow::Break(Poll::Pending), } } - fn collect_build_side( + /// Handle FetchingRight state - fetch next right batch and prepare for processing + fn handle_fetching_right( &mut self, cx: &mut std::task::Context<'_>, - ) -> Poll>>> { - let build_timer = self.join_metrics.build_time.timer(); - // build hash table from left (build) side, if not yet done - self.left_data = Some(ready!(self.inner_table.get_shared(cx))?); - build_timer.done(); + ) -> ControlFlow>>> { + match self.right_data.poll_next_unpin(cx) { + Poll::Ready(result) => match result { + Some(Ok(right_batch)) => { + // Update metrics + let right_batch_size = right_batch.num_rows(); + self.join_metrics.input_rows.add(right_batch_size); + self.join_metrics.input_batches.add(1); + + // Skip the empty batch + if right_batch_size == 0 { + return ControlFlow::Continue(()); + } + + self.current_right_batch = Some(right_batch); + + // Prepare right bitmap + if self.should_track_unmatched_right { + let zeroed_buf = BooleanBuffer::new_unset(right_batch_size); + self.current_right_batch_matched = + Some(BooleanArray::new(zeroed_buf, None)); + } + + self.left_probe_idx = 0; + self.state = NLJState::ProbeRight; + ControlFlow::Continue(()) + } + Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), + None => { + // Right stream exhausted + self.state = NLJState::EmitLeftUnmatched; + ControlFlow::Continue(()) + } + }, + Poll::Pending => ControlFlow::Break(Poll::Pending), + } + } - self.state = NestedLoopJoinStreamState::FetchProbeBatch; + /// Handle ProbeRight state - process current probe batch + fn handle_probe_right(&mut self) -> ControlFlow>>> { + // Return any completed batches first + if let Some(poll) = self.maybe_flush_ready_batch() { + return ControlFlow::Break(poll); + } - Poll::Ready(Ok(StatefulStreamResult::Continue)) + // Process current probe state + match self.process_probe_batch() { + // State unchanged (ProbeRight) + // Continue probing until we have done joining the + // current right batch with all buffered left rows. + Ok(true) => ControlFlow::Continue(()), + // To next FetchRightState + // We have finished joining + // (cur_right_batch x buffered_left_batches) + Ok(false) => { + // Left exhausted, transition to FetchingRight + self.left_probe_idx = 0; + if self.should_track_unmatched_right { + debug_assert!( + self.current_right_batch_matched.is_some(), + "If it's required to track matched rows in the right input, the right bitmap must be present" + ); + self.state = NLJState::EmitRightUnmatched; + } else { + self.current_right_batch = None; + self.state = NLJState::FetchingRight; + } + ControlFlow::Continue(()) + } + Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), + } } - /// Fetches next batch from probe-side - /// - /// If a non-empty batch has been fetched, updates state to - /// `ProcessProbeBatchState`, otherwise updates state to `ExhaustedProbeSide`. - fn fetch_probe_batch( + /// Handle EmitRightUnmatched state - emit unmatched right rows + fn handle_emit_right_unmatched( &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll>>> { - match ready!(self.outer_table.poll_next_unpin(cx)) { - None => { - self.state = NestedLoopJoinStreamState::ExhaustedProbeSide; + ) -> ControlFlow>>> { + // Return any completed batches first + if let Some(poll) = self.maybe_flush_ready_batch() { + return ControlFlow::Break(poll); + } + + debug_assert!( + self.current_right_batch_matched.is_some() + && self.current_right_batch.is_some(), + "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present" + ); + + // Construct the result batch for unmatched right rows using a utility function + match self.process_right_unmatched() { + Ok(Some(batch)) => { + match self.output_buffer.push_batch(batch) { + Ok(()) => { + // Processed all in one pass + // cleared inside `process_right_unmatched` + debug_assert!(self.current_right_batch.is_none()); + self.state = NLJState::FetchingRight; + ControlFlow::Continue(()) + } + Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + } } - Some(Ok(right_batch)) => { - self.state = NestedLoopJoinStreamState::ProcessProbeBatch(right_batch); + Ok(None) => { + // Processed all in one pass + // cleared inside `process_right_unmatched` + debug_assert!(self.current_right_batch.is_none()); + self.state = NLJState::FetchingRight; + ControlFlow::Continue(()) } - Some(Err(err)) => return Poll::Ready(Err(err)), - }; - - Poll::Ready(Ok(StatefulStreamResult::Continue)) + Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), + } } - /// Joins current probe batch with build-side data and produces batch with - /// matched output, updates state to `FetchProbeBatch`. - fn process_probe_batch( + /// Handle EmitLeftUnmatched state - emit unmatched left rows + fn handle_emit_left_unmatched( &mut self, - ) -> Result>> { - let Some(left_data) = self.left_data.clone() else { - return internal_err!( - "Expected left_data to be Some in ProcessProbeBatch state" - ); - }; - let visited_left_side = left_data.bitmap(); - let batch = self.state.try_as_process_probe_batch()?; - - match self.batch_transformer.next() { - None => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - timer.done(); + ) -> ControlFlow>>> { + // Return any completed batches first + if let Some(poll) = self.maybe_flush_ready_batch() { + return ControlFlow::Break(poll); + } - self.batch_transformer.set_batch(result?); - Ok(StatefulStreamResult::Continue) - } - Some((batch, last)) => { - if last { - self.state = NestedLoopJoinStreamState::FetchProbeBatch; + // Process current unmatched state + match self.process_left_unmatched() { + // State unchanged (EmitLeftUnmatched) + // Continue processing until we have processed all unmatched rows + Ok(true) => ControlFlow::Continue(()), + // To Done state + // We have finished processing all unmatched rows + Ok(false) => match self.output_buffer.finish_buffered_batch() { + Ok(()) => { + self.state = NLJState::Done; + ControlFlow::Continue(()) } + Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + }, + Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), + } + } - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - Ok(StatefulStreamResult::Ready(Some(batch))) + /// Handle Done state - final state processing + fn handle_done(&mut self) -> Poll>> { + // Return any remaining completed batches before final termination + if let Some(poll) = self.maybe_flush_ready_batch() { + return poll; + } + + // HACK for the doc test in https://github.com/apache/datafusion/blob/main/datafusion/core/src/dataframe/mod.rs#L1265 + // If this operator directly return `Poll::Ready(None)` + // for empty result, the final result will become an empty + // batch with empty schema, however the expected result + // should be with the expected schema for this operator + if !self.handled_empty_output { + let zero_count = Count::new(); + if *self.join_metrics.baseline.output_rows() == zero_count { + let empty_batch = RecordBatch::new_empty(Arc::clone(&self.output_schema)); + self.handled_empty_output = true; + return Poll::Ready(Some(Ok(empty_batch))); } } + + Poll::Ready(None) } - /// Processes unmatched build-side rows for certain join types and produces - /// output batch, updates state to `Completed`. - fn process_unmatched_build_batch( - &mut self, - ) -> Result>> { - let Some(left_data) = self.left_data.clone() else { - return internal_err!( - "Expected left_data to be Some in ExhaustedProbeSide state" - ); - }; - let visited_left_side = left_data.bitmap(); - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { - self.state = NestedLoopJoinStreamState::Completed; - return Ok(StatefulStreamResult::Ready(None)); - }; + // ==== Core logic handling for each state ==== + + /// Returns bool to indicate should it continue probing + /// true -> continue in the same ProbeRight state + /// false -> It has done with the (buffered_left x cur_right_batch), go to + /// next state (ProbeRight) + fn process_probe_batch(&mut self) -> Result { + let left_data = Arc::clone(self.get_left_data()?); + let right_batch = self + .current_right_batch + .as_ref() + .ok_or_else(|| internal_datafusion_err!("Right batch should be available"))? + .clone(); + + // stop probing, the caller will go to the next state + if self.left_probe_idx >= left_data.batch().num_rows() { + return Ok(false); + } - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); - let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.state = NestedLoopJoinStreamState::Completed; + // ======== + // Join (l_row x right_batch) + // and push the result into output_buffer + // ======== - // Recording time - if result.is_ok() { - timer.done(); - } + let l_idx = self.left_probe_idx; + let join_batch = + self.process_single_left_row_join(&left_data, &right_batch, l_idx)?; - Ok(StatefulStreamResult::Ready(Some(result?))) - } else { - // end of the join loop - self.state = NestedLoopJoinStreamState::Completed; - Ok(StatefulStreamResult::Ready(None)) + if let Some(batch) = join_batch { + self.output_buffer.push_batch(batch)?; } - } -} -#[allow(clippy::too_many_arguments)] -fn join_left_and_right_batch( - left_batch: &RecordBatch, - right_batch: &RecordBatch, - join_type: JoinType, - filter: Option<&JoinFilter>, - column_indices: &[ColumnIndex], - schema: &Schema, - visited_left_side: &SharedBitmapBuilder, - indices_cache: &mut (UInt64Array, UInt32Array), - right_side_ordered: bool, -) -> Result { - let (left_side, right_side) = - build_join_indices(left_batch, right_batch, filter, indices_cache).map_err( - |e| { - exec_datafusion_err!( - "Fail to build join indices in NestedLoopJoinExec, error: {e}" - ) - }, - )?; + // ==== Prepare for the next iteration ==== - // set the left bitmap - // and only full join need the left bitmap - if need_produce_result_in_final(join_type) { - let mut bitmap = visited_left_side.lock(); - left_side.values().iter().for_each(|x| { - bitmap.set_bit(*x as usize, true); - }); - } - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - 0..right_batch.num_rows(), - join_type, - right_side_ordered, - )?; - - build_batch_from_indices( - schema, - left_batch, - right_batch, - &left_side, - &right_side, - column_indices, - JoinSide::Left, - ) -} + // Advance left cursor + self.left_probe_idx += 1; -impl Stream for NestedLoopJoinStream { - type Item = Result; + // Return true to continue probing + Ok(true) + } - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.poll_next_impl(cx) + /// Process a single left row join with the current right batch. + /// Returns a RecordBatch containing the join results (None if empty) + fn process_single_left_row_join( + &mut self, + left_data: &JoinLeftData, + right_batch: &RecordBatch, + l_index: usize, + ) -> Result> { + let right_row_count = right_batch.num_rows(); + if right_row_count == 0 { + return Ok(None); + } + + let cur_right_bitmap = if let Some(filter) = &self.join_filter { + apply_filter_to_row_join_batch( + left_data.batch(), + l_index, + right_batch, + filter, + )? + } else { + BooleanArray::from(vec![true; right_row_count]) + }; + + self.update_matched_bitmap(l_index, &cur_right_bitmap)?; + + // For the following join types: here we only have to set the left/right + // bitmap, and no need to output result + if matches!( + self.join_type, + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::RightAnti + | JoinType::RightMark + | JoinType::RightSemi + ) { + return Ok(None); + } + + if cur_right_bitmap.true_count() == 0 { + // If none of the pairs has passed the join predicate/filter + Ok(None) + } else { + // Use the optimized approach similar to build_intermediate_batch_for_single_left_row + let join_batch = build_row_join_batch( + &self.output_schema, + left_data.batch(), + l_index, + right_batch, + Some(cur_right_bitmap), + &self.column_indices, + JoinSide::Left, + )?; + Ok(join_batch) + } + } + + /// Returns bool to indicate should it continue processing unmatched rows + /// true -> continue in the same EmitLeftUnmatched state + /// false -> next state (Done) + fn process_left_unmatched(&mut self) -> Result { + let left_data = self.get_left_data()?; + let left_batch = left_data.batch(); + + // ======== + // Check early return conditions + // ======== + + // Early return if join type can't have unmatched rows + let join_type_no_produce_left = !need_produce_result_in_final(self.join_type); + // Early return if another thread is already processing unmatched rows + let handled_by_other_partition = + self.left_emit_idx == 0 && !left_data.report_probe_completed(); + // Stop processing unmatched rows, the caller will go to the next state + let finished = self.left_emit_idx >= left_batch.num_rows(); + + if join_type_no_produce_left || handled_by_other_partition || finished { + return Ok(false); + } + + // ======== + // Process unmatched rows and push the result into output_buffer + // Each time, the number to process is up to batch size + // ======== + let start_idx = self.left_emit_idx; + let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows()); + + if let Some(batch) = + self.process_left_unmatched_range(left_data, start_idx, end_idx)? + { + self.output_buffer.push_batch(batch)?; + } + + // ==== Prepare for the next iteration ==== + self.left_emit_idx = end_idx; + + // Return true to continue processing unmatched rows + Ok(true) + } + + /// Process unmatched rows from the left data within the specified range. + /// Returns a RecordBatch containing the unmatched rows (None if empty). + /// + /// # Arguments + /// * `left_data` - The left side data containing the batch and bitmap + /// * `start_idx` - Start index (inclusive) of the range to process + /// * `end_idx` - End index (exclusive) of the range to process + /// + /// # Safety + /// The caller is responsible for ensuring that `start_idx` and `end_idx` are + /// within valid bounds of the left batch. This function does not perform + /// bounds checking. + fn process_left_unmatched_range( + &self, + left_data: &JoinLeftData, + start_idx: usize, + end_idx: usize, + ) -> Result> { + if start_idx == end_idx { + return Ok(None); + } + + // Slice both left batch, and bitmap to range [start_idx, end_idx) + // The range is bit index (not byte) + let left_batch = left_data.batch(); + let left_batch_sliced = left_batch.slice(start_idx, end_idx - start_idx); + + // Can this be more efficient? + let mut bitmap_sliced = BooleanBufferBuilder::new(end_idx - start_idx); + bitmap_sliced.append_n(end_idx - start_idx, false); + let bitmap = left_data.bitmap().lock(); + for i in start_idx..end_idx { + assert!( + i - start_idx < bitmap_sliced.capacity(), + "DBG: {start_idx}, {end_idx}" + ); + bitmap_sliced.set_bit(i - start_idx, bitmap.get_bit(i)); + } + let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None); + + build_unmatched_batch( + Arc::clone(&self.output_schema), + &left_batch_sliced, + bitmap_sliced, + self.right_data.schema(), + &self.column_indices, + self.join_type, + JoinSide::Left, + ) + } + + /// Process unmatched rows from the current right batch and reset the bitmap. + /// Returns a RecordBatch containing the unmatched right rows (None if empty). + fn process_right_unmatched(&mut self) -> Result> { + // ==== Take current right batch and its bitmap ==== + let right_batch_bitmap: BooleanArray = + std::mem::take(&mut self.current_right_batch_matched).ok_or_else(|| { + internal_datafusion_err!("right bitmap should be available") + })?; + + let right_batch = self.current_right_batch.take(); + let cur_right_batch = unwrap_or_internal_err!(right_batch); + + let left_data = self.get_left_data()?; + let left_schema = left_data.batch().schema(); + + let res = build_unmatched_batch( + Arc::clone(&self.output_schema), + &cur_right_batch, + right_batch_bitmap, + left_schema, + &self.column_indices, + self.join_type, + JoinSide::Right, + ); + + // ==== Clean-up ==== + self.current_right_batch_matched = None; + + res + } + + // ==== Utilities ==== + + /// Get the build-side data of the left input, errors if it's None + fn get_left_data(&self) -> Result<&Arc> { + self.buffered_left_data + .as_ref() + .ok_or_else(|| internal_datafusion_err!("LeftData should be available")) + } + + /// Flush the `output_buffer` if there are batches ready to output + /// None if no result batch ready. + fn maybe_flush_ready_batch(&mut self) -> Option>>> { + if self.output_buffer.has_completed_batch() { + if let Some(batch) = self.output_buffer.next_completed_batch() { + // HACK: this is not part of `BaselineMetrics` yet, so update it + // manually + self.join_metrics.output_batches.add(1); + + return Some(Poll::Ready(Some(Ok(batch)))); + } + } + + None + } + + /// After joining (l_index@left_buffer x current_right_batch), it will result + /// in a bitmap (the same length as current_right_batch) as the join match + /// result. Use this bitmap to update the global bitmap, for special join + /// types like full joins. + /// + /// Example: + /// After joining l_index=1 (1-indexed row in the left buffer), and the + /// current right batch with 3 elements, this function will be called with + /// arguments: l_index = 1, r_matched = [false, false, true] + /// - If the join type is FullJoin, the 1-index in the left bitmap will be + /// set to true, and also the right bitmap will be bitwise-ORed with the + /// input r_matched bitmap. + /// - For join types that don't require output unmatched rows, this + /// function can be a no-op. For inner joins, this function is a no-op; for left + /// joins, only the left bitmap may be updated. + fn update_matched_bitmap( + &mut self, + l_index: usize, + r_matched_bitmap: &BooleanArray, + ) -> Result<()> { + let left_data = self.get_left_data()?; + + // number of successfully joined pairs from (l_index x cur_right_batch) + let joined_len = r_matched_bitmap.true_count(); + + // 1. Maybe update the left bitmap + if need_produce_result_in_final(self.join_type) && (joined_len > 0) { + let mut bitmap = left_data.bitmap().lock(); + bitmap.set_bit(l_index, true); + } + + // 2. Maybe updateh the right bitmap + if self.should_track_unmatched_right { + debug_assert!(self.current_right_batch_matched.is_some()); + // after bit-wise or, it will be put back + let right_bitmap = std::mem::take(&mut self.current_right_batch_matched) + .ok_or_else(|| { + internal_datafusion_err!("right batch's bitmap should be present") + })?; + let (buf, nulls) = right_bitmap.into_parts(); + debug_assert!(nulls.is_none()); + let updated_right_bitmap = buf.bitor(r_matched_bitmap.values()); + + self.current_right_batch_matched = + Some(BooleanArray::new(updated_right_bitmap, None)); + } + + Ok(()) } } -impl RecordBatchStream for NestedLoopJoinStream { - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) +// ==== Utilities ==== + +/// Apply the join filter between: +/// (l_index th row in left buffer) x (right batch) +/// Returns a bitmap, with successfully joined indices set to true +fn apply_filter_to_row_join_batch( + left_batch: &RecordBatch, + l_index: usize, + right_batch: &RecordBatch, + filter: &JoinFilter, +) -> Result { + debug_assert!(left_batch.num_rows() != 0 && right_batch.num_rows() != 0); + + let intermediate_batch = if filter.schema.fields().is_empty() { + // If filter is constant (e.g. literal `true`), empty batch can be used + // in the later filter step. + create_record_batch_with_empty_schema( + Arc::new((*filter.schema).clone()), + right_batch.num_rows(), + )? + } else { + build_row_join_batch( + &filter.schema, + left_batch, + l_index, + right_batch, + None, + &filter.column_indices, + JoinSide::Left, + )? + .ok_or_else(|| internal_datafusion_err!("This function assume input batch is not empty, so the intermediate batch can't be empty too"))? + }; + + let filter_result = filter + .expression() + .evaluate(&intermediate_batch)? + .into_array(intermediate_batch.num_rows())?; + let filter_arr = as_boolean_array(&filter_result)?; + + // [Caution] This step has previously introduced bugs + // The filter result is NOT a bitmap; it contains true/false/null values. + // For example, 1 < NULL is evaluated to NULL. Therefore, we must combine (AND) + // the boolean array with its null bitmap to construct a unified bitmap. + let (is_filtered, nulls) = filter_arr.clone().into_parts(); + let bitmap_combined = match nulls { + Some(nulls) => { + let combined = nulls.inner() & &is_filtered; + BooleanArray::new(combined, None) + } + None => BooleanArray::new(is_filtered, None), + }; + + Ok(bitmap_combined) +} + +/// This function performs the following steps: +/// 1. Apply filter to probe-side batch +/// 2. Broadcast the left row (build_side_batch\[build_side_index\]) to the +/// filtered probe-side batch +/// 3. Concat them together according to `col_indices`, and return the result +/// (None if the result is empty) +/// +/// Example: +/// build_side_batch: +/// a +/// ---- +/// 1 +/// 2 +/// 3 +/// +/// # 0 index element in the build_side_batch (that is `1`) will be used +/// build_side_index: 0 +/// +/// probe_side_batch: +/// b +/// ---- +/// 10 +/// 20 +/// 30 +/// 40 +/// +/// # After applying it, only index 1 and 3 elements in probe_side_batch will be +/// # kept +/// probe_side_filter: +/// false +/// true +/// false +/// true +/// +/// +/// # Projections to the build/probe side batch, to construct the output batch +/// col_indices: +/// [(left, 0), (right, 0)] +/// +/// build_side: left +/// +/// ==== +/// Result batch: +/// a b +/// ---- +/// 1 20 +/// 1 40 +fn build_row_join_batch( + output_schema: &Schema, + build_side_batch: &RecordBatch, + build_side_index: usize, + probe_side_batch: &RecordBatch, + probe_side_filter: Option, + // See [`NLJStream`] struct's `column_indices` field for more detail + col_indices: &[ColumnIndex], + // If the build side is left or right, used to interpret the side information + // in `col_indices` + build_side: JoinSide, +) -> Result> { + debug_assert!(build_side != JoinSide::None); + + // TODO(perf): since the output might be projection of right batch, this + // filtering step is more efficient to be done inside the column_index loop + let filtered_probe_batch = if let Some(filter) = probe_side_filter { + &filter_record_batch(probe_side_batch, &filter)? + } else { + probe_side_batch + }; + + if filtered_probe_batch.num_rows() == 0 { + return Ok(None); + } + + // Edge case: downstream operator does not require any columns from this NLJ, + // so allow an empty projection. + // Example: + // SELECT DISTINCT 32 AS col2 + // FROM tab0 AS cor0 + // LEFT OUTER JOIN tab2 AS cor1 + // ON ( NULL ) IS NULL; + if output_schema.fields.is_empty() { + return Ok(Some(create_record_batch_with_empty_schema( + Arc::new(output_schema.clone()), + filtered_probe_batch.num_rows(), + )?)); + } + + let mut columns: Vec> = + Vec::with_capacity(output_schema.fields().len()); + + for column_index in col_indices { + let array = if column_index.side == build_side { + // Broadcast the single build-side row to match the filtered + // probe-side batch length + let original_left_array = build_side_batch.column(column_index.index); + let scalar_value = ScalarValue::try_from_array( + original_left_array.as_ref(), + build_side_index, + )?; + scalar_value.to_array_of_size(filtered_probe_batch.num_rows())? + } else { + // Take the filtered probe-side column using compute::take + Arc::clone(filtered_probe_batch.column(column_index.index)) + }; + + columns.push(array); } + + Ok(Some(RecordBatch::try_new( + Arc::new(output_schema.clone()), + columns, + )?)) } -impl EmbeddedProjection for NestedLoopJoinExec { - fn with_projection(&self, projection: Option>) -> Result { - self.with_projection(projection) +/// Special case for `PlaceHolderRowExec` +/// Minimal example: SELECT 1 WHERE EXISTS (SELECT 1); +// +/// # Return +/// If Some, that's the result batch +/// If None, it's not for this special case. Continue execution. +fn build_unmatched_batch_empty_schema( + output_schema: SchemaRef, + batch_bitmap: &BooleanArray, + // For left/right/full joins, it needs to fill nulls for another side + join_type: JoinType, +) -> Result> { + let result_size = match join_type { + JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftAnti + | JoinType::RightAnti => batch_bitmap.false_count(), + JoinType::LeftSemi | JoinType::RightSemi => batch_bitmap.true_count(), + JoinType::LeftMark | JoinType::RightMark => batch_bitmap.len(), + _ => unreachable!(), + }; + + if output_schema.fields().is_empty() { + Ok(Some(create_record_batch_with_empty_schema( + Arc::clone(&output_schema), + result_size, + )?)) + } else { + Ok(None) + } +} + +/// Creates an empty RecordBatch with a specific row count. +/// This is useful for cases where we need a batch with the correct schema and row count +/// but no actual data columns (e.g., for constant filters). +fn create_record_batch_with_empty_schema( + schema: SchemaRef, + row_count: usize, +) -> Result { + let options = RecordBatchOptions::new() + .with_match_field_names(true) + .with_row_count(Some(row_count)); + + RecordBatch::try_new_with_options(schema, vec![], &options).map_err(|e| { + internal_datafusion_err!("Failed to create empty record batch: {}", e) + }) +} + +/// # Example: +/// batch: +/// a +/// ---- +/// 1 +/// 2 +/// 3 +/// +/// batch_bitmap: +/// ---- +/// false +/// true +/// false +/// +/// another_side_schema: +/// [(b, bool), (c, int32)] +/// +/// join_type: JoinType::Left +/// +/// col_indices: ...(please refer to the comment in `NLJStream::column_indices``) +/// +/// batch_side: right +/// +/// # Walkthrough: +/// +/// This executor is performing a right join, and the currently processed right +/// batch is as above. After joining it with all buffered left rows, the joined +/// entries are marked by the `batch_bitmap`. +/// This method will keep the unmatched indices on the batch side (right), and pad +/// the left side with nulls. The result would be: +/// +/// b c a +/// ------------------------ +/// Null(bool) Null(Int32) 1 +/// Null(bool) Null(Int32) 3 +fn build_unmatched_batch( + output_schema: SchemaRef, + batch: &RecordBatch, + batch_bitmap: BooleanArray, + // For left/right/full joins, it needs to fill nulls for another side + another_side_schema: SchemaRef, + col_indices: &[ColumnIndex], + join_type: JoinType, + batch_side: JoinSide, +) -> Result> { + // Should not call it for inner joins + debug_assert_ne!(join_type, JoinType::Inner); + debug_assert_ne!(batch_side, JoinSide::None); + + // Handle special case (see function comment) + if let Some(batch) = build_unmatched_batch_empty_schema( + Arc::clone(&output_schema), + &batch_bitmap, + join_type, + )? { + return Ok(Some(batch)); + } + + match join_type { + JoinType::Full | JoinType::Right | JoinType::Left => { + if join_type == JoinType::Right { + debug_assert_eq!(batch_side, JoinSide::Right); + } + if join_type == JoinType::Left { + debug_assert_eq!(batch_side, JoinSide::Left); + } + + // 1. Filter the batch with *flipped* bitmap + // 2. Fill left side with nulls + let flipped_bitmap = not(&batch_bitmap)?; + + // create a recordbatch, with left_schema, of only one row of all nulls + let left_null_columns: Vec> = another_side_schema + .fields() + .iter() + .map(|field| new_null_array(field.data_type(), 1)) + .collect(); + + // Hack: If the left schema is not nullable, the full join result + // might contain null, this is only a temporary batch to construct + // such full join result. + let nullable_left_schema = Arc::new(Schema::new( + another_side_schema + .fields() + .iter() + .map(|field| { + (**field).clone().with_nullable(true) + }) + .collect::>(), + )); + let left_null_batch = if nullable_left_schema.fields.is_empty() { + // Left input can be an empty relation, in this case left relation + // won't be used to construct the result batch (i.e. not in `col_indices`) + create_record_batch_with_empty_schema(nullable_left_schema, 0)? + } else { + RecordBatch::try_new(nullable_left_schema, left_null_columns)? + }; + + debug_assert_ne!(batch_side, JoinSide::None); + let opposite_side = batch_side.negate(); + + build_row_join_batch(&output_schema, &left_null_batch, 0, batch, Some(flipped_bitmap), col_indices, opposite_side) + + }, + JoinType::RightSemi | JoinType::RightAnti | JoinType::LeftSemi | JoinType::LeftAnti => { + if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) { + debug_assert_eq!(batch_side, JoinSide::Right); + } + if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + debug_assert_eq!(batch_side, JoinSide::Left); + } + + let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi) { + batch_bitmap.clone() + } else { + not(&batch_bitmap)? + }; + + if bitmap.true_count() == 0 { + return Ok(None); + } + + let mut columns: Vec> = + Vec::with_capacity(output_schema.fields().len()); + + for column_index in col_indices { + debug_assert!(column_index.side == batch_side); + + let col = batch.column(column_index.index); + let filtered_col = filter(col, &bitmap)?; + + columns.push(filtered_col); + } + + Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?)) + }, + JoinType::RightMark | JoinType::LeftMark => { + if join_type == JoinType::RightMark { + debug_assert_eq!(batch_side, JoinSide::Right); + } + if join_type == JoinType::LeftMark { + debug_assert_eq!(batch_side, JoinSide::Left); + } + + let mut columns: Vec> = + Vec::with_capacity(output_schema.fields().len()); + + // Hack to deal with the borrow checker + let mut right_batch_bitmap_opt = Some(batch_bitmap); + + for column_index in col_indices { + if column_index.side == batch_side { + let col = batch.column(column_index.index); + + columns.push(Arc::clone(col)); + } else if column_index.side == JoinSide::None { + let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt); + match right_batch_bitmap { + Some(right_batch_bitmap) => {columns.push(Arc::new(right_batch_bitmap))}, + None => unreachable!("Should only be one mark column"), + } + } else { + return internal_err!("Not possible to have this join side for RightMark join"); + } + } + + Ok(Some(RecordBatch::try_new(Arc::clone(&output_schema), columns)?)) + } + _ => internal_err!("If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"), } } #[cfg(test)] pub(crate) mod tests { use super::*; - use crate::test::TestMemoryExec; + use crate::test::{assert_join_metrics, TestMemoryExec}; use crate::{ common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, }; - use arrow::array::Int32Array; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field}; use datafusion_common::test_util::batches_to_sort_string; @@ -1055,6 +1919,7 @@ pub(crate) mod tests { use datafusion_physical_expr::{Partitioning, PhysicalExpr}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use insta::allow_duplicates; use insta::assert_snapshot; use rstest::rstest; @@ -1081,22 +1946,18 @@ pub(crate) mod tests { vec![batch] }; - let mut source = - TestMemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap(); - if !sorted_column_names.is_empty() { - let mut sort_info = LexOrdering::default(); - for name in sorted_column_names { - let index = schema.index_of(name).unwrap(); - let sort_expr = PhysicalSortExpr { - expr: Arc::new(Column::new(name, index)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }; - sort_info.push(sort_expr); - } - source = source.try_with_sort_information(vec![sort_info]).unwrap(); + let mut sort_info = vec![]; + for name in sorted_column_names { + let index = schema.index_of(name).unwrap(); + let sort_expr = PhysicalSortExpr::new( + Arc::new(Column::new(name, index)), + SortOptions::new(false, false), + ); + sort_info.push(sort_expr); + } + let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap(); + if let Some(ordering) = LexOrdering::new(sort_info) { + source = source.try_with_sort_information(vec![ordering]).unwrap(); } Arc::new(TestMemoryExec::update_cache(Arc::new(source))) @@ -1176,7 +2037,7 @@ pub(crate) mod tests { join_type: &JoinType, join_filter: Option, context: Arc, - ) -> Result<(Vec, Vec)> { + ) -> Result<(Vec, Vec, MetricsSet)> { let partition_count = 4; // Redistributing right input @@ -1196,20 +2057,35 @@ pub(crate) mod tests { batches.extend( more_batches .into_iter() + .inspect(|b| { + assert!(b.num_rows() <= context.session_config().batch_size()) + }) .filter(|b| b.num_rows() > 0) .collect::>(), ); } - Ok((columns, batches)) + + let metrics = nested_loop_join.metrics().unwrap(); + + Ok((columns, batches, metrics)) + } + + fn new_task_ctx(batch_size: usize) -> Arc { + let base = TaskContext::default(); + // limit max size of intermediate batch used in nlj to 1 + let cfg = base.session_config().clone().with_batch_size(batch_size); + Arc::new(base.with_session_config(cfg)) } + #[rstest] #[tokio::test] - async fn join_inner_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_inner_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); + dbg!(&batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::Inner, @@ -1217,26 +2093,30 @@ pub(crate) mod tests { task_ctx, ) .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+----+----+----+----+ | 5 | 5 | 50 | 2 | 2 | 80 | +----+----+----+----+----+----+ - "#); + "#)); + + assert_join_metrics!(metrics, 1); Ok(()) } + #[rstest] #[tokio::test] - async fn join_left_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::Left, @@ -1245,7 +2125,7 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+-----+----+----+----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+-----+----+----+----+ @@ -1253,19 +2133,22 @@ pub(crate) mod tests { | 5 | 5 | 50 | 2 | 2 | 80 | | 9 | 8 | 90 | | | | +----+----+-----+----+----+----+ - "#); + "#)); + + assert_join_metrics!(metrics, 3); Ok(()) } + #[rstest] #[tokio::test] - async fn join_right_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::Right, @@ -1274,7 +2157,7 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+----+----+----+-----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+----+----+----+-----+ @@ -1282,19 +2165,22 @@ pub(crate) mod tests { | | | | 12 | 10 | 40 | | 5 | 5 | 50 | 2 | 2 | 80 | +----+----+----+----+----+-----+ - "#); + "#)); + + assert_join_metrics!(metrics, 3); Ok(()) } + #[rstest] #[tokio::test] - async fn join_full_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_full_with_filter(#[values(1, 2, 16)] batch_size: usize) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::Full, @@ -1303,7 +2189,7 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+-----+----+----+-----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+-----+----+----+-----+ @@ -1313,19 +2199,24 @@ pub(crate) mod tests { | 5 | 5 | 50 | 2 | 2 | 80 | | 9 | 8 | 90 | | | | +----+----+-----+----+----+-----+ - "#); + "#)); + + assert_join_metrics!(metrics, 5); Ok(()) } + #[rstest] #[tokio::test] - async fn join_left_semi_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_semi_with_filter( + #[values(1, 2, 16)] batch_size: usize, + ) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::LeftSemi, @@ -1334,25 +2225,30 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+----+ | a1 | b1 | c1 | +----+----+----+ | 5 | 5 | 50 | +----+----+----+ - "#); + "#)); + + assert_join_metrics!(metrics, 1); Ok(()) } + #[rstest] #[tokio::test] - async fn join_left_anti_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_anti_with_filter( + #[values(1, 2, 16)] batch_size: usize, + ) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::LeftAnti, @@ -1361,26 +2257,51 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+-----+ | a1 | b1 | c1 | +----+----+-----+ | 11 | 8 | 110 | | 9 | 8 | 90 | +----+----+-----+ - "#); + "#)); + + assert_join_metrics!(metrics, 2); + + Ok(()) + } + #[tokio::test] + async fn join_has_correct_stats() -> Result<()> { + let left = build_left_table(); + let right = build_right_table(); + let nested_loop_join = NestedLoopJoinExec::try_new( + left, + right, + None, + &JoinType::Left, + Some(vec![1, 2]), + )?; + let stats = nested_loop_join.partition_statistics(None)?; + assert_eq!( + nested_loop_join.schema().fields().len(), + stats.column_statistics.len(), + ); + assert_eq!(2, stats.column_statistics.len()); Ok(()) } + #[rstest] #[tokio::test] - async fn join_right_semi_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_semi_with_filter( + #[values(1, 2, 16)] batch_size: usize, + ) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::RightSemi, @@ -1389,25 +2310,30 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a2", "b2", "c2"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+----+ | a2 | b2 | c2 | +----+----+----+ | 2 | 2 | 80 | +----+----+----+ - "#); + "#)); + + assert_join_metrics!(metrics, 1); Ok(()) } + #[rstest] #[tokio::test] - async fn join_right_anti_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_right_anti_with_filter( + #[values(1, 2, 16)] batch_size: usize, + ) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::RightAnti, @@ -1416,26 +2342,31 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a2", "b2", "c2"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+-----+ | a2 | b2 | c2 | +----+----+-----+ | 10 | 10 | 100 | | 12 | 10 | 40 | +----+----+-----+ - "#); + "#)); + + assert_join_metrics!(metrics, 2); Ok(()) } + #[rstest] #[tokio::test] - async fn join_left_mark_with_filter() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + async fn join_left_mark_with_filter( + #[values(1, 2, 16)] batch_size: usize, + ) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches, metrics) = multi_partitioned_join_collect( left, right, &JoinType::LeftMark, @@ -1444,7 +2375,7 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" +----+----+-----+-------+ | a1 | b1 | c1 | mark | +----+----+-----+-------+ @@ -1452,7 +2383,44 @@ pub(crate) mod tests { | 5 | 5 | 50 | true | | 9 | 8 | 90 | false | +----+----+-----+-------+ - "#); + "#)); + + assert_join_metrics!(metrics, 3); + + Ok(()) + } + + #[rstest] + #[tokio::test] + async fn join_right_mark_with_filter( + #[values(1, 2, 16)] batch_size: usize, + ) -> Result<()> { + let task_ctx = new_task_ctx(batch_size); + let left = build_left_table(); + let right = build_right_table(); + + let filter = prepare_join_filter(); + let (columns, batches, metrics) = multi_partitioned_join_collect( + left, + right, + &JoinType::RightMark, + Some(filter), + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]); + + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" + +----+----+-----+-------+ + | a2 | b2 | c2 | mark | + +----+----+-----+-------+ + | 10 | 10 | 100 | false | + | 12 | 10 | 40 | false | + | 2 | 2 | 80 | true | + +----+----+-----+-------+ + "#)); + + assert_join_metrics!(metrics, 3); Ok(()) } @@ -1485,6 +2453,7 @@ pub(crate) mod tests { JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, + JoinType::RightMark, ]; for join_type in join_types { @@ -1506,176 +2475,13 @@ pub(crate) mod tests { assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]" + "Resources exhausted: Additional allocation failed for NestedLoopJoinLoad[0] with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]" ); } Ok(()) } - fn prepare_mod_join_filter() -> JoinFilter { - let column_indices = vec![ - ColumnIndex { - index: 1, - side: JoinSide::Left, - }, - ColumnIndex { - index: 1, - side: JoinSide::Right, - }, - ]; - let intermediate_schema = Schema::new(vec![ - Field::new("x", DataType::Int32, true), - Field::new("x", DataType::Int32, true), - ]); - - // left.b1 % 3 - let left_mod = Arc::new(BinaryExpr::new( - Arc::new(Column::new("x", 0)), - Operator::Modulo, - Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), - )) as Arc; - // left.b1 % 3 != 0 - let left_filter = Arc::new(BinaryExpr::new( - left_mod, - Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), - )) as Arc; - - // right.b2 % 5 - let right_mod = Arc::new(BinaryExpr::new( - Arc::new(Column::new("x", 1)), - Operator::Modulo, - Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), - )) as Arc; - // right.b2 % 5 != 0 - let right_filter = Arc::new(BinaryExpr::new( - right_mod, - Operator::NotEq, - Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), - )) as Arc; - // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0 - let filter_expression = - Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter)) - as Arc; - - JoinFilter::new( - filter_expression, - column_indices, - Arc::new(intermediate_schema), - ) - } - - fn generate_columns(num_columns: usize, num_rows: usize) -> Vec> { - let column = (1..=num_rows).map(|x| x as i32).collect(); - vec![column; num_columns] - } - - #[rstest] - #[tokio::test] - async fn join_maintains_right_order( - #[values( - JoinType::Inner, - JoinType::Right, - JoinType::RightAnti, - JoinType::RightSemi - )] - join_type: JoinType, - #[values(1, 100, 1000)] left_batch_size: usize, - #[values(1, 100, 1000)] right_batch_size: usize, - ) -> Result<()> { - let left_columns = generate_columns(3, 1000); - let left = build_table( - ("a1", &left_columns[0]), - ("b1", &left_columns[1]), - ("c1", &left_columns[2]), - Some(left_batch_size), - Vec::new(), - ); - - let right_columns = generate_columns(3, 1000); - let right = build_table( - ("a2", &right_columns[0]), - ("b2", &right_columns[1]), - ("c2", &right_columns[2]), - Some(right_batch_size), - vec!["a2", "b2", "c2"], - ); - - let filter = prepare_mod_join_filter(); - - let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new( - left, - Arc::clone(&right), - Some(filter), - &join_type, - None, - )?) as Arc; - assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]); - - let right_column_indices = match join_type { - JoinType::Inner | JoinType::Right => vec![3, 4, 5], - JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2], - _ => unreachable!(), - }; - - let right_ordering = right.output_ordering().unwrap(); - let join_ordering = nested_loop_join.output_ordering().unwrap(); - for (right, join) in right_ordering.iter().zip(join_ordering.iter()) { - let right_column = right.expr.as_any().downcast_ref::().unwrap(); - let join_column = join.expr.as_any().downcast_ref::().unwrap(); - assert_eq!(join_column.name(), join_column.name()); - assert_eq!( - right_column_indices[right_column.index()], - join_column.index() - ); - assert_eq!(right.options, join.options); - } - - let batches = nested_loop_join - .execute(0, Arc::new(TaskContext::default()))? - .try_collect::>() - .await?; - - // Make sure that the order of the right side is maintained - let mut prev_values = [i32::MIN, i32::MIN, i32::MIN]; - - for (batch_index, batch) in batches.iter().enumerate() { - let columns: Vec<_> = right_column_indices - .iter() - .map(|&i| { - batch - .column(i) - .as_any() - .downcast_ref::() - .unwrap() - }) - .collect(); - - for row in 0..batch.num_rows() { - let current_values = [ - columns[0].value(row), - columns[1].value(row), - columns[2].value(row), - ]; - assert!( - current_values - .into_iter() - .zip(prev_values) - .all(|(current, prev)| current >= prev), - "batch_index: {} row: {} current: {:?}, prev: {:?}", - batch_index, - row, - current_values, - prev_values - ); - prev_values = current_values; - } - } - - Ok(()) - } - /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs deleted file mode 100644 index 716cff939f663..0000000000000 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ /dev/null @@ -1,4794 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines the Sort-Merge join execution plan. -//! A Sort-Merge join plan consumes two sorted children plan and produces -//! joined output by given join type and other options. -//! Sort-Merge join feature is currently experimental. - -use std::any::Any; -use std::cmp::Ordering; -use std::collections::{HashMap, VecDeque}; -use std::fmt::Formatter; -use std::fs::File; -use std::io::BufReader; -use std::mem::size_of; -use std::ops::Range; -use std::pin::Pin; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::Relaxed; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use crate::execution_plan::{boundedness_from_children, EmissionType}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn, - JoinOnRef, -}; -use crate::metrics::{ - Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, SpillMetrics, -}; -use crate::projection::{ - join_allows_pushdown, join_table_borders, new_join_children, - physical_to_column_exprs, update_join_on, ProjectionExec, -}; -use crate::spill::spill_manager::SpillManager; -use crate::{ - metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, - ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, -}; - -use arrow::array::{types::UInt64Type, *}; -use arrow::compute::{ - self, concat_batches, filter_record_batch, is_not_null, take, SortOptions, -}; -use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; -use arrow::error::ArrowError; -use arrow::ipc::reader::StreamReader; -use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_err, DataFusionError, HashSet, JoinSide, - JoinType, Result, -}; -use datafusion_execution::disk_manager::RefCountedTempFile; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; -use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr::PhysicalExprRef; -use datafusion_physical_expr_common::physical_expr::fmt_sql; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; - -use futures::{Stream, StreamExt}; - -/// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge -/// join algorithm and applies an optional filter post join. Can be used to join arbitrarily large -/// inputs where one or both of the inputs don't fit in the available memory. -/// -/// # Join Expressions -/// -/// Equi-join predicate (e.g. ` = `) expressions are represented by [`Self::on`]. -/// -/// Non-equality predicates, which can not be pushed down to join inputs (e.g. -/// ` != `) are known as "filter expressions" and are evaluated -/// after the equijoin predicates. They are represented by [`Self::filter`]. These are optional -/// expressions. -/// -/// # Sorting -/// -/// Assumes that both the left and right input to the join are pre-sorted. It is not the -/// responsibility of this execution plan to sort the inputs. -/// -/// # "Streamed" vs "Buffered" -/// -/// The number of record batches of streamed input currently present in the memory will depend -/// on the output batch size of the execution plan. There is no spilling support for streamed input. -/// The comparisons are performed from values of join keys in streamed input with the values of -/// join keys in buffered input. One row in streamed record batch could be matched with multiple rows in -/// buffered input batches. The streamed input is managed through the states in `StreamedState` -/// and streamed input batches are represented by `StreamedBatch`. -/// -/// Buffered input is buffered for all record batches having the same value of join key. -/// If the memory limit increases beyond the specified value and spilling is enabled, -/// buffered batches could be spilled to disk. If spilling is disabled, the execution -/// will fail under the same conditions. Multiple record batches of buffered could currently reside -/// in memory/disk during the execution. The number of buffered batches residing in -/// memory/disk depends on the number of rows of buffered input having the same value -/// of join key as that of streamed input rows currently present in memory. Due to pre-sorted inputs, -/// the algorithm understands when it is not needed anymore, and releases the buffered batches -/// from memory/disk. The buffered input is managed through the states in `BufferedState` -/// and buffered input batches are represented by `BufferedBatch`. -/// -/// Depending on the type of join, left or right input may be selected as streamed or buffered -/// respectively. For example, in a left-outer join, the left execution plan will be selected as -/// streamed input while in a right-outer join, the right execution plan will be selected as the -/// streamed input. -/// -/// Reference for the algorithm: -/// . -/// -/// Helpful short video demonstration: -/// . -#[derive(Debug, Clone)] -pub struct SortMergeJoinExec { - /// Left sorted joining execution plan - pub left: Arc, - /// Right sorting joining execution plan - pub right: Arc, - /// Set of common columns used to join on - pub on: JoinOn, - /// Filters which are applied while finding matching rows - pub filter: Option, - /// How the join is performed - pub join_type: JoinType, - /// The schema once the join is applied - schema: SchemaRef, - /// Execution metrics - metrics: ExecutionPlanMetricsSet, - /// The left SortExpr - left_sort_exprs: LexOrdering, - /// The right SortExpr - right_sort_exprs: LexOrdering, - /// Sort options of join columns used in sorting left and right execution plans - pub sort_options: Vec, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: bool, - /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, -} - -impl SortMergeJoinExec { - /// Tries to create a new [SortMergeJoinExec]. - /// The inputs are sorted using `sort_options` are applied to the columns in the `on` - /// # Error - /// This function errors when it is not possible to join the left and right sides on keys `on`. - pub fn try_new( - left: Arc, - right: Arc, - on: JoinOn, - filter: Option, - join_type: JoinType, - sort_options: Vec, - null_equals_null: bool, - ) -> Result { - let left_schema = left.schema(); - let right_schema = right.schema(); - - if join_type == JoinType::RightSemi { - return not_impl_err!( - "SortMergeJoinExec does not support JoinType::RightSemi" - ); - } - - check_join_is_valid(&left_schema, &right_schema, &on)?; - if sort_options.len() != on.len() { - return plan_err!( - "Expected number of sort options: {}, actual: {}", - on.len(), - sort_options.len() - ); - } - - let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on - .iter() - .zip(sort_options.iter()) - .map(|((l, r), sort_op)| { - let left = PhysicalSortExpr { - expr: Arc::clone(l), - options: *sort_op, - }; - let right = PhysicalSortExpr { - expr: Arc::clone(r), - options: *sort_op, - }; - (left, right) - }) - .unzip(); - - let schema = - Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); - let cache = - Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on); - Ok(Self { - left, - right, - on, - filter, - join_type, - schema, - metrics: ExecutionPlanMetricsSet::new(), - left_sort_exprs: LexOrdering::new(left_sort_exprs), - right_sort_exprs: LexOrdering::new(right_sort_exprs), - sort_options, - null_equals_null, - cache, - }) - } - - /// Get probe side (e.g streaming side) information for this sort merge join. - /// In current implementation, probe side is determined according to join type. - pub fn probe_side(join_type: &JoinType) -> JoinSide { - // When output schema contains only the right side, probe side is right. - // Otherwise probe side is the left side. - match join_type { - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - JoinSide::Right - } - JoinType::Inner - | JoinType::Left - | JoinType::Full - | JoinType::LeftAnti - | JoinType::LeftSemi - | JoinType::LeftMark => JoinSide::Left, - } - } - - /// Calculate order preservation flags for this sort merge join. - fn maintains_input_order(join_type: JoinType) -> Vec { - match join_type { - JoinType::Inner => vec![true, false], - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::LeftMark => vec![true, false], - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - vec![false, true] - } - _ => vec![false, false], - } - } - - /// Set of common columns used to join on - pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { - &self.on - } - - /// Ref to right execution plan - pub fn right(&self) -> &Arc { - &self.right - } - - /// Join type - pub fn join_type(&self) -> JoinType { - self.join_type - } - - /// Ref to left execution plan - pub fn left(&self) -> &Arc { - &self.left - } - - /// Ref to join filter - pub fn filter(&self) -> &Option { - &self.filter - } - - /// Ref to sort options - pub fn sort_options(&self) -> &[SortOptions] { - &self.sort_options - } - - /// Null equals null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - left: &Arc, - right: &Arc, - schema: SchemaRef, - join_type: JoinType, - join_on: JoinOnRef, - ) -> PlanProperties { - // Calculate equivalence properties: - let eq_properties = join_equivalence_properties( - left.equivalence_properties().clone(), - right.equivalence_properties().clone(), - &join_type, - schema, - &Self::maintains_input_order(join_type), - Some(Self::probe_side(&join_type)), - join_on, - ); - - let output_partitioning = - symmetric_join_output_partitioning(left, right, &join_type); - - PlanProperties::new( - eq_properties, - output_partitioning, - EmissionType::Incremental, - boundedness_from_children([left, right]), - ) - } - - pub fn swap_inputs(&self) -> Result> { - let left = self.left(); - let right = self.right(); - let new_join = SortMergeJoinExec::try_new( - Arc::clone(right), - Arc::clone(left), - self.on() - .iter() - .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) - .collect::>(), - self.filter().as_ref().map(JoinFilter::swap), - self.join_type().swap(), - self.sort_options.clone(), - self.null_equals_null, - )?; - - // TODO: OR this condition with having a built-in projection (like - // ordinary hash join) when we support it. - if matches!( - self.join_type(), - JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - ) { - Ok(Arc::new(new_join)) - } else { - reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) - } - } -} - -impl DisplayAs for SortMergeJoinExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - let on = self - .on - .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) - .collect::>() - .join(", "); - write!( - f, - "SortMergeJoin: join_type={:?}, on=[{}]{}", - self.join_type, - on, - self.filter.as_ref().map_or("".to_string(), |f| format!( - ", filter={}", - f.expression() - )) - ) - } - DisplayFormatType::TreeRender => { - let on = self - .on - .iter() - .map(|(c1, c2)| { - format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref())) - }) - .collect::>() - .join(", "); - - if self.join_type() != JoinType::Inner { - writeln!(f, "join_type={:?}", self.join_type)?; - } - writeln!(f, "on={}", on) - } - } - } -} - -impl ExecutionPlan for SortMergeJoinExec { - fn name(&self) -> &'static str { - "SortMergeJoinExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - &self.cache - } - - fn required_input_distribution(&self) -> Vec { - let (left_expr, right_expr) = self - .on - .iter() - .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) - .unzip(); - vec![ - Distribution::HashPartitioned(left_expr), - Distribution::HashPartitioned(right_expr), - ] - } - - fn required_input_ordering(&self) -> Vec> { - vec![ - Some(LexRequirement::from(self.left_sort_exprs.clone())), - Some(LexRequirement::from(self.right_sort_exprs.clone())), - ] - } - - fn maintains_input_order(&self) -> Vec { - Self::maintains_input_order(self.join_type) - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.left, &self.right] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - match &children[..] { - [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - self.on.clone(), - self.filter.clone(), - self.join_type, - self.sort_options.clone(), - self.null_equals_null, - )?)), - _ => internal_err!("SortMergeJoin wrong number of children"), - } - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let left_partitions = self.left.output_partitioning().partition_count(); - let right_partitions = self.right.output_partitioning().partition_count(); - if left_partitions != right_partitions { - return internal_err!( - "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ - consider using RepartitionExec" - ); - } - let (on_left, on_right) = self.on.iter().cloned().unzip(); - let (streamed, buffered, on_streamed, on_buffered) = - if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { - ( - Arc::clone(&self.left), - Arc::clone(&self.right), - on_left, - on_right, - ) - } else { - ( - Arc::clone(&self.right), - Arc::clone(&self.left), - on_right, - on_left, - ) - }; - - // execute children plans - let streamed = streamed.execute(partition, Arc::clone(&context))?; - let buffered = buffered.execute(partition, Arc::clone(&context))?; - - // create output buffer - let batch_size = context.session_config().batch_size(); - - // create memory reservation - let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]")) - .register(context.memory_pool()); - - // create join stream - Ok(Box::pin(SortMergeJoinStream::try_new( - Arc::clone(&self.schema), - self.sort_options.clone(), - self.null_equals_null, - streamed, - buffered, - on_streamed, - on_buffered, - self.filter.clone(), - self.join_type, - batch_size, - SortMergeJoinMetrics::new(partition, &self.metrics), - reservation, - context.runtime_env(), - )?)) - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - - fn statistics(&self) -> Result { - // TODO stats: it is not possible in general to know the output size of joins - // There are some special cases though, for example: - // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` - estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), - self.on.clone(), - &self.join_type, - &self.schema, - ) - } - - /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done, - /// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan. - /// Otherwise, it returns None. - fn try_swapping_with_projection( - &self, - projection: &ProjectionExec, - ) -> Result>> { - // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. - let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) - else { - return Ok(None); - }; - - let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( - self.left().schema().fields().len(), - &projection_as_columns, - ); - - if !join_allows_pushdown( - &projection_as_columns, - &self.schema(), - far_right_left_col_ind, - far_left_right_col_ind, - ) { - return Ok(None); - } - - let Some(new_on) = update_join_on( - &projection_as_columns[0..=far_right_left_col_ind as _], - &projection_as_columns[far_left_right_col_ind as _..], - self.on(), - self.left().schema().fields().len(), - ) else { - return Ok(None); - }; - - let (new_left, new_right) = new_join_children( - &projection_as_columns, - far_right_left_col_ind, - far_left_right_col_ind, - self.children()[0], - self.children()[1], - )?; - - Ok(Some(Arc::new(SortMergeJoinExec::try_new( - Arc::new(new_left), - Arc::new(new_right), - new_on, - self.filter.clone(), - self.join_type, - self.sort_options.clone(), - self.null_equals_null, - )?))) - } -} - -/// Metrics for SortMergeJoinExec -#[allow(dead_code)] -struct SortMergeJoinMetrics { - /// Total time for joining probe-side batches to the build-side batches - join_time: metrics::Time, - /// Number of batches consumed by this operator - input_batches: Count, - /// Number of rows consumed by this operator - input_rows: Count, - /// Number of batches produced by this operator - output_batches: Count, - /// Number of rows produced by this operator - output_rows: Count, - /// Peak memory used for buffered data. - /// Calculated as sum of peak memory values across partitions - peak_mem_used: metrics::Gauge, - /// Metrics related to spilling - spill_metrics: SpillMetrics, -} - -impl SortMergeJoinMetrics { - #[allow(dead_code)] - pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let output_batches = - MetricBuilder::new(metrics).counter("output_batches", partition); - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); - let spill_metrics = SpillMetrics::new(metrics, partition); - - Self { - join_time, - input_batches, - input_rows, - output_batches, - output_rows, - peak_mem_used, - spill_metrics, - } - } -} - -/// State of SMJ stream -#[derive(Debug, PartialEq, Eq)] -enum SortMergeJoinState { - /// Init joining with a new streamed row or a new buffered batches - Init, - /// Polling one streamed row or one buffered batch, or both - Polling, - /// Joining polled data and making output - JoinOutput, - /// No more output - Exhausted, -} - -/// State of streamed data stream -#[derive(Debug, PartialEq, Eq)] -enum StreamedState { - /// Init polling - Init, - /// Polling one streamed row - Polling, - /// Ready to produce one streamed row - Ready, - /// No more streamed row - Exhausted, -} - -/// State of buffered data stream -#[derive(Debug, PartialEq, Eq)] -enum BufferedState { - /// Init polling - Init, - /// Polling first row in the next batch - PollingFirst, - /// Polling rest rows in the next batch - PollingRest, - /// Ready to produce one batch - Ready, - /// No more buffered batches - Exhausted, -} - -/// Represents a chunk of joined data from streamed and buffered side -struct StreamedJoinedChunk { - /// Index of batch in buffered_data - buffered_batch_idx: Option, - /// Array builder for streamed indices - streamed_indices: UInt64Builder, - /// Array builder for buffered indices - /// This could contain nulls if the join is null-joined - buffered_indices: UInt64Builder, -} - -/// Represents a record batch from streamed input. -/// -/// Also stores information of matching rows from buffered batches. -struct StreamedBatch { - /// The streamed record batch - pub batch: RecordBatch, - /// The index of row in the streamed batch to compare with buffered batches - pub idx: usize, - /// The join key arrays of streamed batch which are used to compare with buffered batches - /// and to produce output. They are produced by evaluating `on` expressions. - pub join_arrays: Vec, - /// Chunks of indices from buffered side (may be nulls) joined to streamed - pub output_indices: Vec, - /// Index of currently scanned batch from buffered data - pub buffered_batch_idx: Option, - /// Indices that found a match for the given join filter - /// Used for semi joins to keep track the streaming index which got a join filter match - /// and already emitted to the output. - pub join_filter_matched_idxs: HashSet, -} - -impl StreamedBatch { - fn new(batch: RecordBatch, on_column: &[Arc]) -> Self { - let join_arrays = join_arrays(&batch, on_column); - StreamedBatch { - batch, - idx: 0, - join_arrays, - output_indices: vec![], - buffered_batch_idx: None, - join_filter_matched_idxs: HashSet::new(), - } - } - - fn new_empty(schema: SchemaRef) -> Self { - StreamedBatch { - batch: RecordBatch::new_empty(schema), - idx: 0, - join_arrays: vec![], - output_indices: vec![], - buffered_batch_idx: None, - join_filter_matched_idxs: HashSet::new(), - } - } - - /// Appends new pair consisting of current streamed index and `buffered_idx` - /// index of buffered batch with `buffered_batch_idx` index. - fn append_output_pair( - &mut self, - buffered_batch_idx: Option, - buffered_idx: Option, - ) { - // If no current chunk exists or current chunk is not for current buffered batch, - // create a new chunk - if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx - { - self.output_indices.push(StreamedJoinedChunk { - buffered_batch_idx, - streamed_indices: UInt64Builder::with_capacity(1), - buffered_indices: UInt64Builder::with_capacity(1), - }); - self.buffered_batch_idx = buffered_batch_idx; - }; - let current_chunk = self.output_indices.last_mut().unwrap(); - - // Append index of streamed batch and index of buffered batch into current chunk - current_chunk.streamed_indices.append_value(self.idx as u64); - if let Some(idx) = buffered_idx { - current_chunk.buffered_indices.append_value(idx as u64); - } else { - current_chunk.buffered_indices.append_null(); - } - } -} - -/// A buffered batch that contains contiguous rows with same join key -#[derive(Debug)] -struct BufferedBatch { - /// The buffered record batch - /// None if the batch spilled to disk th - pub batch: Option, - /// The range in which the rows share the same join key - pub range: Range, - /// Array refs of the join key - pub join_arrays: Vec, - /// Buffered joined index (null joining buffered) - pub null_joined: Vec, - /// Size estimation used for reserving / releasing memory - pub size_estimation: usize, - /// The indices of buffered batch that the join filter doesn't satisfy. - /// This is a map between right row index and a boolean value indicating whether all joined row - /// of the right row does not satisfy the filter . - /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. - pub join_filter_not_matched_map: HashMap, - /// Current buffered batch number of rows. Equal to batch.num_rows() - /// but if batch is spilled to disk this property is preferable - /// and less expensive - pub num_rows: usize, - /// An optional temp spill file name on the disk if the batch spilled - /// None by default - /// Some(fileName) if the batch spilled to the disk - pub spill_file: Option, -} - -impl BufferedBatch { - fn new( - batch: RecordBatch, - range: Range, - on_column: &[PhysicalExprRef], - ) -> Self { - let join_arrays = join_arrays(&batch, on_column); - - // Estimation is calculated as - // inner batch size - // + join keys size - // + worst case null_joined (as vector capacity * element size) - // + Range size - // + size of this estimation - let size_estimation = batch.get_array_memory_size() - + join_arrays - .iter() - .map(|arr| arr.get_array_memory_size()) - .sum::() - + batch.num_rows().next_power_of_two() * size_of::() - + size_of::>() - + size_of::(); - - let num_rows = batch.num_rows(); - BufferedBatch { - batch: Some(batch), - range, - join_arrays, - null_joined: vec![], - size_estimation, - join_filter_not_matched_map: HashMap::new(), - num_rows, - spill_file: None, - } - } -} - -/// Sort-Merge join stream that consumes streamed and buffered data streams -/// and produces joined output stream. -struct SortMergeJoinStream { - /// Current state of the stream - pub state: SortMergeJoinState, - /// Output schema - pub schema: SchemaRef, - /// Sort options of join columns used to sort streamed and buffered data stream - pub sort_options: Vec, - /// null == null? - pub null_equals_null: bool, - /// Input schema of streamed - pub streamed_schema: SchemaRef, - /// Input schema of buffered - pub buffered_schema: SchemaRef, - /// Streamed data stream - pub streamed: SendableRecordBatchStream, - /// Buffered data stream - pub buffered: SendableRecordBatchStream, - /// Current processing record batch of streamed - pub streamed_batch: StreamedBatch, - /// Current buffered data - pub buffered_data: BufferedData, - /// (used in outer join) Is current streamed row joined at least once? - pub streamed_joined: bool, - /// (used in outer join) Is current buffered batches joined at least once? - pub buffered_joined: bool, - /// State of streamed - pub streamed_state: StreamedState, - /// State of buffered - pub buffered_state: BufferedState, - /// The comparison result of current streamed row and buffered batches - pub current_ordering: Ordering, - /// Join key columns of streamed - pub on_streamed: Vec, - /// Join key columns of buffered - pub on_buffered: Vec, - /// optional join filter - pub filter: Option, - /// Staging output array builders - pub staging_output_record_batches: JoinedRecordBatches, - /// Output buffer. Currently used by filtering as it requires double buffering - /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches` - pub output: RecordBatch, - /// Staging output size, including output batches and staging joined results. - /// Increased when we put rows into buffer and decreased after we actually output batches. - /// Used to trigger output when sufficient rows are ready - pub output_size: usize, - /// Target output batch size - pub batch_size: usize, - /// How the join is performed - pub join_type: JoinType, - /// Metrics - pub join_metrics: SortMergeJoinMetrics, - /// Memory reservation - pub reservation: MemoryReservation, - /// Runtime env - pub runtime_env: Arc, - /// Manages the process of spilling and reading back intermediate data - pub spill_manager: SpillManager, - /// A unique number for each batch - pub streamed_batch_counter: AtomicUsize, -} - -/// Joined batches with attached join filter information -struct JoinedRecordBatches { - /// Joined batches. Each batch is already joined columns from left and right sources - pub batches: Vec, - /// Filter match mask for each row(matched/non-matched) - pub filter_mask: BooleanBuilder, - /// Row indices to glue together rows in `batches` and `filter_mask` - pub row_indices: UInt64Builder, - /// Which unique batch id the row belongs to - /// It is necessary to differentiate rows that are distributed the way when they point to the same - /// row index but in not the same batches - pub batch_ids: Vec, -} - -impl JoinedRecordBatches { - fn clear(&mut self) { - self.batches.clear(); - self.batch_ids.clear(); - self.filter_mask = BooleanBuilder::new(); - self.row_indices = UInt64Builder::new(); - } -} -impl RecordBatchStream for SortMergeJoinStream { - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } -} - -/// True if next index refers to either: -/// - another batch id -/// - another row index within same batch id -/// - end of row indices -#[inline(always)] -fn last_index_for_row( - row_index: usize, - indices: &UInt64Array, - batch_ids: &[usize], - indices_len: usize, -) -> bool { - row_index == indices_len - 1 - || batch_ids[row_index] != batch_ids[row_index + 1] - || indices.value(row_index) != indices.value(row_index + 1) -} - -// Returns a corrected boolean bitmask for the given join type -// Values in the corrected bitmask can be: true, false, null -// `true` - the row found its match and sent to the output -// `null` - the row ignored, no output -// `false` - the row sent as NULL joined row -fn get_corrected_filter_mask( - join_type: JoinType, - row_indices: &UInt64Array, - batch_ids: &[usize], - filter_mask: &BooleanArray, - expected_size: usize, -) -> Option { - let row_indices_length = row_indices.len(); - let mut corrected_mask: BooleanBuilder = - BooleanBuilder::with_capacity(row_indices_length); - let mut seen_true = false; - - match join_type { - JoinType::Left | JoinType::Right => { - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - if filter_mask.value(i) { - seen_true = true; - corrected_mask.append_value(true); - } else if seen_true || !filter_mask.value(i) && !last_index { - corrected_mask.append_null(); // to be ignored and not set to output - } else { - corrected_mask.append_value(false); // to be converted to null joined row - } - - if last_index { - seen_true = false; - } - } - - // Generate null joined rows for records which have no matching join key - corrected_mask.append_n(expected_size - corrected_mask.len(), false); - Some(corrected_mask.finish()) - } - JoinType::LeftMark => { - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - if filter_mask.value(i) && !seen_true { - seen_true = true; - corrected_mask.append_value(true); - } else if seen_true || !filter_mask.value(i) && !last_index { - corrected_mask.append_null(); // to be ignored and not set to output - } else { - corrected_mask.append_value(false); // to be converted to null joined row - } - - if last_index { - seen_true = false; - } - } - - // Generate null joined rows for records which have no matching join key - corrected_mask.append_n(expected_size - corrected_mask.len(), false); - Some(corrected_mask.finish()) - } - JoinType::LeftSemi => { - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - if filter_mask.value(i) && !seen_true { - seen_true = true; - corrected_mask.append_value(true); - } else { - corrected_mask.append_null(); // to be ignored and not set to output - } - - if last_index { - seen_true = false; - } - } - - Some(corrected_mask.finish()) - } - JoinType::LeftAnti | JoinType::RightAnti => { - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - - if filter_mask.value(i) { - seen_true = true; - } - - if last_index { - if !seen_true { - corrected_mask.append_value(true); - } else { - corrected_mask.append_null(); - } - - seen_true = false; - } else { - corrected_mask.append_null(); - } - } - // Generate null joined rows for records which have no matching join key, - // for LeftAnti non-matched considered as true - corrected_mask.append_n(expected_size - corrected_mask.len(), true); - Some(corrected_mask.finish()) - } - JoinType::Full => { - let mut mask: Vec> = vec![Some(true); row_indices_length]; - let mut last_true_idx = 0; - let mut first_row_idx = 0; - let mut seen_false = false; - - for i in 0..row_indices_length { - let last_index = - last_index_for_row(i, row_indices, batch_ids, row_indices_length); - let val = filter_mask.value(i); - let is_null = filter_mask.is_null(i); - - if val { - // memoize the first seen matched row - if !seen_true { - last_true_idx = i; - } - seen_true = true; - } - - if is_null || val { - mask[i] = Some(true); - } else if !is_null && !val && (seen_true || seen_false) { - mask[i] = None; - } else { - mask[i] = Some(false); - } - - if !is_null && !val { - seen_false = true; - } - - if last_index { - // If the left row seen as true its needed to output it once - // To do that we mark all other matches for same row as null to avoid the output - if seen_true { - #[allow(clippy::needless_range_loop)] - for j in first_row_idx..last_true_idx { - mask[j] = None; - } - } - - seen_true = false; - seen_false = false; - last_true_idx = 0; - first_row_idx = i + 1; - } - } - - Some(BooleanArray::from(mask)) - } - // Only outer joins needs to keep track of processed rows and apply corrected filter mask - _ => None, - } -} - -impl Stream for SortMergeJoinStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let join_time = self.join_metrics.join_time.clone(); - let _timer = join_time.timer(); - loop { - match &self.state { - SortMergeJoinState::Init => { - let streamed_exhausted = - self.streamed_state == StreamedState::Exhausted; - let buffered_exhausted = - self.buffered_state == BufferedState::Exhausted; - self.state = if streamed_exhausted && buffered_exhausted { - SortMergeJoinState::Exhausted - } else { - match self.current_ordering { - Ordering::Less | Ordering::Equal => { - if !streamed_exhausted { - if self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftMark - | JoinType::Right - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::Full - ) - { - self.freeze_all()?; - - // If join is filtered and there is joined tuples waiting - // to be filtered - if !self - .staging_output_record_batches - .batches - .is_empty() - { - // Apply filter on joined tuples and get filtered batch - let out_filtered_batch = - self.filter_joined_batch()?; - - // Append filtered batch to the output buffer - self.output = concat_batches( - &self.schema(), - vec![&self.output, &out_filtered_batch], - )?; - - // Send to output if the output buffer surpassed the `batch_size` - if self.output.num_rows() >= self.batch_size { - let record_batch = std::mem::replace( - &mut self.output, - RecordBatch::new_empty( - out_filtered_batch.schema(), - ), - ); - return Poll::Ready(Some(Ok( - record_batch, - ))); - } - } - } - - self.streamed_joined = false; - self.streamed_state = StreamedState::Init; - } - } - Ordering::Greater => { - if !buffered_exhausted { - self.buffered_joined = false; - self.buffered_state = BufferedState::Init; - } - } - } - SortMergeJoinState::Polling - }; - } - SortMergeJoinState::Polling => { - if ![StreamedState::Exhausted, StreamedState::Ready] - .contains(&self.streamed_state) - { - match self.poll_streamed_row(cx)? { - Poll::Ready(_) => {} - Poll::Pending => return Poll::Pending, - } - } - - if ![BufferedState::Exhausted, BufferedState::Ready] - .contains(&self.buffered_state) - { - match self.poll_buffered_batches(cx)? { - Poll::Ready(_) => {} - Poll::Pending => return Poll::Pending, - } - } - let streamed_exhausted = - self.streamed_state == StreamedState::Exhausted; - let buffered_exhausted = - self.buffered_state == BufferedState::Exhausted; - if streamed_exhausted && buffered_exhausted { - self.state = SortMergeJoinState::Exhausted; - continue; - } - self.current_ordering = self.compare_streamed_buffered()?; - self.state = SortMergeJoinState::JoinOutput; - } - SortMergeJoinState::JoinOutput => { - self.join_partial()?; - - if self.output_size < self.batch_size { - if self.buffered_data.scanning_finished() { - self.buffered_data.scanning_reset(); - self.state = SortMergeJoinState::Init; - } - } else { - self.freeze_all()?; - if !self.staging_output_record_batches.batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - // For non-filtered join output whenever the target output batch size - // is hit. For filtered join its needed to output on later phase - // because target output batch size can be hit in the middle of - // filtering causing the filtering to be incomplete and causing - // correctness issues - if self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::Right - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark - | JoinType::Full - ) - { - continue; - } - - return Poll::Ready(Some(Ok(record_batch))); - } - return Poll::Pending; - } - } - SortMergeJoinState::Exhausted => { - self.freeze_all()?; - - // if there is still something not processed - if !self.staging_output_record_batches.batches.is_empty() { - if self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::Right - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::Full - | JoinType::LeftMark - ) - { - let record_batch = self.filter_joined_batch()?; - return Poll::Ready(Some(Ok(record_batch))); - } else { - let record_batch = self.output_record_batch_and_reset()?; - return Poll::Ready(Some(Ok(record_batch))); - } - } else if self.output.num_rows() > 0 { - // if processed but still not outputted because it didn't hit batch size before - let schema = self.output.schema(); - let record_batch = std::mem::replace( - &mut self.output, - RecordBatch::new_empty(schema), - ); - return Poll::Ready(Some(Ok(record_batch))); - } else { - return Poll::Ready(None); - } - } - } - } - } -} - -impl SortMergeJoinStream { - #[allow(clippy::too_many_arguments)] - pub fn try_new( - schema: SchemaRef, - sort_options: Vec, - null_equals_null: bool, - streamed: SendableRecordBatchStream, - buffered: SendableRecordBatchStream, - on_streamed: Vec>, - on_buffered: Vec>, - filter: Option, - join_type: JoinType, - batch_size: usize, - join_metrics: SortMergeJoinMetrics, - reservation: MemoryReservation, - runtime_env: Arc, - ) -> Result { - let streamed_schema = streamed.schema(); - let buffered_schema = buffered.schema(); - let spill_manager = SpillManager::new( - Arc::clone(&runtime_env), - join_metrics.spill_metrics.clone(), - Arc::clone(&buffered_schema), - ); - Ok(Self { - state: SortMergeJoinState::Init, - sort_options, - null_equals_null, - schema: Arc::clone(&schema), - streamed_schema: Arc::clone(&streamed_schema), - buffered_schema, - streamed, - buffered, - streamed_batch: StreamedBatch::new_empty(streamed_schema), - buffered_data: BufferedData::default(), - streamed_joined: false, - buffered_joined: false, - streamed_state: StreamedState::Init, - buffered_state: BufferedState::Init, - current_ordering: Ordering::Equal, - on_streamed, - on_buffered, - filter, - staging_output_record_batches: JoinedRecordBatches { - batches: vec![], - filter_mask: BooleanBuilder::new(), - row_indices: UInt64Builder::new(), - batch_ids: vec![], - }, - output: RecordBatch::new_empty(schema), - output_size: 0, - batch_size, - join_type, - join_metrics, - reservation, - runtime_env, - spill_manager, - streamed_batch_counter: AtomicUsize::new(0), - }) - } - - /// Poll next streamed row - fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { - loop { - match &self.streamed_state { - StreamedState::Init => { - if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows() - { - self.streamed_batch.idx += 1; - self.streamed_state = StreamedState::Ready; - return Poll::Ready(Some(Ok(()))); - } else { - self.streamed_state = StreamedState::Polling; - } - } - StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? { - Poll::Pending => { - return Poll::Pending; - } - Poll::Ready(None) => { - self.streamed_state = StreamedState::Exhausted; - } - Poll::Ready(Some(batch)) => { - if batch.num_rows() > 0 { - self.freeze_streamed()?; - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - self.streamed_batch = - StreamedBatch::new(batch, &self.on_streamed); - // Every incoming streaming batch should have its unique id - // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation - self.streamed_batch_counter - .fetch_add(1, std::sync::atomic::Ordering::SeqCst); - self.streamed_state = StreamedState::Ready; - } - } - }, - StreamedState::Ready => { - return Poll::Ready(Some(Ok(()))); - } - StreamedState::Exhausted => { - return Poll::Ready(None); - } - } - } - } - - fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> { - // Shrink memory usage for in-memory batches only - if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() { - self.reservation - .try_shrink(buffered_batch.size_estimation)?; - } - - Ok(()) - } - - fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { - match self.reservation.try_grow(buffered_batch.size_estimation) { - Ok(_) => { - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - Ok(()) - } - Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => { - // Spill buffered batch to disk - if let Some(batch) = buffered_batch.batch { - let spill_file = self - .spill_manager - .spill_record_batch_and_finish( - &[batch], - "sort_merge_join_buffered_spill", - )? - .unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled - - buffered_batch.spill_file = Some(spill_file); - buffered_batch.batch = None; - - Ok(()) - } else { - internal_err!("Buffered batch has empty body") - } - } - Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()), - }?; - - self.buffered_data.batches.push_back(buffered_batch); - Ok(()) - } - - /// Poll next buffered batches - fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll>> { - loop { - match &self.buffered_state { - BufferedState::Init => { - // pop previous buffered batches - while !self.buffered_data.batches.is_empty() { - let head_batch = self.buffered_data.head_batch(); - // If the head batch is fully processed, dequeue it and produce output of it. - if head_batch.range.end == head_batch.num_rows { - self.freeze_dequeuing_buffered()?; - if let Some(mut buffered_batch) = - self.buffered_data.batches.pop_front() - { - self.produce_buffered_not_matched(&mut buffered_batch)?; - self.free_reservation(buffered_batch)?; - } - } else { - // If the head batch is not fully processed, break the loop. - // Streamed batch will be joined with the head batch in the next step. - break; - } - } - if self.buffered_data.batches.is_empty() { - self.buffered_state = BufferedState::PollingFirst; - } else { - let tail_batch = self.buffered_data.tail_batch_mut(); - tail_batch.range.start = tail_batch.range.end; - tail_batch.range.end += 1; - self.buffered_state = BufferedState::PollingRest; - } - } - BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? { - Poll::Pending => { - return Poll::Pending; - } - Poll::Ready(None) => { - self.buffered_state = BufferedState::Exhausted; - return Poll::Ready(None); - } - Poll::Ready(Some(batch)) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - if batch.num_rows() > 0 { - let buffered_batch = - BufferedBatch::new(batch, 0..1, &self.on_buffered); - - self.allocate_reservation(buffered_batch)?; - self.buffered_state = BufferedState::PollingRest; - } - } - }, - BufferedState::PollingRest => { - if self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().num_rows - { - while self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().num_rows - { - if is_join_arrays_equal( - &self.buffered_data.head_batch().join_arrays, - self.buffered_data.head_batch().range.start, - &self.buffered_data.tail_batch().join_arrays, - self.buffered_data.tail_batch().range.end, - )? { - self.buffered_data.tail_batch_mut().range.end += 1; - } else { - self.buffered_state = BufferedState::Ready; - return Poll::Ready(Some(Ok(()))); - } - } - } else { - match self.buffered.poll_next_unpin(cx)? { - Poll::Pending => { - return Poll::Pending; - } - Poll::Ready(None) => { - self.buffered_state = BufferedState::Ready; - } - Poll::Ready(Some(batch)) => { - // Polling batches coming concurrently as multiple partitions - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - if batch.num_rows() > 0 { - let buffered_batch = BufferedBatch::new( - batch, - 0..0, - &self.on_buffered, - ); - self.allocate_reservation(buffered_batch)?; - } - } - } - } - } - BufferedState::Ready => { - return Poll::Ready(Some(Ok(()))); - } - BufferedState::Exhausted => { - return Poll::Ready(None); - } - } - } - } - - /// Get comparison result of streamed row and buffered batches - fn compare_streamed_buffered(&self) -> Result { - if self.streamed_state == StreamedState::Exhausted { - return Ok(Ordering::Greater); - } - if !self.buffered_data.has_buffered_rows() { - return Ok(Ordering::Less); - } - - compare_join_arrays( - &self.streamed_batch.join_arrays, - self.streamed_batch.idx, - &self.buffered_data.head_batch().join_arrays, - self.buffered_data.head_batch().range.start, - &self.sort_options, - self.null_equals_null, - ) - } - - /// Produce join and fill output buffer until reaching target batch size - /// or the join is finished - fn join_partial(&mut self) -> Result<()> { - // Whether to join streamed rows - let mut join_streamed = false; - // Whether to join buffered rows - let mut join_buffered = false; - // For Mark join we store a dummy id to indicate the the row has a match - let mut mark_row_as_match = false; - - // determine whether we need to join streamed/buffered rows - match self.current_ordering { - Ordering::Less => { - if matches!( - self.join_type, - JoinType::Left - | JoinType::Right - | JoinType::RightSemi - | JoinType::Full - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark - ) { - join_streamed = !self.streamed_joined; - } - } - Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) { - mark_row_as_match = matches!(self.join_type, JoinType::LeftMark); - // if the join filter is specified then its needed to output the streamed index - // only if it has not been emitted before - // the `join_filter_matched_idxs` keeps track on if streamed index has a successful - // filter match and prevents the same index to go into output more than once - if self.filter.is_some() { - join_streamed = !self - .streamed_batch - .join_filter_matched_idxs - .contains(&(self.streamed_batch.idx as u64)) - && !self.streamed_joined; - // if the join filter specified there can be references to buffered columns - // so buffered columns are needed to access them - join_buffered = join_streamed; - } else { - join_streamed = !self.streamed_joined; - } - } - if matches!( - self.join_type, - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full - ) { - join_streamed = true; - join_buffered = true; - }; - - if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti) - && self.filter.is_some() - { - join_streamed = !self.streamed_joined; - join_buffered = join_streamed; - } - } - Ordering::Greater => { - if matches!(self.join_type, JoinType::Full) { - join_buffered = !self.buffered_joined; - }; - } - } - if !join_streamed && !join_buffered { - // no joined data - self.buffered_data.scanning_finish(); - return Ok(()); - } - - if join_buffered { - // joining streamed/nulls and buffered - while !self.buffered_data.scanning_finished() - && self.output_size < self.batch_size - { - let scanning_idx = self.buffered_data.scanning_idx(); - if join_streamed { - // Join streamed row and buffered row - self.streamed_batch.append_output_pair( - Some(self.buffered_data.scanning_batch_idx), - Some(scanning_idx), - ); - } else { - // Join nulls and buffered row for FULL join - self.buffered_data - .scanning_batch_mut() - .null_joined - .push(scanning_idx); - } - self.output_size += 1; - self.buffered_data.scanning_advance(); - - if self.buffered_data.scanning_finished() { - self.streamed_joined = join_streamed; - self.buffered_joined = true; - } - } - } else { - // joining streamed and nulls - let scanning_batch_idx = if self.buffered_data.scanning_finished() { - None - } else { - Some(self.buffered_data.scanning_batch_idx) - }; - // For Mark join we store a dummy id to indicate the the row has a match - let scanning_idx = mark_row_as_match.then_some(0); - - self.streamed_batch - .append_output_pair(scanning_batch_idx, scanning_idx); - self.output_size += 1; - self.buffered_data.scanning_finish(); - self.streamed_joined = true; - } - Ok(()) - } - - fn freeze_all(&mut self) -> Result<()> { - self.freeze_buffered(self.buffered_data.batches.len())?; - self.freeze_streamed()?; - Ok(()) - } - - // Produces and stages record batches to ensure dequeued buffered batch - // no longer needed: - // 1. freezes all indices joined to streamed side - // 2. freezes NULLs joined to dequeued buffered batch to "release" it - fn freeze_dequeuing_buffered(&mut self) -> Result<()> { - self.freeze_streamed()?; - // Only freeze and produce the first batch in buffered_data as the batch is fully processed - self.freeze_buffered(1)?; - Ok(()) - } - - // Produces and stages record batch from buffered indices with corresponding - // NULLs on streamed side. - // - // Applicable only in case of Full join. - // - fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { - if !matches!(self.join_type, JoinType::Full) { - return Ok(()); - } - for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) { - let buffered_indices = UInt64Array::from_iter_values( - buffered_batch.null_joined.iter().map(|&index| index as u64), - ); - if let Some(record_batch) = produce_buffered_null_batch( - &self.schema, - &self.streamed_schema, - &buffered_indices, - buffered_batch, - )? { - let num_rows = record_batch.num_rows(); - self.staging_output_record_batches - .filter_mask - .append_nulls(num_rows); - self.staging_output_record_batches - .row_indices - .append_nulls(num_rows); - self.staging_output_record_batches.batch_ids.resize( - self.staging_output_record_batches.batch_ids.len() + num_rows, - 0, - ); - - self.staging_output_record_batches - .batches - .push(record_batch); - } - buffered_batch.null_joined.clear(); - } - Ok(()) - } - - fn produce_buffered_not_matched( - &mut self, - buffered_batch: &mut BufferedBatch, - ) -> Result<()> { - if !matches!(self.join_type, JoinType::Full) { - return Ok(()); - } - - // For buffered row which is joined with streamed side rows but all joined rows - // don't satisfy the join filter - let not_matched_buffered_indices = buffered_batch - .join_filter_not_matched_map - .iter() - .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) - .collect::>(); - - let buffered_indices = - UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied()); - - if let Some(record_batch) = produce_buffered_null_batch( - &self.schema, - &self.streamed_schema, - &buffered_indices, - buffered_batch, - )? { - let num_rows = record_batch.num_rows(); - - self.staging_output_record_batches - .filter_mask - .append_nulls(num_rows); - self.staging_output_record_batches - .row_indices - .append_nulls(num_rows); - self.staging_output_record_batches.batch_ids.resize( - self.staging_output_record_batches.batch_ids.len() + num_rows, - 0, - ); - self.staging_output_record_batches - .batches - .push(record_batch); - } - buffered_batch.join_filter_not_matched_map.clear(); - - Ok(()) - } - - // Produces and stages record batch for all output indices found - // for current streamed batch and clears staged output indices. - fn freeze_streamed(&mut self) -> Result<()> { - for chunk in self.streamed_batch.output_indices.iter_mut() { - // The row indices of joined streamed batch - let left_indices = chunk.streamed_indices.finish(); - - if left_indices.is_empty() { - continue; - } - - let mut left_columns = self - .streamed_batch - .batch - .columns() - .iter() - .map(|column| take(column, &left_indices, None)) - .collect::, ArrowError>>()?; - - // The row indices of joined buffered batch - let right_indices: UInt64Array = chunk.buffered_indices.finish(); - let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) { - vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] - } else if matches!( - self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti - ) { - vec![] - } else if let Some(buffered_idx) = chunk.buffered_batch_idx { - fetch_right_columns_by_idxs( - &self.buffered_data, - buffered_idx, - &right_indices, - )? - } else { - // If buffered batch none, meaning it is null joined batch. - // We need to create null arrays for buffered columns to join with streamed rows. - create_unmatched_columns( - self.join_type, - &self.buffered_schema, - right_indices.len(), - ) - }; - - // Prepare the columns we apply join filter on later. - // Only for joined rows between streamed and buffered. - let filter_columns = if chunk.buffered_batch_idx.is_some() { - if !matches!(self.join_type, JoinType::Right) { - if matches!( - self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark - ) { - let right_cols = fetch_right_columns_by_idxs( - &self.buffered_data, - chunk.buffered_batch_idx.unwrap(), - &right_indices, - )?; - - get_filter_column(&self.filter, &left_columns, &right_cols) - } else if matches!(self.join_type, JoinType::RightAnti) { - let right_cols = fetch_right_columns_by_idxs( - &self.buffered_data, - chunk.buffered_batch_idx.unwrap(), - &right_indices, - )?; - - get_filter_column(&self.filter, &right_cols, &left_columns) - } else { - get_filter_column(&self.filter, &left_columns, &right_columns) - } - } else { - get_filter_column(&self.filter, &right_columns, &left_columns) - } - } else { - // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. - // Any join filter applied only on either streamed or buffered side will be pushed already. - vec![] - }; - - let columns = if !matches!(self.join_type, JoinType::Right) { - left_columns.extend(right_columns); - left_columns - } else { - right_columns.extend(left_columns); - right_columns - }; - - let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; - // Apply join filter if any - if !filter_columns.is_empty() { - if let Some(f) = &self.filter { - // Construct batch with only filter columns - let filter_batch = - RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?; - - let filter_result = f - .expression() - .evaluate(&filter_batch)? - .into_array(filter_batch.num_rows())?; - - // The boolean selection mask of the join filter result - let pre_mask = - datafusion_common::cast::as_boolean_array(&filter_result)?; - - // If there are nulls in join filter result, exclude them from selecting - // the rows to output. - let mask = if pre_mask.null_count() > 0 { - compute::prep_null_mask_filter( - datafusion_common::cast::as_boolean_array(&filter_result)?, - ) - } else { - pre_mask.clone() - }; - - // Push the filtered batch which contains rows passing join filter to the output - if matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::Right - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark - | JoinType::Full - ) { - self.staging_output_record_batches - .batches - .push(output_batch); - } else { - let filtered_batch = filter_record_batch(&output_batch, &mask)?; - self.staging_output_record_batches - .batches - .push(filtered_batch); - } - - if !matches!(self.join_type, JoinType::Full) { - self.staging_output_record_batches.filter_mask.extend(&mask); - } else { - self.staging_output_record_batches - .filter_mask - .extend(pre_mask); - } - self.staging_output_record_batches - .row_indices - .extend(&left_indices); - self.staging_output_record_batches.batch_ids.resize( - self.staging_output_record_batches.batch_ids.len() - + left_indices.len(), - self.streamed_batch_counter.load(Relaxed), - ); - - // For outer joins, we need to push the null joined rows to the output if - // all joined rows are failed on the join filter. - // I.e., if all rows joined from a streamed row are failed with the join filter, - // we need to join it with nulls as buffered side. - if matches!(self.join_type, JoinType::Full) { - let buffered_batch = &mut self.buffered_data.batches - [chunk.buffered_batch_idx.unwrap()]; - - for i in 0..pre_mask.len() { - // If the buffered row is not joined with streamed side, - // skip it. - if right_indices.is_null(i) { - continue; - } - - let buffered_index = right_indices.value(i); - - buffered_batch.join_filter_not_matched_map.insert( - buffered_index, - *buffered_batch - .join_filter_not_matched_map - .get(&buffered_index) - .unwrap_or(&true) - && !pre_mask.value(i), - ); - } - } - } else { - self.staging_output_record_batches - .batches - .push(output_batch); - } - } else { - self.staging_output_record_batches - .batches - .push(output_batch); - } - } - - self.streamed_batch.output_indices.clear(); - - Ok(()) - } - - fn output_record_batch_and_reset(&mut self) -> Result { - let record_batch = - concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(record_batch.num_rows()); - // If join filter exists, `self.output_size` is not accurate as we don't know the exact - // number of rows in the output record batch. If streamed row joined with buffered rows, - // once join filter is applied, the number of output rows may be more than 1. - // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened - // when the join filter is applied and all rows are filtered out. - if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size { - self.output_size = 0; - } else { - self.output_size -= record_batch.num_rows(); - } - - if !(self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::Right - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark - | JoinType::Full - )) - { - self.staging_output_record_batches.batches.clear(); - } - Ok(record_batch) - } - - fn filter_joined_batch(&mut self) -> Result { - let record_batch = - concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; - let mut out_indices = self.staging_output_record_batches.row_indices.finish(); - let mut out_mask = self.staging_output_record_batches.filter_mask.finish(); - let mut batch_ids = &self.staging_output_record_batches.batch_ids; - let default_batch_ids = vec![0; record_batch.num_rows()]; - - // If only nulls come in and indices sizes doesn't match with expected record batch count - // generate missing indices - // Happens for null joined batches for Full Join - if out_indices.null_count() == out_indices.len() - && out_indices.len() != record_batch.num_rows() - { - out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]); - out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]); - batch_ids = &default_batch_ids; - } - - if out_mask.is_empty() { - self.staging_output_record_batches.batches.clear(); - return Ok(record_batch); - } - - let maybe_corrected_mask = get_corrected_filter_mask( - self.join_type, - &out_indices, - batch_ids, - &out_mask, - record_batch.num_rows(), - ); - - let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask { - filtered_join_mask - } else { - &out_mask - }; - - self.filter_record_batch_by_join_type(record_batch, corrected_mask) - } - - fn filter_record_batch_by_join_type( - &mut self, - record_batch: RecordBatch, - corrected_mask: &BooleanArray, - ) -> Result { - let mut filtered_record_batch = - filter_record_batch(&record_batch, corrected_mask)?; - let left_columns_length = self.streamed_schema.fields.len(); - let right_columns_length = self.buffered_schema.fields.len(); - - if matches!( - self.join_type, - JoinType::Left | JoinType::LeftMark | JoinType::Right - ) { - let null_mask = compute::not(corrected_mask)?; - let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; - - let mut right_columns = create_unmatched_columns( - self.join_type, - &self.buffered_schema, - null_joined_batch.num_rows(), - ); - - let columns = if !matches!(self.join_type, JoinType::Right) { - let mut left_columns = null_joined_batch - .columns() - .iter() - .take(right_columns_length) - .cloned() - .collect::>(); - - left_columns.extend(right_columns); - left_columns - } else { - let left_columns = null_joined_batch - .columns() - .iter() - .skip(left_columns_length) - .cloned() - .collect::>(); - - right_columns.extend(left_columns); - right_columns - }; - - // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = - RecordBatch::try_new(Arc::clone(&self.schema), columns)?; - - filtered_record_batch = concat_batches( - &self.schema, - &[filtered_record_batch, null_joined_streamed_batch], - )?; - } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { - let output_column_indices = (0..left_columns_length).collect::>(); - filtered_record_batch = - filtered_record_batch.project(&output_column_indices)?; - } else if matches!(self.join_type, JoinType::RightAnti) { - let output_column_indices = (0..right_columns_length).collect::>(); - filtered_record_batch = - filtered_record_batch.project(&output_column_indices)?; - } else if matches!(self.join_type, JoinType::Full) - && corrected_mask.false_count() > 0 - { - // Find rows which joined by key but Filter predicate evaluated as false - let joined_filter_not_matched_mask = compute::not(corrected_mask)?; - let joined_filter_not_matched_batch = - filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?; - - // Add left unmatched rows adding the right side as nulls - let right_null_columns = self - .buffered_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - joined_filter_not_matched_batch.num_rows(), - ) - }) - .collect::>(); - - let mut result_joined = joined_filter_not_matched_batch - .columns() - .iter() - .take(left_columns_length) - .cloned() - .collect::>(); - - result_joined.extend(right_null_columns); - - let left_null_joined_batch = - RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?; - - // Add right unmatched rows adding the left side as nulls - let mut result_joined = self - .streamed_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - joined_filter_not_matched_batch.num_rows(), - ) - }) - .collect::>(); - - let right_data = joined_filter_not_matched_batch - .columns() - .iter() - .skip(left_columns_length) - .cloned() - .collect::>(); - - result_joined.extend(right_data); - - filtered_record_batch = concat_batches( - &self.schema, - &[filtered_record_batch, left_null_joined_batch], - )?; - } - - self.staging_output_record_batches.clear(); - - Ok(filtered_record_batch) - } -} - -fn create_unmatched_columns( - join_type: JoinType, - schema: &SchemaRef, - size: usize, -) -> Vec { - if matches!(join_type, JoinType::LeftMark) { - vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef] - } else { - schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), size)) - .collect::>() - } -} - -/// Gets the arrays which join filters are applied on. -fn get_filter_column( - join_filter: &Option, - streamed_columns: &[ArrayRef], - buffered_columns: &[ArrayRef], -) -> Vec { - let mut filter_columns = vec![]; - - if let Some(f) = join_filter { - let left_columns = f - .column_indices() - .iter() - .filter(|col_index| col_index.side == JoinSide::Left) - .map(|i| Arc::clone(&streamed_columns[i.index])) - .collect::>(); - - let right_columns = f - .column_indices() - .iter() - .filter(|col_index| col_index.side == JoinSide::Right) - .map(|i| Arc::clone(&buffered_columns[i.index])) - .collect::>(); - - filter_columns.extend(left_columns); - filter_columns.extend(right_columns); - } - - filter_columns -} - -fn produce_buffered_null_batch( - schema: &SchemaRef, - streamed_schema: &SchemaRef, - buffered_indices: &PrimitiveArray, - buffered_batch: &BufferedBatch, -) -> Result> { - if buffered_indices.is_empty() { - return Ok(None); - } - - // Take buffered (right) columns - let right_columns = - fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?; - - // Create null streamed (left) columns - let mut left_columns = streamed_schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), buffered_indices.len())) - .collect::>(); - - left_columns.extend(right_columns); - - Ok(Some(RecordBatch::try_new( - Arc::clone(schema), - left_columns, - )?)) -} - -/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices -#[inline(always)] -fn fetch_right_columns_by_idxs( - buffered_data: &BufferedData, - buffered_batch_idx: usize, - buffered_indices: &UInt64Array, -) -> Result> { - fetch_right_columns_from_batch_by_idxs( - &buffered_data.batches[buffered_batch_idx], - buffered_indices, - ) -} - -#[inline(always)] -fn fetch_right_columns_from_batch_by_idxs( - buffered_batch: &BufferedBatch, - buffered_indices: &UInt64Array, -) -> Result> { - match (&buffered_batch.spill_file, &buffered_batch.batch) { - // In memory batch - (None, Some(batch)) => Ok(batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() - .map_err(Into::::into)?), - // If the batch was spilled to disk, less likely - (Some(spill_file), None) => { - let mut buffered_cols: Vec = - Vec::with_capacity(buffered_indices.len()); - - let file = BufReader::new(File::open(spill_file.path())?); - let reader = StreamReader::try_new(file, None)?; - - for batch in reader { - batch?.columns().iter().for_each(|column| { - buffered_cols.extend(take(column, &buffered_indices, None)) - }); - } - - Ok(buffered_cols) - } - // Invalid combination - (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()), - } -} - -/// Buffered data contains all buffered batches with one unique join key -#[derive(Debug, Default)] -struct BufferedData { - /// Buffered batches with the same key - pub batches: VecDeque, - /// current scanning batch index used in join_partial() - pub scanning_batch_idx: usize, - /// current scanning offset used in join_partial() - pub scanning_offset: usize, -} - -impl BufferedData { - pub fn head_batch(&self) -> &BufferedBatch { - self.batches.front().unwrap() - } - - pub fn tail_batch(&self) -> &BufferedBatch { - self.batches.back().unwrap() - } - - pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch { - self.batches.back_mut().unwrap() - } - - pub fn has_buffered_rows(&self) -> bool { - self.batches.iter().any(|batch| !batch.range.is_empty()) - } - - pub fn scanning_reset(&mut self) { - self.scanning_batch_idx = 0; - self.scanning_offset = 0; - } - - pub fn scanning_advance(&mut self) { - self.scanning_offset += 1; - while !self.scanning_finished() && self.scanning_batch_finished() { - self.scanning_batch_idx += 1; - self.scanning_offset = 0; - } - } - - pub fn scanning_batch(&self) -> &BufferedBatch { - &self.batches[self.scanning_batch_idx] - } - - pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch { - &mut self.batches[self.scanning_batch_idx] - } - - pub fn scanning_idx(&self) -> usize { - self.scanning_batch().range.start + self.scanning_offset - } - - pub fn scanning_batch_finished(&self) -> bool { - self.scanning_offset == self.scanning_batch().range.len() - } - - pub fn scanning_finished(&self) -> bool { - self.scanning_batch_idx == self.batches.len() - } - - pub fn scanning_finish(&mut self) { - self.scanning_batch_idx = self.batches.len(); - self.scanning_offset = 0; - } -} - -/// Get join array refs of given batch and join columns -fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec { - on_column - .iter() - .map(|c| { - let num_rows = batch.num_rows(); - let c = c.evaluate(batch).unwrap(); - c.into_array(num_rows).unwrap() - }) - .collect() -} - -/// Get comparison result of two rows of join arrays -fn compare_join_arrays( - left_arrays: &[ArrayRef], - left: usize, - right_arrays: &[ArrayRef], - right: usize, - sort_options: &[SortOptions], - null_equals_null: bool, -) -> Result { - let mut res = Ordering::Equal; - for ((left_array, right_array), sort_options) in - left_arrays.iter().zip(right_arrays).zip(sort_options) - { - macro_rules! compare_value { - ($T:ty) => {{ - let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_value = &left_array.value(left); - let right_value = &right_array.value(right); - res = left_value.partial_cmp(right_value).unwrap(); - if sort_options.descending { - res = res.reverse(); - } - } - (true, false) => { - res = if sort_options.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - res = if sort_options.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - _ => { - res = if null_equals_null { - Ordering::Equal - } else { - Ordering::Less - }; - } - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !res.is_eq() { - break; - } - } - Ok(res) -} - -/// A faster version of compare_join_arrays() that only output whether -/// the given two rows are equal -fn is_join_arrays_equal( - left_arrays: &[ArrayRef], - left: usize, - right_arrays: &[ArrayRef], - right: usize, -) -> Result { - let mut is_equal = true; - for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { - macro_rules! compare_value { - ($T:ty) => {{ - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_array = - left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = - right_array.as_any().downcast_ref::<$T>().unwrap(); - if left_array.value(left) != right_array.value(right) { - is_equal = false; - } - } - (true, false) => is_equal = false, - (false, true) => is_equal = false, - _ => {} - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !is_equal { - return Ok(false); - } - } - Ok(true) -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::array::{ - builder::{BooleanBuilder, UInt64Builder}, - BooleanArray, Date32Array, Date64Array, Int32Array, RecordBatch, UInt64Array, - }; - use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; - use arrow::datatypes::{DataType, Field, Schema}; - - use datafusion_common::JoinType::*; - use datafusion_common::{assert_batches_eq, assert_contains, JoinType, Result}; - use datafusion_common::{ - test_util::{batches_to_sort_string, batches_to_string}, - JoinSide, - }; - use datafusion_execution::config::SessionConfig; - use datafusion_execution::disk_manager::DiskManagerConfig; - use datafusion_execution::runtime_env::RuntimeEnvBuilder; - use datafusion_execution::TaskContext; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::BinaryExpr; - use insta::{allow_duplicates, assert_snapshot}; - - use crate::expressions::Column; - use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; - use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; - use crate::joins::SortMergeJoinExec; - use crate::test::TestMemoryExec; - use crate::test::{build_table_i32, build_table_i32_two_cols}; - use crate::{common, ExecutionPlan}; - - fn build_table( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> Arc { - let batch = build_table_i32(a, b, c); - let schema = batch.schema(); - TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() - } - - fn build_table_from_batches(batches: Vec) -> Arc { - let schema = batches.first().unwrap().schema(); - TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap() - } - - fn build_date_table( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> Arc { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Date32, false), - Field::new(b.0, DataType::Date32, false), - Field::new(c.0, DataType::Date32, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Date32Array::from(a.1.clone())), - Arc::new(Date32Array::from(b.1.clone())), - Arc::new(Date32Array::from(c.1.clone())), - ], - ) - .unwrap(); - - let schema = batch.schema(); - TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() - } - - fn build_date64_table( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> Arc { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Date64, false), - Field::new(b.0, DataType::Date64, false), - Field::new(c.0, DataType::Date64, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Date64Array::from(a.1.clone())), - Arc::new(Date64Array::from(b.1.clone())), - Arc::new(Date64Array::from(c.1.clone())), - ], - ) - .unwrap(); - - let schema = batch.schema(); - TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() - } - - /// returns a table with 3 columns of i32 in memory - pub fn build_table_i32_nullable( - a: (&str, &Vec>), - b: (&str, &Vec>), - c: (&str, &Vec>), - ) -> Arc { - let schema = Arc::new(Schema::new(vec![ - Field::new(a.0, DataType::Int32, true), - Field::new(b.0, DataType::Int32, true), - Field::new(c.0, DataType::Int32, true), - ])); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - ) - .unwrap(); - TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() - } - - pub fn build_table_two_cols( - a: (&str, &Vec), - b: (&str, &Vec), - ) -> Arc { - let batch = build_table_i32_two_cols(a, b); - let schema = batch.schema(); - TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() - } - - fn join( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - ) -> Result { - let sort_options = vec![SortOptions::default(); on.len()]; - SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false) - } - - fn join_with_options( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - sort_options: Vec, - null_equals_null: bool, - ) -> Result { - SortMergeJoinExec::try_new( - left, - right, - on, - None, - join_type, - sort_options, - null_equals_null, - ) - } - - fn join_with_filter( - left: Arc, - right: Arc, - on: JoinOn, - filter: JoinFilter, - join_type: JoinType, - sort_options: Vec, - null_equals_null: bool, - ) -> Result { - SortMergeJoinExec::try_new( - left, - right, - on, - Some(filter), - join_type, - sort_options, - null_equals_null, - ) - } - - async fn join_collect( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - ) -> Result<(Vec, Vec)> { - let sort_options = vec![SortOptions::default(); on.len()]; - join_collect_with_options(left, right, on, join_type, sort_options, false).await - } - - async fn join_collect_with_filter( - left: Arc, - right: Arc, - on: JoinOn, - filter: JoinFilter, - join_type: JoinType, - ) -> Result<(Vec, Vec)> { - let sort_options = vec![SortOptions::default(); on.len()]; - - let task_ctx = Arc::new(TaskContext::default()); - let join = - join_with_filter(left, right, on, filter, join_type, sort_options, false)?; - let columns = columns(&join.schema()); - - let stream = join.execute(0, task_ctx)?; - let batches = common::collect(stream).await?; - Ok((columns, batches)) - } - - async fn join_collect_with_options( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - sort_options: Vec, - null_equals_null: bool, - ) -> Result<(Vec, Vec)> { - let task_ctx = Arc::new(TaskContext::default()); - let join = join_with_options( - left, - right, - on, - join_type, - sort_options, - null_equals_null, - )?; - let columns = columns(&join.schema()); - - let stream = join.execute(0, task_ctx)?; - let batches = common::collect(stream).await?; - Ok((columns, batches)) - } - - async fn join_collect_batch_size_equals_two( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - ) -> Result<(Vec, Vec)> { - let task_ctx = TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(2)); - let task_ctx = Arc::new(task_ctx); - let join = join(left, right, on, join_type)?; - let columns = columns(&join.schema()); - - let stream = join.execute(0, task_ctx)?; - let batches = common::collect(stream).await?; - Ok((columns, batches)) - } - - #[tokio::test] - async fn join_inner_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 5]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | 3 | 5 | 9 | 20 | 5 | 80 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_inner_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b2", &vec![1, 2, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 2, 3]), - ("b2", &vec![1, 2, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - ), - ]; - - let (_columns, batches) = join_collect(left, right, on, Inner).await?; - - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | 7 | 1 | 1 | 70 | - | 2 | 2 | 8 | 2 | 2 | 80 | - | 2 | 2 | 9 | 2 | 2 | 80 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_inner_two_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 1, 2]), - ("b2", &vec![1, 1, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 1, 3]), - ("b2", &vec![1, 1, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - ), - ]; - - let (_columns, batches) = join_collect(left, right, on, Inner).await?; - - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | 7 | 1 | 1 | 70 | - | 1 | 1 | 7 | 1 | 1 | 80 | - | 1 | 1 | 8 | 1 | 1 | 70 | - | 1 | 1 | 8 | 1 | 1 | 80 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_inner_with_nulls() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), - ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field - ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), - ("b2", &vec![None, Some(1), Some(2), Some(2)]), - ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), - ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - ), - ]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | | 1 | 1 | 70 | - | 2 | 2 | 8 | 2 | 2 | 80 | - | 2 | 2 | 9 | 2 | 2 | 80 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_inner_with_nulls_with_options() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]), - ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field - ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]), - ("b2", &vec![Some(2), Some(2), Some(1), None]), - ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]), - ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - ), - ]; - let (_, batches) = join_collect_with_options( - left, - right, - on, - Inner, - vec![ - SortOptions { - descending: true, - nulls_first: false, - }; - 2 - ], - true, - ) - .await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 2 | 2 | 9 | 2 | 2 | 80 | - | 2 | 2 | 8 | 2 | 2 | 80 | - | 1 | 1 | | 1 | 1 | 70 | - | 1 | | 1 | 1 | | 10 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_inner_output_two_batches() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b2", &vec![1, 2, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 2, 3]), - ("b2", &vec![1, 2, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - ), - ]; - - let (_, batches) = - join_collect_batch_size_equals_two(left, right, on, Inner).await?; - assert_eq!(batches.len(), 2); - assert_eq!(batches[0].num_rows(), 2); - assert_eq!(batches[1].num_rows(), 1); - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | 7 | 1 | 1 | 70 | - | 2 | 2 | 8 | 2 | 2 | 80 | - | 2 | 2 | 9 | 2 | 2 | 80 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_left_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | 3 | 7 | 9 | | | | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_right_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), // 6 does not exist on the left - ("c2", &vec![70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | | | | 30 | 6 | 90 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_full_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b2", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Full).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | | | | 30 | 6 | 90 | - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | 3 | 7 | 9 | | | | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_left_anti() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3, 5]), - ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9, 11]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, LeftAnti).await?; - - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 3 | 7 | 9 | - | 5 | 7 | 11 | - +----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_right_anti_one_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b1", &vec![4, 5, 5]), - ("c1", &vec![7, 8, 8]), - ); - let right = - build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, RightAnti).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+ - | a2 | b1 | - +----+----+ - | 30 | 6 | - +----+----+ - "#); - - let left2 = build_table( - ("a1", &vec![1, 2, 2]), - ("b1", &vec![4, 5, 5]), - ("c1", &vec![7, 8, 8]), - ); - let right2 = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _, - )]; - - let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches2), @r#" - +----+----+----+ - | a2 | b1 | c2 | - +----+----+----+ - | 30 | 6 | 90 | - +----+----+----+ - "#); - - Ok(()) - } - - #[tokio::test] - async fn join_right_anti_two_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b1", &vec![4, 5, 5]), - ("c1", &vec![7, 8, 8]), - ); - let right = - build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - ), - ]; - - let (_, batches) = join_collect(left, right, on, RightAnti).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+ - | a2 | b1 | - +----+----+ - | 10 | 4 | - | 20 | 5 | - | 30 | 6 | - +----+----+ - "#); - - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b1", &vec![4, 5, 5]), - ("c1", &vec![7, 8, 8]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - ), - ]; - - let (_, batches) = join_collect(left, right, on, RightAnti).await?; - let expected = [ - "+----+----+----+", - "| a2 | b1 | c2 |", - "+----+----+----+", - "| 10 | 4 | 70 |", - "| 20 | 5 | 80 |", - "| 30 | 6 | 90 |", - "+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_eq!(expected, &batches); - - Ok(()) - } - - #[tokio::test] - async fn join_right_anti_two_with_filter() -> Result<()> { - let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); - let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - ), - ]; - let filter = JoinFilter::new( - Arc::new(BinaryExpr::new( - Arc::new(Column::new("c2", 1)), - Operator::Gt, - Arc::new(Column::new("c1", 0)), - )), - vec![ - ColumnIndex { - index: 2, - side: JoinSide::Left, - }, - ColumnIndex { - index: 2, - side: JoinSide::Right, - }, - ], - Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - ])), - ); - let (_, batches) = - join_collect_with_filter(left, right, on, filter, RightAnti).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c2 | - +----+----+----+ - | 1 | 10 | 20 | - +----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_right_anti_with_nulls() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), - ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), - ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), - ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field - ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field - ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - ), - ]; - - let (_, batches) = join_collect(left, right, on, RightAnti).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c2 | - +----+----+----+ - | 2 | | 8 | - +----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_right_anti_with_nulls_with_options() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]), - ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]), - ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]), - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), - ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field - ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field - ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - ), - ]; - - let (_, batches) = join_collect_with_options( - left, - right, - on, - RightAnti, - vec![ - SortOptions { - descending: true, - nulls_first: false, - }; - 2 - ], - true, - ) - .await?; - - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c2 | - +----+----+----+ - | 3 | | 9 | - | 2 | 5 | | - | 2 | 5 | 8 | - +----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_right_anti_output_two_batches() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b1", &vec![4, 5, 5]), - ("c1", &vec![7, 8, 8]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, - ), - ( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - ), - ]; - - let (_, batches) = - join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?; - assert_eq!(batches.len(), 2); - assert_eq!(batches[0].num_rows(), 2); - assert_eq!(batches[1].num_rows(), 1); - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 1 | 4 | 7 | - | 2 | 5 | 8 | - | 2 | 5 | 8 | - +----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_semi() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), // 5 is double on the right - ("c2", &vec![70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, LeftSemi).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 1 | 4 | 7 | - | 2 | 5 | 8 | - | 2 | 5 | 8 | - +----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_left_mark() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30, 40]), - ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right - ("c2", &vec![60, 70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, LeftMark).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+-------+ - | a1 | b1 | c1 | mark | - +----+----+----+-------+ - | 1 | 4 | 7 | true | - | 2 | 5 | 8 | true | - | 2 | 5 | 8 | true | - | 3 | 7 | 9 | false | - +----+----+----+-------+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_with_duplicated_column_names() -> Result<()> { - let left = build_table( - ("a", &vec![1, 2, 3]), - ("b", &vec![4, 5, 7]), - ("c", &vec![7, 8, 9]), - ); - let right = build_table( - ("a", &vec![10, 20, 30]), - ("b", &vec![1, 2, 7]), - ("c", &vec![70, 80, 90]), - ); - let on = vec![( - // join on a=b so there are duplicate column names on unjoined columns - Arc::new(Column::new_with_schema("a", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +---+---+---+----+---+----+ - | a | b | c | a | b | c | - +---+---+---+----+---+----+ - | 1 | 4 | 7 | 10 | 1 | 70 | - | 2 | 5 | 8 | 20 | 2 | 80 | - +---+---+---+----+---+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_date32() -> Result<()> { - let left = build_date_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![19107, 19108, 19108]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_date_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![19107, 19108, 19109]), - ("c2", &vec![70, 80, 90]), - ); - - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +------------+------------+------------+------------+------------+------------+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +------------+------------+------------+------------+------------+------------+ - | 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 | - | 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 | - | 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 | - +------------+------------+------------+------------+------------+------------+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_date64() -> Result<()> { - let left = build_date64_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_date64_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), - ("c2", &vec![70, 80, 90]), - ); - - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - | 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 | - | 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | - | 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_left_sort_order() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3, 4, 5]), - ("b1", &vec![3, 4, 5, 6, 6, 7]), - ("c1", &vec![4, 5, 6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30, 40]), - ("b2", &vec![2, 4, 6, 6, 8]), - ("c2", &vec![50, 60, 70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | 0 | 3 | 4 | | | | - | 1 | 4 | 5 | 10 | 4 | 60 | - | 2 | 5 | 6 | | | | - | 3 | 6 | 7 | 20 | 6 | 70 | - | 3 | 6 | 7 | 30 | 6 | 80 | - | 4 | 6 | 8 | 20 | 6 | 70 | - | 4 | 6 | 8 | 30 | 6 | 80 | - | 5 | 7 | 9 | | | | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_right_sort_order() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3]), - ("b1", &vec![3, 4, 5, 7]), - ("c1", &vec![6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30]), - ("b2", &vec![2, 4, 5, 6]), - ("c2", &vec![60, 70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | | | | 0 | 2 | 60 | - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | | | | 30 | 6 | 90 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_left_multiple_batches() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1, 2]), - ("b1", &vec![3, 4, 5]), - ("c1", &vec![4, 5, 6]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![3, 4, 5, 6]), - ("b1", &vec![6, 6, 7, 9]), - ("c1", &vec![7, 8, 9, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10, 20]), - ("b2", &vec![2, 4, 6]), - ("c2", &vec![50, 60, 70]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![30, 40]), - ("b2", &vec![6, 8]), - ("c2", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | 0 | 3 | 4 | | | | - | 1 | 4 | 5 | 10 | 4 | 60 | - | 2 | 5 | 6 | | | | - | 3 | 6 | 7 | 20 | 6 | 70 | - | 3 | 6 | 7 | 30 | 6 | 80 | - | 4 | 6 | 8 | 20 | 6 | 70 | - | 4 | 6 | 8 | 30 | 6 | 80 | - | 5 | 7 | 9 | | | | - | 6 | 9 | 9 | | | | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_right_multiple_batches() -> Result<()> { - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 1, 2]), - ("b2", &vec![3, 4, 5]), - ("c2", &vec![4, 5, 6]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![3, 4, 5, 6]), - ("b2", &vec![6, 6, 7, 9]), - ("c2", &vec![7, 8, 9, 9]), - ); - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 10, 20]), - ("b1", &vec![2, 4, 6]), - ("c1", &vec![50, 60, 70]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![30, 40]), - ("b1", &vec![6, 8]), - ("c1", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | | | | 0 | 3 | 4 | - | 10 | 4 | 60 | 1 | 4 | 5 | - | | | | 2 | 5 | 6 | - | 20 | 6 | 70 | 3 | 6 | 7 | - | 30 | 6 | 80 | 3 | 6 | 7 | - | 20 | 6 | 70 | 4 | 6 | 8 | - | 30 | 6 | 80 | 4 | 6 | 8 | - | | | | 5 | 7 | 9 | - | | | | 6 | 9 | 9 | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn join_full_multiple_batches() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1, 2]), - ("b1", &vec![3, 4, 5]), - ("c1", &vec![4, 5, 6]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![3, 4, 5, 6]), - ("b1", &vec![6, 6, 7, 9]), - ("c1", &vec![7, 8, 9, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10, 20]), - ("b2", &vec![2, 4, 6]), - ("c2", &vec![50, 60, 70]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![30, 40]), - ("b2", &vec![6, 8]), - ("c2", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - - let (_, batches) = join_collect(left, right, on, Full).await?; - assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | | | | 0 | 2 | 50 | - | | | | 40 | 8 | 90 | - | 0 | 3 | 4 | | | | - | 1 | 4 | 5 | 10 | 4 | 60 | - | 2 | 5 | 6 | | | | - | 3 | 6 | 7 | 20 | 6 | 70 | - | 3 | 6 | 7 | 30 | 6 | 80 | - | 4 | 6 | 8 | 20 | 6 | 70 | - | 4 | 6 | 8 | 30 | 6 | 80 | - | 5 | 7 | 9 | | | | - | 6 | 9 | 9 | | | | - +----+----+----+----+----+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn overallocation_single_batch_no_spill() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3, 4, 5]), - ("b1", &vec![1, 2, 3, 4, 5, 6]), - ("c1", &vec![4, 5, 6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30, 40]), - ("b2", &vec![1, 3, 4, 6, 8]), - ("c2", &vec![50, 60, 70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - let sort_options = vec![SortOptions::default(); on.len()]; - - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; - - // Disable DiskManager to prevent spilling - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::Disabled) - .build_arc()?; - let session_config = SessionConfig::default().with_batch_size(50); - - for join_type in join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - join_type, - sort_options.clone(), - false, - )?; - - let stream = join.execute(0, task_ctx)?; - let err = common::collect(stream).await.unwrap_err(); - - assert_contains!(err.to_string(), "Failed to allocate additional"); - assert_contains!(err.to_string(), "SMJStream[0]"); - assert_contains!(err.to_string(), "Disk spilling disabled"); - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - } - - Ok(()) - } - - #[tokio::test] - async fn overallocation_multi_batch_no_spill() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1]), - ("b1", &vec![1, 1]), - ("c1", &vec![4, 5]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![2, 3]), - ("b1", &vec![1, 1]), - ("c1", &vec![6, 7]), - ); - let left_batch_3 = build_table_i32( - ("a1", &vec![4, 5]), - ("b1", &vec![1, 1]), - ("c1", &vec![8, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10]), - ("b2", &vec![1, 1]), - ("c2", &vec![50, 60]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![20, 30]), - ("b2", &vec![1, 1]), - ("c2", &vec![70, 80]), - ); - let right_batch_3 = - build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); - let left = - build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); - let right = - build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - let sort_options = vec![SortOptions::default(); on.len()]; - - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; - - // Disable DiskManager to prevent spilling - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::Disabled) - .build_arc()?; - let session_config = SessionConfig::default().with_batch_size(50); - - for join_type in join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - join_type, - sort_options.clone(), - false, - )?; - - let stream = join.execute(0, task_ctx)?; - let err = common::collect(stream).await.unwrap_err(); - - assert_contains!(err.to_string(), "Failed to allocate additional"); - assert_contains!(err.to_string(), "SMJStream[0]"); - assert_contains!(err.to_string(), "Disk spilling disabled"); - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - } - - Ok(()) - } - - #[tokio::test] - async fn overallocation_single_batch_spill() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3, 4, 5]), - ("b1", &vec![1, 2, 3, 4, 5, 6]), - ("c1", &vec![4, 5, 6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30, 40]), - ("b2", &vec![1, 3, 4, 6, 8]), - ("c2", &vec![50, 60, 70, 80, 90]), - ); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - let sort_options = vec![SortOptions::default(); on.len()]; - - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; - - // Enable DiskManager to allow spilling - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::NewOs) - .build_arc()?; - - for batch_size in [1, 50] { - let session_config = SessionConfig::default().with_batch_size(batch_size); - - for join_type in &join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - false, - )?; - - let stream = join.execute(0, task_ctx)?; - let spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert!(join.metrics().unwrap().spill_count().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); - - // Run the test with no spill configuration as - let task_ctx_no_spill = - TaskContext::default().with_session_config(session_config.clone()); - let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - false, - )?; - let stream = join.execute(0, task_ctx_no_spill)?; - let no_spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - // Compare spilled and non spilled data to check spill logic doesn't corrupt the data - assert_eq!(spilled_join_result, no_spilled_join_result); - } - } - - Ok(()) - } - - #[tokio::test] - async fn overallocation_multi_batch_spill() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1]), - ("b1", &vec![1, 1]), - ("c1", &vec![4, 5]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![2, 3]), - ("b1", &vec![1, 1]), - ("c1", &vec![6, 7]), - ); - let left_batch_3 = build_table_i32( - ("a1", &vec![4, 5]), - ("b1", &vec![1, 1]), - ("c1", &vec![8, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10]), - ("b2", &vec![1, 1]), - ("c2", &vec![50, 60]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![20, 30]), - ("b2", &vec![1, 1]), - ("c2", &vec![70, 80]), - ); - let right_batch_3 = - build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); - let left = - build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); - let right = - build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - let sort_options = vec![SortOptions::default(); on.len()]; - - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; - - // Enable DiskManager to allow spilling - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(500, 1.0) - .with_disk_manager(DiskManagerConfig::NewOs) - .build_arc()?; - - for batch_size in [1, 50] { - let session_config = SessionConfig::default().with_batch_size(batch_size); - - for join_type in &join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - false, - )?; - - let stream = join.execute(0, task_ctx)?; - let spilled_join_result = common::collect(stream).await.unwrap(); - assert!(join.metrics().is_some()); - assert!(join.metrics().unwrap().spill_count().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); - - // Run the test with no spill configuration as - let task_ctx_no_spill = - TaskContext::default().with_session_config(session_config.clone()); - let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - false, - )?; - let stream = join.execute(0, task_ctx_no_spill)?; - let no_spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - // Compare spilled and non spilled data to check spill logic doesn't corrupt the data - assert_eq!(spilled_join_result, no_spilled_join_result); - } - } - - Ok(()) - } - - fn build_joined_record_batches() -> Result { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("x", DataType::Int32, true), - Field::new("y", DataType::Int32, true), - ])); - - let mut batches = JoinedRecordBatches { - batches: vec![], - filter_mask: BooleanBuilder::new(), - row_indices: UInt64Builder::new(), - batch_ids: vec![], - }; - - // Insert already prejoined non-filtered rows - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![10, 10])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![11, 9])), - ], - )?); - - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![11])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![12])), - ], - )?); - - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![12, 12])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![11, 13])), - ], - )?); - - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![13])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![12])), - ], - )?); - - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![14, 14])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![12, 11])), - ], - )?); - - let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![0; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![1]; - batches.batch_ids.extend(vec![0; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![1; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0]; - batches.batch_ids.extend(vec![2; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); - - let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![3; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); - - batches - .filter_mask - .extend(&BooleanArray::from(vec![true, false])); - batches.filter_mask.extend(&BooleanArray::from(vec![true])); - batches - .filter_mask - .extend(&BooleanArray::from(vec![false, true])); - batches.filter_mask.extend(&BooleanArray::from(vec![false])); - batches - .filter_mask - .extend(&BooleanArray::from(vec![false, false])); - - Ok(batches) - } - - #[tokio::test] - async fn test_left_outer_join_filtered_mask() -> Result<()> { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); - - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); - - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - true, false, false, false, false, false, false, false - ]) - ); - - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - false, false, false, false, false, false, false, false - ]) - ); - - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - true, true, false, false, false, false, false, false - ]) - ); - - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![true, true, true, false, false, false, false, false]) - ); - - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - Some(false), - Some(false), - Some(false), - Some(false), - Some(false) - ]) - ); - - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - None, - None, - Some(true), - Some(false), - Some(false), - Some(false), - Some(false), - Some(false) - ]) - ); - - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - None, - Some(true), - Some(true), - Some(false), - Some(false), - Some(false), - Some(false), - Some(false) - ]) - ); - - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - None, - None, - Some(false), - Some(false), - Some(false), - Some(false), - Some(false), - Some(false) - ]) - ); - - let corrected_mask = get_corrected_filter_mask( - Left, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); - - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - None, - Some(true), - Some(false), - None, - Some(false) - ]) - ); - - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; - - assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 10 | 1 | 11 | - | 1 | 11 | 1 | 12 | - | 1 | 12 | 1 | 13 | - +---+----+---+----+ - "#); - - // output null rows - - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - Some(false), - None, - Some(false), - None, - Some(false), - Some(true), - None, - Some(true) - ]) - ); - - let null_joined_batch = filter_record_batch(&output, &null_mask)?; - - assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 13 | 1 | 12 | - | 1 | 14 | 1 | 11 | - +---+----+---+----+ - "#); - Ok(()) - } - - #[tokio::test] - async fn test_left_semi_join_filtered_mask() -> Result<()> { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); - - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); - - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![true]) - ); - - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None]) - ); - - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None]) - ); - - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); - - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); - - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, None, Some(true),]) - ); - - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, Some(true), None]) - ); - - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); - - let corrected_mask = get_corrected_filter_mask( - LeftSemi, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); - - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - None, - Some(true), - None, - None, - None - ]) - ); - - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; - - assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 10 | 1 | 11 | - | 1 | 11 | 1 | 12 | - | 1 | 12 | 1 | 13 | - +---+----+---+----+ - "#); - - // output null rows - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - Some(false), - None, - Some(false), - None, - Some(false), - None, - None, - None - ]) - ); - - let null_joined_batch = filter_record_batch(&output, &null_mask)?; - - assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" - +---+---+---+---+ - | a | b | x | y | - +---+---+---+---+ - +---+---+---+---+ - "#); - Ok(()) - } - - #[tokio::test] - async fn test_anti_join_filtered_mask() -> Result<()> { - for join_type in [LeftAnti, RightAnti] { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); - - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - 1 - ) - .unwrap(), - BooleanArray::from(vec![None]) - ); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - 1 - ) - .unwrap(), - BooleanArray::from(vec![Some(true)]) - ); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - 2 - ) - .unwrap(), - BooleanArray::from(vec![None, None]) - ); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, Some(true)]) - ); - - let corrected_mask = get_corrected_filter_mask( - join_type, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); - - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - None, - None, - None, - None, - None, - Some(true), - None, - Some(true) - ]) - ); - - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; - - allow_duplicates! { - assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 13 | 1 | 12 | - | 1 | 14 | 1 | 11 | - +---+----+---+----+ - "#); - } - - // output null rows - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - None, - None, - None, - None, - None, - Some(false), - None, - Some(false), - ]) - ); - - let null_joined_batch = filter_record_batch(&output, &null_mask)?; - - allow_duplicates! { - assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" - +---+---+---+---+ - | a | b | x | y | - +---+---+---+---+ - +---+---+---+---+ - "#); - } - } - Ok(()) - } - - /// Returns the column names on the schema - fn columns(schema: &Schema) -> Vec { - schema.fields().iter().map(|f| f.name().clone()).collect() - } -} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs new file mode 100644 index 0000000000000..592878a3bb1c5 --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -0,0 +1,594 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the Sort-Merge join execution plan. +//! A Sort-Merge join plan consumes two sorted children plans and produces +//! joined output by given join type and other options. + +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; + +use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::expressions::PhysicalSortExpr; +use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; +use crate::joins::sort_merge_join::stream::SortMergeJoinStream; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn, + JoinOnRef, +}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::projection::{ + join_allows_pushdown, join_table_borders, new_join_children, + physical_to_column_exprs, update_join_on, ProjectionExec, +}; +use crate::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + PlanProperties, SendableRecordBatchStream, Statistics, +}; + +use arrow::compute::SortOptions; +use arrow::datatypes::SchemaRef; +use datafusion_common::{ + internal_err, plan_err, JoinSide, JoinType, NullEquality, Result, +}; +use datafusion_execution::memory_pool::MemoryConsumer; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; + +/// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge +/// join algorithm and applies an optional filter post join. Can be used to join arbitrarily large +/// inputs where one or both of the inputs don't fit in the available memory. +/// +/// # Join Expressions +/// +/// Equi-join predicate (e.g. ` = `) expressions are represented by [`Self::on`]. +/// +/// Non-equality predicates, which can not be pushed down to join inputs (e.g. +/// ` != `) are known as "filter expressions" and are evaluated +/// after the equijoin predicates. They are represented by [`Self::filter`]. These are optional +/// expressions. +/// +/// # Sorting +/// +/// Assumes that both the left and right input to the join are pre-sorted. It is not the +/// responsibility of this execution plan to sort the inputs. +/// +/// # "Streamed" vs "Buffered" +/// +/// The number of record batches of streamed input currently present in the memory will depend +/// on the output batch size of the execution plan. There is no spilling support for streamed input. +/// The comparisons are performed from values of join keys in streamed input with the values of +/// join keys in buffered input. One row in streamed record batch could be matched with multiple rows in +/// buffered input batches. The streamed input is managed through the states in `StreamedState` +/// and streamed input batches are represented by `StreamedBatch`. +/// +/// Buffered input is buffered for all record batches having the same value of join key. +/// If the memory limit increases beyond the specified value and spilling is enabled, +/// buffered batches could be spilled to disk. If spilling is disabled, the execution +/// will fail under the same conditions. Multiple record batches of buffered could currently reside +/// in memory/disk during the execution. The number of buffered batches residing in +/// memory/disk depends on the number of rows of buffered input having the same value +/// of join key as that of streamed input rows currently present in memory. Due to pre-sorted inputs, +/// the algorithm understands when it is not needed anymore, and releases the buffered batches +/// from memory/disk. The buffered input is managed through the states in `BufferedState` +/// and buffered input batches are represented by `BufferedBatch`. +/// +/// Depending on the type of join, left or right input may be selected as streamed or buffered +/// respectively. For example, in a left-outer join, the left execution plan will be selected as +/// streamed input while in a right-outer join, the right execution plan will be selected as the +/// streamed input. +/// +/// Reference for the algorithm: +/// . +/// +/// Helpful short video demonstration: +/// . +#[derive(Debug, Clone)] +pub struct SortMergeJoinExec { + /// Left sorted joining execution plan + pub left: Arc, + /// Right sorting joining execution plan + pub right: Arc, + /// Set of common columns used to join on + pub on: JoinOn, + /// Filters which are applied while finding matching rows + pub filter: Option, + /// How the join is performed + pub join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// The left SortExpr + left_sort_exprs: LexOrdering, + /// The right SortExpr + right_sort_exprs: LexOrdering, + /// Sort options of join columns used in sorting left and right execution plans + pub sort_options: Vec, + /// Defines the null equality for the join. + pub null_equality: NullEquality, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, +} + +impl SortMergeJoinExec { + /// Tries to create a new [SortMergeJoinExec]. + /// The inputs are sorted using `sort_options` are applied to the columns in the `on` + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + check_join_is_valid(&left_schema, &right_schema, &on)?; + if sort_options.len() != on.len() { + return plan_err!( + "Expected number of sort options: {}, actual: {}", + on.len(), + sort_options.len() + ); + } + + let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on + .iter() + .zip(sort_options.iter()) + .map(|((l, r), sort_op)| { + let left = PhysicalSortExpr { + expr: Arc::clone(l), + options: *sort_op, + }; + let right = PhysicalSortExpr { + expr: Arc::clone(r), + options: *sort_op, + }; + (left, right) + }) + .unzip(); + let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { + return plan_err!( + "SortMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { + return plan_err!( + "SortMergeJoinExec requires valid sort expressions for its right side" + ); + }; + + let schema = + Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); + let cache = + Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?; + Ok(Self { + left, + right, + on, + filter, + join_type, + schema, + metrics: ExecutionPlanMetricsSet::new(), + left_sort_exprs, + right_sort_exprs, + sort_options, + null_equality, + cache, + }) + } + + /// Get probe side (e.g streaming side) information for this sort merge join. + /// In current implementation, probe side is determined according to join type. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + // When output schema contains only the right side, probe side is right. + // Otherwise probe side is the left side. + match join_type { + // TODO: sort merge support for right mark (tracked here: https://github.com/apache/datafusion/issues/16226) + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => JoinSide::Right, + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, + } + } + + /// Calculate order preservation flags for this sort merge join. + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + JoinType::Inner => vec![true, false], + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => vec![true, false], + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => { + vec![false, true] + } + _ => vec![false, false], + } + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { + &self.on + } + + /// Ref to right execution plan + pub fn right(&self) -> &Arc { + &self.right + } + + /// Join type + pub fn join_type(&self) -> JoinType { + self.join_type + } + + /// Ref to left execution plan + pub fn left(&self) -> &Arc { + &self.left + } + + /// Ref to join filter + pub fn filter(&self) -> &Option { + &self.filter + } + + /// Ref to sort options + pub fn sort_options(&self) -> &[SortOptions] { + &self.sort_options + } + + /// Null equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + left: &Arc, + right: &Arc, + schema: SchemaRef, + join_type: JoinType, + join_on: JoinOnRef, + ) -> Result { + // Calculate equivalence properties: + let eq_properties = join_equivalence_properties( + left.equivalence_properties().clone(), + right.equivalence_properties().clone(), + &join_type, + schema, + &Self::maintains_input_order(join_type), + Some(Self::probe_side(&join_type)), + join_on, + )?; + + let output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type)?; + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Incremental, + boundedness_from_children([left, right]), + )) + } + + /// # Notes: + /// + /// This function should be called BEFORE inserting any repartitioning + /// operators on the join's children. Check [`super::super::HashJoinExec::swap_inputs`] + /// for more details. + pub fn swap_inputs(&self) -> Result> { + let left = self.left(); + let right = self.right(); + let new_join = SortMergeJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + self.on() + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .collect::>(), + self.filter().as_ref().map(JoinFilter::swap), + self.join_type().swap(), + self.sort_options.clone(), + self.null_equality, + )?; + + // TODO: OR this condition with having a built-in projection (like + // ordinary hash join) when we support it. + if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) { + Ok(Arc::new(new_join)) + } else { + reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) + } + } +} + +impl DisplayAs for SortMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let on = self + .on + .iter() + .map(|(c1, c2)| format!("({c1}, {c2})")) + .collect::>() + .join(", "); + let display_null_equality = + if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + ", NullsEqual: true" + } else { + "" + }; + write!( + f, + "SortMergeJoin: join_type={:?}, on=[{}]{}{}", + self.join_type, + on, + self.filter.as_ref().map_or_else( + || "".to_string(), + |f| format!(", filter={}", f.expression()) + ), + display_null_equality, + ) + } + DisplayFormatType::TreeRender => { + let on = self + .on + .iter() + .map(|(c1, c2)| { + format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref())) + }) + .collect::>() + .join(", "); + + if self.join_type() != JoinType::Inner { + writeln!(f, "join_type={:?}", self.join_type)?; + } + writeln!(f, "on={on}")?; + + if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + writeln!(f, "NullsEqual: true")?; + } + + Ok(()) + } + } + } +} + +impl ExecutionPlan for SortMergeJoinExec { + fn name(&self) -> &'static str { + "SortMergeJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn required_input_distribution(&self) -> Vec { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + + fn required_input_ordering(&self) -> Vec> { + vec![ + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), + ] + } + + fn maintains_input_order(&self) -> Vec { + Self::maintains_input_order(self.join_type) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + self.on.clone(), + self.filter.clone(), + self.join_type, + self.sort_options.clone(), + self.null_equality, + )?)), + _ => internal_err!("SortMergeJoin wrong number of children"), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let left_partitions = self.left.output_partitioning().partition_count(); + let right_partitions = self.right.output_partitioning().partition_count(); + if left_partitions != right_partitions { + return internal_err!( + "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ + consider using RepartitionExec" + ); + } + let (on_left, on_right) = self.on.iter().cloned().unzip(); + let (streamed, buffered, on_streamed, on_buffered) = + if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { + ( + Arc::clone(&self.left), + Arc::clone(&self.right), + on_left, + on_right, + ) + } else { + ( + Arc::clone(&self.right), + Arc::clone(&self.left), + on_right, + on_left, + ) + }; + + // execute children plans + let streamed = streamed.execute(partition, Arc::clone(&context))?; + let buffered = buffered.execute(partition, Arc::clone(&context))?; + + // create output buffer + let batch_size = context.session_config().batch_size(); + + // create memory reservation + let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]")) + .register(context.memory_pool()); + + // create join stream + Ok(Box::pin(SortMergeJoinStream::try_new( + context.session_config().spill_compression(), + Arc::clone(&self.schema), + self.sort_options.clone(), + self.null_equality, + streamed, + buffered, + on_streamed, + on_buffered, + self.filter.clone(), + self.join_type, + batch_size, + SortMergeJoinMetrics::new(partition, &self.metrics), + reservation, + context.runtime_env(), + )?)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } + // TODO stats: it is not possible in general to know the output size of joins + // There are some special cases though, for example: + // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` + estimate_join_statistics( + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, + self.on.clone(), + &self.join_type, + &self.schema, + ) + } + + /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done, + /// it returns the new swapped version having the [`SortMergeJoinExec`] as the top plan. + /// Otherwise, it returns None. + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + // Convert projected PhysicalExpr's to columns. If not possible, we cannot proceed. + let Some(projection_as_columns) = physical_to_column_exprs(projection.expr()) + else { + return Ok(None); + }; + + let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders( + self.left().schema().fields().len(), + &projection_as_columns, + ); + + if !join_allows_pushdown( + &projection_as_columns, + &self.schema(), + far_right_left_col_ind, + far_left_right_col_ind, + ) { + return Ok(None); + } + + let Some(new_on) = update_join_on( + &projection_as_columns[0..=far_right_left_col_ind as _], + &projection_as_columns[far_left_right_col_ind as _..], + self.on(), + self.left().schema().fields().len(), + ) else { + return Ok(None); + }; + + let (new_left, new_right) = new_join_children( + &projection_as_columns, + far_right_left_col_ind, + far_left_right_col_ind, + self.children()[0], + self.children()[1], + )?; + + Ok(Some(Arc::new(SortMergeJoinExec::try_new( + Arc::new(new_left), + Arc::new(new_right), + new_on, + self.filter.clone(), + self.join_type, + self.sort_options.clone(), + self.null_equality, + )?))) + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs b/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs new file mode 100644 index 0000000000000..5920cd663a775 --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Module for tracking Sort Merge Join metrics + +use crate::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, SpillMetrics, + Time, +}; + +/// Metrics for SortMergeJoinExec +#[allow(dead_code)] +pub(super) struct SortMergeJoinMetrics { + /// Total time for joining probe-side batches to the build-side batches + join_time: Time, + /// Number of batches consumed by this operator + input_batches: Count, + /// Number of rows consumed by this operator + input_rows: Count, + /// Number of batches produced by this operator + output_batches: Count, + /// Execution metrics + baseline_metrics: BaselineMetrics, + /// Peak memory used for buffered data. + /// Calculated as sum of peak memory values across partitions + peak_mem_used: Gauge, + /// Metrics related to spilling + spill_metrics: SpillMetrics, +} + +impl SortMergeJoinMetrics { + #[allow(dead_code)] + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); + let spill_metrics = SpillMetrics::new(metrics, partition); + + let baseline_metrics = BaselineMetrics::new(metrics, partition); + + Self { + join_time, + input_batches, + input_rows, + output_batches, + baseline_metrics, + peak_mem_used, + spill_metrics, + } + } + + pub fn join_time(&self) -> Time { + self.join_time.clone() + } + + pub fn baseline_metrics(&self) -> BaselineMetrics { + self.baseline_metrics.clone() + } + + pub fn input_batches(&self) -> Count { + self.input_batches.clone() + } + + pub fn input_rows(&self) -> Count { + self.input_rows.clone() + } + pub fn output_batches(&self) -> Count { + self.output_batches.clone() + } + + pub fn peak_mem_used(&self) -> Gauge { + self.peak_mem_used.clone() + } + + pub fn spill_metrics(&self) -> SpillMetrics { + self.spill_metrics.clone() + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs new file mode 100644 index 0000000000000..82f18e7414095 --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort Merge Join Execution Plan Operator + +pub use exec::SortMergeJoinExec; + +mod exec; +mod metrics; +mod stream; + +#[cfg(test)] +mod tests; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs new file mode 100644 index 0000000000000..879f47638d2c4 --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -0,0 +1,2036 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort-Merge Join execution +//! +//! This module implements the runtime state machine for the Sort-Merge Join +//! operator. It drives two sorted input streams (the *streamed* side and the +//! *buffered* side), compares join keys, and produces joined `RecordBatch`es. + +use std::cmp::Ordering; +use std::collections::{HashMap, VecDeque}; +use std::fs::File; +use std::io::BufReader; +use std::mem::size_of; +use std::ops::Range; +use std::pin::Pin; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; +use crate::joins::utils::JoinFilter; +use crate::spill::spill_manager::SpillManager; +use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; + +use arrow::array::{types::UInt64Type, *}; +use arrow::compute::{ + self, concat_batches, filter_record_batch, is_not_null, take, SortOptions, +}; +use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; +use arrow::error::ArrowError; +use arrow::ipc::reader::StreamReader; +use datafusion_common::config::SpillCompression; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, DataFusionError, HashSet, JoinSide, JoinType, + NullEquality, Result, +}; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; + +use futures::{Stream, StreamExt}; + +/// State of SMJ stream +#[derive(Debug, PartialEq, Eq)] +pub(super) enum SortMergeJoinState { + /// Init joining with a new streamed row or a new buffered batches + Init, + /// Polling one streamed row or one buffered batch, or both + Polling, + /// Joining polled data and making output + JoinOutput, + /// No more output + Exhausted, +} + +/// State of streamed data stream +#[derive(Debug, PartialEq, Eq)] +pub(super) enum StreamedState { + /// Init polling + Init, + /// Polling one streamed row + Polling, + /// Ready to produce one streamed row + Ready, + /// No more streamed row + Exhausted, +} + +/// State of buffered data stream +#[derive(Debug, PartialEq, Eq)] +pub(super) enum BufferedState { + /// Init polling + Init, + /// Polling first row in the next batch + PollingFirst, + /// Polling rest rows in the next batch + PollingRest, + /// Ready to produce one batch + Ready, + /// No more buffered batches + Exhausted, +} + +/// Represents a chunk of joined data from streamed and buffered side +pub(super) struct StreamedJoinedChunk { + /// Index of batch in buffered_data + buffered_batch_idx: Option, + /// Array builder for streamed indices + streamed_indices: UInt64Builder, + /// Array builder for buffered indices + /// This could contain nulls if the join is null-joined + buffered_indices: UInt64Builder, +} + +/// Represents a record batch from streamed input. +/// +/// Also stores information of matching rows from buffered batches. +pub(super) struct StreamedBatch { + /// The streamed record batch + pub batch: RecordBatch, + /// The index of row in the streamed batch to compare with buffered batches + pub idx: usize, + /// The join key arrays of streamed batch which are used to compare with buffered batches + /// and to produce output. They are produced by evaluating `on` expressions. + pub join_arrays: Vec, + /// Chunks of indices from buffered side (may be nulls) joined to streamed + pub output_indices: Vec, + /// Index of currently scanned batch from buffered data + pub buffered_batch_idx: Option, + /// Indices that found a match for the given join filter + /// Used for semi joins to keep track the streaming index which got a join filter match + /// and already emitted to the output. + pub join_filter_matched_idxs: HashSet, +} + +impl StreamedBatch { + fn new(batch: RecordBatch, on_column: &[Arc]) -> Self { + let join_arrays = join_arrays(&batch, on_column); + StreamedBatch { + batch, + idx: 0, + join_arrays, + output_indices: vec![], + buffered_batch_idx: None, + join_filter_matched_idxs: HashSet::new(), + } + } + + fn new_empty(schema: SchemaRef) -> Self { + StreamedBatch { + batch: RecordBatch::new_empty(schema), + idx: 0, + join_arrays: vec![], + output_indices: vec![], + buffered_batch_idx: None, + join_filter_matched_idxs: HashSet::new(), + } + } + + /// Appends new pair consisting of current streamed index and `buffered_idx` + /// index of buffered batch with `buffered_batch_idx` index. + fn append_output_pair( + &mut self, + buffered_batch_idx: Option, + buffered_idx: Option, + ) { + // If no current chunk exists or current chunk is not for current buffered batch, + // create a new chunk + if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx + { + self.output_indices.push(StreamedJoinedChunk { + buffered_batch_idx, + streamed_indices: UInt64Builder::with_capacity(1), + buffered_indices: UInt64Builder::with_capacity(1), + }); + self.buffered_batch_idx = buffered_batch_idx; + }; + let current_chunk = self.output_indices.last_mut().unwrap(); + + // Append index of streamed batch and index of buffered batch into current chunk + current_chunk.streamed_indices.append_value(self.idx as u64); + if let Some(idx) = buffered_idx { + current_chunk.buffered_indices.append_value(idx as u64); + } else { + current_chunk.buffered_indices.append_null(); + } + } +} + +/// A buffered batch that contains contiguous rows with same join key +/// +/// `BufferedBatch` can exist as either an in-memory `RecordBatch` or a `RefCountedTempFile` on disk. +#[derive(Debug)] +pub(super) struct BufferedBatch { + /// Represents in memory or spilled record batch + pub batch: BufferedBatchState, + /// The range in which the rows share the same join key + pub range: Range, + /// Array refs of the join key + pub join_arrays: Vec, + /// Buffered joined index (null joining buffered) + pub null_joined: Vec, + /// Size estimation used for reserving / releasing memory + pub size_estimation: usize, + /// The indices of buffered batch that the join filter doesn't satisfy. + /// This is a map between right row index and a boolean value indicating whether all joined row + /// of the right row does not satisfy the filter . + /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. + pub join_filter_not_matched_map: HashMap, + /// Current buffered batch number of rows. Equal to batch.num_rows() + /// but if batch is spilled to disk this property is preferable + /// and less expensive + pub num_rows: usize, +} + +impl BufferedBatch { + fn new( + batch: RecordBatch, + range: Range, + on_column: &[PhysicalExprRef], + ) -> Self { + let join_arrays = join_arrays(&batch, on_column); + + // Estimation is calculated as + // inner batch size + // + join keys size + // + worst case null_joined (as vector capacity * element size) + // + Range size + // + size of this estimation + let size_estimation = batch.get_array_memory_size() + + join_arrays + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + batch.num_rows().next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + let num_rows = batch.num_rows(); + BufferedBatch { + batch: BufferedBatchState::InMemory(batch), + range, + join_arrays, + null_joined: vec![], + size_estimation, + join_filter_not_matched_map: HashMap::new(), + num_rows, + } + } +} + +// TODO: Spill join arrays (https://github.com/apache/datafusion/pull/17429) +// Used to represent whether the buffered data is currently in memory or written to disk +#[derive(Debug)] +pub(super) enum BufferedBatchState { + // In memory record batch + InMemory(RecordBatch), + // Spilled temp file + Spilled(RefCountedTempFile), +} + +/// Sort-Merge join stream that consumes streamed and buffered data streams +/// and produces joined output stream. +pub(super) struct SortMergeJoinStream { + // ======================================================================== + // PROPERTIES: + // These fields are initialized at the start and remain constant throughout + // the execution. + // ======================================================================== + /// Output schema + pub schema: SchemaRef, + /// Defines the null equality for the join. + pub null_equality: NullEquality, + /// Sort options of join columns used to sort streamed and buffered data stream + pub sort_options: Vec, + /// optional join filter + pub filter: Option, + /// How the join is performed + pub join_type: JoinType, + /// Target output batch size + pub batch_size: usize, + + // ======================================================================== + // STREAMED FIELDS: + // These fields manage the properties and state of the streamed input. + // ======================================================================== + /// Input schema of streamed + pub streamed_schema: SchemaRef, + /// Streamed data stream + pub streamed: SendableRecordBatchStream, + /// Current processing record batch of streamed + pub streamed_batch: StreamedBatch, + /// (used in outer join) Is current streamed row joined at least once? + pub streamed_joined: bool, + /// State of streamed + pub streamed_state: StreamedState, + /// Join key columns of streamed + pub on_streamed: Vec, + + // ======================================================================== + // BUFFERED FIELDS: + // These fields manage the properties and state of the buffered input. + // ======================================================================== + /// Input schema of buffered + pub buffered_schema: SchemaRef, + /// Buffered data stream + pub buffered: SendableRecordBatchStream, + /// Current buffered data + pub buffered_data: BufferedData, + /// (used in outer join) Is current buffered batches joined at least once? + pub buffered_joined: bool, + /// State of buffered + pub buffered_state: BufferedState, + /// Join key columns of buffered + pub on_buffered: Vec, + + // ======================================================================== + // MERGE JOIN STATES: + // These fields track the execution state of merge join and are updated + // during the execution. + // ======================================================================== + /// Current state of the stream + pub state: SortMergeJoinState, + /// Staging output array builders + pub staging_output_record_batches: JoinedRecordBatches, + /// Output buffer. Currently used by filtering as it requires double buffering + /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches` + pub output: RecordBatch, + /// Staging output size, including output batches and staging joined results. + /// Increased when we put rows into buffer and decreased after we actually output batches. + /// Used to trigger output when sufficient rows are ready + pub output_size: usize, + /// The comparison result of current streamed row and buffered batches + pub current_ordering: Ordering, + /// Manages the process of spilling and reading back intermediate data + pub spill_manager: SpillManager, + + // ======================================================================== + // EXECUTION RESOURCES: + // Fields related to managing execution resources and monitoring performance. + // ======================================================================== + /// Metrics + pub join_metrics: SortMergeJoinMetrics, + /// Memory reservation + pub reservation: MemoryReservation, + /// Runtime env + pub runtime_env: Arc, + /// A unique number for each batch + pub streamed_batch_counter: AtomicUsize, +} + +/// Joined batches with attached join filter information +pub(super) struct JoinedRecordBatches { + /// Joined batches. Each batch is already joined columns from left and right sources + pub batches: Vec, + /// Filter match mask for each row(matched/non-matched) + pub filter_mask: BooleanBuilder, + /// Left row indices to glue together rows in `batches` and `filter_mask` + pub row_indices: UInt64Builder, + /// Which unique batch id the row belongs to + /// It is necessary to differentiate rows that are distributed the way when they point to the same + /// row index but in not the same batches + pub batch_ids: Vec, +} + +impl JoinedRecordBatches { + fn clear(&mut self) { + self.batches.clear(); + self.batch_ids.clear(); + self.filter_mask = BooleanBuilder::new(); + self.row_indices = UInt64Builder::new(); + } +} +impl RecordBatchStream for SortMergeJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +/// True if next index refers to either: +/// - another batch id +/// - another row index within same batch id +/// - end of row indices +#[inline(always)] +fn last_index_for_row( + row_index: usize, + indices: &UInt64Array, + batch_ids: &[usize], + indices_len: usize, +) -> bool { + row_index == indices_len - 1 + || batch_ids[row_index] != batch_ids[row_index + 1] + || indices.value(row_index) != indices.value(row_index + 1) +} + +// Returns a corrected boolean bitmask for the given join type +// Values in the corrected bitmask can be: true, false, null +// `true` - the row found its match and sent to the output +// `null` - the row ignored, no output +// `false` - the row sent as NULL joined row +pub(super) fn get_corrected_filter_mask( + join_type: JoinType, + row_indices: &UInt64Array, + batch_ids: &[usize], + filter_mask: &BooleanArray, + expected_size: usize, +) -> Option { + let row_indices_length = row_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(row_indices_length); + let mut seen_true = false; + + match join_type { + JoinType::Left | JoinType::Right => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + corrected_mask.append_n(expected_size - corrected_mask.len(), false); + Some(corrected_mask.finish()) + } + JoinType::LeftMark | JoinType::RightMark => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + corrected_mask.append_n(expected_size - corrected_mask.len(), false); + Some(corrected_mask.finish()) + } + JoinType::LeftSemi | JoinType::RightSemi => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); // to be ignored and not set to output + } + + if last_index { + seen_true = false; + } + } + + Some(corrected_mask.finish()) + } + JoinType::LeftAnti | JoinType::RightAnti => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + + if filter_mask.value(i) { + seen_true = true; + } + + if last_index { + if !seen_true { + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); + } + + seen_true = false; + } else { + corrected_mask.append_null(); + } + } + // Generate null joined rows for records which have no matching join key, + // for LeftAnti non-matched considered as true + corrected_mask.append_n(expected_size - corrected_mask.len(), true); + Some(corrected_mask.finish()) + } + JoinType::Full => { + let mut mask: Vec> = vec![Some(true); row_indices_length]; + let mut last_true_idx = 0; + let mut first_row_idx = 0; + let mut seen_false = false; + + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + let val = filter_mask.value(i); + let is_null = filter_mask.is_null(i); + + if val { + // memoize the first seen matched row + if !seen_true { + last_true_idx = i; + } + seen_true = true; + } + + if is_null || val { + mask[i] = Some(true); + } else if !is_null && !val && (seen_true || seen_false) { + mask[i] = None; + } else { + mask[i] = Some(false); + } + + if !is_null && !val { + seen_false = true; + } + + if last_index { + // If the left row seen as true its needed to output it once + // To do that we mark all other matches for same row as null to avoid the output + if seen_true { + #[allow(clippy::needless_range_loop)] + for j in first_row_idx..last_true_idx { + mask[j] = None; + } + } + + seen_true = false; + seen_false = false; + last_true_idx = 0; + first_row_idx = i + 1; + } + } + + Some(BooleanArray::from(mask)) + } + // Only outer joins needs to keep track of processed rows and apply corrected filter mask + _ => None, + } +} + +impl Stream for SortMergeJoinStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let join_time = self.join_metrics.join_time().clone(); + let _timer = join_time.timer(); + loop { + match &self.state { + SortMergeJoinState::Init => { + let streamed_exhausted = + self.streamed_state == StreamedState::Exhausted; + let buffered_exhausted = + self.buffered_state == BufferedState::Exhausted; + self.state = if streamed_exhausted && buffered_exhausted { + SortMergeJoinState::Exhausted + } else { + match self.current_ordering { + Ordering::Less | Ordering::Equal => { + if !streamed_exhausted { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightMark + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::Full + ) + { + self.freeze_all()?; + + // If join is filtered and there is joined tuples waiting + // to be filtered + if !self + .staging_output_record_batches + .batches + .is_empty() + { + // Apply filter on joined tuples and get filtered batch + let out_filtered_batch = + self.filter_joined_batch()?; + + // Append filtered batch to the output buffer + self.output = concat_batches( + &self.schema(), + vec![&self.output, &out_filtered_batch], + )?; + + // Send to output if the output buffer surpassed the `batch_size` + if self.output.num_rows() >= self.batch_size { + let record_batch = std::mem::replace( + &mut self.output, + RecordBatch::new_empty( + out_filtered_batch.schema(), + ), + ); + return Poll::Ready(Some(Ok( + record_batch, + ))); + } + } + } + + self.streamed_joined = false; + self.streamed_state = StreamedState::Init; + } + } + Ordering::Greater => { + if !buffered_exhausted { + self.buffered_joined = false; + self.buffered_state = BufferedState::Init; + } + } + } + SortMergeJoinState::Polling + }; + } + SortMergeJoinState::Polling => { + if ![StreamedState::Exhausted, StreamedState::Ready] + .contains(&self.streamed_state) + { + match self.poll_streamed_row(cx)? { + Poll::Ready(_) => {} + Poll::Pending => return Poll::Pending, + } + } + + if ![BufferedState::Exhausted, BufferedState::Ready] + .contains(&self.buffered_state) + { + match self.poll_buffered_batches(cx)? { + Poll::Ready(_) => {} + Poll::Pending => return Poll::Pending, + } + } + let streamed_exhausted = + self.streamed_state == StreamedState::Exhausted; + let buffered_exhausted = + self.buffered_state == BufferedState::Exhausted; + if streamed_exhausted && buffered_exhausted { + self.state = SortMergeJoinState::Exhausted; + continue; + } + self.current_ordering = self.compare_streamed_buffered()?; + self.state = SortMergeJoinState::JoinOutput; + } + SortMergeJoinState::JoinOutput => { + self.join_partial()?; + + if self.output_size < self.batch_size { + if self.buffered_data.scanning_finished() { + self.buffered_data.scanning_reset(); + self.state = SortMergeJoinState::Init; + } + } else { + self.freeze_all()?; + if !self.staging_output_record_batches.batches.is_empty() { + let record_batch = self.output_record_batch_and_reset()?; + // For non-filtered join output whenever the target output batch size + // is hit. For filtered join its needed to output on later phase + // because target output batch size can be hit in the middle of + // filtering causing the filtering to be incomplete and causing + // correctness issues + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + | JoinType::Full + ) + { + continue; + } + + return Poll::Ready(Some(Ok(record_batch))); + } + return Poll::Pending; + } + } + SortMergeJoinState::Exhausted => { + self.freeze_all()?; + + // if there is still something not processed + if !self.staging_output_record_batches.batches.is_empty() { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::Full + | JoinType::LeftMark + | JoinType::RightMark + ) + { + let record_batch = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(record_batch))); + } else { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + } else if self.output.num_rows() > 0 { + // if processed but still not outputted because it didn't hit batch size before + let schema = self.output.schema(); + let record_batch = std::mem::replace( + &mut self.output, + RecordBatch::new_empty(schema), + ); + return Poll::Ready(Some(Ok(record_batch))); + } else { + return Poll::Ready(None); + } + } + } + } + } +} + +impl SortMergeJoinStream { + #[allow(clippy::too_many_arguments)] + pub fn try_new( + // Configured via `datafusion.execution.spill_compression`. + spill_compression: SpillCompression, + schema: SchemaRef, + sort_options: Vec, + null_equality: NullEquality, + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec>, + on_buffered: Vec>, + filter: Option, + join_type: JoinType, + batch_size: usize, + join_metrics: SortMergeJoinMetrics, + reservation: MemoryReservation, + runtime_env: Arc, + ) -> Result { + let streamed_schema = streamed.schema(); + let buffered_schema = buffered.schema(); + let spill_manager = SpillManager::new( + Arc::clone(&runtime_env), + join_metrics.spill_metrics().clone(), + Arc::clone(&buffered_schema), + ) + .with_compression_type(spill_compression); + Ok(Self { + state: SortMergeJoinState::Init, + sort_options, + null_equality, + schema: Arc::clone(&schema), + streamed_schema: Arc::clone(&streamed_schema), + buffered_schema, + streamed, + buffered, + streamed_batch: StreamedBatch::new_empty(streamed_schema), + buffered_data: BufferedData::default(), + streamed_joined: false, + buffered_joined: false, + streamed_state: StreamedState::Init, + buffered_state: BufferedState::Init, + current_ordering: Ordering::Equal, + on_streamed, + on_buffered, + filter, + staging_output_record_batches: JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }, + output: RecordBatch::new_empty(schema), + output_size: 0, + batch_size, + join_type, + join_metrics, + reservation, + runtime_env, + spill_manager, + streamed_batch_counter: AtomicUsize::new(0), + }) + } + + /// Poll next streamed row + fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { + loop { + match &self.streamed_state { + StreamedState::Init => { + if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows() + { + self.streamed_batch.idx += 1; + self.streamed_state = StreamedState::Ready; + return Poll::Ready(Some(Ok(()))); + } else { + self.streamed_state = StreamedState::Polling; + } + } + StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.streamed_state = StreamedState::Exhausted; + } + Poll::Ready(Some(batch)) => { + if batch.num_rows() > 0 { + self.freeze_streamed()?; + self.join_metrics.input_batches().add(1); + self.join_metrics.input_rows().add(batch.num_rows()); + self.streamed_batch = + StreamedBatch::new(batch, &self.on_streamed); + // Every incoming streaming batch should have its unique id + // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation + self.streamed_batch_counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.streamed_state = StreamedState::Ready; + } + } + }, + StreamedState::Ready => { + return Poll::Ready(Some(Ok(()))); + } + StreamedState::Exhausted => { + return Poll::Ready(None); + } + } + } + } + + fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> { + // Shrink memory usage for in-memory batches only + if let BufferedBatchState::InMemory(_) = buffered_batch.batch { + self.reservation + .try_shrink(buffered_batch.size_estimation)?; + } + Ok(()) + } + + fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { + match self.reservation.try_grow(buffered_batch.size_estimation) { + Ok(_) => { + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + Ok(()) + } + Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => { + // Spill buffered batch to disk + + match buffered_batch.batch { + BufferedBatchState::InMemory(batch) => { + let spill_file = self + .spill_manager + .spill_record_batch_and_finish( + &[batch], + "sort_merge_join_buffered_spill", + )? + .unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled + + buffered_batch.batch = BufferedBatchState::Spilled(spill_file); + Ok(()) + } + _ => internal_err!("Buffered batch has empty body"), + } + } + Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()), + }?; + + self.buffered_data.batches.push_back(buffered_batch); + Ok(()) + } + + /// Poll next buffered batches + fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll>> { + loop { + match &self.buffered_state { + BufferedState::Init => { + // pop previous buffered batches + while !self.buffered_data.batches.is_empty() { + let head_batch = self.buffered_data.head_batch(); + // If the head batch is fully processed, dequeue it and produce output of it. + if head_batch.range.end == head_batch.num_rows { + self.freeze_dequeuing_buffered()?; + if let Some(mut buffered_batch) = + self.buffered_data.batches.pop_front() + { + self.produce_buffered_not_matched(&mut buffered_batch)?; + self.free_reservation(buffered_batch)?; + } + } else { + // If the head batch is not fully processed, break the loop. + // Streamed batch will be joined with the head batch in the next step. + break; + } + } + if self.buffered_data.batches.is_empty() { + self.buffered_state = BufferedState::PollingFirst; + } else { + let tail_batch = self.buffered_data.tail_batch_mut(); + tail_batch.range.start = tail_batch.range.end; + tail_batch.range.end += 1; + self.buffered_state = BufferedState::PollingRest; + } + } + BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.buffered_state = BufferedState::Exhausted; + return Poll::Ready(None); + } + Poll::Ready(Some(batch)) => { + self.join_metrics.input_batches().add(1); + self.join_metrics.input_rows().add(batch.num_rows()); + + if batch.num_rows() > 0 { + let buffered_batch = + BufferedBatch::new(batch, 0..1, &self.on_buffered); + + self.allocate_reservation(buffered_batch)?; + self.buffered_state = BufferedState::PollingRest; + } + } + }, + BufferedState::PollingRest => { + if self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().num_rows + { + while self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().num_rows + { + if is_join_arrays_equal( + &self.buffered_data.head_batch().join_arrays, + self.buffered_data.head_batch().range.start, + &self.buffered_data.tail_batch().join_arrays, + self.buffered_data.tail_batch().range.end, + )? { + self.buffered_data.tail_batch_mut().range.end += 1; + } else { + self.buffered_state = BufferedState::Ready; + return Poll::Ready(Some(Ok(()))); + } + } + } else { + match self.buffered.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.buffered_state = BufferedState::Ready; + } + Poll::Ready(Some(batch)) => { + // Polling batches coming concurrently as multiple partitions + self.join_metrics.input_batches().add(1); + self.join_metrics.input_rows().add(batch.num_rows()); + if batch.num_rows() > 0 { + let buffered_batch = BufferedBatch::new( + batch, + 0..0, + &self.on_buffered, + ); + self.allocate_reservation(buffered_batch)?; + } + } + } + } + } + BufferedState::Ready => { + return Poll::Ready(Some(Ok(()))); + } + BufferedState::Exhausted => { + return Poll::Ready(None); + } + } + } + } + + /// Get comparison result of streamed row and buffered batches + fn compare_streamed_buffered(&self) -> Result { + if self.streamed_state == StreamedState::Exhausted { + return Ok(Ordering::Greater); + } + if !self.buffered_data.has_buffered_rows() { + return Ok(Ordering::Less); + } + + compare_join_arrays( + &self.streamed_batch.join_arrays, + self.streamed_batch.idx, + &self.buffered_data.head_batch().join_arrays, + self.buffered_data.head_batch().range.start, + &self.sort_options, + self.null_equality, + ) + } + + /// Produce join and fill output buffer until reaching target batch size + /// or the join is finished + fn join_partial(&mut self) -> Result<()> { + // Whether to join streamed rows + let mut join_streamed = false; + // Whether to join buffered rows + let mut join_buffered = false; + // For Mark join we store a dummy id to indicate the the row has a match + let mut mark_row_as_match = false; + + // determine whether we need to join streamed/buffered rows + match self.current_ordering { + Ordering::Less => { + if matches!( + self.join_type, + JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ) { + join_streamed = !self.streamed_joined; + } + } + Ordering::Equal => { + if matches!( + self.join_type, + JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::RightSemi + | JoinType::RightMark + ) { + mark_row_as_match = matches!( + self.join_type, + JoinType::LeftMark | JoinType::RightMark + ); + // if the join filter is specified then its needed to output the streamed index + // only if it has not been emitted before + // the `join_filter_matched_idxs` keeps track on if streamed index has a successful + // filter match and prevents the same index to go into output more than once + if self.filter.is_some() { + join_streamed = !self + .streamed_batch + .join_filter_matched_idxs + .contains(&(self.streamed_batch.idx as u64)) + && !self.streamed_joined; + // if the join filter specified there can be references to buffered columns + // so buffered columns are needed to access them + join_buffered = join_streamed; + } else { + join_streamed = !self.streamed_joined; + } + } + if matches!( + self.join_type, + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ) { + join_streamed = true; + join_buffered = true; + }; + + if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti) + && self.filter.is_some() + { + join_streamed = !self.streamed_joined; + join_buffered = join_streamed; + } + } + Ordering::Greater => { + if matches!(self.join_type, JoinType::Full) { + join_buffered = !self.buffered_joined; + }; + } + } + if !join_streamed && !join_buffered { + // no joined data + self.buffered_data.scanning_finish(); + return Ok(()); + } + + if join_buffered { + // joining streamed/nulls and buffered + while !self.buffered_data.scanning_finished() + && self.output_size < self.batch_size + { + let scanning_idx = self.buffered_data.scanning_idx(); + if join_streamed { + // Join streamed row and buffered row + self.streamed_batch.append_output_pair( + Some(self.buffered_data.scanning_batch_idx), + Some(scanning_idx), + ); + } else { + // Join nulls and buffered row for FULL join + self.buffered_data + .scanning_batch_mut() + .null_joined + .push(scanning_idx); + } + self.output_size += 1; + self.buffered_data.scanning_advance(); + + if self.buffered_data.scanning_finished() { + self.streamed_joined = join_streamed; + self.buffered_joined = true; + } + } + } else { + // joining streamed and nulls + let scanning_batch_idx = if self.buffered_data.scanning_finished() { + None + } else { + Some(self.buffered_data.scanning_batch_idx) + }; + // For Mark join we store a dummy id to indicate the the row has a match + let scanning_idx = mark_row_as_match.then_some(0); + + self.streamed_batch + .append_output_pair(scanning_batch_idx, scanning_idx); + self.output_size += 1; + self.buffered_data.scanning_finish(); + self.streamed_joined = true; + } + Ok(()) + } + + fn freeze_all(&mut self) -> Result<()> { + self.freeze_buffered(self.buffered_data.batches.len())?; + self.freeze_streamed()?; + Ok(()) + } + + // Produces and stages record batches to ensure dequeued buffered batch + // no longer needed: + // 1. freezes all indices joined to streamed side + // 2. freezes NULLs joined to dequeued buffered batch to "release" it + fn freeze_dequeuing_buffered(&mut self) -> Result<()> { + self.freeze_streamed()?; + // Only freeze and produce the first batch in buffered_data as the batch is fully processed + self.freeze_buffered(1)?; + Ok(()) + } + + // Produces and stages record batch from buffered indices with corresponding + // NULLs on streamed side. + // + // Applicable only in case of Full join. + // + fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { + if !matches!(self.join_type, JoinType::Full) { + return Ok(()); + } + for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) { + let buffered_indices = UInt64Array::from_iter_values( + buffered_batch.null_joined.iter().map(|&index| index as u64), + ); + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + let num_rows = record_batch.num_rows(); + self.staging_output_record_batches + .filter_mask + .append_nulls(num_rows); + self.staging_output_record_batches + .row_indices + .append_nulls(num_rows); + self.staging_output_record_batches.batch_ids.resize( + self.staging_output_record_batches.batch_ids.len() + num_rows, + 0, + ); + + self.staging_output_record_batches + .batches + .push(record_batch); + } + buffered_batch.null_joined.clear(); + } + Ok(()) + } + + fn produce_buffered_not_matched( + &mut self, + buffered_batch: &mut BufferedBatch, + ) -> Result<()> { + if !matches!(self.join_type, JoinType::Full) { + return Ok(()); + } + + // For buffered row which is joined with streamed side rows but all joined rows + // don't satisfy the join filter + let not_matched_buffered_indices = buffered_batch + .join_filter_not_matched_map + .iter() + .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) + .collect::>(); + + let buffered_indices = + UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied()); + + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + let num_rows = record_batch.num_rows(); + + self.staging_output_record_batches + .filter_mask + .append_nulls(num_rows); + self.staging_output_record_batches + .row_indices + .append_nulls(num_rows); + self.staging_output_record_batches.batch_ids.resize( + self.staging_output_record_batches.batch_ids.len() + num_rows, + 0, + ); + self.staging_output_record_batches + .batches + .push(record_batch); + } + buffered_batch.join_filter_not_matched_map.clear(); + + Ok(()) + } + + // Produces and stages record batch for all output indices found + // for current streamed batch and clears staged output indices. + fn freeze_streamed(&mut self) -> Result<()> { + for chunk in self.streamed_batch.output_indices.iter_mut() { + // The row indices of joined streamed batch + let left_indices = chunk.streamed_indices.finish(); + + if left_indices.is_empty() { + continue; + } + + let mut left_columns = self + .streamed_batch + .batch + .columns() + .iter() + .map(|column| take(column, &left_indices, None)) + .collect::, ArrowError>>()?; + + // The row indices of joined buffered batch + let right_indices: UInt64Array = chunk.buffered_indices.finish(); + let mut right_columns = + if matches!(self.join_type, JoinType::LeftMark | JoinType::RightMark) { + vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] + } else if matches!( + self.join_type, + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::RightSemi + ) { + vec![] + } else if let Some(buffered_idx) = chunk.buffered_batch_idx { + fetch_right_columns_by_idxs( + &self.buffered_data, + buffered_idx, + &right_indices, + )? + } else { + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. + create_unmatched_columns( + self.join_type, + &self.buffered_schema, + right_indices.len(), + ) + }; + + // Prepare the columns we apply join filter on later. + // Only for joined rows between streamed and buffered. + let filter_columns = if chunk.buffered_batch_idx.is_some() { + if !matches!(self.join_type, JoinType::Right) { + if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark + ) { + let right_cols = fetch_right_columns_by_idxs( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &right_indices, + )?; + + get_filter_column(&self.filter, &left_columns, &right_cols) + } else if matches!( + self.join_type, + JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark + ) { + let right_cols = fetch_right_columns_by_idxs( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &right_indices, + )?; + + get_filter_column(&self.filter, &right_cols, &left_columns) + } else { + get_filter_column(&self.filter, &left_columns, &right_columns) + } + } else { + get_filter_column(&self.filter, &right_columns, &left_columns) + } + } else { + // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. + // Any join filter applied only on either streamed or buffered side will be pushed already. + vec![] + }; + + let columns = if !matches!(self.join_type, JoinType::Right) { + left_columns.extend(right_columns); + left_columns + } else { + right_columns.extend(left_columns); + right_columns + }; + + let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + // Apply join filter if any + if !filter_columns.is_empty() { + if let Some(f) = &self.filter { + // Construct batch with only filter columns + let filter_batch = + RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?; + + let filter_result = f + .expression() + .evaluate(&filter_batch)? + .into_array(filter_batch.num_rows())?; + + // The boolean selection mask of the join filter result + let pre_mask = + datafusion_common::cast::as_boolean_array(&filter_result)?; + + // If there are nulls in join filter result, exclude them from selecting + // the rows to output. + let mask = if pre_mask.null_count() > 0 { + compute::prep_null_mask_filter( + datafusion_common::cast::as_boolean_array(&filter_result)?, + ) + } else { + pre_mask.clone() + }; + + // Push the filtered batch which contains rows passing join filter to the output + if matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + | JoinType::Full + ) { + self.staging_output_record_batches + .batches + .push(output_batch); + } else { + let filtered_batch = filter_record_batch(&output_batch, &mask)?; + self.staging_output_record_batches + .batches + .push(filtered_batch); + } + + if !matches!(self.join_type, JoinType::Full) { + self.staging_output_record_batches.filter_mask.extend(&mask); + } else { + self.staging_output_record_batches + .filter_mask + .extend(pre_mask); + } + self.staging_output_record_batches + .row_indices + .extend(&left_indices); + self.staging_output_record_batches.batch_ids.resize( + self.staging_output_record_batches.batch_ids.len() + + left_indices.len(), + self.streamed_batch_counter.load(Relaxed), + ); + + // For outer joins, we need to push the null joined rows to the output if + // all joined rows are failed on the join filter. + // I.e., if all rows joined from a streamed row are failed with the join filter, + // we need to join it with nulls as buffered side. + if matches!(self.join_type, JoinType::Full) { + let buffered_batch = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()]; + + for i in 0..pre_mask.len() { + // If the buffered row is not joined with streamed side, + // skip it. + if right_indices.is_null(i) { + continue; + } + + let buffered_index = right_indices.value(i); + + buffered_batch.join_filter_not_matched_map.insert( + buffered_index, + *buffered_batch + .join_filter_not_matched_map + .get(&buffered_index) + .unwrap_or(&true) + && !pre_mask.value(i), + ); + } + } + } else { + self.staging_output_record_batches + .batches + .push(output_batch); + } + } else { + self.staging_output_record_batches + .batches + .push(output_batch); + } + } + + self.streamed_batch.output_indices.clear(); + + Ok(()) + } + + fn output_record_batch_and_reset(&mut self) -> Result { + let record_batch = + concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; + self.join_metrics.output_batches().add(1); + self.join_metrics + .baseline_metrics() + .record_output(record_batch.num_rows()); + // If join filter exists, `self.output_size` is not accurate as we don't know the exact + // number of rows in the output record batch. If streamed row joined with buffered rows, + // once join filter is applied, the number of output rows may be more than 1. + // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened + // when the join filter is applied and all rows are filtered out. + if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size { + self.output_size = 0; + } else { + self.output_size -= record_batch.num_rows(); + } + + if !(self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + | JoinType::Full + )) + { + self.staging_output_record_batches.batches.clear(); + } + + Ok(record_batch) + } + + fn filter_joined_batch(&mut self) -> Result { + let record_batch = + concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; + let mut out_indices = self.staging_output_record_batches.row_indices.finish(); + let mut out_mask = self.staging_output_record_batches.filter_mask.finish(); + let mut batch_ids = &self.staging_output_record_batches.batch_ids; + let default_batch_ids = vec![0; record_batch.num_rows()]; + + // If only nulls come in and indices sizes doesn't match with expected record batch count + // generate missing indices + // Happens for null joined batches for Full Join + if out_indices.null_count() == out_indices.len() + && out_indices.len() != record_batch.num_rows() + { + out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]); + out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]); + batch_ids = &default_batch_ids; + } + + if out_mask.is_empty() { + self.staging_output_record_batches.batches.clear(); + return Ok(record_batch); + } + + let maybe_corrected_mask = get_corrected_filter_mask( + self.join_type, + &out_indices, + batch_ids, + &out_mask, + record_batch.num_rows(), + ); + + let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask { + filtered_join_mask + } else { + &out_mask + }; + + self.filter_record_batch_by_join_type(record_batch, corrected_mask) + } + + fn filter_record_batch_by_join_type( + &mut self, + record_batch: RecordBatch, + corrected_mask: &BooleanArray, + ) -> Result { + let mut filtered_record_batch = + filter_record_batch(&record_batch, corrected_mask)?; + let left_columns_length = self.streamed_schema.fields.len(); + let right_columns_length = self.buffered_schema.fields.len(); + + if matches!( + self.join_type, + JoinType::Left | JoinType::LeftMark | JoinType::Right | JoinType::RightMark + ) { + let null_mask = compute::not(corrected_mask)?; + let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; + + let mut right_columns = create_unmatched_columns( + self.join_type, + &self.buffered_schema, + null_joined_batch.num_rows(), + ); + + let columns = if !matches!(self.join_type, JoinType::Right) { + let mut left_columns = null_joined_batch + .columns() + .iter() + .take(right_columns_length) + .cloned() + .collect::>(); + + left_columns.extend(right_columns); + left_columns + } else { + let left_columns = null_joined_batch + .columns() + .iter() + .skip(left_columns_length) + .cloned() + .collect::>(); + + right_columns.extend(left_columns); + right_columns + }; + + // Push the streamed/buffered batch joined nulls to the output + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, null_joined_streamed_batch], + )?; + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + let output_column_indices = (0..left_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; + } else if matches!(self.join_type, JoinType::RightAnti | JoinType::RightSemi) { + let output_column_indices = (0..right_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; + } else if matches!(self.join_type, JoinType::Full) + && corrected_mask.false_count() > 0 + { + // Find rows which joined by key but Filter predicate evaluated as false + let joined_filter_not_matched_mask = compute::not(corrected_mask)?; + let joined_filter_not_matched_batch = + filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?; + + // Add left unmatched rows adding the right side as nulls + let right_null_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + joined_filter_not_matched_batch.num_rows(), + ) + }) + .collect::>(); + + let mut result_joined = joined_filter_not_matched_batch + .columns() + .iter() + .take(left_columns_length) + .cloned() + .collect::>(); + + result_joined.extend(right_null_columns); + + let left_null_joined_batch = + RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?; + + // Add right unmatched rows adding the left side as nulls + let mut result_joined = self + .streamed_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + joined_filter_not_matched_batch.num_rows(), + ) + }) + .collect::>(); + + let right_data = joined_filter_not_matched_batch + .columns() + .iter() + .skip(left_columns_length) + .cloned() + .collect::>(); + + result_joined.extend(right_data); + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, left_null_joined_batch], + )?; + } + + self.staging_output_record_batches.clear(); + + Ok(filtered_record_batch) + } +} + +fn create_unmatched_columns( + join_type: JoinType, + schema: &SchemaRef, + size: usize, +) -> Vec { + if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) { + vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef] + } else { + schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), size)) + .collect::>() + } +} + +/// Gets the arrays which join filters are applied on. +fn get_filter_column( + join_filter: &Option, + streamed_columns: &[ArrayRef], + buffered_columns: &[ArrayRef], +) -> Vec { + let mut filter_columns = vec![]; + + if let Some(f) = join_filter { + let left_columns = f + .column_indices() + .iter() + .filter(|col_index| col_index.side == JoinSide::Left) + .map(|i| Arc::clone(&streamed_columns[i.index])) + .collect::>(); + + let right_columns = f + .column_indices() + .iter() + .filter(|col_index| col_index.side == JoinSide::Right) + .map(|i| Arc::clone(&buffered_columns[i.index])) + .collect::>(); + + filter_columns.extend(left_columns); + filter_columns.extend(right_columns); + } + + filter_columns +} + +fn produce_buffered_null_batch( + schema: &SchemaRef, + streamed_schema: &SchemaRef, + buffered_indices: &PrimitiveArray, + buffered_batch: &BufferedBatch, +) -> Result> { + if buffered_indices.is_empty() { + return Ok(None); + } + + // Take buffered (right) columns + let right_columns = + fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?; + + // Create null streamed (left) columns + let mut left_columns = streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), buffered_indices.len())) + .collect::>(); + + left_columns.extend(right_columns); + + Ok(Some(RecordBatch::try_new( + Arc::clone(schema), + left_columns, + )?)) +} + +/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices +#[inline(always)] +fn fetch_right_columns_by_idxs( + buffered_data: &BufferedData, + buffered_batch_idx: usize, + buffered_indices: &UInt64Array, +) -> Result> { + fetch_right_columns_from_batch_by_idxs( + &buffered_data.batches[buffered_batch_idx], + buffered_indices, + ) +} + +#[inline(always)] +fn fetch_right_columns_from_batch_by_idxs( + buffered_batch: &BufferedBatch, + buffered_indices: &UInt64Array, +) -> Result> { + match &buffered_batch.batch { + // In memory batch + BufferedBatchState::InMemory(batch) => Ok(batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() + .map_err(Into::::into)?), + // If the batch was spilled to disk, less likely + BufferedBatchState::Spilled(spill_file) => { + let mut buffered_cols: Vec = + Vec::with_capacity(buffered_indices.len()); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + + for batch in reader { + batch?.columns().iter().for_each(|column| { + buffered_cols.extend(take(column, &buffered_indices, None)) + }); + } + + Ok(buffered_cols) + } + } +} + +/// Buffered data contains all buffered batches with one unique join key +#[derive(Debug, Default)] +pub(super) struct BufferedData { + /// Buffered batches with the same key + pub batches: VecDeque, + /// current scanning batch index used in join_partial() + pub scanning_batch_idx: usize, + /// current scanning offset used in join_partial() + pub scanning_offset: usize, +} + +impl BufferedData { + pub fn head_batch(&self) -> &BufferedBatch { + self.batches.front().unwrap() + } + + pub fn tail_batch(&self) -> &BufferedBatch { + self.batches.back().unwrap() + } + + pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch { + self.batches.back_mut().unwrap() + } + + pub fn has_buffered_rows(&self) -> bool { + self.batches.iter().any(|batch| !batch.range.is_empty()) + } + + pub fn scanning_reset(&mut self) { + self.scanning_batch_idx = 0; + self.scanning_offset = 0; + } + + pub fn scanning_advance(&mut self) { + self.scanning_offset += 1; + while !self.scanning_finished() && self.scanning_batch_finished() { + self.scanning_batch_idx += 1; + self.scanning_offset = 0; + } + } + + pub fn scanning_batch(&self) -> &BufferedBatch { + &self.batches[self.scanning_batch_idx] + } + + pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch { + &mut self.batches[self.scanning_batch_idx] + } + + pub fn scanning_idx(&self) -> usize { + self.scanning_batch().range.start + self.scanning_offset + } + + pub fn scanning_batch_finished(&self) -> bool { + self.scanning_offset == self.scanning_batch().range.len() + } + + pub fn scanning_finished(&self) -> bool { + self.scanning_batch_idx == self.batches.len() + } + + pub fn scanning_finish(&mut self) { + self.scanning_batch_idx = self.batches.len(); + self.scanning_offset = 0; + } +} + +/// Get join array refs of given batch and join columns +fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec { + on_column + .iter() + .map(|c| { + let num_rows = batch.num_rows(); + let c = c.evaluate(batch).unwrap(); + c.into_array(num_rows).unwrap() + }) + .collect() +} + +/// Get comparison result of two rows of join arrays +fn compare_join_arrays( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let mut res = Ordering::Equal; + for ((left_array, right_array), sort_options) in + left_arrays.iter().zip(right_arrays).zip(sort_options) + { + macro_rules! compare_value { + ($T:ty) => {{ + let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_value = &left_array.value(left); + let right_value = &right_array.value(right); + res = left_value.partial_cmp(right_value).unwrap(); + if sort_options.descending { + res = res.reverse(); + } + } + (true, false) => { + res = if sort_options.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + res = if sort_options.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + _ => { + res = match null_equality { + NullEquality::NullEqualsNothing => Ordering::Less, + NullEquality::NullEqualsNull => Ordering::Equal, + }; + } + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), + DataType::Utf8 => compare_value!(StringArray), + DataType::Utf8View => compare_value!(StringViewArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Decimal32(..) => compare_value!(Decimal32Array), + DataType::Decimal64(..) => compare_value!(Decimal64Array), + DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, + DataType::Date32 => compare_value!(Date32Array), + DataType::Date64 => compare_value!(Date64Array), + dt => { + return not_impl_err!( + "Unsupported data type in sort merge join comparator: {}", + dt + ); + } + } + if !res.is_eq() { + break; + } + } + Ok(res) +} + +/// A faster version of compare_join_arrays() that only output whether +/// the given two rows are equal +fn is_join_arrays_equal( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, +) -> Result { + let mut is_equal = true; + for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { + macro_rules! compare_value { + ($T:ty) => {{ + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_array = + left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = + right_array.as_any().downcast_ref::<$T>().unwrap(); + if left_array.value(left) != right_array.value(right) { + is_equal = false; + } + } + (true, false) => is_equal = false, + (false, true) => is_equal = false, + _ => {} + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), + DataType::Utf8 => compare_value!(StringArray), + DataType::Utf8View => compare_value!(StringViewArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Decimal32(..) => compare_value!(Decimal32Array), + DataType::Decimal64(..) => compare_value!(Decimal64Array), + DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Decimal256(..) => compare_value!(Decimal256Array), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, + DataType::Date32 => compare_value!(Date32Array), + DataType::Date64 => compare_value!(Date64Array), + dt => { + return not_impl_err!( + "Unsupported data type in sort merge join comparator: {}", + dt + ); + } + } + if !is_equal { + return Ok(false); + } + } + Ok(true) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs new file mode 100644 index 0000000000000..83a5c4041cc03 --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -0,0 +1,2747 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SortMergeJoin Testing Module +//! +//! This module currently contains the following test types in this order: +//! - Join behaviour (left, right, full, inner, semi, anti, mark) +//! - Batch spilling +//! - Filter mask +//! +//! Add relevant tests under the specified sections. + +use std::sync::Arc; + +use arrow::array::{ + builder::{BooleanBuilder, UInt64Builder}, + BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray, + Int32Array, RecordBatch, UInt64Array, +}; +use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; +use arrow::datatypes::{DataType, Field, Schema}; + +use datafusion_common::JoinType::*; +use datafusion_common::{ + assert_batches_eq, assert_contains, JoinType, NullEquality, Result, +}; +use datafusion_common::{ + test_util::{batches_to_sort_string, batches_to_string}, + JoinSide, +}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion_execution::TaskContext; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::BinaryExpr; +use insta::{allow_duplicates, assert_snapshot}; + +use crate::{ + expressions::Column, + joins::sort_merge_join::stream::{get_corrected_filter_mask, JoinedRecordBatches}, +}; + +use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; +use crate::joins::SortMergeJoinExec; +use crate::test::TestMemoryExec; +use crate::test::{build_table_i32, build_table_i32_two_cols}; +use crate::{common, ExecutionPlan}; + +fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() +} + +fn build_table_from_batches(batches: Vec) -> Arc { + let schema = batches.first().unwrap().schema(); + TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap() +} + +fn build_date_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() +} + +fn build_date64_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() +} + +fn build_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Binary, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(BinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() +} + +fn build_fixed_size_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::FixedSizeBinary(3), false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(FixedSizeBinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() +} + +/// returns a table with 3 columns of i32 in memory +pub fn build_table_i32_nullable( + a: (&str, &Vec>), + b: (&str, &Vec>), + c: (&str, &Vec>), +) -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + Field::new(c.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() +} + +pub fn build_table_two_cols( + a: (&str, &Vec), + b: (&str, &Vec), +) -> Arc { + let batch = build_table_i32_two_cols(a, b); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() +} + +fn join( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, +) -> Result { + let sort_options = vec![SortOptions::default(); on.len()]; + SortMergeJoinExec::try_new( + left, + right, + on, + None, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + ) +} + +fn join_with_options( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, +) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + None, + join_type, + sort_options, + null_equality, + ) +} + +fn join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, +) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + Some(filter), + join_type, + sort_options, + null_equality, + ) +} + +async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, +) -> Result<(Vec, Vec)> { + let sort_options = vec![SortOptions::default(); on.len()]; + join_collect_with_options( + left, + right, + on, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + ) + .await +} + +async fn join_collect_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: JoinType, +) -> Result<(Vec, Vec)> { + let sort_options = vec![SortOptions::default(); on.len()]; + + let task_ctx = Arc::new(TaskContext::default()); + let join = join_with_filter( + left, + right, + on, + filter, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + )?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) +} + +async fn join_collect_with_options( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, +) -> Result<(Vec, Vec)> { + let task_ctx = Arc::new(TaskContext::default()); + let join = + join_with_options(left, right, on, join_type, sort_options, null_equality)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) +} + +async fn join_collect_batch_size_equals_two( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, +) -> Result<(Vec, Vec)> { + let task_ctx = TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(2)); + let task_ctx = Arc::new(task_ctx); + let join = join(left, right, on, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) +} + +#[tokio::test] +async fn join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 5 | 9 | 20 | 5 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_inner_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + ), + ]; + + let (_columns, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | 7 | 1 | 1 | 70 | + | 2 | 2 | 8 | 2 | 2 | 80 | + | 2 | 2 | 9 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_inner_two_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 1, 2]), + ("b2", &vec![1, 1, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 1, 3]), + ("b2", &vec![1, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + ), + ]; + + let (_columns, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | 7 | 1 | 1 | 70 | + | 1 | 1 | 7 | 1 | 1 | 80 | + | 1 | 1 | 8 | 1 | 1 | 70 | + | 1 | 1 | 8 | 1 | 1 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_inner_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field + ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), + ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | | 1 | 1 | 70 | + | 2 | 2 | 8 | 2 | 2 | 80 | + | 2 | 2 | 9 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_inner_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]), + ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field + ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]), + ("b2", &vec![Some(2), Some(2), Some(1), None]), + ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + ), + ]; + let (_, batches) = join_collect_with_options( + left, + right, + on, + Inner, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + NullEquality::NullEqualsNull, + ) + .await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 2 | 2 | 9 | 2 | 2 | 80 | + | 2 | 2 | 8 | 2 | 2 | 80 | + | 1 | 1 | | 1 | 1 | 70 | + | 1 | | 1 | 1 | | 10 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_inner_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_batch_size_equals_two(left, right, on, Inner).await?; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | 7 | 1 | 1 | 70 | + | 2 | 2 | 8 | 2 | 2 | 80 | + | 2 | 2 | 9 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Left).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 7 | 9 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Right).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | | | | 30 | 6 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_full_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Full).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_sort_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | | | | 30 | 6 | 90 | + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 7 | 9 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_left_anti() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, LeftAnti).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 3 | 7 | 9 | + | 5 | 7 | 11 | + +----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_anti_one_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+ + | a2 | b1 | + +----+----+ + | 30 | 6 | + +----+----+ + "#); + + let left2 = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right2 = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _, + )]; + + let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches2), @r#" + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + | 30 | 6 | 90 | + +----+----+----+ + "#); + + Ok(()) +} + +#[tokio::test] +async fn join_right_anti_two_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+ + | a2 | b1 | + +----+----+ + | 10 | 4 | + | 20 | 5 | + | 30 | 6 | + +----+----+ + "#); + + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + let expected = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 10 | 4 | 70 |", + "| 20 | 5 | 80 |", + "| 30 | 6 | 90 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + + Ok(()) +} + +#[tokio::test] +async fn join_right_anti_two_with_filter() -> Result<()> { + let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); + let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c2", 1)), + Operator::Gt, + Arc::new(Column::new("c1", 0)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ])), + ); + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightAnti).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c2 | + +----+----+----+ + | 1 | 10 | 20 | + +----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_anti_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), + ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field + ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightAnti).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c2 | + +----+----+----+ + | 2 | | 8 | + +----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_anti_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]), + ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]), + ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), + ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field + ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + RightAnti, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + NullEquality::NullEqualsNull, + ) + .await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c2 | + +----+----+----+ + | 3 | | 9 | + | 2 | 5 | | + | 2 | 5 | 8 | + +----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_anti_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 8]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 4 | 7 | + | 2 | 5 | 8 | + | 2 | 5 | 8 | + +----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_left_semi() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 5 is double on the right + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, LeftSemi).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 4 | 7 | + | 2 | 5 | 8 | + | 2 | 5 | 8 | + +----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_semi_one() -> Result<()> { + let left = build_table( + ("a1", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a2", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) +} + +#[tokio::test] +async fn join_right_semi_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) +} + +#[tokio::test] +async fn join_right_semi_two_with_filter() -> Result<()> { + let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); + let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c2", 1)), + Operator::Lt, + Arc::new(Column::new("c1", 0)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ])), + ); + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 10 | 20 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) +} + +#[tokio::test] +async fn join_right_semi_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), + ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field + ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 3 | 6 | |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) +} + +#[tokio::test] +async fn join_right_semi_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(1), Some(0), Some(2)]), + ("b1", &vec![None, Some(5), Some(4), None, Some(5)]), + ("c2", &vec![Some(90), Some(80), Some(70), Some(60), None]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), + ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field + ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + RightSemi, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + NullEquality::NullEqualsNull, + ) + .await?; + + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 3 | | 9 |", + "| 2 | 5 | |", + "| 2 | 5 | 8 |", + "| 1 | 4 | 7 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) +} + +#[tokio::test] +async fn join_right_semi_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_eq!(expected, &batches); + Ok(()) +} + +#[tokio::test] +async fn join_left_mark() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, LeftMark).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+-------+ + | a1 | b1 | c1 | mark | + +----+----+----+-------+ + | 1 | 4 | 7 | true | + | 2 | 5 | 8 | true | + | 2 | 5 | 8 | true | + | 3 | 7 | 9 | false | + +----+----+----+-------+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_mark() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 5 is double on the left + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightMark).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+-------+ + | a2 | b1 | c2 | mark | + +----+----+----+-------+ + | 10 | 4 | 60 | true | + | 20 | 4 | 70 | true | + | 30 | 5 | 80 | true | + | 40 | 6 | 90 | false | + +----+----+----+-------+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_with_duplicated_column_names() -> Result<()> { + let left = build_table( + ("a", &vec![1, 2, 3]), + ("b", &vec![4, 5, 7]), + ("c", &vec![7, 8, 9]), + ); + let right = build_table( + ("a", &vec![10, 20, 30]), + ("b", &vec![1, 2, 7]), + ("c", &vec![70, 80, 90]), + ); + let on = vec![( + // join on a=b so there are duplicate column names on unjoined columns + Arc::new(Column::new_with_schema("a", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +---+---+---+----+---+----+ + | a | b | c | a | b | c | + +---+---+---+----+---+----+ + | 1 | 4 | 7 | 10 | 1 | 70 | + | 2 | 5 | 8 | 20 | 2 | 80 | + +---+---+---+----+---+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_date32() -> Result<()> { + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19107, 19108, 19108]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19107, 19108, 19109]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 | + | 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 | + | 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 | + +------------+------------+------------+------------+------------+------------+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_date64() -> Result<()> { + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 | + | 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + | 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_binary() -> Result<()> { + let left = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_fixed_size_binary() -> Result<()> { + let left = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_left_sort_order() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![3, 4, 5, 6, 6, 7]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![2, 4, 6, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Left).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | 0 | 3 | 4 | | | | + | 1 | 4 | 5 | 10 | 4 | 60 | + | 2 | 5 | 6 | | | | + | 3 | 6 | 7 | 20 | 6 | 70 | + | 3 | 6 | 7 | 30 | 6 | 80 | + | 4 | 6 | 8 | 20 | 6 | 70 | + | 4 | 6 | 8 | 30 | 6 | 80 | + | 5 | 7 | 9 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_sort_order() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3]), + ("b1", &vec![3, 4, 5, 7]), + ("c1", &vec![6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30]), + ("b2", &vec![2, 4, 5, 6]), + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Right).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | | | | 0 | 2 | 60 | + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | | | | 30 | 6 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_left_multiple_batches() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Left).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | 0 | 3 | 4 | | | | + | 1 | 4 | 5 | 10 | 4 | 60 | + | 2 | 5 | 6 | | | | + | 3 | 6 | 7 | 20 | 6 | 70 | + | 3 | 6 | 7 | 30 | 6 | 80 | + | 4 | 6 | 8 | 20 | 6 | 70 | + | 4 | 6 | 8 | 30 | 6 | 80 | + | 5 | 7 | 9 | | | | + | 6 | 9 | 9 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_right_multiple_batches() -> Result<()> { + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![3, 4, 5, 6]), + ("b2", &vec![6, 6, 7, 9]), + ("c2", &vec![7, 8, 9, 9]), + ); + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 10, 20]), + ("b1", &vec![2, 4, 6]), + ("c1", &vec![50, 60, 70]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![30, 40]), + ("b1", &vec![6, 8]), + ("c1", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Right).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | | | | 0 | 3 | 4 | + | 10 | 4 | 60 | 1 | 4 | 5 | + | | | | 2 | 5 | 6 | + | 20 | 6 | 70 | 3 | 6 | 7 | + | 30 | 6 | 80 | 3 | 6 | 7 | + | 20 | 6 | 70 | 4 | 6 | 8 | + | 30 | 6 | 80 | 4 | 6 | 8 | + | | | | 5 | 7 | 9 | + | | | | 6 | 9 | 9 | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn join_full_multiple_batches() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Full).await?; + assert_snapshot!(batches_to_sort_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | | | | 0 | 2 | 50 | + | | | | 40 | 8 | 90 | + | 0 | 3 | 4 | | | | + | 1 | 4 | 5 | 10 | 4 | 60 | + | 2 | 5 | 6 | | | | + | 3 | 6 | 7 | 20 | 6 | 70 | + | 3 | 6 | 7 | 30 | 6 | 80 | + | 4 | 6 | 8 | 20 | 6 | 70 | + | 4 | 6 | 8 | 30 | 6 | 80 | + | 5 | 7 | 9 | | | | + | 6 | 9 | 9 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn overallocation_single_batch_no_spill() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + ]; + + // Disable DiskManager to prevent spilling + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) + .build_arc()?; + let session_config = SessionConfig::default().with_batch_size(50); + + for join_type in join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let err = common::collect(stream).await.unwrap_err(); + + assert_contains!(err.to_string(), "Failed to allocate additional"); + assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + } + + Ok(()) +} + +#[tokio::test] +async fn overallocation_multi_batch_no_spill() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + ]; + + // Disable DiskManager to prevent spilling + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) + .build_arc()?; + let session_config = SessionConfig::default().with_batch_size(50); + + for join_type in join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let err = common::collect(stream).await.unwrap_err(); + + assert_contains!(err.to_string(), "Failed to allocate additional"); + assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + } + + Ok(()) +} + +#[tokio::test] +async fn overallocation_single_batch_spill() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + ]; + + // Enable DiskManager to allow spilling + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } + } + + Ok(()) +} + +#[tokio::test] +async fn overallocation_multi_batch_spill() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + ]; + + // Enable DiskManager to allow spilling + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } + } + + Ok(()) +} + +fn build_joined_record_batches() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut batches = JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }; + + // Insert already prejoined non-filtered rows + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![1; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + batches.batch_ids.extend(vec![2; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![3; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + batches + .filter_mask + .extend(&BooleanArray::from(vec![true, false])); + batches.filter_mask.extend(&BooleanArray::from(vec![true])); + batches + .filter_mask + .extend(&BooleanArray::from(vec![false, true])); + batches.filter_mask.extend(&BooleanArray::from(vec![false])); + batches + .filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + Ok(batches) +} + +#[tokio::test] +async fn test_left_outer_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, false, false, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![false, false, false, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, false, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, true, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + Some(true), + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + let corrected_mask = get_corrected_filter_mask( + Left, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + Some(false), + None, + Some(false) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" + +---+----+---+----+ + | a | b | x | y | + +---+----+---+----+ + | 1 | 10 | 1 | 11 | + | 1 | 11 | 1 | 12 | + | 1 | 12 | 1 | 13 | + +---+----+---+----+ + "#); + + // output null rows + + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + Some(true), + None, + Some(true) + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" + +---+----+---+----+ + | a | b | x | y | + +---+----+---+----+ + | 1 | 13 | 1 | 12 | + | 1 | 14 | 1 | 11 | + +---+----+---+----+ + "#); + Ok(()) +} + +#[tokio::test] +async fn test_semi_join_filtered_mask() -> Result<()> { + for join_type in [LeftSemi, RightSemi] { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + let corrected_mask = get_corrected_filter_mask( + join_type, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + } + Ok(()) +} + +#[tokio::test] +async fn test_anti_join_filtered_mask() -> Result<()> { + for join_type in [LeftAnti, RightAnti] { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![Some(true)]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + 2 + ) + .unwrap(), + BooleanArray::from(vec![None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true)]) + ); + + let corrected_mask = get_corrected_filter_mask( + join_type, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(true), + None, + Some(true) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + allow_duplicates! { + assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" + +---+----+---+----+ + | a | b | x | y | + +---+----+---+----+ + | 1 | 13 | 1 | 12 | + | 1 | 14 | 1 | 11 | + +---+----+---+----+ + "#); + } + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(false), + None, + Some(false), + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + allow_duplicates! { + assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" + +---+---+---+---+ + | a | b | x | y | + +---+---+---+---+ + +---+---+---+---+ + "#); + } + } + + Ok(()) +} + +/// Returns the column names on the schema +fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() +} diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 677601a12845f..9f5485ee93bde 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -22,8 +22,12 @@ use std::collections::{HashMap, VecDeque}; use std::mem::size_of; use std::sync::Arc; +use crate::joins::join_hash_map::{ + get_matched_indices, get_matched_indices_with_limit_offset, update_from_iter, + JoinHashMapOffset, +}; use crate::joins::utils::{JoinFilter, JoinHashMapType}; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{metrics, ExecutionPlan}; use arrow::array::{ @@ -47,26 +51,49 @@ use hashbrown::HashTable; /// Implementation of `JoinHashMapType` for `PruningJoinHashMap`. impl JoinHashMapType for PruningJoinHashMap { - type NextType = VecDeque; - // Extend with zero fn extend_zero(&mut self, len: usize) { self.next.resize(self.next.len() + len, 0) } - /// Get mutable references to the hash map and the next. - fn get_mut(&mut self) -> (&mut HashTable<(u64, u64)>, &mut Self::NextType) { - (&mut self.map, &mut self.next) + fn update_from_iter<'a>( + &mut self, + iter: Box + Send + 'a>, + deleted_offset: usize, + ) { + let slice: &mut [u64] = self.next.make_contiguous(); + update_from_iter::(&mut self.map, slice, iter, deleted_offset); } - /// Get a reference to the hash map. - fn get_map(&self) -> &HashTable<(u64, u64)> { - &self.map + fn get_matched_indices<'a>( + &self, + iter: Box + 'a>, + deleted_offset: Option, + ) -> (Vec, Vec) { + // Flatten the deque + let next: Vec = self.next.iter().copied().collect(); + get_matched_indices::(&self.map, &next, iter, deleted_offset) } - /// Get a reference to the next. - fn get_list(&self) -> &Self::NextType { - &self.next + fn get_matched_indices_with_limit_offset( + &self, + hash_values: &[u64], + limit: usize, + offset: JoinHashMapOffset, + ) -> (Vec, Vec, Option) { + // Flatten the deque + let next: Vec = self.next.iter().copied().collect(); + get_matched_indices_with_limit_offset::( + &self.map, + &next, + hash_values, + limit, + offset, + ) + } + + fn is_empty(&self) -> bool { + self.map.is_empty() } } @@ -659,7 +686,7 @@ pub struct StreamJoinMetrics { /// Number of batches produced by this operator pub(crate) output_batches: metrics::Count, /// Number of rows produced by this operator - pub(crate) output_rows: metrics::Count, + pub(crate) baseline_metrics: BaselineMetrics, } impl StreamJoinMetrics { @@ -686,14 +713,12 @@ impl StreamJoinMetrics { let output_batches = MetricBuilder::new(metrics).counter("output_batches", partition); - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - Self { left, right, output_batches, stream_memory_usage, - output_rows, + baseline_metrics: BaselineMetrics::new(metrics, partition), } } } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 0dcb42169e00a..b55b7e15f194c 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -34,7 +34,6 @@ use std::vec; use crate::common::SharedMemoryReservation; use crate::execution_plan::{boundedness_from_children, emission_type_from_children}; -use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, @@ -43,9 +42,9 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter, - BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, - NoopBatchTransformer, StatefulStreamResult, + check_join_is_valid, equal_rows_arr, symmetric_join_output_partitioning, update_hash, + BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, + JoinOnRef, NoopBatchTransformer, StatefulStreamResult, }; use crate::projection::{ join_allows_pushdown, join_table_borders, new_join_children, @@ -67,15 +66,16 @@ use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; -use datafusion_common::{internal_err, plan_err, HashSet, JoinSide, JoinType, Result}; +use datafusion_common::{ + internal_err, plan_err, HashSet, JoinSide, JoinType, NullEquality, Result, +}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; -use datafusion_physical_expr::PhysicalExprRef; -use datafusion_physical_expr_common::physical_expr::fmt_sql; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use ahash::RandomState; use futures::{ready, Stream, StreamExt}; @@ -186,8 +186,8 @@ pub struct SymmetricHashJoinExec { metrics: ExecutionPlanMetricsSet, /// Information of index and left / right placement of columns column_indices: Vec, - /// If null_equals_null is true, null == null else null != null - pub(crate) null_equals_null: bool, + /// Defines the null equality for the join. + pub(crate) null_equality: NullEquality, /// Left side sort expression(s) pub(crate) left_sort_exprs: Option, /// Right side sort expression(s) @@ -212,7 +212,7 @@ impl SymmetricHashJoinExec { on: JoinOn, filter: Option, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, left_sort_exprs: Option, right_sort_exprs: Option, mode: StreamJoinPartitionMode, @@ -237,8 +237,7 @@ impl SymmetricHashJoinExec { // Initialize the random state for the join operation: let random_state = RandomState::with_seeds(0, 0, 0, 0); let schema = Arc::new(schema); - let cache = - Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type, &on); + let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?; Ok(SymmetricHashJoinExec { left, right, @@ -248,7 +247,7 @@ impl SymmetricHashJoinExec { random_state, metrics: ExecutionPlanMetricsSet::new(), column_indices, - null_equals_null, + null_equality, left_sort_exprs, right_sort_exprs, mode, @@ -263,7 +262,7 @@ impl SymmetricHashJoinExec { schema: SchemaRef, join_type: JoinType, join_on: JoinOnRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: let eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), @@ -274,17 +273,17 @@ impl SymmetricHashJoinExec { // Has alternating probe side None, join_on, - ); + )?; let output_partitioning = - symmetric_join_output_partitioning(left, right, &join_type); + symmetric_join_output_partitioning(left, right, &join_type)?; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, emission_type_from_children([left, right]), boundedness_from_children([left, right]), - ) + )) } /// left stream @@ -312,9 +311,9 @@ impl SymmetricHashJoinExec { &self.join_type } - /// Get null_equals_null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null + /// Get null_equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality } /// Get partition mode @@ -372,7 +371,7 @@ impl DisplayAs for SymmetricHashJoinExec { let on = self .on .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); write!( @@ -395,7 +394,7 @@ impl DisplayAs for SymmetricHashJoinExec { if *self.join_type() != JoinType::Inner { writeln!(f, "join_type={:?}", self.join_type)?; } - writeln!(f, "on={}", on) + writeln!(f, "on={on}") } } } @@ -433,16 +432,14 @@ impl ExecutionPlan for SymmetricHashJoinExec { } } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![ self.left_sort_exprs .as_ref() - .cloned() - .map(LexRequirement::from), + .map(|e| OrderingRequirements::from(e.clone())), self.right_sort_exprs .as_ref() - .cloned() - .map(LexRequirement::from), + .map(|e| OrderingRequirements::from(e.clone())), ] } @@ -460,7 +457,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { self.on.clone(), self.filter.clone(), &self.join_type, - self.null_equals_null, + self.null_equality, self.left_sort_exprs.clone(), self.right_sort_exprs.clone(), self.mode, @@ -549,7 +546,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { graph, left_sorted_filter_expr, right_sorted_filter_expr, - null_equals_null: self.null_equals_null, + null_equality: self.null_equality, state: SHJStreamState::PullRight, reservation, batch_transformer: BatchSplitter::new(batch_size), @@ -569,7 +566,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { graph, left_sorted_filter_expr, right_sorted_filter_expr, - null_equals_null: self.null_equals_null, + null_equality: self.null_equality, state: SHJStreamState::PullRight, reservation, batch_transformer: NoopBatchTransformer::new(), @@ -635,21 +632,18 @@ impl ExecutionPlan for SymmetricHashJoinExec { self.right(), )?; - Ok(Some(Arc::new(SymmetricHashJoinExec::try_new( + SymmetricHashJoinExec::try_new( Arc::new(new_left), Arc::new(new_right), new_on, new_filter, self.join_type(), - self.null_equals_null(), - self.right() - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), - self.left() - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), + self.null_equality(), + self.right().output_ordering().cloned(), + self.left().output_ordering().cloned(), self.partition_mode(), - )?))) + ) + .map(|e| Some(Arc::new(e) as _)) } } @@ -678,8 +672,8 @@ struct SymmetricHashJoinStream { right_sorted_filter_expr: Option, /// Random state used for hashing initialization random_state: RandomState, - /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, + /// Defines the null equality for the join. + null_equality: NullEquality, /// Metrics metrics: StreamJoinMetrics, /// Memory reservation @@ -777,7 +771,11 @@ fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> } else { matches!( join_type, - JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi + JoinType::Right + | JoinType::RightAnti + | JoinType::Full + | JoinType::RightSemi + | JoinType::RightMark ) } } @@ -811,6 +809,21 @@ where { // Store the result in a tuple let result = match (build_side, join_type) { + // For a mark join we “mark” each build‐side row with a dummy 0 in the probe‐side index + // if it ever matched. For example, if + // + // prune_length = 5 + // deleted_offset = 0 + // visited_rows = {1, 3} + // + // then we produce: + // + // build_indices = [0, 1, 2, 3, 4] + // probe_indices = [None, Some(0), None, Some(0), None] + // + // Example: for each build row i in [0..5): + // – We always output its own index i in `build_indices` + // – We output `Some(0)` in `probe_indices[i]` if row i was ever visited, else `None` (JoinSide::Left, JoinType::LeftMark) => { let build_indices = (0..prune_length) .map(L::Native::from_usize) @@ -825,6 +838,20 @@ where .collect(); (build_indices, probe_indices) } + (JoinSide::Right, JoinType::RightMark) => { + let build_indices = (0..prune_length) + .map(L::Native::from_usize) + .collect::>(); + let probe_indices = (0..prune_length) + .map(|idx| { + // For mark join we output a dummy index 0 to indicate the row had a match + visited_rows + .contains(&(idx + deleted_offset)) + .then_some(R::Native::from_usize(0).unwrap()) + }) + .collect(); + (build_indices, probe_indices) + } // In the case of `Left` or `Right` join, or `Full` join, get the anti indices (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) @@ -923,7 +950,7 @@ pub(crate) fn build_side_determined_results( /// * `probe_batch` - The second record batch to be joined. /// * `column_indices` - An array of columns to be selected for the result of the join. /// * `random_state` - The random state for the join. -/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining. /// /// # Returns /// @@ -939,7 +966,7 @@ pub(crate) fn join_with_probe_batch( probe_batch: &RecordBatch, column_indices: &[ColumnIndex], random_state: &RandomState, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result> { if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(None); @@ -951,7 +978,7 @@ pub(crate) fn join_with_probe_batch( &build_hash_joiner.on, &probe_hash_joiner.on, random_state, - null_equals_null, + null_equality, &mut build_hash_joiner.hashes_buffer, Some(build_hash_joiner.deleted_offset), )?; @@ -964,6 +991,7 @@ pub(crate) fn join_with_probe_batch( probe_indices, filter, build_hash_joiner.build_side, + None, )? } else { (build_indices, probe_indices) @@ -990,6 +1018,7 @@ pub(crate) fn join_with_probe_batch( | JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi + | JoinType::RightMark ) { Ok(None) } else { @@ -1017,7 +1046,7 @@ pub(crate) fn join_with_probe_batch( /// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join. /// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join. /// * `random_state` - The random state for the join. -/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining. /// * `hashes_buffer` - Buffer used for probe side keys hash calculation. /// * `deleted_offset` - deleted offset for build side data. /// @@ -1033,7 +1062,7 @@ fn lookup_join_hashmap( build_on: &[PhysicalExprRef], probe_on: &[PhysicalExprRef], random_state: &RandomState, - null_equals_null: bool, + null_equality: NullEquality, hashes_buffer: &mut Vec, deleted_offset: Option, ) -> Result<(UInt64Array, UInt32Array)> { @@ -1080,8 +1109,10 @@ fn lookup_join_hashmap( // (5,1) // // With this approach, the lexicographic order on both the probe side and the build side is preserved. - let (mut matched_probe, mut matched_build) = build_hashmap - .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); + let (mut matched_probe, mut matched_build) = build_hashmap.get_matched_indices( + Box::new(hash_values.iter().enumerate().rev()), + deleted_offset, + ); matched_probe.reverse(); matched_build.reverse(); @@ -1094,7 +1125,7 @@ fn lookup_join_hashmap( &probe_indices, &build_join_values, &keys_values, - null_equals_null, + null_equality, )?; Ok((build_indices, probe_indices)) @@ -1347,8 +1378,10 @@ impl SymmetricHashJoinStream { } Some((batch, _)) => { self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - return Poll::Ready(Some(Ok(batch))); + return self + .metrics + .baseline_metrics + .record_poll(Poll::Ready(Some(Ok(batch)))); } } } @@ -1591,7 +1624,7 @@ impl SymmetricHashJoinStream { size += size_of_val(&self.left_sorted_filter_expr); size += size_of_val(&self.right_sorted_filter_expr); size += size_of_val(&self.random_state); - size += size_of_val(&self.null_equals_null); + size += size_of_val(&self.null_equality); size += size_of_val(&self.metrics); size } @@ -1646,7 +1679,7 @@ impl SymmetricHashJoinStream { &probe_batch, &self.column_indices, &self.random_state, - self.null_equals_null, + self.null_equality, )?; // Increment the offset for the probe hash joiner: probe_hash_joiner.offset += probe_batch.num_rows(); @@ -1743,7 +1776,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, lit, Column}; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use rstest::*; @@ -1802,12 +1835,18 @@ mod tests { on.clone(), filter.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, Arc::clone(&task_ctx), ) .await?; let second_batches = partitioned_hash_join_with_filter( - left, right, on, filter, &join_type, false, task_ctx, + left, + right, + on, + filter, + &join_type, + NullEquality::NullEqualsNothing, + task_ctx, ) .await?; compare_batches(&first_batches, &second_batches); @@ -1826,6 +1865,7 @@ mod tests { JoinType::LeftAnti, JoinType::LeftMark, JoinType::RightAnti, + JoinType::RightMark, JoinType::Full )] join_type: JoinType, @@ -1843,7 +1883,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: binary( col("la1", left_schema)?, Operator::Plus, @@ -1851,11 +1891,13 @@ mod tests { left_schema, )?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1912,6 +1954,7 @@ mod tests { JoinType::LeftAnti, JoinType::LeftMark, JoinType::RightAnti, + JoinType::RightMark, JoinType::Full )] join_type: JoinType, @@ -1923,14 +1966,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1978,6 +2023,7 @@ mod tests { JoinType::LeftAnti, JoinType::LeftMark, JoinType::RightAnti, + JoinType::RightMark, JoinType::Full )] join_type: JoinType, @@ -2031,6 +2077,7 @@ mod tests { JoinType::LeftAnti, JoinType::LeftMark, JoinType::RightAnti, + JoinType::RightMark, JoinType::Full )] join_type: JoinType, @@ -2059,6 +2106,7 @@ mod tests { JoinType::LeftAnti, JoinType::LeftMark, JoinType::RightAnti, + JoinType::RightMark, JoinType::Full )] join_type: JoinType, @@ -2068,20 +2116,22 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1_des", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1_des", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2127,20 +2177,22 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_asc_null_first", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_asc_null_first", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2186,20 +2238,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_asc_null_last", left_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_asc_null_last", right_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2247,20 +2301,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_desc_null_first", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_desc_null_first", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2309,15 +2365,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]); - - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2368,20 +2425,23 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let left_sorted = vec![ - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]), - LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(), + [PhysicalSortExpr { expr: col("la2", left_schema)?, options: SortOptions::default(), - }]), + }] + .into(), ]; - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, @@ -2431,6 +2491,7 @@ mod tests { JoinType::LeftAnti, JoinType::LeftMark, JoinType::RightAnti, + JoinType::RightMark, JoinType::Full )] join_type: JoinType, @@ -2449,20 +2510,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("rt1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2515,6 +2578,7 @@ mod tests { JoinType::LeftAnti, JoinType::LeftMark, JoinType::RightAnti, + JoinType::RightMark, JoinType::Full )] join_type: JoinType, @@ -2532,20 +2596,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ri1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2591,6 +2657,7 @@ mod tests { JoinType::LeftAnti, JoinType::LeftMark, JoinType::RightAnti, + JoinType::RightMark, JoinType::Full )] join_type: JoinType, @@ -2608,14 +2675,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_float", left_schema)?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_float", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index e70007aa651f7..de288724c446e 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -33,7 +33,7 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, Schema}; use arrow::util::pretty::pretty_format_batches; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{NullEquality, Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{binary, cast, col, lit}; @@ -74,7 +74,7 @@ pub async fn partitioned_sym_join_with_filter( on: JoinOn, filter: Option, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, ) -> Result> { let partition_count = 4; @@ -101,11 +101,9 @@ pub async fn partitioned_sym_join_with_filter( on, filter, join_type, - null_equals_null, - left.output_ordering().map(|p| LexOrdering::new(p.to_vec())), - right - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), + null_equality, + left.output_ordering().cloned(), + right.output_ordering().cloned(), StreamJoinPartitionMode::Partitioned, )?; @@ -130,7 +128,7 @@ pub async fn partitioned_hash_join_with_filter( on: JoinOn, filter: Option, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, ) -> Result> { let partition_count = 4; @@ -153,7 +151,7 @@ pub async fn partitioned_hash_join_with_filter( join_type, None, PartitionMode::Partitioned, - null_equals_null, + null_equality, )?); let mut batches = vec![]; @@ -195,7 +193,7 @@ struct AscendingRandomFloatIterator { impl AscendingRandomFloatIterator { fn new(min: f64, max: f64) -> Self { let mut rng = StdRng::seed_from_u64(42); - let initial = rng.gen_range(min..max); + let initial = rng.random_range(min..max); AscendingRandomFloatIterator { prev: initial, max, @@ -208,7 +206,7 @@ impl Iterator for AscendingRandomFloatIterator { type Item = f64; fn next(&mut self) -> Option { - let value = self.rng.gen_range(self.prev..self.max); + let value = self.rng.random_range(self.prev..self.max); self.prev = value; Some(value) } @@ -419,12 +417,14 @@ pub fn build_sides_record_batches( key_cardinality: (i32, i32), ) -> Result<(RecordBatch, RecordBatch)> { let null_ratio: f64 = 0.4; + let duplicate_ratio = 0.4; let initial_range = 0..table_size; let index = (table_size as f64 * null_ratio).round() as i32; let rest_of = index..table_size; let ordered: ArrayRef = Arc::new(Int32Array::from_iter( initial_range.clone().collect::>(), )); + let random_ordered = generate_ordered_array(table_size, duplicate_ratio); let ordered_des = Arc::new(Int32Array::from_iter( initial_range.clone().rev().collect::>(), )); @@ -444,8 +444,7 @@ pub fn build_sides_record_batches( .collect::>(), )); let ordered_asc_null_first = Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) + std::iter::repeat_n(None, index as usize) .chain(rest_of.clone().map(Some)) .collect::>>() })); @@ -453,13 +452,12 @@ pub fn build_sides_record_batches( rest_of .clone() .map(Some) - .chain(std::iter::repeat(None).take(index as usize)) + .chain(std::iter::repeat_n(None, index as usize)) .collect::>>() })); let ordered_desc_null_first = Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) + std::iter::repeat_n(None, index as usize) .chain(rest_of.rev().map(Some)) .collect::>>() })); @@ -505,6 +503,7 @@ pub fn build_sides_record_batches( ), ("li1", Arc::clone(&interval_time)), ("l_float", Arc::clone(&float_asc) as ArrayRef), + ("l_random_ordered", Arc::clone(&random_ordered) as ArrayRef), ])?; let right = RecordBatch::try_from_iter(vec![ ("ra1", Arc::clone(&ordered)), @@ -518,6 +517,7 @@ pub fn build_sides_record_batches( ("r_desc_null_first", ordered_desc_null_first), ("ri1", interval_time), ("r_float", float_asc), + ("r_random_ordered", random_ordered), ])?; Ok((left, right)) } @@ -587,3 +587,24 @@ pub(crate) fn complicated_filter( )?; binary(left_expr, Operator::And, right_expr, filter_schema) } + +fn generate_ordered_array(size: i32, duplicate_ratio: f32) -> Arc { + let mut rng = StdRng::seed_from_u64(42); + let unique_count = (size as f32 * (1.0 - duplicate_ratio)) as i32; + + // Generate unique random values + let mut values: Vec = (0..unique_count) + .map(|_| rng.random_range(1..500)) // Modify as per your range + .collect(); + + // Duplicate the values according to the duplicate ratio + for _ in 0..(size - unique_count) { + let index = rng.random_range(0..unique_count); + values.push(values[index as usize]); + } + + // Sort the values to ensure they are ordered + values.sort(); + + Arc::new(Int32Array::from_iter(values)) +} diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index f6c720dbb707a..c50bfce93a2d5 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,6 +17,7 @@ //! Join related functionality used both on logical and physical plans +use std::cmp::min; use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; @@ -25,48 +26,52 @@ use std::ops::Range; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; +use crate::joins::SharedBitmapBuilder; +use crate::metrics::{self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder}; +use crate::projection::{ProjectionExec, ProjectionExpr}; use crate::{ ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics, }; // compatibility pub use super::join_filter::JoinFilter; -pub use super::join_hash_map::{JoinHashMap, JoinHashMapType}; +pub use super::join_hash_map::JoinHashMapType; +pub use crate::joins::{JoinOn, JoinOnRef}; +use ahash::RandomState; use arrow::array::{ builder::UInt64Builder, downcast_array, new_null_array, Array, ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, }; -use arrow::compute; +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::kernels::cmp::eq; +use arrow::compute::{self, and, take, FilterBuilder}; use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; +use arrow_ord::cmp::not_distinct; +use arrow_schema::ArrowError; use datafusion_common::cast::as_boolean_array; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, + plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; -use datafusion_physical_expr::equivalence::add_offset_to_expr; +use datafusion_expr::Operator; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::{collect_columns, merge_vectors}; +use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - LexOrdering, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + add_offset_to_expr, add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr, + PhysicalExprRef, }; -use crate::joins::SharedBitmapBuilder; -use crate::projection::ProjectionExec; +use datafusion_physical_expr_common::datum::compare_op_for_nested; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use parking_lot::Mutex; -/// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>; -/// Reference for JoinOn. -pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)]; - /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. /// They are valid whenever their columns' intersection equals the set `on` pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { @@ -118,113 +123,84 @@ fn check_join_set_is_valid( pub fn adjust_right_output_partitioning( right_partitioning: &Partitioning, left_columns_len: usize, -) -> Partitioning { - match right_partitioning { +) -> Result { + let result = match right_partitioning { Partitioning::Hash(exprs, size) => { let new_exprs = exprs .iter() - .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len)) - .collect(); + .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len as _)) + .collect::>()?; Partitioning::Hash(new_exprs, *size) } result => result.clone(), - } -} - -/// Replaces the right column (first index in the `on_column` tuple) with -/// the left column (zeroth index in the tuple) inside `right_ordering`. -fn replace_on_columns_of_right_ordering( - on_columns: &[(PhysicalExprRef, PhysicalExprRef)], - right_ordering: &mut LexOrdering, -) -> Result<()> { - for (left_col, right_col) in on_columns { - right_ordering.transform(|item| { - let new_expr = Arc::clone(&item.expr) - .transform(|e| { - if e.eq(right_col) { - Ok(Transformed::yes(Arc::clone(left_col))) - } else { - Ok(Transformed::no(e)) - } - }) - .data() - .expect("closure is infallible"); - item.expr = new_expr; - }); - } - Ok(()) -} - -fn offset_ordering( - ordering: &LexOrdering, - join_type: &JoinType, - offset: usize, -) -> LexOrdering { - match join_type { - // In the case below, right ordering should be offsetted with the left - // side length, since we append the right table to the left table. - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering - .iter() - .map(|sort_expr| PhysicalSortExpr { - expr: add_offset_to_expr(Arc::clone(&sort_expr.expr), offset), - options: sort_expr.options, - }) - .collect(), - _ => ordering.clone(), - } + }; + Ok(result) } /// Calculate the output ordering of a given join operation. pub fn calculate_join_output_ordering( - left_ordering: &LexOrdering, - right_ordering: &LexOrdering, + left_ordering: Option<&LexOrdering>, + right_ordering: Option<&LexOrdering>, join_type: JoinType, - on_columns: &[(PhysicalExprRef, PhysicalExprRef)], left_columns_len: usize, maintains_input_order: &[bool], probe_side: Option, -) -> Option { - let output_ordering = match maintains_input_order { +) -> Result> { + match maintains_input_order { [true, false] => { // Special case, we can prefix ordering of right side with the ordering of left side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { - replace_on_columns_of_right_ordering( - on_columns, - &mut right_ordering.clone(), - ) - .ok()?; - merge_vectors( - left_ordering, - offset_ordering(right_ordering, &join_type, left_columns_len) - .as_ref(), - ) - } else { - left_ordering.clone() + if let Some(right_ordering) = right_ordering.cloned() { + let right_offset = add_offset_to_physical_sort_exprs( + right_ordering, + left_columns_len as _, + )?; + return if let Some(left_ordering) = left_ordering { + let mut result = left_ordering.clone(); + result.extend(right_offset); + Ok(Some(result)) + } else { + Ok(LexOrdering::new(right_offset)) + }; + } } + Ok(left_ordering.cloned()) } [false, true] => { // Special case, we can prefix ordering of left side with the ordering of right side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { - replace_on_columns_of_right_ordering( - on_columns, - &mut right_ordering.clone(), - ) - .ok()?; - merge_vectors( - offset_ordering(right_ordering, &join_type, left_columns_len) - .as_ref(), - left_ordering, - ) - } else { - offset_ordering(right_ordering, &join_type, left_columns_len) + return if let Some(right_ordering) = right_ordering.cloned() { + let mut right_offset = add_offset_to_physical_sort_exprs( + right_ordering, + left_columns_len as _, + )?; + if let Some(left_ordering) = left_ordering { + right_offset.extend(left_ordering.clone()); + } + Ok(LexOrdering::new(right_offset)) + } else { + Ok(left_ordering.cloned()) + }; + } + let Some(right_ordering) = right_ordering else { + return Ok(None); + }; + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + add_offset_to_physical_sort_exprs( + right_ordering.clone(), + left_columns_len as _, + ) + .map(LexOrdering::new) + } + _ => Ok(Some(right_ordering.clone())), } } // Doesn't maintain ordering, output ordering is None. - [false, false] => return None, + [false, false] => Ok(None), [true, true] => unreachable!("Cannot maintain ordering of both sides"), _ => unreachable!("Join operators can not have more than two children"), - }; - (!output_ordering.is_empty()).then_some(output_ordering) + } } /// Information about the index and placement (left or right) of the columns @@ -250,6 +226,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??) JoinType::RightAnti => false, // doesn't introduce nulls (or can it??) JoinType::LeftMark => false, + JoinType::RightMark => false, }; if force_nullable { @@ -316,19 +293,38 @@ pub fn build_join_schema( left_fields().chain(right_field).unzip() } JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), + JoinType::RightMark => { + let left_field = once(( + Field::new("mark", arrow_schema::DataType::Boolean, false), + ColumnIndex { + index: 0, + side: JoinSide::None, + }, + )); + right_fields().chain(left_field).unzip() + } }; - let metadata = left + let (schema1, schema2) = match join_type { + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => (left, right), + _ => (right, left), + }; + + let metadata = schema1 .metadata() .clone() .into_iter() - .chain(right.metadata().clone()) + .chain(schema2.metadata().clone()) .collect(); + (fields.finish().with_metadata(metadata), column_indices) } /// A [`OnceAsync`] runs an `async` closure once, where multiple calls to -/// [`OnceAsync::once`] return a [`OnceFut`] that resolves to the result of the +/// [`OnceAsync::try_once`] return a [`OnceFut`] that resolves to the result of the /// same computation. /// /// This is useful for joins where the results of one child are needed to proceed @@ -341,7 +337,7 @@ pub fn build_join_schema( /// /// Each output partition waits on the same `OnceAsync` before proceeding. pub(crate) struct OnceAsync { - fut: Mutex>>, + fut: Mutex>>>, } impl Default for OnceAsync { @@ -360,19 +356,22 @@ impl Debug for OnceAsync { impl OnceAsync { /// If this is the first call to this function on this object, will invoke - /// `f` to obtain a future and return a [`OnceFut`] referring to this + /// `f` to obtain a future and return a [`OnceFut`] referring to this. `f` + /// may fail, in which case its error is returned. /// /// If this is not the first call, will return a [`OnceFut`] referring - /// to the same future as was returned by the first call - pub(crate) fn once(&self, f: F) -> OnceFut + /// to the same future as was returned by the first call - or the same + /// error if the initial call to `f` failed. + pub(crate) fn try_once(&self, f: F) -> Result> where - F: FnOnce() -> Fut, + F: FnOnce() -> Result, Fut: Future> + Send + 'static, { self.fut .lock() - .get_or_insert_with(|| OnceFut::new(f())) + .get_or_insert_with(|| f().map(OnceFut::new).map_err(Arc::new)) .clone() + .map_err(DataFusionError::Shared) } } @@ -404,15 +403,12 @@ struct PartialJoinStatistics { /// Estimate the statistics for the given join's output. pub(crate) fn estimate_join_statistics( - left: Arc, - right: Arc, + left_stats: Statistics, + right_stats: Statistics, on: JoinOn, join_type: &JoinType, schema: &Schema, ) -> Result { - let left_stats = left.statistics()?; - let right_stats = right.statistics()?; - let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on); let (num_rows, column_statistics) = match join_stats { Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics), @@ -537,6 +533,15 @@ fn estimate_join_cardinality( column_statistics, }) } + JoinType::RightMark => { + let num_rows = *right_stats.num_rows.get_value()?; + let mut column_statistics = right_stats.column_statistics; + column_statistics.push(ColumnStatistics::new_unknown()); + Some(PartialJoinStatistics { + num_rows, + column_statistics, + }) + } } } @@ -561,15 +566,6 @@ fn estimate_inner_join_cardinality( .iter() .zip(right_stats.column_statistics.iter()) { - // Break if any of statistics bounds are undefined - if left_stat.min_value.get_value().is_none() - || left_stat.max_value.get_value().is_none() - || right_stat.min_value.get_value().is_none() - || right_stat.max_value.get_value().is_none() - { - return None; - } - let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat); let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat); let max_distinct = left_max_distinct.max(&right_max_distinct); @@ -656,7 +652,8 @@ fn estimate_disjoint_inputs( /// Estimate the number of maximum distinct values that can be present in the /// given column from its statistics. If distinct_count is available, uses it /// directly. Otherwise, if the column is numeric and has min/max values, it -/// estimates the maximum distinct count from those. +/// estimates the maximum distinct count from those. Otherwise, the num_rows +/// is used. fn max_distinct_count( num_rows: &Precision, stats: &ColumnStatistics, @@ -779,6 +776,23 @@ impl OnceFut { } } +/// Should we use a bitmap to track each incoming right batch's each row's +/// 'joined' status. +/// +/// For example in right joins, we have to use a bit map to track matched +/// right side rows, and later enter a `EmitRightUnmatched` stage to emit +/// unmatched right rows. +pub(crate) fn need_produce_right_in_final(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::Full + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightMark + | JoinType::RightSemi + ) +} + /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and /// use the bit map to generate the part of result of the join. /// @@ -850,24 +864,56 @@ pub(crate) fn apply_join_filter_to_indices( probe_indices: UInt32Array, filter: &JoinFilter, build_side: JoinSide, + max_intermediate_size: Option, ) -> Result<(UInt64Array, UInt32Array)> { if build_indices.is_empty() && probe_indices.is_empty() { return Ok((build_indices, probe_indices)); }; - let intermediate_batch = build_batch_from_indices( - filter.schema(), - build_input_buffer, - probe_batch, - &build_indices, - &probe_indices, - filter.column_indices(), - build_side, - )?; - let filter_result = filter - .expression() - .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows())?; + let filter_result = if let Some(max_size) = max_intermediate_size { + let mut filter_results = + Vec::with_capacity(build_indices.len().div_ceil(max_size)); + + for i in (0..build_indices.len()).step_by(max_size) { + let end = min(build_indices.len(), i + max_size); + let len = end - i; + let intermediate_batch = build_batch_from_indices( + filter.schema(), + build_input_buffer, + probe_batch, + &build_indices.slice(i, len), + &probe_indices.slice(i, len), + filter.column_indices(), + build_side, + )?; + let filter_result = filter + .expression() + .evaluate(&intermediate_batch)? + .into_array(intermediate_batch.num_rows())?; + filter_results.push(filter_result); + } + + let filter_refs: Vec<&dyn Array> = + filter_results.iter().map(|a| a.as_ref()).collect(); + + compute::concat(&filter_refs)? + } else { + let intermediate_batch = build_batch_from_indices( + filter.schema(), + build_input_buffer, + probe_batch, + &build_indices, + &probe_indices, + filter.column_indices(), + build_side, + )?; + + filter + .expression() + .evaluate(&intermediate_batch)? + .into_array(intermediate_batch.num_rows())? + }; + let mask = as_boolean_array(&filter_result)?; let left_filtered = compute::filter(&build_indices, mask)?; @@ -908,7 +954,7 @@ pub(crate) fn build_batch_from_indices( for column_index in column_indices { let array = if column_index.side == JoinSide::None { - // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false + // For mark joins, the mark column is a true if the indices is not null, otherwise it will be false Arc::new(compute::is_not_null(probe_indices)?) } else if column_index.side == build_side { let array = build_input_buffer.column(column_index.index); @@ -919,7 +965,7 @@ pub(crate) fn build_batch_from_indices( assert_eq!(build_indices.null_count(), build_indices.len()); new_null_array(array.data_type(), build_indices.len()) } else { - compute::take(array.as_ref(), build_indices, None)? + take(array.as_ref(), build_indices, None)? } } else { let array = probe_batch.column(column_index.index); @@ -927,14 +973,64 @@ pub(crate) fn build_batch_from_indices( assert_eq!(probe_indices.null_count(), probe_indices.len()); new_null_array(array.data_type(), probe_indices.len()) } else { - compute::take(array.as_ref(), probe_indices, None)? + take(array.as_ref(), probe_indices, None)? } }; + columns.push(array); } Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) } +/// Returns a new [RecordBatch] resulting of a join where the build/left side is empty. +/// The resulting batch has [Schema] `schema`. +pub(crate) fn build_batch_empty_build_side( + schema: &Schema, + build_batch: &RecordBatch, + probe_batch: &RecordBatch, + column_indices: &[ColumnIndex], + join_type: JoinType, +) -> Result { + match join_type { + // these join types only return data if the left side is not empty, so we return an + // empty RecordBatch + JoinType::Inner + | JoinType::Left + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::LeftMark => Ok(RecordBatch::new_empty(Arc::new(schema.clone()))), + + // the remaining joins will return data for the right columns and null for the left ones + JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => { + let num_rows = probe_batch.num_rows(); + let mut columns: Vec> = + Vec::with_capacity(schema.fields().len()); + + for column_index in column_indices { + let array = match column_index.side { + // left -> null array + JoinSide::Left => new_null_array( + build_batch.column(column_index.index).data_type(), + num_rows, + ), + // right -> respective right array + JoinSide::Right => Arc::clone(probe_batch.column(column_index.index)), + // right mark -> unset boolean array as there are no matches on the left side + JoinSide::None => Arc::new(BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + None, + )), + }; + + columns.push(array); + } + + Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) + } + } +} + /// The input is the matched indices for left and right and /// adjust the indices according to the join type pub(crate) fn adjust_indices_by_join_type( @@ -979,6 +1075,12 @@ pub(crate) fn adjust_indices_by_join_type( // the left_indices will not be used later for the `right anti` join Ok((left_indices, right_indices)) } + JoinType::RightMark => { + let right_indices = get_mark_indices(&adjust_range, &right_indices); + let left_indices_vec: Vec = adjust_range.map(|i| i as u64).collect(); + let left_indices = UInt64Array::from(left_indices_vec); + Ok((left_indices, right_indices)) + } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop @@ -1080,17 +1182,7 @@ pub(crate) fn get_anti_indices( where NativeAdapter: From<::Native>, { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - + let bitmap = build_range_bitmap(&range, input_indices); let offset = range.start; // get the anti index @@ -1109,19 +1201,8 @@ pub(crate) fn get_semi_indices( where NativeAdapter: From<::Native>, { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - + let bitmap = build_range_bitmap(&range, input_indices); let offset = range.start; - // get the semi index (range) .filter_map(|idx| { @@ -1130,6 +1211,37 @@ where .collect() } +pub(crate) fn get_mark_indices( + range: &Range, + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = build_range_bitmap(range, input_indices); + PrimitiveArray::new( + vec![0; range.len()].into(), + Some(NullBuffer::new(bitmap.finish())), + ) +} + +fn build_range_bitmap( + range: &Range, + input: &PrimitiveArray, +) -> BooleanBufferBuilder { + let mut builder = BooleanBufferBuilder::new(range.len()); + builder.append_n(range.len(), false); + + input.iter().flatten().for_each(|v| { + let idx = v.as_usize(); + if range.contains(&idx) { + builder.set_bit(idx - range.start, true); + } + }); + + builder +} + /// Appends probe indices in order by considering the given build indices. /// /// This function constructs new build and probe indices by iterating through @@ -1187,6 +1299,7 @@ fn append_probe_indices_in_order( /// Metrics for build & probe joins #[derive(Clone, Debug)] pub(crate) struct BuildProbeJoinMetrics { + pub(crate) baseline: BaselineMetrics, /// Total time for collecting build-side of join pub(crate) build_time: metrics::Time, /// Number of batches consumed by build-side @@ -1203,12 +1316,31 @@ pub(crate) struct BuildProbeJoinMetrics { pub(crate) input_rows: metrics::Count, /// Number of batches produced by this operator pub(crate) output_batches: metrics::Count, - /// Number of rows produced by this operator - pub(crate) output_rows: metrics::Count, +} + +// This Drop implementation updates the elapsed compute part of the metrics. +// +// Why is this in a Drop? +// - We keep track of build_time and join_time separately, but baseline metrics have +// a total elapsed_compute time. Instead of remembering to update both the metrics +// at the same time, we chose to update elapsed_compute once at the end - summing up +// both the parts. +// +// How does this work? +// - The elapsed_compute `Time` is represented by an `Arc`. So even when +// this `BuildProbeJoinMetrics` is dropped, the elapsed_compute is usable through the +// Arc reference. +impl Drop for BuildProbeJoinMetrics { + fn drop(&mut self) { + self.baseline.elapsed_compute().add(&self.build_time); + self.baseline.elapsed_compute().add(&self.join_time); + } } impl BuildProbeJoinMetrics { pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let baseline = BaselineMetrics::new(metrics, partition); + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition); @@ -1230,8 +1362,6 @@ impl BuildProbeJoinMetrics { let output_batches = MetricBuilder::new(metrics).counter("output_batches", partition); - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - Self { build_time, build_input_batches, @@ -1241,7 +1371,7 @@ impl BuildProbeJoinMetrics { input_batches, input_rows, output_batches, - output_rows, + baseline, } } } @@ -1297,36 +1427,41 @@ pub(crate) fn symmetric_join_output_partitioning( left: &Arc, right: &Arc, join_type: &JoinType, -) -> Partitioning { +) -> Result { let left_columns_len = left.schema().fields.len(); let left_partitioning = left.output_partitioning(); let right_partitioning = right.output_partitioning(); - match join_type { + let result = match join_type { JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left_partitioning.clone() } - JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right_partitioning.clone() + } JoinType::Inner | JoinType::Right => { - adjust_right_output_partitioning(right_partitioning, left_columns_len) + adjust_right_output_partitioning(right_partitioning, left_columns_len)? } JoinType::Full => { // We could also use left partition count as they are necessarily equal. Partitioning::UnknownPartitioning(right_partitioning.partition_count()) } - } + }; + Ok(result) } pub(crate) fn asymmetric_join_output_partitioning( left: &Arc, right: &Arc, join_type: &JoinType, -) -> Partitioning { - match join_type { +) -> Result { + let result = match join_type { JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( right.output_partitioning(), left.schema().fields().len(), - ), - JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), + )?, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.output_partitioning().clone() + } JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti @@ -1334,7 +1469,8 @@ pub(crate) fn asymmetric_join_output_partitioning( | JoinType::LeftMark => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), - } + }; + Ok(result) } /// Trait for incrementally generating Join output. @@ -1428,7 +1564,7 @@ impl BatchTransformer for BatchSplitter { /// Joins output columns from their left input followed by their right input. /// Thus if the inputs are reordered, the output columns must be reordered to /// match the original order. -pub(crate) fn reorder_output_after_swap( +pub fn reorder_output_after_swap( plan: Arc, left_schema: &Schema, right_schema: &Schema, @@ -1448,26 +1584,33 @@ pub(crate) fn reorder_output_after_swap( fn swap_reverting_projection( left_schema: &Schema, right_schema: &Schema, -) -> Vec<(Arc, String)> { - let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| { - ( - Arc::new(Column::new(f.name(), i)) as Arc, - f.name().to_owned(), - ) - }); +) -> Vec { + let right_cols = + right_schema + .fields() + .iter() + .enumerate() + .map(|(i, f)| ProjectionExpr { + expr: Arc::new(Column::new(f.name(), i)) as Arc, + alias: f.name().to_owned(), + }); let right_len = right_cols.len(); - let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| { - ( - Arc::new(Column::new(f.name(), right_len + i)) as Arc, - f.name().to_owned(), - ) - }); + let left_cols = + left_schema + .fields() + .iter() + .enumerate() + .map(|(i, f)| ProjectionExpr { + expr: Arc::new(Column::new(f.name(), right_len + i)) + as Arc, + alias: f.name().to_owned(), + }); left_cols.chain(right_cols).collect() } /// This function swaps the given join's projection. -pub(super) fn swap_join_projection( +pub fn swap_join_projection( left_schema_len: usize, right_schema_len: usize, projection: Option<&Vec>, @@ -1479,8 +1622,9 @@ pub(super) fn swap_join_projection( JoinType::LeftAnti | JoinType::LeftSemi | JoinType::RightAnti - | JoinType::RightSemi => projection.cloned(), - + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark => projection.cloned(), _ => projection.map(|p| { p.iter() .map(|i| { @@ -1499,17 +1643,125 @@ pub(super) fn swap_join_projection( } } +/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` +/// using `offset` as a start value for `batch` row indices. +/// +/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, +/// which allows to keep either first (if set to true) or last (if set to false) row index +/// as a chain head for rows with equal hash values. +#[allow(clippy::too_many_arguments)] +pub fn update_hash( + on: &[PhysicalExprRef], + batch: &RecordBatch, + hash_map: &mut dyn JoinHashMapType, + offset: usize, + random_state: &RandomState, + hashes_buffer: &mut Vec, + deleted_offset: usize, + fifo_hashmap: bool, +) -> Result<()> { + // evaluate the keys + let keys_values = on + .iter() + .map(|c| c.evaluate(batch)?.into_array(batch.num_rows())) + .collect::>>()?; + + // calculate the hash values + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + + // For usual JoinHashmap, the implementation is void. + hash_map.extend_zero(batch.num_rows()); + + // Updating JoinHashMap from hash values iterator + let hash_values_iter = hash_values + .iter() + .enumerate() + .map(|(i, val)| (i + offset, val)); + + if fifo_hashmap { + hash_map.update_from_iter(Box::new(hash_values_iter.rev()), deleted_offset); + } else { + hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset); + } + + Ok(()) +} + +pub(super) fn equal_rows_arr( + indices_left: &UInt64Array, + indices_right: &UInt32Array, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], + null_equality: NullEquality, +) -> Result<(UInt64Array, UInt32Array)> { + let mut iter = left_arrays.iter().zip(right_arrays.iter()); + + let Some((first_left, first_right)) = iter.next() else { + return Ok((Vec::::new().into(), Vec::::new().into())); + }; + + let arr_left = take(first_left.as_ref(), indices_left, None)?; + let arr_right = take(first_right.as_ref(), indices_right, None)?; + + let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?; + + // Use map and try_fold to iterate over the remaining pairs of arrays. + // In each iteration, take is used on the pair of arrays and their equality is determined. + // The results are then folded (combined) using the and function to get a final equality result. + equal = iter + .map(|(left, right)| { + let arr_left = take(left.as_ref(), indices_left, None)?; + let arr_right = take(right.as_ref(), indices_right, None)?; + eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality) + }) + .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?; + + let filter_builder = FilterBuilder::new(&equal).optimize().build(); + + let left_filtered = filter_builder.filter(indices_left)?; + let right_filtered = filter_builder.filter(indices_right)?; + + Ok(( + downcast_array(left_filtered.as_ref()), + downcast_array(right_filtered.as_ref()), + )) +} + +// version of eq_dyn supporting equality on null arrays +fn eq_dyn_null( + left: &dyn Array, + right: &dyn Array, + null_equality: NullEquality, +) -> Result { + // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special + // implementation + // + if left.data_type().is_nested() { + let op = match null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, + }; + return Ok(compare_op_for_nested(op, &left, &right)?); + } + match null_equality { + NullEquality::NullEqualsNothing => eq(&left, &right), + NullEquality::NullEqualsNull => not_distinct(&left, &right), + } +} + #[cfg(test)] mod tests { - use super::*; + use std::collections::HashMap; use std::pin::Pin; + use super::*; + use arrow::array::Int32Array; - use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use datafusion_physical_expr::PhysicalSortExpr; use rstest::rstest; @@ -1758,12 +2010,18 @@ mod tests { (20, Inexact(1), Inexact(40), Absent, Absent), Some(Inexact(10)), ), - // When we have distinct count. + // Distinct count matches the range ( (10, Inexact(1), Inexact(10), Inexact(10), Absent), (10, Inexact(1), Inexact(10), Inexact(10), Absent), Some(Inexact(10)), ), + // Distinct count takes precedence over the range + ( + (10, Inexact(1), Inexact(3), Inexact(10), Absent), + (10, Inexact(1), Inexact(3), Inexact(10), Absent), + Some(Inexact(10)), + ), // distinct(left) > distinct(right) ( (10, Inexact(1), Inexact(10), Inexact(5), Absent), @@ -1807,32 +2065,33 @@ mod tests { // Edge cases // ========== // - // No column level stats. + // No column level stats, fall back to row count. ( (10, Absent, Absent, Absent, Absent), (10, Absent, Absent, Absent, Absent), - None, + Some(Inexact(10)), ), - // No min or max (or both). + // No min or max (or both), but distinct available. ( (10, Absent, Absent, Inexact(3), Absent), (10, Absent, Absent, Inexact(3), Absent), - None, + Some(Inexact(33)), ), ( (10, Inexact(2), Absent, Inexact(3), Absent), (10, Absent, Inexact(5), Inexact(3), Absent), - None, + Some(Inexact(33)), ), ( (10, Absent, Inexact(3), Inexact(3), Absent), (10, Inexact(1), Absent, Inexact(3), Absent), - None, + Some(Inexact(33)), ), + // No min or max, fall back to row count ( (10, Absent, Inexact(3), Absent, Absent), (10, Inexact(1), Absent, Absent, Absent), - None, + Some(Inexact(10)), ), // Non overlapping min/max (when exact=False). ( @@ -2246,8 +2505,7 @@ mod tests { assert_eq!( output_cardinality, expected, - "failure for join_type: {}", - join_type + "failure for join_type: {join_type}" ); } @@ -2320,85 +2578,35 @@ mod tests { #[test] fn test_calculate_join_output_ordering() -> Result<()> { - let options = SortOptions::default(); let left_ordering = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), ]); let right_ordering = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options, - }, + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 1))), ]); let join_type = JoinType::Inner; - let on_columns = [( - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("x", 0)) as _, - )]; let left_columns_len = 5; let maintains_input_orders = [[true, false], [false, true]]; let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)]; let expected = [ - Some(LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 7)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 6)), - options, - }, - ])), - Some(LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 7)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 6)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, - ])), + LexOrdering::new(vec![ + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))), + ]), + LexOrdering::new(vec![ + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))), + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), + ]), ]; for (i, (maintains_input_order, probe_side)) in @@ -2409,11 +2617,10 @@ mod tests { left_ordering.as_ref(), right_ordering.as_ref(), join_type, - &on_columns, left_columns_len, maintains_input_order, probe_side, - ), + )?, expected[i] ); } @@ -2479,17 +2686,17 @@ mod tests { assert_eq!(proj.len(), 3); - let (col, name) = &proj[0]; - assert_eq!(name, "a"); - assert_col_expr(col, "a", 1); + let proj_expr = &proj[0]; + assert_eq!(proj_expr.alias, "a"); + assert_col_expr(&proj_expr.expr, "a", 1); - let (col, name) = &proj[1]; - assert_eq!(name, "b"); - assert_col_expr(col, "b", 2); + let proj_expr = &proj[1]; + assert_eq!(proj_expr.alias, "b"); + assert_col_expr(&proj_expr.expr, "b", 2); - let (col, name) = &proj[2]; - assert_eq!(name, "c"); - assert_col_expr(col, "c", 0); + let proj_expr = &proj[2]; + assert_eq!(proj_expr.alias, "c"); + assert_col_expr(&proj_expr.expr, "c", 0); } fn assert_col_expr(expr: &Arc, name: &str, index: usize) { @@ -2500,4 +2707,28 @@ mod tests { assert_eq!(col.name(), name); assert_eq!(col.index(), index); } + + #[test] + fn test_join_metadata() -> Result<()> { + let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".to_string(), "left".to_string())])); + + let right_schema = Schema::new(vec![Field::new("b", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".to_string(), "right".to_string())])); + + let (join_schema, _) = + build_join_schema(&left_schema, &right_schema, &JoinType::Left); + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "left".to_string())]) + ); + let (join_schema, _) = + build_join_schema(&left_schema, &right_schema, &JoinType::Right); + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "right".to_string())]) + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 04fbd06fabcde..17628fd8ad1d2 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -50,6 +50,8 @@ pub use crate::ordering::InputOrderMode; pub use crate::stream::EmptyRecordBatchStream; pub use crate::topk::TopK; pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; +pub use crate::work_table::WorkTable; +pub use spill::spill_manager::SpillManager; mod ordering; mod render_tree; @@ -58,14 +60,18 @@ mod visitor; pub mod aggregates; pub mod analyze; +pub mod async_func; +pub mod coalesce; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common; +pub mod coop; pub mod display; pub mod empty; pub mod execution_plan; pub mod explain; pub mod filter; +pub mod filter_pushdown; pub mod joins; pub mod limit; pub mod memory; @@ -81,7 +87,6 @@ pub mod streaming; pub mod tree_node; pub mod union; pub mod unnest; -pub mod values; pub mod windows; pub mod work_table; pub mod udaf { @@ -89,6 +94,4 @@ pub mod udaf { pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr; } -pub mod coalesce; -#[cfg(test)] pub mod test; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 89cf47a6d6508..6a0cae20e5aa6 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -105,12 +105,13 @@ impl DisplayAs for GlobalLimitExec { f, "GlobalLimitExec: skip={}, fetch={}", self.skip, - self.fetch.map_or("None".to_string(), |x| x.to_string()) + self.fetch + .map_or_else(|| "None".to_string(), |x| x.to_string()) ) } DisplayFormatType::TreeRender => { if let Some(fetch) = self.fetch { - writeln!(f, "limit={}", fetch)?; + writeln!(f, "limit={fetch}")?; } write!(f, "skip={}", self.skip) } @@ -164,10 +165,7 @@ impl ExecutionPlan for GlobalLimitExec { partition: usize, context: Arc, ) -> Result { - trace!( - "Start GlobalLimitExec::execute for partition: {}", - partition - ); + trace!("Start GlobalLimitExec::execute for partition: {partition}"); // GlobalLimitExec has a single output partition if 0 != partition { return internal_err!("GlobalLimitExec invalid partition {partition}"); @@ -193,13 +191,13 @@ impl ExecutionPlan for GlobalLimitExec { } fn statistics(&self) -> Result { - Statistics::with_fetch( - self.input.statistics()?, - self.schema(), - self.fetch, - self.skip, - 1, - ) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input + .partition_statistics(partition)? + .with_fetch(self.fetch, self.skip, 1) } fn fetch(&self) -> Option { @@ -334,13 +332,13 @@ impl ExecutionPlan for LocalLimitExec { } fn statistics(&self) -> Result { - Statistics::with_fetch( - self.input.statistics()?, - self.schema(), - Some(self.fetch), - 0, - 1, - ) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input + .partition_statistics(partition)? + .with_fetch(Some(self.fetch), 0, 1) } fn fetch(&self) -> Option { @@ -765,7 +763,7 @@ mod tests { let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } pub fn build_group_by( @@ -805,7 +803,7 @@ mod tests { fetch, ); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } async fn row_number_statistics_for_local_limit( @@ -818,7 +816,7 @@ mod tests { let offset = LocalLimitExec::new(csv, fetch); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } /// Return a RecordBatch with a single array with row_count sz diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 1bc872a56e763..1bf1e04efb53b 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -22,7 +22,9 @@ use std::fmt; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::coop::cooperative; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -35,6 +37,7 @@ use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use futures::Stream; use parking_lot::RwLock; @@ -131,6 +134,14 @@ impl RecordBatchStream for MemoryStream { } pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + fmt::Display { + /// Returns the generator as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + fn boundedness(&self) -> Boundedness { + Boundedness::Bounded + } + /// Generate the next batch, return `None` when no more batches are available fn generate_next_batch(&mut self) -> Result>; } @@ -146,6 +157,8 @@ pub struct LazyMemoryExec { batch_generators: Vec>>, /// Plan properties cache storing equivalence properties, partitioning, and execution mode cache: PlanProperties, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, } impl LazyMemoryExec { @@ -154,18 +167,67 @@ impl LazyMemoryExec { schema: SchemaRef, generators: Vec>>, ) -> Result { + let boundedness = generators + .iter() + .map(|g| g.read().boundedness()) + .reduce(|acc, b| match acc { + Boundedness::Bounded => b, + Boundedness::Unbounded { + requires_infinite_memory, + } => { + let acc_infinite_memory = requires_infinite_memory; + match b { + Boundedness::Bounded => acc, + Boundedness::Unbounded { + requires_infinite_memory, + } => Boundedness::Unbounded { + requires_infinite_memory: requires_infinite_memory + || acc_infinite_memory, + }, + } + } + }) + .unwrap_or(Boundedness::Bounded); + let cache = PlanProperties::new( EquivalenceProperties::new(Arc::clone(&schema)), Partitioning::RoundRobinBatch(generators.len()), EmissionType::Incremental, - Boundedness::Bounded, - ); + boundedness, + ) + .with_scheduling_type(SchedulingType::Cooperative); + Ok(Self { schema, batch_generators: generators, cache, + metrics: ExecutionPlanMetricsSet::new(), }) } + + pub fn try_set_partitioning(&mut self, partitioning: Partitioning) -> Result<()> { + if partitioning.partition_count() != self.batch_generators.len() { + internal_err!( + "Partition count must match generator count: {} != {}", + partitioning.partition_count(), + self.batch_generators.len() + ) + } else { + self.cache.partitioning = partitioning; + Ok(()) + } + } + + pub fn add_ordering(&mut self, ordering: impl IntoIterator) { + self.cache + .eq_properties + .add_orderings(std::iter::once(ordering)); + } + + /// Get the batch generators + pub fn generators(&self) -> &Vec>> { + &self.batch_generators + } } impl fmt::Debug for LazyMemoryExec { @@ -254,10 +316,18 @@ impl ExecutionPlan for LazyMemoryExec { ); } - Ok(Box::pin(LazyMemoryStream { + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + let stream = LazyMemoryStream { schema: Arc::clone(&self.schema), generator: Arc::clone(&self.batch_generators[partition]), - })) + baseline_metrics, + }; + Ok(Box::pin(cooperative(stream))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) } fn statistics(&self) -> Result { @@ -276,6 +346,8 @@ pub struct LazyMemoryStream { /// parallel execution. /// Sharing generators between streams should be used with caution. generator: Arc>, + /// Execution metrics + baseline_metrics: BaselineMetrics, } impl Stream for LazyMemoryStream { @@ -285,13 +357,16 @@ impl Stream for LazyMemoryStream { self: std::pin::Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll> { + let _timer_guard = self.baseline_metrics.elapsed_compute().timer(); let batch = self.generator.write().generate_next_batch(); - match batch { + let poll = match batch { Ok(Some(batch)) => Poll::Ready(Some(Ok(batch))), Ok(None) => Poll::Ready(None), Err(e) => Poll::Ready(Some(Err(e))), - } + }; + + self.baseline_metrics.record_poll(poll) } } @@ -304,6 +379,7 @@ impl RecordBatchStream for LazyMemoryStream { #[cfg(test)] mod lazy_memory_tests { use super::*; + use crate::common::collect; use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; @@ -327,6 +403,10 @@ mod lazy_memory_tests { } impl LazyBatchGenerator for TestGenerator { + fn as_any(&self) -> &dyn Any { + self + } + fn generate_next_batch(&mut self) -> Result> { if self.counter >= self.max_batches { return Ok(None); @@ -419,4 +499,45 @@ mod lazy_memory_tests { Ok(()) } + + #[tokio::test] + async fn test_generate_series_metrics_integration() -> Result<()> { + // Test LazyMemoryExec metrics with different configurations + let test_cases = vec![ + (10, 2, 10), // 10 rows, batch size 2, expected 10 rows + (100, 10, 100), // 100 rows, batch size 10, expected 100 rows + (5, 1, 5), // 5 rows, batch size 1, expected 5 rows + ]; + + for (total_rows, batch_size, expected_rows) in test_cases { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let generator = TestGenerator { + counter: 0, + max_batches: (total_rows + batch_size - 1) / batch_size, // ceiling division + batch_size: batch_size as usize, + schema: Arc::clone(&schema), + }; + + let exec = + LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?; + let task_ctx = Arc::new(TaskContext::default()); + + let stream = exec.execute(0, task_ctx)?; + let batches = collect(stream).await?; + + // Verify metrics exist with actual expected numbers + let metrics = exec.metrics().unwrap(); + + // Count actual rows returned + let actual_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(actual_rows, expected_rows); + + // Verify metrics match actual output + assert_eq!(metrics.output_rows().unwrap(), expected_rows); + assert!(metrics.elapsed_compute().unwrap() > 0); + } + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/metrics/baseline.rs b/datafusion/physical-plan/src/metrics/baseline.rs index a4a83b84b6555..15efb8f90aa20 100644 --- a/datafusion/physical-plan/src/metrics/baseline.rs +++ b/datafusion/physical-plan/src/metrics/baseline.rs @@ -45,7 +45,7 @@ use datafusion_common::Result; /// ``` #[derive(Debug, Clone)] pub struct BaselineMetrics { - /// end_time is set when `ExecutionMetrics::done()` is called + /// end_time is set when `BaselineMetrics::done()` is called end_time: Timestamp, /// amount of time the operator was actively trying to use the CPU @@ -117,9 +117,10 @@ impl BaselineMetrics { } } - /// Process a poll result of a stream producing output for an - /// operator, recording the output rows and stream done time and - /// returning the same poll result + /// Process a poll result of a stream producing output for an operator. + /// + /// Note: this method only updates `output_rows` and `end_time` metrics. + /// Remember to update `elapsed_compute` and other metrics manually. pub fn record_poll( &self, poll: Poll>>, @@ -150,7 +151,7 @@ pub struct SpillMetrics { /// count of spills during the execution of the operator pub spill_file_count: Count, - /// total spilled bytes during the execution of the operator + /// total bytes actually written to disk during the execution of the operator pub spilled_bytes: Count, /// total spilled rows during the execution of the operator @@ -168,6 +169,23 @@ impl SpillMetrics { } } +/// Metrics for tracking [`crate::stream::BatchSplitStream`] activity +#[derive(Debug, Clone)] +pub struct SplitMetrics { + /// Number of times an input [`RecordBatch`] was split + pub batches_split: Count, +} + +impl SplitMetrics { + /// Create a new [`SplitMetrics`] + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + batches_split: MetricBuilder::new(metrics) + .counter("batches_split", partition), + } + } +} + /// Trait for things that produce output rows as a result of execution. pub trait RecordOutput { /// Record that some number of output rows have been produced diff --git a/datafusion/physical-plan/src/metrics/custom.rs b/datafusion/physical-plan/src/metrics/custom.rs new file mode 100644 index 0000000000000..546af6f3335e9 --- /dev/null +++ b/datafusion/physical-plan/src/metrics/custom.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Custom metric value type. + +use std::{any::Any, fmt::Debug, fmt::Display, sync::Arc}; + +/// A trait for implementing custom metric values. +/// +/// This trait enables defining application- or operator-specific metric types +/// that can be aggregated and displayed alongside standard metrics. These +/// custom metrics integrate with [`MetricValue::Custom`] and support +/// aggregation logic, introspection, and optional numeric representation. +/// +/// # Requirements +/// Implementations of `CustomMetricValue` must satisfy the following: +/// +/// 1. [`Self::aggregate`]: Defines how two metric values are combined +/// 2. [`Self::new_empty`]: Returns a new, zero-value instance for accumulation +/// 3. [`Self::as_any`]: Enables dynamic downcasting for type-specific operations +/// 4. [`Self::as_usize`]: Optionally maps the value to a `usize` (for sorting, display, etc.) +/// 5. [`Self::is_eq`]: Implements comparison between two values, this isn't reusing the std +/// PartialEq trait because this trait is used dynamically in the context of +/// [`MetricValue::Custom`] +/// +/// # Examples +/// ``` +/// # use std::sync::Arc; +/// # use std::fmt::{Debug, Display}; +/// # use std::any::Any; +/// # use std::sync::atomic::{AtomicUsize, Ordering}; +/// +/// # use datafusion_physical_plan::metrics::CustomMetricValue; +/// +/// #[derive(Debug, Default)] +/// struct MyCounter { +/// count: AtomicUsize, +/// } +/// +/// impl Display for MyCounter { +/// fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +/// write!(f, "count: {}", self.count.load(Ordering::Relaxed)) +/// } +/// } +/// +/// impl CustomMetricValue for MyCounter { +/// fn new_empty(&self) -> Arc { +/// Arc::new(Self::default()) +/// } +/// +/// fn aggregate(&self, other: Arc) { +/// let other = other.as_any().downcast_ref::().unwrap(); +/// self.count.fetch_add(other.count.load(Ordering::Relaxed), Ordering::Relaxed); +/// } +/// +/// fn as_any(&self) -> &dyn Any { +/// self +/// } +/// +/// fn as_usize(&self) -> usize { +/// self.count.load(Ordering::Relaxed) +/// } +/// +/// fn is_eq(&self, other: &Arc) -> bool { +/// let Some(other) = other.as_any().downcast_ref::() else { +/// return false; +/// }; +/// +/// self.count.load(Ordering::Relaxed) == other.count.load(Ordering::Relaxed) +/// } +/// } +/// ``` +/// +/// [`MetricValue::Custom`]: super::MetricValue::Custom +pub trait CustomMetricValue: Display + Debug + Send + Sync { + /// Returns a new, zero-initialized version of this metric value. + /// + /// This value is used during metric aggregation to accumulate results. + fn new_empty(&self) -> Arc; + + /// Merges another metric value into this one. + /// + /// The type of `other` could be of a different custom type as long as it's aggregatable into self. + fn aggregate(&self, other: Arc); + + /// Returns this value as a [`Any`] to support dynamic downcasting. + fn as_any(&self) -> &dyn Any; + + /// Optionally returns a numeric representation of the value, if meaningful. + /// Otherwise will default to zero. + /// + /// This is used for sorting and summarizing metrics. + fn as_usize(&self) -> usize { + 0 + } + + /// Compares this value with another custom value. + fn is_eq(&self, other: &Arc) -> bool; +} diff --git a/datafusion/physical-plan/src/metrics/mod.rs b/datafusion/physical-plan/src/metrics/mod.rs index 2ac7ac1299a0a..0b9b4bed856b8 100644 --- a/datafusion/physical-plan/src/metrics/mod.rs +++ b/datafusion/physical-plan/src/metrics/mod.rs @@ -19,6 +19,7 @@ mod baseline; mod builder; +mod custom; mod value; use parking_lot::Mutex; @@ -31,8 +32,9 @@ use std::{ use datafusion_common::HashMap; // public exports -pub use baseline::{BaselineMetrics, RecordOutput, SpillMetrics}; +pub use baseline::{BaselineMetrics, RecordOutput, SpillMetrics, SplitMetrics}; pub use builder::MetricBuilder; +pub use custom::CustomMetricValue; pub use value::{Count, Gauge, MetricValue, ScopedTimerGuard, Time, Timestamp}; /// Something that tracks a value of interest (metric) of a DataFusion @@ -263,6 +265,7 @@ impl MetricsSet { MetricValue::Gauge { name, .. } => name == metric_name, MetricValue::StartTimestamp(_) => false, MetricValue::EndTimestamp(_) => false, + MetricValue::Custom { .. } => false, }) } @@ -384,7 +387,7 @@ impl ExecutionPlanMetricsSet { /// "tags" in /// [InfluxDB](https://docs.influxdata.com/influxdb/v1.8/write_protocols/line_protocol_tutorial/) /// , "attributes" in [open -/// telemetry], +/// telemetry], /// etc. /// /// As the name and value are expected to mostly be constant strings, diff --git a/datafusion/physical-plan/src/metrics/value.rs b/datafusion/physical-plan/src/metrics/value.rs index decf77369db4f..3149fca95ba84 100644 --- a/datafusion/physical-plan/src/metrics/value.rs +++ b/datafusion/physical-plan/src/metrics/value.rs @@ -17,9 +17,14 @@ //! Value representation of metrics +use super::CustomMetricValue; +use chrono::{DateTime, Utc}; +use datafusion_common::instant::Instant; +use datafusion_execution::memory_pool::human_readable_size; +use parking_lot::Mutex; use std::{ borrow::{Borrow, Cow}, - fmt::Display, + fmt::{Debug, Display}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -27,10 +32,6 @@ use std::{ time::Duration, }; -use chrono::{DateTime, Utc}; -use datafusion_common::instant::Instant; -use parking_lot::Mutex; - /// A counter to record things such as number of input or output rows /// /// Note `clone`ing counters update the same underlying metrics @@ -221,6 +222,15 @@ impl Time { pub fn value(&self) -> usize { self.nanos.load(Ordering::Relaxed) } + + /// Return a scoped guard that adds the amount of time elapsed between the + /// given instant and its drop (or the call to `stop`) to the underlying metric + pub fn timer_with(&self, now: Instant) -> ScopedTimerGuard<'_> { + ScopedTimerGuard { + inner: self, + start: Some(now), + } + } } /// Stores a single timestamp, stored as the number of nanoseconds @@ -330,6 +340,20 @@ impl ScopedTimerGuard<'_> { pub fn done(mut self) { self.stop() } + + /// Stop the timer timing and record the time taken since the given endpoint. + pub fn stop_with(&mut self, end_time: Instant) { + if let Some(start) = self.start.take() { + let elapsed = end_time - start; + self.inner.add_duration(elapsed) + } + } + + /// Stop the timer, record the time taken since `end_time` endpoint, and + /// consume self. + pub fn done_with(mut self, end_time: Instant) { + self.stop_with(end_time) + } } impl Drop for ScopedTimerGuard<'_> { @@ -343,7 +367,7 @@ impl Drop for ScopedTimerGuard<'_> { /// Among other differences, the metric types have different ways to /// logically interpret their underlying values and some metrics are /// so common they are given special treatment. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub enum MetricValue { /// Number of output rows produced: "output_rows" metric OutputRows(Count), @@ -400,6 +424,78 @@ pub enum MetricValue { StartTimestamp(Timestamp), /// The time at which execution ended EndTimestamp(Timestamp), + Custom { + /// The provided name of this metric + name: Cow<'static, str>, + /// A custom implementation of the metric value. + value: Arc, + }, +} + +// Manually implement PartialEq for `MetricValue` because it contains CustomMetricValue in its +// definition which is a dyn trait. This wouldn't allow us to just derive PartialEq. +impl PartialEq for MetricValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (MetricValue::OutputRows(count), MetricValue::OutputRows(other)) => { + count == other + } + (MetricValue::ElapsedCompute(time), MetricValue::ElapsedCompute(other)) => { + time == other + } + (MetricValue::SpillCount(count), MetricValue::SpillCount(other)) => { + count == other + } + (MetricValue::SpilledBytes(count), MetricValue::SpilledBytes(other)) => { + count == other + } + (MetricValue::SpilledRows(count), MetricValue::SpilledRows(other)) => { + count == other + } + ( + MetricValue::CurrentMemoryUsage(gauge), + MetricValue::CurrentMemoryUsage(other), + ) => gauge == other, + ( + MetricValue::Count { name, count }, + MetricValue::Count { + name: other_name, + count: other_count, + }, + ) => name == other_name && count == other_count, + ( + MetricValue::Gauge { name, gauge }, + MetricValue::Gauge { + name: other_name, + gauge: other_gauge, + }, + ) => name == other_name && gauge == other_gauge, + ( + MetricValue::Time { name, time }, + MetricValue::Time { + name: other_name, + time: other_time, + }, + ) => name == other_name && time == other_time, + + ( + MetricValue::StartTimestamp(timestamp), + MetricValue::StartTimestamp(other), + ) => timestamp == other, + (MetricValue::EndTimestamp(timestamp), MetricValue::EndTimestamp(other)) => { + timestamp == other + } + ( + MetricValue::Custom { name, value }, + MetricValue::Custom { + name: other_name, + value: other_value, + }, + ) => name == other_name && value.is_eq(other_value), + // Default case when the two sides do not have the same type. + _ => false, + } + } } impl MetricValue { @@ -417,6 +513,7 @@ impl MetricValue { Self::Time { name, .. } => name.borrow(), Self::StartTimestamp(_) => "start_timestamp", Self::EndTimestamp(_) => "end_timestamp", + Self::Custom { name, .. } => name.borrow(), } } @@ -442,6 +539,7 @@ impl MetricValue { .and_then(|ts| ts.timestamp_nanos_opt()) .map(|nanos| nanos as usize) .unwrap_or(0), + Self::Custom { value, .. } => value.as_usize(), } } @@ -469,6 +567,10 @@ impl MetricValue { }, Self::StartTimestamp(_) => Self::StartTimestamp(Timestamp::new()), Self::EndTimestamp(_) => Self::EndTimestamp(Timestamp::new()), + Self::Custom { name, value } => Self::Custom { + name: name.clone(), + value: value.new_empty(), + }, } } @@ -515,6 +617,14 @@ impl MetricValue { (Self::EndTimestamp(timestamp), Self::EndTimestamp(other_timestamp)) => { timestamp.update_to_max(other_timestamp); } + ( + Self::Custom { value, .. }, + Self::Custom { + value: other_value, .. + }, + ) => { + value.aggregate(Arc::clone(other_value)); + } m @ (_, _) => { panic!( "Mismatched metric types. Can not aggregate {:?} with value {:?}", @@ -539,6 +649,7 @@ impl MetricValue { Self::Time { .. } => 8, Self::StartTimestamp(_) => 9, // show timestamps last Self::EndTimestamp(_) => 10, + Self::Custom { .. } => 11, } } @@ -554,11 +665,14 @@ impl Display for MetricValue { match self { Self::OutputRows(count) | Self::SpillCount(count) - | Self::SpilledBytes(count) | Self::SpilledRows(count) | Self::Count { count, .. } => { write!(f, "{count}") } + Self::SpilledBytes(count) => { + let readable_count = human_readable_size(count.value()); + write!(f, "{readable_count}") + } Self::CurrentMemoryUsage(gauge) | Self::Gauge { gauge, .. } => { write!(f, "{gauge}") } @@ -574,16 +688,103 @@ impl Display for MetricValue { Self::StartTimestamp(timestamp) | Self::EndTimestamp(timestamp) => { write!(f, "{timestamp}") } + Self::Custom { name, value } => { + write!(f, "name:{name} {value}") + } } } } #[cfg(test)] mod tests { + use std::any::Any; + use chrono::TimeZone; + use datafusion_execution::memory_pool::units::MB; use super::*; + #[derive(Debug, Default)] + pub struct CustomCounter { + count: AtomicUsize, + } + + impl Display for CustomCounter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "count: {}", self.count.load(Ordering::Relaxed)) + } + } + + impl CustomMetricValue for CustomCounter { + fn new_empty(&self) -> Arc { + Arc::new(CustomCounter::default()) + } + + fn aggregate(&self, other: Arc) { + let other = other.as_any().downcast_ref::().unwrap(); + self.count + .fetch_add(other.count.load(Ordering::Relaxed), Ordering::Relaxed); + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn is_eq(&self, other: &Arc) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + + self.count.load(Ordering::Relaxed) == other.count.load(Ordering::Relaxed) + } + } + + fn new_custom_counter(name: &'static str, value: usize) -> MetricValue { + let custom_counter = CustomCounter::default(); + custom_counter.count.fetch_add(value, Ordering::Relaxed); + let custom_val = MetricValue::Custom { + name: Cow::Borrowed(name), + value: Arc::new(custom_counter), + }; + + custom_val + } + + #[test] + fn test_custom_metric_with_mismatching_names() { + let mut custom_val = new_custom_counter("Hi", 1); + let other_custom_val = new_custom_counter("Hello", 1); + + // Not equal since the name differs. + assert!(other_custom_val != custom_val); + + // Should work even though the name differs + custom_val.aggregate(&other_custom_val); + + let expected_val = new_custom_counter("Hi", 2); + assert!(expected_val == custom_val); + } + + #[test] + fn test_custom_metric() { + let mut custom_val = new_custom_counter("hi", 11); + let other_custom_val = new_custom_counter("hi", 20); + + custom_val.aggregate(&other_custom_val); + + assert!(custom_val != other_custom_val); + + if let MetricValue::Custom { value, .. } = custom_val { + let counter = value + .as_any() + .downcast_ref::() + .expect("Expected CustomCounter"); + assert_eq!(counter.count.load(Ordering::Relaxed), 31); + } else { + panic!("Unexpected value"); + } + } + #[test] fn test_display_output_rows() { let count = Count::new(); @@ -605,6 +806,20 @@ mod tests { } } + #[test] + fn test_display_spilled_bytes() { + let count = Count::new(); + let spilled_byte = MetricValue::SpilledBytes(count.clone()); + + assert_eq!("0.0 B", spilled_byte.to_string()); + + count.add((100 * MB) as usize); + assert_eq!("100.0 MB", spilled_byte.to_string()); + + count.add((0.5 * MB as f64) as usize); + assert_eq!("100.5 MB", spilled_byte.to_string()); + } + #[test] fn test_display_time() { let time = Time::new(); @@ -649,4 +864,99 @@ mod tests { ); } } + + #[test] + fn test_timer_with_custom_instant() { + let time = Time::new(); + let start_time = Instant::now(); + + // Sleep a bit to ensure some time passes + std::thread::sleep(Duration::from_millis(1)); + + // Create timer with the earlier start time + let mut timer = time.timer_with(start_time); + + // Sleep a bit more + std::thread::sleep(Duration::from_millis(1)); + + // Stop the timer + timer.stop(); + + // The recorded time should be at least 20ms (both sleeps) + assert!( + time.value() >= 2_000_000, + "Expected at least 2ms, got {} ns", + time.value() + ); + } + + #[test] + fn test_stop_with_custom_endpoint() { + let time = Time::new(); + let start = Instant::now(); + let mut timer = time.timer_with(start); + + // Simulate exactly 10ms passing + let end = start + Duration::from_millis(10); + + // Stop with custom endpoint + timer.stop_with(end); + + // Should record exactly 10ms (10_000_000 nanoseconds) + // Allow for small variations due to timer resolution + let recorded = time.value(); + assert!( + (10_000_000..=10_100_000).contains(&recorded), + "Expected ~10ms, got {recorded} ns" + ); + + // Calling stop_with again should not add more time + timer.stop_with(end); + assert_eq!( + recorded, + time.value(), + "Time should not change after second stop" + ); + } + + #[test] + fn test_done_with_custom_endpoint() { + let time = Time::new(); + let start = Instant::now(); + + // Create a new scope for the timer + { + let timer = time.timer_with(start); + + // Simulate 50ms passing + let end = start + Duration::from_millis(5); + + // Call done_with to stop and consume the timer + timer.done_with(end); + + // Timer is consumed, can't use it anymore + } + + // Should record exactly 5ms + let recorded = time.value(); + assert!( + (5_000_000..=5_100_000).contains(&recorded), + "Expected ~5ms, got {recorded} ns", + ); + + // Test that done_with prevents drop from recording time again + { + let timer2 = time.timer_with(start); + let end2 = start + Duration::from_millis(5); + timer2.done_with(end2); + // drop happens here but should not record additional time + } + + // Should have added only 5ms more + let new_recorded = time.value(); + assert!( + (10_000_000..=10_100_000).contains(&new_recorded), + "Expected ~10ms total, got {new_recorded} ns", + ); + } } diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index eecd980d09f8a..e7df79f867d70 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -20,12 +20,15 @@ use std::any::Any; use std::sync::Arc; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::coop::cooperative; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::memory::MemoryStream; -use crate::{common, DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics}; -use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; -use arrow::array::{ArrayRef, NullArray}; -use arrow::array::{RecordBatch, RecordBatchOptions}; +use crate::{ + common, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, +}; + +use arrow::array::{ArrayRef, NullArray, RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; @@ -99,6 +102,7 @@ impl PlaceholderRowExec { EmissionType::Incremental, Boundedness::Bounded, ) + .with_scheduling_type(SchedulingType::Cooperative) } } @@ -158,19 +162,27 @@ impl ExecutionPlan for PlaceholderRowExec { ); } - Ok(Box::pin(MemoryStream::try_new( - self.data()?, - Arc::clone(&self.schema), - None, - )?)) + let ms = MemoryStream::try_new(self.data()?, Arc::clone(&self.schema), None)?; + Ok(Box::pin(cooperative(ms))) } fn statistics(&self) -> Result { - let batch = self + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let batches = self .data() .expect("Create single row placeholder RecordBatch should not fail"); + + let batches = match partition { + Some(_) => vec![batches], + // entire plan + None => vec![batches; self.partitions], + }; + Ok(common::compute_record_batch_statistics( - &[batch], + &batches, &self.schema, None, )) diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 1d3e23ea90974..6eea70e1176d3 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -20,24 +20,28 @@ //! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. -use std::any::Any; -use std::collections::HashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use super::expressions::{CastExpr, Column, Literal}; +use super::expressions::{Column, Literal}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; use crate::execution_plan::CardinalityEffect; -use crate::joins::utils::{ColumnIndex, JoinFilter}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef}; use crate::{ColumnStatistics, DisplayFormatType, ExecutionPlan, PhysicalExpr}; +use std::any::Any; +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, @@ -46,18 +50,20 @@ use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::PhysicalExprRef; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; -use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::stream::{Stream, StreamExt}; -use itertools::Itertools; use log::trace; -/// Execution plan for a projection +/// [`ExecutionPlan`] for a projection +/// +/// Computes a set of scalar value expressions for each input row, producing one +/// output row for each input row. #[derive(Debug, Clone)] pub struct ProjectionExec { /// The projection expressions stored as tuples of (expression, output column name) - pub(crate) expr: Vec<(Arc, String)>, + pub(crate) expr: Vec, /// The schema once the projection has been applied to the input schema: SchemaRef, /// The input plan @@ -70,23 +76,74 @@ pub struct ProjectionExec { impl ProjectionExec { /// Create a projection on an input - pub fn try_new( - expr: Vec<(Arc, String)>, - input: Arc, - ) -> Result { + /// + /// # Example: + /// Create a `ProjectionExec` to crate `SELECT a, a+b AS sum_ab FROM t1`: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_schema::{Schema, Field, DataType}; + /// # use datafusion_expr::Operator; + /// # use datafusion_physical_plan::ExecutionPlan; + /// # use datafusion_physical_expr::expressions::{col, binary}; + /// # use datafusion_physical_plan::empty::EmptyExec; + /// # use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; + /// # fn schema() -> Arc { + /// # Arc::new(Schema::new(vec![ + /// # Field::new("a", DataType::Int32, false), + /// # Field::new("b", DataType::Int32, false), + /// # ])) + /// # } + /// # + /// # fn input() -> Arc { + /// # Arc::new(EmptyExec::new(schema())) + /// # } + /// # + /// # fn main() { + /// let schema = schema(); + /// // Create PhysicalExprs + /// let a = col("a", &schema).unwrap(); + /// let b = col("b", &schema).unwrap(); + /// let a_plus_b = binary(Arc::clone(&a), Operator::Plus, b, &schema).unwrap(); + /// // create ProjectionExec + /// let proj = ProjectionExec::try_new([ + /// ProjectionExpr { + /// // expr a produces the column named "a" + /// expr: a, + /// alias: "a".to_string(), + /// }, + /// ProjectionExpr { + /// // expr: a + b produces the column named "sum_ab" + /// expr: a_plus_b, + /// alias: "sum_ab".to_string(), + /// } + /// ], input()).unwrap(); + /// # } + /// ``` + pub fn try_new(expr: I, input: Arc) -> Result + where + I: IntoIterator, + E: Into, + { let input_schema = input.schema(); + // convert argument to Vec + let expr = expr.into_iter().map(Into::into).collect::>(); let fields: Result> = expr .iter() - .map(|(e, name)| { - let mut field = Field::new( - name, - e.data_type(&input_schema)?, - e.nullable(&input_schema)?, - ); - field.set_metadata( - get_field_metadata(e, &input_schema).unwrap_or_default(), - ); + .map(|proj_expr| { + let metadata = proj_expr + .expr + .return_field(&input_schema)? + .metadata() + .clone(); + + let field = Field::new( + &proj_expr.alias, + proj_expr.expr.data_type(&input_schema)?, + proj_expr.expr.nullable(&input_schema)?, + ) + .with_metadata(metadata); Ok(field) }) @@ -98,7 +155,10 @@ impl ProjectionExec { )); // Construct a map from the input expressions to the output expression of the Projection - let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new( + expr.iter().map(|p| (Arc::clone(&p.expr), p.alias.clone())), + &input_schema, + )?; let cache = Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?; Ok(Self { @@ -111,7 +171,7 @@ impl ProjectionExec { } /// The projection expressions stored as tuples of (expression, output column name) - pub fn expr(&self) -> &[(Arc, String)] { + pub fn expr(&self) -> &[ProjectionExpr] { &self.expr } @@ -127,14 +187,12 @@ impl ProjectionExec { schema: SchemaRef, ) -> Result { // Calculate equivalence properties: - let mut input_eq_properties = input.equivalence_properties().clone(); - input_eq_properties.substitute_oeq_class(projection_mapping)?; + let input_eq_properties = input.equivalence_properties(); let eq_properties = input_eq_properties.project(projection_mapping, schema); - // Calculate output partitioning, which needs to respect aliases: - let input_partition = input.output_partitioning(); - let output_partitioning = - input_partition.project(projection_mapping, &input_eq_properties); + let output_partitioning = input + .output_partitioning() + .project(projection_mapping, input_eq_properties); Ok(PlanProperties::new( eq_properties, @@ -145,6 +203,35 @@ impl ProjectionExec { } } +/// A projection expression that is created by [`ProjectionExec`] +/// +/// The expression is evaluated and the result is stored in a column +/// with the name specified by `alias`. +/// +/// For example, the SQL expression `a + b AS sum_ab` would be represented +/// as a `ProjectionExpr` where `expr` is the expression `a + b` +/// and `alias` is the string `sum_ab`. +#[derive(Debug, Clone)] +pub struct ProjectionExpr { + /// The expression that will be evaluated. + pub expr: Arc, + /// The name of the output column for use an output schema. + pub alias: String, +} + +impl ProjectionExpr { + /// Create a new projection expression + pub fn new(expr: Arc, alias: String) -> Self { + Self { expr, alias } + } +} + +impl From<(Arc, String)> for ProjectionExpr { + fn from(value: (Arc, String)) -> Self { + Self::new(value.0, value.1) + } +} + impl DisplayAs for ProjectionExec { fn fmt_as( &self, @@ -156,10 +243,10 @@ impl DisplayAs for ProjectionExec { let expr: Vec = self .expr .iter() - .map(|(e, alias)| { - let e = e.to_string(); - if &e != alias { - format!("{e} as {alias}") + .map(|proj_expr| { + let e = proj_expr.expr.to_string(); + if e != proj_expr.alias { + format!("{e} as {}", proj_expr.alias) } else { e } @@ -169,12 +256,12 @@ impl DisplayAs for ProjectionExec { write!(f, "ProjectionExec: expr=[{}]", expr.join(", ")) } DisplayFormatType::TreeRender => { - for (i, (e, alias)) in self.expr().iter().enumerate() { - let expr_sql = fmt_sql(e.as_ref()); - if &e.to_string() == alias { + for (i, proj_expr) in self.expr().iter().enumerate() { + let expr_sql = fmt_sql(proj_expr.expr.as_ref()); + if proj_expr.expr.to_string() == proj_expr.alias { writeln!(f, "expr{i}={expr_sql}")?; } else { - writeln!(f, "{alias}={expr_sql}")?; + writeln!(f, "{}={expr_sql}", proj_expr.alias)?; } } @@ -198,15 +285,25 @@ impl ExecutionPlan for ProjectionExec { &self.cache } - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - fn maintains_input_order(&self) -> Vec { // Tell optimizer this operator doesn't reorder its input vec![true] } + fn benefits_from_input_partitioning(&self) -> Vec { + let all_simple_exprs = self.expr.iter().all(|proj_expr| { + proj_expr.expr.as_any().is::() + || proj_expr.expr.as_any().is::() + }); + // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, + // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. + vec![!all_simple_exprs] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + fn with_new_children( self: Arc, mut children: Vec>, @@ -215,16 +312,6 @@ impl ExecutionPlan for ProjectionExec { .map(|p| Arc::new(p) as _) } - fn benefits_from_input_partitioning(&self) -> Vec { - let all_simple_exprs = self - .expr - .iter() - .all(|(e, _)| e.as_any().is::() || e.as_any().is::()); - // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, - // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. - vec![!all_simple_exprs] - } - fn execute( &self, partition: usize, @@ -233,7 +320,7 @@ impl ExecutionPlan for ProjectionExec { trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); Ok(Box::pin(ProjectionStream { schema: Arc::clone(&self.schema), - expr: self.expr.iter().map(|x| Arc::clone(&x.0)).collect(), + expr: self.expr.iter().map(|x| Arc::clone(&x.expr)).collect(), input: self.input.execute(partition, context)?, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) @@ -244,11 +331,18 @@ impl ExecutionPlan for ProjectionExec { } fn statistics(&self) -> Result { - Ok(stats_projection( - self.input.statistics()?, - self.expr.iter().map(|(e, _)| Arc::clone(e)), - Arc::clone(&self.schema), - )) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; + stats_projection( + input_stats, + self.expr + .iter() + .map(|proj_expr| Arc::clone(&proj_expr.expr)), + Arc::clone(&self.input.schema()), + ) } fn supports_limit_pushdown(&self) -> bool { @@ -271,31 +365,34 @@ impl ExecutionPlan for ProjectionExec { Ok(Some(Arc::new(projection.clone()))) } } -} -/// If 'e' is a direct column reference, returns the field level -/// metadata for that field, if any. Otherwise returns None -pub(crate) fn get_field_metadata( - e: &Arc, - input_schema: &Schema, -) -> Option> { - if let Some(cast) = e.as_any().downcast_ref::() { - return get_field_metadata(cast.expr(), input_schema); + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + // TODO: In future, we can try to handle inverting aliases here. + // For the time being, we pass through untransformed filters, so filters on aliases are not handled. + // https://github.com/apache/datafusion/issues/17246 + FilterDescription::from_children(parent_filters, &self.children()) } - // Look up field by index in schema (not NAME as there can be more than one - // column with the same name) - e.as_any() - .downcast_ref::() - .map(|column| input_schema.field(column.index()).metadata()) - .cloned() + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } } fn stats_projection( mut stats: Statistics, exprs: impl Iterator>, schema: SchemaRef, -) -> Statistics { +) -> Result { let mut primitive_row_size = 0; let mut primitive_row_size_possible = true; let mut column_statistics = vec![]; @@ -308,11 +405,10 @@ fn stats_projection( ColumnStatistics::new_unknown() }; column_statistics.push(col_stats); - if let Ok(data_type) = expr.data_type(&schema) { - if let Some(value) = data_type.primitive_width() { - primitive_row_size += value; - continue; - } + let data_type = expr.data_type(&schema)?; + if let Some(value) = data_type.primitive_width() { + primitive_row_size += value; + continue; } primitive_row_size_possible = false; } @@ -322,7 +418,7 @@ fn stats_projection( Precision::Exact(primitive_row_size).multiply(&stats.num_rows); } stats.column_statistics = column_statistics; - stats + Ok(stats) } impl ProjectionStream { @@ -417,22 +513,25 @@ pub fn try_embed_projection( let embed_project_exprs = projection_index .iter() .zip(new_execution_plan.schema().fields()) - .map(|(index, field)| { - ( - Arc::new(Column::new(field.name(), *index)) as Arc, - field.name().to_owned(), - ) + .map(|(index, field)| ProjectionExpr { + expr: Arc::new(Column::new(field.name(), *index)) as Arc, + alias: field.name().to_owned(), }) .collect::>(); let mut new_projection_exprs = Vec::with_capacity(projection.expr().len()); - for (expr, alias) in projection.expr() { + for proj_expr in projection.expr() { // update column index for projection expression since the input schema has been changed. - let Some(expr) = update_expr(expr, embed_project_exprs.as_slice(), false)? else { + let Some(expr) = + update_expr(&proj_expr.expr, embed_project_exprs.as_slice(), false)? + else { return Ok(None); }; - new_projection_exprs.push((expr, alias.clone())); + new_projection_exprs.push(ProjectionExpr { + expr, + alias: proj_expr.alias.clone(), + }); } // Old projection may contain some alias or expression such as `a + 1` and `CAST('true' AS BOOLEAN)`, but our projection_exprs in hash join just contain column, so we need to create the new projection to keep the original projection. let new_projection = Arc::new(ProjectionExec::try_new( @@ -446,11 +545,6 @@ pub fn try_embed_projection( } } -/// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>; -/// Reference for JoinOn. -pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)]; - pub struct JoinData { pub projected_left_child: ProjectionExec, pub projected_right_child: ProjectionExec, @@ -543,7 +637,7 @@ pub fn remove_unnecessary_projections( } else { return Ok(Transformed::no(plan)); }; - Ok(maybe_modified.map_or(Transformed::no(plan), Transformed::yes)) + Ok(maybe_modified.map_or_else(|| Transformed::no(plan), Transformed::yes)) } /// Compare the inputs and outputs of the projection. All expressions must be @@ -552,21 +646,23 @@ pub fn remove_unnecessary_projections( /// but `SELECT b, a` and `SELECT a+1, b` and `SELECT a AS c, b` are not. fn is_projection_removable(projection: &ProjectionExec) -> bool { let exprs = projection.expr(); - exprs.iter().enumerate().all(|(idx, (expr, alias))| { - let Some(col) = expr.as_any().downcast_ref::() else { + exprs.iter().enumerate().all(|(idx, proj_expr)| { + let Some(col) = proj_expr.expr.as_any().downcast_ref::() else { return false; }; - col.name() == alias && col.index() == idx + col.name() == proj_expr.alias && col.index() == idx }) && exprs.len() == projection.input().schema().fields().len() } /// Given the expression set of a projection, checks if the projection causes /// any renaming or constructs a non-`Column` physical expression. -pub fn all_alias_free_columns(exprs: &[(Arc, String)]) -> bool { - exprs.iter().all(|(expr, alias)| { - expr.as_any() +pub fn all_alias_free_columns(exprs: &[ProjectionExpr]) -> bool { + exprs.iter().all(|proj_expr| { + proj_expr + .expr + .as_any() .downcast_ref::() - .map(|column| column.name() == alias) + .map(|column| column.name() == proj_expr.alias) .unwrap_or(false) }) } @@ -575,14 +671,15 @@ pub fn all_alias_free_columns(exprs: &[(Arc, String)]) -> bool /// projection operator's expressions. To use this function safely, one must /// ensure that all expressions are `Column` expressions without aliases. pub fn new_projections_for_columns( - projection: &ProjectionExec, + projection: &[ProjectionExpr], source: &[usize], ) -> Vec { projection - .expr() .iter() - .filter_map(|(expr, _)| { - expr.as_any() + .filter_map(|proj_expr| { + proj_expr + .expr + .as_any() .downcast_ref::() .map(|expr| source[expr.index()]) }) @@ -600,8 +697,10 @@ pub fn make_with_child( } /// Returns `true` if all the expressions in the argument are `Column`s. -pub fn all_columns(exprs: &[(Arc, String)]) -> bool { - exprs.iter().all(|(expr, _)| expr.as_any().is::()) +pub fn all_columns(exprs: &[ProjectionExpr]) -> bool { + exprs + .iter() + .all(|proj_expr| proj_expr.expr.as_any().is::()) } /// The function operates in two modes: @@ -623,7 +722,7 @@ pub fn all_columns(exprs: &[(Arc, String)]) -> bool { /// `a@0`, but `b@2` results in `None` since the projection does not include `b`. pub fn update_expr( expr: &Arc, - projected_exprs: &[(Arc, String)], + projected_exprs: &[ProjectionExpr], sync_with_child: bool, ) -> Result>> { #[derive(Debug, PartialEq)] @@ -640,7 +739,7 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up(|expr: Arc| { + .transform_up(|expr| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } @@ -652,7 +751,7 @@ pub fn update_expr( state = RewriteState::RewrittenValid; // Update the index of `column`: Ok(Transformed::yes(Arc::clone( - &projected_exprs[column.index()].0, + &projected_exprs[column.index()].expr, ))) } else { // default to invalid, in case we can't find the relevant column @@ -661,14 +760,14 @@ pub fn update_expr( projected_exprs .iter() .enumerate() - .find_map(|(index, (projected_expr, alias))| { - projected_expr.as_any().downcast_ref::().and_then( + .find_map(|(index, proj_expr)| { + proj_expr.expr.as_any().downcast_ref::().and_then( |projected_column| { (column.name().eq(projected_column.name()) && column.index() == projected_column.index()) .then(|| { state = RewriteState::RewrittenValid; - Arc::new(Column::new(alias, index)) as _ + Arc::new(Column::new(&proj_expr.alias, index)) as _ }) }, ) @@ -684,17 +783,55 @@ pub fn update_expr( new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } +/// Updates the given lexicographic ordering according to given projected +/// expressions using the [`update_expr`] function. +pub fn update_ordering( + ordering: LexOrdering, + projected_exprs: &[ProjectionExpr], +) -> Result> { + let mut updated_exprs = vec![]; + for mut sort_expr in ordering.into_iter() { + let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)? + else { + return Ok(None); + }; + sort_expr.expr = updated_expr; + updated_exprs.push(sort_expr); + } + Ok(LexOrdering::new(updated_exprs)) +} + +/// Updates the given lexicographic requirement according to given projected +/// expressions using the [`update_expr`] function. +pub fn update_ordering_requirement( + reqs: LexRequirement, + projected_exprs: &[ProjectionExpr], +) -> Result> { + let mut updated_exprs = vec![]; + for mut sort_expr in reqs.into_iter() { + let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)? + else { + return Ok(None); + }; + sort_expr.expr = updated_expr; + updated_exprs.push(sort_expr); + } + Ok(LexRequirement::new(updated_exprs)) +} + /// Downcasts all the expressions in `exprs` to `Column`s. If any of the given /// expressions is not a `Column`, returns `None`. pub fn physical_to_column_exprs( - exprs: &[(Arc, String)], + exprs: &[ProjectionExpr], ) -> Option> { exprs .iter() - .map(|(expr, alias)| { - expr.as_any() + .map(|proj_expr| { + proj_expr + .expr + .as_any() .downcast_ref::() - .map(|col| (col.clone(), alias.clone())) + .map(|col| (col.clone(), proj_expr.alias.clone())) }) .collect() } @@ -712,13 +849,10 @@ pub fn new_join_children( let new_left = ProjectionExec::try_new( projection_as_columns[0..=far_right_left_col_ind as _] .iter() - .map(|(col, alias)| { - ( - Arc::new(Column::new(col.name(), col.index())) as _, - alias.clone(), - ) - }) - .collect_vec(), + .map(|(col, alias)| ProjectionExpr { + expr: Arc::new(Column::new(col.name(), col.index())) as _, + alias: alias.clone(), + }), Arc::clone(left_child), )?; let left_size = left_child.schema().fields().len() as i32; @@ -726,17 +860,16 @@ pub fn new_join_children( projection_as_columns[far_left_right_col_ind as _..] .iter() .map(|(col, alias)| { - ( - Arc::new(Column::new( + ProjectionExpr { + expr: Arc::new(Column::new( col.name(), // Align projected expressions coming from the right // table with the new right child projection: (col.index() as i32 - left_size) as _, )) as _, - alias.clone(), - ) - }) - .collect_vec(), + alias: alias.clone(), + } + }), Arc::clone(right_child), )?; @@ -878,45 +1011,50 @@ fn try_unifying_projections( let mut column_ref_map: HashMap = HashMap::new(); // Collect the column references usage in the outer projection. - projection.expr().iter().for_each(|(expr, _)| { - expr.apply(|expr| { - Ok({ - if let Some(column) = expr.as_any().downcast_ref::() { - *column_ref_map.entry(column.clone()).or_default() += 1; - } - TreeNodeRecursion::Continue + projection.expr().iter().for_each(|proj_expr| { + proj_expr + .expr + .apply(|expr| { + Ok({ + if let Some(column) = expr.as_any().downcast_ref::() { + *column_ref_map.entry(column.clone()).or_default() += 1; + } + TreeNodeRecursion::Continue + }) }) - }) - .unwrap(); + .unwrap(); }); // Merging these projections is not beneficial, e.g // If an expression is not trivial and it is referred more than 1, unifies projections will be // beneficial as caching mechanism for non-trivial computations. // See discussion in: https://github.com/apache/datafusion/issues/8296 if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].0)) + *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr)) }) { return Ok(None); } - for (expr, alias) in projection.expr() { + for proj_expr in projection.expr() { // If there is no match in the input projection, we cannot unify these // projections. This case will arise if the projection expression contains // a `PhysicalExpr` variant `update_expr` doesn't support. - let Some(expr) = update_expr(expr, child.expr(), true)? else { + let Some(expr) = update_expr(&proj_expr.expr, child.expr(), true)? else { return Ok(None); }; - projected_exprs.push((expr, alias.clone())); + projected_exprs.push(ProjectionExpr { + expr, + alias: proj_expr.alias.clone(), + }); } ProjectionExec::try_new(projected_exprs, Arc::clone(child.input())) .map(|e| Some(Arc::new(e) as _)) } /// Collect all column indices from the given projection expressions. -fn collect_column_indices(exprs: &[(Arc, String)]) -> Vec { +fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec { // Collect indices and remove duplicates. let mut indices = exprs .iter() - .flat_map(|(expr, _)| collect_columns(expr)) + .flat_map(|proj_expr| collect_columns(&proj_expr.expr)) .map(|x| x.index()) .collect::>() .into_iter() @@ -1015,12 +1153,14 @@ mod tests { use crate::common::collect; use crate::test; + use crate::test::exec::StatisticsExec; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::stats::{ColumnStatistics, Precision, Statistics}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, Literal}; #[test] fn test_collect_column_indices() -> Result<()> { @@ -1033,7 +1173,10 @@ mod tests { Arc::new(Column::new("a", 1)), )), )); - let column_indices = collect_column_indices(&[(expr, "b-(1+a)".to_string())]); + let column_indices = collect_column_indices(&[ProjectionExpr { + expr, + alias: "b-(1+a)".to_string(), + }]); assert_eq!(column_indices, vec![1, 7]); Ok(()) } @@ -1098,18 +1241,33 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let exec = test::scan_partitioned(1); - let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?) - .await - .unwrap(); + let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?; - let projection = ProjectionExec::try_new(vec![], exec)?; + let projection = ProjectionExec::try_new(vec![] as Vec, exec)?; let stream = projection.execute(0, Arc::clone(&task_ctx))?; - let output = collect(stream).await.unwrap(); + let output = collect(stream).await?; assert_eq!(output.len(), expected.len()); Ok(()) } + #[tokio::test] + async fn project_old_syntax() { + let exec = test::scan_partitioned(1); + let schema = exec.schema(); + let expr = col("i", &schema).unwrap(); + ProjectionExec::try_new( + vec![ + // use From impl of ProjectionExpr to create ProjectionExpr + // to test old syntax + (expr, "c".to_string()), + ], + exec, + ) + // expect this to succeed + .unwrap(); + } + fn get_stats() -> Statistics { Statistics { num_rows: Precision::Exact(5), @@ -1156,7 +1314,8 @@ mod tests { Arc::new(Column::new("col0", 0)), ]; - let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); + let result = + stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap(); let expected = Statistics { num_rows: Precision::Exact(5), @@ -1192,7 +1351,8 @@ mod tests { Arc::new(Column::new("col0", 0)), ]; - let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); + let result = + stats_projection(source, exprs.into_iter(), Arc::new(schema)).unwrap(); let expected = Statistics { num_rows: Precision::Exact(5), @@ -1217,4 +1377,86 @@ mod tests { assert_eq!(result, expected); } + + #[test] + fn test_projection_statistics_uses_input_schema() { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, false), + Field::new("e", DataType::Int32, false), + Field::new("f", DataType::Int32, false), + ]); + + let input_statistics = Statistics { + num_rows: Precision::Exact(10), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Int32(Some(5))), + max_value: Precision::Exact(ScalarValue::Int32(Some(50))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Int32(Some(10))), + max_value: Precision::Exact(ScalarValue::Int32(Some(40))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Int32(Some(20))), + max_value: Precision::Exact(ScalarValue::Int32(Some(30))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Int32(Some(21))), + max_value: Precision::Exact(ScalarValue::Int32(Some(29))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Int32(Some(24))), + max_value: Precision::Exact(ScalarValue::Int32(Some(26))), + ..Default::default() + }, + ], + ..Default::default() + }; + + let input = Arc::new(StatisticsExec::new(input_statistics, input_schema)); + + // Create projection expressions that reference columns from the input schema and the length + // of output schema columns < input schema columns and hence if we use the last few columns + // from the input schema in the expressions here, bounds_check would fail on them if output + // schema is supplied to the partitions_statistics method. + let exprs: Vec = vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)) as Arc, + alias: "c_renamed".to_string(), + }, + ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("e", 4)), + Operator::Plus, + Arc::new(Column::new("f", 5)), + )) as Arc, + alias: "e_plus_f".to_string(), + }, + ]; + + let projection = ProjectionExec::try_new(exprs, input).unwrap(); + + let stats = projection.partition_statistics(None).unwrap(); + + assert_eq!(stats.num_rows, Precision::Exact(10)); + assert_eq!( + stats.column_statistics.len(), + 2, + "Expected 2 columns in projection statistics" + ); + assert!(stats.total_byte_size.is_exact().unwrap_or(false)); + } } diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 7268735ea4576..b4cdf2dff2bfb 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -32,7 +32,7 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{internal_datafusion_err, not_impl_err, Result}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; @@ -183,10 +183,9 @@ impl ExecutionPlan for RecursiveQueryExec { ) -> Result { // TODO: we might be able to handle multiple partitions in the future. if partition != 0 { - return Err(DataFusionError::Internal(format!( - "RecursiveQueryExec got an invalid partition {} (expected 0)", - partition - ))); + return Err(internal_datafusion_err!( + "RecursiveQueryExec got an invalid partition {partition} (expected 0)" + )); } let static_stream = self.static_term.execute(partition, Arc::clone(&context))?; @@ -352,16 +351,16 @@ fn assign_work_table( ) -> Result> { let mut work_table_refs = 0; plan.transform_down(|plan| { - if let Some(exec) = plan.as_any().downcast_ref::() { + if let Some(new_plan) = + plan.with_new_state(Arc::clone(&work_table) as Arc) + { if work_table_refs > 0 { not_impl_err!( "Multiple recursive references to the same CTE are not supported" ) } else { work_table_refs += 1; - Ok(Transformed::yes(Arc::new( - exec.with_work_table(Arc::clone(&work_table)), - ))) + Ok(Transformed::yes(new_plan)) } } else if plan.as_any().is::() { not_impl_err!("Recursive queries cannot be nested") @@ -373,7 +372,7 @@ fn assign_work_table( } /// Some plans will change their internal states after execution, making them unable to be executed again. -/// This function uses `ExecutionPlan::with_new_children` to fork a new plan with initial states. +/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan. /// /// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. /// However, if the data of the left table is derived from the work table, it will become outdated @@ -384,8 +383,7 @@ fn reset_plan_states(plan: Arc) -> Result() { Ok(Transformed::no(plan)) } else { - let new_plan = Arc::clone(&plan) - .with_new_children(plan.children().into_iter().cloned().collect())?; + let new_plan = Arc::clone(&plan).reset_state()?; Ok(Transformed::yes(new_plan)) } }) diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index ebc751201378b..a5bf68a63c387 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -19,6 +19,7 @@ //! partitions to M output partitions based on a partitioning scheme, optionally //! maintaining the order of the input rows in the output. +use std::fmt::{Debug, Formatter}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -29,7 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; -use crate::execution_plan::CardinalityEffect; +use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec}; @@ -43,8 +44,10 @@ use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Stat use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions}; use arrow::compute::take_arrays; use arrow::datatypes::{SchemaRef, UInt32Type}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::stats::Precision; use datafusion_common::utils::transpose; -use datafusion_common::HashMap; +use datafusion_common::{internal_err, ColumnStatistics, HashMap}; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; @@ -52,6 +55,10 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; use futures::stream::Stream; use futures::{FutureExt, StreamExt, TryStreamExt}; use log::trace; @@ -63,9 +70,8 @@ type MaybeBatch = Option>; type InputPartitionsToCurrentPartitionSender = Vec>; type InputPartitionsToCurrentPartitionReceiver = Vec>; -/// Inner state of [`RepartitionExec`]. #[derive(Debug)] -struct RepartitionExecState { +struct ConsumingInputStreamsState { /// Channels for sending batches from input partitions to output partitions. /// Key is the partition number. channels: HashMap< @@ -81,16 +87,97 @@ struct RepartitionExecState { abort_helper: Arc>>, } +/// Inner state of [`RepartitionExec`]. +enum RepartitionExecState { + /// Not initialized yet. This is the default state stored in the RepartitionExec node + /// upon instantiation. + NotInitialized, + /// Input streams are initialized, but they are still not being consumed. The node + /// transitions to this state when the arrow's RecordBatch stream is created in + /// RepartitionExec::execute(), but before any message is polled. + InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>), + /// The input streams are being consumed. The node transitions to this state when + /// the first message in the arrow's RecordBatch stream is consumed. + ConsumingInputStreams(ConsumingInputStreamsState), +} + +impl Default for RepartitionExecState { + fn default() -> Self { + Self::NotInitialized + } +} + +impl Debug for RepartitionExecState { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + RepartitionExecState::NotInitialized => write!(f, "NotInitialized"), + RepartitionExecState::InputStreamsInitialized(v) => { + write!(f, "InputStreamsInitialized({:?})", v.len()) + } + RepartitionExecState::ConsumingInputStreams(v) => { + write!(f, "ConsumingInputStreams({v:?})") + } + } + } +} + impl RepartitionExecState { - fn new( + fn ensure_input_streams_initialized( + &mut self, + input: Arc, + metrics: ExecutionPlanMetricsSet, + output_partitions: usize, + ctx: Arc, + ) -> Result<()> { + if !matches!(self, RepartitionExecState::NotInitialized) { + return Ok(()); + } + + let num_input_partitions = input.output_partitioning().partition_count(); + let mut streams_and_metrics = Vec::with_capacity(num_input_partitions); + + for i in 0..num_input_partitions { + let metrics = RepartitionMetrics::new(i, output_partitions, &metrics); + + let timer = metrics.fetch_time.timer(); + let stream = input.execute(i, Arc::clone(&ctx))?; + timer.done(); + + streams_and_metrics.push((stream, metrics)); + } + *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics); + Ok(()) + } + + fn consume_input_streams( + &mut self, input: Arc, - partitioning: Partitioning, metrics: ExecutionPlanMetricsSet, + partitioning: Partitioning, preserve_order: bool, name: String, context: Arc, - ) -> Self { - let num_input_partitions = input.output_partitioning().partition_count(); + ) -> Result<&mut ConsumingInputStreamsState> { + let streams_and_metrics = match self { + RepartitionExecState::NotInitialized => { + self.ensure_input_streams_initialized( + input, + metrics, + partitioning.partition_count(), + Arc::clone(&context), + )?; + let RepartitionExecState::InputStreamsInitialized(value) = self else { + // This cannot happen, as ensure_input_streams_initialized() was just called, + // but the compiler does not know. + return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"); + }; + value + } + RepartitionExecState::ConsumingInputStreams(value) => return Ok(value), + RepartitionExecState::InputStreamsInitialized(value) => value, + }; + + let num_input_partitions = streams_and_metrics.len(); let num_output_partitions = partitioning.partition_count(); let (txs, rxs) = if preserve_order { @@ -117,7 +204,7 @@ impl RepartitionExecState { let mut channels = HashMap::with_capacity(txs.len()); for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { let reservation = Arc::new(Mutex::new( - MemoryConsumer::new(format!("{}[{partition}]", name)) + MemoryConsumer::new(format!("{name}[{partition}]")) .register(context.memory_pool()), )); channels.insert(partition, (tx, rx, reservation)); @@ -125,7 +212,9 @@ impl RepartitionExecState { // launch one async task per *input* partition let mut spawned_tasks = Vec::with_capacity(num_input_partitions); - for i in 0..num_input_partitions { + for (i, (stream, metrics)) in + std::mem::take(streams_and_metrics).into_iter().enumerate() + { let txs: HashMap<_, _> = channels .iter() .map(|(partition, (tx, _rx, reservation))| { @@ -133,15 +222,11 @@ impl RepartitionExecState { }) .collect(); - let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics); - let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( - Arc::clone(&input), - i, + stream, txs.clone(), partitioning.clone(), - r_metrics, - Arc::clone(&context), + metrics, )); // In a separate task, wait for each input to be done @@ -154,28 +239,17 @@ impl RepartitionExecState { )); spawned_tasks.push(wait_for_task); } - - Self { + *self = Self::ConsumingInputStreams(ConsumingInputStreamsState { channels, abort_helper: Arc::new(spawned_tasks), + }); + match self { + RepartitionExecState::ConsumingInputStreams(value) => Ok(value), + _ => unreachable!(), } } } -/// Lazily initialized state -/// -/// Note that the state is initialized ONCE for all partitions by a single task(thread). -/// This may take a short while. It is also like that multiple threads -/// call execute at the same time, because we have just started "target partitions" tasks -/// which is commonly set to the number of CPU cores and all call execute at the same time. -/// -/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles -/// in a mutex lock but instead allow other threads to do something useful. -/// -/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration -/// (e.g. removing channels on completion) where the overhead of `await` is not warranted. -type LazyState = Arc>>; - /// A utility that can be used to partition batches based on [`Partitioning`] pub struct BatchPartitioner { state: BatchPartitionerState, @@ -402,8 +476,9 @@ impl BatchPartitioner { pub struct RepartitionExec { /// Input execution plan input: Arc, - /// Inner state that is initialized when the first output stream is created. - state: LazyState, + /// Inner state that is initialized when the parent calls .execute() on this node + /// and consumed as soon as the parent starts consuming this node. + state: Arc>, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Boolean flag to decide whether to preserve ordering. If true means @@ -482,11 +557,7 @@ impl RepartitionExec { } impl DisplayAs for RepartitionExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( @@ -508,11 +579,17 @@ impl DisplayAs for RepartitionExec { } DisplayFormatType::TreeRender => { writeln!(f, "partitioning_scheme={}", self.partitioning(),)?; + + let input_partition_count = + self.input.output_partitioning().partition_count(); + let output_partition_count = self.partitioning().partition_count(); + let input_to_output_partition_str = + format!("{input_partition_count} -> {output_partition_count}"); writeln!( f, - "output_partition_count={}", - self.input.output_partitioning().partition_count() + "partition_count(in->out)={input_to_output_partition_str}" )?; + if self.preserve_order { writeln!(f, "preserve_order={}", self.preserve_order)?; } @@ -573,42 +650,42 @@ impl ExecutionPlan for RepartitionExec { partition ); - let lazy_state = Arc::clone(&self.state); let input = Arc::clone(&self.input); let partitioning = self.partitioning().clone(); let metrics = self.metrics.clone(); - let preserve_order = self.preserve_order; + let preserve_order = self.sort_exprs().is_some(); let name = self.name().to_owned(); let schema = self.schema(); let schema_captured = Arc::clone(&schema); // Get existing ordering to use for merging - let sort_exprs = self.sort_exprs().cloned().unwrap_or_default(); + let sort_exprs = self.sort_exprs().cloned(); + + let state = Arc::clone(&self.state); + if let Some(mut state) = state.try_lock() { + state.ensure_input_streams_initialized( + Arc::clone(&input), + metrics.clone(), + partitioning.partition_count(), + Arc::clone(&context), + )?; + } let stream = futures::stream::once(async move { let num_input_partitions = input.output_partitioning().partition_count(); - let input_captured = Arc::clone(&input); - let metrics_captured = metrics.clone(); - let name_captured = name.clone(); - let context_captured = Arc::clone(&context); - let state = lazy_state - .get_or_init(|| async move { - Mutex::new(RepartitionExecState::new( - input_captured, - partitioning, - metrics_captured, - preserve_order, - name_captured, - context_captured, - )) - }) - .await; - // lock scope let (mut rx, reservation, abort_helper) = { // lock mutexes let mut state = state.lock(); + let state = state.consume_input_streams( + Arc::clone(&input), + metrics.clone(), + partitioning, + preserve_order, + name.clone(), + Arc::clone(&context), + )?; // now return stream for the specified *output* partition which will // read from the channel @@ -621,9 +698,7 @@ impl ExecutionPlan for RepartitionExec { }; trace!( - "Before returning stream in {}::execute for partition: {}", - name, - partition + "Before returning stream in {name}::execute for partition: {partition}" ); if preserve_order { @@ -645,12 +720,12 @@ impl ExecutionPlan for RepartitionExec { // input partitions to this partition: let fetch = None; let merge_reservation = - MemoryConsumer::new(format!("{}[Merge {partition}]", name)) + MemoryConsumer::new(format!("{name}[Merge {partition}]")) .register(context.memory_pool()); StreamingMergeBuilder::new() .with_streams(input_streams) .with_schema(schema_captured) - .with_expressions(&sort_exprs) + .with_expressions(&sort_exprs.unwrap()) .with_metrics(BaselineMetrics::new(&metrics, partition)) .with_batch_size(context.session_config().batch_size()) .with_fetch(fetch) @@ -677,7 +752,49 @@ impl ExecutionPlan for RepartitionExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + let partition_count = self.partitioning().partition_count(); + if partition_count == 0 { + return Ok(Statistics::new_unknown(&self.schema())); + } + + if partition >= partition_count { + return internal_err!( + "RepartitionExec invalid partition {} (expected less than {})", + partition, + self.partitioning().partition_count() + ); + } + + let mut stats = self.input.partition_statistics(None)?; + + // Distribute statistics across partitions + stats.num_rows = stats + .num_rows + .get_value() + .map(|rows| Precision::Inexact(rows / partition_count)) + .unwrap_or(Precision::Absent); + stats.total_byte_size = stats + .total_byte_size + .get_value() + .map(|bytes| Precision::Inexact(bytes / partition_count)) + .unwrap_or(Precision::Absent); + + // Make all column stats unknown + stats.column_statistics = stats + .column_statistics + .iter() + .map(|_| ColumnStatistics::new_unknown()) + .collect(); + + Ok(stats) + } else { + self.input.partition_statistics(None) + } } fn cardinality_effect(&self) -> CardinalityEffect { @@ -723,6 +840,45 @@ impl ExecutionPlan for RepartitionExec { new_partitioning, )?))) } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } + + fn repartitioned( + &self, + target_partitions: usize, + _config: &ConfigOptions, + ) -> Result>> { + use Partitioning::*; + let mut new_properties = self.cache.clone(); + new_properties.partitioning = match new_properties.partitioning { + RoundRobinBatch(_) => RoundRobinBatch(target_partitions), + Hash(hash, _) => Hash(hash, target_partitions), + UnknownPartitioning(_) => UnknownPartitioning(target_partitions), + }; + Ok(Some(Arc::new(Self { + input: Arc::clone(&self.input), + state: Arc::clone(&self.state), + metrics: self.metrics.clone(), + preserve_order: self.preserve_order, + cache: new_properties, + }))) + } } impl RepartitionExec { @@ -783,6 +939,8 @@ impl RepartitionExec { input.pipeline_behavior(), input.boundedness(), ) + .with_scheduling_type(SchedulingType::Cooperative) + .with_evaluation_type(EvaluationType::Eager) } /// Specify if this repartitioning operation should preserve the order of @@ -818,24 +976,17 @@ impl RepartitionExec { /// /// txs hold the output sending channels for each output partition async fn pull_from_input( - input: Arc, - partition: usize, + mut stream: SendableRecordBatchStream, mut output_channels: HashMap< usize, (DistributionSender, SharedMemoryReservation), >, partitioning: Partitioning, metrics: RepartitionMetrics, - context: Arc, ) -> Result<()> { let mut partitioner = BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; - // execute the child operator - let timer = metrics.fetch_time.timer(); - let mut stream = input.execute(partition, context)?; - timer.done(); - // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); while !output_channels.is_empty() { @@ -1083,6 +1234,7 @@ mod tests { use datafusion_common_runtime::JoinSet; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use insta::assert_snapshot; + use itertools::Itertools; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1263,15 +1415,9 @@ mod tests { let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - // Note: this should pass (the stream can be created) but the - // error when the input is executed should get passed back - let output_stream = exec.execute(0, task_ctx).unwrap(); - // Expect that an error is returned - let result_string = crate::common::collect(output_stream) - .await - .unwrap_err() - .to_string(); + let result_string = exec.execute(0, task_ctx).err().unwrap().to_string(); + assert!( result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"), "actual: {result_string}" @@ -1461,7 +1607,14 @@ mod tests { }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); - assert_eq!(batches_without_drop, batches_with_drop); + fn sort(batch: Vec) -> Vec { + batch + .into_iter() + .sorted_by_key(|b| format!("{b:?}")) + .collect() + } + + assert_eq!(sort(batches_without_drop), sort(batches_with_drop)); } fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> { @@ -1630,8 +1783,7 @@ mod test { /// macro_rules! assert_plan { ($EXPECTED_PLAN_LINES: expr, $PLAN: expr) => { - let physical_plan = $PLAN; - let formatted = crate::displayable(&physical_plan).indent(true).to_string(); + let formatted = crate::displayable($PLAN).indent(true).to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES @@ -1651,11 +1803,9 @@ mod test { let source1 = sorted_memory_exec(&schema, sort_exprs.clone()); let source2 = sorted_memory_exec(&schema, sort_exprs); // output has multiple partitions, and is sorted - let union = UnionExec::new(vec![source1, source2]); - let exec = - RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) - .unwrap() - .with_preserve_order(); + let union = UnionExec::try_new(vec![source1, source2])?; + let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))? + .with_preserve_order(); // Repartition should preserve order let expected_plan = [ @@ -1664,7 +1814,7 @@ mod test { " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", ]; - assert_plan!(expected_plan, exec); + assert_plan!(expected_plan, &exec); Ok(()) } @@ -1683,7 +1833,7 @@ mod test { "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", ]; - assert_plan!(expected_plan, exec); + assert_plan!(expected_plan, &exec); Ok(()) } @@ -1693,11 +1843,9 @@ mod test { let source1 = memory_exec(&schema); let source2 = memory_exec(&schema); // output has multiple partitions, but is not sorted - let union = UnionExec::new(vec![source1, source2]); - let exec = - RepartitionExec::try_new(Arc::new(union), Partitioning::RoundRobinBatch(10)) - .unwrap() - .with_preserve_order(); + let union = UnionExec::try_new(vec![source1, source2])?; + let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))? + .with_preserve_order(); // Repartition should not preserve order, as there is no order to preserve let expected_plan = [ @@ -1706,7 +1854,26 @@ mod test { " DataSourceExec: partitions=1, partition_sizes=[0]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - assert_plan!(expected_plan, exec); + assert_plan!(expected_plan, &exec); + Ok(()) + } + + #[tokio::test] + async fn test_repartition() -> Result<()> { + let schema = test_schema(); + let sort_exprs = sort_exprs(&schema); + let source = sorted_memory_exec(&schema, sort_exprs); + // output is sorted, but has only a single partition, so no need to sort + let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))? + .repartitioned(20, &Default::default())? + .unwrap(); + + // Repartition should not preserve order + let expected_plan = [ + "RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1", + " DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC", + ]; + assert_plan!(expected_plan, exec.as_ref()); Ok(()) } @@ -1715,11 +1882,11 @@ mod test { } fn sort_exprs(schema: &Schema) -> LexOrdering { - let options = SortOptions::default(); - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("c0", schema).unwrap(), - options, - }]) + options: SortOptions::default(), + }] + .into() } fn memory_exec(schema: &SchemaRef) -> Arc { diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index efb9c0a47bf58..54dc2414e4f08 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -16,6 +16,7 @@ // under the License. use std::cmp::Ordering; +use std::sync::Arc; use arrow::array::{ types::ByteArrayType, Array, ArrowPrimitiveType, GenericByteArray, @@ -151,7 +152,7 @@ impl Ord for Cursor { /// Used for sorting when there are multiple columns in the sort key #[derive(Debug)] pub struct RowValues { - rows: Rows, + rows: Arc, /// Tracks for the memory used by in the `Rows` of this /// cursor. Freed on drop @@ -164,7 +165,7 @@ impl RowValues { /// /// Panics if the reservation is not for exactly `rows.size()` /// bytes or if `rows` is empty. - pub fn new(rows: Rows, reservation: MemoryReservation) -> Self { + pub fn new(rows: Arc, reservation: MemoryReservation) -> Self { assert_eq!( rows.size(), reservation.size(), @@ -293,14 +294,19 @@ impl CursorValues for StringViewArray { self.views().len() } + #[inline(always)] fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { // SAFETY: Both l_idx and r_idx are guaranteed to be within bounds, // and any null-checks are handled in the outer layers. // Fast path: Compare the lengths before full byte comparison. - let l_view = unsafe { l.views().get_unchecked(l_idx) }; - let l_len = *l_view as u32; let r_view = unsafe { r.views().get_unchecked(r_idx) }; + + if l.data_buffers().is_empty() && r.data_buffers().is_empty() { + return l_view == r_view; + } + + let l_len = *l_view as u32; let r_len = *r_view as u32; if l_len != r_len { return false; @@ -309,14 +315,20 @@ impl CursorValues for StringViewArray { unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx).is_eq() } } + #[inline(always)] fn eq_to_previous(cursor: &Self, idx: usize) -> bool { // SAFETY: The caller guarantees that idx > 0 and the indices are valid. // Already checked it in is_eq_to_prev_one function // Fast path: Compare the lengths of the current and previous views. let l_view = unsafe { cursor.views().get_unchecked(idx) }; - let l_len = *l_view as u32; let r_view = unsafe { cursor.views().get_unchecked(idx - 1) }; + if cursor.data_buffers().is_empty() { + return l_view == r_view; + } + + let l_len = *l_view as u32; let r_len = *r_view as u32; + if l_len != r_len { return false; } @@ -326,10 +338,18 @@ impl CursorValues for StringViewArray { } } + #[inline(always)] fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { // SAFETY: Prior assertions guarantee that l_idx and r_idx are valid indices. // Null-checks are assumed to have been handled in the wrapper (e.g., ArrayValues). // And the bound is checked in is_finished, it is safe to call get_unchecked + if l.data_buffers().is_empty() && r.data_buffers().is_empty() { + let l_view = unsafe { l.views().get_unchecked(l_idx) }; + let r_view = unsafe { r.views().get_unchecked(r_idx) }; + return StringViewArray::inline_key_fast(*l_view) + .cmp(&StringViewArray::inline_key_fast(*r_view)); + } + unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx) } } } @@ -422,11 +442,10 @@ impl CursorValues for ArrayValues { #[cfg(test)] mod tests { - use std::sync::Arc; - use datafusion_execution::memory_pool::{ GreedyMemoryPool, MemoryConsumer, MemoryPool, }; + use std::sync::Arc; use super::*; diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 1c2b8cd0c91b7..0b0136cd12ced 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -18,7 +18,6 @@ //! Merge that deals with an arbitrary size of streaming inputs. //! This is an order-preserving merge. -use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; @@ -50,8 +49,9 @@ pub(crate) struct SortPreservingMergeStream { /// used to record execution metrics metrics: BaselineMetrics, - /// If the stream has encountered an error - aborted: bool, + /// If the stream has encountered an error or reaches the + /// `fetch` limit. + done: bool, /// A loser tree that always produces the minimum cursor /// @@ -143,11 +143,8 @@ pub(crate) struct SortPreservingMergeStream { /// number of rows produced produced: usize, - /// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`, - /// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the - /// vector to ensure the next iteration starts with a different partition, preventing the same partition - /// from being continuously polled. - uninitiated_partitions: VecDeque, + /// This vector contains the indices of the partitions that have not started emitting yet. + uninitiated_partitions: Vec, } impl SortPreservingMergeStream { @@ -166,7 +163,7 @@ impl SortPreservingMergeStream { in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation), streams, metrics, - aborted: false, + done: false, cursors: (0..stream_count).map(|_| None).collect(), prev_cursors: (0..stream_count).map(|_| None).collect(), round_robin_tie_breaker_mode: false, @@ -210,42 +207,56 @@ impl SortPreservingMergeStream { &mut self, cx: &mut Context<'_>, ) -> Poll>> { - if self.aborted { + if self.done { return Poll::Ready(None); } // Once all partitions have set their corresponding cursors for the loser tree, // we skip the following block. Until then, this function may be called multiple // times and can return Poll::Pending if any partition returns Poll::Pending. + if self.loser_tree.is_empty() { - let remaining_partitions = self.uninitiated_partitions.clone(); - for i in remaining_partitions { - match self.maybe_poll_stream(cx, i) { + // Manual indexing since we're iterating over the vector and shrinking it in the loop + let mut idx = 0; + while idx < self.uninitiated_partitions.len() { + let partition_idx = self.uninitiated_partitions[idx]; + match self.maybe_poll_stream(cx, partition_idx) { Poll::Ready(Err(e)) => { - self.aborted = true; + self.done = true; return Poll::Ready(Some(Err(e))); } Poll::Pending => { - // If a partition returns Poll::Pending, to avoid continuously polling it - // and potentially increasing upstream buffer sizes, we move it to the - // back of the polling queue. - if let Some(front) = self.uninitiated_partitions.pop_front() { - // This pop_front can never return `None`. - self.uninitiated_partitions.push_back(front); - } - // This function could remain in a pending state, so we manually wake it here. - // However, this approach can be investigated further to find a more natural way - // to avoid disrupting the runtime scheduler. - cx.waker().wake_by_ref(); - return Poll::Pending; + // The polled stream is pending which means we're already set up to + // be woken when necessary + // Try the next stream + idx += 1; } _ => { - // If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None), - // we remove this partition from the queue so it is not polled again. - self.uninitiated_partitions.retain(|idx| *idx != i); + // The polled stream is ready + // Remove it from uninitiated_partitions + // Don't bump idx here, since a new element will have taken its + // place which we'll try in the next loop iteration + // swap_remove will change the partition poll order, but that shouldn't + // make a difference since we're waiting for all streams to be ready. + self.uninitiated_partitions.swap_remove(idx); } } } - self.init_loser_tree(); + + if self.uninitiated_partitions.is_empty() { + // If there are no more uninitiated partitions, set up the loser tree and continue + // to the next phase. + + // Claim the memory for the uninitiated partitions + self.uninitiated_partitions.shrink_to_fit(); + self.init_loser_tree(); + } else { + // There are still uninitiated partitions so return pending. + // We only get here if we've polled all uninitiated streams and at least one of them + // returned pending itself. That means we will be woken as soon as one of the + // streams would like to be polled again. + // There is no need to reschedule ourselves eagerly. + return Poll::Pending; + } } // NB timer records time taken on drop, so there are no @@ -258,7 +269,7 @@ impl SortPreservingMergeStream { if !self.loser_tree_adjusted { let winner = self.loser_tree[0]; if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) { - self.aborted = true; + self.done = true; return Poll::Ready(Some(Err(e))); } self.update_loser_tree(); @@ -271,7 +282,7 @@ impl SortPreservingMergeStream { // stop sorting if fetch has been reached if self.fetch_reached() { - self.aborted = true; + self.done = true; } else if self.in_progress.len() < self.batch_size { continue; } @@ -483,13 +494,12 @@ impl SortPreservingMergeStream { if self.enable_round_robin_tie_breaker && cmp_node == 1 { match (&self.cursors[winner], &self.cursors[challenger]) { (Some(ac), Some(bc)) => { - let ord = ac.cmp(bc); - if ord.is_eq() { + if ac == bc { self.handle_tie(cmp_node, &mut winner, challenger); } else { // Ends of tie breaker self.round_robin_tie_breaker_mode = false; - if ord.is_gt() { + if ac > bc { self.update_winner(cmp_node, &mut winner, challenger); } } diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index c7ffae4061c0e..9c72e34fe343e 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -20,6 +20,7 @@ mod builder; mod cursor; mod merge; +mod multi_level_merge; pub mod partial_sort; pub mod sort; pub mod sort_preserving_merge; diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs new file mode 100644 index 0000000000000..58d046cc90911 --- /dev/null +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -0,0 +1,452 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Create a stream that do a multi level merge stream + +use crate::metrics::BaselineMetrics; +use crate::{EmptyRecordBatchStream, SpillManager}; +use arrow::array::RecordBatch; +use std::fmt::{Debug, Formatter}; +use std::mem; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::memory_pool::MemoryReservation; + +use crate::sorts::sort::get_reserved_byte_for_record_batch_size; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; +use crate::stream::RecordBatchStreamAdapter; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::TryStreamExt; +use futures::{Stream, StreamExt}; + +/// Merges a stream of sorted cursors and record batches into a single sorted stream +/// +/// This is a wrapper around [`SortPreservingMergeStream`](crate::sorts::merge::SortPreservingMergeStream) +/// that provide it the sorted streams/files to merge while making sure we can merge them in memory. +/// In case we can't merge all of them in a single pass we will spill the intermediate results to disk +/// and repeat the process. +/// +/// ## High level Algorithm +/// 1. Get the maximum amount of sorted in-memory streams and spill files we can merge with the available memory +/// 2. Sort them to a sorted stream +/// 3. Do we have more spill files to merge? +/// - Yes: write that sorted stream to a spill file, +/// add that spill file back to the spill files to merge and +/// repeat the process +/// +/// - No: return that sorted stream as the final output stream +/// +/// ```text +/// Initial State: Multiple sorted streams + spill files +/// ┌───────────┐ +/// │ Phase 1 │ +/// └───────────┘ +/// ┌──Can hold in memory─┐ +/// │ ┌──────────────┐ │ +/// │ │ In-memory │ +/// │ │sorted stream │──┼────────┐ +/// │ │ 1 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ +/// │ │ In-memory │ │ +/// │ │sorted stream │──┼────────┤ +/// │ │ 2 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ +/// │ │ In-memory │ │ +/// │ │sorted stream │──┼────────┤ +/// │ │ 3 │ │ │ +/// └──────────────┘ │ │ +/// │ ┌──────────────┐ │ │ ┌───────────┐ +/// │ │ Sorted Spill │ │ │ Phase 2 │ +/// │ │ file 1 │──┼────────┤ └───────────┘ +/// │ └──────────────┘ │ │ +/// ──── ──── ──── ──── ─┘ │ ┌──Can hold in memory─┐ +/// │ │ │ +/// ┌──────────────┐ │ │ ┌──────────────┐ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 2 │──────────────────────▶│ file 2 │──┼─────┐ +/// └──────────────┘ │ └──────────────┘ │ │ +/// ┌──────────────┐ │ │ ┌──────────────┐ │ │ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 3 │──────────────────────▶│ file 3 │──┼─────┤ +/// └──────────────┘ │ │ └──────────────┘ │ │ +/// ┌──────────────┐ │ ┌──────────────┐ │ │ +/// │ Sorted Spill │ │ │ │ Sorted Spill │ │ │ +/// │ file 4 │──────────────────────▶│ file 4 │────────┤ ┌───────────┐ +/// └──────────────┘ │ │ └──────────────┘ │ │ │ Phase 3 │ +/// │ │ │ │ └───────────┘ +/// │ ──── ──── ──── ──── ─┘ │ ┌──Can hold in memory─┐ +/// │ │ │ │ +/// ┌──────────────┐ │ ┌──────────────┐ │ │ ┌──────────────┐ +/// │ Sorted Spill │ │ │ Sorted Spill │ │ │ │ Sorted Spill │ │ +/// │ file 5 │──────────────────────▶│ file 5 │────────────────▶│ file 5 │───┼───┐ +/// └──────────────┘ │ └──────────────┘ │ │ └──────────────┘ │ │ +/// │ │ │ │ │ +/// │ ┌──────────────┐ │ │ ┌──────────────┐ │ +/// │ │ Sorted Spill │ │ │ │ Sorted Spill │ │ │ ┌── ─── ─── ─── ─── ─── ─── ──┐ +/// └──────────▶│ file 6 │────────────────▶│ file 6 │───┼───┼──────▶ Output Stream +/// └──────────────┘ │ │ └──────────────┘ │ │ └── ─── ─── ─── ─── ─── ─── ──┘ +/// │ │ │ │ +/// │ │ ┌──────────────┐ │ +/// │ │ │ Sorted Spill │ │ │ +/// └───────▶│ file 7 │───┼───┘ +/// │ └──────────────┘ │ +/// │ │ +/// └─ ──── ──── ──── ──── +/// ``` +/// +/// ## Memory Management Strategy +/// +/// This multi-level merge make sure that we can handle any amount of data to sort as long as +/// we have enough memory to merge at least 2 streams at a time. +/// +/// 1. **Worst-Case Memory Reservation**: Reserves memory based on the largest +/// batch size encountered in each spill file to merge, ensuring sufficient memory is always +/// available during merge operations. +/// 2. **Adaptive Buffer Sizing**: Reduces buffer sizes when memory is constrained +/// 3. **Spill-to-Disk**: Spill to disk when we cannot merge all files in memory +/// +pub(crate) struct MultiLevelMergeBuilder { + spill_manager: SpillManager, + schema: SchemaRef, + sorted_spill_files: Vec, + sorted_streams: Vec, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + reservation: MemoryReservation, + fetch: Option, + enable_round_robin_tie_breaker: bool, +} + +impl Debug for MultiLevelMergeBuilder { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MultiLevelMergeBuilder") + } +} + +impl MultiLevelMergeBuilder { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + spill_manager: SpillManager, + schema: SchemaRef, + sorted_spill_files: Vec, + sorted_streams: Vec, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + reservation: MemoryReservation, + fetch: Option, + enable_round_robin_tie_breaker: bool, + ) -> Self { + Self { + spill_manager, + schema, + sorted_spill_files, + sorted_streams, + expr, + metrics, + batch_size, + reservation, + enable_round_robin_tie_breaker, + fetch, + } + } + + pub(crate) fn create_spillable_merge_stream(self) -> SendableRecordBatchStream { + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + futures::stream::once(self.create_stream()).try_flatten(), + )) + } + + async fn create_stream(mut self) -> Result { + loop { + let mut stream = self.merge_sorted_runs_within_mem_limit()?; + + // TODO - add a threshold for number of files to disk even if empty and reading from disk so + // we can avoid the memory reservation + + // If no spill files are left, we can return the stream as this is the last sorted run + // TODO - We can write to disk before reading it back to avoid having multiple streams in memory + if self.sorted_spill_files.is_empty() { + assert!( + self.sorted_streams.is_empty(), + "We should not have any sorted streams left" + ); + + return Ok(stream); + } + + // Need to sort to a spill file + let Some((spill_file, max_record_batch_memory)) = self + .spill_manager + .spill_record_batch_stream_and_return_max_batch_memory( + &mut stream, + "MultiLevelMergeBuilder intermediate spill", + ) + .await? + else { + continue; + }; + + // Add the spill file + self.sorted_spill_files.push(SortedSpillFile { + file: spill_file, + max_record_batch_memory, + }); + } + } + + /// This tries to create a stream that merges the most sorted streams and sorted spill files + /// as possible within the memory limit. + fn merge_sorted_runs_within_mem_limit( + &mut self, + ) -> Result { + match (self.sorted_spill_files.len(), self.sorted_streams.len()) { + // No data so empty batch + (0, 0) => Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone( + &self.schema, + )))), + + // Only in-memory stream, return that + (0, 1) => Ok(self.sorted_streams.remove(0)), + + // Only single sorted spill file so return it + (1, 0) => { + let spill_file = self.sorted_spill_files.remove(0); + + // Not reserving any memory for this disk as we are not holding it in memory + self.spill_manager + .read_spill_as_stream(spill_file.file, None) + } + + // Only in memory streams, so merge them all in a single pass + (0, _) => { + let sorted_stream = mem::take(&mut self.sorted_streams); + self.create_new_merge_sort( + sorted_stream, + // If we have no sorted spill files left, this is the last run + true, + true, + ) + } + + // Need to merge multiple streams + (_, _) => { + let mut memory_reservation = self.reservation.new_empty(); + + // Don't account for existing streams memory + // as we are not holding the memory for them + let mut sorted_streams = mem::take(&mut self.sorted_streams); + + let (sorted_spill_files, buffer_size) = self + .get_sorted_spill_files_to_merge( + 2, + // we must have at least 2 streams to merge + 2_usize.saturating_sub(sorted_streams.len()), + &mut memory_reservation, + )?; + + let is_only_merging_memory_streams = sorted_spill_files.is_empty(); + + for spill in sorted_spill_files { + let stream = self + .spill_manager + .clone() + .with_batch_read_buffer_capacity(buffer_size) + .read_spill_as_stream( + spill.file, + Some(spill.max_record_batch_memory), + )?; + sorted_streams.push(stream); + } + let merge_sort_stream = self.create_new_merge_sort( + sorted_streams, + // If we have no sorted spill files left, this is the last run + self.sorted_spill_files.is_empty(), + is_only_merging_memory_streams, + )?; + + // If we're only merging memory streams, we don't need to attach the memory reservation + // as it's empty + if is_only_merging_memory_streams { + assert_eq!(memory_reservation.size(), 0, "when only merging memory streams, we should not have any memory reservation and let the merge sort handle the memory"); + + Ok(merge_sort_stream) + } else { + // Attach the memory reservation to the stream to make sure we have enough memory + // throughout the merge process as we bypassed the memory pool for the merge sort stream + Ok(Box::pin(StreamAttachedReservation::new( + merge_sort_stream, + memory_reservation, + ))) + } + } + } + } + + fn create_new_merge_sort( + &mut self, + streams: Vec, + is_output: bool, + all_in_memory: bool, + ) -> Result { + let mut builder = StreamingMergeBuilder::new() + .with_schema(Arc::clone(&self.schema)) + .with_expressions(&self.expr) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_metrics(if is_output { + // Only add the metrics to the last run + self.metrics.clone() + } else { + self.metrics.intermediate() + }) + .with_round_robin_tie_breaker(self.enable_round_robin_tie_breaker) + .with_streams(streams); + + if !all_in_memory { + // Don't track memory used by this stream as we reserve that memory by worst case sceneries + // (reserving memory for the biggest batch in each stream) + // TODO - avoid this hack as this can be broken easily when `SortPreservingMergeStream` + // changes the implementation to use more/less memory + builder = builder.with_bypass_mempool(); + } else { + // If we are only merging in-memory streams, we need to use the memory reservation + // because we don't know the maximum size of the batches in the streams + builder = builder.with_reservation(self.reservation.new_empty()); + } + + builder.build() + } + + /// Return the sorted spill files to use for the next phase, and the buffer size + /// This will try to get as many spill files as possible to merge, and if we don't have enough streams + /// it will try to reduce the buffer size until we have enough streams to merge + /// otherwise it will return an error + fn get_sorted_spill_files_to_merge( + &mut self, + buffer_len: usize, + minimum_number_of_required_streams: usize, + reservation: &mut MemoryReservation, + ) -> Result<(Vec, usize)> { + assert_ne!(buffer_len, 0, "Buffer length must be greater than 0"); + let mut number_of_spills_to_read_for_current_phase = 0; + + for spill in &self.sorted_spill_files { + // For memory pools that are not shared this is good, for other this is not + // and there should be some upper limit to memory reservation so we won't starve the system + match reservation.try_grow(get_reserved_byte_for_record_batch_size( + spill.max_record_batch_memory * buffer_len, + )) { + Ok(_) => { + number_of_spills_to_read_for_current_phase += 1; + } + // If we can't grow the reservation, we need to stop + Err(err) => { + // We must have at least 2 streams to merge, so if we don't have enough memory + // fail + if minimum_number_of_required_streams + > number_of_spills_to_read_for_current_phase + { + // Free the memory we reserved for this merge as we either try again or fail + reservation.free(); + if buffer_len > 1 { + // Try again with smaller buffer size, it will be slower but at least we can merge + return self.get_sorted_spill_files_to_merge( + buffer_len - 1, + minimum_number_of_required_streams, + reservation, + ); + } + + return Err(err); + } + + // We reached the maximum amount of memory we can use + // for this merge + break; + } + } + } + + let spills = self + .sorted_spill_files + .drain(..number_of_spills_to_read_for_current_phase) + .collect::>(); + + Ok((spills, buffer_len)) + } +} + +struct StreamAttachedReservation { + stream: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl StreamAttachedReservation { + fn new(stream: SendableRecordBatchStream, reservation: MemoryReservation) -> Self { + Self { + stream, + reservation, + } + } +} + +impl Stream for StreamAttachedReservation { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = self.stream.poll_next_unpin(cx); + + match res { + Poll::Ready(res) => { + match res { + Some(Ok(batch)) => Poll::Ready(Some(Ok(batch))), + Some(Err(err)) => { + // Had an error so drop the data + self.reservation.free(); + Poll::Ready(Some(Err(err))) + } + None => { + // Stream is done so free the memory + self.reservation.free(); + + Poll::Ready(None) + } + } + } + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for StreamAttachedReservation { + fn schema(&self) -> SchemaRef { + self.stream.schema() + } +} diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 320fa21c86656..513081e627e1a 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -105,7 +105,8 @@ impl PartialSortExec { ) -> Self { debug_assert!(common_prefix_length > 0); let preserve_partitioning = false; - let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning); + let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning) + .unwrap(); Self { input, expr, @@ -159,7 +160,7 @@ impl PartialSortExec { /// Sort expressions pub fn expr(&self) -> &LexOrdering { - self.expr.as_ref() + &self.expr } /// If `Some(fetch)`, limits output to only the first "fetch" items @@ -189,24 +190,22 @@ impl PartialSortExec { input: &Arc, sort_exprs: LexOrdering, preserve_partitioning: bool, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties; i.e. reset the ordering equivalence // class with the new ordering: - let eq_properties = input - .equivalence_properties() - .clone() - .with_reorder(sort_exprs); + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.reorder(sort_exprs)?; // Get output partitioning: let output_partitioning = Self::output_partitioning_helper(input, preserve_partitioning); - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, input.pipeline_behavior(), input.boundedness(), - ) + )) } } @@ -296,10 +295,7 @@ impl ExecutionPlan for PartialSortExec { let input = self.input.execute(partition, Arc::clone(&context))?; - trace!( - "End PartialSortExec's input.execute for partition: {}", - partition - ); + trace!("End PartialSortExec's input.execute for partition: {partition}"); // Make sure common prefix length is larger than 0 // Otherwise, we should use SortExec. @@ -309,7 +305,7 @@ impl ExecutionPlan for PartialSortExec { input, expr: self.expr.clone(), common_prefix_length: self.common_prefix_length, - in_mem_batches: vec![], + in_mem_batch: RecordBatch::new_empty(Arc::clone(&self.schema())), fetch: self.fetch, is_closed: false, baseline_metrics: BaselineMetrics::new(&self.metrics_set, partition), @@ -321,7 +317,11 @@ impl ExecutionPlan for PartialSortExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) } } @@ -334,7 +334,7 @@ struct PartialSortStream { /// should be more than 0 otherwise PartialSort is not applicable common_prefix_length: usize, /// Used as a buffer for part of the input not ready for sort - in_mem_batches: Vec, + in_mem_batch: RecordBatch, /// Fetch top N results fetch: Option, /// Whether the stream has finished returning all of its data or not @@ -375,52 +375,62 @@ impl PartialSortStream { return Poll::Ready(None); } loop { - return Poll::Ready(match ready!(self.input.poll_next_unpin(cx)) { + // Check if we've already reached the fetch limit + if self.fetch == Some(0) { + self.is_closed = true; + return Poll::Ready(None); + } + + match ready!(self.input.poll_next_unpin(cx)) { Some(Ok(batch)) => { - if let Some(slice_point) = - self.get_slice_point(self.common_prefix_length, &batch)? + // Merge new batch into in_mem_batch + self.in_mem_batch = concat_batches( + &self.schema(), + &[self.in_mem_batch.clone(), batch], + )?; + + // Check if we have a slice point, otherwise keep accumulating in `self.in_mem_batch`. + if let Some(slice_point) = self + .get_slice_point(self.common_prefix_length, &self.in_mem_batch)? { - self.in_mem_batches.push(batch.slice(0, slice_point)); - let remaining_batch = - batch.slice(slice_point, batch.num_rows() - slice_point); - // Extract the sorted batch - let sorted_batch = self.sort_in_mem_batches(); - // Refill with the remaining batch - self.in_mem_batches.push(remaining_batch); - - debug_assert!(sorted_batch - .as_ref() - .map(|batch| batch.num_rows() > 0) - .unwrap_or(true)); - Some(sorted_batch) - } else { - self.in_mem_batches.push(batch); - continue; + let sorted = self.in_mem_batch.slice(0, slice_point); + self.in_mem_batch = self.in_mem_batch.slice( + slice_point, + self.in_mem_batch.num_rows() - slice_point, + ); + let sorted_batch = sort_batch(&sorted, &self.expr, self.fetch)?; + if let Some(fetch) = self.fetch.as_mut() { + *fetch -= sorted_batch.num_rows(); + } + + if sorted_batch.num_rows() > 0 { + return Poll::Ready(Some(Ok(sorted_batch))); + } } } - Some(Err(e)) => Some(Err(e)), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), None => { self.is_closed = true; - // once input is consumed, sort the rest of the inserted batches - let remaining_batch = self.sort_in_mem_batches()?; - if remaining_batch.num_rows() > 0 { - Some(Ok(remaining_batch)) + // Once input is consumed, sort the rest of the inserted batches + let remaining_batch = self.sort_in_mem_batch()?; + return if remaining_batch.num_rows() > 0 { + Poll::Ready(Some(Ok(remaining_batch))) } else { - None - } + Poll::Ready(None) + }; } - }); + }; } } /// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches /// - /// If fetch is specified for PartialSortStream `sort_in_mem_batches` will limit + /// If fetch is specified for PartialSortStream `sort_in_mem_batch` will limit /// the last RecordBatch returned and will mark the stream as closed - fn sort_in_mem_batches(self: &mut Pin<&mut Self>) -> Result { - let input_batch = concat_batches(&self.schema(), &self.in_mem_batches)?; - self.in_mem_batches.clear(); - let result = sort_batch(&input_batch, self.expr.as_ref(), self.fetch)?; + fn sort_in_mem_batch(self: &mut Pin<&mut Self>) -> Result { + let input_batch = self.in_mem_batch.clone(); + self.in_mem_batch = RecordBatch::new_empty(self.schema()); + let result = sort_batch(&input_batch, &self.expr, self.fetch)?; if let Some(remaining_fetch) = self.fetch { // remaining_fetch - result.num_rows() is always be >= 0 // because result length of sort_batch with limit cannot be @@ -503,7 +513,7 @@ mod tests { }; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -516,10 +526,11 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&source), 2, - )) as Arc; + )); let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; @@ -568,7 +579,7 @@ mod tests { for common_prefix_length in [1, 2] { let partial_sort_exec = Arc::new( PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -581,12 +592,13 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&source), common_prefix_length, ) .with_fetch(Some(4)), - ) as Arc; + ); let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; @@ -641,7 +653,7 @@ mod tests { [(1, &source_tables[0]), (2, &source_tables[1])] { let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -654,7 +666,8 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(source), common_prefix_length, )); @@ -730,8 +743,8 @@ mod tests { nulls_first: false, }; let schema = mem_exec.schema(); - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -744,17 +757,16 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; let sort_exec = Arc::new(SortExec::new( - partial_sort_executor.expr, - partial_sort_executor.input, - )) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + partial_sort_exec.expr.clone(), + Arc::clone(&partial_sort_exec.input), + )); + let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), [125, 125, 150] @@ -791,8 +803,8 @@ mod tests { (Some(150), vec![125, 25]), (Some(250), vec![125, 125]), ] { - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -805,19 +817,22 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ) .with_fetch(fetch_size); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; let sort_exec = Arc::new( - SortExec::new(partial_sort_executor.expr, partial_sort_executor.input) - .with_fetch(fetch_size), - ) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + SortExec::new( + partial_sort_exec.expr.clone(), + Arc::clone(&partial_sort_exec.input), + ) + .with_fetch(fetch_size), + ); + let result = + collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), expected_batch_num_rows @@ -846,8 +861,8 @@ mod tests { nulls_first: false, }; let fetch_size = Some(250); - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -856,15 +871,14 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ) .with_fetch(fetch_size); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; for rb in result { assert!(rb.num_rows() > 0); } @@ -897,10 +911,11 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("field_name", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), input, 1, )); @@ -986,7 +1001,7 @@ mod tests { )?; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -999,7 +1014,8 @@ mod tests { expr: col("c", &schema)?, options: option_desc, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?, 2, )); @@ -1061,10 +1077,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); let sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, 1, )); @@ -1084,4 +1101,87 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_partial_sort_with_homogeneous_batches() -> Result<()> { + // Test case for the bug where batches with homogeneous sort keys + // (e.g., [1,1,1], [2,2,2]) would not be properly detected as having + // slice points between batches. + let task_ctx = Arc::new(TaskContext::default()); + + // Create batches where each batch has homogeneous values for sort keys + let batch1 = test::build_table_i32( + ("a", &vec![1; 3]), + ("b", &vec![1; 3]), + ("c", &vec![3, 2, 1]), + ); + let batch2 = test::build_table_i32( + ("a", &vec![2; 3]), + ("b", &vec![2; 3]), + ("c", &vec![4, 6, 4]), + ); + let batch3 = test::build_table_i32( + ("a", &vec![3; 3]), + ("b", &vec![3; 3]), + ("c", &vec![9, 7, 8]), + ); + + let schema = batch1.schema(); + let mem_exec = TestMemoryExec::try_new_exec( + &[vec![batch1, batch2, batch3]], + Arc::clone(&schema), + None, + )?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + // Partial sort with common prefix of 2 (sorting by a, b, c) + let partial_sort_exec = Arc::new(PartialSortExec::new( + [ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: option_asc, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: option_asc, + }, + PhysicalSortExpr { + expr: col("c", &schema)?, + options: option_asc, + }, + ] + .into(), + mem_exec, + 2, + )); + + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + + assert_eq!(result.len(), 3,); + + allow_duplicates! { + assert_snapshot!(batches_to_string(&result), @r#" + +---+---+---+ + | a | b | c | + +---+---+---+ + | 1 | 1 | 1 | + | 1 | 1 | 2 | + | 1 | 1 | 3 | + | 2 | 2 | 4 | + | 2 | 2 | 4 | + | 2 | 2 | 6 | + | 3 | 3 | 7 | + | 3 | 3 | 8 | + | 3 | 3 | 9 | + +---+---+---+ + "#); + } + + assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 1072e9abf437e..7f47d60c735a3 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -24,41 +24,47 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; +use parking_lot::RwLock; + use crate::common::spawn_buffered; use crate::execution_plan::{Boundedness, CardinalityEffect, EmissionType}; use crate::expressions::PhysicalSortExpr; +use crate::filter_pushdown::{ + ChildFilterDescription, FilterDescription, FilterPushdownPhase, +}; use crate::limit::LimitStream; use crate::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, SplitMetrics, }; -use crate::projection::{make_with_child, update_expr, ProjectionExec}; -use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::projection::{make_with_child, update_ordering, ProjectionExec}; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; -use crate::spill::spill_manager::SpillManager; +use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; +use crate::stream::BatchSplitStream; use crate::stream::RecordBatchStreamAdapter; use crate::topk::TopK; +use crate::topk::TopKDynamicFilters; use crate::{ DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, }; -use arrow::array::{ - Array, RecordBatch, RecordBatchOptions, StringViewArray, UInt32Array, -}; -use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays, SortColumn}; -use arrow::datatypes::{DataType, SchemaRef}; -use arrow::row::{RowConverter, Rows, SortField}; +use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray}; +use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays}; +use arrow::datatypes::SchemaRef; +use datafusion_common::config::SpillCompression; use datafusion_common::{ - exec_datafusion_err, internal_datafusion_err, internal_err, Result, + internal_datafusion_err, internal_err, unwrap_or_internal_err, DataFusionError, + Result, }; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::LexOrdering; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr::PhysicalExpr; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -68,6 +74,8 @@ struct ExternalSorterMetrics { baseline: BaselineMetrics, spill_metrics: SpillMetrics, + + split_metrics: SplitMetrics, } impl ExternalSorterMetrics { @@ -75,6 +83,7 @@ impl ExternalSorterMetrics { Self { baseline: BaselineMetrics::new(metrics, partition), spill_metrics: SpillMetrics::new(metrics, partition), + split_metrics: SplitMetrics::new(metrics, partition), } } } @@ -89,8 +98,9 @@ impl ExternalSorterMetrics { /// 1. get a non-empty new batch from input /// /// 2. check with the memory manager there is sufficient space to -/// buffer the batch in memory 2.1 if memory sufficient, buffer -/// batch in memory, go to 1. +/// buffer the batch in memory. +/// +/// 2.1 if memory is sufficient, buffer batch in memory, go to 1. /// /// 2.2 if no more memory is available, sort all buffered batches and /// spill to file. buffer the next batch in memory, go to 1. @@ -204,11 +214,7 @@ struct ExternalSorter { /// Schema of the output (and the input) schema: SchemaRef, /// Sort expressions - expr: Arc<[PhysicalSortExpr]>, - /// RowConverter corresponding to the sort expressions - sort_keys_row_converter: Arc, - /// If Some, the maximum number of output rows that will be produced - fetch: Option, + expr: LexOrdering, /// The target number of rows for output batches batch_size: usize, /// If the in size of buffered memory batches is below this size, @@ -225,12 +231,16 @@ struct ExternalSorter { /// During external sorting, in-memory intermediate data will be appended to /// this file incrementally. Once finished, this file will be moved to [`Self::finished_spill_files`]. - in_progress_spill_file: Option, + /// + /// this is a tuple of: + /// 1. `InProgressSpillFile` - the file that is being written to + /// 2. `max_record_batch_memory` - the maximum memory usage of a single batch in this spill file. + in_progress_spill_file: Option<(InProgressSpillFile, usize)>, /// If data has previously been spilled, the locations of the spill files (in /// Arrow IPC format) /// Within the same spill file, the data might be chunked into multiple batches, /// and ordered by sort keys. - finished_spill_files: Vec, + finished_spill_files: Vec, // ======================================================================== // EXECUTION RESOURCES: @@ -262,9 +272,10 @@ impl ExternalSorter { schema: SchemaRef, expr: LexOrdering, batch_size: usize, - fetch: Option, sort_spill_reservation_bytes: usize, sort_in_place_threshold_bytes: usize, + // Configured via `datafusion.execution.spill_compression`. + spill_compression: SpillCompression, metrics: &ExecutionPlanMetricsSet, runtime: Arc, ) -> Result { @@ -277,37 +288,20 @@ impl ExternalSorter { MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]")) .register(&runtime.memory_pool); - // Construct RowConverter for sort keys - let sort_fields = expr - .iter() - .map(|e| { - let data_type = e - .expr - .data_type(&schema) - .map_err(|e| e.context("Resolving sort expression data type"))?; - Ok(SortField::new_with_options(data_type, e.options)) - }) - .collect::>>()?; - - let converter = RowConverter::new(sort_fields).map_err(|e| { - exec_datafusion_err!("Failed to create RowConverter: {:?}", e) - })?; - let spill_manager = SpillManager::new( Arc::clone(&runtime), metrics.spill_metrics.clone(), Arc::clone(&schema), - ); + ) + .with_compression_type(spill_compression); Ok(Self { schema, in_mem_batches: vec![], in_progress_spill_file: None, finished_spill_files: vec![], - expr: expr.into(), - sort_keys_row_converter: Arc::new(converter), + expr, metrics, - fetch, reservation, spill_manager, merge_reservation, @@ -327,15 +321,8 @@ impl ExternalSorter { } self.reserve_memory_for_merge()?; - - let size = get_reserved_byte_for_record_batch(&input); - if self.reservation.try_grow(size).is_err() { - self.sort_or_spill_in_mem_batches(false).await?; - // We've already freed more than half of reserved memory, - // so we can grow the reservation again. There's nothing we can do - // if this try_grow fails. - self.reservation.try_grow(size)?; - } + self.reserve_memory_for_batch_and_maybe_spill(&input) + .await?; self.in_mem_batches.push(input); Ok(()) @@ -361,32 +348,21 @@ impl ExternalSorter { self.merge_reservation.free(); if self.spilled_before() { - let mut streams = vec![]; - // Sort `in_mem_batches` and spill it first. If there are many // `in_mem_batches` and the memory limit is almost reached, merging // them with the spilled files at the same time might cause OOM. if !self.in_mem_batches.is_empty() { - self.sort_or_spill_in_mem_batches(true).await?; + self.sort_and_spill_in_mem_batches().await?; } - for spill in self.finished_spill_files.drain(..) { - if !spill.path().exists() { - return internal_err!("Spill file {:?} does not exist", spill.path()); - } - let stream = self.spill_manager.read_spill_as_stream(spill)?; - streams.push(stream); - } - - let expressions: LexOrdering = self.expr.iter().cloned().collect(); - StreamingMergeBuilder::new() - .with_streams(streams) + .with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files)) + .with_spill_manager(self.spill_manager.clone()) .with_schema(Arc::clone(&self.schema)) - .with_expressions(expressions.as_ref()) + .with_expressions(&self.expr.clone()) .with_metrics(self.metrics.baseline.clone()) .with_batch_size(self.batch_size) - .with_fetch(self.fetch) + .with_fetch(None) .with_reservation(self.merge_reservation.new_empty()) .build() } else { @@ -427,7 +403,7 @@ impl ExternalSorter { // Lazily initialize the in-progress spill file if self.in_progress_spill_file.is_none() { self.in_progress_spill_file = - Some(self.spill_manager.create_in_progress_file("Sorting")?); + Some((self.spill_manager.create_in_progress_file("Sorting")?, 0)); } Self::organize_stringview_arrays(globally_sorted_batches)?; @@ -437,12 +413,16 @@ impl ExternalSorter { let batches_to_spill = std::mem::take(globally_sorted_batches); self.reservation.free(); - let in_progress_file = self.in_progress_spill_file.as_mut().ok_or_else(|| { - internal_datafusion_err!("In-progress spill file should be initialized") - })?; + let (in_progress_file, max_record_batch_size) = + self.in_progress_spill_file.as_mut().ok_or_else(|| { + internal_datafusion_err!("In-progress spill file should be initialized") + })?; for batch in batches_to_spill { in_progress_file.append_batch(&batch)?; + + *max_record_batch_size = + (*max_record_batch_size).max(batch.get_sliced_size()?); } if !globally_sorted_batches.is_empty() { @@ -454,14 +434,17 @@ impl ExternalSorter { /// Finishes the in-progress spill file and moves it to the finished spill files. async fn spill_finish(&mut self) -> Result<()> { - let mut in_progress_file = + let (mut in_progress_file, max_record_batch_memory) = self.in_progress_spill_file.take().ok_or_else(|| { internal_datafusion_err!("Should be called after `spill_append`") })?; let spill_file = in_progress_file.finish()?; if let Some(spill_file) = spill_file { - self.finished_spill_files.push(spill_file); + self.finished_spill_files.push(SortedSpillFile { + file: spill_file, + max_record_batch_memory, + }); } Ok(()) @@ -532,28 +515,21 @@ impl ExternalSorter { Ok(()) } - /// Sorts the in_mem_batches and potentially spill the sorted batches. - /// - /// If the memory usage has dropped by a factor of 2, it might be a sort with - /// fetch (e.g. sorting 1M rows but only keep the top 100), so we keep the - /// sorted entries inside `in_mem_batches` to be sorted in the next iteration. - /// Otherwise, we spill the sorted run to free up memory for inserting more batches. - /// - /// # Arguments - /// - /// * `force_spill` - If true, the method will spill the in-memory batches - /// even if the memory usage has not dropped by a factor of 2. Otherwise it will - /// only spill when the memory usage has dropped by the pre-defined factor. - /// - async fn sort_or_spill_in_mem_batches(&mut self, force_spill: bool) -> Result<()> { + /// Sorts the in-memory batches and merges them into a single sorted run, then writes + /// the result to spill files. + async fn sort_and_spill_in_mem_batches(&mut self) -> Result<()> { + if self.in_mem_batches.is_empty() { + return internal_err!( + "in_mem_batches must not be empty when attempting to sort and spill" + ); + } + // Release the memory reserved for merge back to the pool so // there is some left when `in_mem_sort_stream` requests an // allocation. At the end of this function, memory will be // reserved again for the next spill. self.merge_reservation.free(); - let before = self.reservation.size(); - let mut sorted_stream = self.in_mem_sort_stream(self.metrics.baseline.intermediate())?; // After `in_mem_sort_stream()` is constructed, all `in_mem_batches` is taken @@ -568,7 +544,6 @@ impl ExternalSorter { // sort-preserving merge and incrementally append to spill files. let mut globally_sorted_batches: Vec = vec![]; - let mut spilled = false; while let Some(batch) = sorted_stream.next().await { let batch = batch?; let sorted_size = get_reserved_byte_for_record_batch(&batch); @@ -579,7 +554,6 @@ impl ExternalSorter { globally_sorted_batches.push(batch); self.consume_and_spill_append(&mut globally_sorted_batches) .await?; // reservation is freed in spill() - spilled = true; } else { globally_sorted_batches.push(batch); } @@ -589,33 +563,17 @@ impl ExternalSorter { // upcoming `self.reserve_memory_for_merge()` may fail due to insufficient memory. drop(sorted_stream); - // Sorting may free up some memory especially when fetch is `Some`. If we have - // not freed more than 50% of the memory, then we have to spill to free up more - // memory for inserting more batches. - if (self.reservation.size() > before / 2) || force_spill { - // We have not freed more than 50% of the memory, so we have to spill to - // free up more memory - self.consume_and_spill_append(&mut globally_sorted_batches) - .await?; - spilled = true; - } - - if spilled { - // There might be some buffered batches that haven't trigger a spill yet. - self.consume_and_spill_append(&mut globally_sorted_batches) - .await?; - self.spill_finish().await?; - } else { - // If the memory limit has reached before calling this function, and it - // didn't spill anything, it means this is a sorting with fetch top K - // element: after sorting only the top K elements will be kept in memory. - // For simplicity, those sorted top K entries are put back to unsorted - // `in_mem_batches` to be consumed by the next sort/merge. - if !self.in_mem_batches.is_empty() { - return internal_err!("in_mem_batches should be cleared before"); - } + self.consume_and_spill_append(&mut globally_sorted_batches) + .await?; + self.spill_finish().await?; - self.in_mem_batches = std::mem::take(&mut globally_sorted_batches); + // Sanity check after spilling + let buffers_cleared_property = + self.in_mem_batches.is_empty() && globally_sorted_batches.is_empty(); + if !buffers_cleared_property { + return internal_err!( + "in_mem_batches and globally_sorted_batches should be cleared before" + ); } // Reserve headroom for next sort/merge @@ -706,7 +664,7 @@ impl ExternalSorter { if self.in_mem_batches.len() == 1 { let batch = self.in_mem_batches.swap_remove(0); let reservation = self.reservation.take(); - return self.sort_batch_stream(batch, metrics, reservation); + return self.sort_batch_stream(batch, metrics, reservation, true); } // If less than sort_in_place_threshold_bytes, concatenate and sort in place @@ -715,9 +673,10 @@ impl ExternalSorter { let batch = concat_batches(&self.schema, &self.in_mem_batches)?; self.in_mem_batches.clear(); self.reservation - .try_resize(get_reserved_byte_for_record_batch(&batch))?; + .try_resize(get_reserved_byte_for_record_batch(&batch)) + .map_err(Self::err_with_oom_context)?; let reservation = self.reservation.take(); - return self.sort_batch_stream(batch, metrics, reservation); + return self.sort_batch_stream(batch, metrics, reservation, true); } let streams = std::mem::take(&mut self.in_mem_batches) @@ -727,20 +686,25 @@ impl ExternalSorter { let reservation = self .reservation .split(get_reserved_byte_for_record_batch(&batch)); - let input = self.sort_batch_stream(batch, metrics, reservation)?; + let input = self.sort_batch_stream( + batch, + metrics, + reservation, + // Passing false as `StreamingMergeBuilder` will split the + // stream into batches of `self.batch_size` rows. + false, + )?; Ok(spawn_buffered(input, 1)) }) .collect::>()?; - let expressions: LexOrdering = self.expr.iter().cloned().collect(); - StreamingMergeBuilder::new() .with_streams(streams) .with_schema(Arc::clone(&self.schema)) - .with_expressions(expressions.as_ref()) + .with_expressions(&self.expr.clone()) .with_metrics(metrics) .with_batch_size(self.batch_size) - .with_fetch(self.fetch) + .with_fetch(None) .with_reservation(self.merge_reservation.new_empty()) .build() } @@ -749,36 +713,31 @@ impl ExternalSorter { /// /// `reservation` accounts for the memory used by this batch and /// is released when the sort is complete + /// + /// passing `split` true will return a [`BatchSplitStream`] where each batch maximum row count + /// will be `self.batch_size`. + /// If `split` is false, the stream will return a single batch fn sort_batch_stream( &self, batch: RecordBatch, metrics: BaselineMetrics, reservation: MemoryReservation, + mut split: bool, ) -> Result { assert_eq!( get_reserved_byte_for_record_batch(&batch), reservation.size() ); + + split = split && batch.num_rows() > self.batch_size; + let schema = batch.schema(); - let fetch = self.fetch; - let expressions: LexOrdering = self.expr.iter().cloned().collect(); - let row_converter = Arc::clone(&self.sort_keys_row_converter); + let expressions = self.expr.clone(); let stream = futures::stream::once(async move { let _timer = metrics.elapsed_compute().timer(); - let sort_columns = expressions - .iter() - .map(|expr| expr.evaluate_to_sort_column(&batch)) - .collect::>>()?; - - let sorted = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one column - // https://github.com/apache/arrow-rs/issues/5454 - sort_batch_row_based(&batch, &expressions, row_converter, fetch)? - } else { - sort_batch(&batch, &expressions, fetch)? - }; + let sorted = sort_batch(&batch, &expressions, None)?; metrics.record_output(sorted.num_rows()); drop(batch); @@ -786,7 +745,18 @@ impl ExternalSorter { Ok(sorted) }); - Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + let mut output: SendableRecordBatchStream = + Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + + if split { + output = Box::pin(BatchSplitStream::new( + output, + self.batch_size, + self.metrics.split_metrics.clone(), + )); + } + + Ok(output) } /// If this sort may spill, pre-allocates @@ -797,12 +767,51 @@ impl ExternalSorter { if self.runtime.disk_manager.tmp_files_enabled() { let size = self.sort_spill_reservation_bytes; if self.merge_reservation.size() != size { - self.merge_reservation.try_resize(size)?; + self.merge_reservation + .try_resize(size) + .map_err(Self::err_with_oom_context)?; } } Ok(()) } + + /// Reserves memory to be able to accommodate the given batch. + /// If memory is scarce, tries to spill current in-memory batches to disk first. + async fn reserve_memory_for_batch_and_maybe_spill( + &mut self, + input: &RecordBatch, + ) -> Result<()> { + let size = get_reserved_byte_for_record_batch(input); + + match self.reservation.try_grow(size) { + Ok(_) => Ok(()), + Err(e) => { + if self.in_mem_batches.is_empty() { + return Err(Self::err_with_oom_context(e)); + } + + // Spill and try again. + self.sort_and_spill_in_mem_batches().await?; + self.reservation + .try_grow(size) + .map_err(Self::err_with_oom_context) + } + } + } + + /// Wraps the error with a context message suggesting settings to tweak. + /// This is meant to be used with DataFusionError::ResourcesExhausted only. + fn err_with_oom_context(e: DataFusionError) -> DataFusionError { + match e { + DataFusionError::ResourcesExhausted(_) => e.context( + "Not enough memory to continue external sort. \ + Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes" + ), + // This is not an OOM error, so just return it as is. + _ => e, + } + } } /// Estimate how much memory is needed to sort a `RecordBatch`. @@ -812,11 +821,16 @@ impl ExternalSorter { /// in sorting and merging. The sorted copies are in either row format or array format. /// Please refer to cursor.rs and stream.rs for more details. No matter what format the /// sorted copies are, they will use more memory than the original record batch. -fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { +pub(crate) fn get_reserved_byte_for_record_batch_size(record_batch_size: usize) -> usize { // 2x may not be enough for some cases, but it's a good start. // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes` // to compensate for the extra memory needed. - get_record_batch_memory_size(batch) * 2 + record_batch_size * 2 +} + +/// Estimate how much memory is needed to sort a `RecordBatch`. +fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { + get_reserved_byte_for_record_batch_size(get_record_batch_memory_size(batch)) } impl Debug for ExternalSorter { @@ -830,45 +844,6 @@ impl Debug for ExternalSorter { } } -/// Converts rows into a sorted array of indices based on their order. -/// This function returns the indices that represent the sorted order of the rows. -fn rows_to_indices(rows: Rows, limit: Option) -> Result { - let mut sort: Vec<_> = rows.iter().enumerate().collect(); - sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); - - let mut len = rows.num_rows(); - if let Some(limit) = limit { - len = limit.min(len); - } - let indices = - UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as u32)); - Ok(indices) -} - -/// Sorts a `RecordBatch` by converting its sort columns into Arrow Row Format for faster comparison. -fn sort_batch_row_based( - batch: &RecordBatch, - expressions: &LexOrdering, - row_converter: Arc, - fetch: Option, -) -> Result { - let sort_columns = expressions - .iter() - .map(|expr| expr.evaluate_to_sort_column(batch).map(|col| col.values)) - .collect::>>()?; - let rows = row_converter.convert_columns(&sort_columns)?; - let indices = rows_to_indices(rows, fetch)?; - let columns = take_arrays(batch.columns(), &indices, None)?; - - let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); - - Ok(RecordBatch::try_new_with_options( - batch.schema(), - columns, - &options, - )?) -} - pub fn sort_batch( batch: &RecordBatch, expressions: &LexOrdering, @@ -879,14 +854,7 @@ pub fn sort_batch( .map(|expr| expr.evaluate_to_sort_column(batch)) .collect::>>()?; - let indices = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one column - // https://github.com/apache/arrow-rs/issues/5454 - lexsort_to_indices_multi_columns(sort_columns, fetch)? - } else { - lexsort_to_indices(&sort_columns, fetch)? - }; - + let indices = lexsort_to_indices(&sort_columns, fetch)?; let mut columns = take_arrays(batch.columns(), &indices, None)?; // The columns may be larger than the unsorted columns in `batch` especially for variable length @@ -905,50 +873,6 @@ pub fn sort_batch( )?) } -#[inline] -fn is_multi_column_with_lists(sort_columns: &[SortColumn]) -> bool { - sort_columns.iter().any(|c| { - matches!( - c.values.data_type(), - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) - ) - }) -} - -pub(crate) fn lexsort_to_indices_multi_columns( - sort_columns: Vec, - limit: Option, -) -> Result { - let (fields, columns) = sort_columns.into_iter().fold( - (vec![], vec![]), - |(mut fields, mut columns), sort_column| { - fields.push(SortField::new_with_options( - sort_column.values.data_type().clone(), - sort_column.options.unwrap_or_default(), - )); - columns.push(sort_column.values); - (fields, columns) - }, - ); - - // Note: row converter is reused through `sort_batch_row_based()`, this function - // is not used during normal sort execution, but it's kept temporarily because - // it's inside a public interface `sort_batch()`. - let converter = RowConverter::new(fields)?; - let rows = converter.convert_columns(&columns)?; - let mut sort: Vec<_> = rows.iter().enumerate().collect(); - sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); - - let mut len = rows.num_rows(); - if let Some(limit) = limit { - len = limit.min(len); - } - let indices = - UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as u32)); - - Ok(indices) -} - /// Sort execution plan. /// /// Support sorting datasets that are larger than the memory allotted @@ -966,8 +890,14 @@ pub struct SortExec { preserve_partitioning: bool, /// Fetch highest/lowest n results fetch: Option, + /// Normalized common sort prefix between the input and the sort expressions (only used with fetch) + common_sort_prefix: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, + /// Filter matching the state of the sort for dynamic filter pushdown. + /// If `fetch` is `Some`, this will also be set and a TopK operator may be used. + /// If `fetch` is `None`, this will be `None`. + filter: Option>>, } impl SortExec { @@ -975,14 +905,18 @@ impl SortExec { /// sorted output partition. pub fn new(expr: LexOrdering, input: Arc) -> Self { let preserve_partitioning = false; - let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning); + let (cache, sort_prefix) = + Self::compute_properties(&input, expr.clone(), preserve_partitioning) + .unwrap(); Self { expr, input, metrics_set: ExecutionPlanMetricsSet::new(), preserve_partitioning, fetch: None, + common_sort_prefix: sort_prefix, cache, + filter: None, } } @@ -1009,6 +943,31 @@ impl SortExec { self } + /// Add or reset `self.filter` to a new `TopKDynamicFilters`. + fn create_filter(&self) -> Arc> { + let children = self + .expr + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect::>(); + Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new( + DynamicFilterPhysicalExpr::new(children, lit(true)), + )))) + } + + fn cloned(&self) -> Self { + SortExec { + input: Arc::clone(&self.input), + expr: self.expr.clone(), + metrics_set: self.metrics_set.clone(), + preserve_partitioning: self.preserve_partitioning, + common_sort_prefix: self.common_sort_prefix.clone(), + fetch: self.fetch, + cache: self.cache.clone(), + filter: self.filter.clone(), + } + } + /// Modify how many rows to include in the result /// /// If None, then all rows will be returned, in sorted order. @@ -1028,14 +987,15 @@ impl SortExec { if fetch.is_some() && is_pipeline_friendly { cache = cache.with_boundedness(Boundedness::Bounded); } - SortExec { - input: Arc::clone(&self.input), - expr: self.expr.clone(), - metrics_set: self.metrics_set.clone(), - preserve_partitioning: self.preserve_partitioning, - fetch, - cache, - } + let filter = fetch.is_some().then(|| { + // If we already have a filter, keep it. Otherwise, create a new one. + self.filter.clone().unwrap_or_else(|| self.create_filter()) + }); + let mut new_sort = self.cloned(); + new_sort.fetch = fetch; + new_sort.cache = cache; + new_sort.filter = filter; + new_sort } /// Input schema @@ -1066,19 +1026,18 @@ impl SortExec { } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + /// It also returns the common sort prefix between the input and the sort expressions. fn compute_properties( input: &Arc, sort_exprs: LexOrdering, preserve_partitioning: bool, - ) -> PlanProperties { - // Determine execution mode: - let requirement = LexRequirement::from(sort_exprs); - let sort_satisfied = input + ) -> Result<(PlanProperties, Vec)> { + let (sort_prefix, sort_satisfied) = input .equivalence_properties() - .ordering_satisfy_requirement(&requirement); + .extract_common_sort_prefix(sort_exprs.clone())?; // The emission type depends on whether the input is already sorted: - // - If already sorted, we can emit results in the same way as the input + // - If already fully sorted, we can emit results in the same way as the input // - If not sorted, we must wait until all data is processed to emit results (Final) let emission_type = if sort_satisfied { input.pipeline_behavior() @@ -1104,22 +1063,22 @@ impl SortExec { // Calculate equivalence properties; i.e. reset the ordering equivalence // class with the new ordering: - let sort_exprs = LexOrdering::from(requirement); - let eq_properties = input - .equivalence_properties() - .clone() - .with_reorder(sort_exprs); + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.reorder(sort_exprs)?; // Get output partitioning: let output_partitioning = Self::output_partitioning_helper(input, preserve_partitioning); - PlanProperties::new( - eq_properties, - output_partitioning, - emission_type, - boundedness, - ) + Ok(( + PlanProperties::new( + eq_properties, + output_partitioning, + emission_type, + boundedness, + ), + sort_prefix, + )) } } @@ -1130,7 +1089,29 @@ impl DisplayAs for SortExec { let preserve_partitioning = self.preserve_partitioning; match self.fetch { Some(fetch) => { - write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr) + write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr)?; + if let Some(filter) = &self.filter { + if let Ok(current) = filter.read().expr().current() { + if !current.eq(&lit(true)) { + write!(f, ", filter=[{current}]")?; + } + } + } + if !self.common_sort_prefix.is_empty() { + write!(f, ", sort_prefix=[")?; + let mut first = true; + for sort_expr in &self.common_sort_prefix { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]") + } else { + Ok(()) + } } None => write!(f, "SortExec: expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr), } @@ -1150,7 +1131,10 @@ impl DisplayAs for SortExec { impl ExecutionPlan for SortExec { fn name(&self) -> &'static str { - "SortExec" + match self.fetch { + Some(_) => "SortExec(TopK)", + None => "SortExec", + } } fn as_any(&self) -> &dyn Any { @@ -1183,9 +1167,35 @@ impl ExecutionPlan for SortExec { self: Arc, children: Vec>, ) -> Result> { - let new_sort = SortExec::new(self.expr.clone(), Arc::clone(&children[0])) - .with_fetch(self.fetch) - .with_preserve_partitioning(self.preserve_partitioning); + let mut new_sort = self.cloned(); + assert!( + children.len() == 1, + "SortExec should have exactly one child" + ); + new_sort.input = Arc::clone(&children[0]); + // Recompute the properties based on the new input since they may have changed + let (cache, sort_prefix) = Self::compute_properties( + &new_sort.input, + new_sort.expr.clone(), + new_sort.preserve_partitioning, + )?; + new_sort.cache = cache; + new_sort.common_sort_prefix = sort_prefix; + + Ok(Arc::new(new_sort)) + } + + fn reset_state(self: Arc) -> Result> { + let children = self.children().into_iter().cloned().collect(); + let new_sort = self.with_new_children(children)?; + let mut new_sort = new_sort + .as_any() + .downcast_ref::() + .expect("cloned 1 lines above this line, we know the type") + .clone(); + // Our dynamic filter and execution metrics are the state we need to reset. + new_sort.filter = Some(new_sort.create_filter()); + new_sort.metrics_set = ExecutionPlanMetricsSet::new(); Ok(Arc::new(new_sort)) } @@ -1201,12 +1211,12 @@ impl ExecutionPlan for SortExec { let execution_options = &context.session_config().options().execution; - trace!("End SortExec's input.execute for partition: {}", partition); + trace!("End SortExec's input.execute for partition: {partition}"); let sort_satisfied = self .input .equivalence_properties() - .ordering_satisfy_requirement(&LexRequirement::from(self.expr.clone())); + .ordering_satisfy(self.expr.clone())?; match (sort_satisfied, self.fetch.as_ref()) { (true, Some(fetch)) => Ok(Box::pin(LimitStream::new( @@ -1217,14 +1227,17 @@ impl ExecutionPlan for SortExec { ))), (true, None) => Ok(input), (false, Some(fetch)) => { + let filter = self.filter.clone(); let mut topk = TopK::try_new( partition, input.schema(), + self.common_sort_prefix.clone(), self.expr.clone(), *fetch, context.session_config().batch_size(), context.runtime_env(), &self.metrics_set, + Arc::clone(&unwrap_or_internal_err!(filter)), )?; Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), @@ -1232,6 +1245,9 @@ impl ExecutionPlan for SortExec { while let Some(batch) = input.next().await { let batch = batch?; topk.insert_batch(batch)?; + if topk.finished { + break; + } } topk.emit() }) @@ -1244,9 +1260,9 @@ impl ExecutionPlan for SortExec { input.schema(), self.expr.clone(), context.session_config().batch_size(), - self.fetch, execution_options.sort_spill_reservation_bytes, execution_options.sort_in_place_threshold_bytes, + context.session_config().spill_compression(), &self.metrics_set, context.runtime_env(), )?; @@ -1270,7 +1286,19 @@ impl ExecutionPlan for SortExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if !self.preserve_partitioning() { + return self + .input + .partition_statistics(None)? + .with_fetch(self.fetch, 0, 1); + } + self.input + .partition_statistics(partition)? + .with_fetch(self.fetch, 0, 1) } fn with_fetch(&self, limit: Option) -> Option> { @@ -1301,17 +1329,10 @@ impl ExecutionPlan for SortExec { return Ok(None); } - let mut updated_exprs = LexOrdering::default(); - for sort in self.expr() { - let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? - else { - return Ok(None); - }; - updated_exprs.push(PhysicalSortExpr { - expr: new_expr, - options: sort.options, - }); - } + let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())? + else { + return Ok(None); + }; Ok(Some(Arc::new( SortExec::new(updated_exprs, make_with_child(projection, self.input())?) @@ -1319,6 +1340,28 @@ impl ExecutionPlan for SortExec { .with_preserve_partitioning(self.preserve_partitioning()), ))) } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + config: &datafusion_common::config::ConfigOptions, + ) -> Result { + if !matches!(phase, FilterPushdownPhase::Post) { + return FilterDescription::from_children(parent_filters, &self.children()); + } + + let mut child = + ChildFilterDescription::from_child(&parent_filters, self.input())?; + + if let Some(filter) = &self.filter { + if config.optimizer.enable_dynamic_filter_pushdown { + child = child.with_self_filter(filter.read().expr()); + } + } + + Ok(FilterDescription::new().with_child(child)) + } } #[cfg(test)] @@ -1333,16 +1376,16 @@ mod tests { use crate::execution_plan::Boundedness; use crate::expressions::col; use crate::test; - use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::TestMemoryExec; + use crate::test::{assert_is_pending, make_partition}; use arrow::array::*; use arrow::compute::SortOptions; use arrow::datatypes::*; use datafusion_common::cast::as_primitive_array; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::RecordBatchStream; @@ -1373,9 +1416,9 @@ mod tests { impl SortedUnboundedExec { fn compute_properties(schema: SchemaRef) -> PlanProperties { let mut eq_properties = EquivalenceProperties::new(schema); - eq_properties.add_new_orderings(vec![LexOrdering::new(vec![ - PhysicalSortExpr::new_default(Arc::new(Column::new("c1", 0))), - ])]); + eq_properties.add_ordering([PhysicalSortExpr::new_default(Arc::new( + Column::new("c1", 0), + ))]); PlanProperties::new( eq_properties, Partitioning::UnknownPartitioning(1), @@ -1475,10 +1518,11 @@ mod tests { let schema = csv.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(csv)), )); @@ -1486,7 +1530,6 @@ mod tests { assert_eq!(result.len(), 1); assert_eq!(result[0].num_rows(), 400); - assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), 0, @@ -1521,10 +1564,11 @@ mod tests { let schema = input.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(input)), )); @@ -1550,14 +1594,13 @@ mod tests { // bytes. We leave a little wiggle room for the actual numbers. assert!((3..=10).contains(&spill_count)); assert!((9000..=10000).contains(&spilled_rows)); - assert!((38000..=42000).contains(&spilled_bytes)); + assert!((38000..=44000).contains(&spilled_bytes)); let columns = result[0].columns(); let i = as_primitive_array::(&columns[0])?; assert_eq!(i.value(0), 0); assert_eq!(i.value(i.len() - 1), 81); - assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), 0, @@ -1567,6 +1610,60 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_batch_reservation_error() -> Result<()> { + // Pick a memory limit and sort_spill_reservation that make the first batch reservation fail. + // These values assume that the ExternalSorter will reserve 800 bytes for the first batch. + let expected_batch_reservation = 800; + let merge_reservation: usize = 0; // Set to 0 for simplicity + let memory_limit: usize = expected_batch_reservation + merge_reservation - 1; // Just short of what we need + + let session_config = + SessionConfig::new().with_sort_spill_reservation_bytes(merge_reservation); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .build_arc()?; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + let plan = test::scan_partitioned(1); + + // Read the first record batch to assert that our memory limit and sort_spill_reservation + // settings trigger the test scenario. + { + let mut stream = plan.execute(0, Arc::clone(&task_ctx))?; + let first_batch = stream.next().await.unwrap()?; + let batch_reservation = get_reserved_byte_for_record_batch(&first_batch); + + assert_eq!(batch_reservation, expected_batch_reservation); + assert!(memory_limit < (merge_reservation + batch_reservation)); + } + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr::new_default(col("i", &plan.schema())?)].into(), + plan, + )); + + let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await; + + let err = result.unwrap_err(); + assert!( + matches!(err, DataFusionError::Context(..)), + "Assertion failed: expected a Context error, but got: {err:?}" + ); + + // Assert that the context error is wrapping a resources exhausted error. + assert!( + matches!(err.find_root(), DataFusionError::ResourcesExhausted(_)), + "Assertion failed: expected a ResourcesExhausted error, but got: {err:?}" + ); + + Ok(()) + } + #[tokio::test] async fn test_sort_spill_utf8_strings() -> Result<()> { let session_config = SessionConfig::new() @@ -1589,18 +1686,15 @@ mod tests { let schema = input.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(input)), )); - let result = collect( - Arc::clone(&sort_exec) as Arc, - Arc::clone(&task_ctx), - ) - .await?; + let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; let num_rows = result.iter().map(|batch| batch.num_rows()).sum::(); assert_eq!(num_rows, 20000); @@ -1689,20 +1783,18 @@ mod tests { let sort_exec = Arc::new( SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(csv)), ) .with_fetch(fetch), ); - let result = collect( - Arc::clone(&sort_exec) as Arc, - Arc::clone(&task_ctx), - ) - .await?; + let result = + collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; assert_eq!(result.len(), 1); let metrics = sort_exec.metrics().unwrap(); @@ -1732,16 +1824,16 @@ mod tests { let data: ArrayRef = Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); - let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data]).unwrap(); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?; let input = - TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None) - .unwrap(); + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("field_name", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), input, )); @@ -1750,7 +1842,7 @@ mod tests { let expected_data: ArrayRef = Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); let expected_batch = - RecordBatch::try_new(Arc::clone(&schema), vec![expected_data]).unwrap(); + RecordBatch::try_new(Arc::clone(&schema), vec![expected_data])?; // Data is correct assert_eq!(&vec![expected_batch], &result); @@ -1789,7 +1881,7 @@ mod tests { )?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -1804,7 +1896,8 @@ mod tests { nulls_first: false, }, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?, )); @@ -1875,7 +1968,7 @@ mod tests { )?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -1890,7 +1983,8 @@ mod tests { nulls_first: false, }, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?, )); @@ -1954,10 +2048,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, )); @@ -1985,12 +2080,13 @@ mod tests { RecordBatch::try_new_with_options(Arc::clone(&schema), vec![], &options) .unwrap(); - let expressions = LexOrdering::new(vec![PhysicalSortExpr { + let expressions = [PhysicalSortExpr { expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), options: SortOptions::default(), - }]); + }] + .into(); - let result = sort_batch(&batch, expressions.as_ref(), None).unwrap(); + let result = sort_batch(&batch, &expressions, None).unwrap(); assert_eq!(result.num_rows(), 1); } @@ -2004,9 +2100,10 @@ mod tests { cache: SortedUnboundedExec::compute_properties(Arc::new(schema.clone())), }; let mut plan = SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new( "c1", 0, - )))]), + )))] + .into(), Arc::new(source), ); plan = plan.with_fetch(Some(9)); @@ -2029,4 +2126,270 @@ mod tests { "#); Ok(()) } + + #[tokio::test] + async fn should_return_stream_with_batches_in_the_requested_size() -> Result<()> { + let batch_size = 100; + + let create_task_ctx = |_: &[RecordBatch]| { + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(batch_size) + .with_sort_in_place_threshold_bytes(usize::MAX), + ) + }; + + // Smaller than batch size and require more than a single batch to get the requested batch size + test_sort_output_batch_size(10, batch_size / 4, create_task_ctx).await?; + + // Not evenly divisible by batch size + test_sort_output_batch_size(10, batch_size + 7, create_task_ctx).await?; + + // Evenly divisible by batch size and is larger than 2 output batches + test_sort_output_batch_size(10, batch_size * 3, create_task_ctx).await?; + + Ok(()) + } + + #[tokio::test] + async fn should_return_stream_with_batches_in_the_requested_size_when_sorting_in_place( + ) -> Result<()> { + let batch_size = 100; + + let create_task_ctx = |_: &[RecordBatch]| { + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(batch_size) + .with_sort_in_place_threshold_bytes(usize::MAX - 1), + ) + }; + + // Smaller than batch size and require more than a single batch to get the requested batch size + { + let metrics = + test_sort_output_batch_size(10, batch_size / 4, create_task_ctx).await?; + + assert_eq!( + metrics.spill_count(), + Some(0), + "Expected no spills when sorting in place" + ); + } + + // Not evenly divisible by batch size + { + let metrics = + test_sort_output_batch_size(10, batch_size + 7, create_task_ctx).await?; + + assert_eq!( + metrics.spill_count(), + Some(0), + "Expected no spills when sorting in place" + ); + } + + // Evenly divisible by batch size and is larger than 2 output batches + { + let metrics = + test_sort_output_batch_size(10, batch_size * 3, create_task_ctx).await?; + + assert_eq!( + metrics.spill_count(), + Some(0), + "Expected no spills when sorting in place" + ); + } + + Ok(()) + } + + #[tokio::test] + async fn should_return_stream_with_batches_in_the_requested_size_when_having_a_single_batch( + ) -> Result<()> { + let batch_size = 100; + + let create_task_ctx = |_: &[RecordBatch]| { + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(batch_size)) + }; + + // Smaller than batch size and require more than a single batch to get the requested batch size + { + let metrics = test_sort_output_batch_size( + // Single batch + 1, + batch_size / 4, + create_task_ctx, + ) + .await?; + + assert_eq!( + metrics.spill_count(), + Some(0), + "Expected no spills when sorting in place" + ); + } + + // Not evenly divisible by batch size + { + let metrics = test_sort_output_batch_size( + // Single batch + 1, + batch_size + 7, + create_task_ctx, + ) + .await?; + + assert_eq!( + metrics.spill_count(), + Some(0), + "Expected no spills when sorting in place" + ); + } + + // Evenly divisible by batch size and is larger than 2 output batches + { + let metrics = test_sort_output_batch_size( + // Single batch + 1, + batch_size * 3, + create_task_ctx, + ) + .await?; + + assert_eq!( + metrics.spill_count(), + Some(0), + "Expected no spills when sorting in place" + ); + } + + Ok(()) + } + + #[tokio::test] + async fn should_return_stream_with_batches_in_the_requested_size_when_having_to_spill( + ) -> Result<()> { + let batch_size = 100; + + let create_task_ctx = |generated_batches: &[RecordBatch]| { + let batches_memory = generated_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum::(); + + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(batch_size) + // To make sure there is no in place sorting + .with_sort_in_place_threshold_bytes(1) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime( + RuntimeEnvBuilder::default() + .with_memory_limit(batches_memory, 1.0) + .build_arc() + .unwrap(), + ) + }; + + // Smaller than batch size and require more than a single batch to get the requested batch size + { + let metrics = + test_sort_output_batch_size(10, batch_size / 4, create_task_ctx).await?; + + assert_ne!(metrics.spill_count().unwrap(), 0, "expected to spill"); + } + + // Not evenly divisible by batch size + { + let metrics = + test_sort_output_batch_size(10, batch_size + 7, create_task_ctx).await?; + + assert_ne!(metrics.spill_count().unwrap(), 0, "expected to spill"); + } + + // Evenly divisible by batch size and is larger than 2 batches + { + let metrics = + test_sort_output_batch_size(10, batch_size * 3, create_task_ctx).await?; + + assert_ne!(metrics.spill_count().unwrap(), 0, "expected to spill"); + } + + Ok(()) + } + + async fn test_sort_output_batch_size( + number_of_batches: usize, + batch_size_to_generate: usize, + create_task_ctx: impl Fn(&[RecordBatch]) -> TaskContext, + ) -> Result { + let batches = (0..number_of_batches) + .map(|_| make_partition(batch_size_to_generate as i32)) + .collect::>(); + let task_ctx = create_task_ctx(batches.as_slice()); + + let expected_batch_size = task_ctx.session_config().batch_size(); + + let (mut output_batches, metrics) = + run_sort_on_input(task_ctx, "i", batches).await?; + + let last_batch = output_batches.pop().unwrap(); + + for batch in output_batches { + assert_eq!(batch.num_rows(), expected_batch_size); + } + + let mut last_expected_batch_size = + (batch_size_to_generate * number_of_batches) % expected_batch_size; + if last_expected_batch_size == 0 { + last_expected_batch_size = expected_batch_size; + } + assert_eq!(last_batch.num_rows(), last_expected_batch_size); + + Ok(metrics) + } + + async fn run_sort_on_input( + task_ctx: TaskContext, + order_by_col: &str, + batches: Vec, + ) -> Result<(Vec, MetricsSet)> { + let task_ctx = Arc::new(task_ctx); + + // let task_ctx = env. + let schema = batches[0].schema(); + let ordering: LexOrdering = [PhysicalSortExpr { + expr: col(order_by_col, &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }] + .into(); + let sort_exec: Arc = Arc::new(SortExec::new( + ordering.clone(), + TestMemoryExec::try_new_exec(std::slice::from_ref(&batches), schema, None)?, + )); + + let sorted_batches = + collect(Arc::clone(&sort_exec), Arc::clone(&task_ctx)).await?; + + let metrics = sort_exec.metrics().expect("sort have metrics"); + + // assert output + { + let input_batches_concat = concat_batches(batches[0].schema_ref(), &batches)?; + let sorted_input_batch = sort_batch(&input_batches_concat, &ordering, None)?; + + let sorted_batches_concat = + concat_batches(sorted_batches[0].schema_ref(), &sorted_batches)?; + + assert_eq!(sorted_input_batch, sorted_batches_concat); + } + + Ok((sorted_batches, metrics)) + } } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index b987dff36441d..3a94f156fa9b3 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::common::spawn_buffered; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::{make_with_child, update_expr, ProjectionExec}; +use crate::projection::{make_with_child, update_ordering, ProjectionExec}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, @@ -33,9 +33,9 @@ use crate::{ use datafusion_common::{internal_err, Result}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; +use crate::execution_plan::{EvaluationType, SchedulingType}; use log::{debug, trace}; /// Sort preserving merge execution plan @@ -144,7 +144,7 @@ impl SortPreservingMergeExec { /// Sort expressions pub fn expr(&self) -> &LexOrdering { - self.expr.as_ref() + &self.expr } /// Fetch @@ -158,15 +158,27 @@ impl SortPreservingMergeExec { input: &Arc, ordering: LexOrdering, ) -> PlanProperties { + let input_partitions = input.output_partitioning().partition_count(); + let (drive, scheduling) = if input_partitions > 1 { + (EvaluationType::Eager, SchedulingType::Cooperative) + } else { + ( + input.properties().evaluation_type, + input.properties().scheduling_type, + ) + }; + let mut eq_properties = input.equivalence_properties().clone(); eq_properties.clear_per_partition_constants(); - eq_properties.add_new_orderings(vec![ordering]); + eq_properties.add_ordering(ordering); PlanProperties::new( eq_properties, // Equivalence Properties Partitioning::UnknownPartitioning(1), // Output Partitioning input.pipeline_behavior(), // Pipeline Behavior input.boundedness(), // Boundedness ) + .with_evaluation_type(drive) + .with_scheduling_type(scheduling) } } @@ -186,15 +198,16 @@ impl DisplayAs for SortPreservingMergeExec { Ok(()) } DisplayFormatType::TreeRender => { + if let Some(fetch) = self.fetch { + writeln!(f, "limit={fetch}")?; + }; + for (i, e) in self.expr().iter().enumerate() { e.fmt_sql(f)?; if i != self.expr().len() - 1 { write!(f, ", ")?; } } - if let Some(fetch) = self.fetch { - writeln!(f, "limit={fetch}")?; - }; Ok(()) } @@ -240,8 +253,8 @@ impl ExecutionPlan for SortPreservingMergeExec { vec![false] } - fn required_input_ordering(&self) -> Vec> { - vec![Some(LexRequirement::from(self.expr.clone()))] + fn required_input_ordering(&self) -> Vec> { + vec![Some(OrderingRequirements::from(self.expr.clone()))] } fn maintains_input_order(&self) -> Vec { @@ -267,10 +280,7 @@ impl ExecutionPlan for SortPreservingMergeExec { partition: usize, context: Arc, ) -> Result { - trace!( - "Start SortPreservingMergeExec::execute for partition: {}", - partition - ); + trace!("Start SortPreservingMergeExec::execute for partition: {partition}"); if 0 != partition { return internal_err!( "SortPreservingMergeExec invalid partition {partition}" @@ -279,8 +289,7 @@ impl ExecutionPlan for SortPreservingMergeExec { let input_partitions = self.input.output_partitioning().partition_count(); trace!( - "Number of input partitions of SortPreservingMergeExec::execute: {}", - input_partitions + "Number of input partitions of SortPreservingMergeExec::execute: {input_partitions}" ); let schema = self.schema(); @@ -323,7 +332,7 @@ impl ExecutionPlan for SortPreservingMergeExec { let result = StreamingMergeBuilder::new() .with_streams(receivers) .with_schema(schema) - .with_expressions(self.expr.as_ref()) + .with_expressions(&self.expr) .with_metrics(BaselineMetrics::new(&self.metrics, partition)) .with_batch_size(context.session_config().batch_size()) .with_fetch(self.fetch) @@ -343,7 +352,11 @@ impl ExecutionPlan for SortPreservingMergeExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, _partition: Option) -> Result { + self.input.partition_statistics(None) } fn supports_limit_pushdown(&self) -> bool { @@ -362,17 +375,10 @@ impl ExecutionPlan for SortPreservingMergeExec { return Ok(None); } - let mut updated_exprs = LexOrdering::default(); - for sort in self.expr() { - let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? - else { - return Ok(None); - }; - updated_exprs.push(PhysicalSortExpr { - expr: updated_expr, - options: sort.options, - }); - } + let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())? + else { + return Ok(None); + }; Ok(Some(Arc::new( SortPreservingMergeExec::new( @@ -386,10 +392,11 @@ impl ExecutionPlan for SortPreservingMergeExec { #[cfg(test)] mod tests { + use std::collections::HashSet; use std::fmt::Formatter; use std::pin::Pin; use std::sync::Mutex; - use std::task::{Context, Poll}; + use std::task::{ready, Context, Poll, Waker}; use std::time::Duration; use super::*; @@ -413,7 +420,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError}; + use datafusion_common::{assert_batches_eq, exec_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; @@ -421,8 +428,8 @@ mod tests { use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use futures::{FutureExt, Stream, StreamExt}; use insta::assert_snapshot; use tokio::time::timeout; @@ -449,24 +456,25 @@ mod tests { let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); - let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?; let rbs = (0..1024).map(|_| rb.clone()).collect::>(); let schema = rb.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { - expr: col("b", &schema).unwrap(), + expr: col("b", &schema)?, options: Default::default(), }, PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + expr: col("c", &schema)?, options: Default::default(), }, - ]); + ] + .into(); let repartition_exec = RepartitionExec::try_new( - TestMemoryExec::try_new_exec(&[rbs], schema, None).unwrap(), + TestMemoryExec::try_new_exec(&[rbs], schema, None)?, Partitioning::RoundRobinBatch(2), )?; let coalesce_batches_exec = @@ -485,7 +493,7 @@ mod tests { async fn test_round_robin_tie_breaker_success() -> Result<()> { let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; let spm = generate_spm_for_round_robin_tie_breaker(true)?; - let _collected = collect(spm, task_ctx).await.unwrap(); + let _collected = collect(spm, task_ctx).await?; Ok(()) } @@ -550,30 +558,6 @@ mod tests { .await; } - #[tokio::test] - async fn test_merge_no_exprs() { - let task_ctx = Arc::new(TaskContext::default()); - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); - let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); - - let schema = batch.schema(); - let sort = LexOrdering::default(); // no sort expressions - let exec = TestMemoryExec::try_new_exec( - &[vec![batch.clone()], vec![batch]], - schema, - None, - ) - .unwrap(); - - let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); - - let res = collect(merge, task_ctx).await.unwrap_err(); - assert_contains!( - res.to_string(), - "Internal error: Sort expressions cannot be empty for streaming merge" - ); - } - #[tokio::test] async fn test_merge_some_overlap() { let task_ctx = Arc::new(TaskContext::default()); @@ -741,7 +725,7 @@ mod tests { context: Arc, ) { let schema = partitions[0][0].schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), @@ -750,7 +734,8 @@ mod tests { expr: col("c", &schema).unwrap(), options: Default::default(), }, - ]); + ] + .into(); let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -798,13 +783,14 @@ mod tests { let csv = test::scan_partitioned(partitions); let schema = csv.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let basic = basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await; @@ -859,17 +845,18 @@ mod tests { let sorted = basic_sort(csv, sort, context).await; let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect(); - Ok(TestMemoryExec::try_new_exec(&split, sorted.schema(), None).unwrap()) + TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _) } #[tokio::test] async fn test_partition_sort_streaming_input() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: Default::default(), - }]); + }] + .into(); let input = sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx)) @@ -881,12 +868,9 @@ mod tests { assert_eq!(basic.num_rows(), 1200); assert_eq!(partition.num_rows(), 1200); - let basic = arrow::util::pretty::pretty_format_batches(&[basic]) - .unwrap() - .to_string(); - let partition = arrow::util::pretty::pretty_format_batches(&[partition]) - .unwrap() - .to_string(); + let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string(); + let partition = + arrow::util::pretty::pretty_format_batches(&[partition])?.to_string(); assert_eq!(basic, partition); @@ -896,10 +880,11 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input_output() -> Result<()> { let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: Default::default(), - }]); + }] + .into(); // Test streaming with default batch size let task_ctx = Arc::new(TaskContext::default()); @@ -914,19 +899,14 @@ mod tests { let task_ctx = Arc::new(task_ctx); let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); - let merged = collect(merge, task_ctx).await.unwrap(); + let merged = collect(merge, task_ctx).await?; assert_eq!(merged.len(), 53); - assert_eq!(basic.num_rows(), 1200); assert_eq!(merged.iter().map(|x| x.num_rows()).sum::(), 1200); - let basic = arrow::util::pretty::pretty_format_batches(&[basic]) - .unwrap() - .to_string(); - let partition = arrow::util::pretty::pretty_format_batches(merged.as_slice()) - .unwrap() - .to_string(); + let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string(); + let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string(); assert_eq!(basic, partition); @@ -971,7 +951,7 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); let schema = b1.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { @@ -986,7 +966,8 @@ mod tests { nulls_first: false, }, }, - ]); + ] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1020,13 +1001,14 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2))); @@ -1052,13 +1034,14 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1082,10 +1065,11 @@ mod tests { async fn test_async() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort: LexOrdering = [PhysicalSortExpr { expr: col("i", &schema).unwrap(), options: SortOptions::default(), - }]); + }] + .into(); let batches = sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx)) @@ -1121,7 +1105,7 @@ mod tests { let merge_stream = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(batches.schema()) - .with_expressions(sort.as_ref()) + .with_expressions(&sort) .with_metrics(BaselineMetrics::new(&metrics, 0)) .with_batch_size(task_ctx.session_config().batch_size()) .with_fetch(fetch) @@ -1161,10 +1145,11 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = b1.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1220,10 +1205,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); let refs = blocking_exec.refs(); let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, )); @@ -1268,13 +1254,14 @@ mod tests { let schema = partitions[0][0].schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("value", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1313,13 +1300,50 @@ mod tests { "#); } + #[derive(Debug)] + struct CongestionState { + wakers: Vec, + unpolled_partitions: HashSet, + } + + #[derive(Debug)] + struct Congestion { + congestion_state: Mutex, + } + + impl Congestion { + fn new(partition_count: usize) -> Self { + Congestion { + congestion_state: Mutex::new(CongestionState { + wakers: vec![], + unpolled_partitions: (0usize..partition_count).collect(), + }), + } + } + + fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> { + let mut state = self.congestion_state.lock().unwrap(); + + state.unpolled_partitions.remove(&partition); + + if state.unpolled_partitions.is_empty() { + state.wakers.iter().for_each(|w| w.wake_by_ref()); + state.wakers.clear(); + Poll::Ready(()) + } else { + state.wakers.push(cx.waker().clone()); + Poll::Pending + } + } + } + /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st /// partition is exhausted from the start, and if it is polled more than one, it panics. #[derive(Debug, Clone)] struct CongestedExec { schema: Schema, cache: PlanProperties, - congestion_cleared: Arc>, + congestion: Arc, } impl CongestedExec { @@ -1331,10 +1355,11 @@ mod tests { .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc) .collect::>(); let mut eq_properties = EquivalenceProperties::new(schema); - eq_properties.add_new_orderings(vec![columns - .iter() - .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))) - .collect::()]); + eq_properties.add_ordering( + columns + .iter() + .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))), + ); PlanProperties::new( eq_properties, Partitioning::Hash(columns, 3), @@ -1373,7 +1398,7 @@ mod tests { Ok(Box::pin(CongestedStream { schema: Arc::new(self.schema.clone()), none_polled_once: false, - congestion_cleared: Arc::clone(&self.congestion_cleared), + congestion: Arc::clone(&self.congestion), partition, })) } @@ -1400,7 +1425,7 @@ mod tests { pub struct CongestedStream { schema: SchemaRef, none_polled_once: bool, - congestion_cleared: Arc>, + congestion: Arc, partition: usize, } @@ -1408,31 +1433,22 @@ mod tests { type Item = Result; fn poll_next( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { match self.partition { 0 => { + let _ = self.congestion.check_congested(self.partition, cx); if self.none_polled_once { - panic!("Exhausted stream is polled more than one") + panic!("Exhausted stream is polled more than once") } else { self.none_polled_once = true; Poll::Ready(None) } } - 1 => { - let cleared = self.congestion_cleared.lock().unwrap(); - if *cleared { - Poll::Ready(None) - } else { - Poll::Pending - } - } - 2 => { - let mut cleared = self.congestion_cleared.lock().unwrap(); - *cleared = true; + _ => { + ready!(self.congestion.check_congested(self.partition, cx)); Poll::Ready(None) } - _ => unreachable!(), } } } @@ -1447,15 +1463,22 @@ mod tests { async fn test_spm_congestion() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]); + let properties = CongestedExec::compute_properties(Arc::new(schema.clone())); + let &partition_count = match properties.output_partitioning() { + Partitioning::RoundRobinBatch(partitions) => partitions, + Partitioning::Hash(_, partitions) => partitions, + Partitioning::UnknownPartitioning(partitions) => partitions, + }; let source = CongestedExec { schema: schema.clone(), - cache: CongestedExec::compute_properties(Arc::new(schema.clone())), - congestion_cleared: Arc::new(Mutex::new(false)), + cache: properties, + congestion: Arc::new(Congestion::new(partition_count)), }; let spm = SortPreservingMergeExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new( "c1", 0, - )))]), + )))] + .into(), Arc::new(source), ); let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx)); @@ -1464,12 +1487,8 @@ mod tests { match result { Ok(Ok(Ok(_batches))) => Ok(()), Ok(Ok(Err(e))) => Err(e), - Ok(Err(_)) => Err(DataFusionError::Execution( - "SortPreservingMerge task panicked or was cancelled".to_string(), - )), - Err(_) => Err(DataFusionError::Execution( - "SortPreservingMerge caused a deadlock".to_string(), - )), + Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"), + Err(_) => exec_err!("SortPreservingMerge caused a deadlock"), } } } diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index e029c60b285b6..97dd1761b14cf 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -21,8 +21,8 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::Array; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; -use arrow::row::{RowConverter, SortField}; -use datafusion_common::Result; +use arrow::row::{RowConverter, Rows, SortField}; +use datafusion_common::{internal_datafusion_err, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::stream::{Fuse, StreamExt}; @@ -76,8 +76,40 @@ impl FusedStreams { } } +/// A pair of `Arc` that can be reused +#[derive(Debug)] +struct ReusableRows { + // inner[stream_idx] holds a two Arcs: + // at start of a new poll + // .0 is the rows from the previous poll (at start), + // .1 is the one that is being written to + // at end of a poll, .0 will be swapped with .1, + inner: Vec<[Option>; 2]>, +} + +impl ReusableRows { + // return a Rows for writing, + // does not clone if the existing rows can be reused + fn take_next(&mut self, stream_idx: usize) -> Result { + Arc::try_unwrap(self.inner[stream_idx][1].take().unwrap()).map_err(|_| { + internal_datafusion_err!( + "Rows from RowCursorStream is still in use by consumer" + ) + }) + } + // save the Rows + fn save(&mut self, stream_idx: usize, rows: Arc) { + self.inner[stream_idx][1] = Some(Arc::clone(&rows)); + // swap the current with the previous one, so that the next poll can reuse the Rows from the previous poll + let [a, b] = &mut self.inner[stream_idx]; + std::mem::swap(a, b); + } +} + /// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`] /// and computes [`RowValues`] based on the provided [`PhysicalSortExpr`] +/// Note: the stream returns an error if the consumer buffers more than one RowValues (i.e. holds on to two RowValues +/// from the same partition at the same time). #[derive(Debug)] pub struct RowCursorStream { /// Converter to convert output of physical expressions @@ -88,6 +120,9 @@ pub struct RowCursorStream { streams: FusedStreams, /// Tracks the memory used by `converter` reservation: MemoryReservation, + /// Allocated rows for each partition, we keep two to allow for buffering one + /// in the consumer of the stream + rows: ReusableRows, } impl RowCursorStream { @@ -105,26 +140,48 @@ impl RowCursorStream { }) .collect::>>()?; - let streams = streams.into_iter().map(|s| s.fuse()).collect(); + let streams: Vec<_> = streams.into_iter().map(|s| s.fuse()).collect(); let converter = RowConverter::new(sort_fields)?; + let mut rows = Vec::with_capacity(streams.len()); + for _ in &streams { + // Initialize each stream with an empty Rows + rows.push([ + Some(Arc::new(converter.empty_rows(0, 0))), + Some(Arc::new(converter.empty_rows(0, 0))), + ]); + } Ok(Self { converter, reservation, column_expressions: expressions.iter().map(|x| Arc::clone(&x.expr)).collect(), streams: FusedStreams(streams), + rows: ReusableRows { inner: rows }, }) } - fn convert_batch(&mut self, batch: &RecordBatch) -> Result { + fn convert_batch( + &mut self, + batch: &RecordBatch, + stream_idx: usize, + ) -> Result { let cols = self .column_expressions .iter() .map(|expr| expr.evaluate(batch)?.into_array(batch.num_rows())) .collect::>>()?; - let rows = self.converter.convert_columns(&cols)?; + // At this point, ownership should of this Rows should be unique + let mut rows = self.rows.take_next(stream_idx)?; + + rows.clear(); + + self.converter.append(&mut rows, &cols)?; self.reservation.try_resize(self.converter.size())?; + let rows = Arc::new(rows); + + self.rows.save(stream_idx, Arc::clone(&rows)); + // track the memory in the newly created Rows. let mut rows_reservation = self.reservation.new_empty(); rows_reservation.try_grow(rows.size())?; @@ -146,7 +203,7 @@ impl PartitionedStream for RowCursorStream { ) -> Poll> { Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| { r.and_then(|batch| { - let cursor = self.convert_batch(&batch)?; + let cursor = self.convert_batch(&batch, stream_idx)?; Ok((cursor, batch)) }) })) diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 3f022ec6095ae..191b135753412 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -19,16 +19,22 @@ //! This is an order-preserving merge. use crate::metrics::BaselineMetrics; +use crate::sorts::multi_level_merge::MultiLevelMergeBuilder; use crate::sorts::{ merge::SortPreservingMergeStream, stream::{FieldCursorStream, RowCursorStream}, }; -use crate::SendableRecordBatchStream; +use crate::{SendableRecordBatchStream, SpillManager}; use arrow::array::*; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{internal_err, Result}; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::{ + human_readable_size, MemoryConsumer, MemoryPool, MemoryReservation, + UnboundedMemoryPool, +}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use std::sync::Arc; macro_rules! primitive_merge_helper { ($t:ty, $($v:ident),+) => { @@ -52,10 +58,31 @@ macro_rules! merge_helper { }}; } +pub struct SortedSpillFile { + pub file: RefCountedTempFile, + + /// how much memory the largest memory batch is taking + pub max_record_batch_memory: usize, +} + +impl std::fmt::Debug for SortedSpillFile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "SortedSpillFile({:?}) takes {}", + self.file.path(), + human_readable_size(self.max_record_batch_memory) + ) + } +} + +#[derive(Default)] pub struct StreamingMergeBuilder<'a> { streams: Vec, + sorted_spill_files: Vec, + spill_manager: Option, schema: Option, - expressions: &'a LexOrdering, + expressions: Option<&'a LexOrdering>, metrics: Option, batch_size: Option, fetch: Option, @@ -63,21 +90,6 @@ pub struct StreamingMergeBuilder<'a> { enable_round_robin_tie_breaker: bool, } -impl Default for StreamingMergeBuilder<'_> { - fn default() -> Self { - Self { - streams: vec![], - schema: None, - expressions: LexOrdering::empty(), - metrics: None, - batch_size: None, - fetch: None, - reservation: None, - enable_round_robin_tie_breaker: false, - } - } -} - impl<'a> StreamingMergeBuilder<'a> { pub fn new() -> Self { Self { @@ -91,13 +103,26 @@ impl<'a> StreamingMergeBuilder<'a> { self } + pub fn with_sorted_spill_files( + mut self, + sorted_spill_files: Vec, + ) -> Self { + self.sorted_spill_files = sorted_spill_files; + self + } + + pub fn with_spill_manager(mut self, spill_manager: SpillManager) -> Self { + self.spill_manager = Some(spill_manager); + self + } + pub fn with_schema(mut self, schema: SchemaRef) -> Self { self.schema = Some(schema); self } pub fn with_expressions(mut self, expressions: &'a LexOrdering) -> Self { - self.expressions = expressions; + self.expressions = Some(expressions); self } @@ -133,9 +158,22 @@ impl<'a> StreamingMergeBuilder<'a> { self } + /// Bypass the mempool and avoid using the memory reservation. + /// + /// This is not marked as `pub` because it is not recommended to use this method + pub(super) fn with_bypass_mempool(self) -> Self { + let mem_pool: Arc = Arc::new(UnboundedMemoryPool::default()); + + self.with_reservation( + MemoryConsumer::new("merge stream mock memory").register(&mem_pool), + ) + } + pub fn build(self) -> Result { let Self { streams, + sorted_spill_files, + spill_manager, schema, metrics, batch_size, @@ -145,21 +183,40 @@ impl<'a> StreamingMergeBuilder<'a> { enable_round_robin_tie_breaker, } = self; - // Early return if streams or expressions are empty - let checks = [ - ( - streams.is_empty(), - "Streams cannot be empty for streaming merge", - ), - ( - expressions.is_empty(), - "Sort expressions cannot be empty for streaming merge", - ), - ]; - - if let Some((_, error_message)) = checks.iter().find(|(condition, _)| *condition) - { - return internal_err!("{}", error_message); + // Early return if expressions are empty: + let Some(expressions) = expressions else { + return internal_err!("Sort expressions cannot be empty for streaming merge"); + }; + + if !sorted_spill_files.is_empty() { + // Unwrapping mandatory fields + let schema = schema.expect("Schema cannot be empty for streaming merge"); + let metrics = metrics.expect("Metrics cannot be empty for streaming merge"); + let batch_size = + batch_size.expect("Batch size cannot be empty for streaming merge"); + let reservation = + reservation.expect("Reservation cannot be empty for streaming merge"); + + return Ok(MultiLevelMergeBuilder::new( + spill_manager.expect("spill_manager should exist"), + schema, + sorted_spill_files, + streams, + expressions.clone(), + metrics, + batch_size, + reservation, + fetch, + enable_round_robin_tie_breaker, + ) + .create_spillable_merge_stream()); + } + + // Early return if streams are empty: + if streams.is_empty() { + return internal_err!( + "Streams/sorted spill files cannot be empty for streaming merge" + ); } // Unwrapping mandatory fields diff --git a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs index 8c1ed77559078..14917e23b7921 100644 --- a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs +++ b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs @@ -49,7 +49,12 @@ impl InProgressSpillFile { } } - /// Appends a `RecordBatch` to the file, initializing the writer if necessary. + /// Appends a `RecordBatch` to the spill file, initializing the writer if necessary. + /// + /// # Errors + /// - Returns an error if the file is not active (has been finalized) + /// - Returns an error if appending would exceed the disk usage limit configured + /// by `max_temp_directory_size` in `DiskManager` pub fn append_batch(&mut self, batch: &RecordBatch) -> Result<()> { if self.in_progress_file.is_none() { return Err(exec_datafusion_err!( @@ -62,6 +67,7 @@ impl InProgressSpillFile { self.writer = Some(IPCStreamWriter::new( in_progress_file.path(), schema.as_ref(), + self.spill_writer.compression, )?); // Update metrics @@ -69,10 +75,14 @@ impl InProgressSpillFile { } } if let Some(writer) = &mut self.writer { - let (spilled_rows, spilled_bytes) = writer.write(batch)?; + let (spilled_rows, _) = writer.write(batch)?; + if let Some(in_progress_file) = &mut self.in_progress_file { + in_progress_file.update_disk_usage()?; + } else { + unreachable!() // Already checked inside current function + } // Update metrics - self.spill_writer.metrics.spilled_bytes.add(spilled_bytes); self.spill_writer.metrics.spilled_rows.add(spilled_rows); } Ok(()) @@ -87,6 +97,14 @@ impl InProgressSpillFile { return Ok(None); } + // Since spill files are append-only, add the file size to spilled_bytes + if let Some(in_progress_file) = &mut self.in_progress_file { + // Since writer.finish() writes continuation marker and message length at the end + in_progress_file.update_disk_usage()?; + let size = in_progress_file.current_disk_usage(); + self.spill_writer.metrics.spilled_bytes.add(size as usize); + } + Ok(self.in_progress_file.take()) } } diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 88bf7953daeb4..5b9a91e781b16 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -23,25 +23,197 @@ pub(crate) mod spill_manager; use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::ptr::NonNull; +use std::sync::Arc; +use std::task::{Context, Poll}; -use arrow::array::ArrayData; +use arrow::array::{layout, ArrayData, BufferSpec}; use arrow::datatypes::{Schema, SchemaRef}; -use arrow::ipc::{reader::StreamReader, writer::StreamWriter}; +use arrow::ipc::{ + reader::StreamReader, + writer::{IpcWriteOptions, StreamWriter}, + MetadataVersion, +}; use arrow::record_batch::RecordBatch; -use tokio::sync::mpsc::Sender; -use datafusion_common::{exec_datafusion_err, HashSet, Result}; +use datafusion_common::config::SpillCompression; +use datafusion_common::{exec_datafusion_err, DataFusionError, HashSet, Result}; +use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::RecordBatchStream; +use futures::{FutureExt as _, Stream}; +use log::warn; -fn read_spill(sender: Sender>, path: &Path) -> Result<()> { - let file = BufReader::new(File::open(path)?); - let reader = StreamReader::try_new(file, None)?; - for batch in reader { - sender - .blocking_send(batch.map_err(Into::into)) - .map_err(|e| exec_datafusion_err!("{e}"))?; +/// Stream that reads spill files from disk where each batch is read in a spawned blocking task +/// It will read one batch at a time and will not do any buffering, to buffer data use [`crate::common::spawn_buffered`] +/// +/// A simpler solution would be spawning a long-running blocking task for each +/// file read (instead of each batch). This approach does not work because when +/// the number of concurrent reads exceeds the Tokio thread pool limit, +/// deadlocks can occur and block progress. +struct SpillReaderStream { + schema: SchemaRef, + state: SpillReaderStreamState, + /// Maximum memory size observed among spilling sorted record batches. + /// This is used for validation purposes during reading each RecordBatch from spill. + /// For context on why this value is recorded and validated, + /// see `physical_plan/sort/multi_level_merge.rs`. + max_record_batch_memory: Option, +} + +// Small margin allowed to accommodate slight memory accounting variation +const SPILL_BATCH_MEMORY_MARGIN: usize = 4096; + +/// When we poll for the next batch, we will get back both the batch and the reader, +/// so we can call `next` again. +type NextRecordBatchResult = Result<(StreamReader>, Option)>; + +enum SpillReaderStreamState { + /// Initial state: the stream was not initialized yet + /// and the file was not opened + Uninitialized(RefCountedTempFile), + + /// A read is in progress in a spawned blocking task for which we hold the handle. + ReadInProgress(SpawnedTask), + + /// A read has finished and we wait for being polled again in order to start reading the next batch. + Waiting(StreamReader>), + + /// The stream has finished, successfully or not. + Done, +} + +impl SpillReaderStream { + fn new( + schema: SchemaRef, + spill_file: RefCountedTempFile, + max_record_batch_memory: Option, + ) -> Self { + Self { + schema, + state: SpillReaderStreamState::Uninitialized(spill_file), + max_record_batch_memory, + } + } + + fn poll_next_inner( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + match &mut self.state { + SpillReaderStreamState::Uninitialized(_) => { + // Temporarily replace with `Done` to be able to pass the file to the task. + let SpillReaderStreamState::Uninitialized(spill_file) = + std::mem::replace(&mut self.state, SpillReaderStreamState::Done) + else { + unreachable!() + }; + + let task = SpawnedTask::spawn_blocking(move || { + let file = BufReader::new(File::open(spill_file.path())?); + // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications + // with validated schemas and buffers. Skip redundant validation during read + // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written. + let mut reader = unsafe { + StreamReader::try_new(file, None)?.with_skip_validation(true) + }; + + let next_batch = reader.next().transpose()?; + + Ok((reader, next_batch)) + }); + + self.state = SpillReaderStreamState::ReadInProgress(task); + + // Poll again immediately so the inner task is polled and the waker is + // registered. + self.poll_next_inner(cx) + } + + SpillReaderStreamState::ReadInProgress(task) => { + let result = futures::ready!(task.poll_unpin(cx)) + .unwrap_or_else(|err| Err(DataFusionError::External(Box::new(err)))); + + match result { + Ok((reader, batch)) => { + match batch { + Some(batch) => { + if let Some(max_record_batch_memory) = + self.max_record_batch_memory + { + let actual_size = + get_record_batch_memory_size(&batch); + if actual_size + > max_record_batch_memory + + SPILL_BATCH_MEMORY_MARGIN + { + warn!( + "Record batch memory usage ({actual_size} bytes) exceeds the expected limit ({max_record_batch_memory} bytes) \n\ + by more than the allowed tolerance ({SPILL_BATCH_MEMORY_MARGIN} bytes).\n\ + This likely indicates a bug in memory accounting during spilling.\n\ + Please report this issue in https://github.com/apache/datafusion/issues/17340." + ); + } + } + self.state = SpillReaderStreamState::Waiting(reader); + + Poll::Ready(Some(Ok(batch))) + } + None => { + // Stream is done + self.state = SpillReaderStreamState::Done; + + Poll::Ready(None) + } + } + } + Err(err) => { + self.state = SpillReaderStreamState::Done; + + Poll::Ready(Some(Err(err))) + } + } + } + + SpillReaderStreamState::Waiting(_) => { + // Temporarily replace with `Done` to be able to pass the file to the task. + let SpillReaderStreamState::Waiting(mut reader) = + std::mem::replace(&mut self.state, SpillReaderStreamState::Done) + else { + unreachable!() + }; + + let task = SpawnedTask::spawn_blocking(move || { + let next_batch = reader.next().transpose()?; + + Ok((reader, next_batch)) + }); + + self.state = SpillReaderStreamState::ReadInProgress(task); + + // Poll again immediately so the inner task is polled and the waker is + // registered. + self.poll_next_inner(cx) + } + + SpillReaderStreamState::Done => Poll::Ready(None), + } + } +} + +impl Stream for SpillReaderStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().poll_next_inner(cx) + } +} + +impl RecordBatchStream for SpillReaderStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) } - Ok(()) } /// Spill the `RecordBatch` to disk as smaller batches @@ -58,7 +230,8 @@ pub fn spill_record_batch_by_size( ) -> Result<()> { let mut offset = 0; let total_rows = batch.num_rows(); - let mut writer = IPCStreamWriter::new(&path, schema.as_ref())?; + let mut writer = + IPCStreamWriter::new(&path, schema.as_ref(), SpillCompression::Uncompressed)?; while offset < total_rows { let length = std::cmp::min(total_rows - offset, batch_size_rows); @@ -156,15 +329,32 @@ struct IPCStreamWriter { impl IPCStreamWriter { /// Create new writer - pub fn new(path: &Path, schema: &Schema) -> Result { + pub fn new( + path: &Path, + schema: &Schema, + compression_type: SpillCompression, + ) -> Result { let file = File::create(path).map_err(|e| { - exec_datafusion_err!("Failed to create partition file at {path:?}: {e:?}") + exec_datafusion_err!("(Hint: you may increase the file descriptor limit with shell command 'ulimit -n 4096') Failed to create partition file at {path:?}: {e:?}") })?; + + let metadata_version = MetadataVersion::V5; + // Depending on the schema, some array types such as StringViewArray require larger (16 byte in this case) alignment. + // If the actual buffer layout after IPC read does not satisfy the alignment requirement, + // Arrow ArrayBuilder will copy the buffer into a newly allocated, properly aligned buffer. + // This copying may lead to memory blowup during IPC read due to duplicated buffers. + // To avoid this, we compute the maximum required alignment based on the schema and configure the IPCStreamWriter accordingly. + let alignment = get_max_alignment_for_schema(schema); + let mut write_options = + IpcWriteOptions::try_new(alignment, false, metadata_version)?; + write_options = write_options.try_with_compression(compression_type.into())?; + + let writer = StreamWriter::try_new_with_options(file, schema, write_options)?; Ok(Self { num_batches: 0, num_rows: 0, num_bytes: 0, - writer: StreamWriter::try_new(file, schema)?, + writer, }) } @@ -187,6 +377,29 @@ impl IPCStreamWriter { } } +// Returns the maximum byte alignment required by any field in the schema (>= 8), derived from Arrow buffer layouts. +fn get_max_alignment_for_schema(schema: &Schema) -> usize { + let minimum_alignment = 8; + let mut max_alignment = minimum_alignment; + for field in schema.fields() { + let layout = layout(field.data_type()); + let required_alignment = layout + .buffers + .iter() + .map(|buffer_spec| { + if let BufferSpec::FixedWidth { alignment, .. } = buffer_spec { + *alignment + } else { + minimum_alignment + } + }) + .max() + .unwrap_or(minimum_alignment); + max_alignment = std::cmp::max(max_alignment, required_alignment); + } + max_alignment +} + #[cfg(test)] mod tests { use super::in_progress_spill_file::InProgressSpillFile; @@ -196,12 +409,13 @@ mod tests { use crate::metrics::SpillMetrics; use crate::spill::spill_manager::SpillManager; use crate::test::build_table_i32; - use arrow::array::{Float64Array, Int32Array, ListArray, StringArray}; + use arrow::array::{ArrayRef, Float64Array, Int32Array, ListArray, StringArray}; use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::runtime_env::RuntimeEnv; + use futures::StreamExt as _; use std::sync::Arc; @@ -234,7 +448,7 @@ mod tests { let spilled_rows = spill_manager.metrics.spilled_rows.value(); assert_eq!(spilled_rows, num_rows); - let stream = spill_manager.read_spill_as_stream(spill_file)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None)?; assert_eq!(stream.schema(), schema); let batches = collect(stream).await?; @@ -298,7 +512,7 @@ mod tests { let spilled_rows = spill_manager.metrics.spilled_rows.value(); assert_eq!(spilled_rows, num_rows); - let stream = spill_manager.read_spill_as_stream(spill_file)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None)?; assert_eq!(stream.schema(), dict_schema); let batches = collect(stream).await?; assert_eq!(batches.len(), 2); @@ -319,12 +533,17 @@ mod tests { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); - let spill_file = spill_manager - .spill_record_batch_by_size(&batch1, "Test Spill", 1)? + let (spill_file, max_batch_mem) = spill_manager + .spill_record_batch_by_size_and_return_max_batch_memory( + &batch1, + "Test Spill", + 1, + )? .unwrap(); assert!(spill_file.path().exists()); + assert!(max_batch_mem > 0); - let stream = spill_manager.read_spill_as_stream(spill_file)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None)?; assert_eq!(stream.schema(), schema); let batches = collect(stream).await?; @@ -333,6 +552,113 @@ mod tests { Ok(()) } + fn build_compressible_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, true), + ])); + + let a: ArrayRef = Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "repeated", 100, + ))); + let b: ArrayRef = Arc::new(Int32Array::from(vec![1; 100])); + let c: ArrayRef = Arc::new(Int32Array::from(vec![2; 100])); + + RecordBatch::try_new(schema, vec![a, b, c]).unwrap() + } + + async fn validate( + spill_manager: &SpillManager, + spill_file: RefCountedTempFile, + num_rows: usize, + schema: SchemaRef, + batch_count: usize, + ) -> Result<()> { + let spilled_rows = spill_manager.metrics.spilled_rows.value(); + assert_eq!(spilled_rows, num_rows); + + let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + assert_eq!(stream.schema(), schema); + + let batches = collect(stream).await?; + assert_eq!(batches.len(), batch_count); + + Ok(()) + } + + #[tokio::test] + async fn test_spill_compression() -> Result<()> { + let batch = build_compressible_batch(); + let num_rows = batch.num_rows(); + let schema = batch.schema(); + let batch_count = 1; + let batches = [batch]; + + // Construct SpillManager + let env = Arc::new(RuntimeEnv::default()); + let uncompressed_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let lz4_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let zstd_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let uncompressed_spill_manager = SpillManager::new( + Arc::clone(&env), + uncompressed_metrics, + Arc::clone(&schema), + ); + let lz4_spill_manager = + SpillManager::new(Arc::clone(&env), lz4_metrics, Arc::clone(&schema)) + .with_compression_type(SpillCompression::Lz4Frame); + let zstd_spill_manager = + SpillManager::new(env, zstd_metrics, Arc::clone(&schema)) + .with_compression_type(SpillCompression::Zstd); + let uncompressed_spill_file = uncompressed_spill_manager + .spill_record_batch_and_finish(&batches, "Test")? + .unwrap(); + let lz4_spill_file = lz4_spill_manager + .spill_record_batch_and_finish(&batches, "Lz4_Test")? + .unwrap(); + let zstd_spill_file = zstd_spill_manager + .spill_record_batch_and_finish(&batches, "ZSTD_Test")? + .unwrap(); + assert!(uncompressed_spill_file.path().exists()); + assert!(lz4_spill_file.path().exists()); + assert!(zstd_spill_file.path().exists()); + + let lz4_spill_size = std::fs::metadata(lz4_spill_file.path())?.len(); + let zstd_spill_size = std::fs::metadata(zstd_spill_file.path())?.len(); + let uncompressed_spill_size = + std::fs::metadata(uncompressed_spill_file.path())?.len(); + + assert!(uncompressed_spill_size > lz4_spill_size); + assert!(uncompressed_spill_size > zstd_spill_size); + + validate( + &lz4_spill_manager, + lz4_spill_file, + num_rows, + Arc::clone(&schema), + batch_count, + ) + .await?; + validate( + &zstd_spill_manager, + zstd_spill_file, + num_rows, + Arc::clone(&schema), + batch_count, + ) + .await?; + validate( + &uncompressed_spill_manager, + uncompressed_spill_file, + num_rows, + schema, + batch_count, + ) + .await?; + Ok(()) + } + #[test] fn test_get_record_batch_memory_size() { // Create a simple record batch with two columns @@ -457,7 +783,7 @@ mod tests { .unwrap(); let size = get_record_batch_memory_size(&batch); - assert_eq!(size, 8320); + assert_eq!(size, 8208); } // ==== Spill manager tests ==== @@ -547,12 +873,13 @@ mod tests { Arc::new(StringArray::from(vec!["d", "e", "f"])), ], )?; - + // After appending each batch, spilled_rows should increase, while spill_file_count and + // spilled_bytes remain the same (spilled_bytes is updated only after finish() is called) in_progress_file.append_batch(&batch1)?; - verify_metrics(&in_progress_file, 1, 356, 3)?; + verify_metrics(&in_progress_file, 1, 0, 3)?; in_progress_file.append_batch(&batch2)?; - verify_metrics(&in_progress_file, 1, 712, 6)?; + verify_metrics(&in_progress_file, 1, 0, 6)?; let completed_file = in_progress_file.finish()?; assert!(completed_file.is_some()); @@ -587,7 +914,7 @@ mod tests { let completed_file = spill_manager.spill_record_batch_and_finish(&[], "Test")?; assert!(completed_file.is_none()); - // Test write empty batch with interface `spill_record_batch_by_size()` + // Test write empty batch with interface `spill_record_batch_by_size_and_return_max_batch_memory()` let empty_batch = RecordBatch::try_new( Arc::clone(&schema), vec![ @@ -595,10 +922,69 @@ mod tests { Arc::new(StringArray::from(Vec::>::new())), ], )?; - let completed_file = - spill_manager.spill_record_batch_by_size(&empty_batch, "Test", 1)?; + let completed_file = spill_manager + .spill_record_batch_by_size_and_return_max_batch_memory( + &empty_batch, + "Test", + 1, + )?; assert!(completed_file.is_none()); Ok(()) } + + #[test] + fn test_reading_more_spills_than_tokio_blocking_threads() -> Result<()> { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .max_blocking_threads(1) + .build() + .unwrap() + .block_on(async { + let batch = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + + let schema = batch.schema(); + + // Construct SpillManager + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let batches: [_; 10] = std::array::from_fn(|_| batch.clone()); + + let spill_file_1 = spill_manager + .spill_record_batch_and_finish(&batches, "Test1")? + .unwrap(); + let spill_file_2 = spill_manager + .spill_record_batch_and_finish(&batches, "Test2")? + .unwrap(); + + let mut stream_1 = + spill_manager.read_spill_as_stream(spill_file_1, None)?; + let mut stream_2 = + spill_manager.read_spill_as_stream(spill_file_2, None)?; + stream_1.next().await; + stream_2.next().await; + + Ok(()) + }) + } + + #[test] + fn test_alignment_for_schema() -> Result<()> { + let schema = Schema::new(vec![Field::new("strings", DataType::Utf8View, false)]); + let alignment = get_max_alignment_for_schema(&schema); + assert_eq!(alignment, 16); + + let schema = Schema::new(vec![ + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + ]); + let alignment = get_max_alignment_for_schema(&schema); + assert_eq!(alignment, 8); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 4a8e293323f02..cc39102d89819 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,20 +17,19 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. -use std::sync::Arc; - +use arrow::array::StringViewArray; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_execution::runtime_env::RuntimeEnv; +use std::sync::Arc; -use datafusion_common::Result; +use datafusion_common::{config::SpillCompression, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::SendableRecordBatchStream; -use crate::metrics::SpillMetrics; -use crate::stream::RecordBatchReceiverStream; - -use super::{in_progress_spill_file::InProgressSpillFile, read_spill}; +use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; +use crate::coop::cooperative; +use crate::{common::spawn_buffered, metrics::SpillMetrics}; /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. @@ -45,7 +44,8 @@ pub struct SpillManager { schema: SchemaRef, /// Number of batches to buffer in memory during disk reads batch_read_buffer_capacity: usize, - // TODO: Add general-purpose compression options + /// general-purpose compression options + pub(crate) compression: SpillCompression, } impl SpillManager { @@ -55,9 +55,23 @@ impl SpillManager { metrics, schema, batch_read_buffer_capacity: 2, + compression: SpillCompression::default(), } } + pub fn with_batch_read_buffer_capacity( + mut self, + batch_read_buffer_capacity: usize, + ) -> Self { + self.batch_read_buffer_capacity = batch_read_buffer_capacity; + self + } + + pub fn with_compression_type(mut self, spill_compression: SpillCompression) -> Self { + self.compression = spill_compression; + self + } + /// Creates a temporary file for in-progress operations, returning an error /// message if file creation fails. The file can be used to append batches /// incrementally and then finish the file when done. @@ -73,7 +87,10 @@ impl SpillManager { /// intended to incrementally write in-memory batches into the same spill file, /// use [`Self::create_in_progress_file`] instead. /// None is returned if no batches are spilled. - #[allow(dead_code)] // TODO: remove after change SMJ to use SpillManager + /// + /// # Errors + /// - Returns an error if spilling would exceed the disk usage limit configured + /// by `max_temp_directory_size` in `DiskManager` pub fn spill_record_batch_and_finish( &self, batches: &[RecordBatch], @@ -90,13 +107,16 @@ impl SpillManager { /// Refer to the documentation for [`Self::spill_record_batch_and_finish`]. This method /// additionally spills the `RecordBatch` into smaller batches, divided by `row_limit`. - #[allow(dead_code)] // TODO: remove after change aggregate to use SpillManager - pub fn spill_record_batch_by_size( + /// + /// # Errors + /// - Returns an error if spilling would exceed the disk usage limit configured + /// by `max_temp_directory_size` in `DiskManager` + pub(crate) fn spill_record_batch_by_size_and_return_max_batch_memory( &self, batch: &RecordBatch, request_description: &str, row_limit: usize, - ) -> Result> { + ) -> Result> { let total_rows = batch.num_rows(); let mut batches = Vec::new(); let mut offset = 0; @@ -109,8 +129,43 @@ impl SpillManager { offset += length; } - // Spill the sliced batches to disk - self.spill_record_batch_and_finish(&batches, request_description) + let mut in_progress_file = self.create_in_progress_file(request_description)?; + + let mut max_record_batch_size = 0; + + for batch in batches { + in_progress_file.append_batch(&batch)?; + + max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); + } + + let file = in_progress_file.finish()?; + + Ok(file.map(|f| (f, max_record_batch_size))) + } + + /// Spill a stream of `RecordBatch`es to disk and return the spill file and the size of the largest batch in memory + pub(crate) async fn spill_record_batch_stream_and_return_max_batch_memory( + &self, + stream: &mut SendableRecordBatchStream, + request_description: &str, + ) -> Result> { + use futures::StreamExt; + + let mut in_progress_file = self.create_in_progress_file(request_description)?; + + let mut max_record_batch_size = 0; + + while let Some(batch) = stream.next().await { + let batch = batch?; + in_progress_file.append_batch(&batch)?; + + max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); + } + + let file = in_progress_file.finish()?; + + Ok(file.map(|f| (f, max_record_batch_size))) } /// Reads a spill file as a stream. The file must be created by the current `SpillManager`. @@ -119,15 +174,108 @@ impl SpillManager { pub fn read_spill_as_stream( &self, spill_file_path: RefCountedTempFile, + max_record_batch_memory: Option, ) -> Result { - let mut builder = RecordBatchReceiverStream::builder( + let stream = Box::pin(cooperative(SpillReaderStream::new( Arc::clone(&self.schema), - self.batch_read_buffer_capacity, + spill_file_path, + max_record_batch_memory, + ))); + + Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) + } +} + +pub(crate) trait GetSlicedSize { + /// Returns the size of the `RecordBatch` when sliced. + /// Note: if multiple arrays or even a single array share the same data buffers, we may double count each buffer. + /// Therefore, make sure we call gc() or organize_stringview_arrays() before using this method. + fn get_sliced_size(&self) -> Result; +} + +impl GetSlicedSize for RecordBatch { + fn get_sliced_size(&self) -> Result { + let mut total = 0; + for array in self.columns() { + let data = array.to_data(); + total += data.get_slice_memory_size()?; + + // While StringViewArray holds large data buffer for non inlined string, the Arrow layout (BufferSpec) + // does not include any data buffers. Currently, ArrayData::get_slice_memory_size() + // under-counts memory size by accounting only views buffer although data buffer is cloned during slice() + // + // Therefore, we manually add the sum of the lengths used by all non inlined views + // on top of the sliced size for views buffer. This matches the intended semantics of + // "bytes needed if we materialized exactly this slice into fresh buffers". + // This is a workaround until https://github.com/apache/arrow-rs/issues/8230 + if let Some(sv) = array.as_any().downcast_ref::() { + for buffer in sv.data_buffers() { + total += buffer.capacity(); + } + } + } + Ok(total) + } +} + +#[cfg(test)] +mod tests { + use crate::spill::{get_record_batch_memory_size, spill_manager::GetSlicedSize}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::{ + array::{ArrayRef, StringViewArray}, + record_batch::RecordBatch, + }; + use datafusion_common::Result; + use std::sync::Arc; + + #[test] + fn check_sliced_size_for_string_view_array() -> Result<()> { + let array_length = 50; + let short_len = 8; + let long_len = 25; + + // Build StringViewArray that includes both inline strings and non inlined strings + let strings: Vec = (0..array_length) + .map(|i| { + if i % 2 == 0 { + "a".repeat(short_len) + } else { + "b".repeat(long_len) + } + }) + .collect(); + + let string_array = StringViewArray::from(strings); + let array_ref: ArrayRef = Arc::new(string_array); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "strings", + DataType::Utf8View, + false, + )])), + vec![array_ref], + ) + .unwrap(); + + // We did not slice the batch, so these two memory size should be equal + assert_eq!( + batch.get_sliced_size().unwrap(), + get_record_batch_memory_size(&batch) ); - let sender = builder.tx(); - builder.spawn_blocking(move || read_spill(sender, spill_file_path.path())); + // Slice the batch into half + let half_batch = batch.slice(0, array_length / 2); + // Now sliced_size is smaller because the views buffer is sliced + assert!( + half_batch.get_sliced_size().unwrap() + < get_record_batch_memory_size(&half_batch) + ); + let data = arrow::array::Array::to_data(&half_batch.column(0)); + let views_sliced_size = data.get_slice_memory_size()?; + // The sliced size should be larger than sliced views buffer size + assert!(views_sliced_size < half_batch.get_sliced_size().unwrap()); - Ok(builder.build()) + Ok(()) } } diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 338ac7d048a33..100a6a7ffcc08 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -22,7 +22,9 @@ use std::sync::Arc; use std::task::Context; use std::task::Poll; -use super::metrics::BaselineMetrics; +#[cfg(test)] +use super::metrics::ExecutionPlanMetricsSet; +use super::metrics::{BaselineMetrics, SplitMetrics}; use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; use crate::displayable; @@ -31,10 +33,12 @@ use datafusion_common::{exec_err, Result}; use datafusion_common_runtime::JoinSet; use datafusion_execution::TaskContext; +use futures::ready; use futures::stream::BoxStream; use futures::{Future, Stream, StreamExt}; use log::debug; use pin_project_lite::pin_project; +use tokio::runtime::Handle; use tokio::sync::mpsc::{Receiver, Sender}; /// Creates a stream from a collection of producing tasks, routing panics to the stream. @@ -81,6 +85,15 @@ impl ReceiverStreamBuilder { self.join_set.spawn(task); } + /// Same as [`Self::spawn`] but it spawns the task on the provided runtime + pub fn spawn_on(&mut self, task: F, handle: &Handle) + where + F: Future>, + F: Send + 'static, + { + self.join_set.spawn_on(task, handle); + } + /// Spawn a blocking task that will be aborted if this builder (or the stream /// built from it) are dropped. /// @@ -94,6 +107,15 @@ impl ReceiverStreamBuilder { self.join_set.spawn_blocking(f); } + /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime + pub fn spawn_blocking_on(&mut self, f: F, handle: &Handle) + where + F: FnOnce() -> Result<()>, + F: Send + 'static, + { + self.join_set.spawn_blocking_on(f, handle); + } + /// Create a stream of all data written to `tx` pub fn build(self) -> BoxStream<'static, Result> { let Self { @@ -245,6 +267,15 @@ impl RecordBatchReceiverStreamBuilder { self.inner.spawn(task) } + /// Same as [`Self::spawn`] but it spawns the task on the provided runtime. + pub fn spawn_on(&mut self, task: F, handle: &Handle) + where + F: Future>, + F: Send + 'static, + { + self.inner.spawn_on(task, handle) + } + /// Spawn a blocking task tied to the builder and stream. /// /// # Drop / Cancel Behavior @@ -272,6 +303,15 @@ impl RecordBatchReceiverStreamBuilder { self.inner.spawn_blocking(f) } + /// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime. + pub fn spawn_blocking_on(&mut self, f: F, handle: &Handle) + where + F: FnOnce() -> Result<()>, + F: Send + 'static, + { + self.inner.spawn_blocking_on(f, handle) + } + /// Runs the `partition` of the `input` ExecutionPlan on the /// tokio thread pool and writes its outputs to this stream /// @@ -522,6 +562,138 @@ impl Stream for ObservedStream { } } +pin_project! { + /// Stream wrapper that splits large [`RecordBatch`]es into smaller batches. + /// + /// This ensures upstream operators receive batches no larger than + /// `batch_size`, which can improve parallelism when data sources + /// generate very large batches. + /// + /// # Fields + /// + /// - `current_batch`: The batch currently being split, if any + /// - `offset`: Index of the next row to split from `current_batch`. + /// This tracks our position within the current batch being split. + /// + /// # Invariants + /// + /// - `offset` is always ≤ `current_batch.num_rows()` when `current_batch` is `Some` + /// - When `current_batch` is `None`, `offset` is always 0 + /// - `batch_size` is always > 0 +pub struct BatchSplitStream { + #[pin] + input: SendableRecordBatchStream, + schema: SchemaRef, + batch_size: usize, + metrics: SplitMetrics, + current_batch: Option, + offset: usize, + } +} + +impl BatchSplitStream { + /// Create a new [`BatchSplitStream`] + pub fn new( + input: SendableRecordBatchStream, + batch_size: usize, + metrics: SplitMetrics, + ) -> Self { + let schema = input.schema(); + Self { + input, + schema, + batch_size, + metrics, + current_batch: None, + offset: 0, + } + } + + /// Attempt to produce the next sliced batch from the current batch. + /// + /// Returns `Some(batch)` if a slice was produced, `None` if the current batch + /// is exhausted and we need to poll upstream for more data. + fn next_sliced_batch(&mut self) -> Option> { + let batch = self.current_batch.take()?; + + // Assert slice boundary safety - offset should never exceed batch size + debug_assert!( + self.offset <= batch.num_rows(), + "Offset {} exceeds batch size {}", + self.offset, + batch.num_rows() + ); + + let remaining = batch.num_rows() - self.offset; + let to_take = remaining.min(self.batch_size); + let out = batch.slice(self.offset, to_take); + + self.metrics.batches_split.add(1); + self.offset += to_take; + if self.offset < batch.num_rows() { + // More data remains in this batch, store it back + self.current_batch = Some(batch); + } else { + // Batch is exhausted, reset offset + // Note: current_batch is already None since we took it at the start + self.offset = 0; + } + Some(Ok(out)) + } + + /// Poll the upstream input for the next batch. + /// + /// Returns the appropriate `Poll` result based on upstream state. + /// Small batches are passed through directly, large batches are stored + /// for slicing and return the first slice immediately. + fn poll_upstream( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + match ready!(self.input.as_mut().poll_next(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() <= self.batch_size { + // Small batch, pass through directly + Poll::Ready(Some(Ok(batch))) + } else { + // Large batch, store for slicing and return first slice + self.current_batch = Some(batch); + // Immediately produce the first slice + match self.next_sliced_batch() { + Some(result) => Poll::Ready(Some(result)), + None => Poll::Ready(None), // Should not happen + } + } + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +impl Stream for BatchSplitStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // First, try to produce a slice from the current batch + if let Some(result) = self.next_sliced_batch() { + return Poll::Ready(Some(result)); + } + + // No current batch or current batch exhausted, poll upstream + self.poll_upstream(cx) + } +} + +impl RecordBatchStream for BatchSplitStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + #[cfg(test)] mod test { use super::*; @@ -616,6 +788,44 @@ mod test { assert!(stream.next().await.is_none()); } + #[tokio::test] + async fn batch_split_stream_basic_functionality() { + use arrow::array::{Int32Array, RecordBatch}; + use futures::stream::{self, StreamExt}; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a large batch that should be split + let large_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from((0..2000).collect::>()))], + ) + .unwrap(); + + // Create a stream with the large batch + let input_stream = stream::iter(vec![Ok(large_batch)]); + let adapter = RecordBatchStreamAdapter::new(Arc::clone(&schema), input_stream); + let batch_stream = Box::pin(adapter) as SendableRecordBatchStream; + + // Create a BatchSplitStream with batch_size = 500 + let metrics = ExecutionPlanMetricsSet::new(); + let split_metrics = SplitMetrics::new(&metrics, 0); + let mut split_stream = BatchSplitStream::new(batch_stream, 500, split_metrics); + + let mut total_rows = 0; + let mut batch_count = 0; + + while let Some(result) = split_stream.next().await { + let batch = result.unwrap(); + assert!(batch.num_rows() <= 500, "Batch size should not exceed 500"); + total_rows += batch.num_rows(); + batch_count += 1; + } + + assert_eq!(total_rows, 2000, "All rows should be preserved"); + assert_eq!(batch_count, 4, "Should have 4 batches of 500 rows each"); + } + /// Consumes all the input's partitions into a /// RecordBatchReceiverStream and runs it to completion /// @@ -649,4 +859,67 @@ mod test { ); } } + + #[test] + fn record_batch_receiver_stream_builder_spawn_on_runtime() { + let tokio_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let mut builder = + RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10); + + let tx1 = builder.tx(); + builder.spawn_on( + async move { + tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty())))) + .await + .unwrap(); + + Ok(()) + }, + tokio_runtime.handle(), + ); + + let tx2 = builder.tx(); + builder.spawn_blocking_on( + move || { + tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty())))) + .unwrap(); + + Ok(()) + }, + tokio_runtime.handle(), + ); + + let mut stream = builder.build(); + + let mut number_of_batches = 0; + + loop { + let poll = stream.poll_next_unpin(&mut Context::from_waker( + futures::task::noop_waker_ref(), + )); + + match poll { + Poll::Ready(None) => { + break; + } + Poll::Ready(Some(Ok(batch))) => { + number_of_batches += 1; + assert_eq!(batch.num_rows(), 0); + } + Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"), + Poll::Pending => { + continue; + } + } + } + + assert_eq!( + number_of_batches, 2, + "Should have received exactly one empty batch" + ); + } } diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 18c472a7e1874..f9a7feb9e726e 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -22,12 +22,13 @@ use std::fmt::Debug; use std::sync::Arc; use super::{DisplayAs, DisplayFormatType, PlanProperties}; +use crate::coop::make_cooperative; use crate::display::{display_orderings, ProjectSchemaDisplay}; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ - all_alias_free_columns, new_projections_for_columns, update_expr, ProjectionExec, + all_alias_free_columns, new_projections_for_columns, update_ordering, ProjectionExec, }; use crate::stream::RecordBatchStreamAdapter; use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; @@ -35,7 +36,7 @@ use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use async_trait::async_trait; use futures::stream::StreamExt; @@ -99,7 +100,7 @@ impl StreamingTableExec { projected_output_ordering.into_iter().collect::>(); let cache = Self::compute_properties( Arc::clone(&projected_schema), - &projected_output_ordering, + projected_output_ordering.clone(), &partitions, infinite, ); @@ -146,7 +147,7 @@ impl StreamingTableExec { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( schema: SchemaRef, - orderings: &[LexOrdering], + orderings: Vec, partitions: &[Arc], infinite: bool, ) -> PlanProperties { @@ -168,6 +169,7 @@ impl StreamingTableExec { EmissionType::Incremental, boundedness, ) + .with_scheduling_type(SchedulingType::Cooperative) } } @@ -262,7 +264,7 @@ impl ExecutionPlan for StreamingTableExec { partition: usize, ctx: Arc, ) -> Result { - let stream = self.partitions[partition].execute(ctx); + let stream = self.partitions[partition].execute(Arc::clone(&ctx)); let projected_stream = match self.projection.clone() { Some(projection) => Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&self.projected_schema), @@ -272,16 +274,13 @@ impl ExecutionPlan for StreamingTableExec { )), None => stream, }; + let stream = make_cooperative(projected_stream); + Ok(match self.limit { - None => projected_stream, + None => stream, Some(fetch) => { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - Box::pin(LimitStream::new( - projected_stream, - 0, - Some(fetch), - baseline_metrics, - )) + Box::pin(LimitStream::new(stream, 0, Some(fetch), baseline_metrics)) } }) } @@ -300,26 +299,17 @@ impl ExecutionPlan for StreamingTableExec { let streaming_table_projections = self.projection().as_ref().map(|i| i.as_ref().to_vec()); let new_projections = new_projections_for_columns( - projection, + projection.expr(), &streaming_table_projections - .unwrap_or((0..self.schema().fields().len()).collect()), + .unwrap_or_else(|| (0..self.schema().fields().len()).collect()), ); let mut lex_orderings = vec![]; - for lex_ordering in self.projected_output_ordering().into_iter() { - let mut orderings = LexOrdering::default(); - for order in lex_ordering { - let Some(new_ordering) = - update_expr(&order.expr, projection.expr(), false)? - else { - return Ok(None); - }; - orderings.push(PhysicalSortExpr { - expr: new_ordering, - options: order.options, - }); - } - lex_orderings.push(orderings); + for ordering in self.projected_output_ordering().into_iter() { + let Some(ordering) = update_ordering(ordering, projection.expr())? else { + return Ok(None); + }; + lex_orderings.push(ordering); } StreamingTableExec::try_new( diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index a2dc1d778436a..349f9955b6914 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -40,10 +40,12 @@ use datafusion_common::{ config::ConfigOptions, internal_err, project_schema, Result, Statistics, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{ - equivalence::ProjectionMapping, expressions::Column, utils::collect_columns, - EquivalenceProperties, LexOrdering, Partitioning, +use datafusion_physical_expr::equivalence::{ + OrderingEquivalenceClass, ProjectionMapping, }; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, Partitioning}; use futures::{Future, FutureExt}; @@ -87,9 +89,7 @@ impl DisplayAs for TestMemoryExec { let output_ordering = self .sort_information .first() - .map(|output_ordering| { - format!(", output_ordering={}", output_ordering) - }) + .map(|output_ordering| format!(", output_ordering={output_ordering}")) .unwrap_or_default(); let eq_properties = self.eq_properties(); @@ -97,12 +97,12 @@ impl DisplayAs for TestMemoryExec { let constraints = if constraints.is_empty() { String::new() } else { - format!(", {}", constraints) + format!(", {constraints}") }; let limit = self .fetch - .map_or(String::new(), |limit| format!(", fetch={}", limit)); + .map_or(String::new(), |limit| format!(", fetch={limit}")); if self.show_sizes { write!( f, @@ -131,7 +131,7 @@ impl ExecutionPlan for TestMemoryExec { } fn as_any(&self) -> &dyn Any { - unimplemented!() + self } fn properties(&self) -> &PlanProperties { @@ -170,7 +170,15 @@ impl ExecutionPlan for TestMemoryExec { } fn statistics(&self) -> Result { - self.statistics() + self.statistics_inner() + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + Ok(Statistics::new_unknown(&self.schema)) + } else { + self.statistics_inner() + } } fn fetch(&self) -> Option { @@ -210,11 +218,11 @@ impl TestMemoryExec { fn eq_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new_with_orderings( Arc::clone(&self.projected_schema), - self.sort_information.as_slice(), + self.sort_information.clone(), ) } - fn statistics(&self) -> Result { + fn statistics_inner(&self) -> Result { Ok(common::compute_record_batch_statistics( &self.partitions, &self.schema, @@ -234,7 +242,7 @@ impl TestMemoryExec { cache: PlanProperties::new( EquivalenceProperties::new_with_orderings( Arc::clone(&projected_schema), - vec![].as_slice(), + Vec::::new(), ), Partitioning::UnknownPartitioning(partitions.len()), EmissionType::Incremental, @@ -292,7 +300,7 @@ impl TestMemoryExec { } /// refer to `try_with_sort_information` at MemorySourceConfig for more information. - /// https://github.com/apache/datafusion/tree/main/datafusion/datasource/src/memory.rs + /// pub fn try_with_sort_information( mut self, mut sort_information: Vec, @@ -318,24 +326,21 @@ impl TestMemoryExec { // If there is a projection on the source, we also need to project orderings if let Some(projection) = &self.projection { + let base_schema = self.original_schema(); + let proj_exprs = projection.iter().map(|idx| { + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }); + let projection_mapping = + ProjectionMapping::try_new(proj_exprs, &base_schema)?; let base_eqp = EquivalenceProperties::new_with_orderings( - self.original_schema(), - &sort_information, + Arc::clone(&base_schema), + sort_information, ); - let proj_exprs = projection - .iter() - .map(|idx| { - let base_schema = self.original_schema(); - let name = base_schema.field(*idx).name(); - (Arc::new(Column::new(name, *idx)) as _, name.to_string()) - }) - .collect::>(); - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; - sort_information = base_eqp - .project(&projection_mapping, Arc::clone(&self.projected_schema)) - .into_oeq_class() - .into_inner(); + let proj_eqp = + base_eqp.project(&projection_mapping, Arc::clone(&self.projected_schema)); + let oeq_class: OrderingEquivalenceClass = proj_eqp.into(); + sort_information = oeq_class.into(); } self.sort_information = sort_information; @@ -450,7 +455,7 @@ pub fn make_partition_utf8(sz: i32) -> RecordBatch { let seq_start = 0; let seq_end = sz; let values = (seq_start..seq_end) - .map(|i| format!("test_long_string_that_is_roughly_42_bytes_{}", i)) + .map(|i| format!("test_long_string_that_is_roughly_42_bytes_{i}")) .collect::>(); let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Utf8, true)])); let mut string_array = arrow::array::StringArray::from(values); @@ -517,3 +522,33 @@ impl PartitionStream for TestPartitionStream { )) } } + +#[cfg(test)] +macro_rules! assert_join_metrics { + ($metrics:expr, $expected_rows:expr) => { + assert_eq!($metrics.output_rows().unwrap(), $expected_rows); + + let elapsed_compute = $metrics + .elapsed_compute() + .expect("did not find elapsed_compute metric"); + let join_time = $metrics + .sum_by_name("join_time") + .expect("did not find join_time metric") + .as_usize(); + let build_time = $metrics + .sum_by_name("build_time") + .expect("did not find build_time metric") + .as_usize(); + // ensure join_time and build_time are considered in elapsed_compute + assert!( + join_time + build_time <= elapsed_compute, + "join_time ({}) + build_time ({}) = {} was <= elapsed_compute = {}", + join_time, + build_time, + join_time + build_time, + elapsed_compute + ); + }; +} +#[cfg(test)] +pub(crate) use assert_join_metrics; diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index d0a0d25779cc8..12ffca871f073 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -255,6 +255,13 @@ impl ExecutionPlan for MockExec { // Panics if one of the batches is an error fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema)); + } let data: Result> = self .data .iter() @@ -405,6 +412,13 @@ impl ExecutionPlan for BarrierExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema)); + } Ok(common::compute_record_batch_statistics( &self.data, &self.schema, @@ -590,6 +604,14 @@ impl ExecutionPlan for StatisticsExec { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + Ok(if partition.is_some() { + Statistics::new_unknown(&self.schema) + } else { + self.stats.clone() + }) + } } /// Execution plan that emits streams that block forever. diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 85de1eefce2e4..9435de1cc4488 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,25 +18,33 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ - compute::interleave, + array::{Array, AsArray}, + compute::{interleave_record_batch, prep_null_mask_filter, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; +use datafusion_expr::{ColumnarValue, Operator}; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; use crate::spill::get_record_batch_memory_size; use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; -use arrow::array::{Array, ArrayRef, RecordBatch}; + +use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; -use datafusion_common::HashMap; -use datafusion_common::Result; +use datafusion_common::{ + internal_datafusion_err, internal_err, HashMap, Result, ScalarValue, +}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, runtime_env::RuntimeEnv, }; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr::{ + expressions::{is_not_null, is_null, lit, BinaryExpr, DynamicFilterPhysicalExpr}, + PhysicalExpr, +}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use parking_lot::RwLock; /// Global TopK /// @@ -70,6 +78,25 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; /// The same answer can be produced by simply keeping track of the top /// K=3 elements, reducing the total amount of required buffer memory. /// +/// # Partial Sort Optimization +/// +/// This implementation additionally optimizes queries where the input is already +/// partially sorted by a common prefix of the requested ordering. Once the top K +/// heap is full, if subsequent rows are guaranteed to be strictly greater (in sort +/// order) on this prefix than the largest row currently stored, the operator +/// safely terminates early. +/// +/// ## Example +/// +/// For input sorted by `(day DESC)`, but not by `timestamp`, a query such as: +/// +/// ```sql +/// SELECT day, timestamp FROM sensor ORDER BY day DESC, timestamp DESC LIMIT 10; +/// ``` +/// +/// can terminate scanning early once sufficient rows from the latest days have been +/// collected, skipping older data. +/// /// # Structure /// /// This operator tracks the top K items using a `TopKHeap`. @@ -83,50 +110,101 @@ pub struct TopK { /// The target number of rows for output batches batch_size: usize, /// sort expressions - expr: Arc<[PhysicalSortExpr]>, + expr: LexOrdering, /// row converter, for sort keys row_converter: RowConverter, /// scratch space for converting rows scratch_rows: Rows, /// stores the top k values and their sort key values, in order heap: TopKHeap, + /// row converter, for common keys between the sort keys and the input ordering + common_sort_prefix_converter: Option, + /// Common sort prefix between the input and the sort expressions to allow early exit optimization + common_sort_prefix: Arc<[PhysicalSortExpr]>, + /// Filter matching the state of the `TopK` heap used for dynamic filter pushdown + filter: Arc>, + /// If true, indicates that all rows of subsequent batches are guaranteed + /// to be greater (by byte order, after row conversion) than the top K, + /// which means the top K won't change and the computation can be finished early. + pub(crate) finished: bool, +} + +#[derive(Debug, Clone)] +pub struct TopKDynamicFilters { + /// The current *global* threshold for the dynamic filter. + /// This is shared across all partitions and is updated by any of them. + /// Stored as row bytes for efficient comparison. + threshold_row: Option>, + /// The expression used to evaluate the dynamic filter + /// Only updated when lock held for the duration of the update + expr: Arc, +} + +impl TopKDynamicFilters { + /// Create a new `TopKDynamicFilters` with the given expression + pub fn new(expr: Arc) -> Self { + Self { + threshold_row: None, + expr, + } + } + + pub fn expr(&self) -> Arc { + Arc::clone(&self.expr) + } +} + +// Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter +const ESTIMATED_BYTES_PER_ROW: usize = 20; + +fn build_sort_fields( + ordering: &[PhysicalSortExpr], + schema: &SchemaRef, +) -> Result> { + ordering + .iter() + .map(|e| { + Ok(SortField::new_with_options( + e.expr.data_type(schema)?, + e.options, + )) + }) + .collect::>() } impl TopK { /// Create a new [`TopK`] that stores the top `k` values, as /// defined by the sort expressions in `expr`. // TODO: make a builder or some other nicer API + #[allow(clippy::too_many_arguments)] pub fn try_new( partition_id: usize, schema: SchemaRef, + common_sort_prefix: Vec, expr: LexOrdering, k: usize, batch_size: usize, runtime: Arc, metrics: &ExecutionPlanMetricsSet, + filter: Arc>, ) -> Result { let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) .register(&runtime.memory_pool); - let expr: Arc<[PhysicalSortExpr]> = expr.into(); - - let sort_fields: Vec<_> = expr - .iter() - .map(|e| { - Ok(SortField::new_with_options( - e.expr.data_type(&schema)?, - e.options, - )) - }) - .collect::>()?; + let sort_fields = build_sort_fields(&expr, &schema)?; // TODO there is potential to add special cases for single column sort fields // to improve performance let row_converter = RowConverter::new(sort_fields)?; - let scratch_rows = row_converter.empty_rows( - batch_size, - 20 * batch_size, // guesstimate 20 bytes per row - ); + let scratch_rows = + row_converter.empty_rows(batch_size, ESTIMATED_BYTES_PER_ROW * batch_size); + + let prefix_row_converter = if common_sort_prefix.is_empty() { + None + } else { + let input_sort_fields = build_sort_fields(&common_sort_prefix, &schema)?; + Some(RowConverter::new(input_sort_fields)?) + }; Ok(Self { schema: Arc::clone(&schema), @@ -136,7 +214,11 @@ impl TopK { expr, row_converter, scratch_rows, - heap: TopKHeap::new(k, batch_size, schema), + heap: TopKHeap::new(k, batch_size), + common_sort_prefix_converter: prefix_row_converter, + common_sort_prefix: Arc::from(common_sort_prefix), + finished: false, + filter, }) } @@ -144,9 +226,10 @@ impl TopK { /// the top k seen so far. pub fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> { // Updates on drop - let _timer = self.metrics.baseline.elapsed_compute().timer(); + let baseline = self.metrics.baseline.clone(); + let _timer = baseline.elapsed_compute().timer(); - let sort_keys: Vec = self + let mut sort_keys: Vec = self .expr .iter() .map(|expr| { @@ -155,34 +238,340 @@ impl TopK { }) .collect::>>()?; + let mut selected_rows = None; + + // If a filter is provided, update it with the new rows + let filter = self.filter.read().expr.current()?; + let filtered = filter.evaluate(&batch)?; + let num_rows = batch.num_rows(); + let array = filtered.into_array(num_rows)?; + let mut filter = array.as_boolean().clone(); + let true_count = filter.true_count(); + if true_count == 0 { + // nothing to filter, so no need to update + return Ok(()); + } + // only update the keys / rows if the filter does not match all rows + if true_count < num_rows { + // Indices in `set_indices` should be correct if filter contains nulls + // So we prepare the filter here. Note this is also done in the `FilterBuilder` + // so there is no overhead to do this here. + if filter.nulls().is_some() { + filter = prep_null_mask_filter(&filter); + } + + let filter_predicate = FilterBuilder::new(&filter); + let filter_predicate = if sort_keys.len() > 1 { + // Optimize filter when it has multiple sort keys + filter_predicate.optimize().build() + } else { + filter_predicate.build() + }; + selected_rows = Some(filter); + sort_keys = sort_keys + .iter() + .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) + .collect::>>()?; + } // reuse existing `Rows` to avoid reallocations let rows = &mut self.scratch_rows; rows.clear(); self.row_converter.append(rows, &sort_keys)?; - // TODO make this algorithmically better?: - // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`) - // this avoids some work and also might be better vectorizable. - let mut batch_entry = self.heap.register_batch(batch); - for (index, row) in rows.iter().enumerate() { + let mut batch_entry = self.heap.register_batch(batch.clone()); + + let replacements = match selected_rows { + Some(filter) => { + self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry) + } + None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry), + }; + + if replacements > 0 { + self.metrics.row_replacements.add(replacements); + + self.heap.insert_batch_entry(batch_entry); + + // conserve memory + self.heap.maybe_compact()?; + + // update memory reservation + self.reservation.try_resize(self.size())?; + + // flag the topK as finished if we know that all + // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K, + // which means the top K won't change and the computation can be finished early. + self.attempt_early_completion(&batch)?; + + // update the filter representation of our TopK heap + self.update_filter()?; + } + + Ok(()) + } + + fn find_new_topk_items( + &mut self, + items: impl Iterator, + batch_entry: &mut RecordBatchEntry, + ) -> usize { + let mut replacements = 0; + let rows = &mut self.scratch_rows; + for (index, row) in items.zip(rows.iter()) { match self.heap.max() { // heap has k items, and the new row is greater than the // current max in the heap ==> it is not a new topk Some(max_row) if row.as_ref() >= max_row.row() => {} // don't yet have k items or new item is lower than the currently k low values None | Some(_) => { - self.heap.add(&mut batch_entry, row, index); - self.metrics.row_replacements.add(1); + self.heap.add(batch_entry, row, index); + replacements += 1; } } } - self.heap.insert_batch_entry(batch_entry); + replacements + } + + /// Update the filter representation of our TopK heap. + /// For example, given the sort expression `ORDER BY a DESC, b ASC LIMIT 3`, + /// and the current heap values `[(1, 5), (1, 4), (2, 3)]`, + /// the filter will be updated to: + /// + /// ```sql + /// (a > 1 OR (a = 1 AND b < 5)) AND + /// (a > 1 OR (a = 1 AND b < 4)) AND + /// (a > 2 OR (a = 2 AND b < 3)) + /// ``` + fn update_filter(&mut self) -> Result<()> { + // If the heap doesn't have k elements yet, we can't create thresholds + let Some(max_row) = self.heap.max() else { + return Ok(()); + }; + + let new_threshold_row = &max_row.row; + + // Fast path: check if the current value in topk is better than what is + // currently set in the filter with a read only lock + let needs_update = self + .filter + .read() + .threshold_row + .as_ref() + .map(|current_row| { + // new < current means new threshold is more selective + new_threshold_row < current_row + }) + .unwrap_or(true); // No current threshold, so we need to set one + + // exit early if the current values are better + if !needs_update { + return Ok(()); + } - // conserve memory - self.heap.maybe_compact()?; + // Extract scalar values BEFORE acquiring lock to reduce critical section + let thresholds = match self.heap.get_threshold_values(&self.expr)? { + Some(t) => t, + None => return Ok(()), + }; - // update memory reservation - self.reservation.try_resize(self.size())?; + // Build the filter expression OUTSIDE any synchronization + let predicate = Self::build_filter_expression(&self.expr, thresholds)?; + let new_threshold = new_threshold_row.to_vec(); + + // update the threshold. Since there was a lock gap, we must check if it is still the best + // may have changed while we were building the expression without the lock + let mut filter = self.filter.write(); + let old_threshold = filter.threshold_row.take(); + + // Update filter if we successfully updated the threshold + // (or if there was no previous threshold and we're the first) + match old_threshold { + Some(old_threshold) => { + // new threshold is still better than the old one + if new_threshold.as_slice() < old_threshold.as_slice() { + filter.threshold_row = Some(new_threshold); + } else { + // some other thread updated the threshold to a better + // one while we were building so there is no need to + // update the filter + filter.threshold_row = Some(old_threshold); + return Ok(()); + } + } + None => { + // No previous threshold, so we can set the new one + filter.threshold_row = Some(new_threshold); + } + }; + + // Update the filter expression + if let Some(pred) = predicate { + if !pred.eq(&lit(true)) { + filter.expr.update(pred)?; + } + } + + Ok(()) + } + + /// Build the filter expression with the given thresholds. + /// This is now called outside of any locks to reduce critical section time. + fn build_filter_expression( + sort_exprs: &[PhysicalSortExpr], + thresholds: Vec, + ) -> Result>> { + // Create filter expressions for each threshold + let mut filters: Vec> = + Vec::with_capacity(thresholds.len()); + + let mut prev_sort_expr: Option> = None; + for (sort_expr, value) in sort_exprs.iter().zip(thresholds.iter()) { + // Create the appropriate operator based on sort order + let op = if sort_expr.options.descending { + // For descending sort, we want col > threshold (exclude smaller values) + Operator::Gt + } else { + // For ascending sort, we want col < threshold (exclude larger values) + Operator::Lt + }; + + let value_null = value.is_null(); + + let comparison = Arc::new(BinaryExpr::new( + Arc::clone(&sort_expr.expr), + op, + lit(value.clone()), + )); + + let comparison_with_null = match (sort_expr.options.nulls_first, value_null) { + // For nulls first, transform to (threshold.value is not null) and (threshold.expr is null or comparison) + (true, true) => lit(false), + (true, false) => Arc::new(BinaryExpr::new( + is_null(Arc::clone(&sort_expr.expr))?, + Operator::Or, + comparison, + )), + // For nulls last, transform to (threshold.value is null and threshold.expr is not null) + // or (threshold.value is not null and comparison) + (false, true) => is_not_null(Arc::clone(&sort_expr.expr))?, + (false, false) => comparison, + }; + + let mut eq_expr = Arc::new(BinaryExpr::new( + Arc::clone(&sort_expr.expr), + Operator::Eq, + lit(value.clone()), + )); + + if value_null { + eq_expr = Arc::new(BinaryExpr::new( + is_null(Arc::clone(&sort_expr.expr))?, + Operator::Or, + eq_expr, + )); + } + + // For a query like order by a, b, the filter for column `b` is only applied if + // the condition a = threshold.value (considering null equality) is met. + // Therefore, we add equality predicates for all preceding fields to the filter logic of the current field, + // and include the current field's equality predicate in `prev_sort_expr` for use with subsequent fields. + match prev_sort_expr.take() { + None => { + prev_sort_expr = Some(eq_expr); + filters.push(comparison_with_null); + } + Some(p) => { + filters.push(Arc::new(BinaryExpr::new( + Arc::clone(&p), + Operator::And, + comparison_with_null, + ))); + + prev_sort_expr = + Some(Arc::new(BinaryExpr::new(p, Operator::And, eq_expr))); + } + } + } + + let dynamic_predicate = filters + .into_iter() + .reduce(|a, b| Arc::new(BinaryExpr::new(a, Operator::Or, b))); + + Ok(dynamic_predicate) + } + + /// If input ordering shares a common sort prefix with the TopK, and if the TopK's heap is full, + /// check if the computation can be finished early. + /// This is the case if the last row of the current batch is strictly greater than the max row in the heap, + /// comparing only on the shared prefix columns. + fn attempt_early_completion(&mut self, batch: &RecordBatch) -> Result<()> { + // Early exit if the batch is empty as there is no last row to extract from it. + if batch.num_rows() == 0 { + return Ok(()); + } + + // prefix_row_converter is only `Some` if the input ordering has a common prefix with the TopK, + // so early exit if it is `None`. + let Some(prefix_converter) = &self.common_sort_prefix_converter else { + return Ok(()); + }; + + // Early exit if the heap is not full (`heap.max()` only returns `Some` if the heap is full). + let Some(max_topk_row) = self.heap.max() else { + return Ok(()); + }; + + // Evaluate the prefix for the last row of the current batch. + let last_row_idx = batch.num_rows() - 1; + let mut batch_prefix_scratch = + prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW + + self.compute_common_sort_prefix(batch, last_row_idx, &mut batch_prefix_scratch)?; + + // Retrieve the max row from the heap. + let store_entry = self + .heap + .store + .get(max_topk_row.batch_id) + .ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?; + let max_batch = &store_entry.batch; + let mut heap_prefix_scratch = + prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW + self.compute_common_sort_prefix( + max_batch, + max_topk_row.index, + &mut heap_prefix_scratch, + )?; + + // If the last row's prefix is strictly greater than the max prefix, mark as finished. + if batch_prefix_scratch.row(0).as_ref() > heap_prefix_scratch.row(0).as_ref() { + self.finished = true; + } + + Ok(()) + } + + // Helper function to compute the prefix for a given batch and row index, storing the result in scratch. + fn compute_common_sort_prefix( + &self, + batch: &RecordBatch, + last_row_idx: usize, + scratch: &mut Rows, + ) -> Result<()> { + let last_row: Vec = self + .common_sort_prefix + .iter() + .map(|expr| { + expr.expr + .evaluate(&batch.slice(last_row_idx, 1))? + .into_array(1) + }) + .collect::>()?; + + self.common_sort_prefix_converter + .as_ref() + .unwrap() + .append(scratch, &last_row)?; Ok(()) } @@ -197,6 +586,10 @@ impl TopK { row_converter: _, scratch_rows: _, mut heap, + common_sort_prefix_converter: _, + common_sort_prefix: _, + finished: _, + filter: _, } = self; let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop @@ -271,13 +664,13 @@ struct TopKHeap { } impl TopKHeap { - fn new(k: usize, batch_size: usize, schema: SchemaRef) -> Self { + fn new(k: usize, batch_size: usize) -> Self { assert!(k > 0); Self { k, batch_size, inner: BinaryHeap::new(), - store: RecordBatchStore::new(schema), + store: RecordBatchStore::new(), owned_bytes: 0, } } @@ -354,8 +747,6 @@ impl TopKHeap { /// high, as a single [`RecordBatch`], and a sorted vec of the /// current heap's contents pub fn emit_with_state(&mut self) -> Result<(Option, Vec)> { - let schema = Arc::clone(self.store.schema()); - // generate sorted rows let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec(); @@ -363,37 +754,26 @@ impl TopKHeap { return Ok((None, topk_rows)); } - // Indices for each row within its respective RecordBatch + // Collect the batches into a vec and store the "batch_id -> array_pos" mapping, to then + // build the `indices` vec below. This is needed since the batch ids are not continuous. + let mut record_batches = Vec::new(); + let mut batch_id_array_pos = HashMap::new(); + for (array_pos, (batch_id, batch)) in self.store.batches.iter().enumerate() { + record_batches.push(&batch.batch); + batch_id_array_pos.insert(*batch_id, array_pos); + } + let indices: Vec<_> = topk_rows .iter() - .enumerate() - .map(|(i, k)| (i, k.index)) + .map(|k| (batch_id_array_pos[&k.batch_id], k.index)) .collect(); - let num_columns = schema.fields().len(); - - // build the output columns one at time, using the - // `interleave` kernel to pick rows from different arrays - let output_columns: Vec<_> = (0..num_columns) - .map(|col| { - let input_arrays: Vec<_> = topk_rows - .iter() - .map(|k| { - let entry = - self.store.get(k.batch_id).expect("invalid stored batch id"); - entry.batch.column(col) as &dyn Array - }) - .collect(); - - // at this point `indices` contains indexes within the - // rows and `input_arrays` contains a reference to the - // relevant Array for that index. `interleave` pulls - // them together into a single new array - Ok(interleave(&input_arrays, &indices)?) - }) - .collect::>()?; + // At this point `indices` contains indexes within the + // rows and `input_arrays` contains a reference to the + // relevant RecordBatch for that index. `interleave_record_batch` pulls + // them together into a single new batch + let new_batch = interleave_record_batch(&record_batches, &indices)?; - let new_batch = RecordBatch::try_new(schema, output_columns)?; Ok((Some(new_batch), topk_rows)) } @@ -451,6 +831,47 @@ impl TopKHeap { + self.store.size() + self.owned_bytes } + + fn get_threshold_values( + &self, + sort_exprs: &[PhysicalSortExpr], + ) -> Result>> { + // If the heap doesn't have k elements yet, we can't create thresholds + let max_row = match self.max() { + Some(row) => row, + None => return Ok(None), + }; + + // Get the batch that contains the max row + let batch_entry = match self.store.get(max_row.batch_id) { + Some(entry) => entry, + None => return internal_err!("Invalid batch ID in TopKRow"), + }; + + // Extract threshold values for each sort expression + let mut scalar_values = Vec::with_capacity(sort_exprs.len()); + for sort_expr in sort_exprs { + // Extract the value for this column from the max row + let expr = Arc::clone(&sort_expr.expr); + let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; + + // Convert to scalar value - should be a single value since we're evaluating on a single row batch + let scalar = match value { + ColumnarValue::Scalar(scalar) => scalar, + ColumnarValue::Array(array) if array.len() == 1 => { + // Extract the first (and only) value from the array + ScalarValue::try_from_array(&array, 0)? + } + array => { + return internal_err!("Expected a scalar value, got {:?}", array) + } + }; + + scalar_values.push(scalar); + } + + Ok(Some(scalar_values)) + } } /// Represents one of the top K rows held in this heap. Orders @@ -518,6 +939,7 @@ impl Eq for TopKRow {} impl PartialOrd for TopKRow { fn partial_cmp(&self, other: &Self) -> Option { + // TODO PartialOrd is not consistent with PartialEq; PartialOrd contract is violated Some(self.cmp(other)) } } @@ -548,17 +970,14 @@ struct RecordBatchStore { batches: HashMap, /// total size of all record batches tracked by this store batches_size: usize, - /// schema of the batches - schema: SchemaRef, } impl RecordBatchStore { - fn new(schema: SchemaRef) -> Self { + fn new() -> Self { Self { next_id: 0, batches: HashMap::new(), batches_size: 0, - schema, } } @@ -609,11 +1028,6 @@ impl RecordBatchStore { self.batches.is_empty() } - /// return the schema of batches stored - fn schema(&self) -> &SchemaRef { - &self.schema - } - /// remove a use from the specified batch id. If the use count /// reaches zero the batch entry is removed from the store /// @@ -649,6 +1063,10 @@ mod tests { use super::*; use arrow::array::{Float64Array, Int32Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use datafusion_common::assert_batches_eq; + use datafusion_physical_expr::expressions::col; + use futures::TryStreamExt; /// This test ensures the size calculation is correct for RecordBatches with multiple columns. #[test] @@ -658,7 +1076,7 @@ mod tests { Field::new("ints", DataType::Int32, true), Field::new("float64", DataType::Float64, false), ])); - let mut record_batch_store = RecordBatchStore::new(Arc::clone(&schema)); + let mut record_batch_store = RecordBatchStore::new(); let int_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20 let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40 @@ -681,4 +1099,101 @@ mod tests { record_batch_store.unuse(0); assert_eq!(record_batch_store.batches_size, 0); } + + /// This test validates that the `try_finish` method marks the TopK operator as finished + /// when the prefix (on column "a") of the last row in the current batch is strictly greater + /// than the max top‑k row. + /// The full sort expression is defined on both columns ("a", "b"), but the input ordering is only on "a". + #[tokio::test] + async fn test_try_finish_marks_finished_with_prefix() -> Result<()> { + // Create a schema with two columns. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ])); + + // Create sort expressions. + // Full sort: first by "a", then by "b". + let sort_expr_a = PhysicalSortExpr { + expr: col("a", schema.as_ref())?, + options: SortOptions::default(), + }; + let sort_expr_b = PhysicalSortExpr { + expr: col("b", schema.as_ref())?, + options: SortOptions::default(), + }; + + // Input ordering uses only column "a" (a prefix of the full sort). + let prefix = vec![sort_expr_a.clone()]; + let full_expr = LexOrdering::from([sort_expr_a, sort_expr_b]); + + // Create a dummy runtime environment and metrics. + let runtime = Arc::new(RuntimeEnv::default()); + let metrics = ExecutionPlanMetricsSet::new(); + + // Create a TopK instance with k = 3 and batch_size = 2. + let mut topk = TopK::try_new( + 0, + Arc::clone(&schema), + prefix, + full_expr, + 3, + 2, + runtime, + &metrics, + Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new( + DynamicFilterPhysicalExpr::new(vec![], lit(true)), + )))), + )?; + + // Create the first batch with two columns: + // Column "a": [1, 1, 2], Column "b": [20.0, 15.0, 30.0]. + let array_a1: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), Some(1), Some(2)])); + let array_b1: ArrayRef = Arc::new(Float64Array::from(vec![20.0, 15.0, 30.0])); + let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a1, array_b1])?; + + // Insert the first batch. + // At this point the heap is not yet “finished” because the prefix of the last row of the batch + // is not strictly greater than the prefix of the max top‑k row (both being `2`). + topk.insert_batch(batch1)?; + assert!( + !topk.finished, + "Expected 'finished' to be false after the first batch." + ); + + // Create the second batch with two columns: + // Column "a": [2, 3], Column "b": [10.0, 20.0]. + let array_a2: ArrayRef = Arc::new(Int32Array::from(vec![Some(2), Some(3)])); + let array_b2: ArrayRef = Arc::new(Float64Array::from(vec![10.0, 20.0])); + let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![array_a2, array_b2])?; + + // Insert the second batch. + // The last row in this batch has a prefix value of `3`, + // which is strictly greater than the max top‑k row (with value `2`), + // so try_finish should mark the TopK as finished. + topk.insert_batch(batch2)?; + assert!( + topk.finished, + "Expected 'finished' to be true after the second batch." + ); + + // Verify the TopK correctly emits the top k rows from both batches + // (the value 10.0 for b is from the second batch). + let results: Vec<_> = topk.emit()?.try_collect().await?; + assert_batches_eq!( + &[ + "+---+------+", + "| a | b |", + "+---+------+", + "| 1 | 15.0 |", + "| 1 | 20.0 |", + "| 2 | 10.0 |", + "+---+------+", + ], + &results + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 69b0a165315ec..78ba984ed1a58 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -94,7 +94,7 @@ impl PlanContext { impl Display for PlanContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let node_string = displayable(self.plan.as_ref()).one_line(); - write!(f, "Node plan: {}", node_string)?; + write!(f, "Node plan: {node_string}")?; write!(f, "Node data: {}", self.data)?; write!(f, "") } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 2b666093f29e0..164f17edebd31 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -33,18 +33,21 @@ use super::{ SendableRecordBatchStream, Statistics, }; use crate::execution_plan::{ - boundedness_from_children, emission_type_from_children, InvariantLevel, + boundedness_from_children, check_default_invariants, emission_type_from_children, + InvariantLevel, }; +use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; use crate::metrics::BaselineMetrics; use crate::projection::{make_with_child, ProjectionExec}; use crate::stream::ObservedStream; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; -use datafusion_common::{exec_err, internal_err, DataFusionError, Result}; +use datafusion_common::{exec_err, internal_datafusion_err, internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{calculate_union, EquivalenceProperties}; +use datafusion_physical_expr::{calculate_union, EquivalenceProperties, PhysicalExpr}; use futures::Stream; use itertools::Itertools; @@ -100,8 +103,10 @@ pub struct UnionExec { impl UnionExec { /// Create a new UnionExec + #[deprecated(since = "44.0.0", note = "Use UnionExec::try_new instead")] pub fn new(inputs: Vec>) -> Self { - let schema = union_schema(&inputs); + let schema = + union_schema(&inputs).expect("UnionExec::new called with empty inputs"); // The schema of the inputs and the union schema is consistent when: // - They have the same number of fields, and // - Their fields have same types at the same indices. @@ -115,6 +120,37 @@ impl UnionExec { } } + /// Try to create a new UnionExec. + /// + /// # Errors + /// Returns an error if: + /// - `inputs` is empty + /// + /// # Optimization + /// If there is only one input, returns that input directly rather than wrapping it in a UnionExec + pub fn try_new( + inputs: Vec>, + ) -> Result> { + match inputs.len() { + 0 => exec_err!("UnionExec requires at least one input"), + 1 => Ok(inputs.into_iter().next().unwrap()), + _ => { + let schema = union_schema(&inputs)?; + // The schema of the inputs and the union schema is consistent when: + // - They have the same number of fields, and + // - Their fields have same types at the same indices. + // Here, we know that schemas are consistent and the call below can + // not return an error. + let cache = Self::compute_properties(&inputs, schema).unwrap(); + Ok(Arc::new(UnionExec { + inputs, + metrics: ExecutionPlanMetricsSet::new(), + cache, + })) + } + } + } + /// Get inputs of the execution plan pub fn inputs(&self) -> &Vec> { &self.inputs @@ -176,16 +212,12 @@ impl ExecutionPlan for UnionExec { &self.cache } - fn check_invariants(&self, _check: InvariantLevel) -> Result<()> { - (self.inputs().len() >= 2) - .then_some(()) - .ok_or(DataFusionError::Internal( - "UnionExec should have at least 2 children".into(), - )) - } + fn check_invariants(&self, check: InvariantLevel) -> Result<()> { + check_default_invariants(self, check)?; - fn children(&self) -> Vec<&Arc> { - self.inputs.iter().collect() + (self.inputs().len() >= 2).then_some(()).ok_or_else(|| { + internal_datafusion_err!("UnionExec should have at least 2 children") + }) } fn maintains_input_order(&self) -> Vec { @@ -213,11 +245,19 @@ impl ExecutionPlan for UnionExec { } } + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false; self.children().len()] + } + + fn children(&self) -> Vec<&Arc> { + self.inputs.iter().collect() + } + fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(UnionExec::new(children))) + UnionExec::try_new(children) } fn execute( @@ -248,7 +288,7 @@ impl ExecutionPlan for UnionExec { } } - warn!("Error in Union: Partition {} not found", partition); + warn!("Error in Union: Partition {partition} not found"); exec_err!("Partition {partition} not found in Union") } @@ -258,20 +298,36 @@ impl ExecutionPlan for UnionExec { } fn statistics(&self) -> Result { - let stats = self - .inputs - .iter() - .map(|stat| stat.statistics()) - .collect::>>()?; - - Ok(stats - .into_iter() - .reduce(stats_union) - .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) + self.partition_statistics(None) } - fn benefits_from_input_partitioning(&self) -> Vec { - vec![false; self.children().len()] + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition_idx) = partition { + // For a specific partition, find which input it belongs to + let mut remaining_idx = partition_idx; + for input in &self.inputs { + let input_partition_count = input.output_partitioning().partition_count(); + if remaining_idx < input_partition_count { + // This partition belongs to this input + return input.partition_statistics(Some(remaining_idx)); + } + remaining_idx -= input_partition_count; + } + // If we get here, the partition index is out of bounds + Ok(Statistics::new_unknown(&self.schema())) + } else { + // Collect statistics from all inputs + let stats = self + .inputs + .iter() + .map(|input_exec| input_exec.partition_statistics(None)) + .collect::>>()?; + + Ok(stats + .into_iter() + .reduce(stats_union) + .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) + } } fn supports_limit_pushdown(&self) -> bool { @@ -296,7 +352,16 @@ impl ExecutionPlan for UnionExec { .map(|child| make_with_child(projection, child)) .collect::>>()?; - Ok(Some(Arc::new(UnionExec::new(new_children)))) + Ok(Some(UnionExec::try_new(new_children.clone())?)) + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) } } @@ -350,7 +415,7 @@ impl InterleaveExec { "Not all InterleaveExec children have a consistent hash partitioning" ); } - let cache = Self::compute_properties(&inputs); + let cache = Self::compute_properties(&inputs)?; Ok(InterleaveExec { inputs, metrics: ExecutionPlanMetricsSet::new(), @@ -364,17 +429,17 @@ impl InterleaveExec { } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties(inputs: &[Arc]) -> PlanProperties { - let schema = union_schema(inputs); + fn compute_properties(inputs: &[Arc]) -> Result { + let schema = union_schema(inputs)?; let eq_properties = EquivalenceProperties::new(schema); // Get output partitioning: let output_partitioning = inputs[0].output_partitioning().clone(); - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, emission_type_from_children(inputs), boundedness_from_children(inputs), - ) + )) } } @@ -461,7 +526,7 @@ impl ExecutionPlan for InterleaveExec { ))); } - warn!("Error in InterleaveExec: Partition {} not found", partition); + warn!("Error in InterleaveExec: Partition {partition} not found"); exec_err!("Partition {partition} not found in InterleaveExec") } @@ -471,10 +536,14 @@ impl ExecutionPlan for InterleaveExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { let stats = self .inputs .iter() - .map(|stat| stat.statistics()) + .map(|stat| stat.partition_statistics(partition)) .collect::>>()?; Ok(stats @@ -508,12 +577,21 @@ pub fn can_interleave>>( .all(|partition| partition == *reference) } -fn union_schema(inputs: &[Arc]) -> SchemaRef { +fn union_schema(inputs: &[Arc]) -> Result { + if inputs.is_empty() { + return exec_err!("Cannot create union schema from empty inputs"); + } + let first_schema = inputs[0].schema(); let fields = (0..first_schema.fields().len()) .map(|i| { - inputs + // We take the name from the left side of the union to match how names are coerced during logical planning, + // which also uses the left side names. + let base_field = first_schema.field(i).clone(); + + // Coerce metadata and nullability across all inputs + let merged_field = inputs .iter() .enumerate() .map(|(input_idx, input)| { @@ -535,6 +613,9 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { // We can unwrap this because if inputs was empty, this would've already panic'ed when we // indexed into inputs[0]. .unwrap() + .with_name(base_field.name()); + + merged_field }) .collect::>(); @@ -543,7 +624,10 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { .flat_map(|i| i.schema().metadata().clone().into_iter()) .collect(); - Arc::new(Schema::new_with_metadata(fields, all_metadata_merged)) + Ok(Arc::new(Schema::new_with_metadata( + fields, + all_metadata_merged, + ))) } /// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one @@ -642,15 +726,13 @@ fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { mod tests { use super::*; use crate::collect; - use crate::test; - use crate::test::TestMemoryExec; + use crate::test::{self, TestMemoryExec}; use arrow::compute::SortOptions; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; + use datafusion_physical_expr::equivalence::convert_to_orderings; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use datafusion_physical_expr_common::sort_expr::LexOrdering; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { @@ -666,19 +748,6 @@ mod tests { Ok(schema) } - // Convert each tuple to PhysicalSortExpr - fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect::() - } - #[tokio::test] async fn test_union_partitions() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -687,7 +756,7 @@ mod tests { let csv = test::scan_partitioned(4); let csv2 = test::scan_partitioned(5); - let union_exec = Arc::new(UnionExec::new(vec![csv, csv2])); + let union_exec: Arc = UnionExec::try_new(vec![csv, csv2])?; // Should have 9 partitions and 9 output batches assert_eq!( @@ -854,18 +923,9 @@ mod tests { (first_child_orderings, second_child_orderings, union_orderings), ) in test_cases.iter().enumerate() { - let first_orderings = first_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let second_orderings = second_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let union_expected_orderings = union_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); + let first_orderings = convert_to_orderings(first_child_orderings); + let second_orderings = convert_to_orderings(second_child_orderings); + let union_expected_orderings = convert_to_orderings(union_orderings); let child1 = Arc::new(TestMemoryExec::update_cache(Arc::new( TestMemoryExec::try_new(&[], Arc::clone(&schema), None)? .try_with_sort_information(first_orderings)?, @@ -876,9 +936,9 @@ mod tests { ))); let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); - union_expected_eq.add_new_orderings(union_expected_orderings); + union_expected_eq.add_orderings(union_expected_orderings); - let union = UnionExec::new(vec![child1, child2]); + let union: Arc = UnionExec::try_new(vec![child1, child2])?; let union_eq_properties = union.properties().equivalence_properties(); let err_msg = format!( "Error in test id: {:?}, test case: {:?}", @@ -897,9 +957,71 @@ mod tests { // Check whether orderings are same. let lhs_orderings = lhs.oeq_class(); let rhs_orderings = rhs.oeq_class(); - assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); + assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}"); for rhs_ordering in rhs_orderings.iter() { assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); } } + + #[test] + fn test_union_empty_inputs() { + // Test that UnionExec::try_new fails with empty inputs + let result = UnionExec::try_new(vec![]); + assert!(result + .unwrap_err() + .to_string() + .contains("UnionExec requires at least one input")); + } + + #[test] + fn test_union_schema_empty_inputs() { + // Test that union_schema fails with empty inputs + let result = union_schema(&[]); + assert!(result + .unwrap_err() + .to_string() + .contains("Cannot create union schema from empty inputs")); + } + + #[test] + fn test_union_single_input() -> Result<()> { + // Test that UnionExec::try_new returns the single input directly + let schema = create_test_schema()?; + let memory_exec: Arc = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?); + let memory_exec_clone = Arc::clone(&memory_exec); + let result = UnionExec::try_new(vec![memory_exec])?; + + // Check that the result is the same as the input (no UnionExec wrapper) + assert_eq!(result.schema(), schema); + // Verify it's the same execution plan + assert!(Arc::ptr_eq(&result, &memory_exec_clone)); + + Ok(()) + } + + #[test] + fn test_union_schema_multiple_inputs() -> Result<()> { + // Test that existing functionality with multiple inputs still works + let schema = create_test_schema()?; + let memory_exec1 = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?); + let memory_exec2 = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?); + + let union_plan = UnionExec::try_new(vec![memory_exec1, memory_exec2])?; + + // Downcast to verify it's a UnionExec + let union = union_plan + .as_any() + .downcast_ref::() + .expect("Expected UnionExec"); + + // Check that schema is correct + assert_eq!(union.schema(), schema); + // Check that we have 2 inputs + assert_eq!(union.inputs().len(), 2); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index c06b09f2fecd5..e36cd2b6c2429 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -21,7 +21,10 @@ use std::cmp::{self, Ordering}; use std::task::{ready, Poll}; use std::{any::Any, sync::Arc}; -use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use super::metrics::{ + self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, + RecordOutput, +}; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, RecordBatchStream, @@ -38,13 +41,12 @@ use arrow::compute::{cast, is_not_null, kernels, sum}; use arrow::datatypes::{DataType, Int64Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_ord::cmp::lt; +use async_trait::async_trait; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, HashMap, HashSet, Result, UnnestOptions, }; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; - -use async_trait::async_trait; use futures::{Stream, StreamExt}; use log::trace; @@ -203,22 +205,18 @@ impl ExecutionPlan for UnnestExec { #[derive(Clone, Debug)] struct UnnestMetrics { - /// Total time for column unnesting - elapsed_compute: metrics::Time, + /// Execution metrics + baseline_metrics: BaselineMetrics, /// Number of batches consumed input_batches: metrics::Count, /// Number of rows consumed input_rows: metrics::Count, /// Number of batches produced output_batches: metrics::Count, - /// Number of rows produced by this operator - output_rows: metrics::Count, } impl UnnestMetrics { fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let elapsed_compute = MetricBuilder::new(metrics).elapsed_compute(partition); - let input_batches = MetricBuilder::new(metrics).counter("input_batches", partition); @@ -227,14 +225,11 @@ impl UnnestMetrics { let output_batches = MetricBuilder::new(metrics).counter("output_batches", partition); - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - Self { + baseline_metrics: BaselineMetrics::new(metrics, partition), input_batches, input_rows, output_batches, - output_rows, - elapsed_compute, } } } @@ -284,7 +279,9 @@ impl UnnestStream { loop { return Poll::Ready(match ready!(self.input.poll_next_unpin(cx)) { Some(Ok(batch)) => { - let timer = self.metrics.elapsed_compute.timer(); + let elapsed_compute = + self.metrics.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); self.metrics.input_batches.add(1); self.metrics.input_rows.add(batch.num_rows()); let result = build_batch( @@ -299,7 +296,7 @@ impl UnnestStream { continue; }; self.metrics.output_batches.add(1); - self.metrics.output_rows.add(result_batch.num_rows()); + (&result_batch).record_output(&self.metrics.baseline_metrics); // Empty record batches should not be emitted. // They need to be treated as [`Option`]es and handled separately @@ -313,8 +310,8 @@ impl UnnestStream { self.metrics.input_batches, self.metrics.input_rows, self.metrics.output_batches, - self.metrics.output_rows, - self.metrics.elapsed_compute, + self.metrics.baseline_metrics.output_rows(), + self.metrics.baseline_metrics.elapsed_compute(), ); other } diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs deleted file mode 100644 index 6cb64bcb5d867..0000000000000 --- a/datafusion/physical-plan/src/values.rs +++ /dev/null @@ -1,328 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Values execution plan - -use std::any::Any; -use std::sync::Arc; - -use crate::execution_plan::{Boundedness, EmissionType}; -use crate::memory::MemoryStream; -use crate::{common, DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics}; -use crate::{ - ColumnarValue, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, -}; -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::EquivalenceProperties; - -/// Execution plan for values list based relation (produces constant rows) -#[deprecated( - since = "45.0.0", - note = "Use `MemorySourceConfig::try_new_as_values` instead" -)] -#[derive(Debug, Clone)] -pub struct ValuesExec { - /// The schema - schema: SchemaRef, - /// The data - data: Vec, - /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, -} - -#[allow(deprecated)] -impl ValuesExec { - /// Create a new values exec from data as expr - #[deprecated(since = "45.0.0", note = "Use `MemoryExec::try_new` instead")] - pub fn try_new( - schema: SchemaRef, - data: Vec>>, - ) -> Result { - if data.is_empty() { - return plan_err!("Values list cannot be empty"); - } - let n_row = data.len(); - let n_col = schema.fields().len(); - // We have this single row batch as a placeholder to satisfy evaluation argument - // and generate a single output row - let batch = RecordBatch::try_new_with_options( - Arc::new(Schema::empty()), - vec![], - &RecordBatchOptions::new().with_row_count(Some(1)), - )?; - - let arr = (0..n_col) - .map(|j| { - (0..n_row) - .map(|i| { - let r = data[i][j].evaluate(&batch); - - match r { - Ok(ColumnarValue::Scalar(scalar)) => Ok(scalar), - Ok(ColumnarValue::Array(a)) if a.len() == 1 => { - ScalarValue::try_from_array(&a, 0) - } - Ok(ColumnarValue::Array(a)) => { - plan_err!( - "Cannot have array values {a:?} in a values list" - ) - } - Err(err) => Err(err), - } - }) - .collect::>>() - .and_then(ScalarValue::iter_to_array) - }) - .collect::>>()?; - let batch = RecordBatch::try_new_with_options( - Arc::clone(&schema), - arr, - &RecordBatchOptions::new().with_row_count(Some(n_row)), - )?; - let data: Vec = vec![batch]; - Self::try_new_from_batches(schema, data) - } - - /// Create a new plan using the provided schema and batches. - /// - /// Errors if any of the batches don't match the provided schema, or if no - /// batches are provided. - #[deprecated( - since = "45.0.0", - note = "Use `MemoryExec::try_new_from_batches` instead" - )] - pub fn try_new_from_batches( - schema: SchemaRef, - batches: Vec, - ) -> Result { - if batches.is_empty() { - return plan_err!("Values list cannot be empty"); - } - - for batch in &batches { - let batch_schema = batch.schema(); - if batch_schema != schema { - return plan_err!( - "Batch has invalid schema. Expected: {schema}, got: {batch_schema}" - ); - } - } - - let cache = Self::compute_properties(Arc::clone(&schema)); - #[allow(deprecated)] - Ok(ValuesExec { - schema, - data: batches, - cache, - }) - } - - /// Provides the data - pub fn data(&self) -> Vec { - #[allow(deprecated)] - self.data.clone() - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties(schema: SchemaRef) -> PlanProperties { - PlanProperties::new( - EquivalenceProperties::new(schema), - Partitioning::UnknownPartitioning(1), - EmissionType::Incremental, - Boundedness::Bounded, - ) - } -} - -#[allow(deprecated)] -impl DisplayAs for ValuesExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "ValuesExec") - } - DisplayFormatType::TreeRender => { - // TODO: collect info - write!(f, "") - } - } - } -} - -#[allow(deprecated)] -impl ExecutionPlan for ValuesExec { - fn name(&self) -> &'static str { - "ValuesExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - #[allow(deprecated)] - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - #[allow(deprecated)] - ValuesExec::try_new_from_batches(Arc::clone(&self.schema), self.data.clone()) - .map(|e| Arc::new(e) as _) - } - - fn execute( - &self, - partition: usize, - _context: Arc, - ) -> Result { - // ValuesExec has a single output partition - if 0 != partition { - return internal_err!( - "ValuesExec invalid partition {partition} (expected 0)" - ); - } - - Ok(Box::pin(MemoryStream::try_new( - self.data(), - #[allow(deprecated)] - Arc::clone(&self.schema), - None, - )?)) - } - - fn statistics(&self) -> Result { - let batch = self.data(); - Ok(common::compute_record_batch_statistics( - &[batch], - #[allow(deprecated)] - &self.schema, - None, - )) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::lit; - use crate::test::{self, make_partition}; - - use arrow::datatypes::{DataType, Field}; - use datafusion_common::stats::{ColumnStatistics, Precision}; - - #[tokio::test] - async fn values_empty_case() -> Result<()> { - let schema = test::aggr_test_schema(); - #[allow(deprecated)] - let empty = ValuesExec::try_new(schema, vec![]); - assert!(empty.is_err()); - Ok(()) - } - - #[test] - fn new_exec_with_batches() { - let batch = make_partition(7); - let schema = batch.schema(); - let batches = vec![batch.clone(), batch]; - #[allow(deprecated)] - let _exec = ValuesExec::try_new_from_batches(schema, batches).unwrap(); - } - - #[test] - fn new_exec_with_batches_empty() { - let batch = make_partition(7); - let schema = batch.schema(); - #[allow(deprecated)] - let _ = ValuesExec::try_new_from_batches(schema, Vec::new()).unwrap_err(); - } - - #[test] - fn new_exec_with_batches_invalid_schema() { - let batch = make_partition(7); - let batches = vec![batch.clone(), batch]; - - let invalid_schema = Arc::new(Schema::new(vec![ - Field::new("col0", DataType::UInt32, false), - Field::new("col1", DataType::Utf8, false), - ])); - #[allow(deprecated)] - let _ = ValuesExec::try_new_from_batches(invalid_schema, batches).unwrap_err(); - } - - // Test issue: https://github.com/apache/datafusion/issues/8763 - #[test] - fn new_exec_with_non_nullable_schema() { - let schema = Arc::new(Schema::new(vec![Field::new( - "col0", - DataType::UInt32, - false, - )])); - #[allow(deprecated)] - let _ = ValuesExec::try_new(Arc::clone(&schema), vec![vec![lit(1u32)]]).unwrap(); - // Test that a null value is rejected - #[allow(deprecated)] - let _ = ValuesExec::try_new(schema, vec![vec![lit(ScalarValue::UInt32(None))]]) - .unwrap_err(); - } - - #[test] - fn values_stats_with_nulls_only() -> Result<()> { - let data = vec![ - vec![lit(ScalarValue::Null)], - vec![lit(ScalarValue::Null)], - vec![lit(ScalarValue::Null)], - ]; - let rows = data.len(); - #[allow(deprecated)] - let values = ValuesExec::try_new( - Arc::new(Schema::new(vec![Field::new("col0", DataType::Null, true)])), - data, - )?; - - assert_eq!( - values.statistics()?, - Statistics { - num_rows: Precision::Exact(rows), - total_byte_size: Precision::Exact(8), // not important - column_statistics: vec![ColumnStatistics { - null_count: Precision::Exact(rows), // there are only nulls - distinct_count: Precision::Absent, - max_value: Precision::Absent, - min_value: Precision::Absent, - sum_value: Precision::Absent, - },], - } - ); - - Ok(()) - } -} diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 92138bf6a7a1a..891fd0ae48511 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -38,7 +38,7 @@ use crate::{ ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; -use ahash::RandomState; + use arrow::compute::take_record_batch; use arrow::{ array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, @@ -52,7 +52,7 @@ use datafusion_common::utils::{ evaluate_partition_ranges, get_at_indices, get_row_at_idx, }; use datafusion_common::{ - arrow_datafusion_err, exec_err, DataFusionError, HashMap, Result, + arrow_datafusion_err, exec_datafusion_err, exec_err, DataFusionError, HashMap, Result, }; use datafusion_execution::TaskContext; use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; @@ -60,9 +60,12 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr::window::{ PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState, }; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, +}; +use ahash::RandomState; use futures::stream::Stream; use futures::{ready, StreamExt}; use hashbrown::hash_table::HashTable; @@ -111,7 +114,7 @@ impl BoundedWindowAggExec { let indices = get_ordered_partition_by_indices( window_expr[0].partition_by(), &input, - ); + )?; if indices.len() == partition_by_exprs.len() { indices } else { @@ -123,7 +126,7 @@ impl BoundedWindowAggExec { vec![] } }; - let cache = Self::compute_properties(&input, &schema, &window_expr); + let cache = Self::compute_properties(&input, &schema, &window_expr)?; Ok(Self { input, window_expr, @@ -151,7 +154,7 @@ impl BoundedWindowAggExec { // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points - pub fn partition_by_sort_keys(&self) -> Result { + pub fn partition_by_sort_keys(&self) -> Result> { let partition_by = self.window_expr()[0].partition_by(); get_partition_by_sort_exprs( &self.input, @@ -191,9 +194,9 @@ impl BoundedWindowAggExec { input: &Arc, schema: &SchemaRef, window_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - let eq_properties = window_equivalence_properties(schema, input, window_exprs); + let eq_properties = window_equivalence_properties(schema, input, window_exprs)?; // As we can have repartitioning using the partition keys, this can // be either one or more than one, depending on the presence of @@ -201,13 +204,13 @@ impl BoundedWindowAggExec { let output_partitioning = input.output_partitioning().clone(); // Construct properties cache - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, // TODO: Emission type and boundedness information can be enhanced here input.pipeline_behavior(), input.boundedness(), - ) + )) } pub fn partition_keys(&self) -> Vec> { @@ -226,6 +229,23 @@ impl BoundedWindowAggExec { .unwrap_or_else(Vec::new) } } + + fn statistics_helper(&self, statistics: Statistics) -> Result { + let win_cols = self.window_expr.len(); + let input_cols = self.input.schema().fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(statistics.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) + } + Ok(Statistics { + num_rows: statistics.num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } } impl DisplayAs for BoundedWindowAggExec { @@ -241,10 +261,14 @@ impl DisplayAs for BoundedWindowAggExec { .window_expr .iter() .map(|e| { + let field = match e.field() { + Ok(f) => f.to_string(), + Err(e) => format!("{e:?}"), + }; format!( - "{}: {:?}, frame: {:?}", + "{}: {}, frame: {}", e.name().to_owned(), - e.field(), + field, e.get_window_frame() ) }) @@ -261,7 +285,7 @@ impl DisplayAs for BoundedWindowAggExec { writeln!(f, "select_list={}", g.join(", "))?; let mode = &self.input_order_mode; - writeln!(f, "mode={:?}", mode)?; + writeln!(f, "mode={mode:?}")?; } } Ok(()) @@ -286,14 +310,14 @@ impl ExecutionPlan for BoundedWindowAggExec { vec![&self.input] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); let partition_bys = self .ordered_partition_by_indices .iter() .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } fn required_input_distribution(&self) -> Vec { @@ -343,21 +367,12 @@ impl ExecutionPlan for BoundedWindowAggExec { } fn statistics(&self) -> Result { - let input_stat = self.input.statistics()?; - let win_cols = self.window_expr.len(); - let input_cols = self.input.schema().fields().len(); - // TODO stats: some windowing function will maintain invariants such as min, max... - let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - // copy stats of the input to the beginning of the schema. - column_statistics.extend(input_stat.column_statistics); - for _ in 0..win_cols { - column_statistics.push(ColumnStatistics::new_unknown()) - } - Ok(Statistics { - num_rows: input_stat.num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stat = self.input.partition_statistics(partition)?; + self.statistics_helper(input_stat) } } @@ -414,16 +429,25 @@ trait PartitionSearcher: Send { let partition_batches = self.evaluate_partition_batches(&record_batch, window_expr)?; for (partition_row, partition_batch) in partition_batches { - let partition_batch_state = partition_buffers - .entry(partition_row) + if let Some(partition_batch_state) = partition_buffers.get_mut(&partition_row) + { + partition_batch_state.extend(&partition_batch)? + } else { + let options = RecordBatchOptions::new() + .with_row_count(Some(partition_batch.num_rows())); // Use input_schema for the buffer schema, not `record_batch.schema()` // as it may not have the "correct" schema in terms of output // nullability constraints. For details, see the following issue: // https://github.com/apache/datafusion/issues/9320 - .or_insert_with(|| { - PartitionBatchState::new(Arc::clone(self.input_schema())) - }); - partition_batch_state.extend(&partition_batch)?; + let partition_batch = RecordBatch::try_new_with_options( + Arc::clone(self.input_schema()), + partition_batch.columns().to_vec(), + &options, + )?; + let partition_batch_state = + PartitionBatchState::new_with_batch(partition_batch); + partition_buffers.insert(partition_row, partition_batch_state); + } } if self.is_mode_linear() { @@ -742,7 +766,7 @@ impl LinearSearch { /// when computing partitions. pub struct SortedSearch { /// Stores partition by columns and their ordering information - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, /// Input ordering and partition by key ordering need not be the same, so /// this vector stores the mapping between them. For instance, if the input /// is ordered by a, b and the window expression contains a PARTITION BY b, a @@ -855,9 +879,11 @@ impl SortedSearch { cur_window_expr_out_result_len }); argmin(out_col_counts).map_or(0, |(min_idx, minima)| { - for (row, count) in counts.swap_remove(min_idx).into_iter() { - let partition_batch = &mut partition_buffers[row]; - partition_batch.n_out_row = count; + let mut slowest_partition = counts.swap_remove(min_idx); + for (partition_key, partition_batch) in partition_buffers.iter_mut() { + if let Some(count) = slowest_partition.remove(partition_key) { + partition_batch.n_out_row = count; + } } minima }) @@ -1161,6 +1187,7 @@ fn get_aggregate_result_out_column( ) -> Result { let mut result = None; let mut running_length = 0; + let mut batches_to_concat = vec![]; // We assume that iteration order is according to insertion order for ( _, @@ -1172,23 +1199,31 @@ fn get_aggregate_result_out_column( { if running_length < len_to_show { let n_to_use = min(len_to_show - running_length, out_col.len()); - let slice_to_use = out_col.slice(0, n_to_use); - result = Some(match result { - Some(arr) => concat(&[&arr, &slice_to_use])?, - None => slice_to_use, - }); + let slice_to_use = if n_to_use == out_col.len() { + // avoid slice when the entire column is used + Arc::clone(out_col) + } else { + out_col.slice(0, n_to_use) + }; + batches_to_concat.push(slice_to_use); running_length += n_to_use; } else { break; } } + + if !batches_to_concat.is_empty() { + let array_refs: Vec<&dyn Array> = + batches_to_concat.iter().map(|a| a.as_ref()).collect(); + result = Some(concat(&array_refs)?); + } + if running_length != len_to_show { return exec_err!( "Generated row number should be {len_to_show}, it is {running_length}" ); } - result - .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) + result.ok_or_else(|| exec_datafusion_err!("Should contain something")) } /// Constructs a batch from the last row of batch in the argument. @@ -1208,13 +1243,13 @@ mod tests { use crate::common::collect; use crate::expressions::PhysicalSortExpr; - use crate::projection::ProjectionExec; + use crate::projection::{ProjectionExec, ProjectionExpr}; use crate::streaming::{PartitionStream, StreamingTableExec}; use crate::test::TestMemoryExec; use crate::windows::{ create_udwf_window_expr, create_window_expr, BoundedWindowAggExec, InputOrderMode, }; - use crate::{execute_stream, get_plan_string, ExecutionPlan}; + use crate::{displayable, execute_stream, ExecutionPlan}; use arrow::array::{ builder::{Int64Builder, UInt64Builder}, @@ -1339,18 +1374,17 @@ mod tests { Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; let args = vec![col_expr]; let partitionby_exprs = vec![col(hash, &schema)?]; - let orderby_exprs = LexOrdering::new(vec![PhysicalSortExpr { + let orderby_exprs = vec![PhysicalSortExpr { expr: col(order_by, &schema)?, options: SortOptions::default(), - }]); + }]; let window_frame = WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::CurrentRow, WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))), ); let fn_name = format!( - "{}({:?}) PARTITION BY: [{:?}], ORDER BY: [{:?}]", - window_fn, args, partitionby_exprs, orderby_exprs + "{window_fn}({args:?}) PARTITION BY: [{partitionby_exprs:?}], ORDER BY: [{orderby_exprs:?}]" ); let input_order_mode = InputOrderMode::Linear; Ok(Arc::new(BoundedWindowAggExec::try_new( @@ -1359,10 +1393,12 @@ mod tests { fn_name, &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs, Arc::new(window_frame), - &input.schema(), + input.schema(), false, + false, + None, )?], input, input_order_mode, @@ -1387,7 +1423,11 @@ mod tests { (expr, name) }) .collect::>(); - Ok(Arc::new(ProjectionExec::try_new(exprs, input)?)) + let proj_exprs: Vec = exprs + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?)) } fn task_context_helper() -> TaskContext { @@ -1456,13 +1496,14 @@ mod tests { } fn schema_orders(schema: &SchemaRef) -> Result> { - let orderings = vec![LexOrdering::new(vec![PhysicalSortExpr { + let orderings = vec![[PhysicalSortExpr { expr: col("sn", schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }])]; + }] + .into()]; Ok(orderings) } @@ -1613,7 +1654,7 @@ mod tests { Arc::new(StandardWindowExpr::new( last_value_func, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1624,7 +1665,7 @@ mod tests { Arc::new(StandardWindowExpr::new( nth_value_func1, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1635,7 +1676,7 @@ mod tests { Arc::new(StandardWindowExpr::new( nth_value_func2, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1653,16 +1694,11 @@ mod tests { let batches = collect(physical_plan.execute(0, task_ctx)?).await?; - let expected = vec![ - "BoundedWindowAggExec: wdw=[last: Ok(Field { name: \"last\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, nth_value(-1): Ok(Field { name: \"nth_value(-1)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, nth_value(-2): Ok(Field { name: \"nth_value(-2)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", - " DataSourceExec: partitions=1, partition_sizes=[3]", - ]; // Get string representation of the plan - let actual = get_plan_string(&physical_plan); - assert_eq!( - expected, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + assert_snapshot!(displayable(physical_plan.as_ref()).indent(true), @r#" + BoundedWindowAggExec: wdw=[last: Field { name: "last", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-1): Field { name: "nth_value(-1)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(-2): Field { name: "nth_value(-2)", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + DataSourceExec: partitions=1, partition_sizes=[3] + "#); assert_snapshot!(batches_to_string(&batches), @r#" +---+------+---------------+---------------+ @@ -1775,18 +1811,12 @@ mod tests { let plan = projection_exec(window)?; - let expected_plan = vec![ - "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]@2 as col_2]", - " BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]: Ok(Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", - " StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]", - ]; - // Get string representation of the plan - let actual = get_plan_string(&plan); - assert_eq!( - expected_plan, actual, - "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_plan:#?}\nactual:\n\n{actual:#?}\n\n" - ); + assert_snapshot!(displayable(plan.as_ref()).indent(true), @r#" + ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: "sn", index: 0 }]) PARTITION BY: [[Column { name: "hash", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: "sn", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2] + BoundedWindowAggExec: wdw=[count([Column { name: "sn", index: 0 }]) PARTITION BY: [[Column { name: "hash", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: "sn", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Field { name: "count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING], mode=[Linear] + StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST] + "#); let task_ctx = task_context(); let batches = collect_with_timeout(plan, task_ctx, timeout_duration).await?; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d38bf2a186a87..cd35325eb3d7a 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -22,7 +22,6 @@ mod utils; mod window_agg_exec; use std::borrow::Borrow; -use std::iter; use std::sync::Arc; use crate::{ @@ -30,11 +29,11 @@ use crate::{ InputOrderMode, PhysicalExpr, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow_schema::SortOptions; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ - PartitionEvaluator, ReversedUDWF, SetMonotonicity, WindowFrame, + LimitEffect, PartitionEvaluator, ReversedUDWF, SetMonotonicity, WindowFrame, WindowFunctionDefinition, WindowUDF, }; use datafusion_functions_window_common::expr::ExpressionArgs; @@ -42,12 +41,13 @@ use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{ - reverse_order_bys, - window::{SlidingAggregateWindowExpr, StandardWindowFunctionExpr}, - ConstExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, +use datafusion_physical_expr::window::{ + SlidingAggregateWindowExpr, StandardWindowFunctionExpr, +}; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::sort_expr::LexRequirement; use itertools::Itertools; @@ -65,16 +65,11 @@ pub fn schema_add_window_field( window_fn: &WindowFunctionDefinition, fn_name: &str, ) -> Result> { - let data_types = args - .iter() - .map(|e| Arc::clone(e).as_ref().data_type(schema)) - .collect::>>()?; - let nullability = args + let fields = args .iter() - .map(|e| Arc::clone(e).as_ref().nullable(schema)) + .map(|e| Arc::clone(e).as_ref().return_field(schema)) .collect::>>()?; - let window_expr_return_type = - window_fn.return_type(&data_types, &nullability, fn_name)?; + let window_expr_return_field = window_fn.return_field(&fields, fn_name)?; let mut window_fields = schema .fields() .iter() @@ -84,11 +79,10 @@ pub fn schema_add_window_field( if let WindowFunctionDefinition::AggregateUDF(_) = window_fn { Ok(Arc::new(Schema::new(window_fields))) } else { - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - false, - )]); + window_fields.extend_from_slice(&[window_expr_return_field + .as_ref() + .clone() + .with_name(fn_name)]); Ok(Arc::new(Schema::new(window_fields))) } } @@ -100,28 +94,41 @@ pub fn create_window_expr( name: String, args: &[Arc], partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, - input_schema: &Schema, + input_schema: SchemaRef, ignore_nulls: bool, + distinct: bool, + filter: Option>, ) -> Result> { Ok(match fun { WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) - .schema(Arc::new(input_schema.clone())) - .alias(name) - .with_ignore_nulls(ignore_nulls) - .build() - .map(Arc::new)?; + let aggregate = if distinct { + AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(input_schema) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .distinct() + .build() + .map(Arc::new)? + } else { + AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(input_schema) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .build() + .map(Arc::new)? + }; window_expr_from_aggregate_expr( partition_by, order_by, window_frame, aggregate, + filter, ) } WindowFunctionDefinition::WindowUDF(fun) => Arc::new(StandardWindowExpr::new( - create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, + create_udwf_window_expr(fun, args, &input_schema, name, ignore_nulls)?, partition_by, order_by, window_frame, @@ -132,9 +139,10 @@ pub fn create_window_expr( /// Creates an appropriate [`WindowExpr`] based on the window frame and fn window_expr_from_aggregate_expr( partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, aggregate: Arc, + filter: Option>, ) -> Arc { // Is there a potentially unlimited sized window frame? let unbounded_window = window_frame.is_ever_expanding(); @@ -145,6 +153,7 @@ fn window_expr_from_aggregate_expr( partition_by, order_by, window_frame, + filter, )) } else { Arc::new(PlainAggregateWindowExpr::new( @@ -152,6 +161,7 @@ fn window_expr_from_aggregate_expr( partition_by, order_by, window_frame, + filter, )) } } @@ -165,15 +175,15 @@ pub fn create_udwf_window_expr( ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason - let input_types: Vec<_> = args + let input_fields: Vec<_> = args .iter() - .map(|arg| arg.data_type(input_schema)) + .map(|arg| arg.return_field(input_schema)) .collect::>()?; let udwf_expr = Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), - input_types, + input_fields, name, is_reversed: false, ignore_nulls, @@ -202,8 +212,8 @@ pub struct WindowUDFExpr { args: Vec>, /// Display name name: String, - /// Types of input expressions - input_types: Vec, + /// Fields of input expressions + input_fields: Vec, /// This is set to `true` only if the user-defined window function /// expression supports evaluation in reverse order, and the /// evaluation order is reversed. @@ -223,21 +233,21 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { self.fun - .field(WindowUDFFieldArgs::new(&self.input_types, &self.name)) + .field(WindowUDFFieldArgs::new(&self.input_fields, &self.name)) } fn expressions(&self) -> Vec> { self.fun - .expressions(ExpressionArgs::new(&self.args, &self.input_types)) + .expressions(ExpressionArgs::new(&self.args, &self.input_fields)) } fn create_evaluator(&self) -> Result> { self.fun .partition_evaluator_factory(PartitionEvaluatorArgs::new( &self.args, - &self.input_types, + &self.input_fields, self.is_reversed, self.ignore_nulls, )) @@ -255,7 +265,7 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { fun, args: self.args.clone(), name: self.name.clone(), - input_types: self.input_types.clone(), + input_fields: self.input_fields.clone(), is_reversed: !self.is_reversed, ignore_nulls: self.ignore_nulls, })), @@ -271,6 +281,10 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { PhysicalSortExpr { expr, options } }) } + + fn limit_effect(&self) -> LimitEffect { + self.fun.inner().limit_effect(self.args.as_slice()) + } } pub(crate) fn calc_requirements< @@ -279,26 +293,33 @@ pub(crate) fn calc_requirements< >( partition_by_exprs: impl IntoIterator, orderby_sort_exprs: impl IntoIterator, -) -> Option { - let mut sort_reqs = LexRequirement::new( - partition_by_exprs - .into_iter() - .map(|partition_by| { - PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None) - }) - .collect::>(), - ); +) -> Option { + let mut sort_reqs_with_partition = partition_by_exprs + .into_iter() + .map(|partition_by| { + PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None) + }) + .collect::>(); + let mut sort_reqs = vec![]; for element in orderby_sort_exprs.into_iter() { let PhysicalSortExpr { expr, options } = element.borrow(); - if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { - sort_reqs.push(PhysicalSortRequirement::new( - Arc::clone(expr), - Some(*options), - )); + let sort_req = PhysicalSortRequirement::new(Arc::clone(expr), Some(*options)); + if !sort_reqs_with_partition.iter().any(|e| e.expr.eq(expr)) { + sort_reqs_with_partition.push(sort_req.clone()); + } + if !sort_reqs + .iter() + .any(|e: &PhysicalSortRequirement| e.expr.eq(expr)) + { + sort_reqs.push(sort_req); } } - // Convert empty result to None. Otherwise wrap result inside Some() - (!sort_reqs.is_empty()).then_some(sort_reqs) + + let mut alternatives = vec![]; + alternatives.extend(LexRequirement::new(sort_reqs_with_partition)); + alternatives.extend(LexRequirement::new(sort_reqs)); + + OrderingRequirements::new_alternatives(alternatives, false) } /// This function calculates the indices such that when partition by expressions reordered with the indices @@ -309,18 +330,18 @@ pub(crate) fn calc_requirements< pub fn get_ordered_partition_by_indices( partition_by_exprs: &[Arc], input: &Arc, -) -> Vec { +) -> Result> { let (_, indices) = input .equivalence_properties() - .find_longest_permutation(partition_by_exprs); - indices + .find_longest_permutation(partition_by_exprs)?; + Ok(indices) } pub(crate) fn get_partition_by_sort_exprs( input: &Arc, partition_by_exprs: &[Arc], ordered_partition_by_indices: &[usize], -) -> Result { +) -> Result> { let ordered_partition_exprs = ordered_partition_by_indices .iter() .map(|idx| Arc::clone(&partition_by_exprs[*idx])) @@ -329,7 +350,7 @@ pub(crate) fn get_partition_by_sort_exprs( assert!(ordered_partition_by_indices.len() <= partition_by_exprs.len()); let (ordering, _) = input .equivalence_properties() - .find_longest_permutation(&ordered_partition_exprs); + .find_longest_permutation(&ordered_partition_exprs)?; if ordering.len() == ordered_partition_exprs.len() { Ok(ordering) } else { @@ -341,11 +362,11 @@ pub(crate) fn window_equivalence_properties( schema: &SchemaRef, input: &Arc, window_exprs: &[Arc], -) -> EquivalenceProperties { +) -> Result { // We need to update the schema, so we can't directly use input's equivalence // properties. let mut window_eq_properties = EquivalenceProperties::new(Arc::clone(schema)) - .extend(input.equivalence_properties().clone()); + .extend(input.equivalence_properties().clone())?; let window_schema_len = schema.fields.len(); let input_schema_len = window_schema_len - window_exprs.len(); @@ -354,25 +375,51 @@ pub(crate) fn window_equivalence_properties( for (i, expr) in window_exprs.iter().enumerate() { let partitioning_exprs = expr.partition_by(); let no_partitioning = partitioning_exprs.is_empty(); - // Collect columns defining partitioning, and construct all `SortOptions` - // variations for them. Then, we will check each one whether it satisfies - // the existing ordering provided by the input plan. - let partition_by_orders = partitioning_exprs - .iter() - .map(|pb_order| sort_options_resolving_constant(Arc::clone(pb_order))); - let all_satisfied_lexs = partition_by_orders - .multi_cartesian_product() - .map(LexOrdering::new) - .filter(|lex| window_eq_properties.ordering_satisfy(lex)) - .collect::>(); + + // Find "one" valid ordering for partition columns to avoid exponential complexity. + // see https://github.com/apache/datafusion/issues/17401 + let mut all_satisfied_lexs = vec![]; + let mut candidate_ordering = vec![]; + + for partition_expr in partitioning_exprs.iter() { + let sort_options = + sort_options_resolving_constant(Arc::clone(partition_expr), true); + + // Try each sort option and pick the first one that works + let mut found = false; + for sort_expr in sort_options.into_iter() { + candidate_ordering.push(sort_expr); + if let Some(lex) = LexOrdering::new(candidate_ordering.clone()) { + if window_eq_properties.ordering_satisfy(lex)? { + found = true; + break; + } + } + // This option didn't work, remove it and try the next one + candidate_ordering.pop(); + } + // If no sort option works for this column, we can't build a valid ordering + if !found { + candidate_ordering.clear(); + break; + } + } + + // If we successfully built an ordering for all columns, use it + // When there are no partition expressions, candidate_ordering will be empty and won't be added + if candidate_ordering.len() == partitioning_exprs.len() { + if let Some(lex) = LexOrdering::new(candidate_ordering) { + all_satisfied_lexs.push(lex); + } + } // If there is a partitioning, and no possible ordering cannot satisfy // the input plan's orderings, then we cannot further introduce any // new orderings for the window plan. if !no_partitioning && all_satisfied_lexs.is_empty() { - return window_eq_properties; + return Ok(window_eq_properties); } else if let Some(std_expr) = expr.as_any().downcast_ref::() { - std_expr.add_equal_orderings(&mut window_eq_properties); + std_expr.add_equal_orderings(&mut window_eq_properties)?; } else if let Some(plain_expr) = expr.as_any().downcast_ref::() { @@ -380,26 +427,28 @@ pub(crate) fn window_equivalence_properties( // unbounded starting point. // First, check if the frame covers the whole table: if plain_expr.get_window_frame().end_bound.is_unbounded() { - let window_col = Column::new(expr.name(), i + input_schema_len); + let window_col = + Arc::new(Column::new(expr.name(), i + input_schema_len)) as _; if no_partitioning { // Window function has a constant result across the table: - window_eq_properties = window_eq_properties - .with_constants(iter::once(ConstExpr::new(Arc::new(window_col)))) + window_eq_properties + .add_constants(std::iter::once(ConstExpr::from(window_col)))? } else { // Window function results in a partial constant value in // some ordering. Adjust the ordering equivalences accordingly: let new_lexs = all_satisfied_lexs.into_iter().flat_map(|lex| { - let orderings = lex.take_exprs(); - let new_partial_consts = - sort_options_resolving_constant(Arc::new(window_col.clone())); + let new_partial_consts = sort_options_resolving_constant( + Arc::clone(&window_col), + false, + ); new_partial_consts.into_iter().map(move |partial| { - let mut existing = orderings.clone(); + let mut existing = lex.clone(); existing.push(partial); - LexOrdering::new(existing) + existing }) }); - window_eq_properties.add_new_orderings(new_lexs); + window_eq_properties.add_orderings(new_lexs); } } else { // The window frame is ever expanding, so set monotonicity comes @@ -407,7 +456,7 @@ pub(crate) fn window_equivalence_properties( plain_expr.add_equal_orderings( &mut window_eq_properties, window_expr_indices[i], - ); + )?; } } else if let Some(sliding_expr) = expr.as_any().downcast_ref::() @@ -425,22 +474,18 @@ pub(crate) fn window_equivalence_properties( let window_col = Column::new(expr.name(), i + input_schema_len); if no_partitioning { // Reverse set-monotonic cases with no partitioning: - let new_ordering = - vec![LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(increasing, true), - )])]; - window_eq_properties.add_new_orderings(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(increasing, true), + )]); } else { // Reverse set-monotonic cases for all orderings: - for lex in all_satisfied_lexs.into_iter() { - let mut existing = lex.take_exprs(); - existing.push(PhysicalSortExpr::new( + for mut lex in all_satisfied_lexs.into_iter() { + lex.push(PhysicalSortExpr::new( Arc::new(window_col.clone()), SortOptions::new(increasing, true), )); - window_eq_properties - .add_new_ordering(LexOrdering::new(existing)); + window_eq_properties.add_ordering(lex); } } } @@ -451,44 +496,73 @@ pub(crate) fn window_equivalence_properties( // utilize set-monotonicity since the set shrinks as the frame // boundary starts "touching" the end of the table. else if frame.is_causal() { - let mut args_all_lexs = sliding_expr - .get_aggregate_expr() - .expressions() - .into_iter() - .map(sort_options_resolving_constant) - .multi_cartesian_product(); - + // Find one valid ordering for aggregate arguments instead of + // checking all combinations + let aggregate_exprs = sliding_expr.get_aggregate_expr().expressions(); + let mut candidate_order = vec![]; let mut asc = false; - if args_all_lexs.any(|order| { - if let Some(f) = order.first() { - asc = !f.options.descending; + + for (idx, expr) in aggregate_exprs.iter().enumerate() { + let mut found = false; + let sort_options = + sort_options_resolving_constant(Arc::clone(expr), false); + + // Try each option and pick the first that works + for sort_expr in sort_options.into_iter() { + let is_asc = !sort_expr.options.descending; + candidate_order.push(sort_expr); + + if let Some(lex) = LexOrdering::new(candidate_order.clone()) { + if window_eq_properties.ordering_satisfy(lex)? { + if idx == 0 { + // The first column's ordering direction determines the overall + // monotonicity behavior of the window result. + // - If the aggregate has increasing set monotonicity (e.g., MAX, COUNT) + // and the first arg is ascending, the window result is increasing + // - If the aggregate has decreasing set monotonicity (e.g., MIN) + // and the first arg is ascending, the window result is also increasing + // This flag is used to determine the final window column ordering. + asc = is_asc; + } + found = true; + break; + } + } + // This option didn't work, remove it and try the next one + candidate_order.pop(); + } + + // If we couldn't extend the ordering, stop trying + if !found { + break; } - window_eq_properties.ordering_satisfy(&LexOrdering::new(order)) - }) { + } + + // Check if we successfully built a complete ordering + let satisfied = candidate_order.len() == aggregate_exprs.len() + && !aggregate_exprs.is_empty(); + + if satisfied { let increasing = set_monotonicity.eq(&SetMonotonicity::Increasing); let window_col = Column::new(expr.name(), i + input_schema_len); if increasing && (asc || no_partitioning) { - let new_ordering = - LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(false, false), - )]); - window_eq_properties.add_new_ordering(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(false, false), + )]); } else if !increasing && (!asc || no_partitioning) { - let new_ordering = - LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(true, false), - )]); - window_eq_properties.add_new_ordering(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(true, false), + )]); }; } } } } } - window_eq_properties + Ok(window_eq_properties) } /// Constructs the best-fitting windowing operator (a `WindowAggExec` or a @@ -515,7 +589,7 @@ pub fn get_best_fitting_window( let orderby_keys = window_exprs[0].order_by(); let (should_reverse, input_order_mode) = if let Some((should_reverse, input_order_mode)) = - get_window_mode(partitionby_exprs, orderby_keys, input) + get_window_mode(partitionby_exprs, orderby_keys, input)? { (should_reverse, input_order_mode) } else { @@ -581,35 +655,29 @@ pub fn get_best_fitting_window( /// the mode this window operator should work in to accommodate the existing ordering. pub fn get_window_mode( partitionby_exprs: &[Arc], - orderby_keys: &LexOrdering, + orderby_keys: &[PhysicalSortExpr], input: &Arc, -) -> Option<(bool, InputOrderMode)> { - let input_eqs = input.equivalence_properties().clone(); - let mut partition_by_reqs: LexRequirement = LexRequirement::new(vec![]); - let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); - vec![].extend(indices.iter().map(|&idx| PhysicalSortRequirement { - expr: Arc::clone(&partitionby_exprs[idx]), - options: None, - })); - partition_by_reqs - .inner - .extend(indices.iter().map(|&idx| PhysicalSortRequirement { +) -> Result> { + let mut input_eqs = input.equivalence_properties().clone(); + let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs)?; + let partition_by_reqs = indices + .iter() + .map(|&idx| PhysicalSortRequirement { expr: Arc::clone(&partitionby_exprs[idx]), options: None, - })); + }) + .collect::>(); // Treat partition by exprs as constant. During analysis of requirements are satisfied. - let const_exprs = partitionby_exprs.iter().map(ConstExpr::from); - let partition_by_eqs = input_eqs.with_constants(const_exprs); - let order_by_reqs = LexRequirement::from(orderby_keys.clone()); - let reverse_order_by_reqs = LexRequirement::from(reverse_order_bys(orderby_keys)); - for (should_swap, order_by_reqs) in - [(false, order_by_reqs), (true, reverse_order_by_reqs)] + let const_exprs = partitionby_exprs.iter().cloned().map(ConstExpr::from); + input_eqs.add_constants(const_exprs)?; + let reverse_orderby_keys = + orderby_keys.iter().map(|e| e.reverse()).collect::>(); + for (should_swap, orderbys) in + [(false, orderby_keys), (true, reverse_orderby_keys.as_ref())] { - let req = LexRequirement::new( - [partition_by_reqs.inner.clone(), order_by_reqs.inner].concat(), - ) - .collapse(); - if partition_by_eqs.ordering_satisfy_requirement(&req) { + let mut req = partition_by_reqs.clone(); + req.extend(orderbys.iter().cloned().map(Into::into)); + if req.is_empty() || input_eqs.ordering_satisfy_requirement(req)? { // Window can be run with existing ordering let mode = if indices.len() == partitionby_exprs.len() { InputOrderMode::Sorted @@ -618,17 +686,51 @@ pub fn get_window_mode( } else { InputOrderMode::PartiallySorted(indices) }; - return Some((should_swap, mode)); + return Ok(Some((should_swap, mode))); } } - None + Ok(None) } -fn sort_options_resolving_constant(expr: Arc) -> Vec { - vec![ - PhysicalSortExpr::new(Arc::clone(&expr), SortOptions::new(false, false)), - PhysicalSortExpr::new(expr, SortOptions::new(true, true)), - ] +/// Generates sort option variations for a given expression. +/// +/// This function is used to handle constant columns in window operations. Since constant +/// columns can be considered as having any ordering, we generate multiple sort options +/// to explore different ordering possibilities. +/// +/// # Parameters +/// - `expr`: The physical expression to generate sort options for +/// - `only_monotonic`: If false, generates all 4 possible sort options (ASC/DESC × NULLS FIRST/LAST). +/// If true, generates only 2 options that preserve set monotonicity. +/// +/// # When to use `only_monotonic = false`: +/// Use for PARTITION BY columns where we want to explore all possible orderings to find +/// one that matches the existing data ordering. +/// +/// # When to use `only_monotonic = true`: +/// Use for aggregate/window function arguments where set monotonicity needs to be preserved. +/// Only generates ASC NULLS LAST and DESC NULLS FIRST because: +/// - Set monotonicity is broken if data has increasing order but nulls come first +/// - Set monotonicity is broken if data has decreasing order but nulls come last +fn sort_options_resolving_constant( + expr: Arc, + only_monotonic: bool, +) -> Vec { + if only_monotonic { + // Generate only the 2 options that preserve set monotonicity + vec![ + PhysicalSortExpr::new(Arc::clone(&expr), SortOptions::new(false, false)), // ASC NULLS LAST + PhysicalSortExpr::new(expr, SortOptions::new(true, true)), // DESC NULLS FIRST + ] + } else { + // Generate all 4 possible sort options for partition columns + vec![ + PhysicalSortExpr::new(Arc::clone(&expr), SortOptions::new(false, false)), // ASC NULLS LAST + PhysicalSortExpr::new(Arc::clone(&expr), SortOptions::new(false, true)), // ASC NULLS FIRST + PhysicalSortExpr::new(Arc::clone(&expr), SortOptions::new(true, false)), // DESC NULLS LAST + PhysicalSortExpr::new(expr, SortOptions::new(true, true)), // DESC NULLS FIRST + ] + } } #[cfg(test)] @@ -641,12 +743,13 @@ mod tests { use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use arrow::compute::SortOptions; + use arrow_schema::{DataType, Field}; use datafusion_execution::TaskContext; - use datafusion_functions_aggregate::count::count_udaf; - use futures::FutureExt; use InputOrderMode::{Linear, PartiallySorted, Sorted}; + use futures::FutureExt; + fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); @@ -696,16 +799,14 @@ mod tests { /// Created a sorted Streaming Table exec pub fn streaming_table_exec( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, infinite_source: bool, ) -> Result> { - let sort_exprs = sort_exprs.into_iter().collect(); - Ok(Arc::new(StreamingTableExec::try_new( Arc::clone(schema), vec![], None, - Some(sort_exprs), + Some(ordering), infinite_source, None, )?)) @@ -719,25 +820,38 @@ mod tests { ( vec!["a"], vec![("b", true, true)], - vec![("a", None), ("b", Some((true, true)))], + vec![ + vec![("a", None), ("b", Some((true, true)))], + vec![("b", Some((true, true)))], + ], ), // PARTITION BY a, ORDER BY a ASC NULLS FIRST - (vec!["a"], vec![("a", true, true)], vec![("a", None)]), + ( + vec!["a"], + vec![("a", true, true)], + vec![vec![("a", None)], vec![("a", Some((true, true)))]], + ), // PARTITION BY a, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST ( vec!["a"], vec![("b", true, true), ("c", false, false)], vec![ - ("a", None), - ("b", Some((true, true))), - ("c", Some((false, false))), + vec![ + ("a", None), + ("b", Some((true, true))), + ("c", Some((false, false))), + ], + vec![("b", Some((true, true))), ("c", Some((false, false)))], ], ), // PARTITION BY a, c, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST ( vec!["a", "c"], vec![("b", true, true), ("c", false, false)], - vec![("a", None), ("c", None), ("b", Some((true, true)))], + vec![ + vec![("a", None), ("c", None), ("b", Some((true, true)))], + vec![("b", Some((true, true))), ("c", Some((false, false)))], + ], ), ]; for (pb_params, ob_params, expected_params) in test_data { @@ -749,25 +863,26 @@ mod tests { let mut orderbys = vec![]; for (col_name, descending, nulls_first) in ob_params { let expr = col(col_name, &schema)?; - let options = SortOptions { - descending, - nulls_first, - }; - orderbys.push(PhysicalSortExpr { expr, options }); + let options = SortOptions::new(descending, nulls_first); + orderbys.push(PhysicalSortExpr::new(expr, options)); } - let mut expected: Option = None; - for (col_name, reqs) in expected_params { - let options = reqs.map(|(descending, nulls_first)| SortOptions { - descending, - nulls_first, - }); - let expr = col(col_name, &schema)?; - let res = PhysicalSortRequirement::new(expr, options); - if let Some(expected) = &mut expected { - expected.push(res); - } else { - expected = Some(LexRequirement::new(vec![res])); + let mut expected: Option = None; + for expected_param in expected_params.clone() { + let mut requirements = vec![]; + for (col_name, reqs) in expected_param { + let options = reqs.map(|(descending, nulls_first)| { + SortOptions::new(descending, nulls_first) + }); + let expr = col(col_name, &schema)?; + requirements.push(PhysicalSortRequirement::new(expr, options)); + } + if let Some(requirements) = LexRequirement::new(requirements) { + if let Some(alts) = expected.as_mut() { + alts.add_alternative(requirements); + } else { + expected = Some(OrderingRequirements::new(requirements)); + } } } assert_eq!(calc_requirements(partitionbys, orderbys), expected); @@ -789,10 +904,12 @@ mod tests { "count".to_owned(), &[col("a", &schema)?], &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), - schema.as_ref(), + schema, + false, false, + None, )?], blocking_exec, false, @@ -893,13 +1010,14 @@ mod tests { // Columns a,c are nullable whereas b,d are not nullable. // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST // Column e is not ordered. - let sort_exprs = vec![ + let ordering = [ sort_expr("a", &test_schema), sort_expr("b", &test_schema), sort_expr("c", &test_schema), sort_expr("d", &test_schema), - ]; - let exec_unbounded = streaming_table_exec(&test_schema, sort_exprs, true)?; + ] + .into(); + let exec_unbounded = streaming_table_exec(&test_schema, ordering, true)?; // test cases consists of vector of tuples. Where each tuple represents a single test case. // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns @@ -986,7 +1104,7 @@ mod tests { partition_by_exprs.push(col(col_name, &test_schema)?); } - let mut order_by_exprs = LexOrdering::default(); + let mut order_by_exprs = vec![]; for col_name in order_by_params { let expr = col(col_name, &test_schema)?; // Give default ordering, this is same with input ordering direction @@ -994,11 +1112,8 @@ mod tests { let options = SortOptions::default(); order_by_exprs.push(PhysicalSortExpr { expr, options }); } - let res = get_window_mode( - &partition_by_exprs, - order_by_exprs.as_ref(), - &exec_unbounded, - ); + let res = + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?; // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( @@ -1016,13 +1131,14 @@ mod tests { // Columns a,c are nullable whereas b,d are not nullable. // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST // Column e is not ordered. - let sort_exprs = vec![ + let ordering = [ sort_expr("a", &test_schema), sort_expr("b", &test_schema), sort_expr("c", &test_schema), sort_expr("d", &test_schema), - ]; - let exec_unbounded = streaming_table_exec(&test_schema, sort_exprs, true)?; + ] + .into(); + let exec_unbounded = streaming_table_exec(&test_schema, ordering, true)?; // test cases consists of vector of tuples. Where each tuple represents a single test case. // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns @@ -1151,7 +1267,7 @@ mod tests { partition_by_exprs.push(col(col_name, &test_schema)?); } - let mut order_by_exprs = LexOrdering::default(); + let mut order_by_exprs = vec![]; for (col_name, descending, nulls_first) in order_by_params { let expr = col(col_name, &test_schema)?; let options = SortOptions { @@ -1162,7 +1278,7 @@ mod tests { } assert_eq!( - get_window_mode(&partition_by_exprs, order_by_exprs.as_ref(), &exec_unbounded), + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?, *expected, "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" ); diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index 3c42d3032ed5d..1b7cb9bb76e1b 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -44,7 +44,9 @@ use datafusion_common::stats::Precision; use datafusion_common::utils::{evaluate_partition_ranges, transpose}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, +}; use futures::{ready, Stream, StreamExt}; @@ -79,8 +81,8 @@ impl WindowAggExec { let schema = Arc::new(schema); let ordered_partition_by_indices = - get_ordered_partition_by_indices(window_expr[0].partition_by(), &input); - let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr); + get_ordered_partition_by_indices(window_expr[0].partition_by(), &input)?; + let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr)?; Ok(Self { input, window_expr, @@ -107,7 +109,7 @@ impl WindowAggExec { // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points - pub fn partition_by_sort_keys(&self) -> Result { + pub fn partition_by_sort_keys(&self) -> Result> { let partition_by = self.window_expr()[0].partition_by(); get_partition_by_sort_exprs( &self.input, @@ -121,9 +123,9 @@ impl WindowAggExec { schema: SchemaRef, input: &Arc, window_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - let eq_properties = window_equivalence_properties(&schema, input, window_exprs); + let eq_properties = window_equivalence_properties(&schema, input, window_exprs)?; // Get output partitioning: // Because we can have repartitioning using the partition keys this @@ -131,13 +133,13 @@ impl WindowAggExec { let output_partitioning = input.output_partitioning().clone(); // Construct properties cache: - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, // TODO: Emission type and boundedness information can be enhanced here EmissionType::Final, input.boundedness(), - ) + )) } pub fn partition_keys(&self) -> Vec> { @@ -156,6 +158,24 @@ impl WindowAggExec { .unwrap_or_else(Vec::new) } } + + fn statistics_inner(&self) -> Result { + let input_stat = self.input.partition_statistics(None)?; + let win_cols = self.window_expr.len(); + let input_cols = self.input.schema().fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(input_stat.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) + } + Ok(Statistics { + num_rows: input_stat.num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } } impl DisplayAs for WindowAggExec { @@ -216,17 +236,17 @@ impl ExecutionPlan for WindowAggExec { vec![true] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); if self.ordered_partition_by_indices.len() < partition_bys.len() { - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } else { let partition_bys = self .ordered_partition_by_indices .iter() .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } } @@ -271,21 +291,15 @@ impl ExecutionPlan for WindowAggExec { } fn statistics(&self) -> Result { - let input_stat = self.input.statistics()?; - let win_cols = self.window_expr.len(); - let input_cols = self.input.schema().fields().len(); - // TODO stats: some windowing function will maintain invariants such as min, max... - let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - // copy stats of the input to the beginning of the schema. - column_statistics.extend(input_stat.column_statistics); - for _ in 0..win_cols { - column_statistics.push(ColumnStatistics::new_unknown()) + self.statistics_inner() + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_none() { + self.statistics_inner() + } else { + Ok(Statistics::new_unknown(&self.schema())) } - Ok(Statistics { - num_rows: input_stat.num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) } } @@ -307,7 +321,7 @@ pub struct WindowAggStream { batches: Vec, finished: bool, window_expr: Vec>, - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, baseline_metrics: BaselineMetrics, ordered_partition_by_indices: Vec, } @@ -319,7 +333,7 @@ impl WindowAggStream { window_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, ordered_partition_by_indices: Vec, ) -> Result { // In WindowAggExec all partition by columns should be ordered. diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index 126a7d0bba294..40a22f94b81f6 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -20,13 +20,14 @@ use std::any::Any; use std::sync::{Arc, Mutex}; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::coop::cooperative; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::memory::MemoryStream; +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - metrics::{ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, Statistics, }; -use crate::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -56,7 +57,7 @@ impl ReservedBatches { /// See /// This table serves as a mirror or buffer between each iteration of a recursive query. #[derive(Debug)] -pub(super) struct WorkTable { +pub struct WorkTable { batches: Mutex>, } @@ -131,16 +132,6 @@ impl WorkTableExec { Arc::clone(&self.schema) } - pub(super) fn with_work_table(&self, work_table: Arc) -> Self { - Self { - name: self.name.clone(), - schema: Arc::clone(&self.schema), - metrics: ExecutionPlanMetricsSet::new(), - work_table, - cache: self.cache.clone(), - } - } - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { PlanProperties::new( @@ -149,6 +140,7 @@ impl WorkTableExec { EmissionType::Incremental, Boundedness::Bounded, ) + .with_scheduling_type(SchedulingType::Cooperative) } } @@ -186,14 +178,6 @@ impl ExecutionPlan for WorkTableExec { vec![] } - fn maintains_input_order(&self) -> Vec { - vec![false] - } - - fn benefits_from_input_partitioning(&self) -> Vec { - vec![false] - } - fn with_new_children( self: Arc, _: Vec>, @@ -214,10 +198,11 @@ impl ExecutionPlan for WorkTableExec { ); } let batch = self.work_table.take()?; - Ok(Box::pin( + + let stream = MemoryStream::try_new(batch.batches, Arc::clone(&self.schema), None)? - .with_reservation(batch.reservation), - )) + .with_reservation(batch.reservation); + Ok(Box::pin(cooperative(stream))) } fn metrics(&self) -> Option { @@ -227,6 +212,33 @@ impl ExecutionPlan for WorkTableExec { fn statistics(&self) -> Result { Ok(Statistics::new_unknown(&self.schema())) } + + fn partition_statistics(&self, _partition: Option) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } + + /// Injects run-time state into this `WorkTableExec`. + /// + /// The only state this node currently understands is an [`Arc`]. + /// If `state` can be down-cast to that type, a new `WorkTableExec` backed + /// by the provided work table is returned. Otherwise `None` is returned + /// so that callers can attempt to propagate the state further down the + /// execution plan tree. + fn with_new_state( + &self, + state: Arc, + ) -> Option> { + // Down-cast to the expected state type; propagate `None` on failure + let work_table = state.downcast::().ok()?; + + Some(Arc::new(Self { + name: self.name.clone(), + schema: Arc::clone(&self.schema), + metrics: ExecutionPlanMetricsSet::new(), + work_table, + cache: self.cache.clone(), + })) + } } #[cfg(test)] diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 957cbc253616b..c67c8892a3ded 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -19,9 +19,9 @@ name = "datafusion-proto-common" description = "Protobuf serialization of DataFusion common types" keywords = ["arrow", "query", "sql"] +readme = "README.md" version = { workspace = true } edition = { workspace = true } -readme = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } @@ -39,7 +39,7 @@ name = "datafusion_proto_common" [features] default = [] -json = ["serde", "serde_json", "pbjson"] +json = ["serde", "pbjson"] [dependencies] arrow = { workspace = true } @@ -47,7 +47,6 @@ datafusion-common = { workspace = true } pbjson = { workspace = true, optional = true } prost = { workspace = true } serde = { version = "1.0", optional = true } -serde_json = { workspace = true, optional = true } [dev-dependencies] doc-comment = { workspace = true } diff --git a/datafusion/proto-common/README.md b/datafusion/proto-common/README.md index c8b46424f701e..9c4aa707b0ea6 100644 --- a/datafusion/proto-common/README.md +++ b/datafusion/proto-common/README.md @@ -17,12 +17,21 @@ under the License. --> -# `datafusion-proto-common`: Apache DataFusion Protobuf Serialization / Deserialization +# Apache DataFusion Protobuf Common Serialization / Deserialization -This crate contains code to convert Apache [DataFusion] primitive types to and from -bytes, which can be useful for sending data over the network. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate contains code to convert DataFusion primitive types to and from +bytes using [Protocol Buffers], which can be useful for sending data over the network. See [API Docs] for details and examples. -[datafusion]: https://datafusion.apache.org +Most projects should use the [`datafusion-proto`] crate directly, which re-exports +this module. If you are already using the [`datafusion-proto`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[protocol buffers]: https://protobuf.dev/ +[`datafusion-proto`]: https://crates.io/crates/datafusion-proto [api docs]: http://docs.rs/datafusion-proto/latest diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index cfd3368b0c5ee..ef56d2697d818 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -34,5 +34,5 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic -pbjson-build = "=0.7.0" -prost-build = "=0.13.5" +pbjson-build = "=0.8.0" +prost-build = "=0.14.1" diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index bbeea5e1ec237..267953556b166 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -55,6 +55,8 @@ message NdJsonFormat { JsonOptions options = 1; } +message ArrowFormat {} + message PrimaryKeyConstraint{ repeated uint64 indices = 1; @@ -85,6 +87,7 @@ enum JoinType { RIGHTSEMI = 6; RIGHTANTI = 7; LEFTMARK = 8; + RIGHTMARK = 9; } enum JoinConstraint { @@ -92,6 +95,11 @@ enum JoinConstraint { USING = 1; } +enum NullEquality { + NULL_EQUALS_NOTHING = 0; + NULL_EQUALS_NULL = 1; +} + message AvroOptions {} message ArrowOptions {} @@ -108,7 +116,6 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; - bool dict_ordered = 6; } message Timestamp{ @@ -129,7 +136,19 @@ enum IntervalUnit{ MonthDayNano = 2; } -message Decimal{ +message Decimal32Type { + reserved 1, 2; + uint32 precision = 3; + int32 scale = 4; +} + +message Decimal64Type { + reserved 1, 2; + uint32 precision = 3; + int32 scale = 4; +} + +message Decimal128Type { reserved 1, 2; uint32 precision = 3; int32 scale = 4; @@ -279,6 +298,8 @@ message ScalarValue{ ScalarNestedValue struct_value = 32; ScalarNestedValue map_value = 41; + Decimal32 decimal32_value = 43; + Decimal64 decimal64_value = 44; Decimal128 decimal128_value = 20; Decimal256 decimal256_value = 39; @@ -303,6 +324,18 @@ message ScalarValue{ } } +message Decimal32{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + +message Decimal64{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + message Decimal128{ bytes value = 1; int64 p = 2; @@ -345,7 +378,9 @@ message ArrowType{ TimeUnit TIME32 = 21 ; TimeUnit TIME64 = 22 ; IntervalUnit INTERVAL = 23 ; - Decimal DECIMAL = 24 ; + Decimal32Type DECIMAL32 = 40; + Decimal64Type DECIMAL64 = 41; + Decimal128Type DECIMAL128 = 24; Decimal256Type DECIMAL256 = 36; List LIST = 25; List LARGE_LIST = 26; @@ -425,6 +460,7 @@ message CsvOptions { bytes double_quote = 15; // Indicates if quotes are doubled bytes newlines_in_values = 16; // Indicates if newlines are supported in values bytes terminator = 17; // Optional terminator character as a byte + bytes truncated_rows = 18; // Indicates if truncated rows are allowed } // Options controlling CSV format @@ -473,9 +509,7 @@ message ParquetColumnOptions { uint64 bloom_filter_ndv = 7; } - oneof max_statistics_size_opt { - uint32 max_statistics_size = 8; - } + reserved 8; // used to be uint32 max_statistics_size = 8; } message ParquetOptions { @@ -514,9 +548,7 @@ message ParquetOptions { string statistics_enabled = 13; } - oneof max_statistics_size_opt { - uint64 max_statistics_size = 14; - } + reserved 14; // used to be uint32 max_statistics_size = 20; oneof column_index_truncate_length_opt { uint64 column_index_truncate_length = 17; @@ -545,6 +577,14 @@ message ParquetOptions { uint64 max_row_group_size = 15; string created_by = 16; + + oneof coerce_int96_opt { + string coerce_int96 = 32; + } + + oneof max_predicate_cache_size_opt { + uint64 max_predicate_cache_size = 33; + } } enum JoinSide { diff --git a/datafusion/proto-common/src/common.rs b/datafusion/proto-common/src/common.rs index 61711dcf8e088..9af63e3b07365 100644 --- a/datafusion/proto-common/src/common.rs +++ b/datafusion/proto-common/src/common.rs @@ -17,6 +17,7 @@ use datafusion_common::{internal_datafusion_err, DataFusionError}; +/// Return a `DataFusionError::Internal` with the given message pub fn proto_error>(message: S) -> DataFusionError { internal_datafusion_err!("{}", message.into()) } diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index da43a97899565..2d07fb8410210 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -257,7 +257,15 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { arrow_type::ArrowTypeEnum::Interval(interval_unit) => { DataType::Interval(parse_i32_to_interval_unit(interval_unit)?) } - arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { + arrow_type::ArrowTypeEnum::Decimal32(protobuf::Decimal32Type { + precision, + scale, + }) => DataType::Decimal32(*precision as u8, *scale as i8), + arrow_type::ArrowTypeEnum::Decimal64(protobuf::Decimal64Type { + precision, + scale, + }) => DataType::Decimal64(*precision as u8, *scale as i8), + arrow_type::ArrowTypeEnum::Decimal128(protobuf::Decimal128Type { precision, scale, }) => DataType::Decimal128(*precision as u8, *scale as i8), @@ -469,6 +477,14 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let null_type: DataType = v.try_into()?; null_type.try_into().map_err(Error::DataFusionError)? } + Value::Decimal32Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal32(Some(i32::from_be_bytes(array)), val.p as u8, val.s as i8) + } + Value::Decimal64Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal64(Some(i64::from_be_bytes(array)), val.p as u8, val.s as i8) + } Value::Decimal128Value(val) => { let array = vec_to_array(val.value.clone()); Self::Decimal128( @@ -900,6 +916,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions { null_regex: (!proto_opts.null_regex.is_empty()) .then(|| proto_opts.null_regex.clone()), comment: proto_opts.comment.first().copied(), + truncated_rows: proto_opts.truncated_rows.first().map(|h| *h != 0), }) } } @@ -938,12 +955,6 @@ impl TryFrom<&protobuf::ParquetOptions> for ParquetOptions { protobuf::parquet_options::StatisticsEnabledOpt::StatisticsEnabled(v) => Some(v), }) .unwrap_or(None), - max_statistics_size: value - .max_statistics_size_opt.as_ref() - .map(|opt| match opt { - protobuf::parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => Some(*v as usize), - }) - .unwrap_or(None), max_row_group_size: value.max_row_group_size as usize, created_by: value.created_by.clone(), column_index_truncate_length: value @@ -984,7 +995,13 @@ impl TryFrom<&protobuf::ParquetOptions> for ParquetOptions { maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as usize, schema_force_view_types: value.schema_force_view_types, binary_as_string: value.binary_as_string, + coerce_int96: value.coerce_int96_opt.clone().map(|opt| match opt { + protobuf::parquet_options::CoerceInt96Opt::CoerceInt96(v) => Some(v), + }).unwrap_or(None), skip_arrow_metadata: value.skip_arrow_metadata, + max_predicate_cache_size: value.max_predicate_cache_size_opt.map(|opt| match opt { + protobuf::parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize(v) => Some(v as usize), + }).unwrap_or(None), }) } } @@ -1006,12 +1023,6 @@ impl TryFrom<&protobuf::ParquetColumnOptions> for ParquetColumnOptions { protobuf::parquet_column_options::StatisticsEnabledOpt::StatisticsEnabled(v) => Some(v), }) .unwrap_or(None), - max_statistics_size: value - .max_statistics_size_opt - .map(|opt| match opt { - protobuf::parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => Some(v as usize), - }) - .unwrap_or(None), encoding: value .encoding_opt.clone() .map(|opt| match opt { @@ -1063,6 +1074,7 @@ impl TryFrom<&protobuf::TableParquetOptions> for TableParquetOptions { .unwrap(), column_specific_options, key_value_metadata: Default::default(), + crypto: Default::default(), }) } } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index b0241fd47a26f..e63f345459b8f 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -1,3 +1,74 @@ +impl serde::Serialize for ArrowFormat { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion_common.ArrowFormat", len)?; + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ArrowFormat { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + Err(serde::de::Error::unknown_field(value, FIELDS)) + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ArrowFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ArrowFormat") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; + } + Ok(ArrowFormat { + }) + } + } + deserializer.deserialize_struct("datafusion_common.ArrowFormat", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ArrowOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -172,8 +243,14 @@ impl serde::Serialize for ArrowType { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; struct_ser.serialize_field("INTERVAL", &v)?; } - arrow_type::ArrowTypeEnum::Decimal(v) => { - struct_ser.serialize_field("DECIMAL", v)?; + arrow_type::ArrowTypeEnum::Decimal32(v) => { + struct_ser.serialize_field("DECIMAL32", v)?; + } + arrow_type::ArrowTypeEnum::Decimal64(v) => { + struct_ser.serialize_field("DECIMAL64", v)?; + } + arrow_type::ArrowTypeEnum::Decimal128(v) => { + struct_ser.serialize_field("DECIMAL128", v)?; } arrow_type::ArrowTypeEnum::Decimal256(v) => { struct_ser.serialize_field("DECIMAL256", v)?; @@ -243,7 +320,9 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "TIME32", "TIME64", "INTERVAL", - "DECIMAL", + "DECIMAL32", + "DECIMAL64", + "DECIMAL128", "DECIMAL256", "LIST", "LARGE_LIST", @@ -285,7 +364,9 @@ impl<'de> serde::Deserialize<'de> for ArrowType { Time32, Time64, Interval, - Decimal, + Decimal32, + Decimal64, + Decimal128, Decimal256, List, LargeList, @@ -342,7 +423,9 @@ impl<'de> serde::Deserialize<'de> for ArrowType { "TIME32" => Ok(GeneratedField::Time32), "TIME64" => Ok(GeneratedField::Time64), "INTERVAL" => Ok(GeneratedField::Interval), - "DECIMAL" => Ok(GeneratedField::Decimal), + "DECIMAL32" => Ok(GeneratedField::Decimal32), + "DECIMAL64" => Ok(GeneratedField::Decimal64), + "DECIMAL128" => Ok(GeneratedField::Decimal128), "DECIMAL256" => Ok(GeneratedField::Decimal256), "LIST" => Ok(GeneratedField::List), "LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList), @@ -557,11 +640,25 @@ impl<'de> serde::Deserialize<'de> for ArrowType { } arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Interval(x as i32)); } - GeneratedField::Decimal => { + GeneratedField::Decimal32 => { if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("DECIMAL")); + return Err(serde::de::Error::duplicate_field("DECIMAL32")); } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal) + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal32) +; + } + GeneratedField::Decimal64 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DECIMAL64")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal64) +; + } + GeneratedField::Decimal128 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DECIMAL128")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal128) ; } GeneratedField::Decimal256 => { @@ -1566,6 +1663,9 @@ impl serde::Serialize for CsvOptions { if !self.terminator.is_empty() { len += 1; } + if !self.truncated_rows.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; if !self.has_header.is_empty() { #[allow(clippy::needless_borrow)] @@ -1638,6 +1738,11 @@ impl serde::Serialize for CsvOptions { #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("terminator", pbjson::private::base64::encode(&self.terminator).as_str())?; } + if !self.truncated_rows.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("truncatedRows", pbjson::private::base64::encode(&self.truncated_rows).as_str())?; + } struct_ser.end() } } @@ -1676,6 +1781,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "newlines_in_values", "newlinesInValues", "terminator", + "truncated_rows", + "truncatedRows", ]; #[allow(clippy::enum_variant_names)] @@ -1697,6 +1804,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { DoubleQuote, NewlinesInValues, Terminator, + TruncatedRows, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1735,6 +1843,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), "newlinesInValues" | "newlines_in_values" => Ok(GeneratedField::NewlinesInValues), "terminator" => Ok(GeneratedField::Terminator), + "truncatedRows" | "truncated_rows" => Ok(GeneratedField::TruncatedRows), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1771,6 +1880,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { let mut double_quote__ = None; let mut newlines_in_values__ = None; let mut terminator__ = None; + let mut truncated_rows__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::HasHeader => { @@ -1893,6 +2003,14 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } + GeneratedField::TruncatedRows => { + if truncated_rows__.is_some() { + return Err(serde::de::Error::duplicate_field("truncatedRows")); + } + truncated_rows__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } } } Ok(CsvOptions { @@ -1913,6 +2031,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { double_quote: double_quote__.unwrap_or_default(), newlines_in_values: newlines_in_values__.unwrap_or_default(), terminator: terminator__.unwrap_or_default(), + truncated_rows: truncated_rows__.unwrap_or_default(), }) } } @@ -2186,10 +2305,396 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { }) } } - deserializer.deserialize_struct("datafusion_common.CsvWriterOptions", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion_common.CsvWriterOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Decimal128 { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.value.is_empty() { + len += 1; + } + if self.p != 0 { + len += 1; + } + if self.s != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal128", len)?; + if !self.value.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + } + if self.p != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + } + if self.s != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal128 { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value", + "p", + "s", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Value, + P, + S, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal128; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal128") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); + } + p__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); + } + s__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal128 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal128", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Decimal128Type { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision != 0 { + len += 1; + } + if self.scale != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal128Type", len)?; + if self.precision != 0 { + struct_ser.serialize_field("precision", &self.precision)?; + } + if self.scale != 0 { + struct_ser.serialize_field("scale", &self.scale)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal128Type { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision", + "scale", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Precision, + Scale, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precision" => Ok(GeneratedField::Precision), + "scale" => Ok(GeneratedField::Scale), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal128Type; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal128Type") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut precision__ = None; + let mut scale__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Precision => { + if precision__.is_some() { + return Err(serde::de::Error::duplicate_field("precision")); + } + precision__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Scale => { + if scale__.is_some() { + return Err(serde::de::Error::duplicate_field("scale")); + } + scale__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal128Type { + precision: precision__.unwrap_or_default(), + scale: scale__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal128Type", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Decimal256 { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.value.is_empty() { + len += 1; + } + if self.p != 0 { + len += 1; + } + if self.s != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256", len)?; + if !self.value.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + } + if self.p != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + } + if self.s != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal256 { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value", + "p", + "s", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Value, + P, + S, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal256") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); + } + p__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); + } + s__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal { +impl serde::Serialize for Decimal256Type { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2203,7 +2708,7 @@ impl serde::Serialize for Decimal { if self.scale != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?; if self.precision != 0 { struct_ser.serialize_field("precision", &self.precision)?; } @@ -2213,7 +2718,7 @@ impl serde::Serialize for Decimal { struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal { +impl<'de> serde::Deserialize<'de> for Decimal256Type { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -2260,13 +2765,13 @@ impl<'de> serde::Deserialize<'de> for Decimal { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal; + type Value = Decimal256Type; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion_common.Decimal") + formatter.write_str("struct datafusion_common.Decimal256Type") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -2292,16 +2797,16 @@ impl<'de> serde::Deserialize<'de> for Decimal { } } } - Ok(Decimal { + Ok(Decimal256Type { precision: precision__.unwrap_or_default(), scale: scale__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion_common.Decimal", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal128 { +impl serde::Serialize for Decimal32 { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2318,7 +2823,7 @@ impl serde::Serialize for Decimal128 { if self.s != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal128", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal32", len)?; if !self.value.is_empty() { #[allow(clippy::needless_borrow)] #[allow(clippy::needless_borrows_for_generic_args)] @@ -2337,7 +2842,7 @@ impl serde::Serialize for Decimal128 { struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal128 { +impl<'de> serde::Deserialize<'de> for Decimal32 { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -2387,13 +2892,13 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal128; + type Value = Decimal32; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion_common.Decimal128") + formatter.write_str("struct datafusion_common.Decimal32") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -2428,17 +2933,129 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { } } } - Ok(Decimal128 { + Ok(Decimal32 { value: value__.unwrap_or_default(), p: p__.unwrap_or_default(), s: s__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion_common.Decimal128", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion_common.Decimal32", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal256 { +impl serde::Serialize for Decimal32Type { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision != 0 { + len += 1; + } + if self.scale != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal32Type", len)?; + if self.precision != 0 { + struct_ser.serialize_field("precision", &self.precision)?; + } + if self.scale != 0 { + struct_ser.serialize_field("scale", &self.scale)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal32Type { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision", + "scale", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Precision, + Scale, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precision" => Ok(GeneratedField::Precision), + "scale" => Ok(GeneratedField::Scale), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal32Type; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal32Type") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut precision__ = None; + let mut scale__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Precision => { + if precision__.is_some() { + return Err(serde::de::Error::duplicate_field("precision")); + } + precision__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Scale => { + if scale__.is_some() { + return Err(serde::de::Error::duplicate_field("scale")); + } + scale__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal32Type { + precision: precision__.unwrap_or_default(), + scale: scale__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal32Type", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Decimal64 { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2455,7 +3072,7 @@ impl serde::Serialize for Decimal256 { if self.s != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal64", len)?; if !self.value.is_empty() { #[allow(clippy::needless_borrow)] #[allow(clippy::needless_borrows_for_generic_args)] @@ -2474,7 +3091,7 @@ impl serde::Serialize for Decimal256 { struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal256 { +impl<'de> serde::Deserialize<'de> for Decimal64 { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -2524,13 +3141,13 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal256; + type Value = Decimal64; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion_common.Decimal256") + formatter.write_str("struct datafusion_common.Decimal64") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -2565,17 +3182,17 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { } } } - Ok(Decimal256 { + Ok(Decimal64 { value: value__.unwrap_or_default(), p: p__.unwrap_or_default(), s: s__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion_common.Decimal64", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal256Type { +impl serde::Serialize for Decimal64Type { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2589,7 +3206,7 @@ impl serde::Serialize for Decimal256Type { if self.scale != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal64Type", len)?; if self.precision != 0 { struct_ser.serialize_field("precision", &self.precision)?; } @@ -2599,7 +3216,7 @@ impl serde::Serialize for Decimal256Type { struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal256Type { +impl<'de> serde::Deserialize<'de> for Decimal64Type { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -2646,13 +3263,13 @@ impl<'de> serde::Deserialize<'de> for Decimal256Type { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal256Type; + type Value = Decimal64Type; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion_common.Decimal256Type") + formatter.write_str("struct datafusion_common.Decimal64Type") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -2678,13 +3295,13 @@ impl<'de> serde::Deserialize<'de> for Decimal256Type { } } } - Ok(Decimal256Type { + Ok(Decimal64Type { precision: precision__.unwrap_or_default(), scale: scale__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion_common.Decimal64Type", FIELDS, GeneratedVisitor) } } impl serde::Serialize for DfField { @@ -3107,9 +3724,6 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } - if self.dict_ordered { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion_common.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -3126,9 +3740,6 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } - if self.dict_ordered { - struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; - } struct_ser.end() } } @@ -3145,8 +3756,6 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", - "dict_ordered", - "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -3156,7 +3765,6 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, - DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3183,7 +3791,6 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), - "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3208,7 +3815,6 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; - let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -3243,12 +3849,6 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } - GeneratedField::DictOrdered => { - if dict_ordered__.is_some() { - return Err(serde::de::Error::duplicate_field("dictOrdered")); - } - dict_ordered__ = Some(map_.next_value()?); - } } } Ok(Field { @@ -3257,7 +3857,6 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), - dict_ordered: dict_ordered__.unwrap_or_default(), }) } } @@ -3856,6 +4455,7 @@ impl serde::Serialize for JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", }; serializer.serialize_str(variant) } @@ -3876,6 +4476,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { "RIGHTSEMI", "RIGHTANTI", "LEFTMARK", + "RIGHTMARK", ]; struct GeneratedVisitor; @@ -3925,6 +4526,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { "RIGHTSEMI" => Ok(JoinType::Rightsemi), "RIGHTANTI" => Ok(JoinType::Rightanti), "LEFTMARK" => Ok(JoinType::Leftmark), + "RIGHTMARK" => Ok(JoinType::Rightmark), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -4433,6 +5035,77 @@ impl<'de> serde::Deserialize<'de> for NdJsonFormat { deserializer.deserialize_struct("datafusion_common.NdJsonFormat", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for NullEquality { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for NullEquality { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "NULL_EQUALS_NOTHING", + "NULL_EQUALS_NULL", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NullEquality; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "NULL_EQUALS_NOTHING" => Ok(NullEquality::NullEqualsNothing), + "NULL_EQUALS_NULL" => Ok(NullEquality::NullEqualsNull), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for ParquetColumnOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -4462,9 +5135,6 @@ impl serde::Serialize for ParquetColumnOptions { if self.bloom_filter_ndv_opt.is_some() { len += 1; } - if self.max_statistics_size_opt.is_some() { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion_common.ParquetColumnOptions", len)?; if let Some(v) = self.bloom_filter_enabled_opt.as_ref() { match v { @@ -4517,13 +5187,6 @@ impl serde::Serialize for ParquetColumnOptions { } } } - if let Some(v) = self.max_statistics_size_opt.as_ref() { - match v { - parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => { - struct_ser.serialize_field("maxStatisticsSize", v)?; - } - } - } struct_ser.end() } } @@ -4546,8 +5209,6 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { "bloomFilterFpp", "bloom_filter_ndv", "bloomFilterNdv", - "max_statistics_size", - "maxStatisticsSize", ]; #[allow(clippy::enum_variant_names)] @@ -4559,7 +5220,6 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { StatisticsEnabled, BloomFilterFpp, BloomFilterNdv, - MaxStatisticsSize, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4588,7 +5248,6 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { "statisticsEnabled" | "statistics_enabled" => Ok(GeneratedField::StatisticsEnabled), "bloomFilterFpp" | "bloom_filter_fpp" => Ok(GeneratedField::BloomFilterFpp), "bloomFilterNdv" | "bloom_filter_ndv" => Ok(GeneratedField::BloomFilterNdv), - "maxStatisticsSize" | "max_statistics_size" => Ok(GeneratedField::MaxStatisticsSize), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4615,7 +5274,6 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { let mut statistics_enabled_opt__ = None; let mut bloom_filter_fpp_opt__ = None; let mut bloom_filter_ndv_opt__ = None; - let mut max_statistics_size_opt__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::BloomFilterEnabled => { @@ -4660,12 +5318,6 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { } bloom_filter_ndv_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(x.0)); } - GeneratedField::MaxStatisticsSize => { - if max_statistics_size_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("maxStatisticsSize")); - } - max_statistics_size_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(x.0)); - } } } Ok(ParquetColumnOptions { @@ -4676,7 +5328,6 @@ impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { statistics_enabled_opt: statistics_enabled_opt__, bloom_filter_fpp_opt: bloom_filter_fpp_opt__, bloom_filter_ndv_opt: bloom_filter_ndv_opt__, - max_statistics_size_opt: max_statistics_size_opt__, }) } } @@ -4963,9 +5614,6 @@ impl serde::Serialize for ParquetOptions { if self.statistics_enabled_opt.is_some() { len += 1; } - if self.max_statistics_size_opt.is_some() { - len += 1; - } if self.column_index_truncate_length_opt.is_some() { len += 1; } @@ -4981,6 +5629,12 @@ impl serde::Serialize for ParquetOptions { if self.bloom_filter_ndv_opt.is_some() { len += 1; } + if self.coerce_int96_opt.is_some() { + len += 1; + } + if self.max_predicate_cache_size_opt.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.ParquetOptions", len)?; if self.enable_page_index { struct_ser.serialize_field("enablePageIndex", &self.enable_page_index)?; @@ -5086,15 +5740,6 @@ impl serde::Serialize for ParquetOptions { } } } - if let Some(v) = self.max_statistics_size_opt.as_ref() { - match v { - parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => { - #[allow(clippy::needless_borrow)] - #[allow(clippy::needless_borrows_for_generic_args)] - struct_ser.serialize_field("maxStatisticsSize", ToString::to_string(&v).as_str())?; - } - } - } if let Some(v) = self.column_index_truncate_length_opt.as_ref() { match v { parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v) => { @@ -5136,6 +5781,22 @@ impl serde::Serialize for ParquetOptions { } } } + if let Some(v) = self.coerce_int96_opt.as_ref() { + match v { + parquet_options::CoerceInt96Opt::CoerceInt96(v) => { + struct_ser.serialize_field("coerceInt96", v)?; + } + } + } + if let Some(v) = self.max_predicate_cache_size_opt.as_ref() { + match v { + parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("maxPredicateCacheSize", ToString::to_string(&v).as_str())?; + } + } + } struct_ser.end() } } @@ -5192,8 +5853,6 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { "dictionaryEnabled", "statistics_enabled", "statisticsEnabled", - "max_statistics_size", - "maxStatisticsSize", "column_index_truncate_length", "columnIndexTruncateLength", "statistics_truncate_length", @@ -5203,6 +5862,10 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { "bloomFilterFpp", "bloom_filter_ndv", "bloomFilterNdv", + "coerce_int96", + "coerceInt96", + "max_predicate_cache_size", + "maxPredicateCacheSize", ]; #[allow(clippy::enum_variant_names)] @@ -5231,12 +5894,13 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { Compression, DictionaryEnabled, StatisticsEnabled, - MaxStatisticsSize, ColumnIndexTruncateLength, StatisticsTruncateLength, Encoding, BloomFilterFpp, BloomFilterNdv, + CoerceInt96, + MaxPredicateCacheSize, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5282,12 +5946,13 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { "compression" => Ok(GeneratedField::Compression), "dictionaryEnabled" | "dictionary_enabled" => Ok(GeneratedField::DictionaryEnabled), "statisticsEnabled" | "statistics_enabled" => Ok(GeneratedField::StatisticsEnabled), - "maxStatisticsSize" | "max_statistics_size" => Ok(GeneratedField::MaxStatisticsSize), "columnIndexTruncateLength" | "column_index_truncate_length" => Ok(GeneratedField::ColumnIndexTruncateLength), "statisticsTruncateLength" | "statistics_truncate_length" => Ok(GeneratedField::StatisticsTruncateLength), "encoding" => Ok(GeneratedField::Encoding), "bloomFilterFpp" | "bloom_filter_fpp" => Ok(GeneratedField::BloomFilterFpp), "bloomFilterNdv" | "bloom_filter_ndv" => Ok(GeneratedField::BloomFilterNdv), + "coerceInt96" | "coerce_int96" => Ok(GeneratedField::CoerceInt96), + "maxPredicateCacheSize" | "max_predicate_cache_size" => Ok(GeneratedField::MaxPredicateCacheSize), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5331,12 +5996,13 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { let mut compression_opt__ = None; let mut dictionary_enabled_opt__ = None; let mut statistics_enabled_opt__ = None; - let mut max_statistics_size_opt__ = None; let mut column_index_truncate_length_opt__ = None; let mut statistics_truncate_length_opt__ = None; let mut encoding_opt__ = None; let mut bloom_filter_fpp_opt__ = None; let mut bloom_filter_ndv_opt__ = None; + let mut coerce_int96_opt__ = None; + let mut max_predicate_cache_size_opt__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::EnablePageIndex => { @@ -5497,12 +6163,6 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { } statistics_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::StatisticsEnabledOpt::StatisticsEnabled); } - GeneratedField::MaxStatisticsSize => { - if max_statistics_size_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("maxStatisticsSize")); - } - max_statistics_size_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(x.0)); - } GeneratedField::ColumnIndexTruncateLength => { if column_index_truncate_length_opt__.is_some() { return Err(serde::de::Error::duplicate_field("columnIndexTruncateLength")); @@ -5533,6 +6193,18 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { } bloom_filter_ndv_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::BloomFilterNdvOpt::BloomFilterNdv(x.0)); } + GeneratedField::CoerceInt96 => { + if coerce_int96_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("coerceInt96")); + } + coerce_int96_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::CoerceInt96Opt::CoerceInt96); + } + GeneratedField::MaxPredicateCacheSize => { + if max_predicate_cache_size_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("maxPredicateCacheSize")); + } + max_predicate_cache_size_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize(x.0)); + } } } Ok(ParquetOptions { @@ -5560,12 +6232,13 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { compression_opt: compression_opt__, dictionary_enabled_opt: dictionary_enabled_opt__, statistics_enabled_opt: statistics_enabled_opt__, - max_statistics_size_opt: max_statistics_size_opt__, column_index_truncate_length_opt: column_index_truncate_length_opt__, statistics_truncate_length_opt: statistics_truncate_length_opt__, encoding_opt: encoding_opt__, bloom_filter_fpp_opt: bloom_filter_fpp_opt__, bloom_filter_ndv_opt: bloom_filter_ndv_opt__, + coerce_int96_opt: coerce_int96_opt__, + max_predicate_cache_size_opt: max_predicate_cache_size_opt__, }) } } @@ -6810,6 +7483,12 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::MapValue(v) => { struct_ser.serialize_field("mapValue", v)?; } + scalar_value::Value::Decimal32Value(v) => { + struct_ser.serialize_field("decimal32Value", v)?; + } + scalar_value::Value::Decimal64Value(v) => { + struct_ser.serialize_field("decimal64Value", v)?; + } scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } @@ -6936,6 +7615,10 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "structValue", "map_value", "mapValue", + "decimal32_value", + "decimal32Value", + "decimal64_value", + "decimal64Value", "decimal128_value", "decimal128Value", "decimal256_value", @@ -6998,6 +7681,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { FixedSizeListValue, StructValue, MapValue, + Decimal32Value, + Decimal64Value, Decimal128Value, Decimal256Value, Date64Value, @@ -7059,6 +7744,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "structValue" | "struct_value" => Ok(GeneratedField::StructValue), "mapValue" | "map_value" => Ok(GeneratedField::MapValue), + "decimal32Value" | "decimal32_value" => Ok(GeneratedField::Decimal32Value), + "decimal64Value" | "decimal64_value" => Ok(GeneratedField::Decimal64Value), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), @@ -7236,6 +7923,20 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("mapValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::MapValue) +; + } + GeneratedField::Decimal32Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("decimal32Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal32Value) +; + } + GeneratedField::Decimal64Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("decimal64Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal64Value) ; } GeneratedField::Decimal128Value => { diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index b6e9bc1379832..aa7c3d51a9d6d 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -1,10 +1,10 @@ // This file is @generated by prost-build. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ColumnRelation { #[prost(string, tag = "1")] pub relation: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Column { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, @@ -28,7 +28,7 @@ pub struct DfSchema { ::prost::alloc::string::String, >, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CsvFormat { #[prost(message, optional, tag = "5")] pub options: ::core::option::Option, @@ -38,31 +38,33 @@ pub struct ParquetFormat { #[prost(message, optional, tag = "2")] pub options: ::core::option::Option, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct AvroFormat {} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct NdJsonFormat { #[prost(message, optional, tag = "1")] pub options: ::core::option::Option, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ArrowFormat {} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct PrimaryKeyConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct UniqueConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Constraint { #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] pub constraint_mode: ::core::option::Option, } /// Nested message and enum types in `Constraint`. pub mod constraint { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum ConstraintMode { #[prost(message, tag = "1")] PrimaryKey(super::PrimaryKeyConstraint), @@ -75,9 +77,9 @@ pub struct Constraints { #[prost(message, repeated, tag = "1")] pub constraints: ::prost::alloc::vec::Vec, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct AvroOptions {} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ArrowOptions {} #[derive(Clone, PartialEq, ::prost::Message)] pub struct Schema { @@ -106,24 +108,36 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(bool, tag = "6")] - pub dict_ordered: bool, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Timestamp { #[prost(enumeration = "TimeUnit", tag = "1")] pub time_unit: i32, #[prost(string, tag = "2")] pub timezone: ::prost::alloc::string::String, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct Decimal { +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal32Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal64Type { #[prost(uint32, tag = "3")] pub precision: u32, #[prost(int32, tag = "4")] pub scale: i32, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal128Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Decimal256Type { #[prost(uint32, tag = "3")] pub precision: u32, @@ -184,7 +198,7 @@ pub struct ScalarNestedValue { } /// Nested message and enum types in `ScalarNestedValue`. pub mod scalar_nested_value { - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Dictionary { #[prost(bytes = "vec", tag = "1")] pub ipc_message: ::prost::alloc::vec::Vec, @@ -192,14 +206,14 @@ pub mod scalar_nested_value { pub arrow_data: ::prost::alloc::vec::Vec, } } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScalarTime32Value { #[prost(oneof = "scalar_time32_value::Value", tags = "1, 2")] pub value: ::core::option::Option, } /// Nested message and enum types in `ScalarTime32Value`. pub mod scalar_time32_value { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Value { #[prost(int32, tag = "1")] Time32SecondValue(i32), @@ -207,14 +221,14 @@ pub mod scalar_time32_value { Time32MillisecondValue(i32), } } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScalarTime64Value { #[prost(oneof = "scalar_time64_value::Value", tags = "1, 2")] pub value: ::core::option::Option, } /// Nested message and enum types in `ScalarTime64Value`. pub mod scalar_time64_value { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Value { #[prost(int64, tag = "1")] Time64MicrosecondValue(i64), @@ -222,7 +236,7 @@ pub mod scalar_time64_value { Time64NanosecondValue(i64), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScalarTimestampValue { #[prost(string, tag = "5")] pub timezone: ::prost::alloc::string::String, @@ -231,7 +245,7 @@ pub struct ScalarTimestampValue { } /// Nested message and enum types in `ScalarTimestampValue`. pub mod scalar_timestamp_value { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Value { #[prost(int64, tag = "1")] TimeMicrosecondValue(i64), @@ -250,14 +264,14 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] pub days: i32, #[prost(int32, tag = "2")] pub milliseconds: i32, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalMonthDayNanoValue { #[prost(int32, tag = "1")] pub months: i32, @@ -286,7 +300,7 @@ pub struct UnionValue { #[prost(enumeration = "UnionMode", tag = "4")] pub mode: i32, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] pub values: ::prost::alloc::vec::Vec, @@ -297,7 +311,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -352,6 +366,10 @@ pub mod scalar_value { StructValue(super::ScalarNestedValue), #[prost(message, tag = "41")] MapValue(super::ScalarNestedValue), + #[prost(message, tag = "43")] + Decimal32Value(super::Decimal32), + #[prost(message, tag = "44")] + Decimal64Value(super::Decimal64), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] @@ -390,7 +408,25 @@ pub mod scalar_value { UnionValue(::prost::alloc::boxed::Box), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal32 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal64 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Decimal128 { #[prost(bytes = "vec", tag = "1")] pub value: ::prost::alloc::vec::Vec, @@ -399,7 +435,7 @@ pub struct Decimal128 { #[prost(int64, tag = "3")] pub s: i64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Decimal256 { #[prost(bytes = "vec", tag = "1")] pub value: ::prost::alloc::vec::Vec, @@ -413,7 +449,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -480,8 +516,12 @@ pub mod arrow_type { Time64(i32), #[prost(enumeration = "super::IntervalUnit", tag = "23")] Interval(i32), + #[prost(message, tag = "40")] + Decimal32(super::Decimal32Type), + #[prost(message, tag = "41")] + Decimal64(super::Decimal64Type), #[prost(message, tag = "24")] - Decimal(super::Decimal), + Decimal128(super::Decimal128Type), #[prost(message, tag = "36")] Decimal256(super::Decimal256Type), #[prost(message, tag = "25")] @@ -509,14 +549,14 @@ pub mod arrow_type { /// i32 Two = 2; /// } /// } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct EmptyMessage {} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct JsonWriterOptions { #[prost(enumeration = "CompressionTypeVariant", tag = "1")] pub compression: i32, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CsvWriterOptions { /// Compression type #[prost(enumeration = "CompressionTypeVariant", tag = "1")] @@ -553,7 +593,7 @@ pub struct CsvWriterOptions { pub double_quote: bool, } /// Options controlling CSV format -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CsvOptions { /// Indicates if the CSV has a header row #[prost(bytes = "vec", tag = "1")] @@ -606,9 +646,12 @@ pub struct CsvOptions { /// Optional terminator character as a byte #[prost(bytes = "vec", tag = "17")] pub terminator: ::prost::alloc::vec::Vec, + /// Indicates if truncated rows are allowed + #[prost(bytes = "vec", tag = "18")] + pub truncated_rows: ::prost::alloc::vec::Vec, } /// Options controlling CSV format -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct JsonOptions { /// Compression type #[prost(enumeration = "CompressionTypeVariant", tag = "1")] @@ -662,34 +705,30 @@ pub struct ParquetColumnOptions { pub bloom_filter_ndv_opt: ::core::option::Option< parquet_column_options::BloomFilterNdvOpt, >, - #[prost(oneof = "parquet_column_options::MaxStatisticsSizeOpt", tags = "8")] - pub max_statistics_size_opt: ::core::option::Option< - parquet_column_options::MaxStatisticsSizeOpt, - >, } /// Nested message and enum types in `ParquetColumnOptions`. pub mod parquet_column_options { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum BloomFilterEnabledOpt { #[prost(bool, tag = "1")] BloomFilterEnabled(bool), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum EncodingOpt { #[prost(string, tag = "2")] Encoding(::prost::alloc::string::String), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum DictionaryEnabledOpt { #[prost(bool, tag = "3")] DictionaryEnabled(bool), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum CompressionOpt { #[prost(string, tag = "4")] Compression(::prost::alloc::string::String), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum StatisticsEnabledOpt { #[prost(string, tag = "5")] StatisticsEnabled(::prost::alloc::string::String), @@ -699,16 +738,11 @@ pub mod parquet_column_options { #[prost(double, tag = "6")] BloomFilterFpp(f64), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum BloomFilterNdvOpt { #[prost(uint64, tag = "7")] BloomFilterNdv(u64), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] - pub enum MaxStatisticsSizeOpt { - #[prost(uint32, tag = "8")] - MaxStatisticsSize(u32), - } } #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetOptions { @@ -786,10 +820,6 @@ pub struct ParquetOptions { pub statistics_enabled_opt: ::core::option::Option< parquet_options::StatisticsEnabledOpt, >, - #[prost(oneof = "parquet_options::MaxStatisticsSizeOpt", tags = "14")] - pub max_statistics_size_opt: ::core::option::Option< - parquet_options::MaxStatisticsSizeOpt, - >, #[prost(oneof = "parquet_options::ColumnIndexTruncateLengthOpt", tags = "17")] pub column_index_truncate_length_opt: ::core::option::Option< parquet_options::ColumnIndexTruncateLengthOpt, @@ -804,45 +834,46 @@ pub struct ParquetOptions { pub bloom_filter_fpp_opt: ::core::option::Option, #[prost(oneof = "parquet_options::BloomFilterNdvOpt", tags = "22")] pub bloom_filter_ndv_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::CoerceInt96Opt", tags = "32")] + pub coerce_int96_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::MaxPredicateCacheSizeOpt", tags = "33")] + pub max_predicate_cache_size_opt: ::core::option::Option< + parquet_options::MaxPredicateCacheSizeOpt, + >, } /// Nested message and enum types in `ParquetOptions`. pub mod parquet_options { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum MetadataSizeHintOpt { #[prost(uint64, tag = "4")] MetadataSizeHint(u64), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum CompressionOpt { #[prost(string, tag = "10")] Compression(::prost::alloc::string::String), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum DictionaryEnabledOpt { #[prost(bool, tag = "11")] DictionaryEnabled(bool), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum StatisticsEnabledOpt { #[prost(string, tag = "13")] StatisticsEnabled(::prost::alloc::string::String), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] - pub enum MaxStatisticsSizeOpt { - #[prost(uint64, tag = "14")] - MaxStatisticsSize(u64), - } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum ColumnIndexTruncateLengthOpt { #[prost(uint64, tag = "17")] ColumnIndexTruncateLength(u64), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum StatisticsTruncateLengthOpt { #[prost(uint64, tag = "31")] StatisticsTruncateLength(u64), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum EncodingOpt { #[prost(string, tag = "19")] Encoding(::prost::alloc::string::String), @@ -852,11 +883,21 @@ pub mod parquet_options { #[prost(double, tag = "21")] BloomFilterFpp(f64), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum BloomFilterNdvOpt { #[prost(uint64, tag = "22")] BloomFilterNdv(u64), } + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum CoerceInt96Opt { + #[prost(string, tag = "32")] + CoerceInt96(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum MaxPredicateCacheSizeOpt { + #[prost(uint64, tag = "33")] + MaxPredicateCacheSize(u64), + } } #[derive(Clone, PartialEq, ::prost::Message)] pub struct Precision { @@ -899,6 +940,7 @@ pub enum JoinType { Rightsemi = 6, Rightanti = 7, Leftmark = 8, + Rightmark = 9, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -916,6 +958,7 @@ impl JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -930,6 +973,7 @@ impl JoinType { "RIGHTSEMI" => Some(Self::Rightsemi), "RIGHTANTI" => Some(Self::Rightanti), "LEFTMARK" => Some(Self::Leftmark), + "RIGHTMARK" => Some(Self::Rightmark), _ => None, } } @@ -962,6 +1006,32 @@ impl JoinConstraint { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum NullEquality { + NullEqualsNothing = 0, + NullEqualsNull = 1, +} +impl NullEquality { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "NULL_EQUALS_NOTHING" => Some(Self::NullEqualsNothing), + "NULL_EQUALS_NULL" => Some(Self::NullEqualsNull), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum TimeUnit { Second = 0, Millisecond = 1, diff --git a/datafusion/proto-common/src/lib.rs b/datafusion/proto-common/src/lib.rs index 6400e4bdc66de..9efb234e3994a 100644 --- a/datafusion/proto-common/src/lib.rs +++ b/datafusion/proto-common/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index decd0cf630388..8e4131479e506 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -65,7 +65,7 @@ impl std::fmt::Display for Error { write!(f, "{value:?} is invalid as a DataFusion scalar value") } Self::InvalidScalarType(data_type) => { - write!(f, "{data_type:?} is invalid as a DataFusion scalar type") + write!(f, "{data_type} is invalid as a DataFusion scalar type") } Self::InvalidTimeUnit(time_unit) => { write!( @@ -97,7 +97,6 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), - dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } @@ -190,7 +189,15 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { value: Some(Box::new(value_type.as_ref().try_into()?)), })) } - DataType::Decimal128(precision, scale) => Self::Decimal(protobuf::Decimal { + DataType::Decimal32(precision, scale) => Self::Decimal32(protobuf::Decimal32Type { + precision: *precision as u32, + scale: *scale as i32, + }), + DataType::Decimal64(precision, scale) => Self::Decimal64(protobuf::Decimal64Type { + precision: *precision as u32, + scale: *scale as i32, + }), + DataType::Decimal128(precision, scale) => Self::Decimal128(protobuf::Decimal128Type { precision: *precision as u32, scale: *scale as i32, }), @@ -398,6 +405,42 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }) }) } + ScalarValue::Decimal32(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal32Value(protobuf::Decimal32 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, + ScalarValue::Decimal64(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal64Value(protobuf::Decimal64 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, ScalarValue::Decimal128(val, p, s) => match *val { Some(v) => { let array = v.to_be_bytes(); @@ -818,8 +861,6 @@ impl TryFrom<&ParquetOptions> for protobuf::ParquetOptions { dictionary_enabled_opt: value.dictionary_enabled.map(protobuf::parquet_options::DictionaryEnabledOpt::DictionaryEnabled), dictionary_page_size_limit: value.dictionary_page_size_limit as u64, statistics_enabled_opt: value.statistics_enabled.clone().map(protobuf::parquet_options::StatisticsEnabledOpt::StatisticsEnabled), - #[allow(deprecated)] - max_statistics_size_opt: value.max_statistics_size.map(|v| protobuf::parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v as u64)), max_row_group_size: value.max_row_group_size as u64, created_by: value.created_by.clone(), column_index_truncate_length_opt: value.column_index_truncate_length.map(|v| protobuf::parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v as u64)), @@ -836,6 +877,8 @@ impl TryFrom<&ParquetOptions> for protobuf::ParquetOptions { schema_force_view_types: value.schema_force_view_types, binary_as_string: value.binary_as_string, skip_arrow_metadata: value.skip_arrow_metadata, + coerce_int96_opt: value.coerce_int96.clone().map(protobuf::parquet_options::CoerceInt96Opt::CoerceInt96), + max_predicate_cache_size_opt: value.max_predicate_cache_size.map(|v| protobuf::parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize(v as u64)), }) } } @@ -858,12 +901,6 @@ impl TryFrom<&ParquetColumnOptions> for protobuf::ParquetColumnOptions { .statistics_enabled .clone() .map(protobuf::parquet_column_options::StatisticsEnabledOpt::StatisticsEnabled), - #[allow(deprecated)] - max_statistics_size_opt: value.max_statistics_size.map(|v| { - protobuf::parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize( - v as u32, - ) - }), encoding_opt: value .encoding .clone() @@ -934,6 +971,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions { null_value: opts.null_value.clone().unwrap_or_default(), null_regex: opts.null_regex.clone().unwrap_or_default(), comment: opts.comment.map_or_else(Vec::new, |h| vec![h]), + truncated_rows: opts.truncated_rows.map_or_else(Vec::new, |h| vec![h as u8]), }) } } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 553fccf7d428e..c1d894a6c0629 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -46,8 +46,8 @@ avro = ["datafusion/avro", "datafusion-common/avro"] [dependencies] arrow = { workspace = true } chrono = { workspace = true } -datafusion = { workspace = true, default-features = true } -datafusion-common = { workspace = true, default-features = true } +datafusion = { workspace = true, default-features = false } +datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-proto-common = { workspace = true } object_store = { workspace = true } @@ -57,9 +57,15 @@ serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +datafusion = { workspace = true, default-features = false, features = [ + "sql", + "datetime_expressions", + "nested_expressions", + "unicode_expressions", +] } datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window-common = { workspace = true } doc-comment = { workspace = true } -strum = { version = "0.27.1", features = ["derive"] } +pretty_assertions = "1.4" tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index f51e4664d5d98..c1382c5b8f8f8 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -17,13 +17,17 @@ under the License. --> -# `datafusion-proto`: Apache DataFusion Protobuf Serialization / Deserialization +# Apache DataFusion Protobuf Serialization / Deserialization -This crate contains code to convert Apache [DataFusion] plans to and from -bytes, which can be useful for sending plans over the network, for example -when building a distributed query engine. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate contains code to convert DataFusion plans to and from bytes using [Protocol Buffers], +which can be useful for sending plans over the network, for example when building a distributed +query engine. See [API Docs] for details and examples. -[datafusion]: https://datafusion.apache.org +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[protocol buffers]: https://protobuf.dev/ [api docs]: http://docs.rs/datafusion-proto/latest diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 467a7f487dae9..c2096b6011123 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -34,5 +34,5 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic -pbjson-build = "=0.7.0" -prost-build = "=0.13.5" +pbjson-build = "=0.8.0" +prost-build = "=0.14.1" diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2e028eb291181..ee9ac0e7902d3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -21,7 +21,7 @@ syntax = "proto3"; package datafusion; option java_multiple_files = true; -option java_package = "org.apache.arrow.datafusion.protobuf"; +option java_package = "org.apache.datafusion.protobuf"; option java_outer_classname = "DatafusionProto"; import "datafusion/proto-common/proto/datafusion_common.proto"; @@ -90,7 +90,7 @@ message ListingTableScanNode { ProjectionColumns projection = 4; datafusion_common.Schema schema = 5; repeated LogicalExprNode filters = 6; - repeated string table_partition_cols = 7; + repeated PartitionColumn table_partition_cols = 7; bool collect_stat = 8; uint32 target_partitions = 9; oneof FileFormatType { @@ -98,6 +98,7 @@ message ListingTableScanNode { datafusion_common.ParquetFormat parquet = 11; datafusion_common.AvroFormat avro = 12; datafusion_common.NdJsonFormat json = 15; + datafusion_common.ArrowFormat arrow = 16; } repeated SortExprNodeCollection file_sort_order = 13; } @@ -166,6 +167,7 @@ message CreateExternalTableNode { datafusion_common.DfSchema schema = 4; repeated string table_partition_cols = 5; bool if_not_exists = 6; + bool or_replace = 15; bool temporary = 14; string definition = 7; repeated SortExprNodeCollection order_exprs = 10; @@ -243,7 +245,7 @@ message JoinNode { datafusion_common.JoinConstraint join_constraint = 4; repeated LogicalExprNode left_join_key = 5; repeated LogicalExprNode right_join_key = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; LogicalExprNode filter = 8; } @@ -726,6 +728,10 @@ message PhysicalPlanNode { ParquetSinkExecNode parquet_sink = 29; UnnestExecNode unnest = 30; JsonScanExecNode json_scan = 31; + CooperativeExecNode cooperative = 32; + GenerateSeriesNode generate_series = 33; + SortMergeJoinExecNode sort_merge_join = 34; + MemoryScanExecNode memory_scan = 35; } } @@ -858,6 +864,7 @@ message PhysicalScalarUdfNode { optional bytes fun_definition = 3; datafusion_common.ArrowType return_type = 4; bool nullable = 5; + string return_field_name = 6; } message PhysicalAggregateExprNode { @@ -869,6 +876,7 @@ message PhysicalAggregateExprNode { bool distinct = 3; bool ignore_nulls = 6; optional bytes fun_definition = 7; + string human_display = 8; } message PhysicalWindowExprNode { @@ -883,6 +891,8 @@ message PhysicalWindowExprNode { WindowFrame window_frame = 7; string name = 8; optional bytes fun_definition = 9; + bool ignore_nulls = 11; + bool distinct = 12; } message PhysicalIsNull { @@ -1023,6 +1033,7 @@ message CsvScanExecNode { string comment = 6; } bool newlines_in_values = 7; + bool truncate_rows = 8; } message JsonScanExecNode { @@ -1033,6 +1044,19 @@ message AvroScanExecNode { FileScanExecConf base_conf = 1; } +message MemoryScanExecNode { + repeated bytes partitions = 1; + datafusion_common.Schema schema = 2; + repeated uint32 projection = 3; + repeated PhysicalSortExprNodeCollection sort_information = 4; + bool show_sizes = 5; + optional uint32 fetch = 6; +} + +message CooperativeExecNode { + PhysicalPlanNode input = 1; +} + enum PartitionMode { COLLECT_LEFT = 0; PARTITIONED = 1; @@ -1045,7 +1069,7 @@ message HashJoinExecNode { repeated JoinOn on = 3; datafusion_common.JoinType join_type = 4; PartitionMode partition_mode = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated uint32 projection = 9; } @@ -1061,7 +1085,7 @@ message SymmetricHashJoinExecNode { repeated JoinOn on = 3; datafusion_common.JoinType join_type = 4; StreamPartitionMode partition_mode = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated PhysicalSortExprNode left_sort_exprs = 9; repeated PhysicalSortExprNode right_sort_exprs = 10; @@ -1217,6 +1241,7 @@ message CoalesceBatchesExecNode { message CoalescePartitionsExecNode { PhysicalPlanNode input = 1; + optional uint32 fetch = 2; } message PhysicalHashRepartition { @@ -1285,3 +1310,59 @@ message CteWorkTableScanNode { string name = 1; datafusion_common.Schema schema = 2; } + +enum GenerateSeriesName { + GS_GENERATE_SERIES = 0; + GS_RANGE = 1; +} + +message GenerateSeriesArgsContainsNull { + GenerateSeriesName name = 1; +} + +message GenerateSeriesArgsInt64 { + int64 start = 1; + int64 end = 2; + int64 step = 3; + bool include_end = 4; + GenerateSeriesName name = 5; +} + +message GenerateSeriesArgsTimestamp { + int64 start = 1; + int64 end = 2; + datafusion_common.IntervalMonthDayNanoValue step = 3; + optional string tz = 4; + bool include_end = 5; + GenerateSeriesName name = 6; +} + +message GenerateSeriesArgsDate { + int64 start = 1; + int64 end = 2; + datafusion_common.IntervalMonthDayNanoValue step = 3; + bool include_end = 4; + GenerateSeriesName name = 5; +} + +message GenerateSeriesNode { + datafusion_common.Schema schema = 1; + uint32 target_batch_size = 2; + + oneof args { + GenerateSeriesArgsContainsNull contains_null = 3; + GenerateSeriesArgsInt64 int64_args = 4; + GenerateSeriesArgsTimestamp timestamp_args = 5; + GenerateSeriesArgsDate date_args = 6; + } +} + +message SortMergeJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + datafusion_common.JoinType join_type = 4; + JoinFilter filter = 5; + repeated SortExprNode sort_options = 6; + datafusion_common.NullEquality null_equality = 7; +} \ No newline at end of file diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index da01d89c0c3d1..5b07e59e807f0 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -24,6 +24,7 @@ use crate::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use crate::protobuf; +use datafusion::execution::TaskContext; use datafusion_common::{plan_datafusion_err, Result}; use datafusion_expr::{ create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, @@ -170,6 +171,14 @@ impl Serializeable for Expr { fn expr_planners(&self) -> Vec> { vec![] } + + fn udafs(&self) -> std::collections::HashSet { + std::collections::HashSet::default() + } + + fn udwfs(&self) -> std::collections::HashSet { + std::collections::HashSet::default() + } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; @@ -308,13 +317,13 @@ pub fn physical_plan_from_json( let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; - back.try_into_physical_plan(ctx, &ctx.runtime_env(), &extension_codec) + back.try_into_physical_plan(&ctx.task_ctx(), &extension_codec) } /// Deserialize a PhysicalPlan from bytes pub fn physical_plan_from_bytes( bytes: &[u8], - ctx: &SessionContext, + ctx: &TaskContext, ) -> Result> { let extension_codec = DefaultPhysicalExtensionCodec {}; physical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec) @@ -323,10 +332,10 @@ pub fn physical_plan_from_bytes( /// Deserialize a PhysicalPlan from bytes pub fn physical_plan_from_bytes_with_extension_codec( bytes: &[u8], - ctx: &SessionContext, + ctx: &TaskContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let protobuf = protobuf::PhysicalPlanNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - protobuf.try_into_physical_plan(ctx, &ctx.runtime_env(), extension_codec) + protobuf.try_into_physical_plan(ctx, extension_codec) } diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index eae2425f8ac19..5d46d41f793ed 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -59,4 +59,12 @@ impl FunctionRegistry for NoRegistry { fn expr_planners(&self) -> Vec> { vec![] } + + fn udafs(&self) -> HashSet { + HashSet::new() + } + + fn udwfs(&self) -> HashSet { + HashSet::new() + } } diff --git a/datafusion/proto/src/common.rs b/datafusion/proto/src/common.rs index 2b052a31b8b76..2aa12dd3504b6 100644 --- a/datafusion/proto/src/common.rs +++ b/datafusion/proto/src/common.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_common::{internal_datafusion_err, internal_err, Result}; pub(crate) fn str_to_byte(s: &String, description: &str) -> Result { if s.len() != 1 { @@ -29,9 +29,9 @@ pub(crate) fn str_to_byte(s: &String, description: &str) -> Result { pub(crate) fn byte_to_string(b: u8, description: &str) -> Result { let b = &[b]; let b = std::str::from_utf8(b).map_err(|_| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "Invalid CSV {description}: can not represent {b:0x?} as utf8" - )) + ) })?; Ok(b.to_owned()) } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index b6e9bc1379832..aa7c3d51a9d6d 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -1,10 +1,10 @@ // This file is @generated by prost-build. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ColumnRelation { #[prost(string, tag = "1")] pub relation: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Column { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, @@ -28,7 +28,7 @@ pub struct DfSchema { ::prost::alloc::string::String, >, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CsvFormat { #[prost(message, optional, tag = "5")] pub options: ::core::option::Option, @@ -38,31 +38,33 @@ pub struct ParquetFormat { #[prost(message, optional, tag = "2")] pub options: ::core::option::Option, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct AvroFormat {} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct NdJsonFormat { #[prost(message, optional, tag = "1")] pub options: ::core::option::Option, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ArrowFormat {} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct PrimaryKeyConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct UniqueConstraint { #[prost(uint64, repeated, tag = "1")] pub indices: ::prost::alloc::vec::Vec, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Constraint { #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] pub constraint_mode: ::core::option::Option, } /// Nested message and enum types in `Constraint`. pub mod constraint { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum ConstraintMode { #[prost(message, tag = "1")] PrimaryKey(super::PrimaryKeyConstraint), @@ -75,9 +77,9 @@ pub struct Constraints { #[prost(message, repeated, tag = "1")] pub constraints: ::prost::alloc::vec::Vec, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct AvroOptions {} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ArrowOptions {} #[derive(Clone, PartialEq, ::prost::Message)] pub struct Schema { @@ -106,24 +108,36 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(bool, tag = "6")] - pub dict_ordered: bool, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Timestamp { #[prost(enumeration = "TimeUnit", tag = "1")] pub time_unit: i32, #[prost(string, tag = "2")] pub timezone: ::prost::alloc::string::String, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct Decimal { +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal32Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal64Type { #[prost(uint32, tag = "3")] pub precision: u32, #[prost(int32, tag = "4")] pub scale: i32, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal128Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Decimal256Type { #[prost(uint32, tag = "3")] pub precision: u32, @@ -184,7 +198,7 @@ pub struct ScalarNestedValue { } /// Nested message and enum types in `ScalarNestedValue`. pub mod scalar_nested_value { - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Dictionary { #[prost(bytes = "vec", tag = "1")] pub ipc_message: ::prost::alloc::vec::Vec, @@ -192,14 +206,14 @@ pub mod scalar_nested_value { pub arrow_data: ::prost::alloc::vec::Vec, } } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScalarTime32Value { #[prost(oneof = "scalar_time32_value::Value", tags = "1, 2")] pub value: ::core::option::Option, } /// Nested message and enum types in `ScalarTime32Value`. pub mod scalar_time32_value { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Value { #[prost(int32, tag = "1")] Time32SecondValue(i32), @@ -207,14 +221,14 @@ pub mod scalar_time32_value { Time32MillisecondValue(i32), } } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScalarTime64Value { #[prost(oneof = "scalar_time64_value::Value", tags = "1, 2")] pub value: ::core::option::Option, } /// Nested message and enum types in `ScalarTime64Value`. pub mod scalar_time64_value { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Value { #[prost(int64, tag = "1")] Time64MicrosecondValue(i64), @@ -222,7 +236,7 @@ pub mod scalar_time64_value { Time64NanosecondValue(i64), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScalarTimestampValue { #[prost(string, tag = "5")] pub timezone: ::prost::alloc::string::String, @@ -231,7 +245,7 @@ pub struct ScalarTimestampValue { } /// Nested message and enum types in `ScalarTimestampValue`. pub mod scalar_timestamp_value { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum Value { #[prost(int64, tag = "1")] TimeMicrosecondValue(i64), @@ -250,14 +264,14 @@ pub struct ScalarDictionaryValue { #[prost(message, optional, boxed, tag = "2")] pub value: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalDayTimeValue { #[prost(int32, tag = "1")] pub days: i32, #[prost(int32, tag = "2")] pub milliseconds: i32, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct IntervalMonthDayNanoValue { #[prost(int32, tag = "1")] pub months: i32, @@ -286,7 +300,7 @@ pub struct UnionValue { #[prost(enumeration = "UnionMode", tag = "4")] pub mode: i32, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] pub values: ::prost::alloc::vec::Vec, @@ -297,7 +311,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 43, 44, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -352,6 +366,10 @@ pub mod scalar_value { StructValue(super::ScalarNestedValue), #[prost(message, tag = "41")] MapValue(super::ScalarNestedValue), + #[prost(message, tag = "43")] + Decimal32Value(super::Decimal32), + #[prost(message, tag = "44")] + Decimal64Value(super::Decimal64), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] @@ -390,7 +408,25 @@ pub mod scalar_value { UnionValue(::prost::alloc::boxed::Box), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal32 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct Decimal64 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Decimal128 { #[prost(bytes = "vec", tag = "1")] pub value: ::prost::alloc::vec::Vec, @@ -399,7 +435,7 @@ pub struct Decimal128 { #[prost(int64, tag = "3")] pub s: i64, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Decimal256 { #[prost(bytes = "vec", tag = "1")] pub value: ::prost::alloc::vec::Vec, @@ -413,7 +449,7 @@ pub struct Decimal256 { pub struct ArrowType { #[prost( oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 40, 41, 24, 36, 25, 26, 27, 28, 29, 30, 33" )] pub arrow_type_enum: ::core::option::Option, } @@ -480,8 +516,12 @@ pub mod arrow_type { Time64(i32), #[prost(enumeration = "super::IntervalUnit", tag = "23")] Interval(i32), + #[prost(message, tag = "40")] + Decimal32(super::Decimal32Type), + #[prost(message, tag = "41")] + Decimal64(super::Decimal64Type), #[prost(message, tag = "24")] - Decimal(super::Decimal), + Decimal128(super::Decimal128Type), #[prost(message, tag = "36")] Decimal256(super::Decimal256Type), #[prost(message, tag = "25")] @@ -509,14 +549,14 @@ pub mod arrow_type { /// i32 Two = 2; /// } /// } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct EmptyMessage {} -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct JsonWriterOptions { #[prost(enumeration = "CompressionTypeVariant", tag = "1")] pub compression: i32, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CsvWriterOptions { /// Compression type #[prost(enumeration = "CompressionTypeVariant", tag = "1")] @@ -553,7 +593,7 @@ pub struct CsvWriterOptions { pub double_quote: bool, } /// Options controlling CSV format -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CsvOptions { /// Indicates if the CSV has a header row #[prost(bytes = "vec", tag = "1")] @@ -606,9 +646,12 @@ pub struct CsvOptions { /// Optional terminator character as a byte #[prost(bytes = "vec", tag = "17")] pub terminator: ::prost::alloc::vec::Vec, + /// Indicates if truncated rows are allowed + #[prost(bytes = "vec", tag = "18")] + pub truncated_rows: ::prost::alloc::vec::Vec, } /// Options controlling CSV format -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct JsonOptions { /// Compression type #[prost(enumeration = "CompressionTypeVariant", tag = "1")] @@ -662,34 +705,30 @@ pub struct ParquetColumnOptions { pub bloom_filter_ndv_opt: ::core::option::Option< parquet_column_options::BloomFilterNdvOpt, >, - #[prost(oneof = "parquet_column_options::MaxStatisticsSizeOpt", tags = "8")] - pub max_statistics_size_opt: ::core::option::Option< - parquet_column_options::MaxStatisticsSizeOpt, - >, } /// Nested message and enum types in `ParquetColumnOptions`. pub mod parquet_column_options { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum BloomFilterEnabledOpt { #[prost(bool, tag = "1")] BloomFilterEnabled(bool), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum EncodingOpt { #[prost(string, tag = "2")] Encoding(::prost::alloc::string::String), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum DictionaryEnabledOpt { #[prost(bool, tag = "3")] DictionaryEnabled(bool), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum CompressionOpt { #[prost(string, tag = "4")] Compression(::prost::alloc::string::String), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum StatisticsEnabledOpt { #[prost(string, tag = "5")] StatisticsEnabled(::prost::alloc::string::String), @@ -699,16 +738,11 @@ pub mod parquet_column_options { #[prost(double, tag = "6")] BloomFilterFpp(f64), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum BloomFilterNdvOpt { #[prost(uint64, tag = "7")] BloomFilterNdv(u64), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] - pub enum MaxStatisticsSizeOpt { - #[prost(uint32, tag = "8")] - MaxStatisticsSize(u32), - } } #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetOptions { @@ -786,10 +820,6 @@ pub struct ParquetOptions { pub statistics_enabled_opt: ::core::option::Option< parquet_options::StatisticsEnabledOpt, >, - #[prost(oneof = "parquet_options::MaxStatisticsSizeOpt", tags = "14")] - pub max_statistics_size_opt: ::core::option::Option< - parquet_options::MaxStatisticsSizeOpt, - >, #[prost(oneof = "parquet_options::ColumnIndexTruncateLengthOpt", tags = "17")] pub column_index_truncate_length_opt: ::core::option::Option< parquet_options::ColumnIndexTruncateLengthOpt, @@ -804,45 +834,46 @@ pub struct ParquetOptions { pub bloom_filter_fpp_opt: ::core::option::Option, #[prost(oneof = "parquet_options::BloomFilterNdvOpt", tags = "22")] pub bloom_filter_ndv_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::CoerceInt96Opt", tags = "32")] + pub coerce_int96_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::MaxPredicateCacheSizeOpt", tags = "33")] + pub max_predicate_cache_size_opt: ::core::option::Option< + parquet_options::MaxPredicateCacheSizeOpt, + >, } /// Nested message and enum types in `ParquetOptions`. pub mod parquet_options { - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum MetadataSizeHintOpt { #[prost(uint64, tag = "4")] MetadataSizeHint(u64), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum CompressionOpt { #[prost(string, tag = "10")] Compression(::prost::alloc::string::String), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum DictionaryEnabledOpt { #[prost(bool, tag = "11")] DictionaryEnabled(bool), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum StatisticsEnabledOpt { #[prost(string, tag = "13")] StatisticsEnabled(::prost::alloc::string::String), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] - pub enum MaxStatisticsSizeOpt { - #[prost(uint64, tag = "14")] - MaxStatisticsSize(u64), - } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum ColumnIndexTruncateLengthOpt { #[prost(uint64, tag = "17")] ColumnIndexTruncateLength(u64), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum StatisticsTruncateLengthOpt { #[prost(uint64, tag = "31")] StatisticsTruncateLength(u64), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum EncodingOpt { #[prost(string, tag = "19")] Encoding(::prost::alloc::string::String), @@ -852,11 +883,21 @@ pub mod parquet_options { #[prost(double, tag = "21")] BloomFilterFpp(f64), } - #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum BloomFilterNdvOpt { #[prost(uint64, tag = "22")] BloomFilterNdv(u64), } + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum CoerceInt96Opt { + #[prost(string, tag = "32")] + CoerceInt96(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum MaxPredicateCacheSizeOpt { + #[prost(uint64, tag = "33")] + MaxPredicateCacheSize(u64), + } } #[derive(Clone, PartialEq, ::prost::Message)] pub struct Precision { @@ -899,6 +940,7 @@ pub enum JoinType { Rightsemi = 6, Rightanti = 7, Leftmark = 8, + Rightmark = 9, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -916,6 +958,7 @@ impl JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -930,6 +973,7 @@ impl JoinType { "RIGHTSEMI" => Some(Self::Rightsemi), "RIGHTANTI" => Some(Self::Rightanti), "LEFTMARK" => Some(Self::Leftmark), + "RIGHTMARK" => Some(Self::Rightmark), _ => None, } } @@ -962,6 +1006,32 @@ impl JoinConstraint { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum NullEquality { + NullEqualsNothing = 0, + NullEqualsNull = 1, +} +impl NullEquality { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "NULL_EQUALS_NOTHING" => Some(Self::NullEqualsNothing), + "NULL_EQUALS_NULL" => Some(Self::NullEqualsNull), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum TimeUnit { Second = 0, Millisecond = 1, diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 6166b6ec47961..29967d812000f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2050,10 +2050,16 @@ impl serde::Serialize for CoalescePartitionsExecNode { if self.input.is_some() { len += 1; } + if self.fetch.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CoalescePartitionsExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } + if let Some(v) = self.fetch.as_ref() { + struct_ser.serialize_field("fetch", v)?; + } struct_ser.end() } } @@ -2065,11 +2071,13 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { { const FIELDS: &[&str] = &[ "input", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2092,6 +2100,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { { match value { "input" => Ok(GeneratedField::Input), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2112,6 +2121,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { V: serde::de::MapAccess<'de>, { let mut input__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -2120,10 +2130,19 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { } input__ = map_.next_value()?; } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(CoalescePartitionsExecNode { input: input__, + fetch: fetch__, }) } } @@ -2555,6 +2574,97 @@ impl<'de> serde::Deserialize<'de> for ColumnUnnestListRecursions { deserializer.deserialize_struct("datafusion.ColumnUnnestListRecursions", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CooperativeExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CooperativeExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CooperativeExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CooperativeExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CooperativeExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + } + } + Ok(CooperativeExecNode { + input: input__, + }) + } + } + deserializer.deserialize_struct("datafusion.CooperativeExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CopyToNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -2984,6 +3094,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.if_not_exists { len += 1; } + if self.or_replace { + len += 1; + } if self.temporary { len += 1; } @@ -3024,6 +3137,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.if_not_exists { struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; } + if self.or_replace { + struct_ser.serialize_field("orReplace", &self.or_replace)?; + } if self.temporary { struct_ser.serialize_field("temporary", &self.temporary)?; } @@ -3064,6 +3180,8 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "tablePartitionCols", "if_not_exists", "ifNotExists", + "or_replace", + "orReplace", "temporary", "definition", "order_exprs", @@ -3083,6 +3201,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { Schema, TablePartitionCols, IfNotExists, + OrReplace, Temporary, Definition, OrderExprs, @@ -3117,6 +3236,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "schema" => Ok(GeneratedField::Schema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "orReplace" | "or_replace" => Ok(GeneratedField::OrReplace), "temporary" => Ok(GeneratedField::Temporary), "definition" => Ok(GeneratedField::Definition), "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), @@ -3149,6 +3269,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { let mut schema__ = None; let mut table_partition_cols__ = None; let mut if_not_exists__ = None; + let mut or_replace__ = None; let mut temporary__ = None; let mut definition__ = None; let mut order_exprs__ = None; @@ -3194,6 +3315,12 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } if_not_exists__ = Some(map_.next_value()?); } + GeneratedField::OrReplace => { + if or_replace__.is_some() { + return Err(serde::de::Error::duplicate_field("orReplace")); + } + or_replace__ = Some(map_.next_value()?); + } GeneratedField::Temporary => { if temporary__.is_some() { return Err(serde::de::Error::duplicate_field("temporary")); @@ -3249,6 +3376,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { schema: schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), if_not_exists: if_not_exists__.unwrap_or_default(), + or_replace: or_replace__.unwrap_or_default(), temporary: temporary__.unwrap_or_default(), definition: definition__.unwrap_or_default(), order_exprs: order_exprs__.unwrap_or_default(), @@ -3661,6 +3789,9 @@ impl serde::Serialize for CsvScanExecNode { if self.newlines_in_values { len += 1; } + if self.truncate_rows { + len += 1; + } if self.optional_escape.is_some() { len += 1; } @@ -3683,6 +3814,9 @@ impl serde::Serialize for CsvScanExecNode { if self.newlines_in_values { struct_ser.serialize_field("newlinesInValues", &self.newlines_in_values)?; } + if self.truncate_rows { + struct_ser.serialize_field("truncateRows", &self.truncate_rows)?; + } if let Some(v) = self.optional_escape.as_ref() { match v { csv_scan_exec_node::OptionalEscape::Escape(v) => { @@ -3715,6 +3849,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "quote", "newlines_in_values", "newlinesInValues", + "truncate_rows", + "truncateRows", "escape", "comment", ]; @@ -3726,6 +3862,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { Delimiter, Quote, NewlinesInValues, + TruncateRows, Escape, Comment, } @@ -3754,6 +3891,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { "delimiter" => Ok(GeneratedField::Delimiter), "quote" => Ok(GeneratedField::Quote), "newlinesInValues" | "newlines_in_values" => Ok(GeneratedField::NewlinesInValues), + "truncateRows" | "truncate_rows" => Ok(GeneratedField::TruncateRows), "escape" => Ok(GeneratedField::Escape), "comment" => Ok(GeneratedField::Comment), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -3780,6 +3918,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { let mut delimiter__ = None; let mut quote__ = None; let mut newlines_in_values__ = None; + let mut truncate_rows__ = None; let mut optional_escape__ = None; let mut optional_comment__ = None; while let Some(k) = map_.next_key()? { @@ -3814,6 +3953,12 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { } newlines_in_values__ = Some(map_.next_value()?); } + GeneratedField::TruncateRows => { + if truncate_rows__.is_some() { + return Err(serde::de::Error::duplicate_field("truncateRows")); + } + truncate_rows__ = Some(map_.next_value()?); + } GeneratedField::Escape => { if optional_escape__.is_some() { return Err(serde::de::Error::duplicate_field("escape")); @@ -3834,6 +3979,7 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { delimiter: delimiter__.unwrap_or_default(), quote: quote__.unwrap_or_default(), newlines_in_values: newlines_in_values__.unwrap_or_default(), + truncate_rows: truncate_rows__.unwrap_or_default(), optional_escape: optional_escape__, optional_comment: optional_comment__, }) @@ -6428,41 +6574,861 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { } deserializer.deserialize_identifier(GeneratedVisitor) } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FixedSizeBinary; + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FixedSizeBinary; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FixedSizeBinary") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut length__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Length => { + if length__.is_some() { + return Err(serde::de::Error::duplicate_field("length")); + } + length__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(FixedSizeBinary { + length: length__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for FullTableReference { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.catalog.is_empty() { + len += 1; + } + if !self.schema.is_empty() { + len += 1; + } + if !self.table.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FullTableReference", len)?; + if !self.catalog.is_empty() { + struct_ser.serialize_field("catalog", &self.catalog)?; + } + if !self.schema.is_empty() { + struct_ser.serialize_field("schema", &self.schema)?; + } + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FullTableReference { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "catalog", + "schema", + "table", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Catalog, + Schema, + Table, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "catalog" => Ok(GeneratedField::Catalog), + "schema" => Ok(GeneratedField::Schema), + "table" => Ok(GeneratedField::Table), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FullTableReference; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FullTableReference") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut catalog__ = None; + let mut schema__ = None; + let mut table__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Catalog => { + if catalog__.is_some() { + return Err(serde::de::Error::duplicate_field("catalog")); + } + catalog__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = Some(map_.next_value()?); + } + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); + } + table__ = Some(map_.next_value()?); + } + } + } + Ok(FullTableReference { + catalog: catalog__.unwrap_or_default(), + schema: schema__.unwrap_or_default(), + table: table__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FullTableReference", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for GenerateSeriesArgsContainsNull { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.name != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GenerateSeriesArgsContainsNull", len)?; + if self.name != 0 { + let v = GenerateSeriesName::try_from(self.name) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.name)))?; + struct_ser.serialize_field("name", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsContainsNull { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GenerateSeriesArgsContainsNull; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.GenerateSeriesArgsContainsNull") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(GenerateSeriesArgsContainsNull { + name: name__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.GenerateSeriesArgsContainsNull", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for GenerateSeriesArgsDate { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start != 0 { + len += 1; + } + if self.end != 0 { + len += 1; + } + if self.step.is_some() { + len += 1; + } + if self.include_end { + len += 1; + } + if self.name != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GenerateSeriesArgsDate", len)?; + if self.start != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("start", ToString::to_string(&self.start).as_str())?; + } + if self.end != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("end", ToString::to_string(&self.end).as_str())?; + } + if let Some(v) = self.step.as_ref() { + struct_ser.serialize_field("step", v)?; + } + if self.include_end { + struct_ser.serialize_field("includeEnd", &self.include_end)?; + } + if self.name != 0 { + let v = GenerateSeriesName::try_from(self.name) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.name)))?; + struct_ser.serialize_field("name", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsDate { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "end", + "step", + "include_end", + "includeEnd", + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + End, + Step, + IncludeEnd, + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "end" => Ok(GeneratedField::End), + "step" => Ok(GeneratedField::Step), + "includeEnd" | "include_end" => Ok(GeneratedField::IncludeEnd), + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GenerateSeriesArgsDate; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.GenerateSeriesArgsDate") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut end__ = None; + let mut step__ = None; + let mut include_end__ = None; + let mut name__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::End => { + if end__.is_some() { + return Err(serde::de::Error::duplicate_field("end")); + } + end__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Step => { + if step__.is_some() { + return Err(serde::de::Error::duplicate_field("step")); + } + step__ = map_.next_value()?; + } + GeneratedField::IncludeEnd => { + if include_end__.is_some() { + return Err(serde::de::Error::duplicate_field("includeEnd")); + } + include_end__ = Some(map_.next_value()?); + } + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(GenerateSeriesArgsDate { + start: start__.unwrap_or_default(), + end: end__.unwrap_or_default(), + step: step__, + include_end: include_end__.unwrap_or_default(), + name: name__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.GenerateSeriesArgsDate", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for GenerateSeriesArgsInt64 { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start != 0 { + len += 1; + } + if self.end != 0 { + len += 1; + } + if self.step != 0 { + len += 1; + } + if self.include_end { + len += 1; + } + if self.name != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GenerateSeriesArgsInt64", len)?; + if self.start != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("start", ToString::to_string(&self.start).as_str())?; + } + if self.end != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("end", ToString::to_string(&self.end).as_str())?; + } + if self.step != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("step", ToString::to_string(&self.step).as_str())?; + } + if self.include_end { + struct_ser.serialize_field("includeEnd", &self.include_end)?; + } + if self.name != 0 { + let v = GenerateSeriesName::try_from(self.name) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.name)))?; + struct_ser.serialize_field("name", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsInt64 { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "end", + "step", + "include_end", + "includeEnd", + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + End, + Step, + IncludeEnd, + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "end" => Ok(GeneratedField::End), + "step" => Ok(GeneratedField::Step), + "includeEnd" | "include_end" => Ok(GeneratedField::IncludeEnd), + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GenerateSeriesArgsInt64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.GenerateSeriesArgsInt64") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut end__ = None; + let mut step__ = None; + let mut include_end__ = None; + let mut name__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::End => { + if end__.is_some() { + return Err(serde::de::Error::duplicate_field("end")); + } + end__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Step => { + if step__.is_some() { + return Err(serde::de::Error::duplicate_field("step")); + } + step__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::IncludeEnd => { + if include_end__.is_some() { + return Err(serde::de::Error::duplicate_field("includeEnd")); + } + include_end__ = Some(map_.next_value()?); + } + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(GenerateSeriesArgsInt64 { + start: start__.unwrap_or_default(), + end: end__.unwrap_or_default(), + step: step__.unwrap_or_default(), + include_end: include_end__.unwrap_or_default(), + name: name__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.GenerateSeriesArgsInt64", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for GenerateSeriesArgsTimestamp { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start != 0 { + len += 1; + } + if self.end != 0 { + len += 1; + } + if self.step.is_some() { + len += 1; + } + if self.tz.is_some() { + len += 1; + } + if self.include_end { + len += 1; + } + if self.name != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GenerateSeriesArgsTimestamp", len)?; + if self.start != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("start", ToString::to_string(&self.start).as_str())?; + } + if self.end != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("end", ToString::to_string(&self.end).as_str())?; + } + if let Some(v) = self.step.as_ref() { + struct_ser.serialize_field("step", v)?; + } + if let Some(v) = self.tz.as_ref() { + struct_ser.serialize_field("tz", v)?; + } + if self.include_end { + struct_ser.serialize_field("includeEnd", &self.include_end)?; + } + if self.name != 0 { + let v = GenerateSeriesName::try_from(self.name) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.name)))?; + struct_ser.serialize_field("name", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for GenerateSeriesArgsTimestamp { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "end", + "step", + "tz", + "include_end", + "includeEnd", + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + End, + Step, + Tz, + IncludeEnd, + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "end" => Ok(GeneratedField::End), + "step" => Ok(GeneratedField::Step), + "tz" => Ok(GeneratedField::Tz), + "includeEnd" | "include_end" => Ok(GeneratedField::IncludeEnd), + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GenerateSeriesArgsTimestamp; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.GenerateSeriesArgsTimestamp") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut end__ = None; + let mut step__ = None; + let mut tz__ = None; + let mut include_end__ = None; + let mut name__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); + } + start__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::End => { + if end__.is_some() { + return Err(serde::de::Error::duplicate_field("end")); + } + end__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Step => { + if step__.is_some() { + return Err(serde::de::Error::duplicate_field("step")); + } + step__ = map_.next_value()?; + } + GeneratedField::Tz => { + if tz__.is_some() { + return Err(serde::de::Error::duplicate_field("tz")); + } + tz__ = map_.next_value()?; + } + GeneratedField::IncludeEnd => { + if include_end__.is_some() { + return Err(serde::de::Error::duplicate_field("includeEnd")); + } + include_end__ = Some(map_.next_value()?); + } + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(GenerateSeriesArgsTimestamp { + start: start__.unwrap_or_default(), + end: end__.unwrap_or_default(), + step: step__, + tz: tz__, + include_end: include_end__.unwrap_or_default(), + name: name__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.GenerateSeriesArgsTimestamp", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for GenerateSeriesName { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::GsGenerateSeries => "GS_GENERATE_SERIES", + Self::GsRange => "GS_RANGE", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for GenerateSeriesName { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "GS_GENERATE_SERIES", + "GS_RANGE", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GenerateSeriesName; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FixedSizeBinary") + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) } - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, { - let mut length__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Length => { - if length__.is_some() { - return Err(serde::de::Error::duplicate_field("length")); - } - length__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - } + match value { + "GS_GENERATE_SERIES" => Ok(GenerateSeriesName::GsGenerateSeries), + "GS_RANGE" => Ok(GenerateSeriesName::GsRange), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } - Ok(FixedSizeBinary { - length: length__.unwrap_or_default(), - }) } } - deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) + deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for FullTableReference { +impl serde::Serialize for GenerateSeriesNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6470,45 +7436,69 @@ impl serde::Serialize for FullTableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.catalog.is_empty() { + if self.schema.is_some() { len += 1; } - if !self.schema.is_empty() { + if self.target_batch_size != 0 { len += 1; } - if !self.table.is_empty() { + if self.args.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FullTableReference", len)?; - if !self.catalog.is_empty() { - struct_ser.serialize_field("catalog", &self.catalog)?; + let mut struct_ser = serializer.serialize_struct("datafusion.GenerateSeriesNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } - if !self.schema.is_empty() { - struct_ser.serialize_field("schema", &self.schema)?; + if self.target_batch_size != 0 { + struct_ser.serialize_field("targetBatchSize", &self.target_batch_size)?; } - if !self.table.is_empty() { - struct_ser.serialize_field("table", &self.table)?; + if let Some(v) = self.args.as_ref() { + match v { + generate_series_node::Args::ContainsNull(v) => { + struct_ser.serialize_field("containsNull", v)?; + } + generate_series_node::Args::Int64Args(v) => { + struct_ser.serialize_field("int64Args", v)?; + } + generate_series_node::Args::TimestampArgs(v) => { + struct_ser.serialize_field("timestampArgs", v)?; + } + generate_series_node::Args::DateArgs(v) => { + struct_ser.serialize_field("dateArgs", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FullTableReference { +impl<'de> serde::Deserialize<'de> for GenerateSeriesNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "catalog", "schema", - "table", + "target_batch_size", + "targetBatchSize", + "contains_null", + "containsNull", + "int64_args", + "int64Args", + "timestamp_args", + "timestampArgs", + "date_args", + "dateArgs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Catalog, Schema, - Table, + TargetBatchSize, + ContainsNull, + Int64Args, + TimestampArgs, + DateArgs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6530,9 +7520,12 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { E: serde::de::Error, { match value { - "catalog" => Ok(GeneratedField::Catalog), "schema" => Ok(GeneratedField::Schema), - "table" => Ok(GeneratedField::Table), + "targetBatchSize" | "target_batch_size" => Ok(GeneratedField::TargetBatchSize), + "containsNull" | "contains_null" => Ok(GeneratedField::ContainsNull), + "int64Args" | "int64_args" => Ok(GeneratedField::Int64Args), + "timestampArgs" | "timestamp_args" => Ok(GeneratedField::TimestampArgs), + "dateArgs" | "date_args" => Ok(GeneratedField::DateArgs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6542,49 +7535,73 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FullTableReference; + type Value = GenerateSeriesNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FullTableReference") + formatter.write_str("struct datafusion.GenerateSeriesNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut catalog__ = None; let mut schema__ = None; - let mut table__ = None; + let mut target_batch_size__ = None; + let mut args__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Catalog => { - if catalog__.is_some() { - return Err(serde::de::Error::duplicate_field("catalog")); - } - catalog__ = Some(map_.next_value()?); - } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); } - schema__ = Some(map_.next_value()?); + schema__ = map_.next_value()?; } - GeneratedField::Table => { - if table__.is_some() { - return Err(serde::de::Error::duplicate_field("table")); + GeneratedField::TargetBatchSize => { + if target_batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("targetBatchSize")); } - table__ = Some(map_.next_value()?); + target_batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::ContainsNull => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("containsNull")); + } + args__ = map_.next_value::<::std::option::Option<_>>()?.map(generate_series_node::Args::ContainsNull) +; + } + GeneratedField::Int64Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("int64Args")); + } + args__ = map_.next_value::<::std::option::Option<_>>()?.map(generate_series_node::Args::Int64Args) +; + } + GeneratedField::TimestampArgs => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampArgs")); + } + args__ = map_.next_value::<::std::option::Option<_>>()?.map(generate_series_node::Args::TimestampArgs) +; + } + GeneratedField::DateArgs => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("dateArgs")); + } + args__ = map_.next_value::<::std::option::Option<_>>()?.map(generate_series_node::Args::DateArgs) +; } } } - Ok(FullTableReference { - catalog: catalog__.unwrap_or_default(), - schema: schema__.unwrap_or_default(), - table: table__.unwrap_or_default(), + Ok(GenerateSeriesNode { + schema: schema__, + target_batch_size: target_batch_size__.unwrap_or_default(), + args: args__, }) } } - deserializer.deserialize_struct("datafusion.FullTableReference", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.GenerateSeriesNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for GlobalLimitExecNode { @@ -6832,7 +7849,7 @@ impl serde::Serialize for HashJoinExecNode { if self.partition_mode != 0 { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -6861,8 +7878,10 @@ impl serde::Serialize for HashJoinExecNode { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; struct_ser.serialize_field("partitionMode", &v)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -6887,8 +7906,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "joinType", "partition_mode", "partitionMode", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", "projection", ]; @@ -6900,7 +7919,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { On, JoinType, PartitionMode, - NullEqualsNull, + NullEquality, Filter, Projection, } @@ -6929,7 +7948,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "on" => Ok(GeneratedField::On), "joinType" | "join_type" => Ok(GeneratedField::JoinType), "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "projection" => Ok(GeneratedField::Projection), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -6956,7 +7975,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { let mut on__ = None; let mut join_type__ = None; let mut partition_mode__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; let mut projection__ = None; while let Some(k) = map_.next_key()? { @@ -6991,11 +8010,11 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { } partition_mode__ = Some(map_.next_value::()? as i32); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -7020,7 +8039,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { on: on__.unwrap_or_default(), join_type: join_type__.unwrap_or_default(), partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, projection: projection__.unwrap_or_default(), }) @@ -8456,7 +9475,7 @@ impl serde::Serialize for JoinNode { if !self.right_join_key.is_empty() { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -8485,8 +9504,10 @@ impl serde::Serialize for JoinNode { if !self.right_join_key.is_empty() { struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -8511,8 +9532,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode { "leftJoinKey", "right_join_key", "rightJoinKey", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", ]; @@ -8524,7 +9545,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { JoinConstraint, LeftJoinKey, RightJoinKey, - NullEqualsNull, + NullEquality, Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -8553,7 +9574,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { "joinConstraint" | "join_constraint" => Ok(GeneratedField::JoinConstraint), "leftJoinKey" | "left_join_key" => Ok(GeneratedField::LeftJoinKey), "rightJoinKey" | "right_join_key" => Ok(GeneratedField::RightJoinKey), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -8580,7 +9601,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { let mut join_constraint__ = None; let mut left_join_key__ = None; let mut right_join_key__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { @@ -8620,11 +9641,11 @@ impl<'de> serde::Deserialize<'de> for JoinNode { } right_join_key__ = Some(map_.next_value()?); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -8641,7 +9662,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { join_constraint: join_constraint__.unwrap_or_default(), left_join_key: left_join_key__.unwrap_or_default(), right_join_key: right_join_key__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, }) } @@ -9793,6 +10814,9 @@ impl serde::Serialize for ListingTableScanNode { listing_table_scan_node::FileFormatType::Json(v) => { struct_ser.serialize_field("json", v)?; } + listing_table_scan_node::FileFormatType::Arrow(v) => { + struct_ser.serialize_field("arrow", v)?; + } } } struct_ser.end() @@ -9825,6 +10849,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { "parquet", "avro", "json", + "arrow", ]; #[allow(clippy::enum_variant_names)] @@ -9843,6 +10868,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { Parquet, Avro, Json, + Arrow, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9878,6 +10904,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { "parquet" => Ok(GeneratedField::Parquet), "avro" => Ok(GeneratedField::Avro), "json" => Ok(GeneratedField::Json), + "arrow" => Ok(GeneratedField::Arrow), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9998,6 +11025,13 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { return Err(serde::de::Error::duplicate_field("json")); } file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Json) +; + } + GeneratedField::Arrow => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrow")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Arrow) ; } } @@ -11614,16 +12648,202 @@ impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { if sort_expr__.is_some() { return Err(serde::de::Error::duplicate_field("sortExpr")); } - sort_expr__ = Some(map_.next_value()?); + sort_expr__ = Some(map_.next_value()?); + } + } + } + Ok(MaybePhysicalSortExprs { + sort_expr: sort_expr__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for MemoryScanExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.partitions.is_empty() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if !self.projection.is_empty() { + len += 1; + } + if !self.sort_information.is_empty() { + len += 1; + } + if self.show_sizes { + len += 1; + } + if self.fetch.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.MemoryScanExecNode", len)?; + if !self.partitions.is_empty() { + struct_ser.serialize_field("partitions", &self.partitions.iter().map(pbjson::private::base64::encode).collect::>())?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.projection.is_empty() { + struct_ser.serialize_field("projection", &self.projection)?; + } + if !self.sort_information.is_empty() { + struct_ser.serialize_field("sortInformation", &self.sort_information)?; + } + if self.show_sizes { + struct_ser.serialize_field("showSizes", &self.show_sizes)?; + } + if let Some(v) = self.fetch.as_ref() { + struct_ser.serialize_field("fetch", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for MemoryScanExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "partitions", + "schema", + "projection", + "sort_information", + "sortInformation", + "show_sizes", + "showSizes", + "fetch", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Partitions, + Schema, + Projection, + SortInformation, + ShowSizes, + Fetch, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "partitions" => Ok(GeneratedField::Partitions), + "schema" => Ok(GeneratedField::Schema), + "projection" => Ok(GeneratedField::Projection), + "sortInformation" | "sort_information" => Ok(GeneratedField::SortInformation), + "showSizes" | "show_sizes" => Ok(GeneratedField::ShowSizes), + "fetch" => Ok(GeneratedField::Fetch), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = MemoryScanExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.MemoryScanExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut partitions__ = None; + let mut schema__ = None; + let mut projection__ = None; + let mut sort_information__ = None; + let mut show_sizes__ = None; + let mut fetch__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Partitions => { + if partitions__.is_some() { + return Err(serde::de::Error::duplicate_field("partitions")); + } + partitions__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + GeneratedField::SortInformation => { + if sort_information__.is_some() { + return Err(serde::de::Error::duplicate_field("sortInformation")); + } + sort_information__ = Some(map_.next_value()?); + } + GeneratedField::ShowSizes => { + if show_sizes__.is_some() { + return Err(serde::de::Error::duplicate_field("showSizes")); + } + show_sizes__ = Some(map_.next_value()?); + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; } } } - Ok(MaybePhysicalSortExprs { - sort_expr: sort_expr__.unwrap_or_default(), + Ok(MemoryScanExecNode { + partitions: partitions__.unwrap_or_default(), + schema: schema__, + projection: projection__.unwrap_or_default(), + sort_information: sort_information__.unwrap_or_default(), + show_sizes: show_sizes__.unwrap_or_default(), + fetch: fetch__, }) } } - deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.MemoryScanExecNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for NamedStructField { @@ -13505,6 +14725,9 @@ impl serde::Serialize for PhysicalAggregateExprNode { if self.fun_definition.is_some() { len += 1; } + if !self.human_display.is_empty() { + len += 1; + } if self.aggregate_function.is_some() { len += 1; } @@ -13526,6 +14749,9 @@ impl serde::Serialize for PhysicalAggregateExprNode { #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; } + if !self.human_display.is_empty() { + struct_ser.serialize_field("humanDisplay", &self.human_display)?; + } if let Some(v) = self.aggregate_function.as_ref() { match v { physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { @@ -13551,6 +14777,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "ignoreNulls", "fun_definition", "funDefinition", + "human_display", + "humanDisplay", "user_defined_aggr_function", "userDefinedAggrFunction", ]; @@ -13562,6 +14790,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { Distinct, IgnoreNulls, FunDefinition, + HumanDisplay, UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -13589,6 +14818,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { "distinct" => Ok(GeneratedField::Distinct), "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "humanDisplay" | "human_display" => Ok(GeneratedField::HumanDisplay), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -13614,6 +14844,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { let mut distinct__ = None; let mut ignore_nulls__ = None; let mut fun_definition__ = None; + let mut human_display__ = None; let mut aggregate_function__ = None; while let Some(k) = map_.next_key()? { match k { @@ -13649,6 +14880,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::HumanDisplay => { + if human_display__.is_some() { + return Err(serde::de::Error::duplicate_field("humanDisplay")); + } + human_display__ = Some(map_.next_value()?); + } GeneratedField::UserDefinedAggrFunction => { if aggregate_function__.is_some() { return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); @@ -13663,6 +14900,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { distinct: distinct__.unwrap_or_default(), ignore_nulls: ignore_nulls__.unwrap_or_default(), fun_definition: fun_definition__, + human_display: human_display__.unwrap_or_default(), aggregate_function: aggregate_function__, }) } @@ -15777,6 +17015,18 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::JsonScan(v) => { struct_ser.serialize_field("jsonScan", v)?; } + physical_plan_node::PhysicalPlanType::Cooperative(v) => { + struct_ser.serialize_field("cooperative", v)?; + } + physical_plan_node::PhysicalPlanType::GenerateSeries(v) => { + struct_ser.serialize_field("generateSeries", v)?; + } + physical_plan_node::PhysicalPlanType::SortMergeJoin(v) => { + struct_ser.serialize_field("sortMergeJoin", v)?; + } + physical_plan_node::PhysicalPlanType::MemoryScan(v) => { + struct_ser.serialize_field("memoryScan", v)?; + } } } struct_ser.end() @@ -15835,6 +17085,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "unnest", "json_scan", "jsonScan", + "cooperative", + "generate_series", + "generateSeries", + "sort_merge_join", + "sortMergeJoin", + "memory_scan", + "memoryScan", ]; #[allow(clippy::enum_variant_names)] @@ -15869,6 +17126,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { ParquetSink, Unnest, JsonScan, + Cooperative, + GenerateSeries, + SortMergeJoin, + MemoryScan, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15920,6 +17181,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), "unnest" => Ok(GeneratedField::Unnest), "jsonScan" | "json_scan" => Ok(GeneratedField::JsonScan), + "cooperative" => Ok(GeneratedField::Cooperative), + "generateSeries" | "generate_series" => Ok(GeneratedField::GenerateSeries), + "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), + "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16150,6 +17415,34 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("jsonScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonScan) +; + } + GeneratedField::Cooperative => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cooperative")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Cooperative) +; + } + GeneratedField::GenerateSeries => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("generateSeries")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GenerateSeries) +; + } + GeneratedField::SortMergeJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sortMergeJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortMergeJoin) +; + } + GeneratedField::MemoryScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("memoryScan")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::MemoryScan) ; } } @@ -16185,6 +17478,9 @@ impl serde::Serialize for PhysicalScalarUdfNode { if self.nullable { len += 1; } + if !self.return_field_name.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarUdfNode", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -16203,6 +17499,9 @@ impl serde::Serialize for PhysicalScalarUdfNode { if self.nullable { struct_ser.serialize_field("nullable", &self.nullable)?; } + if !self.return_field_name.is_empty() { + struct_ser.serialize_field("returnFieldName", &self.return_field_name)?; + } struct_ser.end() } } @@ -16220,6 +17519,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { "return_type", "returnType", "nullable", + "return_field_name", + "returnFieldName", ]; #[allow(clippy::enum_variant_names)] @@ -16229,6 +17530,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { FunDefinition, ReturnType, Nullable, + ReturnFieldName, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16255,6 +17557,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "returnType" | "return_type" => Ok(GeneratedField::ReturnType), "nullable" => Ok(GeneratedField::Nullable), + "returnFieldName" | "return_field_name" => Ok(GeneratedField::ReturnFieldName), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16279,6 +17582,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { let mut fun_definition__ = None; let mut return_type__ = None; let mut nullable__ = None; + let mut return_field_name__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -16313,6 +17617,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } nullable__ = Some(map_.next_value()?); } + GeneratedField::ReturnFieldName => { + if return_field_name__.is_some() { + return Err(serde::de::Error::duplicate_field("returnFieldName")); + } + return_field_name__ = Some(map_.next_value()?); + } } } Ok(PhysicalScalarUdfNode { @@ -16321,6 +17631,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { fun_definition: fun_definition__, return_type: return_type__, nullable: nullable__.unwrap_or_default(), + return_field_name: return_field_name__.unwrap_or_default(), }) } } @@ -16790,6 +18101,12 @@ impl serde::Serialize for PhysicalWindowExprNode { if self.fun_definition.is_some() { len += 1; } + if self.ignore_nulls { + len += 1; + } + if self.distinct { + len += 1; + } if self.window_function.is_some() { len += 1; } @@ -16814,6 +18131,12 @@ impl serde::Serialize for PhysicalWindowExprNode { #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; } + if self.ignore_nulls { + struct_ser.serialize_field("ignoreNulls", &self.ignore_nulls)?; + } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } if let Some(v) = self.window_function.as_ref() { match v { physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(v) => { @@ -16844,6 +18167,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "name", "fun_definition", "funDefinition", + "ignore_nulls", + "ignoreNulls", + "distinct", "user_defined_aggr_function", "userDefinedAggrFunction", "user_defined_window_function", @@ -16858,6 +18184,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { WindowFrame, Name, FunDefinition, + IgnoreNulls, + Distinct, UserDefinedAggrFunction, UserDefinedWindowFunction, } @@ -16887,6 +18215,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "name" => Ok(GeneratedField::Name), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), + "distinct" => Ok(GeneratedField::Distinct), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), "userDefinedWindowFunction" | "user_defined_window_function" => Ok(GeneratedField::UserDefinedWindowFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -16914,6 +18244,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { let mut window_frame__ = None; let mut name__ = None; let mut fun_definition__ = None; + let mut ignore_nulls__ = None; + let mut distinct__ = None; let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { @@ -16955,6 +18287,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::IgnoreNulls => { + if ignore_nulls__.is_some() { + return Err(serde::de::Error::duplicate_field("ignoreNulls")); + } + ignore_nulls__ = Some(map_.next_value()?); + } + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + distinct__ = Some(map_.next_value()?); + } GeneratedField::UserDefinedAggrFunction => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); @@ -16976,6 +18320,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { window_frame: window_frame__, name: name__.unwrap_or_default(), fun_definition: fun_definition__, + ignore_nulls: ignore_nulls__.unwrap_or_default(), + distinct: distinct__.unwrap_or_default(), window_function: window_function__, }) } @@ -19439,6 +20785,206 @@ impl<'de> serde::Deserialize<'de> for SortExprNodeCollection { deserializer.deserialize_struct("datafusion.SortExprNodeCollection", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SortMergeJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + if !self.sort_options.is_empty() { + len += 1; + } + if self.null_equality != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SortMergeJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + if !self.sort_options.is_empty() { + struct_ser.serialize_field("sortOptions", &self.sort_options)?; + } + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SortMergeJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "on", + "join_type", + "joinType", + "filter", + "sort_options", + "sortOptions", + "null_equality", + "nullEquality", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + On, + JoinType, + Filter, + SortOptions, + NullEquality, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "filter" => Ok(GeneratedField::Filter), + "sortOptions" | "sort_options" => Ok(GeneratedField::SortOptions), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SortMergeJoinExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SortMergeJoinExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut filter__ = None; + let mut sort_options__ = None; + let mut null_equality__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; + } + GeneratedField::SortOptions => { + if sort_options__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOptions")); + } + sort_options__ = Some(map_.next_value()?); + } + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); + } + null_equality__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(SortMergeJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + filter: filter__, + sort_options: sort_options__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SortMergeJoinExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for SortNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -20008,7 +21554,7 @@ impl serde::Serialize for SymmetricHashJoinExecNode { if self.partition_mode != 0 { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -20040,8 +21586,10 @@ impl serde::Serialize for SymmetricHashJoinExecNode { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; struct_ser.serialize_field("partitionMode", &v)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -20069,8 +21617,8 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { "joinType", "partition_mode", "partitionMode", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", "left_sort_exprs", "leftSortExprs", @@ -20085,7 +21633,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { On, JoinType, PartitionMode, - NullEqualsNull, + NullEquality, Filter, LeftSortExprs, RightSortExprs, @@ -20115,7 +21663,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { "on" => Ok(GeneratedField::On), "joinType" | "join_type" => Ok(GeneratedField::JoinType), "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "leftSortExprs" | "left_sort_exprs" => Ok(GeneratedField::LeftSortExprs), "rightSortExprs" | "right_sort_exprs" => Ok(GeneratedField::RightSortExprs), @@ -20143,7 +21691,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { let mut on__ = None; let mut join_type__ = None; let mut partition_mode__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; let mut left_sort_exprs__ = None; let mut right_sort_exprs__ = None; @@ -20179,11 +21727,11 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { } partition_mode__ = Some(map_.next_value::()? as i32); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -20211,7 +21759,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { on: on__.unwrap_or_default(), join_type: join_type__.unwrap_or_default(), partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, left_sort_exprs: left_sort_exprs__.unwrap_or_default(), right_sort_exprs: right_sort_exprs__.unwrap_or_default(), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d2165dad48501..d3b5f566e98b7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -86,7 +86,7 @@ pub struct LogicalExtensionNode { #[prost(message, repeated, tag = "2")] pub inputs: ::prost::alloc::vec::Vec, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ProjectionColumns { #[prost(string, repeated, tag = "1")] pub columns: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, @@ -115,15 +115,18 @@ pub struct ListingTableScanNode { pub schema: ::core::option::Option, #[prost(message, repeated, tag = "6")] pub filters: ::prost::alloc::vec::Vec, - #[prost(string, repeated, tag = "7")] - pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, repeated, tag = "7")] + pub table_partition_cols: ::prost::alloc::vec::Vec, #[prost(bool, tag = "8")] pub collect_stat: bool, #[prost(uint32, tag = "9")] pub target_partitions: u32, #[prost(message, repeated, tag = "13")] pub file_sort_order: ::prost::alloc::vec::Vec, - #[prost(oneof = "listing_table_scan_node::FileFormatType", tags = "10, 11, 12, 15")] + #[prost( + oneof = "listing_table_scan_node::FileFormatType", + tags = "10, 11, 12, 15, 16" + )] pub file_format_type: ::core::option::Option< listing_table_scan_node::FileFormatType, >, @@ -140,6 +143,8 @@ pub mod listing_table_scan_node { Avro(super::super::datafusion_common::AvroFormat), #[prost(message, tag = "15")] Json(super::super::datafusion_common::NdJsonFormat), + #[prost(message, tag = "16")] + Arrow(super::super::datafusion_common::ArrowFormat), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -180,7 +185,7 @@ pub struct ProjectionNode { } /// Nested message and enum types in `ProjectionNode`. pub mod projection_node { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum OptionalAlias { #[prost(string, tag = "3")] Alias(::prost::alloc::string::String), @@ -227,7 +232,7 @@ pub struct HashRepartition { #[prost(uint64, tag = "2")] pub partition_count: u64, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct EmptyRelationNode { #[prost(bool, tag = "1")] pub produce_one_row: bool, @@ -246,6 +251,8 @@ pub struct CreateExternalTableNode { pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, #[prost(bool, tag = "6")] pub if_not_exists: bool, + #[prost(bool, tag = "15")] + pub or_replace: bool, #[prost(bool, tag = "14")] pub temporary: bool, #[prost(string, tag = "7")] @@ -369,8 +376,8 @@ pub struct JoinNode { pub left_join_key: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] pub right_join_key: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option, } @@ -480,7 +487,7 @@ pub struct UnnestNode { #[prost(message, optional, tag = "7")] pub options: ::core::option::Option, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ColumnUnnestListItem { #[prost(uint32, tag = "1")] pub input_index: u32, @@ -492,7 +499,7 @@ pub struct ColumnUnnestListRecursions { #[prost(message, repeated, tag = "2")] pub recursions: ::prost::alloc::vec::Vec, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ColumnUnnestListRecursion { #[prost(message, optional, tag = "1")] pub output_column: ::core::option::Option, @@ -506,7 +513,7 @@ pub struct UnnestOptions { #[prost(message, repeated, tag = "2")] pub recursions: ::prost::alloc::vec::Vec, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct RecursionUnnestOption { #[prost(message, optional, tag = "1")] pub output_column: ::core::option::Option, @@ -635,7 +642,7 @@ pub mod logical_expr_node { Unnest(super::Unnest), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Wildcard { #[prost(message, optional, tag = "1")] pub qualifier: ::core::option::Option, @@ -816,7 +823,7 @@ pub struct WindowExprNode { } /// Nested message and enum types in `WindowExprNode`. pub mod window_expr_node { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum WindowFunction { /// BuiltInWindowFunction built_in_function = 2; #[prost(string, tag = "3")] @@ -936,27 +943,27 @@ pub struct WindowFrameBound { #[prost(message, optional, tag = "2")] pub bound_value: ::core::option::Option, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct FixedSizeBinary { #[prost(int32, tag = "1")] pub length: i32, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct AnalyzedLogicalPlanType { #[prost(string, tag = "1")] pub analyzer_name: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct OptimizedLogicalPlanType { #[prost(string, tag = "1")] pub optimizer_name: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct OptimizedPhysicalPlanType { #[prost(string, tag = "1")] pub optimizer_name: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct PlanType { #[prost( oneof = "plan_type::PlanTypeEnum", @@ -966,7 +973,7 @@ pub struct PlanType { } /// Nested message and enum types in `PlanType`. pub mod plan_type { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum PlanTypeEnum { #[prost(message, tag = "1")] InitialLogicalPlan(super::super::datafusion_common::EmptyMessage), @@ -996,26 +1003,26 @@ pub mod plan_type { PhysicalPlanError(super::super::datafusion_common::EmptyMessage), } } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct StringifiedPlan { #[prost(message, optional, tag = "1")] pub plan_type: ::core::option::Option, #[prost(string, tag = "2")] pub plan: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct BareTableReference { #[prost(string, tag = "1")] pub table: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct PartialTableReference { #[prost(string, tag = "1")] pub schema: ::prost::alloc::string::String, #[prost(string, tag = "2")] pub table: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct FullTableReference { #[prost(string, tag = "1")] pub catalog: ::prost::alloc::string::String, @@ -1024,7 +1031,7 @@ pub struct FullTableReference { #[prost(string, tag = "3")] pub table: ::prost::alloc::string::String, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct TableReference { #[prost(oneof = "table_reference::TableReferenceEnum", tags = "1, 2, 3")] pub table_reference_enum: ::core::option::Option< @@ -1033,7 +1040,7 @@ pub struct TableReference { } /// Nested message and enum types in `TableReference`. pub mod table_reference { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum TableReferenceEnum { #[prost(message, tag = "1")] Bare(super::BareTableReference), @@ -1048,7 +1055,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" )] pub physical_plan_type: ::core::option::Option, } @@ -1118,6 +1125,14 @@ pub mod physical_plan_node { Unnest(::prost::alloc::boxed::Box), #[prost(message, tag = "31")] JsonScan(super::JsonScanExecNode), + #[prost(message, tag = "32")] + Cooperative(::prost::alloc::boxed::Box), + #[prost(message, tag = "33")] + GenerateSeries(super::GenerateSeriesNode), + #[prost(message, tag = "34")] + SortMergeJoin(::prost::alloc::boxed::Box), + #[prost(message, tag = "35")] + MemoryScan(super::MemoryScanExecNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1219,7 +1234,7 @@ pub struct UnnestExecNode { #[prost(message, optional, tag = "5")] pub options: ::core::option::Option, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ListUnnest { #[prost(uint32, tag = "1")] pub index_in_input_schema: u32, @@ -1303,6 +1318,8 @@ pub struct PhysicalScalarUdfNode { pub return_type: ::core::option::Option, #[prost(bool, tag = "5")] pub nullable: bool, + #[prost(string, tag = "6")] + pub return_field_name: ::prost::alloc::string::String, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalAggregateExprNode { @@ -1316,6 +1333,8 @@ pub struct PhysicalAggregateExprNode { pub ignore_nulls: bool, #[prost(bytes = "vec", optional, tag = "7")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, + #[prost(string, tag = "8")] + pub human_display: ::prost::alloc::string::String, #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "4")] pub aggregate_function: ::core::option::Option< physical_aggregate_expr_node::AggregateFunction, @@ -1323,7 +1342,7 @@ pub struct PhysicalAggregateExprNode { } /// Nested message and enum types in `PhysicalAggregateExprNode`. pub mod physical_aggregate_expr_node { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum AggregateFunction { #[prost(string, tag = "4")] UserDefinedAggrFunction(::prost::alloc::string::String), @@ -1343,6 +1362,10 @@ pub struct PhysicalWindowExprNode { pub name: ::prost::alloc::string::String, #[prost(bytes = "vec", optional, tag = "9")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, + #[prost(bool, tag = "11")] + pub ignore_nulls: bool, + #[prost(bool, tag = "12")] + pub distinct: bool, #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "3, 10")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, @@ -1350,7 +1373,7 @@ pub struct PhysicalWindowExprNode { } /// Nested message and enum types in `PhysicalWindowExprNode`. pub mod physical_window_expr_node { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum WindowFunction { /// BuiltInWindowFunction built_in_function = 2; #[prost(string, tag = "3")] @@ -1486,7 +1509,7 @@ pub struct FileGroup { #[prost(message, repeated, tag = "1")] pub files: ::prost::alloc::vec::Vec, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ScanLimit { /// wrap into a message to make it optional #[prost(uint32, tag = "1")] @@ -1543,6 +1566,8 @@ pub struct CsvScanExecNode { pub quote: ::prost::alloc::string::String, #[prost(bool, tag = "7")] pub newlines_in_values: bool, + #[prost(bool, tag = "8")] + pub truncate_rows: bool, #[prost(oneof = "csv_scan_exec_node::OptionalEscape", tags = "5")] pub optional_escape: ::core::option::Option, #[prost(oneof = "csv_scan_exec_node::OptionalComment", tags = "6")] @@ -1550,12 +1575,12 @@ pub struct CsvScanExecNode { } /// Nested message and enum types in `CsvScanExecNode`. pub mod csv_scan_exec_node { - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum OptionalEscape { #[prost(string, tag = "5")] Escape(::prost::alloc::string::String), } - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum OptionalComment { #[prost(string, tag = "6")] Comment(::prost::alloc::string::String), @@ -1572,6 +1597,26 @@ pub struct AvroScanExecNode { pub base_conf: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct MemoryScanExecNode { + #[prost(bytes = "vec", repeated, tag = "1")] + pub partitions: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, + #[prost(message, optional, tag = "2")] + pub schema: ::core::option::Option, + #[prost(uint32, repeated, tag = "3")] + pub projection: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "4")] + pub sort_information: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "5")] + pub show_sizes: bool, + #[prost(uint32, optional, tag = "6")] + pub fetch: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CooperativeExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct HashJoinExecNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -1583,8 +1628,8 @@ pub struct HashJoinExecNode { pub join_type: i32, #[prost(enumeration = "PartitionMode", tag = "6")] pub partition_mode: i32, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option, #[prost(uint32, repeated, tag = "9")] @@ -1602,8 +1647,8 @@ pub struct SymmetricHashJoinExecNode { pub join_type: i32, #[prost(enumeration = "StreamPartitionMode", tag = "6")] pub partition_mode: i32, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option, #[prost(message, repeated, tag = "9")] @@ -1648,14 +1693,14 @@ pub struct CrossJoinExecNode { #[prost(message, optional, boxed, tag = "2")] pub right: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct PhysicalColumn { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, #[prost(uint32, tag = "2")] pub index: u32, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct UnknownColumn { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, @@ -1686,7 +1731,7 @@ pub struct ProjectionExecNode { #[prost(string, repeated, tag = "3")] pub expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct PartiallySortedInputOrderMode { #[prost(uint64, repeated, tag = "6")] pub columns: ::prost::alloc::vec::Vec, @@ -1706,7 +1751,7 @@ pub struct WindowAggExecNode { /// Nested message and enum types in `WindowAggExecNode`. pub mod window_agg_exec_node { /// Set optional to `None` for `BoundedWindowAggExec`. - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] pub enum InputOrderMode { #[prost(message, tag = "7")] Linear(super::super::datafusion_common::EmptyMessage), @@ -1726,7 +1771,7 @@ pub struct MaybePhysicalSortExprs { #[prost(message, repeated, tag = "1")] pub sort_expr: ::prost::alloc::vec::Vec, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct AggLimit { /// wrap into a message to make it optional #[prost(uint64, tag = "1")] @@ -1824,6 +1869,8 @@ pub struct CoalesceBatchesExecNode { pub struct CoalescePartitionsExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(uint32, optional, tag = "2")] + pub fetch: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalHashRepartition { @@ -1870,7 +1917,7 @@ pub struct JoinFilter { #[prost(message, optional, tag = "3")] pub schema: ::core::option::Option, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ColumnIndex { #[prost(uint32, tag = "1")] pub index: u32, @@ -1894,7 +1941,7 @@ pub struct PartitionedFile { #[prost(message, optional, tag = "6")] pub statistics: ::core::option::Option, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct FileRange { #[prost(int64, tag = "1")] pub start: i64, @@ -1932,6 +1979,96 @@ pub struct CteWorkTableScanNode { #[prost(message, optional, tag = "2")] pub schema: ::core::option::Option, } +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct GenerateSeriesArgsContainsNull { + #[prost(enumeration = "GenerateSeriesName", tag = "1")] + pub name: i32, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct GenerateSeriesArgsInt64 { + #[prost(int64, tag = "1")] + pub start: i64, + #[prost(int64, tag = "2")] + pub end: i64, + #[prost(int64, tag = "3")] + pub step: i64, + #[prost(bool, tag = "4")] + pub include_end: bool, + #[prost(enumeration = "GenerateSeriesName", tag = "5")] + pub name: i32, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct GenerateSeriesArgsTimestamp { + #[prost(int64, tag = "1")] + pub start: i64, + #[prost(int64, tag = "2")] + pub end: i64, + #[prost(message, optional, tag = "3")] + pub step: ::core::option::Option< + super::datafusion_common::IntervalMonthDayNanoValue, + >, + #[prost(string, optional, tag = "4")] + pub tz: ::core::option::Option<::prost::alloc::string::String>, + #[prost(bool, tag = "5")] + pub include_end: bool, + #[prost(enumeration = "GenerateSeriesName", tag = "6")] + pub name: i32, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct GenerateSeriesArgsDate { + #[prost(int64, tag = "1")] + pub start: i64, + #[prost(int64, tag = "2")] + pub end: i64, + #[prost(message, optional, tag = "3")] + pub step: ::core::option::Option< + super::datafusion_common::IntervalMonthDayNanoValue, + >, + #[prost(bool, tag = "4")] + pub include_end: bool, + #[prost(enumeration = "GenerateSeriesName", tag = "5")] + pub name: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GenerateSeriesNode { + #[prost(message, optional, tag = "1")] + pub schema: ::core::option::Option, + #[prost(uint32, tag = "2")] + pub target_batch_size: u32, + #[prost(oneof = "generate_series_node::Args", tags = "3, 4, 5, 6")] + pub args: ::core::option::Option, +} +/// Nested message and enum types in `GenerateSeriesNode`. +pub mod generate_series_node { + #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)] + pub enum Args { + #[prost(message, tag = "3")] + ContainsNull(super::GenerateSeriesArgsContainsNull), + #[prost(message, tag = "4")] + Int64Args(super::GenerateSeriesArgsInt64), + #[prost(message, tag = "5")] + TimestampArgs(super::GenerateSeriesArgsTimestamp), + #[prost(message, tag = "6")] + DateArgs(super::GenerateSeriesArgsDate), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SortMergeJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")] + pub join_type: i32, + #[prost(message, optional, tag = "5")] + pub filter: ::core::option::Option, + #[prost(message, repeated, tag = "6")] + pub sort_options: ::prost::alloc::vec::Vec, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { @@ -2135,3 +2272,29 @@ impl AggregateMode { } } } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum GenerateSeriesName { + GsGenerateSeries = 0, + GsRange = 1, +} +impl GenerateSeriesName { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::GsGenerateSeries => "GS_GENERATE_SERIES", + Self::GsRange => "GS_RANGE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GS_GENERATE_SERIES" => Some(Self::GsGenerateSeries), + "GS_RANGE" => Some(Self::GsRange), + _ => None, + } + } +} diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 2df162f21e3a3..bb7b992f145f5 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] @@ -115,7 +115,7 @@ //! let bytes = physical_plan_to_bytes(physical_plan.clone())?; //! //! // Decode bytes from somewhere (over network, etc.) back to ExecutionPlan -//! let physical_round_trip = physical_plan_from_bytes(&bytes, &ctx)?; +//! let physical_round_trip = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; //! assert_eq!(format!("{:?}", physical_plan), format!("{:?}", physical_round_trip)); //! # Ok(()) //! # } @@ -130,8 +130,9 @@ pub mod protobuf { pub use crate::generated::datafusion::*; pub use datafusion_proto_common::common::proto_error; pub use datafusion_proto_common::protobuf_common::{ - ArrowOptions, ArrowType, AvroFormat, AvroOptions, CsvFormat, DfSchema, - EmptyMessage, Field, JoinSide, NdJsonFormat, ParquetFormat, ScalarValue, Schema, + ArrowFormat, ArrowOptions, ArrowType, AvroFormat, AvroOptions, CsvFormat, + DfSchema, EmptyMessage, Field, JoinSide, NdJsonFormat, ParquetFormat, + ScalarValue, Schema, }; pub use datafusion_proto_common::{FromProtoError, ToProtoError}; } diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index e22738973284e..0e76e19ecb1ab 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -18,28 +18,20 @@ use std::sync::Arc; use datafusion::{ - config::{ - CsvOptions, JsonOptions, ParquetColumnOptions, ParquetOptions, - TableParquetOptions, - }, + config::{CsvOptions, JsonOptions}, datasource::file_format::{ arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory, - parquet::ParquetFormatFactory, FileFormatFactory, + FileFormatFactory, }, prelude::SessionContext, }; use datafusion_common::{ - exec_err, not_impl_err, parsers::CompressionTypeVariant, DataFusionError, + exec_datafusion_err, exec_err, not_impl_err, parsers::CompressionTypeVariant, TableReference, }; use prost::Message; -use crate::protobuf::{ - parquet_column_options, parquet_options, CsvOptions as CsvOptionsProto, - JsonOptions as JsonOptionsProto, ParquetColumnOptions as ParquetColumnOptionsProto, - ParquetColumnSpecificOptions, ParquetOptions as ParquetOptionsProto, - TableParquetOptions as TableParquetOptionsProto, -}; +use crate::protobuf::{CsvOptions as CsvOptionsProto, JsonOptions as JsonOptionsProto}; use super::LogicalExtensionCodec; @@ -72,6 +64,7 @@ impl CsvOptionsProto { newlines_in_values: options .newlines_in_values .map_or(vec![], |v| vec![v as u8]), + truncated_rows: options.truncated_rows.map_or(vec![], |v| vec![v as u8]), } } else { CsvOptionsProto::default() @@ -157,6 +150,11 @@ impl From<&CsvOptionsProto> for CsvOptions { } else { Some(proto.newlines_in_values[0] != 0) }, + truncated_rows: if proto.truncated_rows.is_empty() { + None + } else { + Some(proto.truncated_rows[0] != 0) + }, } } } @@ -205,10 +203,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { _ctx: &SessionContext, ) -> datafusion_common::Result> { let proto = CsvOptionsProto::decode(buf).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode CsvOptionsProto: {:?}", - e - )) + exec_datafusion_err!("Failed to decode CsvOptionsProto: {e:?}") })?; let options: CsvOptions = (&proto).into(); Ok(Arc::new(CsvFormatFactory { @@ -232,9 +227,9 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { options: Some(options), }); - proto.encode(buf).map_err(|e| { - DataFusionError::Execution(format!("Failed to encode CsvOptions: {:?}", e)) - })?; + proto + .encode(buf) + .map_err(|e| exec_datafusion_err!("Failed to encode CsvOptions: {e:?}"))?; Ok(()) } @@ -315,10 +310,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { _ctx: &SessionContext, ) -> datafusion_common::Result> { let proto = JsonOptionsProto::decode(buf).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode JsonOptionsProto: {:?}", - e - )) + exec_datafusion_err!("Failed to decode JsonOptionsProto: {e:?}") })?; let options: JsonOptions = (&proto).into(); Ok(Arc::new(JsonFormatFactory { @@ -336,33 +328,47 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { { json_factory.options.clone().unwrap_or_default() } else { - return Err(DataFusionError::Execution( - "Unsupported FileFormatFactory type".to_string(), - )); + return exec_err!("Unsupported FileFormatFactory type"); }; let proto = JsonOptionsProto::from_factory(&JsonFormatFactory { options: Some(options), }); - proto.encode(buf).map_err(|e| { - DataFusionError::Execution(format!("Failed to encode JsonOptions: {:?}", e)) - })?; + proto + .encode(buf) + .map_err(|e| exec_datafusion_err!("Failed to encode JsonOptions: {e:?}"))?; Ok(()) } } -impl TableParquetOptionsProto { - fn from_factory(factory: &ParquetFormatFactory) -> Self { - let global_options = if let Some(ref options) = factory.options { - options.clone() - } else { - return TableParquetOptionsProto::default(); - }; +#[cfg(feature = "parquet")] +mod parquet { + use super::*; + + use crate::protobuf::{ + parquet_column_options, parquet_options, + ParquetColumnOptions as ParquetColumnOptionsProto, ParquetColumnSpecificOptions, + ParquetOptions as ParquetOptionsProto, + TableParquetOptions as TableParquetOptionsProto, + }; + + use datafusion::{ + config::{ParquetColumnOptions, ParquetOptions, TableParquetOptions}, + datasource::file_format::parquet::ParquetFormatFactory, + }; + + impl TableParquetOptionsProto { + fn from_factory(factory: &ParquetFormatFactory) -> Self { + let global_options = if let Some(ref options) = factory.options { + options.clone() + } else { + return TableParquetOptionsProto::default(); + }; - let column_specific_options = global_options.column_specific_options; - #[allow(deprecated)] // max_statistics_size + let column_specific_options = global_options.column_specific_options; + #[allow(deprecated)] // max_statistics_size TableParquetOptionsProto { global: Some(ParquetOptionsProto { enable_page_index: global_options.global.enable_page_index, @@ -386,9 +392,6 @@ impl TableParquetOptionsProto { statistics_enabled_opt: global_options.global.statistics_enabled.map(|enabled| { parquet_options::StatisticsEnabledOpt::StatisticsEnabled(enabled) }), - max_statistics_size_opt: global_options.global.max_statistics_size.map(|size| { - parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(size as u64) - }), max_row_group_size: global_options.global.max_row_group_size as u64, created_by: global_options.global.created_by.clone(), column_index_truncate_length_opt: global_options.global.column_index_truncate_length.map(|length| { @@ -415,6 +418,12 @@ impl TableParquetOptionsProto { schema_force_view_types: global_options.global.schema_force_view_types, binary_as_string: global_options.global.binary_as_string, skip_arrow_metadata: global_options.global.skip_arrow_metadata, + coerce_int96_opt: global_options.global.coerce_int96.map(|compression| { + parquet_options::CoerceInt96Opt::CoerceInt96(compression) + }), + max_predicate_cache_size_opt: global_options.global.max_predicate_cache_size.map(|size| { + parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize(size as u64) + }), }), column_specific_options: column_specific_options.into_iter().map(|(column_name, options)| { ParquetColumnSpecificOptions { @@ -441,9 +450,6 @@ impl TableParquetOptionsProto { bloom_filter_ndv_opt: options.bloom_filter_ndv.map(|ndv| { parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(ndv) }), - max_statistics_size_opt: options.max_statistics_size.map(|size| { - parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(size as u32) - }), }) } }).collect(), @@ -454,12 +460,12 @@ impl TableParquetOptionsProto { }) .collect(), } + } } -} -impl From<&ParquetOptionsProto> for ParquetOptions { - fn from(proto: &ParquetOptionsProto) -> Self { - #[allow(deprecated)] // max_statistics_size + impl From<&ParquetOptionsProto> for ParquetOptions { + fn from(proto: &ParquetOptionsProto) -> Self { + #[allow(deprecated)] // max_statistics_size ParquetOptions { enable_page_index: proto.enable_page_index, pruning: proto.pruning, @@ -482,9 +488,6 @@ impl From<&ParquetOptionsProto> for ParquetOptions { statistics_enabled: proto.statistics_enabled_opt.as_ref().map(|opt| match opt { parquet_options::StatisticsEnabledOpt::StatisticsEnabled(statistics) => statistics.clone(), }), - max_statistics_size: proto.max_statistics_size_opt.as_ref().map(|opt| match opt { - parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(size) => *size as usize, - }), max_row_group_size: proto.max_row_group_size as usize, created_by: proto.created_by.clone(), column_index_truncate_length: proto.column_index_truncate_length_opt.as_ref().map(|opt| match opt { @@ -511,13 +514,19 @@ impl From<&ParquetOptionsProto> for ParquetOptions { schema_force_view_types: proto.schema_force_view_types, binary_as_string: proto.binary_as_string, skip_arrow_metadata: proto.skip_arrow_metadata, + coerce_int96: proto.coerce_int96_opt.as_ref().map(|opt| match opt { + parquet_options::CoerceInt96Opt::CoerceInt96(coerce_int96) => coerce_int96.clone(), + }), + max_predicate_cache_size: proto.max_predicate_cache_size_opt.as_ref().map(|opt| match opt { + parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize(size) => *size as usize, + }), + } } } -} -impl From for ParquetColumnOptions { - fn from(proto: ParquetColumnOptionsProto) -> Self { - #[allow(deprecated)] // max_statistics_size + impl From for ParquetColumnOptions { + fn from(proto: ParquetColumnOptionsProto) -> Self { + #[allow(deprecated)] // max_statistics_size ParquetColumnOptions { bloom_filter_enabled: proto.bloom_filter_enabled_opt.map( |parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v)| v, @@ -540,131 +549,130 @@ impl From for ParquetColumnOptions { bloom_filter_ndv: proto .bloom_filter_ndv_opt .map(|parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(v)| v), - max_statistics_size: proto.max_statistics_size_opt.map( - |parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v)| { - v as usize - }, - ), + } } } -} -impl From<&TableParquetOptionsProto> for TableParquetOptions { - fn from(proto: &TableParquetOptionsProto) -> Self { - TableParquetOptions { - global: proto - .global - .as_ref() - .map(ParquetOptions::from) - .unwrap_or_default(), - column_specific_options: proto - .column_specific_options - .iter() - .map(|parquet_column_options| { - ( - parquet_column_options.column_name.clone(), - ParquetColumnOptions::from( - parquet_column_options.options.clone().unwrap_or_default(), - ), - ) - }) - .collect(), - key_value_metadata: proto - .key_value_metadata - .iter() - .map(|(k, v)| (k.clone(), Some(v.clone()))) - .collect(), + impl From<&TableParquetOptionsProto> for TableParquetOptions { + fn from(proto: &TableParquetOptionsProto) -> Self { + TableParquetOptions { + global: proto + .global + .as_ref() + .map(ParquetOptions::from) + .unwrap_or_default(), + column_specific_options: proto + .column_specific_options + .iter() + .map(|parquet_column_options| { + ( + parquet_column_options.column_name.clone(), + ParquetColumnOptions::from( + parquet_column_options + .options + .clone() + .unwrap_or_default(), + ), + ) + }) + .collect(), + key_value_metadata: proto + .key_value_metadata + .iter() + .map(|(k, v)| (k.clone(), Some(v.clone()))) + .collect(), + crypto: Default::default(), + } } } -} -#[derive(Debug)] -pub struct ParquetLogicalExtensionCodec; + #[derive(Debug)] + pub struct ParquetLogicalExtensionCodec; -// TODO! This is a placeholder for now and needs to be implemented for real. -impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[datafusion_expr::LogicalPlan], - _ctx: &SessionContext, - ) -> datafusion_common::Result { - not_impl_err!("Method not implemented") - } + // TODO! This is a placeholder for now and needs to be implemented for real. + impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } - fn try_encode( - &self, - _node: &datafusion_expr::Extension, - _buf: &mut Vec, - ) -> datafusion_common::Result<()> { - not_impl_err!("Method not implemented") - } + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } - fn try_decode_table_provider( - &self, - _buf: &[u8], - _table_ref: &TableReference, - _schema: arrow::datatypes::SchemaRef, - _ctx: &SessionContext, - ) -> datafusion_common::Result> { - not_impl_err!("Method not implemented") - } + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &SessionContext, + ) -> datafusion_common::Result> + { + not_impl_err!("Method not implemented") + } - fn try_encode_table_provider( - &self, - _table_ref: &TableReference, - _node: Arc, - _buf: &mut Vec, - ) -> datafusion_common::Result<()> { - not_impl_err!("Method not implemented") - } + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } - fn try_decode_file_format( - &self, - buf: &[u8], - _ctx: &SessionContext, - ) -> datafusion_common::Result> { - let proto = TableParquetOptionsProto::decode(buf).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode TableParquetOptionsProto: {:?}", - e + fn try_decode_file_format( + &self, + buf: &[u8], + _ctx: &SessionContext, + ) -> datafusion_common::Result> { + let proto = TableParquetOptionsProto::decode(buf).map_err(|e| { + exec_datafusion_err!("Failed to decode TableParquetOptionsProto: {e:?}") + })?; + let options: TableParquetOptions = (&proto).into(); + Ok(Arc::new( + datafusion::datasource::file_format::parquet::ParquetFormatFactory { + options: Some(options), + }, )) - })?; - let options: TableParquetOptions = (&proto).into(); - Ok(Arc::new(ParquetFormatFactory { - options: Some(options), - })) - } + } - fn try_encode_file_format( - &self, - buf: &mut Vec, - node: Arc, - ) -> datafusion_common::Result<()> { - let options = if let Some(parquet_factory) = - node.as_any().downcast_ref::() - { - parquet_factory.options.clone().unwrap_or_default() - } else { - return Err(DataFusionError::Execution( - "Unsupported FileFormatFactory type".to_string(), - )); - }; + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: Arc, + ) -> datafusion_common::Result<()> { + use datafusion::datasource::file_format::parquet::ParquetFormatFactory; + + let options = if let Some(parquet_factory) = + node.as_any().downcast_ref::() + { + parquet_factory.options.clone().unwrap_or_default() + } else { + return exec_err!("Unsupported FileFormatFactory type"); + }; - let proto = TableParquetOptionsProto::from_factory(&ParquetFormatFactory { - options: Some(options), - }); + let proto = TableParquetOptionsProto::from_factory(&ParquetFormatFactory { + options: Some(options), + }); - proto.encode(buf).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to encode TableParquetOptionsProto: {:?}", - e - )) - })?; + proto.encode(buf).map_err(|e| { + exec_datafusion_err!("Failed to encode TableParquetOptionsProto: {e:?}") + })?; - Ok(()) + Ok(()) + } } } +#[cfg(feature = "parquet")] +pub use parquet::ParquetLogicalExtensionCodec; #[derive(Debug)] pub struct ArrowLogicalExtensionCodec; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index cac2f9db1645b..cbfa15183b5c1 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,8 +19,8 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, - Result, ScalarValue, TableReference, UnnestOptions, + exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, + RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, }; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{Alias, Placeholder, Sort}; @@ -205,6 +205,7 @@ impl From for JoinType { protobuf::JoinType::Leftanti => JoinType::LeftAnti, protobuf::JoinType::Rightanti => JoinType::RightAnti, protobuf::JoinType::Leftmark => JoinType::LeftMark, + protobuf::JoinType::Rightmark => JoinType::RightMark, } } } @@ -218,6 +219,15 @@ impl From for JoinConstraint { } } +impl From for NullEquality { + fn from(t: protobuf::NullEquality) -> Self { + match t { + protobuf::NullEquality::NullEqualsNothing => NullEquality::NullEqualsNothing, + protobuf::NullEquality::NullEqualsNull => NullEquality::NullEqualsNull, + } + } +} + impl From for WriteOp { fn from(t: protobuf::dml_node::Type) -> Self { match t { @@ -268,7 +278,7 @@ pub fn parse_expr( ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { let scalar_value: ScalarValue = literal.try_into()?; - Ok(Expr::Literal(scalar_value)) + Ok(Expr::Literal(scalar_value, None)) } ExprType::WindowExpr(expr) => { let window_function = expr @@ -291,16 +301,19 @@ pub fn parse_expr( exec_datafusion_err!("missing window frame during deserialization") })?; - // TODO: support proto for null treatment + // TODO: support null treatment, distinct, and filter in proto. + // See https://github.com/apache/datafusion/issues/17417 match window_function { window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)?, + None => registry + .udaf(udaf_name) + .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -313,11 +326,13 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry.udwf(udwf_name)?, + None => registry + .udwf(udwf_name) + .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) @@ -540,7 +555,9 @@ pub fn parse_expr( }) => { let scalar_fn = match fun_definition { Some(buf) => codec.try_decode_udf(fun_name, buf)?, - None => registry.udf(fun_name.as_str())?, + None => registry + .udf(fun_name.as_str()) + .or_else(|_| codec.try_decode_udf(fun_name, &[]))?, }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, @@ -550,7 +567,9 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = match &pb.fun_definition { Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, - None => registry.udaf(&pb.fun_name)?, + None => registry + .udaf(&pb.fun_name) + .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -558,10 +577,7 @@ pub fn parse_expr( parse_exprs(&pb.args, registry, codec)?, pb.distinct, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), - match pb.order_by.len() { - 0 => None, - _ => Some(parse_sorts(&pb.order_by, registry, codec)?), - }, + parse_sorts(&pb.order_by, registry, codec)?, None, ))) } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index c65569ef1cfbe..fd9e07914b076 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -33,8 +33,9 @@ use crate::{ }; use crate::protobuf::{proto_error, ToProtoError}; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaBuilder, SchemaRef}; use datafusion::datasource::cte_worktable::CteWorkTable; +use datafusion::datasource::file_format::arrow::ArrowFormat; #[cfg(feature = "avro")] use datafusion::datasource::file_format::avro::AvroFormat; #[cfg(feature = "parquet")] @@ -56,8 +57,8 @@ use datafusion::{ }; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - context, internal_datafusion_err, internal_err, not_impl_err, plan_err, - DataFusionError, Result, TableReference, ToDFSchema, + context, internal_datafusion_err, internal_err, not_impl_err, plan_err, Result, + TableReference, ToDFSchema, }; use datafusion_expr::{ dml, @@ -71,8 +72,7 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, - TableSource, Unnest, + AggregateUDF, DmlStatement, FetchType, RecursiveQuery, SkipType, TableSource, Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -229,9 +229,9 @@ fn from_table_reference( error_context: &str, ) -> Result { let table_ref = table_ref.ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "Protobuf deserialization error, {error_context} was missing required field name." - )) + ) })?; Ok(table_ref.clone().try_into()?) @@ -281,9 +281,8 @@ impl AsLogicalPlan for LogicalPlanNode { where Self: Sized, { - LogicalPlanNode::decode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to decode logical plan: {e:?}")) - }) + LogicalPlanNode::decode(buf) + .map_err(|e| internal_datafusion_err!("failed to decode logical plan: {e:?}")) } fn try_encode(&self, buf: &mut B) -> Result<()> @@ -291,9 +290,8 @@ impl AsLogicalPlan for LogicalPlanNode { B: BufMut, Self: Sized, { - self.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode logical plan: {e:?}")) - }) + self.encode(buf) + .map_err(|e| internal_datafusion_err!("failed to encode logical plan: {e:?}")) } fn try_into_logical_plan( @@ -355,10 +353,7 @@ impl AsLogicalPlan for LogicalPlanNode { .as_ref() .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .transpose()? - .ok_or_else(|| { - DataFusionError::Internal("expression required".to_string()) - })?; - // .try_into()?; + .ok_or_else(|| proto_error("expression required"))?; LogicalPlanBuilder::from(input).filter(expr)?.build() } LogicalPlanType::Window(window) => { @@ -382,16 +377,6 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::ListingScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; - let mut projection = None; - if let Some(columns) = &scan.projection { - let column_indices = columns - .columns - .iter() - .map(|name| schema.index_of(name)) - .collect::, _>>()?; - projection = Some(column_indices); - } - let filters = from_proto::parse_exprs(&scan.filters, ctx, extension_codec)?; @@ -443,13 +428,16 @@ impl AsLogicalPlan for LogicalPlanNode { } #[cfg_attr(not(feature = "avro"), allow(unused_variables))] FileFormatType::Avro(..) => { - #[cfg(feature = "avro")] + #[cfg(feature = "avro")] { Arc::new(AvroFormat) } #[cfg(not(feature = "avro"))] panic!("Unable to process avro file since `avro` feature is not enabled"); } + FileFormatType::Arrow(..) => { + Arc::new(ArrowFormat) + } }; let table_paths = &scan @@ -458,23 +446,25 @@ impl AsLogicalPlan for LogicalPlanNode { .map(ListingTableUrl::parse) .collect::, _>>()?; + let partition_columns = scan + .table_partition_cols + .iter() + .map(|col| { + let Some(arrow_type) = col.arrow_type.as_ref() else { + return Err(proto_error( + "Missing Arrow type in partition columns", + )); + }; + let arrow_type = DataType::try_from(arrow_type).map_err(|e| { + proto_error(format!("Received an unknown ArrowType: {e}")) + })?; + Ok((col.name.clone(), arrow_type)) + }) + .collect::>>()?; + let options = ListingOptions::new(file_format) .with_file_extension(&scan.file_extension) - .with_table_partition_cols( - scan.table_partition_cols - .iter() - .map(|col| { - ( - col.clone(), - schema - .field_with_name(col) - .unwrap() - .data_type() - .clone(), - ) - }) - .collect(), - ) + .with_table_partition_cols(partition_columns) .with_collect_stat(scan.collect_stat) .with_target_partitions(scan.target_partitions as usize) .with_file_sort_order(all_sort_orders); @@ -494,6 +484,16 @@ impl AsLogicalPlan for LogicalPlanNode { let table_name = from_table_reference(scan.table_name.as_ref(), "ListingTableScan")?; + let mut projection = None; + if let Some(columns) = &scan.projection { + let column_indices = columns + .columns + .iter() + .map(|name| provider.schema().index_of(name)) + .collect::, _>>()?; + projection = Some(column_indices); + } + LogicalPlanBuilder::scan_with_filters( table_name, provider_as_source(Arc::new(provider)), @@ -579,15 +579,15 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::CreateExternalTable(create_extern_table) => { let pb_schema = (create_extern_table.schema.clone()).ok_or_else(|| { - DataFusionError::Internal(String::from( + internal_datafusion_err!( "Protobuf deserialization error, CreateExternalTableNode was missing required field schema." - )) + ) })?; let constraints = (create_extern_table.constraints.clone()).ok_or_else(|| { - DataFusionError::Internal(String::from( - "Protobuf deserialization error, CreateExternalTableNode was missing required table constraints.", - )) + internal_datafusion_err!( + "Protobuf deserialization error, CreateExternalTableNode was missing required table constraints." + ) })?; let definition = if !create_extern_table.definition.is_empty() { Some(create_extern_table.definition.clone()) @@ -630,6 +630,7 @@ impl AsLogicalPlan for LogicalPlanNode { .clone(), order_exprs, if_not_exists: create_extern_table.if_not_exists, + or_replace: create_extern_table.or_replace, temporary: create_extern_table.temporary, definition, unbounded: create_extern_table.unbounded, @@ -641,9 +642,9 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::CreateView(create_view) => { let plan = create_view - .input.clone().ok_or_else(|| DataFusionError::Internal(String::from( - "Protobuf deserialization error, CreateViewNode has invalid LogicalPlan input.", - )))? + .input.clone().ok_or_else(|| internal_datafusion_err!( + "Protobuf deserialization error, CreateViewNode has invalid LogicalPlan input." + ))? .try_into_logical_plan(ctx, extension_codec)?; let definition = if !create_view.definition.is_empty() { Some(create_view.definition.clone()) @@ -661,9 +662,9 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::CreateCatalogSchema(create_catalog_schema) => { let pb_schema = (create_catalog_schema.schema.clone()).ok_or_else(|| { - DataFusionError::Internal(String::from( - "Protobuf deserialization error, CreateCatalogSchemaNode was missing required field schema.", - )) + internal_datafusion_err!( + "Protobuf deserialization error, CreateCatalogSchemaNode was missing required field schema." + ) })?; Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema( @@ -676,9 +677,9 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::CreateCatalog(create_catalog) => { let pb_schema = (create_catalog.schema.clone()).ok_or_else(|| { - DataFusionError::Internal(String::from( - "Protobuf deserialization error, CreateCatalogNode was missing required field schema.", - )) + internal_datafusion_err!( + "Protobuf deserialization error, CreateCatalogNode was missing required field schema." + ) })?; Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalog( @@ -787,9 +788,9 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlanType::Union(union) => { if union.inputs.len() < 2 { - return Err( DataFusionError::Internal(String::from( - "Protobuf deserialization error, Union was require at least two input.", - ))); + return internal_err!( + "Protobuf deserialization error, Union was require at least two input." + ); } let (first, rest) = union.inputs.split_first().unwrap(); let mut builder = LogicalPlanBuilder::from( @@ -906,67 +907,40 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec.try_decode_file_format(©.file_type, ctx)?, ); - Ok(LogicalPlan::Copy(dml::CopyTo { - input: Arc::new(input), - output_url: copy.output_url.clone(), - partition_by: copy.partition_by.clone(), + Ok(LogicalPlan::Copy(dml::CopyTo::new( + Arc::new(input), + copy.output_url.clone(), + copy.partition_by.clone(), file_type, - options: Default::default(), - })) + Default::default(), + ))) } LogicalPlanType::Unnest(unnest) => { let input: LogicalPlan = into_logical_plan!(unnest.input, ctx, extension_codec)?; - Ok(LogicalPlan::Unnest(Unnest { - input: Arc::new(input), - exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(), - list_type_columns: unnest - .list_type_columns - .iter() - .map(|c| { - let recursion_item = c.recursion.as_ref().unwrap(); - ( - c.input_index as _, - ColumnUnnestList { - output_column: recursion_item - .output_column - .as_ref() - .unwrap() - .into(), - depth: recursion_item.depth as _, - }, - ) - }) - .collect(), - struct_type_columns: unnest - .struct_type_columns - .iter() - .map(|c| *c as usize) - .collect(), - dependency_indices: unnest - .dependency_indices - .iter() - .map(|c| *c as usize) - .collect(), - schema: Arc::new(convert_required!(unnest.schema)?), - options: into_required!(unnest.options)?, - })) + + LogicalPlanBuilder::from(input) + .unnest_columns_with_options( + unnest.exec_columns.iter().map(|c| c.into()).collect(), + into_required!(unnest.options)?, + )? + .build() } LogicalPlanType::RecursiveQuery(recursive_query_node) => { let static_term = recursive_query_node .static_term .as_ref() - .ok_or_else(|| DataFusionError::Internal(String::from( - "Protobuf deserialization error, RecursiveQueryNode was missing required field static_term.", - )))? + .ok_or_else(|| internal_datafusion_err!( + "Protobuf deserialization error, RecursiveQueryNode was missing required field static_term." + ))? .try_into_logical_plan(ctx, extension_codec)?; let recursive_term = recursive_query_node .recursive_term .as_ref() - .ok_or_else(|| DataFusionError::Internal(String::from( - "Protobuf deserialization error, RecursiveQueryNode was missing required field recursive_term.", - )))? + .ok_or_else(|| internal_datafusion_err!( + "Protobuf deserialization error, RecursiveQueryNode was missing required field recursive_term." + ))? .try_into_logical_plan(ctx, extension_codec)?; Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { @@ -1046,7 +1020,6 @@ impl AsLogicalPlan for LogicalPlanNode { }) } }; - let schema: protobuf::Schema = schema.as_ref().try_into()?; let filters: Vec = serialize_exprs(filters, extension_codec)?; @@ -1087,18 +1060,38 @@ impl AsLogicalPlan for LogicalPlanNode { Some(FileFormatType::Avro(protobuf::AvroFormat {})) } + if any.is::() { + maybe_some_type = + Some(FileFormatType::Arrow(protobuf::ArrowFormat {})) + } + if let Some(file_format_type) = maybe_some_type { file_format_type } else { return Err(proto_error(format!( - "Error converting file format, {:?} is invalid as a datafusion format.", - listing_table.options().format - ))); + "Error deserializing unknown file format: {:?}", + listing_table.options().format + ))); } }; let options = listing_table.options(); + let mut builder = SchemaBuilder::from(schema.as_ref()); + for (idx, field) in schema.fields().iter().enumerate().rev() { + if options + .table_partition_cols + .iter() + .any(|(name, _)| name == field.name()) + { + builder.remove(idx); + } + } + + let schema = builder.finish(); + + let schema: protobuf::Schema = (&schema).try_into()?; + let mut exprs_vec: Vec = vec![]; for order in &options.file_sort_order { let expr_vec = SortExprNodeCollection { @@ -1107,6 +1100,23 @@ impl AsLogicalPlan for LogicalPlanNode { exprs_vec.push(expr_vec); } + let partition_columns = options + .table_partition_cols + .iter() + .map(|(name, arrow_type)| { + let arrow_type = protobuf::ArrowType::try_from(arrow_type) + .map_err(|e| { + proto_error(format!( + "Received an unknown ArrowType: {e}" + )) + })?; + Ok(protobuf::PartitionColumn { + name: name.clone(), + arrow_type: Some(arrow_type), + }) + }) + .collect::>>()?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ListingScan( protobuf::ListingTableScanNode { @@ -1114,11 +1124,7 @@ impl AsLogicalPlan for LogicalPlanNode { table_name: Some(table_name.clone().into()), collect_stat: options.collect_stat, file_extension: options.file_extension.clone(), - table_partition_cols: options - .table_partition_cols - .iter() - .map(|x| x.0.clone()) - .collect::>(), + table_partition_cols: partition_columns, paths: listing_table .table_paths() .iter() @@ -1133,6 +1139,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else if let Some(view_table) = source.downcast_ref::() { + let schema: protobuf::Schema = schema.as_ref().try_into()?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new( protobuf::ViewTableScanNode { @@ -1167,6 +1174,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else { + let schema: protobuf::Schema = schema.as_ref().try_into()?; let mut bytes = vec![]; extension_codec .try_encode_table_provider(table_name, provider, &mut bytes) @@ -1299,7 +1307,7 @@ impl AsLogicalPlan for LogicalPlanNode { filter, join_type, join_constraint, - null_equals_null, + null_equality, .. }) => { let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( @@ -1324,6 +1332,8 @@ impl AsLogicalPlan for LogicalPlanNode { let join_type: protobuf::JoinType = join_type.to_owned().into(); let join_constraint: protobuf::JoinConstraint = join_constraint.to_owned().into(); + let null_equality: protobuf::NullEquality = + null_equality.to_owned().into(); let filter = filter .as_ref() .map(|e| serialize_expr(e, extension_codec)) @@ -1337,7 +1347,7 @@ impl AsLogicalPlan for LogicalPlanNode { join_constraint: join_constraint.into(), left_join_key, right_join_key, - null_equals_null: *null_equals_null, + null_equality: null_equality.into(), filter, }, ))), @@ -1414,7 +1424,7 @@ impl AsLogicalPlan for LogicalPlanNode { )?; // Assumed common usize field was batch size - // Used u64 to avoid any nastyness involving large values, most data clusters are probably uniformly 64 bits any ways + // Used u64 to avoid any nastiness involving large values, most data clusters are probably uniformly 64 bits any ways use protobuf::repartition_node::PartitionMethod; let pb_partition_method = match partitioning_scheme { @@ -1458,6 +1468,7 @@ impl AsLogicalPlan for LogicalPlanNode { schema: df_schema, table_partition_cols, if_not_exists, + or_replace, definition, order_exprs, unbounded, @@ -1491,6 +1502,7 @@ impl AsLogicalPlan for LogicalPlanNode { schema: Some(df_schema.try_into()?), table_partition_cols: table_partition_cols.clone(), if_not_exists: *if_not_exists, + or_replace: *or_replace, temporary: *temporary, order_exprs: converted_order_exprs, definition: definition.clone().unwrap_or_default(), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 841c31fa035f4..1be3300008c79 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; -use datafusion_common::{TableReference, UnnestOptions}; +use datafusion_common::{NullEquality, TableReference, UnnestOptions}; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, @@ -211,13 +211,16 @@ pub fn serialize_expr( .map(|r| vec![r.into()]) .unwrap_or(vec![]), alias: name.to_owned(), - metadata: metadata.to_owned().unwrap_or(HashMap::new()), + metadata: metadata + .as_ref() + .map(|m| m.to_hashmap()) + .unwrap_or(HashMap::new()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Alias(alias)), } } - Expr::Literal(value) => { + Expr::Literal(value, _) => { let pb_value: protobuf::ScalarValue = value.try_into()?; protobuf::LogicalExprNode { expr_type: Some(ExprType::Literal(pb_value)), @@ -302,40 +305,38 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - params: - expr::WindowFunctionParams { - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }, - }) => { - let (window_function, fun_definition) = match fun { + Expr::WindowFunction(window_fun) => { + let expr::WindowFunction { + ref fun, + params: + expr::WindowFunctionParams { + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment, distinct, and filter in proto. + // See https://github.com/apache/datafusion/issues/17417 + null_treatment: _, + distinct: _, + filter: _, + }, + } = window_fun.as_ref(); + let mut buf = Vec::new(); + let window_function = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { - let mut buf = Vec::new(); let _ = codec.try_encode_udaf(aggr_udf, &mut buf); - ( - protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name().to_string(), - ), - (!buf.is_empty()).then_some(buf), + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), ) } WindowFunctionDefinition::WindowUDF(window_udf) => { - let mut buf = Vec::new(); let _ = codec.try_encode_udwf(window_udf, &mut buf); - ( - protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name().to_string(), - ), - (!buf.is_empty()).then_some(buf), + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), ) } }; + let fun_definition = (!buf.is_empty()).then_some(buf); let partition_by = serialize_exprs(partition_by, codec)?; let order_by = serialize_sorts(order_by, codec)?; @@ -376,10 +377,7 @@ pub fn serialize_expr( Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), None => None, }, - order_by: match order_by { - Some(e) => serialize_sorts(e, codec)?, - None => vec![], - }, + order_by: serialize_sorts(order_by, codec)?, fun_definition: (!buf.is_empty()).then_some(buf), }, ))), @@ -687,6 +685,7 @@ impl From for protobuf::JoinType { JoinType::LeftAnti => protobuf::JoinType::Leftanti, JoinType::RightAnti => protobuf::JoinType::Rightanti, JoinType::LeftMark => protobuf::JoinType::Leftmark, + JoinType::RightMark => protobuf::JoinType::Rightmark, } } } @@ -700,6 +699,15 @@ impl From for protobuf::JoinConstraint { } } +impl From for protobuf::NullEquality { + fn from(t: NullEquality) -> Self { + match t { + NullEquality::NullEqualsNothing => protobuf::NullEquality::NullEqualsNothing, + NullEquality::NullEqualsNull => protobuf::NullEquality::NullEqualsNull, + } + } +} + impl From<&WriteOp> for protobuf::dml_node::Type { fn from(t: &WriteOp) -> Self { match t { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index c949e3c9f8cb1..e2ee1be7d7321 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -19,7 +19,10 @@ use std::sync::Arc; +use arrow::array::RecordBatch; use arrow::compute::SortOptions; +use arrow::datatypes::Field; +use arrow::ipc::reader::StreamReader; use chrono::{TimeZone, Utc}; use datafusion_expr::dml::InsertOp; use object_store::path::Path; @@ -35,7 +38,7 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ FileGroup, FileScanConfig, FileScanConfigBuilder, FileSinkConfig, FileSource, }; -use datafusion::execution::FunctionRegistry; +use datafusion::execution::{FunctionRegistry, TaskContext}; use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ @@ -44,7 +47,7 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{internal_datafusion_err, not_impl_err, DataFusionError, Result}; use datafusion_proto_common::common::proto_error; use crate::convert_required; @@ -67,16 +70,16 @@ impl From<&protobuf::PhysicalColumn> for Column { /// * `proto` - Input proto with physical sort expression node /// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. +/// when performing type coercion. /// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_sort_expr( proto: &protobuf::PhysicalSortExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema, codec)?; + let expr = parse_physical_expr(expr.as_ref(), ctx, input_schema, codec)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -94,20 +97,18 @@ pub fn parse_physical_sort_expr( /// * `proto` - Input proto with vector of physical sort expression node /// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. +/// when performing type coercion. /// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_sort_exprs( proto: &[protobuf::PhysicalSortExprNode], - registry: &dyn FunctionRegistry, + ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, -) -> Result { +) -> Result> { proto .iter() - .map(|sort_expr| { - parse_physical_sort_expr(sort_expr, registry, input_schema, codec) - }) - .collect::>() + .map(|sort_expr| parse_physical_sort_expr(sort_expr, ctx, input_schema, codec)) + .collect() } /// Parses a physical window expr from a protobuf. @@ -118,32 +119,28 @@ pub fn parse_physical_sort_exprs( /// * `name` - Name of the window expression. /// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. +/// when performing type coercion. /// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_window_expr( proto: &protobuf::PhysicalWindowExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let window_node_expr = - parse_physical_exprs(&proto.args, registry, input_schema, codec)?; + let window_node_expr = parse_physical_exprs(&proto.args, ctx, input_schema, codec)?; let partition_by = - parse_physical_exprs(&proto.partition_by, registry, input_schema, codec)?; + parse_physical_exprs(&proto.partition_by, ctx, input_schema, codec)?; - let order_by = - parse_physical_sort_exprs(&proto.order_by, registry, input_schema, codec)?; + let order_by = parse_physical_sort_exprs(&proto.order_by, ctx, input_schema, codec)?; let window_frame = proto .window_frame .as_ref() .map(|wf| wf.clone().try_into()) .transpose() - .map_err(|e| DataFusionError::Internal(format!("{e}")))? + .map_err(|e| internal_datafusion_err!("{e}"))? .ok_or_else(|| { - DataFusionError::Internal( - "Missing required field 'window_frame' in protobuf".to_string(), - ) + internal_datafusion_err!("Missing required field 'window_frame' in protobuf") })?; let fun = if let Some(window_func) = proto.window_function.as_ref() { @@ -151,13 +148,13 @@ pub fn parse_physical_window_expr( protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name) => { WindowFunctionDefinition::AggregateUDF(match &proto.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)? + None => ctx.udaf(udaf_name).or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }) } protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { WindowFunctionDefinition::WindowUDF(match &proto.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry.udwf(udwf_name)? + None => ctx.udwf(udwf_name).or_else(|_| codec.try_decode_udwf(udwf_name, &[]))? }) } } @@ -174,16 +171,18 @@ pub fn parse_physical_window_expr( name, &window_node_expr, &partition_by, - order_by.as_ref(), + &order_by, Arc::new(window_frame), - &extended_schema, - false, + extended_schema, + proto.ignore_nulls, + proto.distinct, + None, ) } pub fn parse_physical_exprs<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result>> @@ -192,7 +191,7 @@ where { protos .into_iter() - .map(|p| parse_physical_expr(p, registry, input_schema, codec)) + .map(|p| parse_physical_expr(p, ctx, input_schema, codec)) .collect::>>() } @@ -203,11 +202,11 @@ where /// * `proto` - Input proto with physical expression node /// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. +/// when performing type coercion. /// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -226,7 +225,7 @@ pub fn parse_physical_expr( ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new( parse_required_physical_expr( binary_expr.l.as_deref(), - registry, + ctx, "left", input_schema, codec, @@ -234,7 +233,7 @@ pub fn parse_physical_expr( logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, parse_required_physical_expr( binary_expr.r.as_deref(), - registry, + ctx, "right", input_schema, codec, @@ -256,7 +255,7 @@ pub fn parse_physical_expr( ExprType::IsNullExpr(e) => { Arc::new(IsNullExpr::new(parse_required_physical_expr( e.expr.as_deref(), - registry, + ctx, "expr", input_schema, codec, @@ -265,7 +264,7 @@ pub fn parse_physical_expr( ExprType::IsNotNullExpr(e) => { Arc::new(IsNotNullExpr::new(parse_required_physical_expr( e.expr.as_deref(), - registry, + ctx, "expr", input_schema, codec, @@ -273,7 +272,7 @@ pub fn parse_physical_expr( } ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_expr( e.expr.as_deref(), - registry, + ctx, "expr", input_schema, codec, @@ -281,7 +280,7 @@ pub fn parse_physical_expr( ExprType::Negative(e) => { Arc::new(NegativeExpr::new(parse_required_physical_expr( e.expr.as_deref(), - registry, + ctx, "expr", input_schema, codec, @@ -290,19 +289,19 @@ pub fn parse_physical_expr( ExprType::InList(e) => in_list( parse_required_physical_expr( e.expr.as_deref(), - registry, + ctx, "expr", input_schema, codec, )?, - parse_physical_exprs(&e.list, registry, input_schema, codec)?, + parse_physical_exprs(&e.list, ctx, input_schema, codec)?, &e.negated, input_schema, )?, ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) + .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) .transpose()?, e.when_then_expr .iter() @@ -310,14 +309,14 @@ pub fn parse_physical_expr( Ok(( parse_required_physical_expr( e.when_expr.as_ref(), - registry, + ctx, "when_expr", input_schema, codec, )?, parse_required_physical_expr( e.then_expr.as_ref(), - registry, + ctx, "then_expr", input_schema, codec, @@ -327,13 +326,13 @@ pub fn parse_physical_expr( .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) + .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( parse_required_physical_expr( e.expr.as_deref(), - registry, + ctx, "expr", input_schema, codec, @@ -344,7 +343,7 @@ pub fn parse_physical_expr( ExprType::TryCast(e) => Arc::new(TryCastExpr::new( parse_required_physical_expr( e.expr.as_deref(), - registry, + ctx, "expr", input_schema, codec, @@ -354,18 +353,28 @@ pub fn parse_physical_expr( ExprType::ScalarUdf(e) => { let udf = match &e.fun_definition { Some(buf) => codec.try_decode_udf(&e.name, buf)?, - None => registry.udf(e.name.as_str())?, + None => ctx + .udf(e.name.as_str()) + .or_else(|_| codec.try_decode_udf(&e.name, &[]))?, }; let scalar_fun_def = Arc::clone(&udf); - let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?; + let args = parse_physical_exprs(&e.args, ctx, input_schema, codec)?; + + let config_options = Arc::clone(ctx.session_config().options()); Arc::new( ScalarFunctionExpr::new( e.name.as_str(), scalar_fun_def, args, - convert_required!(e.return_type)?, + Field::new( + &e.return_field_name, + convert_required!(e.return_type)?, + true, + ) + .into(), + config_options, ) .with_nullable(e.nullable), ) @@ -375,14 +384,14 @@ pub fn parse_physical_expr( like_expr.case_insensitive, parse_required_physical_expr( like_expr.expr.as_deref(), - registry, + ctx, "expr", input_schema, codec, )?, parse_required_physical_expr( like_expr.pattern.as_deref(), - registry, + ctx, "pattern", input_schema, codec, @@ -392,7 +401,7 @@ pub fn parse_physical_expr( let inputs: Vec> = extension .inputs .iter() - .map(|e| parse_physical_expr(e, registry, input_schema, codec)) + .map(|e| parse_physical_expr(e, ctx, input_schema, codec)) .collect::>()?; (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ } @@ -403,32 +412,26 @@ pub fn parse_physical_expr( fn parse_required_physical_expr( expr: Option<&protobuf::PhysicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, field: &str, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { - expr.map(|e| parse_physical_expr(e, registry, input_schema, codec)) + expr.map(|e| parse_physical_expr(e, ctx, input_schema, codec)) .transpose()? - .ok_or_else(|| { - DataFusionError::Internal(format!("Missing required field {field:?}")) - }) + .ok_or_else(|| internal_datafusion_err!("Missing required field {field:?}")) } pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { match partitioning { Some(hash_part) => { - let expr = parse_physical_exprs( - &hash_part.hash_expr, - registry, - input_schema, - codec, - )?; + let expr = + parse_physical_exprs(&hash_part.hash_expr, ctx, input_schema, codec)?; Ok(Some(Partitioning::Hash( expr, @@ -441,7 +444,7 @@ pub fn parse_protobuf_hash_partitioning( pub fn parse_protobuf_partitioning( partitioning: Option<&protobuf::Partitioning>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { @@ -455,7 +458,7 @@ pub fn parse_protobuf_partitioning( Some(protobuf::partitioning::PartitionMethod::Hash(hash_repartition)) => { parse_protobuf_hash_partitioning( Some(hash_repartition), - registry, + ctx, input_schema, codec, ) @@ -479,7 +482,7 @@ pub fn parse_protobuf_file_scan_schema( pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn PhysicalExtensionCodec, file_source: Arc, ) -> Result { @@ -514,24 +517,27 @@ pub fn parse_protobuf_file_scan_config( // Remove partition columns from the schema after recreating table_partition_cols // because the partition columns are not in the file. They are present to allow // the partition column types to be reconstructed after serde. - let file_schema = Arc::new(Schema::new( - schema - .fields() - .iter() - .filter(|field| !table_partition_cols.contains(field)) - .cloned() - .collect::>(), - )); + let file_schema = Arc::new( + Schema::new( + schema + .fields() + .iter() + .filter(|field| !table_partition_cols.contains(field)) + .cloned() + .collect::>(), + ) + .with_metadata(schema.metadata.clone()), + ); let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { - let sort_expr = parse_physical_sort_exprs( + let sort_exprs = parse_physical_sort_exprs( &node_collection.physical_sort_expr_nodes, - registry, + ctx, &schema, codec, )?; - output_ordering.push(sort_expr); + output_ordering.extend(LexOrdering::new(sort_exprs)); } let config = FileScanConfigBuilder::new(object_store_url, file_schema, file_source) @@ -547,6 +553,18 @@ pub fn parse_protobuf_file_scan_config( Ok(config) } +pub fn parse_record_batches(buf: &[u8]) -> Result> { + if buf.is_empty() { + return Ok(vec![]); + } + let reader = StreamReader::try_new(buf, None)?; + let mut batches = Vec::new(); + for batch in reader { + batches.push(batch?); + } + Ok(batches) +} + impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { type Error = DataFusionError; @@ -555,7 +573,7 @@ impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { object_meta: ObjectMeta { location: Path::from(val.path.as_str()), last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), - size: val.size as usize, + size: val.size, e_tag: None, version: None, }, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 24cc0d5b3b028..d76bcc89b3db2 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -24,22 +24,25 @@ use crate::common::{byte_to_string, str_to_byte}; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, parse_physical_window_expr, parse_protobuf_file_scan_config, - parse_protobuf_file_scan_schema, + parse_protobuf_file_scan_schema, parse_record_batches, }; use crate::physical_plan::to_proto::{ serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, - serialize_physical_window_expr, + serialize_physical_sort_exprs, serialize_physical_window_expr, + serialize_record_batches, }; use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::{ - self, proto_error, window_agg_exec_node, ListUnnest as ProtoListUnnest, + self, proto_error, window_agg_exec_node, ListUnnest as ProtoListUnnest, SortExprNode, + SortMergeJoinExecNode, }; use crate::{convert_required, into_required}; use datafusion::arrow::compute::SortOptions; -use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::datatypes::{IntervalMonthDayNanoType, Schema, SchemaRef}; +use datafusion::catalog::memory::MemorySourceConfig; use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::json::JsonSink; @@ -50,12 +53,14 @@ use datafusion::datasource::physical_plan::AvroSource; #[cfg(feature = "parquet")] use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::physical_plan::{ - CsvSource, FileScanConfig, FileScanConfigBuilder, JsonSource, + CsvSource, FileScanConfig, FileScanConfigBuilder, FileSource, JsonSource, }; use datafusion::datasource::sink::DataSinkExec; -use datafusion::datasource::source::DataSourceExec; -use datafusion::execution::runtime_env::RuntimeEnv; -use datafusion::execution::FunctionRegistry; +use datafusion::datasource::source::{DataSource, DataSourceExec}; +use datafusion::execution::{FunctionRegistry, TaskContext}; +use datafusion::functions_table::generate_series::{ + Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, +}; use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_expr::aggregate::AggregateFunctionExpr; use datafusion::physical_expr::{LexOrdering, LexRequirement, PhysicalExprRef}; @@ -64,18 +69,21 @@ use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::coop::CooperativeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::explain::ExplainExec; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion::physical_plan::joins::{ - CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, + CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, + SymmetricHashJoinExec, }; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::memory::LazyMemoryExec; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -86,7 +94,9 @@ use datafusion::physical_plan::{ ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; use datafusion_common::config::TableParquetOptions; -use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + internal_datafusion_err, internal_err, not_impl_err, DataFusionError, Result, +}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use prost::bytes::BufMut; @@ -101,7 +111,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Self: Sized, { protobuf::PhysicalPlanNode::decode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to decode physical plan: {e:?}")) + internal_datafusion_err!("failed to decode physical plan: {e:?}") }) } @@ -111,14 +121,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Self: Sized, { self.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode physical plan: {e:?}")) + internal_datafusion_err!("failed to encode physical plan: {e:?}") }) } fn try_into_physical_plan( &self, - registry: &dyn FunctionRegistry, - runtime: &RuntimeEnv, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { @@ -127,879 +137,1412 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { )) })?; match plan { - PhysicalPlanType::Explain(explain) => Ok(Arc::new(ExplainExec::new( - Arc::new(explain.schema.as_ref().unwrap().try_into()?), - explain - .stringified_plans - .iter() - .map(|plan| plan.into()) - .collect(), - explain.verbose, - ))), + PhysicalPlanType::Explain(explain) => { + self.try_into_explain_physical_plan(explain, ctx, extension_codec) + } PhysicalPlanType::Projection(projection) => { - let input: Arc = into_physical_plan( - &projection.input, - registry, - runtime, - extension_codec, - )?; - let exprs = projection - .expr - .iter() - .zip(projection.expr_name.iter()) - .map(|(expr, name)| { - Ok(( - parse_physical_expr( - expr, - registry, - input.schema().as_ref(), - extension_codec, - )?, - name.to_string(), - )) - }) - .collect::, String)>>>()?; - Ok(Arc::new(ProjectionExec::try_new(exprs, input)?)) + self.try_into_projection_physical_plan(projection, ctx, extension_codec) } PhysicalPlanType::Filter(filter) => { - let input: Arc = into_physical_plan( - &filter.input, - registry, - runtime, - extension_codec, - )?; - let predicate = filter - .expr - .as_ref() - .map(|expr| { - parse_physical_expr( - expr, - registry, - input.schema().as_ref(), - extension_codec, - ) - }) - .transpose()? - .ok_or_else(|| { - DataFusionError::Internal( - "filter (FilterExecNode) in PhysicalPlanNode is missing." - .to_owned(), - ) - })?; - let filter_selectivity = filter.default_filter_selectivity.try_into(); - let projection = if !filter.projection.is_empty() { - Some( - filter - .projection - .iter() - .map(|i| *i as usize) - .collect::>(), - ) - } else { - None - }; - let filter = - FilterExec::try_new(predicate, input)?.with_projection(projection)?; - match filter_selectivity { - Ok(filter_selectivity) => Ok(Arc::new( - filter.with_default_selectivity(filter_selectivity)?, - )), - Err(_) => Err(DataFusionError::Internal( - "filter_selectivity in PhysicalPlanNode is invalid ".to_owned(), - )), - } + self.try_into_filter_physical_plan(filter, ctx, extension_codec) } PhysicalPlanType::CsvScan(scan) => { - let escape = if let Some( - protobuf::csv_scan_exec_node::OptionalEscape::Escape(escape), - ) = &scan.optional_escape - { - Some(str_to_byte(escape, "escape")?) - } else { - None - }; - - let comment = if let Some( - protobuf::csv_scan_exec_node::OptionalComment::Comment(comment), - ) = &scan.optional_comment - { - Some(str_to_byte(comment, "comment")?) - } else { - None - }; - - let source = Arc::new( - CsvSource::new( - scan.has_header, - str_to_byte(&scan.delimiter, "delimiter")?, - 0, - ) - .with_escape(escape) - .with_comment(comment), - ); - - let conf = FileScanConfigBuilder::from(parse_protobuf_file_scan_config( - scan.base_conf.as_ref().unwrap(), - registry, - extension_codec, - source, - )?) - .with_newlines_in_values(scan.newlines_in_values) - .with_file_compression_type(FileCompressionType::UNCOMPRESSED) - .build(); - Ok(DataSourceExec::from_data_source(conf)) + self.try_into_csv_scan_physical_plan(scan, ctx, extension_codec) } PhysicalPlanType::JsonScan(scan) => { - let scan_conf = parse_protobuf_file_scan_config( - scan.base_conf.as_ref().unwrap(), - registry, - extension_codec, - Arc::new(JsonSource::new()), - )?; - Ok(DataSourceExec::from_data_source(scan_conf)) + self.try_into_json_scan_physical_plan(scan, ctx, extension_codec) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] PhysicalPlanType::ParquetScan(scan) => { - #[cfg(feature = "parquet")] - { - let schema = parse_protobuf_file_scan_schema( - scan.base_conf.as_ref().unwrap(), - )?; - let predicate = scan - .predicate - .as_ref() - .map(|expr| { - parse_physical_expr( - expr, - registry, - schema.as_ref(), - extension_codec, - ) - }) - .transpose()?; - let mut options = TableParquetOptions::default(); - - if let Some(table_options) = scan.parquet_options.as_ref() { - options = table_options.try_into()?; - } - let mut source = ParquetSource::new(options); - - if let Some(predicate) = predicate { - source = source.with_predicate(Arc::clone(&schema), predicate); - } - let base_config = parse_protobuf_file_scan_config( - scan.base_conf.as_ref().unwrap(), - registry, - extension_codec, - Arc::new(source), - )?; - Ok(DataSourceExec::from_data_source(base_config)) - } - #[cfg(not(feature = "parquet"))] - panic!("Unable to process a Parquet PhysicalPlan when `parquet` feature is not enabled") + self.try_into_parquet_scan_physical_plan(scan, ctx, extension_codec) } #[cfg_attr(not(feature = "avro"), allow(unused_variables))] PhysicalPlanType::AvroScan(scan) => { - #[cfg(feature = "avro")] - { - let conf = parse_protobuf_file_scan_config( - scan.base_conf.as_ref().unwrap(), - registry, - extension_codec, - Arc::new(AvroSource::new()), - )?; - Ok(DataSourceExec::from_data_source(conf)) - } - #[cfg(not(feature = "avro"))] - panic!("Unable to process a Avro PhysicalPlan when `avro` feature is not enabled") + self.try_into_avro_scan_physical_plan(scan, ctx, extension_codec) } - PhysicalPlanType::CoalesceBatches(coalesce_batches) => { - let input: Arc = into_physical_plan( - &coalesce_batches.input, - registry, - runtime, - extension_codec, - )?; - Ok(Arc::new( - CoalesceBatchesExec::new( - input, - coalesce_batches.target_batch_size as usize, - ) - .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), - )) + PhysicalPlanType::MemoryScan(scan) => { + self.try_into_memory_scan_physical_plan(scan, ctx, extension_codec) } + PhysicalPlanType::CoalesceBatches(coalesce_batches) => self + .try_into_coalesce_batches_physical_plan( + coalesce_batches, + ctx, + extension_codec, + ), PhysicalPlanType::Merge(merge) => { - let input: Arc = - into_physical_plan(&merge.input, registry, runtime, extension_codec)?; - Ok(Arc::new(CoalescePartitionsExec::new(input))) + self.try_into_merge_physical_plan(merge, ctx, extension_codec) } PhysicalPlanType::Repartition(repart) => { - let input: Arc = into_physical_plan( - &repart.input, - registry, - runtime, - extension_codec, - )?; - let partitioning = parse_protobuf_partitioning( - repart.partitioning.as_ref(), - registry, - input.schema().as_ref(), - extension_codec, - )?; - Ok(Arc::new(RepartitionExec::try_new( - input, - partitioning.unwrap(), - )?)) + self.try_into_repartition_physical_plan(repart, ctx, extension_codec) } PhysicalPlanType::GlobalLimit(limit) => { - let input: Arc = - into_physical_plan(&limit.input, registry, runtime, extension_codec)?; - let fetch = if limit.fetch >= 0 { - Some(limit.fetch as usize) - } else { - None - }; - Ok(Arc::new(GlobalLimitExec::new( - input, - limit.skip as usize, - fetch, - ))) + self.try_into_global_limit_physical_plan(limit, ctx, extension_codec) } PhysicalPlanType::LocalLimit(limit) => { - let input: Arc = - into_physical_plan(&limit.input, registry, runtime, extension_codec)?; - Ok(Arc::new(LocalLimitExec::new(input, limit.fetch as usize))) + self.try_into_local_limit_physical_plan(limit, ctx, extension_codec) } PhysicalPlanType::Window(window_agg) => { - let input: Arc = into_physical_plan( - &window_agg.input, - registry, - runtime, + self.try_into_window_physical_plan(window_agg, ctx, extension_codec) + } + PhysicalPlanType::Aggregate(hash_agg) => { + self.try_into_aggregate_physical_plan(hash_agg, ctx, extension_codec) + } + PhysicalPlanType::HashJoin(hashjoin) => { + self.try_into_hash_join_physical_plan(hashjoin, ctx, extension_codec) + } + PhysicalPlanType::SymmetricHashJoin(sym_join) => self + .try_into_symmetric_hash_join_physical_plan( + sym_join, + ctx, extension_codec, - )?; - let input_schema = input.schema(); + ), + PhysicalPlanType::Union(union) => { + self.try_into_union_physical_plan(union, ctx, extension_codec) + } + PhysicalPlanType::Interleave(interleave) => { + self.try_into_interleave_physical_plan(interleave, ctx, extension_codec) + } + PhysicalPlanType::CrossJoin(crossjoin) => { + self.try_into_cross_join_physical_plan(crossjoin, ctx, extension_codec) + } + PhysicalPlanType::Empty(empty) => { + self.try_into_empty_physical_plan(empty, ctx, extension_codec) + } + PhysicalPlanType::PlaceholderRow(placeholder) => self + .try_into_placeholder_row_physical_plan( + placeholder, + ctx, + extension_codec, + ), + PhysicalPlanType::Sort(sort) => { + self.try_into_sort_physical_plan(sort, ctx, extension_codec) + } + PhysicalPlanType::SortPreservingMerge(sort) => self + .try_into_sort_preserving_merge_physical_plan(sort, ctx, extension_codec), + PhysicalPlanType::Extension(extension) => { + self.try_into_extension_physical_plan(extension, ctx, extension_codec) + } + PhysicalPlanType::NestedLoopJoin(join) => { + self.try_into_nested_loop_join_physical_plan(join, ctx, extension_codec) + } + PhysicalPlanType::Analyze(analyze) => { + self.try_into_analyze_physical_plan(analyze, ctx, extension_codec) + } + PhysicalPlanType::JsonSink(sink) => { + self.try_into_json_sink_physical_plan(sink, ctx, extension_codec) + } + PhysicalPlanType::CsvSink(sink) => { + self.try_into_csv_sink_physical_plan(sink, ctx, extension_codec) + } + #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] + PhysicalPlanType::ParquetSink(sink) => { + self.try_into_parquet_sink_physical_plan(sink, ctx, extension_codec) + } + PhysicalPlanType::Unnest(unnest) => { + self.try_into_unnest_physical_plan(unnest, ctx, extension_codec) + } + PhysicalPlanType::Cooperative(cooperative) => { + self.try_into_cooperative_physical_plan(cooperative, ctx, extension_codec) + } + PhysicalPlanType::GenerateSeries(generate_series) => { + self.try_into_generate_series_physical_plan(generate_series) + } + PhysicalPlanType::SortMergeJoin(sort_join) => { + self.try_into_sort_join(sort_join, ctx, extension_codec) + } + } + } - let physical_window_expr: Vec> = window_agg - .window_expr - .iter() - .map(|window_expr| { - parse_physical_window_expr( - window_expr, - registry, - input_schema.as_ref(), - extension_codec, - ) - }) - .collect::, _>>()?; + fn try_from_physical_plan( + plan: Arc, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + let plan_clone = Arc::clone(&plan); + let plan = plan.as_any(); - let partition_keys = window_agg - .partition_keys - .iter() - .map(|expr| { - parse_physical_expr( - expr, - registry, - input.schema().as_ref(), - extension_codec, - ) - }) - .collect::>>>()?; + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_explain_exec( + exec, + extension_codec, + ); + } - if let Some(input_order_mode) = window_agg.input_order_mode.as_ref() { - let input_order_mode = match input_order_mode { - window_agg_exec_node::InputOrderMode::Linear(_) => { - InputOrderMode::Linear - } - window_agg_exec_node::InputOrderMode::PartiallySorted( - protobuf::PartiallySortedInputOrderMode { columns }, - ) => InputOrderMode::PartiallySorted( - columns.iter().map(|c| *c as usize).collect(), - ), - window_agg_exec_node::InputOrderMode::Sorted(_) => { - InputOrderMode::Sorted - } - }; - - Ok(Arc::new(BoundedWindowAggExec::try_new( - physical_window_expr, - input, - input_order_mode, - !partition_keys.is_empty(), - )?)) - } else { - Ok(Arc::new(WindowAggExec::try_new( - physical_window_expr, - input, - !partition_keys.is_empty(), - )?)) - } - } - PhysicalPlanType::Aggregate(hash_agg) => { - let input: Arc = into_physical_plan( - &hash_agg.input, - registry, - runtime, - extension_codec, - )?; - let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err( - |_| { - proto_error(format!( - "Received a AggregateNode message with unknown AggregateMode {}", - hash_agg.mode - )) - }, - )?; - let agg_mode: AggregateMode = match mode { - protobuf::AggregateMode::Partial => AggregateMode::Partial, - protobuf::AggregateMode::Final => AggregateMode::Final, - protobuf::AggregateMode::FinalPartitioned => { - AggregateMode::FinalPartitioned - } - protobuf::AggregateMode::Single => AggregateMode::Single, - protobuf::AggregateMode::SinglePartitioned => { - AggregateMode::SinglePartitioned - } - }; + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_projection_exec( + exec, + extension_codec, + ); + } - let num_expr = hash_agg.group_expr.len(); + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_analyze_exec( + exec, + extension_codec, + ); + } - let group_expr = hash_agg - .group_expr - .iter() - .zip(hash_agg.group_expr_name.iter()) - .map(|(expr, name)| { - parse_physical_expr( - expr, - registry, - input.schema().as_ref(), - extension_codec, - ) - .map(|expr| (expr, name.to_string())) - }) - .collect::, _>>()?; + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_filter_exec( + exec, + extension_codec, + ); + } - let null_expr = hash_agg - .null_expr - .iter() - .zip(hash_agg.group_expr_name.iter()) - .map(|(expr, name)| { - parse_physical_expr( - expr, - registry, - input.schema().as_ref(), - extension_codec, - ) - .map(|expr| (expr, name.to_string())) - }) - .collect::, _>>()?; - - let groups: Vec> = if !hash_agg.groups.is_empty() { - hash_agg - .groups - .chunks(num_expr) - .map(|g| g.to_vec()) - .collect::>>() - } else { - vec![] - }; + if let Some(limit) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_global_limit_exec( + limit, + extension_codec, + ); + } - let input_schema = hash_agg.input_schema.as_ref().ok_or_else(|| { - DataFusionError::Internal( - "input_schema in AggregateNode is missing.".to_owned(), - ) - })?; - let physical_schema: SchemaRef = SchemaRef::new(input_schema.try_into()?); + if let Some(limit) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_local_limit_exec( + limit, + extension_codec, + ); + } - let physical_filter_expr = hash_agg - .filter_expr - .iter() - .map(|expr| { - expr.expr - .as_ref() - .map(|e| { - parse_physical_expr( - e, - registry, - &physical_schema, - extension_codec, - ) - }) - .transpose() - }) - .collect::, _>>()?; + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_hash_join_exec( + exec, + extension_codec, + ); + } - let physical_aggr_expr: Vec> = hash_agg - .aggr_expr - .iter() - .zip(hash_agg.aggr_expr_name.iter()) - .map(|(expr, name)| { - let expr_type = expr.expr_type.as_ref().ok_or_else(|| { - proto_error("Unexpected empty aggregate physical expression") - })?; - - match expr_type { - ExprType::AggregateExpr(agg_node) => { - let input_phy_expr: Vec> = agg_node.expr.iter() - .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; - let ordering_req: LexOrdering = agg_node.ordering_req.iter() - .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)) - .collect::>()?; - agg_node.aggregate_function.as_ref().map(|func| { - match func { - AggregateFunction::UserDefinedAggrFunction(udaf_name) => { - let agg_udf = match &agg_node.fun_definition { - Some(buf) => extension_codec.try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)? - }; - - AggregateExprBuilder::new(agg_udf, input_phy_expr) - .schema(Arc::clone(&physical_schema)) - .alias(name) - .with_ignore_nulls(agg_node.ignore_nulls) - .with_distinct(agg_node.distinct) - .order_by(ordering_req) - .build() - .map(Arc::new) - } - } - }).transpose()?.ok_or_else(|| { - proto_error("Invalid AggregateExpr, missing aggregate_function") - }) - } - _ => internal_err!( - "Invalid aggregate expression for AggregateExec" - ), - } - }) - .collect::, _>>()?; + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( + exec, + extension_codec, + ); + } - let limit = hash_agg - .limit - .as_ref() - .map(|lit_value| lit_value.limit as usize); - - let agg = AggregateExec::try_new( - agg_mode, - PhysicalGroupBy::new(group_expr, null_expr, groups), - physical_aggr_expr, - physical_filter_expr, - input, - physical_schema, - )?; + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_sort_merge_join_exec( + exec, + extension_codec, + ); + } - let agg = agg.with_limit(limit); + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_cross_join_exec( + exec, + extension_codec, + ); + } - Ok(Arc::new(agg)) - } - PhysicalPlanType::HashJoin(hashjoin) => { - let left: Arc = into_physical_plan( - &hashjoin.left, - registry, - runtime, - extension_codec, - )?; - let right: Arc = into_physical_plan( - &hashjoin.right, - registry, - runtime, - extension_codec, - )?; - let left_schema = left.schema(); - let right_schema = right.schema(); - let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin - .on - .iter() - .map(|col| { - let left = parse_physical_expr( - &col.left.clone().unwrap(), - registry, - left_schema.as_ref(), - extension_codec, - )?; - let right = parse_physical_expr( - &col.right.clone().unwrap(), - registry, - right_schema.as_ref(), - extension_codec, - )?; - Ok((left, right)) - }) - .collect::>()?; - let join_type = protobuf::JoinType::try_from(hashjoin.join_type) - .map_err(|_| { - proto_error(format!( - "Received a HashJoinNode message with unknown JoinType {}", - hashjoin.join_type - )) - })?; - let filter = hashjoin - .filter - .as_ref() - .map(|f| { - let schema = f - .schema - .as_ref() - .ok_or_else(|| proto_error("Missing JoinFilter schema"))? - .try_into()?; + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_aggregate_exec( + exec, + extension_codec, + ); + } - let expression = parse_physical_expr( - f.expression.as_ref().ok_or_else(|| { - proto_error("Unexpected empty filter expression") - })?, - registry, &schema, - extension_codec, - )?; - let column_indices = f.column_indices - .iter() - .map(|i| { - let side = protobuf::JoinSide::try_from(i.side) - .map_err(|_| proto_error(format!( - "Received a HashJoinNode message with JoinSide in Filter {}", - i.side)) - )?; + if let Some(empty) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_empty_exec( + empty, + extension_codec, + ); + } - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::>>()?; + if let Some(empty) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_placeholder_row_exec( + empty, + extension_codec, + ); + } - Ok(JoinFilter::new(expression, column_indices, Arc::new(schema))) - }) - .map_or(Ok(None), |v: Result| v.map(Some))?; + if let Some(coalesce_batches) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_coalesce_batches_exec( + coalesce_batches, + extension_codec, + ); + } - let partition_mode = protobuf::PartitionMode::try_from( - hashjoin.partition_mode, - ) - .map_err(|_| { - proto_error(format!( - "Received a HashJoinNode message with unknown PartitionMode {}", - hashjoin.partition_mode - )) - })?; - let partition_mode = match partition_mode { - protobuf::PartitionMode::CollectLeft => PartitionMode::CollectLeft, - protobuf::PartitionMode::Partitioned => PartitionMode::Partitioned, - protobuf::PartitionMode::Auto => PartitionMode::Auto, - }; - let projection = if !hashjoin.projection.is_empty() { - Some( - hashjoin - .projection - .iter() - .map(|i| *i as usize) - .collect::>(), - ) - } else { - None - }; - Ok(Arc::new(HashJoinExec::try_new( - left, - right, - on, - filter, - &join_type.into(), - projection, - partition_mode, - hashjoin.null_equals_null, - )?)) + if let Some(data_source_exec) = plan.downcast_ref::() { + if let Some(node) = protobuf::PhysicalPlanNode::try_from_data_source_exec( + data_source_exec, + extension_codec, + )? { + return Ok(node); } - PhysicalPlanType::SymmetricHashJoin(sym_join) => { - let left = into_physical_plan( - &sym_join.left, - registry, - runtime, - extension_codec, - )?; - let right = into_physical_plan( - &sym_join.right, - registry, - runtime, - extension_codec, - )?; - let left_schema = left.schema(); - let right_schema = right.schema(); - let on = sym_join - .on - .iter() - .map(|col| { - let left = parse_physical_expr( - &col.left.clone().unwrap(), - registry, - left_schema.as_ref(), - extension_codec, - )?; - let right = parse_physical_expr( - &col.right.clone().unwrap(), - registry, - right_schema.as_ref(), + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_coalesce_partitions_exec( + exec, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_repartition_exec( + exec, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_sort_exec(exec, extension_codec); + } + + if let Some(union) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_union_exec( + union, + extension_codec, + ); + } + + if let Some(interleave) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_interleave_exec( + interleave, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_sort_preserving_merge_exec( + exec, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_nested_loop_join_exec( + exec, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_window_agg_exec( + exec, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_bounded_window_agg_exec( + exec, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + if let Some(node) = protobuf::PhysicalPlanNode::try_from_data_sink_exec( + exec, + extension_codec, + )? { + return Ok(node); + } + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_unnest_exec( + exec, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_cooperative_exec( + exec, + extension_codec, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + if let Some(node) = + protobuf::PhysicalPlanNode::try_from_lazy_memory_exec(exec)? + { + return Ok(node); + } + } + + let mut buf: Vec = vec![]; + match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { + Ok(_) => { + let inputs: Vec = plan_clone + .children() + .into_iter() + .cloned() + .map(|i| { + protobuf::PhysicalPlanNode::try_from_physical_plan( + i, extension_codec, - )?; - Ok((left, right)) + ) }) .collect::>()?; - let join_type = protobuf::JoinType::try_from(sym_join.join_type) - .map_err(|_| { - proto_error(format!( - "Received a SymmetricHashJoin message with unknown JoinType {}", - sym_join.join_type - )) - })?; - let filter = sym_join - .filter + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Extension( + protobuf::PhysicalExtensionNode { node: buf, inputs }, + )), + }) + } + Err(e) => internal_err!( + "Unsupported plan and extension codec failed with [{e}]. Plan: {plan_clone:?}" + ), + } + } +} + +impl protobuf::PhysicalPlanNode { + fn try_into_explain_physical_plan( + &self, + explain: &protobuf::ExplainExecNode, + _ctx: &TaskContext, + + _extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + Ok(Arc::new(ExplainExec::new( + Arc::new(explain.schema.as_ref().unwrap().try_into()?), + explain + .stringified_plans + .iter() + .map(|plan| plan.into()) + .collect(), + explain.verbose, + ))) + } + + fn try_into_projection_physical_plan( + &self, + projection: &protobuf::ProjectionExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&projection.input, ctx, extension_codec)?; + let exprs = projection + .expr + .iter() + .zip(projection.expr_name.iter()) + .map(|(expr, name)| { + Ok(( + parse_physical_expr( + expr, + ctx, + input.schema().as_ref(), + extension_codec, + )?, + name.to_string(), + )) + }) + .collect::, String)>>>()?; + let proj_exprs: Vec = exprs + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?)) + } + + fn try_into_filter_physical_plan( + &self, + filter: &protobuf::FilterExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&filter.input, ctx, extension_codec)?; + + let predicate = filter + .expr + .as_ref() + .map(|expr| { + parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + }) + .transpose()? + .ok_or_else(|| { + internal_datafusion_err!( + "filter (FilterExecNode) in PhysicalPlanNode is missing." + ) + })?; + + let filter_selectivity = filter.default_filter_selectivity.try_into(); + let projection = if !filter.projection.is_empty() { + Some( + filter + .projection + .iter() + .map(|i| *i as usize) + .collect::>(), + ) + } else { + None + }; + + let filter = + FilterExec::try_new(predicate, input)?.with_projection(projection)?; + match filter_selectivity { + Ok(filter_selectivity) => Ok(Arc::new( + filter.with_default_selectivity(filter_selectivity)?, + )), + Err(_) => Err(internal_datafusion_err!( + "filter_selectivity in PhysicalPlanNode is invalid " + )), + } + } + + fn try_into_csv_scan_physical_plan( + &self, + scan: &protobuf::CsvScanExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let escape = + if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape(escape)) = + &scan.optional_escape + { + Some(str_to_byte(escape, "escape")?) + } else { + None + }; + + let comment = if let Some( + protobuf::csv_scan_exec_node::OptionalComment::Comment(comment), + ) = &scan.optional_comment + { + Some(str_to_byte(comment, "comment")?) + } else { + None + }; + + let source = Arc::new( + CsvSource::new( + scan.has_header, + str_to_byte(&scan.delimiter, "delimiter")?, + 0, + ) + .with_escape(escape) + .with_comment(comment), + ); + + let conf = FileScanConfigBuilder::from(parse_protobuf_file_scan_config( + scan.base_conf.as_ref().unwrap(), + ctx, + extension_codec, + source, + )?) + .with_newlines_in_values(scan.newlines_in_values) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(); + Ok(DataSourceExec::from_data_source(conf)) + } + + fn try_into_json_scan_physical_plan( + &self, + scan: &protobuf::JsonScanExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let scan_conf = parse_protobuf_file_scan_config( + scan.base_conf.as_ref().unwrap(), + ctx, + extension_codec, + Arc::new(JsonSource::new()), + )?; + Ok(DataSourceExec::from_data_source(scan_conf)) + } + + #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] + fn try_into_parquet_scan_physical_plan( + &self, + scan: &protobuf::ParquetScanExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + #[cfg(feature = "parquet")] + { + let schema = + parse_protobuf_file_scan_schema(scan.base_conf.as_ref().unwrap())?; + + // Check if there's a projection and use projected schema for predicate parsing + let base_conf = scan.base_conf.as_ref().unwrap(); + let predicate_schema = if !base_conf.projection.is_empty() { + // Create projected schema for parsing the predicate + let projected_fields: Vec<_> = base_conf + .projection + .iter() + .map(|&i| schema.field(i as usize).clone()) + .collect(); + Arc::new(Schema::new(projected_fields)) + } else { + schema + }; + + let predicate = scan + .predicate + .as_ref() + .map(|expr| { + parse_physical_expr( + expr, + ctx, + predicate_schema.as_ref(), + extension_codec, + ) + }) + .transpose()?; + let mut options = TableParquetOptions::default(); + + if let Some(table_options) = scan.parquet_options.as_ref() { + options = table_options.try_into()?; + } + let mut source = ParquetSource::new(options); + + if let Some(predicate) = predicate { + source = source.with_predicate(predicate); + } + let base_config = parse_protobuf_file_scan_config( + base_conf, + ctx, + extension_codec, + Arc::new(source), + )?; + Ok(DataSourceExec::from_data_source(base_config)) + } + #[cfg(not(feature = "parquet"))] + panic!("Unable to process a Parquet PhysicalPlan when `parquet` feature is not enabled") + } + + #[cfg_attr(not(feature = "avro"), allow(unused_variables))] + fn try_into_avro_scan_physical_plan( + &self, + scan: &protobuf::AvroScanExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + #[cfg(feature = "avro")] + { + let conf = parse_protobuf_file_scan_config( + scan.base_conf.as_ref().unwrap(), + ctx, + extension_codec, + Arc::new(AvroSource::new()), + )?; + Ok(DataSourceExec::from_data_source(conf)) + } + #[cfg(not(feature = "avro"))] + panic!("Unable to process a Avro PhysicalPlan when `avro` feature is not enabled") + } + + fn try_into_memory_scan_physical_plan( + &self, + scan: &protobuf::MemoryScanExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let partitions = scan + .partitions + .iter() + .map(|p| parse_record_batches(p)) + .collect::>>()?; + + let proto_schema = scan.schema.as_ref().ok_or_else(|| { + internal_datafusion_err!("schema in MemoryScanExecNode is missing.") + })?; + let schema: SchemaRef = SchemaRef::new(proto_schema.try_into()?); + + let projection = if !scan.projection.is_empty() { + Some( + scan.projection + .iter() + .map(|i| *i as usize) + .collect::>(), + ) + } else { + None + }; + + let mut sort_information = vec![]; + for ordering in &scan.sort_information { + let sort_exprs = parse_physical_sort_exprs( + &ordering.physical_sort_expr_nodes, + ctx, + &schema, + extension_codec, + )?; + sort_information.extend(LexOrdering::new(sort_exprs)); + } + + let source = MemorySourceConfig::try_new(&partitions, schema, projection)? + .with_limit(scan.fetch.map(|f| f as usize)) + .with_show_sizes(scan.show_sizes); + + let source = source.try_with_sort_information(sort_information)?; + + Ok(DataSourceExec::from_data_source(source)) + } + + fn try_into_coalesce_batches_physical_plan( + &self, + coalesce_batches: &protobuf::CoalesceBatchesExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&coalesce_batches.input, ctx, extension_codec)?; + Ok(Arc::new( + CoalesceBatchesExec::new(input, coalesce_batches.target_batch_size as usize) + .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), + )) + } + + fn try_into_merge_physical_plan( + &self, + merge: &protobuf::CoalescePartitionsExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&merge.input, ctx, extension_codec)?; + Ok(Arc::new( + CoalescePartitionsExec::new(input) + .with_fetch(merge.fetch.map(|f| f as usize)), + )) + } + + fn try_into_repartition_physical_plan( + &self, + repart: &protobuf::RepartitionExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&repart.input, ctx, extension_codec)?; + let partitioning = parse_protobuf_partitioning( + repart.partitioning.as_ref(), + ctx, + input.schema().as_ref(), + extension_codec, + )?; + Ok(Arc::new(RepartitionExec::try_new( + input, + partitioning.unwrap(), + )?)) + } + + fn try_into_global_limit_physical_plan( + &self, + limit: &protobuf::GlobalLimitExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&limit.input, ctx, extension_codec)?; + let fetch = if limit.fetch >= 0 { + Some(limit.fetch as usize) + } else { + None + }; + Ok(Arc::new(GlobalLimitExec::new( + input, + limit.skip as usize, + fetch, + ))) + } + + fn try_into_local_limit_physical_plan( + &self, + limit: &protobuf::LocalLimitExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&limit.input, ctx, extension_codec)?; + Ok(Arc::new(LocalLimitExec::new(input, limit.fetch as usize))) + } + + fn try_into_window_physical_plan( + &self, + window_agg: &protobuf::WindowAggExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&window_agg.input, ctx, extension_codec)?; + let input_schema = input.schema(); + + let physical_window_expr: Vec> = window_agg + .window_expr + .iter() + .map(|window_expr| { + parse_physical_window_expr( + window_expr, + ctx, + input_schema.as_ref(), + extension_codec, + ) + }) + .collect::, _>>()?; + + let partition_keys = window_agg + .partition_keys + .iter() + .map(|expr| { + parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + }) + .collect::>>>()?; + + if let Some(input_order_mode) = window_agg.input_order_mode.as_ref() { + let input_order_mode = match input_order_mode { + window_agg_exec_node::InputOrderMode::Linear(_) => InputOrderMode::Linear, + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { columns }, + ) => InputOrderMode::PartiallySorted( + columns.iter().map(|c| *c as usize).collect(), + ), + window_agg_exec_node::InputOrderMode::Sorted(_) => InputOrderMode::Sorted, + }; + + Ok(Arc::new(BoundedWindowAggExec::try_new( + physical_window_expr, + input, + input_order_mode, + !partition_keys.is_empty(), + )?)) + } else { + Ok(Arc::new(WindowAggExec::try_new( + physical_window_expr, + input, + !partition_keys.is_empty(), + )?)) + } + } + + fn try_into_aggregate_physical_plan( + &self, + hash_agg: &protobuf::AggregateExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&hash_agg.input, ctx, extension_codec)?; + let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err(|_| { + proto_error(format!( + "Received a AggregateNode message with unknown AggregateMode {}", + hash_agg.mode + )) + })?; + let agg_mode: AggregateMode = match mode { + protobuf::AggregateMode::Partial => AggregateMode::Partial, + protobuf::AggregateMode::Final => AggregateMode::Final, + protobuf::AggregateMode::FinalPartitioned => AggregateMode::FinalPartitioned, + protobuf::AggregateMode::Single => AggregateMode::Single, + protobuf::AggregateMode::SinglePartitioned => { + AggregateMode::SinglePartitioned + } + }; + + let num_expr = hash_agg.group_expr.len(); + + let group_expr = hash_agg + .group_expr + .iter() + .zip(hash_agg.group_expr_name.iter()) + .map(|(expr, name)| { + parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + .map(|expr| (expr, name.to_string())) + }) + .collect::, _>>()?; + + let null_expr = hash_agg + .null_expr + .iter() + .zip(hash_agg.group_expr_name.iter()) + .map(|(expr, name)| { + parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + .map(|expr| (expr, name.to_string())) + }) + .collect::, _>>()?; + + let groups: Vec> = if !hash_agg.groups.is_empty() { + hash_agg + .groups + .chunks(num_expr) + .map(|g| g.to_vec()) + .collect::>>() + } else { + vec![] + }; + + let input_schema = hash_agg.input_schema.as_ref().ok_or_else(|| { + internal_datafusion_err!("input_schema in AggregateNode is missing.") + })?; + let physical_schema: SchemaRef = SchemaRef::new(input_schema.try_into()?); + + let physical_filter_expr = hash_agg + .filter_expr + .iter() + .map(|expr| { + expr.expr .as_ref() - .map(|f| { - let schema = f - .schema - .as_ref() - .ok_or_else(|| proto_error("Missing JoinFilter schema"))? - .try_into()?; + .map(|e| { + parse_physical_expr(e, ctx, &physical_schema, extension_codec) + }) + .transpose() + }) + .collect::, _>>()?; + + let physical_aggr_expr: Vec> = hash_agg + .aggr_expr + .iter() + .zip(hash_agg.aggr_expr_name.iter()) + .map(|(expr, name)| { + let expr_type = expr.expr_type.as_ref().ok_or_else(|| { + proto_error("Unexpected empty aggregate physical expression") + })?; - let expression = parse_physical_expr( - f.expression.as_ref().ok_or_else(|| { - proto_error("Unexpected empty filter expression") - })?, - registry, &schema, - extension_codec, - )?; - let column_indices = f.column_indices + match expr_type { + ExprType::AggregateExpr(agg_node) => { + let input_phy_expr: Vec> = agg_node + .expr .iter() - .map(|i| { - let side = protobuf::JoinSide::try_from(i.side) - .map_err(|_| proto_error(format!( - "Received a HashJoinNode message with JoinSide in Filter {}", - i.side)) - )?; - - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) + .map(|e| { + parse_physical_expr( + e, + ctx, + &physical_schema, + extension_codec, + ) + }) + .collect::>>()?; + let order_bys = agg_node + .ordering_req + .iter() + .map(|e| { + parse_physical_sort_expr( + e, + ctx, + &physical_schema, + extension_codec, + ) }) .collect::>()?; + agg_node + .aggregate_function + .as_ref() + .map(|func| match func { + AggregateFunction::UserDefinedAggrFunction(udaf_name) => { + let agg_udf = match &agg_node.fun_definition { + Some(buf) => extension_codec + .try_decode_udaf(udaf_name, buf)?, + None => ctx.udaf(udaf_name).or_else(|_| { + extension_codec + .try_decode_udaf(udaf_name, &[]) + })?, + }; + + AggregateExprBuilder::new(agg_udf, input_phy_expr) + .schema(Arc::clone(&physical_schema)) + .alias(name) + .human_display(agg_node.human_display.clone()) + .with_ignore_nulls(agg_node.ignore_nulls) + .with_distinct(agg_node.distinct) + .order_by(order_bys) + .build() + .map(Arc::new) + } + }) + .transpose()? + .ok_or_else(|| { + proto_error( + "Invalid AggregateExpr, missing aggregate_function", + ) + }) + } + _ => internal_err!("Invalid aggregate expression for AggregateExec"), + } + }) + .collect::, _>>()?; + + let limit = hash_agg + .limit + .as_ref() + .map(|lit_value| lit_value.limit as usize); + + let agg = AggregateExec::try_new( + agg_mode, + PhysicalGroupBy::new(group_expr, null_expr, groups), + physical_aggr_expr, + physical_filter_expr, + input, + physical_schema, + )?; + + let agg = agg.with_limit(limit); + + Ok(Arc::new(agg)) + } - Ok(JoinFilter::new(expression, column_indices, Arc::new(schema))) - }) - .map_or(Ok(None), |v: Result| v.map(Some))?; + fn try_into_hash_join_physical_plan( + &self, + hashjoin: &protobuf::HashJoinExecNode, + ctx: &TaskContext, - let left_sort_exprs = parse_physical_sort_exprs( - &sym_join.left_sort_exprs, - registry, - &left_schema, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let left: Arc = + into_physical_plan(&hashjoin.left, ctx, extension_codec)?; + let right: Arc = + into_physical_plan(&hashjoin.right, ctx, extension_codec)?; + let left_schema = left.schema(); + let right_schema = right.schema(); + let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin + .on + .iter() + .map(|col| { + let left = parse_physical_expr( + &col.left.clone().unwrap(), + ctx, + left_schema.as_ref(), extension_codec, )?; - let left_sort_exprs = if left_sort_exprs.is_empty() { - None - } else { - Some(left_sort_exprs) - }; + let right = parse_physical_expr( + &col.right.clone().unwrap(), + ctx, + right_schema.as_ref(), + extension_codec, + )?; + Ok((left, right)) + }) + .collect::>()?; + let join_type = + protobuf::JoinType::try_from(hashjoin.join_type).map_err(|_| { + proto_error(format!( + "Received a HashJoinNode message with unknown JoinType {}", + hashjoin.join_type + )) + })?; + let null_equality = protobuf::NullEquality::try_from(hashjoin.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a HashJoinNode message with unknown NullEquality {}", + hashjoin.null_equality + )) + })?; + let filter = hashjoin + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; - let right_sort_exprs = parse_physical_sort_exprs( - &sym_join.right_sort_exprs, - registry, - &right_schema, + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + ctx, &schema, extension_codec, )?; - let right_sort_exprs = if right_sort_exprs.is_empty() { - None - } else { - Some(right_sort_exprs) - }; - - let partition_mode = - protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| { - proto_error(format!( - "Received a SymmetricHashJoin message with unknown PartitionMode {}", - sym_join.partition_mode - )) - })?; - let partition_mode = match partition_mode { - protobuf::StreamPartitionMode::SinglePartition => { - StreamJoinPartitionMode::SinglePartition - } - protobuf::StreamPartitionMode::PartitionedExec => { - StreamJoinPartitionMode::Partitioned - } - }; - SymmetricHashJoinExec::try_new( - left, - right, - on, - filter, - &join_type.into(), - sym_join.null_equals_null, - left_sort_exprs, - right_sort_exprs, - partition_mode, - ) - .map(|e| Arc::new(e) as _) - } - PhysicalPlanType::Union(union) => { - let mut inputs: Vec> = vec![]; - for input in &union.inputs { - inputs.push(input.try_into_physical_plan( - registry, - runtime, - extension_codec, - )?); - } - Ok(Arc::new(UnionExec::new(inputs))) - } - PhysicalPlanType::Interleave(interleave) => { - let mut inputs: Vec> = vec![]; - for input in &interleave.inputs { - inputs.push(input.try_into_physical_plan( - registry, - runtime, - extension_codec, - )?); - } - Ok(Arc::new(InterleaveExec::try_new(inputs)?)) - } - PhysicalPlanType::CrossJoin(crossjoin) => { - let left: Arc = into_physical_plan( - &crossjoin.left, - registry, - runtime, + let column_indices = f.column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side) + .map_err(|_| proto_error(format!( + "Received a HashJoinNode message with JoinSide in Filter {}", + i.side)) + )?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>>()?; + + Ok(JoinFilter::new(expression, column_indices, Arc::new(schema))) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = protobuf::PartitionMode::try_from(hashjoin.partition_mode) + .map_err(|_| { + proto_error(format!( + "Received a HashJoinNode message with unknown PartitionMode {}", + hashjoin.partition_mode + )) + })?; + let partition_mode = match partition_mode { + protobuf::PartitionMode::CollectLeft => PartitionMode::CollectLeft, + protobuf::PartitionMode::Partitioned => PartitionMode::Partitioned, + protobuf::PartitionMode::Auto => PartitionMode::Auto, + }; + let projection = if !hashjoin.projection.is_empty() { + Some( + hashjoin + .projection + .iter() + .map(|i| *i as usize) + .collect::>(), + ) + } else { + None + }; + Ok(Arc::new(HashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + projection, + partition_mode, + null_equality.into(), + )?)) + } + + fn try_into_symmetric_hash_join_physical_plan( + &self, + sym_join: &protobuf::SymmetricHashJoinExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let left = into_physical_plan(&sym_join.left, ctx, extension_codec)?; + let right = into_physical_plan(&sym_join.right, ctx, extension_codec)?; + let left_schema = left.schema(); + let right_schema = right.schema(); + let on = sym_join + .on + .iter() + .map(|col| { + let left = parse_physical_expr( + &col.left.clone().unwrap(), + ctx, + left_schema.as_ref(), extension_codec, )?; - let right: Arc = into_physical_plan( - &crossjoin.right, - registry, - runtime, + let right = parse_physical_expr( + &col.right.clone().unwrap(), + ctx, + right_schema.as_ref(), extension_codec, )?; - Ok(Arc::new(CrossJoinExec::new(left, right))) - } - PhysicalPlanType::Empty(empty) => { - let schema = Arc::new(convert_required!(empty.schema)?); - Ok(Arc::new(EmptyExec::new(schema))) + Ok((left, right)) + }) + .collect::>()?; + let join_type = + protobuf::JoinType::try_from(sym_join.join_type).map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown JoinType {}", + sym_join.join_type + )) + })?; + let null_equality = protobuf::NullEquality::try_from(sym_join.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown NullEquality {}", + sym_join.null_equality + )) + })?; + let filter = sym_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + ctx, &schema, + extension_codec, + )?; + let column_indices = f.column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side) + .map_err(|_| proto_error(format!( + "Received a HashJoinNode message with JoinSide in Filter {}", + i.side)) + )?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>()?; + + Ok(JoinFilter::new(expression, column_indices, Arc::new(schema))) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let left_sort_exprs = parse_physical_sort_exprs( + &sym_join.left_sort_exprs, + ctx, + &left_schema, + extension_codec, + )?; + let left_sort_exprs = LexOrdering::new(left_sort_exprs); + + let right_sort_exprs = parse_physical_sort_exprs( + &sym_join.right_sort_exprs, + ctx, + &right_schema, + extension_codec, + )?; + let right_sort_exprs = LexOrdering::new(right_sort_exprs); + + let partition_mode = protobuf::StreamPartitionMode::try_from( + sym_join.partition_mode, + ) + .map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown PartitionMode {}", + sym_join.partition_mode + )) + })?; + let partition_mode = match partition_mode { + protobuf::StreamPartitionMode::SinglePartition => { + StreamJoinPartitionMode::SinglePartition } - PhysicalPlanType::PlaceholderRow(placeholder) => { - let schema = Arc::new(convert_required!(placeholder.schema)?); - Ok(Arc::new(PlaceholderRowExec::new(schema))) + protobuf::StreamPartitionMode::PartitionedExec => { + StreamJoinPartitionMode::Partitioned } - PhysicalPlanType::Sort(sort) => { - let input: Arc = - into_physical_plan(&sort.input, registry, runtime, extension_codec)?; - let exprs = sort - .expr - .iter() - .map(|expr| { - let expr = expr.expr_type.as_ref().ok_or_else(|| { + }; + SymmetricHashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + null_equality.into(), + left_sort_exprs, + right_sort_exprs, + partition_mode, + ) + .map(|e| Arc::new(e) as _) + } + + fn try_into_union_physical_plan( + &self, + union: &protobuf::UnionExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let mut inputs: Vec> = vec![]; + for input in &union.inputs { + inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + } + UnionExec::try_new(inputs) + } + + fn try_into_interleave_physical_plan( + &self, + interleave: &protobuf::InterleaveExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let mut inputs: Vec> = vec![]; + for input in &interleave.inputs { + inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + } + Ok(Arc::new(InterleaveExec::try_new(inputs)?)) + } + + fn try_into_cross_join_physical_plan( + &self, + crossjoin: &protobuf::CrossJoinExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let left: Arc = + into_physical_plan(&crossjoin.left, ctx, extension_codec)?; + let right: Arc = + into_physical_plan(&crossjoin.right, ctx, extension_codec)?; + Ok(Arc::new(CrossJoinExec::new(left, right))) + } + + fn try_into_empty_physical_plan( + &self, + empty: &protobuf::EmptyExecNode, + _ctx: &TaskContext, + + _extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let schema = Arc::new(convert_required!(empty.schema)?); + Ok(Arc::new(EmptyExec::new(schema))) + } + + fn try_into_placeholder_row_physical_plan( + &self, + placeholder: &protobuf::PlaceholderRowExecNode, + _ctx: &TaskContext, + + _extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let schema = Arc::new(convert_required!(placeholder.schema)?); + Ok(Arc::new(PlaceholderRowExec::new(schema))) + } + + fn try_into_sort_physical_plan( + &self, + sort: &protobuf::SortExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let exprs = sort + .expr + .iter() + .map(|expr| { + let expr = expr.expr_type.as_ref().ok_or_else(|| { + proto_error(format!( + "physical_plan::from_proto() Unexpected expr {self:?}" + )) + })?; + if let ExprType::Sort(sort_expr) = expr { + let expr = sort_expr + .expr + .as_ref() + .ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected expr {self:?}" + "physical_plan::from_proto() Unexpected sort expr {self:?}" )) - })?; - if let ExprType::Sort(sort_expr) = expr { - let expr = sort_expr - .expr - .as_ref() - .ok_or_else(|| { - proto_error(format!( - "physical_plan::from_proto() Unexpected sort expr {self:?}" - )) - })? - .as_ref(); - Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, - options: SortOptions { - descending: !sort_expr.asc, - nulls_first: sort_expr.nulls_first, - }, - }) - } else { - internal_err!( - "physical_plan::from_proto() {self:?}" - ) - } + })? + .as_ref(); + Ok(PhysicalSortExpr { + expr: parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec)?, + options: SortOptions { + descending: !sort_expr.asc, + nulls_first: sort_expr.nulls_first, + }, }) - .collect::>()?; - let fetch = if sort.fetch < 0 { - None } else { - Some(sort.fetch as usize) - }; - let new_sort = SortExec::new(exprs, input) - .with_fetch(fetch) - .with_preserve_partitioning(sort.preserve_partitioning); + internal_err!( + "physical_plan::from_proto() {self:?}" + ) + } + }) + .collect::>>()?; + let Some(ordering) = LexOrdering::new(exprs) else { + return internal_err!("SortExec requires an ordering"); + }; + let fetch = (sort.fetch >= 0).then_some(sort.fetch as _); + let new_sort = SortExec::new(ordering, input) + .with_fetch(fetch) + .with_preserve_partitioning(sort.preserve_partitioning); + + Ok(Arc::new(new_sort)) + } - Ok(Arc::new(new_sort)) - } - PhysicalPlanType::SortPreservingMerge(sort) => { - let input: Arc = - into_physical_plan(&sort.input, registry, runtime, extension_codec)?; - let exprs = sort - .expr - .iter() - .map(|expr| { - let expr = expr.expr_type.as_ref().ok_or_else(|| { + fn try_into_sort_preserving_merge_physical_plan( + &self, + sort: &protobuf::SortPreservingMergeExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let exprs = sort + .expr + .iter() + .map(|expr| { + let expr = expr.expr_type.as_ref().ok_or_else(|| { + proto_error(format!( + "physical_plan::from_proto() Unexpected expr {self:?}" + )) + })?; + if let ExprType::Sort(sort_expr) = expr { + let expr = sort_expr + .expr + .as_ref() + .ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected expr {self:?}" - )) - })?; - if let ExprType::Sort(sort_expr) = expr { - let expr = sort_expr - .expr - .as_ref() - .ok_or_else(|| { - proto_error(format!( - "physical_plan::from_proto() Unexpected sort expr {self:?}" - )) - })? - .as_ref(); - Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, - options: SortOptions { - descending: !sort_expr.asc, - nulls_first: sort_expr.nulls_first, - }, - }) - } else { - internal_err!( - "physical_plan::from_proto() {self:?}" - ) - } + "physical_plan::from_proto() Unexpected sort expr {self:?}" + )) + })? + .as_ref(); + Ok(PhysicalSortExpr { + expr: parse_physical_expr( + expr, + ctx, + input.schema().as_ref(), + extension_codec, + )?, + options: SortOptions { + descending: !sort_expr.asc, + nulls_first: sort_expr.nulls_first, + }, }) - .collect::>()?; - let fetch = if sort.fetch < 0 { - None } else { - Some(sort.fetch as usize) - }; - Ok(Arc::new( - SortPreservingMergeExec::new(exprs, input).with_fetch(fetch), - )) - } - PhysicalPlanType::Extension(extension) => { - let inputs: Vec> = extension - .inputs - .iter() - .map(|i| i.try_into_physical_plan(registry, runtime, extension_codec)) - .collect::>()?; + internal_err!("physical_plan::from_proto() {self:?}") + } + }) + .collect::>>()?; + let Some(ordering) = LexOrdering::new(exprs) else { + return internal_err!("SortExec requires an ordering"); + }; + let fetch = (sort.fetch >= 0).then_some(sort.fetch as _); + Ok(Arc::new( + SortPreservingMergeExec::new(ordering, input).with_fetch(fetch), + )) + } - let extension_node = extension_codec.try_decode( - extension.node.as_slice(), - &inputs, - registry, - )?; + fn try_into_extension_physical_plan( + &self, + extension: &protobuf::PhysicalExtensionNode, + ctx: &TaskContext, - Ok(extension_node) - } - PhysicalPlanType::NestedLoopJoin(join) => { - let left: Arc = - into_physical_plan(&join.left, registry, runtime, extension_codec)?; - let right: Arc = - into_physical_plan(&join.right, registry, runtime, extension_codec)?; - let join_type = - protobuf::JoinType::try_from(join.join_type).map_err(|_| { - proto_error(format!( - "Received a NestedLoopJoinExecNode message with unknown JoinType {}", - join.join_type - )) - })?; - let filter = join + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let inputs: Vec> = extension + .inputs + .iter() + .map(|i| i.try_into_physical_plan(ctx, extension_codec)) + .collect::>()?; + + let extension_node = + extension_codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; + + Ok(extension_node) + } + + fn try_into_nested_loop_join_physical_plan( + &self, + join: &protobuf::NestedLoopJoinExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let left: Arc = + into_physical_plan(&join.left, ctx, extension_codec)?; + let right: Arc = + into_physical_plan(&join.right, ctx, extension_codec)?; + let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { + proto_error(format!( + "Received a NestedLoopJoinExecNode message with unknown JoinType {}", + join.join_type + )) + })?; + let filter = join .filter .as_ref() .map(|f| { @@ -1013,7 +1556,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - registry, &schema, + ctx, &schema, extension_codec, )?; let column_indices = f.column_indices @@ -1036,1121 +1579,1630 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }) .map_or(Ok(None), |v: Result| v.map(Some))?; - let projection = if !join.projection.is_empty() { - Some( - join.projection - .iter() - .map(|i| *i as usize) - .collect::>(), - ) - } else { - None - }; + let projection = if !join.projection.is_empty() { + Some( + join.projection + .iter() + .map(|i| *i as usize) + .collect::>(), + ) + } else { + None + }; + + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, + right, + filter, + &join_type.into(), + projection, + )?)) + } - Ok(Arc::new(NestedLoopJoinExec::try_new( - left, - right, - filter, - &join_type.into(), - projection, - )?)) - } - PhysicalPlanType::Analyze(analyze) => { - let input: Arc = into_physical_plan( - &analyze.input, - registry, - runtime, - extension_codec, - )?; - Ok(Arc::new(AnalyzeExec::new( - analyze.verbose, - analyze.show_statistics, - input, - Arc::new(convert_required!(analyze.schema)?), - ))) - } - PhysicalPlanType::JsonSink(sink) => { - let input = - into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + fn try_into_analyze_physical_plan( + &self, + analyze: &protobuf::AnalyzeExecNode, + ctx: &TaskContext, - let data_sink: JsonSink = sink - .sink - .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))? - .try_into()?; - let sink_schema = input.schema(); - let sort_order = sink - .sort_order - .as_ref() - .map(|collection| { - parse_physical_sort_exprs( - &collection.physical_sort_expr_nodes, - registry, - &sink_schema, - extension_codec, - ) - .map(LexRequirement::from) - }) - .transpose()?; - Ok(Arc::new(DataSinkExec::new( - input, - Arc::new(data_sink), - sort_order, - ))) - } - PhysicalPlanType::CsvSink(sink) => { - let input = - into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: Arc = + into_physical_plan(&analyze.input, ctx, extension_codec)?; + Ok(Arc::new(AnalyzeExec::new( + analyze.verbose, + analyze.show_statistics, + input, + Arc::new(convert_required!(analyze.schema)?), + ))) + } - let data_sink: CsvSink = sink - .sink - .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))? - .try_into()?; - let sink_schema = input.schema(); - let sort_order = sink - .sort_order - .as_ref() - .map(|collection| { - parse_physical_sort_exprs( - &collection.physical_sort_expr_nodes, - registry, - &sink_schema, - extension_codec, - ) - .map(LexRequirement::from) - }) - .transpose()?; - Ok(Arc::new(DataSinkExec::new( - input, - Arc::new(data_sink), - sort_order, - ))) - } - #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] - PhysicalPlanType::ParquetSink(sink) => { - #[cfg(feature = "parquet")] - { - let input = into_physical_plan( - &sink.input, - registry, - runtime, - extension_codec, - )?; + fn try_into_json_sink_physical_plan( + &self, + sink: &protobuf::JsonSinkExecNode, + ctx: &TaskContext, - let data_sink: ParquetSink = sink - .sink - .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))? - .try_into()?; - let sink_schema = input.schema(); - let sort_order = sink - .sort_order - .as_ref() - .map(|collection| { - parse_physical_sort_exprs( - &collection.physical_sort_expr_nodes, - registry, - &sink_schema, - extension_codec, - ) - .map(LexRequirement::from) - }) - .transpose()?; - Ok(Arc::new(DataSinkExec::new( - input, - Arc::new(data_sink), - sort_order, - ))) - } - #[cfg(not(feature = "parquet"))] - panic!("Trying to use ParquetSink without `parquet` feature enabled"); - } - PhysicalPlanType::Unnest(unnest) => { - let input = into_physical_plan( - &unnest.input, - registry, - runtime, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + + let data_sink: JsonSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = input.schema(); + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + ctx, + &sink_schema, extension_codec, - )?; + ) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) + }) + .transpose()? + .flatten(); + Ok(Arc::new(DataSinkExec::new( + input, + Arc::new(data_sink), + sort_order, + ))) + } - Ok(Arc::new(UnnestExec::new( - input, - unnest - .list_type_columns - .iter() - .map(|c| ListUnnest { - index_in_input_schema: c.index_in_input_schema as _, - depth: c.depth as _, - }) - .collect(), - unnest.struct_type_columns.iter().map(|c| *c as _).collect(), - Arc::new(convert_required!(unnest.schema)?), - into_required!(unnest.options)?, - ))) - } - } + fn try_into_csv_sink_physical_plan( + &self, + sink: &protobuf::CsvSinkExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + + let data_sink: CsvSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = input.schema(); + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + ctx, + &sink_schema, + extension_codec, + ) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) + }) + .transpose()? + .flatten(); + Ok(Arc::new(DataSinkExec::new( + input, + Arc::new(data_sink), + sort_order, + ))) } - fn try_from_physical_plan( - plan: Arc, + fn try_into_parquet_sink_physical_plan( + &self, + sink: &protobuf::ParquetSinkExecNode, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result - where - Self: Sized, - { - let plan_clone = Arc::clone(&plan); - let plan = plan.as_any(); + ) -> Result> { + #[cfg(feature = "parquet")] + { + let input = into_physical_plan(&sink.input, ctx, extension_codec)?; - if let Some(exec) = plan.downcast_ref::() { - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Explain( - protobuf::ExplainExecNode { - schema: Some(exec.schema().as_ref().try_into()?), - stringified_plans: exec - .stringified_plans() - .iter() - .map(|plan| plan.into()) - .collect(), - verbose: exec.verbose(), - }, - )), - }); + let data_sink: ParquetSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = input.schema(); + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + ctx, + &sink_schema, + extension_codec, + ) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) + }) + .transpose()? + .flatten(); + Ok(Arc::new(DataSinkExec::new( + input, + Arc::new(data_sink), + sort_order, + ))) } + #[cfg(not(feature = "parquet"))] + panic!("Trying to use ParquetSink without `parquet` feature enabled"); + } - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; - let expr = exec - .expr() + fn try_into_unnest_physical_plan( + &self, + unnest: &protobuf::UnnestExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = into_physical_plan(&unnest.input, ctx, extension_codec)?; + + Ok(Arc::new(UnnestExec::new( + input, + unnest + .list_type_columns .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) - .collect::>>()?; - let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Projection(Box::new( - protobuf::ProjectionExecNode { - input: Some(Box::new(input)), - expr, - expr_name, - }, - ))), - }); - } + .map(|c| ListUnnest { + index_in_input_schema: c.index_in_input_schema as _, + depth: c.depth as _, + }) + .collect(), + unnest.struct_type_columns.iter().map(|c| *c as _).collect(), + Arc::new(convert_required!(unnest.schema)?), + into_required!(unnest.options)?, + ))) + } - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Analyze(Box::new( - protobuf::AnalyzeExecNode { - verbose: exec.verbose(), - show_statistics: exec.show_statistics(), - input: Some(Box::new(input)), - schema: Some(exec.schema().as_ref().try_into()?), - }, - ))), - }); + fn generate_series_name_to_str(name: protobuf::GenerateSeriesName) -> &'static str { + match name { + protobuf::GenerateSeriesName::GsGenerateSeries => "generate_series", + protobuf::GenerateSeriesName::GsRange => "range", } + } + fn try_into_sort_join( + &self, + sort_join: &SortMergeJoinExecNode, + ctx: &TaskContext, - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( - protobuf::FilterExecNode { - input: Some(Box::new(input)), - expr: Some(serialize_physical_expr( - exec.predicate(), - extension_codec, - )?), - default_filter_selectivity: exec.default_selectivity() as u32, - projection: exec - .projection() - .as_ref() - .map_or_else(Vec::new, |v| { - v.iter().map(|x| *x as u32).collect::>() - }), - }, - ))), - }); - } + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let left = into_physical_plan(&sort_join.left, ctx, extension_codec)?; + let left_schema = left.schema(); + let right = into_physical_plan(&sort_join.right, ctx, extension_codec)?; + let right_schema = right.schema(); + + let filter = sort_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; - if let Some(limit) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - limit.input().to_owned(), - extension_codec, - )?; + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + ctx, + &schema, + extension_codec, + )?; + let column_indices = f + .column_indices + .iter() + .map(|i| { + let side = + protobuf::JoinSide::try_from(i.side).map_err(|_| { + proto_error(format!( + "Received a SortMergeJoinExecNode message with JoinSide in Filter {}", + i.side + )) + })?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>>()?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::GlobalLimit(Box::new( - protobuf::GlobalLimitExecNode { - input: Some(Box::new(input)), - skip: limit.skip() as u32, - fetch: match limit.fetch() { - Some(n) => n as i64, - _ => -1, // no limit - }, - }, - ))), - }); - } + Ok(JoinFilter::new( + expression, + column_indices, + Arc::new(schema), + )) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let join_type = + protobuf::JoinType::try_from(sort_join.join_type).map_err(|_| { + proto_error(format!( + "Received a SortMergeJoinExecNode message with unknown JoinType {}", + sort_join.join_type + )) + })?; - if let Some(limit) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - limit.input().to_owned(), - extension_codec, - )?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::LocalLimit(Box::new( - protobuf::LocalLimitExecNode { - input: Some(Box::new(input)), - fetch: limit.fetch() as u32, - }, - ))), - }); - } + let null_equality = protobuf::NullEquality::try_from(sort_join.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a SortMergeJoinExecNode message with unknown NullEquality {}", + sort_join.null_equality + )) + })?; + + let sort_options = sort_join + .sort_options + .iter() + .map(|e| SortOptions { + descending: !e.asc, + nulls_first: e.nulls_first, + }) + .collect(); + let on = sort_join + .on + .iter() + .map(|col| { + let left = parse_physical_expr( + &col.left.clone().unwrap(), + ctx, + left_schema.as_ref(), + extension_codec, + )?; + let right = parse_physical_expr( + &col.right.clone().unwrap(), + ctx, + right_schema.as_ref(), + extension_codec, + )?; + Ok((left, right)) + }) + .collect::>()?; + + Ok(Arc::new(SortMergeJoinExec::try_new( + left, + right, + on, + filter, + join_type.into(), + sort_options, + null_equality.into(), + )?)) + } - if let Some(exec) = plan.downcast_ref::() { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.left().to_owned(), - extension_codec, - )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.right().to_owned(), - extension_codec, - )?; - let on: Vec = exec - .on() - .iter() - .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; - Ok::<_, DataFusionError>(protobuf::JoinOn { - left: Some(l), - right: Some(r), - }) - }) - .collect::>()?; - let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); - let filter = exec - .filter() - .as_ref() - .map(|f| { - let expression = - serialize_physical_expr(f.expression(), extension_codec)?; - let column_indices = f - .column_indices() + fn try_into_generate_series_physical_plan( + &self, + generate_series: &protobuf::GenerateSeriesNode, + ) -> Result> { + let schema: SchemaRef = Arc::new(convert_required!(generate_series.schema)?); + + let args = match &generate_series.args { + Some(protobuf::generate_series_node::Args::ContainsNull(args)) => { + GenSeriesArgs::ContainsNull { + name: Self::generate_series_name_to_str(args.name()), + } + } + Some(protobuf::generate_series_node::Args::Int64Args(args)) => { + GenSeriesArgs::Int64Args { + start: args.start, + end: args.end, + step: args.step, + include_end: args.include_end, + name: Self::generate_series_name_to_str(args.name()), + } + } + Some(protobuf::generate_series_node::Args::TimestampArgs(args)) => { + let step_proto = args.step.as_ref().ok_or_else(|| { + internal_datafusion_err!("Missing step in TimestampArgs") + })?; + let step = IntervalMonthDayNanoType::make_value( + step_proto.months, + step_proto.days, + step_proto.nanos, + ); + GenSeriesArgs::TimestampArgs { + start: args.start, + end: args.end, + step, + tz: args.tz.as_ref().map(|s| Arc::from(s.as_str())), + include_end: args.include_end, + name: Self::generate_series_name_to_str(args.name()), + } + } + Some(protobuf::generate_series_node::Args::DateArgs(args)) => { + let step_proto = args.step.as_ref().ok_or_else(|| { + internal_datafusion_err!("Missing step in DateArgs") + })?; + let step = IntervalMonthDayNanoType::make_value( + step_proto.months, + step_proto.days, + step_proto.nanos, + ); + GenSeriesArgs::DateArgs { + start: args.start, + end: args.end, + step, + include_end: args.include_end, + name: Self::generate_series_name_to_str(args.name()), + } + } + None => return internal_err!("Missing args in GenerateSeriesNode"), + }; + + let table = GenerateSeriesTable::new(Arc::clone(&schema), args); + let generator = table.as_generator(generate_series.target_batch_size as usize)?; + + Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?)) + } + + fn try_into_cooperative_physical_plan( + &self, + field_stream: &protobuf::CooperativeExecNode, + ctx: &TaskContext, + + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = into_physical_plan(&field_stream.input, ctx, extension_codec)?; + Ok(Arc::new(CooperativeExec::new(input))) + } + + fn try_from_explain_exec( + exec: &ExplainExec, + _extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Explain( + protobuf::ExplainExecNode { + schema: Some(exec.schema().as_ref().try_into()?), + stringified_plans: exec + .stringified_plans() .iter() - .map(|i| { - let side: protobuf::JoinSide = i.side.to_owned().into(); - protobuf::ColumnIndex { - index: i.index as u32, - side: side.into(), - } - }) - .collect(); - let schema = f.schema().as_ref().try_into()?; - Ok(protobuf::JoinFilter { - expression: Some(expression), - column_indices, - schema: Some(schema), - }) - }) - .map_or(Ok(None), |v: Result| v.map(Some))?; + .map(|plan| plan.into()) + .collect(), + verbose: exec.verbose(), + }, + )), + }) + } + + fn try_from_projection_exec( + exec: &ProjectionExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + let expr = exec + .expr() + .iter() + .map(|proj_expr| serialize_physical_expr(&proj_expr.expr, extension_codec)) + .collect::>>()?; + let expr_name = exec + .expr() + .iter() + .map(|proj_expr| proj_expr.alias.clone()) + .collect(); + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Projection(Box::new( + protobuf::ProjectionExecNode { + input: Some(Box::new(input)), + expr, + expr_name, + }, + ))), + }) + } - let partition_mode = match exec.partition_mode() { - PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft, - PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned, - PartitionMode::Auto => protobuf::PartitionMode::Auto, - }; + fn try_from_analyze_exec( + exec: &AnalyzeExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Analyze(Box::new( + protobuf::AnalyzeExecNode { + verbose: exec.verbose(), + show_statistics: exec.show_statistics(), + input: Some(Box::new(input)), + schema: Some(exec.schema().as_ref().try_into()?), + }, + ))), + }) + } + + fn try_from_filter_exec( + exec: &FilterExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( + protobuf::FilterExecNode { + input: Some(Box::new(input)), + expr: Some(serialize_physical_expr( + exec.predicate(), + extension_codec, + )?), + default_filter_selectivity: exec.default_selectivity() as u32, + projection: exec.projection().as_ref().map_or_else(Vec::new, |v| { + v.iter().map(|x| *x as u32).collect::>() + }), + }, + ))), + }) + } - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( - protobuf::HashJoinExecNode { - left: Some(Box::new(left)), - right: Some(Box::new(right)), - on, - join_type: join_type.into(), - partition_mode: partition_mode.into(), - null_equals_null: exec.null_equals_null(), - filter, - projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { - v.iter().map(|x| *x as u32).collect::>() - }), + fn try_from_global_limit_exec( + limit: &GlobalLimitExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + limit.input().to_owned(), + extension_codec, + )?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::GlobalLimit(Box::new( + protobuf::GlobalLimitExecNode { + input: Some(Box::new(input)), + skip: limit.skip() as u32, + fetch: match limit.fetch() { + Some(n) => n as i64, + _ => -1, // no limit }, - ))), - }); - } + }, + ))), + }) + } - if let Some(exec) = plan.downcast_ref::() { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.left().to_owned(), - extension_codec, - )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.right().to_owned(), - extension_codec, - )?; - let on = exec - .on() - .iter() - .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; - Ok::<_, DataFusionError>(protobuf::JoinOn { - left: Some(l), - right: Some(r), - }) + fn try_from_local_limit_exec( + limit: &LocalLimitExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + limit.input().to_owned(), + extension_codec, + )?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::LocalLimit(Box::new( + protobuf::LocalLimitExecNode { + input: Some(Box::new(input)), + fetch: limit.fetch() as u32, + }, + ))), + }) + } + + fn try_from_hash_join_exec( + exec: &HashJoinExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on: Vec = exec + .on() + .iter() + .map(|tuple| { + let l = serialize_physical_expr(&tuple.0, extension_codec)?; + let r = serialize_physical_expr(&tuple.1, extension_codec)?; + Ok::<_, DataFusionError>(protobuf::JoinOn { + left: Some(l), + right: Some(r), }) - .collect::>()?; - let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); - let filter = exec - .filter() - .as_ref() - .map(|f| { - let expression = - serialize_physical_expr(f.expression(), extension_codec)?; - let column_indices = f - .column_indices() - .iter() - .map(|i| { - let side: protobuf::JoinSide = i.side.to_owned().into(); - protobuf::ColumnIndex { - index: i.index as u32, - side: side.into(), - } - }) - .collect(); - let schema = f.schema().as_ref().try_into()?; - Ok(protobuf::JoinFilter { - expression: Some(expression), - column_indices, - schema: Some(schema), + }) + .collect::>()?; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } }) + .collect(); + let schema = f.schema().as_ref().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), }) - .map_or(Ok(None), |v: Result| v.map(Some))?; - - let partition_mode = match exec.partition_mode() { - StreamJoinPartitionMode::SinglePartition => { - protobuf::StreamPartitionMode::SinglePartition - } - StreamJoinPartitionMode::Partitioned => { - protobuf::StreamPartitionMode::PartitionedExec - } - }; + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let partition_mode = match exec.partition_mode() { + PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft, + PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned, + PartitionMode::Auto => protobuf::PartitionMode::Auto, + }; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( + protobuf::HashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + partition_mode: partition_mode.into(), + null_equality: null_equality.into(), + filter, + projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { + v.iter().map(|x| *x as u32).collect::>() + }), + }, + ))), + }) + } - let left_sort_exprs = exec - .left_sort_exprs() - .map(|exprs| { - exprs - .iter() - .map(|expr| { - Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), - asc: !expr.options.descending, - nulls_first: expr.options.nulls_first, - }) - }) - .collect::>>() + fn try_from_symmetric_hash_join_exec( + exec: &SymmetricHashJoinExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on = exec + .on() + .iter() + .map(|tuple| { + let l = serialize_physical_expr(&tuple.0, extension_codec)?; + let r = serialize_physical_expr(&tuple.1, extension_codec)?; + Ok::<_, DataFusionError>(protobuf::JoinOn { + left: Some(l), + right: Some(r), }) - .transpose()? - .unwrap_or(vec![]); - - let right_sort_exprs = exec - .right_sort_exprs() - .map(|exprs| { - exprs - .iter() - .map(|expr| { - Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), - asc: !expr.options.descending, - nulls_first: expr.options.nulls_first, - }) - }) - .collect::>>() + }) + .collect::>()?; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().as_ref().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), }) - .transpose()? - .unwrap_or(vec![]); - - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::SymmetricHashJoin(Box::new( - protobuf::SymmetricHashJoinExecNode { - left: Some(Box::new(left)), - right: Some(Box::new(right)), - on, - join_type: join_type.into(), - partition_mode: partition_mode.into(), - null_equals_null: exec.null_equals_null(), - left_sort_exprs, - right_sort_exprs, - filter, - }, - ))), - }); - } - - if let Some(exec) = plan.downcast_ref::() { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.left().to_owned(), - extension_codec, - )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.right().to_owned(), - extension_codec, - )?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( - protobuf::CrossJoinExecNode { - left: Some(Box::new(left)), - right: Some(Box::new(right)), - }, - ))), - }); - } - if let Some(exec) = plan.downcast_ref::() { - let groups: Vec = exec - .group_expr() - .groups() - .iter() - .flatten() - .copied() - .collect(); + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; - let group_names = exec - .group_expr() - .expr() - .iter() - .map(|expr| expr.1.to_owned()) - .collect(); + let partition_mode = match exec.partition_mode() { + StreamJoinPartitionMode::SinglePartition => { + protobuf::StreamPartitionMode::SinglePartition + } + StreamJoinPartitionMode::Partitioned => { + protobuf::StreamPartitionMode::PartitionedExec + } + }; - let filter = exec - .filter_expr() - .iter() - .map(|expr| serialize_maybe_filter(expr.to_owned(), extension_codec)) - .collect::>>()?; + let left_sort_exprs = exec + .left_sort_exprs() + .map(|exprs| { + exprs + .iter() + .map(|expr| { + Ok(protobuf::PhysicalSortExprNode { + expr: Some(Box::new(serialize_physical_expr( + &expr.expr, + extension_codec, + )?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }) + }) + .collect::>>() + }) + .transpose()? + .unwrap_or(vec![]); + + let right_sort_exprs = exec + .right_sort_exprs() + .map(|exprs| { + exprs + .iter() + .map(|expr| { + Ok(protobuf::PhysicalSortExprNode { + expr: Some(Box::new(serialize_physical_expr( + &expr.expr, + extension_codec, + )?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }) + }) + .collect::>>() + }) + .transpose()? + .unwrap_or(vec![]); + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::SymmetricHashJoin(Box::new( + protobuf::SymmetricHashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + partition_mode: partition_mode.into(), + null_equality: null_equality.into(), + left_sort_exprs, + right_sort_exprs, + filter, + }, + ))), + }) + } - let agg = exec - .aggr_expr() - .iter() - .map(|expr| { - serialize_physical_aggr_expr(expr.to_owned(), extension_codec) + fn try_from_sort_merge_join_exec( + exec: &SortMergeJoinExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on = exec + .on() + .iter() + .map(|tuple| { + let l = serialize_physical_expr(&tuple.0, extension_codec)?; + let r = serialize_physical_expr(&tuple.1, extension_codec)?; + Ok::<_, DataFusionError>(protobuf::JoinOn { + left: Some(l), + right: Some(r), }) - .collect::>>()?; - - let agg_names = exec - .aggr_expr() - .iter() - .map(|expr| expr.name().to_string()) - .collect::>(); - - let agg_mode = match exec.mode() { - AggregateMode::Partial => protobuf::AggregateMode::Partial, - AggregateMode::Final => protobuf::AggregateMode::Final, - AggregateMode::FinalPartitioned => { - protobuf::AggregateMode::FinalPartitioned - } - AggregateMode::Single => protobuf::AggregateMode::Single, - AggregateMode::SinglePartitioned => { - protobuf::AggregateMode::SinglePartitioned - } - }; - let input_schema = exec.input_schema(); - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; - - let null_expr = exec - .group_expr() - .null_expr() - .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) - .collect::>>()?; - - let group_expr = exec - .group_expr() - .expr() - .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) - .collect::>>()?; + }) + .collect::>()?; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().as_ref().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + let sort_options = exec + .sort_options() + .iter() + .map( + |SortOptions { + descending, + nulls_first, + }| { + SortExprNode { + expr: None, + asc: !*descending, + nulls_first: *nulls_first, + } + }, + ) + .collect(); + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::SortMergeJoin(Box::new( + protobuf::SortMergeJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + null_equality: null_equality.into(), + filter, + sort_options, + }, + ))), + }) + } - let limit = exec.limit().map(|value| protobuf::AggLimit { - limit: value as u64, - }); + fn try_from_cross_join_exec( + exec: &CrossJoinExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( + protobuf::CrossJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + }, + ))), + }) + } - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new( - protobuf::AggregateExecNode { - group_expr, - group_expr_name: group_names, - aggr_expr: agg, - filter_expr: filter, - aggr_expr_name: agg_names, - mode: agg_mode as i32, - input: Some(Box::new(input)), - input_schema: Some(input_schema.as_ref().try_into()?), - null_expr, - groups, - limit, - }, - ))), - }); - } + fn try_from_aggregate_exec( + exec: &AggregateExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let groups: Vec = exec + .group_expr() + .groups() + .iter() + .flatten() + .copied() + .collect(); + + let group_names = exec + .group_expr() + .expr() + .iter() + .map(|expr| expr.1.to_owned()) + .collect(); + + let filter = exec + .filter_expr() + .iter() + .map(|expr| serialize_maybe_filter(expr.to_owned(), extension_codec)) + .collect::>>()?; + + let agg = exec + .aggr_expr() + .iter() + .map(|expr| serialize_physical_aggr_expr(expr.to_owned(), extension_codec)) + .collect::>>()?; + + let agg_names = exec + .aggr_expr() + .iter() + .map(|expr| expr.name().to_string()) + .collect::>(); + + let agg_mode = match exec.mode() { + AggregateMode::Partial => protobuf::AggregateMode::Partial, + AggregateMode::Final => protobuf::AggregateMode::Final, + AggregateMode::FinalPartitioned => protobuf::AggregateMode::FinalPartitioned, + AggregateMode::Single => protobuf::AggregateMode::Single, + AggregateMode::SinglePartitioned => { + protobuf::AggregateMode::SinglePartitioned + } + }; + let input_schema = exec.input_schema(); + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + let null_expr = exec + .group_expr() + .null_expr() + .iter() + .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .collect::>>()?; + + let group_expr = exec + .group_expr() + .expr() + .iter() + .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .collect::>>()?; + + let limit = exec.limit().map(|value| protobuf::AggLimit { + limit: value as u64, + }); + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new( + protobuf::AggregateExecNode { + group_expr, + group_expr_name: group_names, + aggr_expr: agg, + filter_expr: filter, + aggr_expr_name: agg_names, + mode: agg_mode as i32, + input: Some(Box::new(input)), + input_schema: Some(input_schema.as_ref().try_into()?), + null_expr, + groups, + limit, + }, + ))), + }) + } - if let Some(empty) = plan.downcast_ref::() { - let schema = empty.schema().as_ref().try_into()?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Empty( - protobuf::EmptyExecNode { - schema: Some(schema), - }, - )), - }); - } + fn try_from_empty_exec( + empty: &EmptyExec, + _extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let schema = empty.schema().as_ref().try_into()?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Empty(protobuf::EmptyExecNode { + schema: Some(schema), + })), + }) + } - if let Some(empty) = plan.downcast_ref::() { - let schema = empty.schema().as_ref().try_into()?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::PlaceholderRow( - protobuf::PlaceholderRowExecNode { - schema: Some(schema), - }, - )), - }); - } + fn try_from_placeholder_row_exec( + empty: &PlaceholderRowExec, + _extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let schema = empty.schema().as_ref().try_into()?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::PlaceholderRow( + protobuf::PlaceholderRowExecNode { + schema: Some(schema), + }, + )), + }) + } - if let Some(coalesce_batches) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - coalesce_batches.input().to_owned(), - extension_codec, - )?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::CoalesceBatches(Box::new( - protobuf::CoalesceBatchesExecNode { - input: Some(Box::new(input)), - target_batch_size: coalesce_batches.target_batch_size() as u32, - fetch: coalesce_batches.fetch().map(|n| n as u32), - }, - ))), - }); - } + fn try_from_coalesce_batches_exec( + coalesce_batches: &CoalesceBatchesExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + coalesce_batches.input().to_owned(), + extension_codec, + )?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CoalesceBatches(Box::new( + protobuf::CoalesceBatchesExecNode { + input: Some(Box::new(input)), + target_batch_size: coalesce_batches.target_batch_size() as u32, + fetch: coalesce_batches.fetch().map(|n| n as u32), + }, + ))), + }) + } - if let Some(data_source_exec) = plan.downcast_ref::() { - let data_source = data_source_exec.data_source(); - if let Some(maybe_csv) = data_source.as_any().downcast_ref::() - { - let source = maybe_csv.file_source(); - if let Some(csv_config) = source.as_any().downcast_ref::() { - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::CsvScan( - protobuf::CsvScanExecNode { - base_conf: Some(serialize_file_scan_config( - maybe_csv, - extension_codec, - )?), - has_header: csv_config.has_header(), - delimiter: byte_to_string( - csv_config.delimiter(), - "delimiter", - )?, - quote: byte_to_string(csv_config.quote(), "quote")?, - optional_escape: if let Some(escape) = csv_config.escape() - { - Some( - protobuf::csv_scan_exec_node::OptionalEscape::Escape( - byte_to_string(escape, "escape")?, - ), - ) - } else { - None - }, - optional_comment: if let Some(comment) = - csv_config.comment() - { - Some(protobuf::csv_scan_exec_node::OptionalComment::Comment( + fn try_from_data_source_exec( + data_source_exec: &DataSourceExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let data_source = data_source_exec.data_source(); + if let Some(maybe_csv) = data_source.as_any().downcast_ref::() { + let source = maybe_csv.file_source(); + if let Some(csv_config) = source.as_any().downcast_ref::() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CsvScan( + protobuf::CsvScanExecNode { + base_conf: Some(serialize_file_scan_config( + maybe_csv, + extension_codec, + )?), + has_header: csv_config.has_header(), + delimiter: byte_to_string( + csv_config.delimiter(), + "delimiter", + )?, + quote: byte_to_string(csv_config.quote(), "quote")?, + optional_escape: if let Some(escape) = csv_config.escape() { + Some( + protobuf::csv_scan_exec_node::OptionalEscape::Escape( + byte_to_string(escape, "escape")?, + ), + ) + } else { + None + }, + optional_comment: if let Some(comment) = csv_config.comment() + { + Some(protobuf::csv_scan_exec_node::OptionalComment::Comment( byte_to_string(comment, "comment")?, )) - } else { - None - }, - newlines_in_values: maybe_csv.newlines_in_values(), + } else { + None }, - )), - }); - } + newlines_in_values: maybe_csv.newlines_in_values(), + truncate_rows: csv_config.truncate_rows(), + }, + )), + })); } } - if let Some(data_source_exec) = plan.downcast_ref::() { - let data_source = data_source_exec.data_source(); - if let Some(scan_conf) = data_source.as_any().downcast_ref::() - { - let source = scan_conf.file_source(); - if let Some(_json_source) = source.as_any().downcast_ref::() { - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::JsonScan( - protobuf::JsonScanExecNode { - base_conf: Some(serialize_file_scan_config( - scan_conf, - extension_codec, - )?), - }, - )), - }); - } + if let Some(scan_conf) = data_source.as_any().downcast_ref::() { + let source = scan_conf.file_source(); + if let Some(_json_source) = source.as_any().downcast_ref::() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::JsonScan( + protobuf::JsonScanExecNode { + base_conf: Some(serialize_file_scan_config( + scan_conf, + extension_codec, + )?), + }, + )), + })); } } #[cfg(feature = "parquet")] - if let Some(exec) = plan.downcast_ref::() { - if let Some((maybe_parquet, conf)) = - exec.downcast_to_file_source::() - { - let predicate = conf - .predicate() - .map(|pred| serialize_physical_expr(pred, extension_codec)) - .transpose()?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::ParquetScan( - protobuf::ParquetScanExecNode { + if let Some((maybe_parquet, conf)) = + data_source_exec.downcast_to_file_source::() + { + let predicate = conf + .filter() + .map(|pred| serialize_physical_expr(&pred, extension_codec)) + .transpose()?; + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ParquetScan( + protobuf::ParquetScanExecNode { + base_conf: Some(serialize_file_scan_config( + maybe_parquet, + extension_codec, + )?), + predicate, + parquet_options: Some(conf.table_parquet_options().try_into()?), + }, + )), + })); + } + + #[cfg(feature = "avro")] + if let Some(maybe_avro) = data_source.as_any().downcast_ref::() { + let source = maybe_avro.file_source(); + if source.as_any().downcast_ref::().is_some() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::AvroScan( + protobuf::AvroScanExecNode { base_conf: Some(serialize_file_scan_config( - maybe_parquet, + maybe_avro, extension_codec, )?), - predicate, - parquet_options: Some( - conf.table_parquet_options().try_into()?, - ), }, )), - }); + })); } } - #[cfg(feature = "avro")] - if let Some(data_source_exec) = plan.downcast_ref::() { - let data_source = data_source_exec.data_source(); - if let Some(maybe_avro) = - data_source.as_any().downcast_ref::() - { - let source = maybe_avro.file_source(); - if source.as_any().downcast_ref::().is_some() { - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::AvroScan( - protobuf::AvroScanExecNode { - base_conf: Some(serialize_file_scan_config( - maybe_avro, - extension_codec, - )?), - }, - )), - }); - } - } - } + if let Some(source_conf) = + data_source.as_any().downcast_ref::() + { + let proto_partitions = source_conf + .partitions() + .iter() + .map(|p| serialize_record_batches(p)) + .collect::>>()?; - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( - protobuf::CoalescePartitionsExecNode { - input: Some(Box::new(input)), + let proto_schema: protobuf::Schema = + source_conf.original_schema().as_ref().try_into()?; + + let proto_projection = source_conf + .projection() + .as_ref() + .map_or_else(Vec::new, |v| { + v.iter().map(|x| *x as u32).collect::>() + }); + + let proto_sort_information = source_conf + .sort_information() + .iter() + .map(|ordering| { + let sort_exprs = serialize_physical_sort_exprs( + ordering.to_owned(), + extension_codec, + )?; + Ok::<_, DataFusionError>(protobuf::PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: sort_exprs, + }) + }) + .collect::, _>>()?; + + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::MemoryScan( + protobuf::MemoryScanExecNode { + partitions: proto_partitions, + schema: Some(proto_schema), + projection: proto_projection, + sort_information: proto_sort_information, + show_sizes: source_conf.show_sizes(), + fetch: source_conf.fetch().map(|f| f as u32), }, - ))), - }); + )), + })); } - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; + Ok(None) + } + + fn try_from_coalesce_partitions_exec( + exec: &CoalescePartitionsExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( + protobuf::CoalescePartitionsExecNode { + input: Some(Box::new(input)), + fetch: exec.fetch().map(|f| f as u32), + }, + ))), + }) + } + + fn try_from_repartition_exec( + exec: &RepartitionExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + let pb_partitioning = + serialize_partitioning(exec.partitioning(), extension_codec)?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( + protobuf::RepartitionExecNode { + input: Some(Box::new(input)), + partitioning: Some(pb_partitioning), + }, + ))), + }) + } - let pb_partitioning = - serialize_partitioning(exec.partitioning(), extension_codec)?; + fn try_from_sort_exec( + exec: &SortExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + let expr = exec + .expr() + .iter() + .map(|expr| { + let sort_expr = Box::new(protobuf::PhysicalSortExprNode { + expr: Some(Box::new(serialize_physical_expr( + &expr.expr, + extension_codec, + )?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }); + Ok(protobuf::PhysicalExprNode { + expr_type: Some(ExprType::Sort(sort_expr)), + }) + }) + .collect::>>()?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Sort(Box::new( + protobuf::SortExecNode { + input: Some(Box::new(input)), + expr, + fetch: match exec.fetch() { + Some(n) => n as i64, + _ => -1, + }, + preserve_partitioning: exec.preserve_partitioning(), + }, + ))), + }) + } - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( - protobuf::RepartitionExecNode { - input: Some(Box::new(input)), - partitioning: Some(pb_partitioning), - }, - ))), - }); + fn try_from_union_exec( + union: &UnionExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let mut inputs: Vec = vec![]; + for input in union.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); } + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Union(protobuf::UnionExecNode { + inputs, + })), + }) + } - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), + fn try_from_interleave_exec( + interleave: &InterleaveExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let mut inputs: Vec = vec![]; + for input in interleave.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), extension_codec, - )?; - let expr = exec - .expr() - .iter() - .map(|expr| { - let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), - asc: !expr.options.descending, - nulls_first: expr.options.nulls_first, - }); - Ok(protobuf::PhysicalExprNode { - expr_type: Some(ExprType::Sort(sort_expr)), + )?); + } + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Interleave( + protobuf::InterleaveExecNode { inputs }, + )), + }) + } + + fn try_from_sort_preserving_merge_exec( + exec: &SortPreservingMergeExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + let expr = exec + .expr() + .iter() + .map(|expr| { + let sort_expr = Box::new(protobuf::PhysicalSortExprNode { + expr: Some(Box::new(serialize_physical_expr( + &expr.expr, + extension_codec, + )?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }); + Ok(protobuf::PhysicalExprNode { + expr_type: Some(ExprType::Sort(sort_expr)), + }) + }) + .collect::>>()?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::SortPreservingMerge(Box::new( + protobuf::SortPreservingMergeExecNode { + input: Some(Box::new(input)), + expr, + fetch: exec.fetch().map(|f| f as i64).unwrap_or(-1), + }, + ))), + }) + } + + fn try_from_nested_loop_join_exec( + exec: &NestedLoopJoinExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } }) + .collect(); + let schema = f.schema().as_ref().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), }) - .collect::>>()?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Sort(Box::new( - protobuf::SortExecNode { - input: Some(Box::new(input)), - expr, - fetch: match exec.fetch() { - Some(n) => n as i64, - _ => -1, - }, - preserve_partitioning: exec.preserve_partitioning(), - }, - ))), - }); - } + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::NestedLoopJoin(Box::new( + protobuf::NestedLoopJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + join_type: join_type.into(), + filter, + projection: exec.projection().map_or_else(Vec::new, |v| { + v.iter().map(|x| *x as u32).collect::>() + }), + }, + ))), + }) + } - if let Some(union) = plan.downcast_ref::() { - let mut inputs: Vec = vec![]; - for input in union.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); - } - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Union( - protobuf::UnionExecNode { inputs }, - )), - }); - } + fn try_from_window_agg_exec( + exec: &WindowAggExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + let window_expr = exec + .window_expr() + .iter() + .map(|e| serialize_physical_window_expr(e, extension_codec)) + .collect::>>()?; + + let partition_keys = exec + .partition_keys() + .iter() + .map(|e| serialize_physical_expr(e, extension_codec)) + .collect::>>()?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Window(Box::new( + protobuf::WindowAggExecNode { + input: Some(Box::new(input)), + window_expr, + partition_keys, + input_order_mode: None, + }, + ))), + }) + } - if let Some(interleave) = plan.downcast_ref::() { - let mut inputs: Vec = vec![]; - for input in interleave.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + fn try_from_bounded_window_agg_exec( + exec: &BoundedWindowAggExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + let window_expr = exec + .window_expr() + .iter() + .map(|e| serialize_physical_window_expr(e, extension_codec)) + .collect::>>()?; + + let partition_keys = exec + .partition_keys() + .iter() + .map(|e| serialize_physical_expr(e, extension_codec)) + .collect::>>()?; + + let input_order_mode = match &exec.input_order_mode { + InputOrderMode::Linear => { + window_agg_exec_node::InputOrderMode::Linear(protobuf::EmptyMessage {}) } - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Interleave( - protobuf::InterleaveExecNode { inputs }, - )), - }); - } + InputOrderMode::PartiallySorted(columns) => { + window_agg_exec_node::InputOrderMode::PartiallySorted( + protobuf::PartiallySortedInputOrderMode { + columns: columns.iter().map(|c| *c as u64).collect(), + }, + ) + } + InputOrderMode::Sorted => { + window_agg_exec_node::InputOrderMode::Sorted(protobuf::EmptyMessage {}) + } + }; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Window(Box::new( + protobuf::WindowAggExecNode { + input: Some(Box::new(input)), + window_expr, + partition_keys, + input_order_mode: Some(input_order_mode), + }, + ))), + }) + } - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + fn try_from_data_sink_exec( + exec: &DataSinkExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input: protobuf::PhysicalPlanNode = + protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(), extension_codec, )?; - let expr = exec - .expr() - .iter() - .map(|expr| { - let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), - asc: !expr.options.descending, - nulls_first: expr.options.nulls_first, - }); - Ok(protobuf::PhysicalExprNode { - expr_type: Some(ExprType::Sort(sort_expr)), + let sort_order = match exec.sort_order() { + Some(requirements) => { + let expr = requirements + .iter() + .map(|requirement| { + let expr: PhysicalSortExpr = requirement.to_owned().into(); + let sort_expr = protobuf::PhysicalSortExprNode { + expr: Some(Box::new(serialize_physical_expr( + &expr.expr, + extension_codec, + )?)), + asc: !expr.options.descending, + nulls_first: expr.options.nulls_first, + }; + Ok(sort_expr) }) + .collect::>>()?; + Some(protobuf::PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: expr, }) - .collect::>>()?; - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::SortPreservingMerge( - Box::new(protobuf::SortPreservingMergeExecNode { + } + None => None, + }; + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::JsonSink(Box::new( + protobuf::JsonSinkExecNode { input: Some(Box::new(input)), - expr, - fetch: exec.fetch().map(|f| f as i64).unwrap_or(-1), - }), - )), - }); + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + })); } - if let Some(exec) = plan.downcast_ref::() { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.left().to_owned(), - extension_codec, - )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.right().to_owned(), - extension_codec, - )?; - - let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); - let filter = exec - .filter() - .as_ref() - .map(|f| { - let expression = - serialize_physical_expr(f.expression(), extension_codec)?; - let column_indices = f - .column_indices() - .iter() - .map(|i| { - let side: protobuf::JoinSide = i.side.to_owned().into(); - protobuf::ColumnIndex { - index: i.index as u32, - side: side.into(), - } - }) - .collect(); - let schema = f.schema().as_ref().try_into()?; - Ok(protobuf::JoinFilter { - expression: Some(expression), - column_indices, - schema: Some(schema), - }) - }) - .map_or(Ok(None), |v: Result| v.map(Some))?; - - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::NestedLoopJoin(Box::new( - protobuf::NestedLoopJoinExecNode { - left: Some(Box::new(left)), - right: Some(Box::new(right)), - join_type: join_type.into(), - filter, - projection: exec.projection().map_or_else(Vec::new, |v| { - v.iter().map(|x| *x as u32).collect::>() - }), + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CsvSink(Box::new( + protobuf::CsvSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, }, ))), - }); + })); } - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; - - let window_expr = exec - .window_expr() - .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) - .collect::>>()?; - - let partition_keys = exec - .partition_keys() - .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) - .collect::>>()?; - - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Window(Box::new( - protobuf::WindowAggExecNode { + #[cfg(feature = "parquet")] + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( + protobuf::ParquetSinkExecNode { input: Some(Box::new(input)), - window_expr, - partition_keys, - input_order_mode: None, + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, }, ))), - }); + })); } - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; + // If unknown DataSink then let extension handle it + Ok(None) + } - let window_expr = exec - .window_expr() - .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) - .collect::>>()?; + fn try_from_unnest_exec( + exec: &UnnestExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Unnest(Box::new( + protobuf::UnnestExecNode { + input: Some(Box::new(input)), + schema: Some(exec.schema().try_into()?), + list_type_columns: exec + .list_column_indices() + .iter() + .map(|c| ProtoListUnnest { + index_in_input_schema: c.index_in_input_schema as _, + depth: c.depth as _, + }) + .collect(), + struct_type_columns: exec + .struct_column_indices() + .iter() + .map(|c| *c as _) + .collect(), + options: Some(exec.options().into()), + }, + ))), + }) + } - let partition_keys = exec - .partition_keys() - .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) - .collect::>>()?; + fn try_from_cooperative_exec( + exec: &CooperativeExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Cooperative(Box::new( + protobuf::CooperativeExecNode { + input: Some(Box::new(input)), + }, + ))), + }) + } - let input_order_mode = match &exec.input_order_mode { - InputOrderMode::Linear => window_agg_exec_node::InputOrderMode::Linear( - protobuf::EmptyMessage {}, - ), - InputOrderMode::PartiallySorted(columns) => { - window_agg_exec_node::InputOrderMode::PartiallySorted( - protobuf::PartiallySortedInputOrderMode { - columns: columns.iter().map(|c| *c as u64).collect(), - }, - ) - } - InputOrderMode::Sorted => window_agg_exec_node::InputOrderMode::Sorted( - protobuf::EmptyMessage {}, - ), - }; + fn str_to_generate_series_name(name: &str) -> Result { + match name { + "generate_series" => Ok(protobuf::GenerateSeriesName::GsGenerateSeries), + "range" => Ok(protobuf::GenerateSeriesName::GsRange), + _ => internal_err!("unknown name: {name}"), + } + } - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Window(Box::new( - protobuf::WindowAggExecNode { - input: Some(Box::new(input)), - window_expr, - partition_keys, - input_order_mode: Some(input_order_mode), + fn try_from_lazy_memory_exec(exec: &LazyMemoryExec) -> Result> { + let generators = exec.generators(); + + // ensure we only have one generator + let [generator] = generators.as_slice() else { + return Ok(None); + }; + + let generator_guard = generator.read(); + + // Try to downcast to different generate_series types + if let Some(empty_gen) = generator_guard.as_any().downcast_ref::() { + let schema = exec.schema(); + let node = protobuf::GenerateSeriesNode { + schema: Some(schema.as_ref().try_into()?), + target_batch_size: 8192, // Default batch size + args: Some(protobuf::generate_series_node::Args::ContainsNull( + protobuf::GenerateSeriesArgsContainsNull { + name: Self::str_to_generate_series_name(empty_gen.name())? as i32, }, - ))), - }); + )), + }; + + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::GenerateSeries(node)), + })); } - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; - let sort_order = match exec.sort_order() { - Some(requirements) => { - let expr = requirements - .iter() - .map(|requirement| { - let expr: PhysicalSortExpr = requirement.to_owned().into(); - let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), - asc: !expr.options.descending, - nulls_first: expr.options.nulls_first, - }; - Ok(sort_expr) - }) - .collect::>>()?; - Some(protobuf::PhysicalSortExprNodeCollection { - physical_sort_expr_nodes: expr, - }) - } - None => None, + if let Some(int_64) = generator_guard + .as_any() + .downcast_ref::>() + { + let schema = exec.schema(); + let node = protobuf::GenerateSeriesNode { + schema: Some(schema.as_ref().try_into()?), + target_batch_size: int_64.batch_size() as u32, + args: Some(protobuf::generate_series_node::Args::Int64Args( + protobuf::GenerateSeriesArgsInt64 { + start: *int_64.start(), + end: *int_64.end(), + step: *int_64.step(), + include_end: int_64.include_end(), + name: Self::str_to_generate_series_name(int_64.name())? as i32, + }, + )), }; - if let Some(sink) = exec.sink().as_any().downcast_ref::() { - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::JsonSink(Box::new( - protobuf::JsonSinkExecNode { - input: Some(Box::new(input)), - sink: Some(sink.try_into()?), - sink_schema: Some(exec.schema().as_ref().try_into()?), - sort_order, - }, - ))), - }); - } - - if let Some(sink) = exec.sink().as_any().downcast_ref::() { - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::CsvSink(Box::new( - protobuf::CsvSinkExecNode { - input: Some(Box::new(input)), - sink: Some(sink.try_into()?), - sink_schema: Some(exec.schema().as_ref().try_into()?), - sort_order, - }, - ))), - }); - } + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::GenerateSeries(node)), + })); + } - #[cfg(feature = "parquet")] - if let Some(sink) = exec.sink().as_any().downcast_ref::() { - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( - protobuf::ParquetSinkExecNode { - input: Some(Box::new(input)), - sink: Some(sink.try_into()?), - sink_schema: Some(exec.schema().as_ref().try_into()?), - sort_order, - }, - ))), - }); - } + if let Some(timestamp_args) = generator_guard + .as_any() + .downcast_ref::>() + { + let schema = exec.schema(); - // If unknown DataSink then let extension handle it - } + let start = timestamp_args.start().value(); + let end = timestamp_args.end().value(); - if let Some(exec) = plan.downcast_ref::() { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; + let step_value = timestamp_args.step(); - return Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Unnest(Box::new( - protobuf::UnnestExecNode { - input: Some(Box::new(input)), - schema: Some(exec.schema().try_into()?), - list_type_columns: exec - .list_column_indices() - .iter() - .map(|c| ProtoListUnnest { - index_in_input_schema: c.index_in_input_schema as _, - depth: c.depth as _, - }) - .collect(), - struct_type_columns: exec - .struct_column_indices() - .iter() - .map(|c| *c as _) - .collect(), - options: Some(exec.options().into()), - }, - ))), + let step = Some(datafusion_proto_common::IntervalMonthDayNanoValue { + months: step_value.months, + days: step_value.days, + nanos: step_value.nanoseconds, }); - } + let include_end = timestamp_args.include_end(); + let name = Self::str_to_generate_series_name(timestamp_args.name())? as i32; + + let args = match timestamp_args.current().tz_str() { + Some(tz) => protobuf::generate_series_node::Args::TimestampArgs( + protobuf::GenerateSeriesArgsTimestamp { + start, + end, + step, + include_end, + name, + tz: Some(tz.to_string()), + }, + ), + None => protobuf::generate_series_node::Args::DateArgs( + protobuf::GenerateSeriesArgsDate { + start, + end, + step, + include_end, + name, + }, + ), + }; - let mut buf: Vec = vec![]; - match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { - Ok(_) => { - let inputs: Vec = plan_clone - .children() - .into_iter() - .cloned() - .map(|i| { - protobuf::PhysicalPlanNode::try_from_physical_plan( - i, - extension_codec, - ) - }) - .collect::>()?; + let node = protobuf::GenerateSeriesNode { + schema: Some(schema.as_ref().try_into()?), + target_batch_size: timestamp_args.batch_size() as u32, + args: Some(args), + }; - Ok(protobuf::PhysicalPlanNode { - physical_plan_type: Some(PhysicalPlanType::Extension( - protobuf::PhysicalExtensionNode { node: buf, inputs }, - )), - }) - } - Err(e) => internal_err!( - "Unsupported plan and extension codec failed with [{e}]. Plan: {plan_clone:?}" - ), + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::GenerateSeries(node)), + })); } + + Ok(None) } } @@ -2166,8 +3218,8 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { fn try_into_physical_plan( &self, - registry: &dyn FunctionRegistry, - runtime: &RuntimeEnv, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, ) -> Result>; @@ -2184,7 +3236,7 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { &self, buf: &[u8], inputs: &[Arc], - registry: &dyn FunctionRegistry, + ctx: &TaskContext, ) -> Result>; fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()>; @@ -2240,7 +3292,7 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { &self, _buf: &[u8], _inputs: &[Arc], - _registry: &dyn FunctionRegistry, + _ctx: &TaskContext, ) -> Result> { not_impl_err!("PhysicalExtensionCodec is not provided") } @@ -2254,14 +3306,126 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { } } +/// DataEncoderTuple captures the position of the encoder +/// in the codec list that was used to encode the data and actual encoded data +#[derive(Clone, PartialEq, prost::Message)] +struct DataEncoderTuple { + /// The position of encoder used to encode data + /// (to be used for decoding) + #[prost(uint32, tag = 1)] + pub encoder_position: u32, + + #[prost(bytes, tag = 2)] + pub blob: Vec, +} + +/// A PhysicalExtensionCodec that tries one of multiple inner codecs +/// until one works +#[derive(Debug)] +pub struct ComposedPhysicalExtensionCodec { + codecs: Vec>, +} + +impl ComposedPhysicalExtensionCodec { + // Position in this codecs list is important as it will be used for decoding. + // If new codec is added it should go to last position. + pub fn new(codecs: Vec>) -> Self { + Self { codecs } + } + + fn decode_protobuf( + &self, + buf: &[u8], + decode: impl FnOnce(&dyn PhysicalExtensionCodec, &[u8]) -> Result, + ) -> Result { + let proto = + DataEncoderTuple::decode(buf).map_err(|e| internal_datafusion_err!("{e}"))?; + + let codec = self.codecs.get(proto.encoder_position as usize).ok_or( + internal_datafusion_err!("Can't find required codec in codec list"), + )?; + + decode(codec.as_ref(), &proto.blob) + } + + fn encode_protobuf( + &self, + buf: &mut Vec, + mut encode: impl FnMut(&dyn PhysicalExtensionCodec, &mut Vec) -> Result<()>, + ) -> Result<()> { + let mut data = vec![]; + let mut last_err = None; + let mut encoder_position = None; + + // find the encoder + for (position, codec) in self.codecs.iter().enumerate() { + match encode(codec.as_ref(), &mut data) { + Ok(_) => { + encoder_position = Some(position as u32); + break; + } + Err(err) => last_err = Some(err), + } + } + + let encoder_position = encoder_position.ok_or_else(|| { + last_err.unwrap_or_else(|| { + DataFusionError::NotImplemented( + "Empty list of composed codecs".to_owned(), + ) + }) + })?; + + // encode with encoder position + let proto = DataEncoderTuple { + encoder_position, + blob: data, + }; + proto + .encode(buf) + .map_err(|e| internal_datafusion_err!("{e}")) + } +} + +impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + ctx: &TaskContext, + ) -> Result> { + self.decode_protobuf(buf, |codec, data| codec.try_decode(data, inputs, ctx)) + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + self.encode_protobuf(buf, |codec, data| codec.try_encode(Arc::clone(&node), data)) + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + self.decode_protobuf(buf, |codec, data| codec.try_decode_udf(name, data)) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + self.encode_protobuf(buf, |codec, data| codec.try_encode_udf(node, data)) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + self.decode_protobuf(buf, |codec, data| codec.try_decode_udaf(name, data)) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + self.encode_protobuf(buf, |codec, data| codec.try_encode_udaf(node, data)) + } +} + fn into_physical_plan( node: &Option>, - registry: &dyn FunctionRegistry, - runtime: &RuntimeEnv, + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, -) -> Result, DataFusionError> { +) -> Result> { if let Some(field) = node { - field.try_into_physical_plan(registry, runtime, extension_codec) + field.try_into_physical_plan(ctx, extension_codec) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 1384e6c0c32b3..19a76de3e5b08 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -17,11 +17,16 @@ use std::sync::Arc; +use arrow::array::RecordBatch; +use arrow::datatypes::Schema; +use arrow::ipc::writer::StreamWriter; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::physical_plan::FileSink; use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; -use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; +use datafusion::physical_expr::ScalarFunctionExpr; +use datafusion::physical_expr_common::physical_expr::snapshot_physical_expr; +use datafusion::physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, @@ -37,7 +42,9 @@ use datafusion::{ }, physical_plan::expressions::LikeExpr, }; -use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + internal_datafusion_err, internal_err, not_impl_err, DataFusionError, Result, +}; use datafusion_expr::WindowFrame; use crate::protobuf::{ @@ -52,11 +59,8 @@ pub fn serialize_physical_aggr_expr( codec: &dyn PhysicalExtensionCodec, ) -> Result { let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; - let ordering_req = match aggr_expr.order_bys() { - Some(order) => order.clone(), - None => LexOrdering::default(), - }; - let ordering_req = serialize_physical_sort_exprs(ordering_req, codec)?; + let order_bys = + serialize_physical_sort_exprs(aggr_expr.order_bys().iter().cloned(), codec)?; let name = aggr_expr.fun().name().to_string(); let mut buf = Vec::new(); @@ -66,10 +70,11 @@ pub fn serialize_physical_aggr_expr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, - ordering_req, + ordering_req: order_bys, distinct: aggr_expr.is_distinct(), ignore_nulls: aggr_expr.ignore_nulls(), fun_definition: (!buf.is_empty()).then_some(buf), + human_display: aggr_expr.human_display().to_string(), }, )), }) @@ -80,12 +85,7 @@ fn serialize_physical_window_aggr_expr( _window_frame: &WindowFrame, codec: &dyn PhysicalExtensionCodec, ) -> Result<(physical_window_expr_node::WindowFunction, Option>)> { - if aggr_expr.is_distinct() || aggr_expr.ignore_nulls() { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } + // Distinct and ignore_nulls are now supported in window expressions let mut buf = Vec::new(); codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; @@ -105,44 +105,55 @@ pub fn serialize_physical_window_expr( let args = window_expr.expressions().to_vec(); let window_frame = window_expr.get_window_frame(); - let (window_function, fun_definition) = if let Some(plain_aggr_window_expr) = - expr.downcast_ref::() - { - serialize_physical_window_aggr_expr( - plain_aggr_window_expr.get_aggregate_expr(), - window_frame, - codec, - )? - } else if let Some(sliding_aggr_window_expr) = - expr.downcast_ref::() - { - serialize_physical_window_aggr_expr( - sliding_aggr_window_expr.get_aggregate_expr(), - window_frame, - codec, - )? - } else if let Some(udf_window_expr) = expr.downcast_ref::() { - if let Some(expr) = udf_window_expr - .get_standard_func_expr() - .as_any() - .downcast_ref::() + let (window_function, fun_definition, ignore_nulls, distinct) = + if let Some(plain_aggr_window_expr) = + expr.downcast_ref::() { - let mut buf = Vec::new(); - codec.try_encode_udwf(expr.fun(), &mut buf)?; + let aggr_expr = plain_aggr_window_expr.get_aggregate_expr(); + let (window_function, fun_definition) = + serialize_physical_window_aggr_expr(aggr_expr, window_frame, codec)?; ( - physical_window_expr_node::WindowFunction::UserDefinedWindowFunction( - expr.fun().name().to_string(), - ), - (!buf.is_empty()).then_some(buf), + window_function, + fun_definition, + aggr_expr.ignore_nulls(), + aggr_expr.is_distinct(), ) + } else if let Some(sliding_aggr_window_expr) = + expr.downcast_ref::() + { + let aggr_expr = sliding_aggr_window_expr.get_aggregate_expr(); + let (window_function, fun_definition) = + serialize_physical_window_aggr_expr(aggr_expr, window_frame, codec)?; + ( + window_function, + fun_definition, + aggr_expr.ignore_nulls(), + aggr_expr.is_distinct(), + ) + } else if let Some(udf_window_expr) = expr.downcast_ref::() { + if let Some(expr) = udf_window_expr + .get_standard_func_expr() + .as_any() + .downcast_ref::() + { + let mut buf = Vec::new(); + codec.try_encode_udwf(expr.fun(), &mut buf)?; + ( + physical_window_expr_node::WindowFunction::UserDefinedWindowFunction( + expr.fun().name().to_string(), + ), + (!buf.is_empty()).then_some(buf), + false, // WindowUDFExpr doesn't have ignore_nulls/distinct + false, + ) + } else { + return not_impl_err!( + "User-defined window function not supported: {window_expr:?}" + ); + } } else { - return not_impl_err!( - "User-defined window function not supported: {window_expr:?}" - ); - } - } else { - return not_impl_err!("WindowExpr not supported: {window_expr:?}"); - }; + return not_impl_err!("WindowExpr not supported: {window_expr:?}"); + }; let args = serialize_physical_exprs(&args, codec)?; let partition_by = serialize_physical_exprs(window_expr.partition_by(), codec)?; @@ -150,7 +161,7 @@ pub fn serialize_physical_window_expr( let window_frame: protobuf::WindowFrame = window_frame .as_ref() .try_into() - .map_err(|e| DataFusionError::Internal(format!("{e}")))?; + .map_err(|e| internal_datafusion_err!("{e}"))?; Ok(protobuf::PhysicalWindowExprNode { args, @@ -160,6 +171,8 @@ pub fn serialize_physical_window_expr( window_function: Some(window_function), name: window_expr.name().to_string(), fun_definition, + ignore_nulls, + distinct, }) } @@ -210,6 +223,9 @@ pub fn serialize_physical_expr( value: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { + // Snapshot the expr in case it has dynamic predicate state so + // it can be serialized + let value = snapshot_physical_expr(Arc::clone(value))?; let expr = value.as_any(); if let Some(expr) = expr.downcast_ref::() { @@ -349,6 +365,10 @@ pub fn serialize_physical_expr( fun_definition: (!buf.is_empty()).then_some(buf), return_type: Some(expr.return_type().try_into()?), nullable: expr.nullable(), + return_field_name: expr + .return_field(&Schema::empty())? + .name() + .to_string(), }, )), }) @@ -368,7 +388,7 @@ pub fn serialize_physical_expr( }) } else { let mut buf: Vec = vec![]; - match codec.try_encode_expr(value, &mut buf) { + match codec.try_encode_expr(&value, &mut buf) { Ok(_) => { let inputs: Vec = value .children() @@ -441,7 +461,7 @@ impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { })? as u64; Ok(protobuf::PartitionedFile { path: pf.object_meta.location.as_ref().to_owned(), - size: pf.object_meta.size as u64, + size: pf.object_meta.size, last_modified_ns, partition_values: pf .partition_values @@ -502,12 +522,16 @@ pub fn serialize_file_scan_config( .iter() .cloned() .collect::>(); - fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new)); - let schema = Arc::new(arrow::datatypes::Schema::new(fields.clone())); + fields.extend(conf.table_partition_cols.iter().cloned()); + + let schema = Arc::new( + arrow::datatypes::Schema::new(fields.clone()) + .with_metadata(conf.file_schema.metadata.clone()), + ); Ok(protobuf::FileScanExecConf { file_groups, - statistics: Some((&conf.statistics).into()), + statistics: Some((&conf.file_source.statistics().unwrap()).into()), limit: conf.limit.map(|l| protobuf::ScanLimit { limit: l as u32 }), projection: conf .projection @@ -546,6 +570,20 @@ pub fn serialize_maybe_filter( } } +pub fn serialize_record_batches(batches: &[RecordBatch]) -> Result> { + if batches.is_empty() { + return Ok(vec![]); + } + let schema = batches[0].schema(); + let mut buf = Vec::new(); + let mut writer = StreamWriter::try_new(&mut buf, &schema)?; + for batch in batches { + writer.write(batch)?; + } + writer.finish()?; + Ok(buf) +} + impl TryFrom<&JsonSink> for protobuf::JsonSink { type Error = DataFusionError; diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index 92d961fc75562..aec6c1de30309 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, PartitionEvaluator, ScalarFunctionArgs, ScalarUDFImpl, - Signature, Volatility, WindowUDFImpl, + Accumulator, AggregateUDFImpl, LimitEffect, PartitionEvaluator, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::any::Any; use std::fmt::Debug; +use std::hash::Hash; +use std::sync::Arc; mod roundtrip_logical_plan; mod roundtrip_physical_plan; @@ -131,7 +134,7 @@ pub struct MyAggregateUdfNode { pub result: String, } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(in crate::cases) struct CustomUDWF { signature: Signature, payload: String, @@ -166,8 +169,15 @@ impl WindowUDFImpl for CustomUDWF { Ok(Box::new(CustomUDWFEvaluator {})) } - fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + fn field( + &self, + field_args: WindowUDFFieldArgs, + ) -> datafusion_common::Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) + } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9fa1f74ae188a..c5d4b49092d91 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -19,15 +19,20 @@ use arrow::array::{ ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, }; use arrow::datatypes::{ - DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, - IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, - DECIMAL256_MAX_PRECISION, + DataType, Field, FieldRef, Fields, Int32Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, + UnionMode, DECIMAL256_MAX_PRECISION, }; use arrow::util::pretty::pretty_format_batches; -use datafusion::datasource::file_format::json::JsonFormatFactory; +use datafusion::datasource::file_format::json::{JsonFormat, JsonFormatFactory}; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::execution::options::ArrowReadOptions; use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion; use datafusion::optimizer::Optimizer; use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_functions_aggregate::sum::sum_distinct; use prost::Message; use std::any::Any; use std::collections::HashMap; @@ -41,6 +46,7 @@ use datafusion::datasource::file_format::arrow::ArrowFormatFactory; use datafusion::datasource::file_format::csv::CsvFormatFactory; use datafusion::datasource::file_format::parquet::ParquetFormatFactory; use datafusion::datasource::file_format::{format_as_file_type, DefaultFileType}; +use datafusion::datasource::DefaultTableSource; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; @@ -56,6 +62,7 @@ use datafusion::functions_window::expr_fn::{ cume_dist, dense_rank, lag, lead, ntile, percent_rank, rank, row_number, }; use datafusion::functions_window::rank::rank_udwf; +use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; @@ -71,15 +78,15 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, Literal, - LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + Accumulator, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, + LimitEffect, Literal, LogicalPlan, LogicalPlanBuilder, Operator, PartitionEvaluator, + ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ - approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, - nth_value, + approx_distinct, array_agg, avg, avg_distinct, bit_and, bit_or, bit_xor, bool_and, + bool_or, corr, nth_value, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -110,15 +117,21 @@ fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { #[cfg(not(feature = "json"))] fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} -// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test -// equality. fn roundtrip_expr_test(initial_struct: Expr, ctx: SessionContext) { let extension_codec = DefaultLogicalExtensionCodec {}; - let proto: protobuf::LogicalExprNode = - serialize_expr(&initial_struct, &extension_codec) - .unwrap_or_else(|e| panic!("Error serializing expression: {:?}", e)); - let round_trip: Expr = - from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + roundtrip_expr_test_with_codec(initial_struct, ctx, &extension_codec); +} + +// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test +// equality. +fn roundtrip_expr_test_with_codec( + initial_struct: Expr, + ctx: SessionContext, + codec: &dyn LogicalExtensionCodec, +) { + let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec) + .unwrap_or_else(|e| panic!("Error serializing expression: {e:?}")); + let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -180,9 +193,8 @@ impl LogicalExtensionCodec for TestTableProviderCodec { schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { - let msg = TestTableProto::decode(buf).map_err(|_| { - DataFusionError::Internal("Error decoding test table".to_string()) - })?; + let msg = TestTableProto::decode(buf) + .map_err(|_| internal_datafusion_err!("Error decoding test table"))?; assert_eq!(msg.table_name, table_ref.to_string()); let provider = TestTableProvider { url: msg.url, @@ -206,9 +218,8 @@ impl LogicalExtensionCodec for TestTableProviderCodec { url: table.url.clone(), table_name: table_ref.to_string(), }; - msg.encode(buf).map_err(|_| { - DataFusionError::Internal("Error encoding test table".to_string()) - }) + msg.encode(buf) + .map_err(|_| internal_datafusion_err!("Error encoding test table")) } } @@ -420,13 +431,13 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let input = create_csv_scan(&ctx).await?; let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); - let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), - output_url: "test.csv".to_string(), - partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.csv".to_string(), + vec!["a".to_string(), "b".to_string(), "c".to_string()], file_type, - options: Default::default(), - }); + Default::default(), + )); let codec = CsvLogicalExtensionCodec {}; let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; @@ -460,13 +471,13 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { ParquetFormatFactory::new_with_options(parquet_format), )); - let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), - output_url: "test.parquet".to_string(), + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.parquet".to_string(), + vec!["a".to_string(), "b".to_string(), "c".to_string()], file_type, - partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - options: Default::default(), - }); + Default::default(), + )); let codec = ParquetLogicalExtensionCodec {}; let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; @@ -492,13 +503,13 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { let file_type = format_as_file_type(Arc::new(ArrowFormatFactory::new())); - let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), - output_url: "test.arrow".to_string(), - partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.arrow".to_string(), + vec!["a".to_string(), "b".to_string(), "c".to_string()], file_type, - options: Default::default(), - }); + Default::default(), + )); let codec = ArrowLogicalExtensionCodec {}; let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; @@ -539,13 +550,13 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.clone(), ))); - let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), - output_url: "test.csv".to_string(), - partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.csv".to_string(), + vec!["a".to_string(), "b".to_string(), "c".to_string()], file_type, - options: Default::default(), - }); + Default::default(), + )); let codec = CsvLogicalExtensionCodec {}; let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; @@ -605,13 +616,13 @@ async fn roundtrip_logical_plan_copy_to_json() -> Result<()> { json_format.clone(), ))); - let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), - output_url: "test.json".to_string(), - partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.json".to_string(), + vec!["a".to_string(), "b".to_string(), "c".to_string()], file_type, - options: Default::default(), - }); + Default::default(), + )); // Assume JsonLogicalExtensionCodec is implemented similarly to CsvLogicalExtensionCodec let codec = JsonLogicalExtensionCodec {}; @@ -677,13 +688,13 @@ async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { ParquetFormatFactory::new_with_options(parquet_format.clone()), )); - let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), - output_url: "test.parquet".to_string(), - partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.parquet".to_string(), + vec!["a".to_string(), "b".to_string(), "c".to_string()], file_type, - options: Default::default(), - }); + Default::default(), + )); // Assume ParquetLogicalExtensionCodec is implemented similarly to JsonLogicalExtensionCodec let codec = ParquetLogicalExtensionCodec {}; @@ -951,16 +962,18 @@ async fn roundtrip_expr_api() -> Result<()> { array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), count(lit(1)), count_distinct(lit(1)), - first_value(lit(1), None), - first_value(lit(1), Some(vec![lit(2).sort(true, true)])), + first_value(lit(1), vec![]), + first_value(lit(1), vec![lit(2).sort(true, true)]), functions_window::nth_value::first_value(lit(1)), functions_window::nth_value::last_value(lit(1)), functions_window::nth_value::nth_value(lit(1), 1), avg(lit(1.5)), + avg_distinct(lit(1.5)), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), corr(lit(1.5), lit(2.2)), sum(lit(1)), + sum_distinct(lit(1)), max(lit(1)), median(lit(2)), min(lit(2)), @@ -970,9 +983,20 @@ async fn roundtrip_expr_api() -> Result<()> { stddev_pop(lit(2.2)), approx_distinct(lit(2)), approx_median(lit(2)), - approx_percentile_cont(lit(2), lit(0.5), None), - approx_percentile_cont(lit(2), lit(0.5), Some(lit(50))), - approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None), + approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + None, + ), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + Some(lit(50)), + ), grouping(lit(1)), bit_and(lit(2)), bit_or(lit(2)), @@ -1057,6 +1081,7 @@ pub mod proto { pub expr: Option, } + #[allow(dead_code)] #[derive(Clone, PartialEq, Eq, ::prost::Message)] pub struct TopKExecProto { #[prost(uint64, tag = "1")] @@ -1139,7 +1164,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { ) -> Result { if let Some((input, _)) = inputs.split_first() { let proto = proto::TopKPlanProto::decode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to decode logical plan: {e:?}")) + internal_datafusion_err!("failed to decode logical plan: {e:?}") })?; if let Some(expr) = proto.expr.as_ref() { @@ -1168,7 +1193,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { }; proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode logical plan: {e:?}")) + internal_datafusion_err!("failed to encode logical plan: {e:?}") })?; Ok(()) @@ -1236,7 +1261,7 @@ impl LogicalExtensionCodec for UDFExtensionCodec { fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { if name == "regex_udf" { let proto = MyRegexUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to decode regex_udf: {err}")) + internal_datafusion_err!("failed to decode regex_udf: {err}") })?; Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern)))) @@ -1251,18 +1276,16 @@ impl LogicalExtensionCodec for UDFExtensionCodec { let proto = MyRegexUdfNode { pattern: udf.pattern.clone(), }; - proto.encode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to encode udf: {err}")) - })?; + proto + .encode(buf) + .map_err(|err| internal_datafusion_err!("failed to encode udf: {err}"))?; Ok(()) } fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { if name == "aggregate_udf" { let proto = MyAggregateUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!( - "failed to decode aggregate_udf: {err}" - )) + internal_datafusion_err!("failed to decode aggregate_udf: {err}") })?; Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new( @@ -1279,9 +1302,9 @@ impl LogicalExtensionCodec for UDFExtensionCodec { let proto = MyAggregateUdfNode { result: udf.result.clone(), }; - proto.encode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to encode udf: {err}")) - })?; + proto + .encode(buf) + .map_err(|err| internal_datafusion_err!("failed to encode udf: {err}"))?; Ok(()) } } @@ -1565,7 +1588,7 @@ fn round_trip_scalar_values_and_data_types() { assert_eq!( dt, roundtrip, "DataType was not the same after round trip!\n\n\ - Input: {dt:?}\n\nRoundtrip: {roundtrip:?}" + Input: {dt}\n\nRoundtrip: {roundtrip:?}" ); } } @@ -1959,7 +1982,7 @@ fn roundtrip_case_with_null() { let test_expr = Expr::Case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(Expr::Literal(ScalarValue::Null))), + Some(Box::new(Expr::Literal(ScalarValue::Null, None))), )); let ctx = SessionContext::new(); @@ -1968,7 +1991,7 @@ fn roundtrip_case_with_null() { #[test] fn roundtrip_null_literal() { - let test_expr = Expr::Literal(ScalarValue::Null); + let test_expr = Expr::Literal(ScalarValue::Null, None); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2172,7 +2195,7 @@ fn roundtrip_aggregate_udf() { vec![lit(1.0_f64)], false, Some(Box::new(lit(true))), - None, + vec![], None, )); @@ -2182,8 +2205,7 @@ fn roundtrip_aggregate_udf() { roundtrip_expr_test(test_expr, ctx); } -#[test] -fn roundtrip_scalar_udf() { +fn dummy_udf() -> ScalarUDF { let scalar_fn = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { panic!("should be array") @@ -2191,13 +2213,18 @@ fn roundtrip_scalar_udf() { Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) }); - let udf = create_udf( + create_udf( "dummy", vec![DataType::Utf8], DataType::Utf8, Volatility::Immutable, scalar_fn, - ); + ) +} + +#[test] +fn roundtrip_scalar_udf() { + let udf = dummy_udf(); let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( Arc::new(udf.clone()), @@ -2207,7 +2234,57 @@ fn roundtrip_scalar_udf() { let ctx = SessionContext::new(); ctx.register_udf(udf); - roundtrip_expr_test(test_expr, ctx); + roundtrip_expr_test(test_expr.clone(), ctx); + + // Now test loading the UDF without registering it in the context, but rather creating it in the + // extension codec. + #[derive(Debug)] + struct DummyUDFExtensionCodec; + + impl LogicalExtensionCodec for DummyUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + if name == "dummy" { + Ok(Arc::new(dummy_udf())) + } else { + Err(internal_datafusion_err!("UDF {name} not found")) + } + } + } + + let ctx = SessionContext::new(); + roundtrip_expr_test_with_codec(test_expr, ctx, &DummyUDFExtensionCodec) } #[test] @@ -2296,7 +2373,7 @@ fn roundtrip_window() { let ctx = SessionContext::new(); // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2307,7 +2384,7 @@ fn roundtrip_window() { .unwrap(); // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2324,7 +2401,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr3 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2341,7 +2418,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr4 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) @@ -2391,7 +2468,7 @@ fn roundtrip_window() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr5 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) @@ -2420,7 +2497,7 @@ fn roundtrip_window() { } } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimpleWindowUDF { signature: Signature, } @@ -2453,17 +2530,25 @@ fn roundtrip_window() { make_partition_evaluator() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - if let Some(return_type) = field_args.get_input_type(0) { - Ok(Field::new(field_args.name(), return_type, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + if let Some(return_field) = field_args.get_input_field(0) { + Ok(return_field + .as_ref() + .clone() + .with_name(field_args.name()) + .into()) } else { plan_err!( "dummy_udwf expects 1 argument, got {}: {:?}", - field_args.input_types().len(), - field_args.input_types() + field_args.input_fields().len(), + field_args.input_fields() ) } } + + fn limit_effect(&self, _args: &[Arc]) -> LimitEffect { + LimitEffect::Unknown + } } fn make_partition_evaluator() -> Result> { @@ -2472,7 +2557,7 @@ fn roundtrip_window() { let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); - let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr6 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) @@ -2482,7 +2567,7 @@ fn roundtrip_window() { .build() .unwrap(); - let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + let text_expr7 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) @@ -2559,3 +2644,84 @@ async fn roundtrip_union_query() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn roundtrip_custom_listing_tables_schema() -> Result<()> { + let ctx = SessionContext::new(); + // Make sure during round-trip, constraint information is preserved + let file_format = JsonFormat::default(); + let table_partition_cols = vec![("part".to_owned(), DataType::Int64)]; + let data = "../core/tests/data/partitioned_table_json"; + let listing_table_url = ListingTableUrl::parse(data)?; + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_table_partition_cols(table_partition_cols); + + let config = ListingTableConfig::new(listing_table_url) + .with_listing_options(listing_options) + .infer_schema(&ctx.state()) + .await?; + + ctx.register_table("hive_style", Arc::new(ListingTable::try_new(config)?))?; + + let plan = ctx + .sql("SELECT part, value FROM hive_style LIMIT 1") + .await? + .logical_plan() + .clone(); + + let bytes = logical_plan_to_bytes(&plan)?; + let new_plan = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(plan, new_plan); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_custom_listing_tables_schema_table_scan_projection() -> Result<()> { + let ctx = SessionContext::new(); + // Make sure during round-trip, constraint information is preserved + let file_format = JsonFormat::default(); + let table_partition_cols = vec![("part".to_owned(), DataType::Int64)]; + let data = "../core/tests/data/partitioned_table_json"; + let listing_table_url = ListingTableUrl::parse(data)?; + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_table_partition_cols(table_partition_cols); + + let config = ListingTableConfig::new(listing_table_url) + .with_listing_options(listing_options) + .infer_schema(&ctx.state()) + .await?; + + let listing_table: Arc = Arc::new(ListingTable::try_new(config)?); + + let projection = ["part", "value"] + .iter() + .map(|field_name| listing_table.schema().index_of(field_name)) + .collect::, _>>()?; + + let plan = LogicalPlanBuilder::scan( + "hive_style", + Arc::new(DefaultTableSource::new(listing_table)), + Some(projection), + )? + .limit(0, Some(1))? + .build()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let new_plan = logical_plan_from_bytes(&bytes, &ctx)?; + + assert_eq!(plan, new_plan); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_arrow_scan() -> Result<()> { + let ctx = SessionContext::new(); + let plan = ctx + .read_arrow("tests/testdata/test.arrow", ArrowReadOptions::default()) + .await? + .into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + Ok(()) +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 6356b8b7b0cf4..b93d0d3c4e7cb 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -16,8 +16,9 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::fmt::{Display, Formatter}; -use std::ops::Deref; + use std::sync::Arc; use std::vec; @@ -42,9 +43,11 @@ use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::file_format::csv::CsvSink; -use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::json::{JsonFormat, JsonSink}; use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, PartitionedFile, +}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileGroup, @@ -52,7 +55,8 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::datasource::sink::DataSinkExec; use datafusion::datasource::source::DataSourceExec; -use datafusion::execution::FunctionRegistry; +use datafusion::execution::TaskContext; +use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::functions_window::nth_value::nth_value_udwf; use datafusion::functions_window::row_number::row_number_udwf; @@ -60,23 +64,25 @@ use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion::physical_expr::{ - LexOrdering, LexRequirement, PhysicalSortRequirement, ScalarFunctionExpr, + LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion::physical_plan::analyze::AnalyzeExec; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ binary, cast, col, in_list, like, lit, BinaryExpr, Column, NotExpr, PhysicalSortExpr, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::{ - HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, + HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, + StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; @@ -90,13 +96,14 @@ use datafusion::physical_plan::{ }; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion::scalar::ScalarValue; -use datafusion_common::config::TableParquetOptions; +use datafusion_common::config::{ConfigOptions, TableParquetOptions}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_err, not_impl_err, DataFusionError, Result, UnnestOptions, + internal_datafusion_err, internal_err, not_impl_err, DataFusionError, NullEquality, + Result, UnnestOptions, }; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, @@ -108,8 +115,7 @@ use datafusion_functions_aggregate::string_agg::string_agg_udaf; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; -use datafusion_proto::protobuf; -use datafusion_proto::protobuf::PhysicalPlanNode; +use datafusion_proto::protobuf::{self, PhysicalPlanNode}; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -135,11 +141,14 @@ fn roundtrip_test_and_return( let proto: protobuf::PhysicalPlanNode = protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), codec) .expect("to proto"); - let runtime = ctx.runtime_env(); let result_exec_plan: Arc = proto - .try_into_physical_plan(ctx, runtime.deref(), codec) + .try_into_physical_plan(&ctx.task_ctx(), codec) .expect("from proto"); - assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); + + pretty_assertions::assert_eq!( + format!("{exec_plan:?}"), + format!("{result_exec_plan:?}") + ); Ok(result_exec_plan) } @@ -204,7 +213,10 @@ fn roundtrip_date_time_interval() -> Result<()> { let date_time_interval_expr = binary(date_expr, Operator::Plus, literal_expr, &schema)?; let plan = Arc::new(ProjectionExec::try_new( - vec![(date_time_interval_expr, "result".to_string())], + vec![ProjectionExpr { + expr: date_time_interval_expr, + alias: "result".to_string(), + }], input, )?); roundtrip_test(plan) @@ -267,7 +279,7 @@ fn roundtrip_hash_join() -> Result<()> { join_type, None, *partition_mode, - false, + NullEquality::NullEqualsNothing, )?))?; } } @@ -320,9 +332,9 @@ fn roundtrip_udwf() -> Result<()> { &[ col("a", &schema)? ], - &LexOrdering::new(vec![ - PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)), - ]), + &[ + PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)) + ], Arc::new(WindowFrame::new(None)), )); @@ -359,13 +371,13 @@ fn roundtrip_window() -> Result<()> { let udwf_expr = Arc::new(StandardWindowExpr::new( nth_value_window, &[col("b", &schema)?], - &LexOrdering::new(vec![PhysicalSortExpr { + &[PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]), + }], Arc::new(window_frame), )); @@ -379,8 +391,9 @@ fn roundtrip_window() -> Result<()> { .build() .map(Arc::new)?, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), + None, )); let window_frame = WindowFrame::new_bounds( @@ -392,15 +405,16 @@ fn roundtrip_window() -> Result<()> { let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) .schema(Arc::clone(&schema)) - .alias("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") + .alias("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEDING") .build() .map(Arc::new)?; let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( sum_expr, &[], - &LexOrdering::default(), + &[], Arc::new(window_frame), + None, )); let input = Arc::new(EmptyExec::new(schema.clone())); @@ -413,7 +427,118 @@ fn roundtrip_window() -> Result<()> { } #[test] -fn rountrip_aggregate() -> Result<()> { +fn roundtrip_window_distinct() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + // Create a distinct count window expression with unbounded frame (becomes PlainAggregateWindowExpr) + let distinct_count_expr = Arc::new(PlainAggregateWindowExpr::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count(DISTINCT a)") + .distinct() // Enable distinct + .build() + .map(Arc::new)?, + &[col("b", &schema)?], // partition by b + &[], // no order by + Arc::new(WindowFrame::new(None)), // unbounded frame + None, + )); + + // Create a distinct sum window expression with bounded frame (becomes SlidingAggregateWindowExpr) + let bounded_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), + WindowFrameBound::CurrentRow, + ); + + let distinct_sum_expr = Arc::new(SlidingAggregateWindowExpr::new( + AggregateExprBuilder::new( + sum_udaf(), + vec![cast(col("a", &schema)?, &schema, DataType::Float64)?], + ) + .schema(Arc::clone(&schema)) + .alias("sum(DISTINCT a)") + .distinct() // Enable distinct + .with_ignore_nulls(true) // Enable ignore nulls + .build() + .map(Arc::new)?, + &[], // no partition by + &[], // no order by + Arc::new(bounded_frame), // bounded frame + None, + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + roundtrip_test(Arc::new(WindowAggExec::try_new( + vec![distinct_count_expr, distinct_sum_expr], + input, + false, + )?)) +} + +#[test] +fn test_distinct_window_serialization_end_to_end() -> Result<()> { + // Create a more comprehensive test that verifies distinct window functions + // work properly through the entire serialization/deserialization pipeline + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + // Test 1: DISTINCT COUNT with IGNORE NULLS + let distinct_count_ignore_nulls = Arc::new(PlainAggregateWindowExpr::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_distinct_ignore_nulls") + .distinct() + .with_ignore_nulls(true) + .build() + .map(Arc::new)?, + &[col("b", &schema)?], + &[], + Arc::new(WindowFrame::new(None)), + None, + )); + + // Test 2: DISTINCT SUM (without ignore nulls) + let bounded_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + WindowFrameBound::CurrentRow, + ); + + let distinct_sum = Arc::new(SlidingAggregateWindowExpr::new( + AggregateExprBuilder::new( + sum_udaf(), + vec![cast(col("a", &schema)?, &schema, DataType::Float64)?], + ) + .schema(Arc::clone(&schema)) + .alias("sum_distinct") + .distinct() + .build() + .map(Arc::new)?, + &[], + &[], + Arc::new(bounded_frame), + None, + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + let window_exec = Arc::new(WindowAggExec::try_new( + vec![distinct_count_ignore_nulls, distinct_sum], + input, + false, + )?); + + // Perform the roundtrip test + roundtrip_test(window_exec) +} + +#[test] +fn roundtrip_aggregate() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -461,7 +586,7 @@ fn rountrip_aggregate() -> Result<()> { } #[test] -fn rountrip_aggregate_with_limit() -> Result<()> { +fn roundtrip_aggregate_with_limit() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -491,7 +616,7 @@ fn rountrip_aggregate_with_limit() -> Result<()> { } #[test] -fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { +fn roundtrip_aggregate_with_approx_pencentile_cont() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -504,7 +629,7 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { vec![col("b", &schema)?, lit(0.5)], ) .schema(Arc::clone(&schema)) - .alias("APPROX_PERCENTILE_CONT(b, 0.5)") + .alias("APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY b)") .build() .map(Arc::new)?]; @@ -520,20 +645,20 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { } #[test] -fn rountrip_aggregate_with_sort() -> Result<()> { +fn roundtrip_aggregate_with_sort() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let sort_exprs = LexOrdering::new(vec![PhysicalSortExpr { + let sort_exprs = vec![PhysicalSortExpr { expr: col("b", &schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }]; let aggregates = vec![ @@ -594,7 +719,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { Signature::exact(vec![DataType::Int64], Volatility::Immutable), return_type, accumulator, - vec![Field::new("value", DataType::Int64, true)], + vec![Field::new("value", DataType::Int64, true).into()], )); let ctx = SessionContext::new(); @@ -653,7 +778,7 @@ fn roundtrip_sort() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = LexOrdering::new(vec![ + let sort_exprs = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -668,7 +793,8 @@ fn roundtrip_sort() -> Result<()> { nulls_first: true, }, }, - ]); + ] + .into(); roundtrip_test(Arc::new(SortExec::new( sort_exprs, Arc::new(EmptyExec::new(schema)), @@ -680,7 +806,7 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = LexOrdering::new(vec![ + let sort_exprs: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -695,7 +821,8 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { nulls_first: true, }, }, - ]); + ] + .into(); roundtrip_test(Arc::new(SortExec::new( sort_exprs.clone(), @@ -709,7 +836,7 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { } #[test] -fn roundtrip_coalesce_with_fetch() -> Result<()> { +fn roundtrip_coalesce_batches_with_fetch() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -725,6 +852,22 @@ fn roundtrip_coalesce_with_fetch() -> Result<()> { )) } +#[test] +fn roundtrip_coalesce_partitions_with_fetch() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + roundtrip_test(Arc::new(CoalescePartitionsExec::new(Arc::new( + EmptyExec::new(schema.clone()), + ))))?; + + roundtrip_test(Arc::new( + CoalescePartitionsExec::new(Arc::new(EmptyExec::new(schema))) + .with_fetch(Some(10)), + )) +} + #[test] fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let file_schema = @@ -739,9 +882,7 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let mut options = TableParquetOptions::new(); options.global.pushdown_filters = true; - let file_source = Arc::new( - ParquetSource::new(options).with_predicate(Arc::clone(&file_schema), predicate), - ); + let file_source = Arc::new(ParquetSource::new(options).with_predicate(predicate)); let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), @@ -800,10 +941,8 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { inner: Arc::new(Column::new("col", 1)), }); - let file_source = Arc::new( - ParquetSource::default() - .with_predicate(Arc::clone(&file_schema), custom_predicate_expr), - ); + let file_source = + Arc::new(ParquetSource::default().with_predicate(custom_predicate_expr)); let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), @@ -872,7 +1011,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { self: Arc, _children: Vec>, ) -> Result> { - todo!() + Ok(self) } fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -887,7 +1026,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { &self, _buf: &[u8], _inputs: &[Arc], - _registry: &dyn FunctionRegistry, + _ctx: &TaskContext, ) -> Result> { unreachable!() } @@ -968,11 +1107,17 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", fun_def, vec![col("a", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), + Arc::new(ConfigOptions::default()), ); - let project = - ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?; + let project = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(expr), + alias: "a".to_string(), + }], + input, + )?; let ctx = SessionContext::new(); @@ -989,7 +1134,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { &self, _buf: &[u8], _inputs: &[Arc], - _registry: &dyn FunctionRegistry, + _ctx: &TaskContext, ) -> Result> { not_impl_err!("No extension codec provided") } @@ -1005,7 +1150,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { if name == "regex_udf" { let proto = MyRegexUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to decode regex_udf: {err}")) + internal_datafusion_err!("failed to decode regex_udf: {err}") })?; Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern)))) @@ -1020,9 +1165,9 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { let proto = MyRegexUdfNode { pattern: udf.pattern.clone(), }; - proto.encode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to encode udf: {err}")) - })?; + proto + .encode(buf) + .map_err(|err| internal_datafusion_err!("failed to encode udf: {err}"))?; } Ok(()) } @@ -1030,9 +1175,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { if name == "aggregate_udf" { let proto = MyAggregateUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!( - "failed to decode aggregate_udf: {err}" - )) + internal_datafusion_err!("failed to decode aggregate_udf: {err}") })?; Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new( @@ -1050,7 +1193,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { result: udf.result.clone(), }; proto.encode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to encode udf: {err:?}")) + internal_datafusion_err!("failed to encode udf: {err:?}") })?; } Ok(()) @@ -1059,7 +1202,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { if name == "custom_udwf" { let proto = CustomUDWFNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to decode custom_udwf: {err}")) + internal_datafusion_err!("failed to decode custom_udwf: {err}") })?; Ok(Arc::new(WindowUDF::from(CustomUDWF::new(proto.payload)))) @@ -1077,7 +1220,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { payload: udwf.payload.clone(), }; proto.encode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to encode udwf: {err:?}")) + internal_datafusion_err!("failed to encode udwf: {err:?}") })?; } Ok(()) @@ -1096,7 +1239,8 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), + Arc::new(ConfigOptions::default()), )); let filter = Arc::new(FilterExec::try_new( @@ -1118,8 +1262,9 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { vec![Arc::new(PlainAggregateWindowExpr::new( aggr_expr.clone(), &[col("author", &schema)?], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), + None, ))], filter, true, @@ -1163,13 +1308,13 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { let udwf_expr = Arc::new(StandardWindowExpr::new( udwf, &[col("b", &schema)?], - &LexOrdering::new(vec![PhysicalSortExpr { + &[PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]), + }], Arc::new(window_frame), )); @@ -1198,7 +1343,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), + Arc::new(ConfigOptions::default()), )); let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( @@ -1226,8 +1372,9 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { vec![Arc::new(PlainAggregateWindowExpr::new( aggr_expr, &[col("author", &schema)?], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), + None, ))], filter, true, @@ -1270,7 +1417,10 @@ fn roundtrip_like() -> Result<()> { &schema, )?; let plan = Arc::new(ProjectionExec::try_new( - vec![(like_expr, "result".to_string())], + vec![ProjectionExpr { + expr: like_expr, + alias: "result".to_string(), + }], input, )?); roundtrip_test(plan) @@ -1322,13 +1472,14 @@ fn roundtrip_json_sink() -> Result<()> { file_sink_config, JsonWriterOptions::new(CompressionTypeVariant::UNCOMPRESSED), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); roundtrip_test(Arc::new(DataSinkExec::new( input, @@ -1359,13 +1510,14 @@ fn roundtrip_csv_sink() -> Result<()> { file_sink_config, CsvWriterOptions::new(WriterBuilder::default(), CompressionTypeVariant::ZSTD), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; @@ -1415,13 +1567,14 @@ fn roundtrip_parquet_sink() -> Result<()> { file_sink_config, TableParquetOptions::default(), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); roundtrip_test(Arc::new(DataSinkExec::new( input, @@ -1458,31 +1611,29 @@ fn roundtrip_sym_hash_join() -> Result<()> { ] { for left_order in &[ None, - Some(LexOrdering::new(vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("col", schema_left.index_of("col")?)), options: Default::default(), - }])), + }]), ] { - for right_order in &[ + for right_order in [ None, - Some(LexOrdering::new(vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("col", schema_right.index_of("col")?)), options: Default::default(), - }])), + }]), ] { - roundtrip_test(Arc::new( - datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( - Arc::new(EmptyExec::new(schema_left.clone())), - Arc::new(EmptyExec::new(schema_right.clone())), - on.clone(), - None, - join_type, - false, - left_order.clone(), - right_order.clone(), - *partition_mode, - )?, - ))?; + roundtrip_test(Arc::new(SymmetricHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + NullEquality::NullEqualsNothing, + left_order.clone(), + right_order, + *partition_mode, + )?))?; } } } @@ -1498,8 +1649,8 @@ fn roundtrip_union() -> Result<()> { let left = EmptyExec::new(Arc::new(schema_left)); let right = EmptyExec::new(Arc::new(schema_right)); let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; - let union = UnionExec::new(inputs); - roundtrip_test(Arc::new(union)) + let union = UnionExec::try_new(inputs)?; + roundtrip_test(union) } #[test] @@ -1585,11 +1736,44 @@ async fn roundtrip_coalesce() -> Result<()> { )?; let node = PhysicalPlanNode::decode(node.encode_to_vec().as_slice()) .map_err(|e| DataFusionError::External(Box::new(e)))?; - let restored = node.try_into_physical_plan( - &ctx, - ctx.runtime_env().as_ref(), + let restored = + node.try_into_physical_plan(&ctx.task_ctx(), &DefaultPhysicalExtensionCodec {})?; + + assert_eq!( + plan.schema(), + restored.schema(), + "Schema mismatch for plans:\n>> initial:\n{}>> final: \n{}", + displayable(plan.as_ref()) + .set_show_schema(true) + .indent(true), + displayable(restored.as_ref()) + .set_show_schema(true) + .indent(true), + ); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_generate_series() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_table( + "t", + Arc::new(EmptyTable::new(Arc::new(Schema::new(Fields::from([ + Arc::new(Field::new("f", DataType::Int64, false)), + ]))))), + )?; + let df = ctx.sql("select * from generate_series(1, 10000)").await?; + let plan = df.create_physical_plan().await?; + + let node = PhysicalPlanNode::try_from_physical_plan( + plan.clone(), &DefaultPhysicalExtensionCodec {}, )?; + let node = PhysicalPlanNode::decode(node.encode_to_vec().as_slice()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let restored = + node.try_into_physical_plan(&ctx.task_ctx(), &DefaultPhysicalExtensionCodec {})?; assert_eq!( plan.schema(), @@ -1676,3 +1860,405 @@ async fn roundtrip_empty_projection() -> Result<()> { let sql = "select 1 from alltypes_plain"; roundtrip_test_sql_with_context(sql, &ctx).await } + +#[tokio::test] +async fn roundtrip_physical_plan_node() { + use datafusion::prelude::*; + use datafusion_proto::physical_plan::{ + AsExecutionPlan, DefaultPhysicalExtensionCodec, + }; + use datafusion_proto::protobuf::PhysicalPlanNode; + + let ctx = SessionContext::new(); + + ctx.register_parquet( + "pt", + &format!( + "{}/alltypes_plain.snappy.parquet", + datafusion_common::test_util::parquet_test_data() + ), + ParquetReadOptions::default(), + ) + .await + .unwrap(); + + let plan = ctx + .sql("select id, string_col, timestamp_col from pt where id > 4 order by string_col") + .await + .unwrap() + .create_physical_plan() + .await + .unwrap(); + + let node: PhysicalPlanNode = + PhysicalPlanNode::try_from_physical_plan(plan, &DefaultPhysicalExtensionCodec {}) + .unwrap(); + + let plan = node + .try_into_physical_plan(&ctx.task_ctx(), &DefaultPhysicalExtensionCodec {}) + .unwrap(); + + let _ = plan.execute(0, ctx.task_ctx()).unwrap(); +} + +/// Helper function to create a SessionContext with all TPC-H tables registered as external tables +async fn tpch_context() -> Result { + use datafusion_common::test_util::datafusion_test_data; + + let ctx = SessionContext::new(); + let test_data = datafusion_test_data(); + + // TPC-H table names + let tables = [ + "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", + "region", + ]; + + // Create external tables for all TPC-H tables + for table in &tables { + let table_sql = format!( + "CREATE EXTERNAL TABLE {table} STORED AS PARQUET LOCATION '{test_data}/tpch_{table}_small.parquet'" + ); + ctx.sql(&table_sql).await.map_err(|e| { + DataFusionError::External( + format!("Failed to create {table} table: {e}").into(), + ) + })?; + } + + Ok(ctx) +} + +/// Helper function to get TPC-H query SQL +fn get_tpch_query_sql(query: usize) -> Result> { + use std::fs; + + if !(1..=22).contains(&query) { + return Err(DataFusionError::External( + format!("Invalid TPC-H query number: {query}").into(), + )); + } + + let filename = format!("../../benchmarks/queries/q{query}.sql"); + let contents = fs::read_to_string(&filename).map_err(|e| { + DataFusionError::External( + format!("Failed to read query file {filename}: {e}").into(), + ) + })?; + + Ok(contents + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect()) +} + +#[tokio::test] +async fn test_serialize_deserialize_tpch_queries() -> Result<()> { + // Create context with TPC-H tables + let ctx = tpch_context().await?; + + // repeat to run all 22 queries + for query in 1..=22 { + // run all statements in the query + let sql = get_tpch_query_sql(query)?; + for stmt in sql { + let logical_plan = ctx.sql(&stmt).await?.into_unoptimized_plan(); + let optimized_plan = ctx.state().optimize(&logical_plan)?; + let physical_plan = ctx.state().create_physical_plan(&optimized_plan).await?; + + // serialize the physical plan + let codec = DefaultPhysicalExtensionCodec {}; + let proto = + PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; + + // deserialize the physical plan + let _deserialized_plan = + proto.try_into_physical_plan(&ctx.task_ctx(), &codec)?; + } + } + + Ok(()) +} + +// Bugs: https://github.com/apache/datafusion/issues/16772 +#[tokio::test] +async fn test_round_trip_tpch_queries() -> Result<()> { + // Create context with TPC-H tables + let ctx = tpch_context().await?; + + // repeat to run all 22 queries + for query in 1..=22 { + // run all statements in the query + let sql = get_tpch_query_sql(query)?; + for stmt in sql { + roundtrip_test_sql_with_context(&stmt, &ctx).await?; + } + } + + Ok(()) +} + +// Bug 1 of https://github.com/apache/datafusion/issues/16772 +/// Test that AggregateFunctionExpr human_display field is correctly preserved +/// during serialization/deserialization roundtrip. +/// +/// Test for issue where the human_display field (used for EXPLAIN output) +/// was not being serialized to protobuf, causing it to be lost during roundtrip +/// and resulting in empty or incorrect display strings in query plans. +#[tokio::test] +async fn test_round_trip_human_display() -> Result<()> { + // Create context with TPC-H tables + let ctx = tpch_context().await?; + + let sql = "select r_name, count(1) from region group by r_name"; + roundtrip_test_sql_with_context(sql, &ctx).await?; + + let sql = "select r_name, count(*) from region group by r_name"; + roundtrip_test_sql_with_context(sql, &ctx).await?; + + let sql = "select r_name, count(r_name) from region group by r_name"; + roundtrip_test_sql_with_context(sql, &ctx).await?; + + Ok(()) +} + +// Bug 2 of https://github.com/apache/datafusion/issues/16772 +/// Test that PhysicalGroupBy groups field is correctly serialized/deserialized +/// for simple aggregates (no GROUP BY clause). +/// +/// Test for issue where simple aggregates like "SELECT SUM(col1 * col2) FROM table" +/// would incorrectly serialize groups as [[]] instead of [] during roundtrip serialization. +/// The groups field should be empty ([]) when there are no GROUP BY expressions. +#[tokio::test] +async fn test_round_trip_groups_display() -> Result<()> { + // Create context with TPC-H tables + let ctx = tpch_context().await?; + + let sql = "select sum(l_extendedprice * l_discount) as revenue from lineitem;"; + roundtrip_test_sql_with_context(sql, &ctx).await?; + + let sql = "select sum(l_extendedprice) as revenue from lineitem;"; + roundtrip_test_sql_with_context(sql, &ctx).await?; + + Ok(()) +} + +// Bug 3 of https://github.com/apache/datafusion/issues/16772 +/// Test that ScalarFunctionExpr return_field name is correctly preserved +/// during serialization/deserialization roundtrip. +/// +/// Test for issue where the return_field.name for scalar functions +/// was not being serialized to protobuf, causing it to be lost during roundtrip +/// and defaulting to a generic name like "f" instead of the proper function name. +#[tokio::test] +async fn test_round_trip_date_part_display() -> Result<()> { + // Create context with TPC-H tables + let ctx = tpch_context().await?; + + let sql = "select extract(year from l_shipdate) as l_year from lineitem "; + roundtrip_test_sql_with_context(sql, &ctx).await?; + + let sql = "select extract(month from l_shipdate) as l_year from lineitem "; + roundtrip_test_sql_with_context(sql, &ctx).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_tpch_part_in_list_query_with_real_parquet_data() -> Result<()> { + use datafusion_common::test_util::datafusion_test_data; + + let ctx = SessionContext::new(); + + // Register the TPC-H part table using the local test data + let test_data = datafusion_test_data(); + let table_sql = format!( + "CREATE EXTERNAL TABLE part STORED AS PARQUET LOCATION '{test_data}/tpch_part_small.parquet'" + ); + ctx.sql(&table_sql).await.map_err(|e| { + DataFusionError::External(format!("Failed to create part table: {e}").into()) + })?; + + // Test the exact problematic query + let sql = + "SELECT p_size FROM part WHERE p_size IN (14, 6, 5, 31) and p_partkey > 1000"; + + let logical_plan = ctx.sql(sql).await?.into_unoptimized_plan(); + let optimized_plan = ctx.state().optimize(&logical_plan)?; + let physical_plan = ctx.state().create_physical_plan(&optimized_plan).await?; + + // Serialize the physical plan - bug may happen here already but not necessarily manifests + let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; + + // This will fail with the bug, but should succeed when fixed + let _deserialized_plan = proto.try_into_physical_plan(&ctx.task_ctx(), &codec)?; + Ok(()) +} + +#[tokio::test] +/// Tests that we can serialize an unoptimized "analyze" plan and it will work on the other end +async fn analyze_roundtrip_unoptimized() -> Result<()> { + let ctx = SessionContext::new(); + + // No optimizations + let session_state = + datafusion::execution::SessionStateBuilder::new_from_existing(ctx.state()) + .with_physical_optimizer_rules(vec![]) + .build(); + + let logical_plan = session_state + .create_logical_plan("explain analyze select 1") + .await?; + let plan = session_state.create_physical_plan(&logical_plan).await?; + + let node = PhysicalPlanNode::try_from_physical_plan( + plan.clone(), + &DefaultPhysicalExtensionCodec {}, + )?; + + let node = PhysicalPlanNode::decode(node.encode_to_vec().as_slice()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let unoptimized = + node.try_into_physical_plan(&ctx.task_ctx(), &DefaultPhysicalExtensionCodec {})?; + + let physical_planner = + datafusion::physical_planner::DefaultPhysicalPlanner::default(); + physical_planner.optimize_physical_plan(unoptimized, &session_state, |_, _| {})?; + Ok(()) +} + +#[test] +fn roundtrip_sort_merge_join() -> Result<()> { + let field_a = Field::new("col_a", DataType::Int64, false); + let field_b = Field::new("col_b", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_b.clone()]); + let on = vec![( + Arc::new(Column::new("col_a", schema_left.index_of("col_a")?)) as _, + Arc::new(Column::new("col_b", schema_right.index_of("col_b")?)) as _, + )]; + + let filter = datafusion::physical_plan::joins::utils::JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("col_a", 1)), + Operator::Gt, + Arc::new(Column::new("col_b", 0)), + )), + vec![ + datafusion::physical_plan::joins::utils::ColumnIndex { + index: 0, + side: datafusion_common::JoinSide::Left, + }, + datafusion::physical_plan::joins::utils::ColumnIndex { + index: 0, + side: datafusion_common::JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![field_a, field_b])), + ); + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for filter in [None, Some(filter)] { + for join_type in [ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + roundtrip_test(Arc::new(SortMergeJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + filter.clone(), + join_type, + vec![Default::default()], + NullEquality::NullEqualsNothing, + )?))?; + } + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_sort_merge_join() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv( + "t0", + "tests/testdata/test.csv", + datafusion::prelude::CsvReadOptions::default().has_header(true), + ) + .await?; + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + datafusion::prelude::CsvReadOptions::default().has_header(true), + ) + .await?; + + ctx.sql("SET datafusion.optimizer.prefer_hash_join = false") + .await? + .show() + .await?; + + let query = "SELECT t1.* FROM t0 join t1 on t0.a = t1.a"; + let plan = ctx.sql(query).await?.create_physical_plan().await?; + roundtrip_test(plan) +} + +#[tokio::test] +async fn roundtrip_memory_source() -> Result<()> { + let ctx = SessionContext::new(); + let plan = ctx + .sql("select * from values ('Tom', 18)") + .await? + .create_physical_plan() + .await?; + roundtrip_test(plan) +} + +#[tokio::test] +async fn roundtrip_listing_table_with_schema_metadata() -> Result<()> { + let ctx = SessionContext::new(); + let file_format = JsonFormat::default(); + let table_partition_cols = vec![("part".to_owned(), DataType::Int64)]; + let data = "../core/tests/data/partitioned_table_json"; + let listing_table_url = ListingTableUrl::parse(data)?; + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_table_partition_cols(table_partition_cols); + + let config = ListingTableConfig::new(listing_table_url) + .with_listing_options(listing_options) + .infer_schema(&ctx.state()) + .await?; + + // Decorate metadata onto the inferred ListingTable schema + let schema_with_meta = config + .file_schema + .clone() + .map(|s| { + let mut meta: HashMap = HashMap::new(); + meta.insert("foo.bar".to_string(), "baz".to_string()); + s.as_ref().clone().with_metadata(meta) + }) + .expect("Must decorate metadata"); + + let config = config.with_schema(Arc::new(schema_with_meta)); + ctx.register_table("hive_style", Arc::new(ListingTable::try_new(config)?))?; + + let plan = ctx + .sql("select * from hive_style limit 1") + .await? + .create_physical_plan() + .await?; + + roundtrip_test(plan) +} diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d1b50105d053d..c9ef4377d43b1 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -83,7 +83,7 @@ fn udf_roundtrip_with_registry() { #[test] #[should_panic( - expected = "No function registry provided to deserialize, so can not deserialize User Defined Function 'dummy'" + expected = "LogicalExtensionCodec is not provided for scalar function dummy" )] fn udf_roundtrip_without_registry() { let ctx = context_with_udf(); @@ -256,11 +256,11 @@ fn test_expression_serialization_roundtrip() { use datafusion_proto::logical_plan::from_proto::parse_expr; let ctx = SessionContext::new(); - let lit = Expr::Literal(ScalarValue::Utf8(None)); + let lit = Expr::Literal(ScalarValue::Utf8(None), None); for function in string::functions() { // default to 4 args (though some exprs like substr have error checking) let num_args = 4; - let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); + let args: Vec<_> = std::iter::repeat_n(&lit, num_args).cloned().collect(); let expr = Expr::ScalarFunction(ScalarFunction::new_udf(function, args)); let extension_codec = DefaultLogicalExtensionCodec {}; diff --git a/datafusion/proto/tests/testdata/test.arrow b/datafusion/proto/tests/testdata/test.arrow new file mode 100644 index 0000000000000..5314d9eea1345 Binary files /dev/null and b/datafusion/proto/tests/testdata/test.arrow differ diff --git a/datafusion/pruning/Cargo.toml b/datafusion/pruning/Cargo.toml new file mode 100644 index 0000000000000..2429123bdf966 --- /dev/null +++ b/datafusion/pruning/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "datafusion-pruning" +description = "DataFusion Pruning Logic" +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } + +[lints] +workspace = true + +[dependencies] +arrow = { workspace = true } +datafusion-common = { workspace = true, default-features = true } +datafusion-datasource = { workspace = true } +datafusion-expr-common = { workspace = true, default-features = true } +datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +datafusion-physical-plan = { workspace = true } +itertools = { workspace = true } +log = { workspace = true } + +[dev-dependencies] +datafusion-expr = { workspace = true } +datafusion-functions-nested = { workspace = true } +insta = { workspace = true } diff --git a/datafusion/pruning/README.md b/datafusion/pruning/README.md new file mode 100644 index 0000000000000..4db509193d172 --- /dev/null +++ b/datafusion/pruning/README.md @@ -0,0 +1,34 @@ + + +# Apache DataFusion Pruning Logic + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate is a submodule of DataFusion that contains pruning logic, to analyze filter expressions with +statistics such as min/max values and null counts, proving files / large subsections of files can be skipped +without reading the actual data. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/pruning/src/file_pruner.rs b/datafusion/pruning/src/file_pruner.rs new file mode 100644 index 0000000000000..ee86a8cc8cd58 --- /dev/null +++ b/datafusion/pruning/src/file_pruner.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! File-level pruning based on partition values and file-level statistics + +use std::sync::Arc; + +use arrow::datatypes::{FieldRef, Schema, SchemaRef}; +use datafusion_common::{ + pruning::{ + CompositePruningStatistics, PartitionPruningStatistics, PrunableStatistics, + PruningStatistics, + }, + Result, +}; +use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr_common::physical_expr::{snapshot_generation, PhysicalExpr}; +use datafusion_physical_plan::metrics::Count; +use itertools::Itertools; +use log::debug; + +use crate::build_pruning_predicate; + +/// Prune based on partition values and file-level statistics. +pub struct FilePruner { + predicate_generation: Option, + predicate: Arc, + /// Schema used for pruning, which combines the file schema and partition fields. + /// Partition fields are always at the end, as they are during scans. + pruning_schema: Arc, + partitioned_file: PartitionedFile, + partition_fields: Vec, + predicate_creation_errors: Count, +} + +impl FilePruner { + pub fn new( + predicate: Arc, + logical_file_schema: &SchemaRef, + partition_fields: Vec, + partitioned_file: PartitionedFile, + predicate_creation_errors: Count, + ) -> Result { + // Build a pruning schema that combines the file fields and partition fields. + // Partition fields are always at the end. + let pruning_schema = Arc::new( + Schema::new( + logical_file_schema + .fields() + .iter() + .cloned() + .chain(partition_fields.iter().cloned()) + .collect_vec(), + ) + .with_metadata(logical_file_schema.metadata().clone()), + ); + Ok(Self { + // Initialize the predicate generation to None so that the first time we call `should_prune` we actually check the predicate + // Subsequent calls will only do work if the predicate itself has changed. + // See `snapshot_generation` for more info. + predicate_generation: None, + predicate, + pruning_schema, + partitioned_file, + partition_fields, + predicate_creation_errors, + }) + } + + pub fn should_prune(&mut self) -> Result { + let new_generation = snapshot_generation(&self.predicate); + if let Some(current_generation) = self.predicate_generation.as_mut() { + if *current_generation == new_generation { + return Ok(false); + } + *current_generation = new_generation; + } else { + self.predicate_generation = Some(new_generation); + } + let pruning_predicate = build_pruning_predicate( + Arc::clone(&self.predicate), + &self.pruning_schema, + &self.predicate_creation_errors, + ); + if let Some(pruning_predicate) = pruning_predicate { + // The partition column schema is the schema of the table - the schema of the file + let mut pruning = Box::new(PartitionPruningStatistics::try_new( + vec![self.partitioned_file.partition_values.clone()], + self.partition_fields.clone(), + )?) as Box; + if let Some(stats) = &self.partitioned_file.statistics { + let stats_pruning = Box::new(PrunableStatistics::new( + vec![Arc::clone(stats)], + Arc::clone(&self.pruning_schema), + )); + pruning = Box::new(CompositePruningStatistics::new(vec![ + pruning, + stats_pruning, + ])); + } + match pruning_predicate.prune(pruning.as_ref()) { + Ok(values) => { + assert!(values.len() == 1); + // We expect a single container -> if all containers are false skip this file + if values.into_iter().all(|v| !v) { + return Ok(true); + } + } + // Stats filter array could not be built, so we can't prune + Err(e) => { + debug!("Ignoring error building pruning predicate for file: {e}"); + self.predicate_creation_errors.add(1); + } + } + } + + Ok(false) + } +} diff --git a/datafusion/pruning/src/lib.rs b/datafusion/pruning/src/lib.rs new file mode 100644 index 0000000000000..cec4fab2262f8 --- /dev/null +++ b/datafusion/pruning/src/lib.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod file_pruner; +mod pruning_predicate; + +pub use file_pruner::FilePruner; +pub use pruning_predicate::{ + build_pruning_predicate, PredicateRewriter, PruningPredicate, PruningStatistics, + RequiredColumns, UnhandledPredicateHook, +}; diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/pruning/src/pruning_predicate.rs similarity index 90% rename from datafusion/physical-optimizer/src/pruning.rs rename to datafusion/pruning/src/pruning_predicate.rs index b5287f3d33f3c..fa3454ce56442 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -28,12 +28,16 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; -use log::trace; +// pub use for backwards compatibility +pub use datafusion_common::pruning::PruningStatistics; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; +use datafusion_physical_plan::metrics::Count; +use log::{debug, trace}; -use datafusion_common::error::{DataFusionError, Result}; +use datafusion_common::error::Result; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, + internal_datafusion_err, internal_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, ScalarValue, }; @@ -41,108 +45,9 @@ use datafusion_common::{Column, DFSchema}; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; +use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; -/// A source of runtime statistical information to [`PruningPredicate`]s. -/// -/// # Supported Information -/// -/// 1. Minimum and maximum values for columns -/// -/// 2. Null counts and row counts for columns -/// -/// 3. Whether the values in a column are contained in a set of literals -/// -/// # Vectorized Interface -/// -/// Information for containers / files are returned as Arrow [`ArrayRef`], so -/// the evaluation happens once on a single `RecordBatch`, which amortizes the -/// overhead of evaluating the predicate. This is important when pruning 1000s -/// of containers which often happens in analytic systems that have 1000s of -/// potential files to consider. -/// -/// For example, for the following three files with a single column `a`: -/// ```text -/// file1: column a: min=5, max=10 -/// file2: column a: No stats -/// file2: column a: min=20, max=30 -/// ``` -/// -/// PruningStatistics would return: -/// -/// ```text -/// min_values("a") -> Some([5, Null, 20]) -/// max_values("a") -> Some([10, Null, 30]) -/// min_values("X") -> None -/// ``` -pub trait PruningStatistics { - /// Return the minimum values for the named column, if known. - /// - /// If the minimum value for a particular container is not known, the - /// returned array should have `null` in that row. If the minimum value is - /// not known for any row, return `None`. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn min_values(&self, column: &Column) -> Option; - - /// Return the maximum values for the named column, if known. - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn max_values(&self, column: &Column) -> Option; - - /// Return the number of containers (e.g. Row Groups) being pruned with - /// these statistics. - /// - /// This value corresponds to the size of the [`ArrayRef`] returned by - /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], - /// and [`Self::row_counts`]. - fn num_containers(&self) -> usize; - - /// Return the number of null values for the named column as an - /// [`UInt64Array`] - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - /// - /// [`UInt64Array`]: arrow::array::UInt64Array - fn null_counts(&self, column: &Column) -> Option; - - /// Return the number of rows for the named column in each container - /// as an [`UInt64Array`]. - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - /// - /// [`UInt64Array`]: arrow::array::UInt64Array - fn row_counts(&self, column: &Column) -> Option; - - /// Returns [`BooleanArray`] where each row represents information known - /// about specific literal `values` in a column. - /// - /// For example, Parquet Bloom Filters implement this API to communicate - /// that `values` are known not to be present in a Row Group. - /// - /// The returned array has one row for each container, with the following - /// meanings: - /// * `true` if the values in `column` ONLY contain values from `values` - /// * `false` if the values in `column` are NOT ANY of `values` - /// * `null` if the neither of the above holds or is unknown. - /// - /// If these statistics can not determine column membership for any - /// container, return `None` (the default). - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn contained( - &self, - column: &Column, - values: &HashSet, - ) -> Option; -} - /// Used to prove that arbitrary predicates (boolean expression) can not /// possibly evaluate to `true` given information about a column provided by /// [`PruningStatistics`]. @@ -312,13 +217,13 @@ pub trait PruningStatistics { /// * `true`: there MAY be rows that pass the predicate, **KEEPS** the container /// /// * `NULL`: there MAY be rows that pass the predicate, **KEEPS** the container -/// Note that rewritten predicate can evaluate to NULL when some of -/// the min/max values are not known. *Note that this is different than -/// the SQL filter semantics where `NULL` means the row is filtered -/// out.* +/// Note that rewritten predicate can evaluate to NULL when some of +/// the min/max values are not known. *Note that this is different than +/// the SQL filter semantics where `NULL` means the row is filtered +/// out.* /// /// * `false`: there are no rows that could possibly match the predicate, -/// **PRUNES** the container +/// **PRUNES** the container /// /// For example, given a column `x`, the `x_min`, `x_max`, `x_null_count`, and /// `x_row_count` represent the minimum and maximum values, the null count of @@ -473,6 +378,30 @@ pub struct PruningPredicate { literal_guarantees: Vec, } +/// Build a pruning predicate from an optional predicate expression. +/// If the predicate is None or the predicate cannot be converted to a pruning +/// predicate, return None. +/// If there is an error creating the pruning predicate it is recorded by incrementing +/// the `predicate_creation_errors` counter. +pub fn build_pruning_predicate( + predicate: Arc, + file_schema: &SchemaRef, + predicate_creation_errors: &Count, +) -> Option> { + match PruningPredicate::try_new(predicate, Arc::clone(file_schema)) { + Ok(pruning_predicate) => { + if !pruning_predicate.always_true() { + return Some(Arc::new(pruning_predicate)); + } + } + Err(e) => { + debug!("Could not create pruning predicate for: {e}"); + predicate_creation_errors.add(1); + } + } + None +} + /// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain /// complex expressions or predicates that reference columns that are not in the /// schema. @@ -527,16 +456,23 @@ impl PruningPredicate { /// See the struct level documentation on [`PruningPredicate`] for more /// details. pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { + // Get a (simpler) snapshot of the physical expr here to use with `PruningPredicate` + // which does not handle dynamic exprs in general + let expr = snapshot_physical_expr(expr)?; let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; // build predicate expression once let mut required_columns = RequiredColumns::new(); let predicate_expr = build_predicate_expression( &expr, - schema.as_ref(), + &schema, &mut required_columns, &unhandled_hook, ); + let predicate_schema = required_columns.schema(); + // Simplify the newly created predicate to get rid of redundant casts, comparisons, etc. + let predicate_expr = + PhysicalExprSimplifier::new(&predicate_schema).simplify(predicate_expr)?; let literal_guarantees = LiteralGuarantee::analyze(&expr); @@ -563,7 +499,10 @@ impl PruningPredicate { /// simplified version `b`. See [`ExprSimplifier`] to simplify expressions. /// /// [`ExprSimplifier`]: https://docs.rs/datafusion/latest/datafusion/optimizer/simplify_expressions/struct.ExprSimplifier.html - pub fn prune(&self, statistics: &S) -> Result> { + pub fn prune( + &self, + statistics: &S, + ) -> Result> { let mut builder = BoolVecBuilder::new(statistics.num_containers()); // Try to prove the predicate can't be true for the containers based on @@ -581,9 +520,9 @@ impl PruningPredicate { // If `contained` returns false, that means the column is // not any of the values so we can prune the container Guarantee::In => builder.combine_array(&results), - // `NotIn` means the values in the column must must not be + // `NotIn` means the values in the column must not be // any of the values in the set for the predicate to - // evaluate to true. If contained returns true, it means the + // evaluate to true. If `contained` returns true, it means the // column is only in the set of values so we can prune the // container Guarantee::NotIn => { @@ -747,6 +686,13 @@ fn is_always_true(expr: &Arc) -> bool { .unwrap_or_default() } +fn is_always_false(expr: &Arc) -> bool { + expr.as_any() + .downcast_ref::() + .map(|l| matches!(l.value(), ScalarValue::Boolean(Some(false)))) + .unwrap_or_default() +} + /// Describes which columns statistics are necessary to evaluate a /// [`PruningPredicate`]. /// @@ -794,6 +740,21 @@ impl RequiredColumns { } } + /// Returns a schema that describes the columns required to evaluate this + /// pruning predicate. + /// The schema contains the fields for each column in `self.columns` with + /// the appropriate data type for the statistics. + /// Order matters, this same order is used to evaluate the + /// pruning predicate. + fn schema(&self) -> Schema { + let fields = self + .columns + .iter() + .map(|(_c, _t, f)| f.clone()) + .collect::>(); + Schema::new(fields) + } + /// Returns an iterator over items in columns (see doc on /// `self.columns` for details) pub(crate) fn iter( @@ -915,7 +876,7 @@ impl From> for RequiredColumns { /// Build a RecordBatch from a list of statistics, creating arrays, /// with one row for each PruningStatistics and columns specified in -/// in the required_columns parameter. +/// the required_columns parameter. /// /// For example, if the requested columns are /// ```text @@ -938,11 +899,10 @@ impl From> for RequiredColumns { /// -------+-------- /// 5 | 1000 /// ``` -fn build_statistics_record_batch( +fn build_statistics_record_batch( statistics: &S, required_columns: &RequiredColumns, ) -> Result { - let mut fields = Vec::::new(); let mut arrays = Vec::::new(); // For each needed statistics column: for (column, statistics_type, stat_field) in required_columns.iter() { @@ -971,20 +931,15 @@ fn build_statistics_record_batch( // provides timestamp statistics as "Int64") let array = arrow::compute::cast(&array, data_type)?; - fields.push(stat_field.clone()); arrays.push(array); } - let schema = Arc::new(Schema::new(fields)); + let schema = Arc::new(required_columns.schema()); // provide the count in case there were no needed statistics let mut options = RecordBatchOptions::default(); options.row_count = Some(statistics.num_containers()); - trace!( - "Creating statistics batch for {:#?} with {:#?}", - required_columns, - arrays - ); + trace!("Creating statistics batch for {required_columns:#?} with {arrays:#?}"); RecordBatch::try_new_with_options(schema, arrays, &options).map_err(|err| { plan_datafusion_err!("Can not create statistics record batch: {err}") @@ -1005,7 +960,7 @@ impl<'a> PruningExpressionBuilder<'a> { left: &'a Arc, right: &'a Arc, op: Operator, - schema: &'a Schema, + schema: &'a SchemaRef, required_columns: &'a mut RequiredColumns, ) -> Result { // find column name; input could be a more complicated expression @@ -1023,7 +978,7 @@ impl<'a> PruningExpressionBuilder<'a> { } }; - let df_schema = DFSchema::try_from(schema.clone())?; + let df_schema = DFSchema::try_from(Arc::clone(schema))?; let (column_expr, correct_operator, scalar_expr) = rewrite_expr_to_prunable( column_expr, correct_operator, @@ -1139,8 +1094,8 @@ fn rewrite_expr_to_prunable( Ok((Arc::clone(column_expr), op, Arc::clone(scalar_expr))) } else if let Some(cast) = column_expr_any.downcast_ref::() { // `cast(col) op lit()` - let arrow_schema: SchemaRef = schema.clone().into(); - let from_type = cast.expr().data_type(&arrow_schema)?; + let arrow_schema = schema.as_arrow(); + let from_type = cast.expr().data_type(arrow_schema)?; verify_support_type_for_prune(&from_type, cast.cast_type())?; let (left, op, right) = rewrite_expr_to_prunable(cast.expr(), op, scalar_expr, schema)?; @@ -1154,8 +1109,8 @@ fn rewrite_expr_to_prunable( column_expr_any.downcast_ref::() { // `try_cast(col) op lit()` - let arrow_schema: SchemaRef = schema.clone().into(); - let from_type = try_cast.expr().data_type(&arrow_schema)?; + let arrow_schema = schema.as_arrow(); + let from_type = try_cast.expr().data_type(arrow_schema)?; verify_support_type_for_prune(&from_type, try_cast.cast_type())?; let (left, op, right) = rewrite_expr_to_prunable(try_cast.expr(), op, scalar_expr, schema)?; @@ -1206,23 +1161,35 @@ fn is_compare_op(op: Operator) -> bool { ) } +fn is_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + // The pruning logic is based on the comparing the min/max bounds. // Must make sure the two type has order. // For example, casts from string to numbers is not correct. // Because the "13" is less than "3" with UTF8 comparison order. fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Result<()> { - // TODO: support other data type for prunable cast or try cast - if matches!( - from_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - ) && matches!( - to_type, - DataType::Int8 | DataType::Int32 | DataType::Int64 | DataType::Decimal128(_, _) - ) { + // Dictionary casts are always supported as long as the value types are supported + let from_type = match from_type { + DataType::Dictionary(_, t) => { + return verify_support_type_for_prune(t.as_ref(), to_type) + } + _ => from_type, + }; + let to_type = match to_type { + DataType::Dictionary(_, t) => { + return verify_support_type_for_prune(from_type, t.as_ref()) + } + _ => to_type, + }; + // If both types are strings or both are not strings (number, timestamp, etc) + // then we can compare them. + // PruningPredicate does not support casting of strings to numbers and such. + if is_string_type(from_type) == is_string_type(to_type) { Ok(()) } else { plan_err!( @@ -1251,9 +1218,9 @@ fn rewrite_column_expr( fn reverse_operator(op: Operator) -> Result { op.swap().ok_or_else(|| { - DataFusionError::Internal(format!( + internal_datafusion_err!( "Could not reverse operator {op} while building pruning predicate" - )) + ) }) } @@ -1401,7 +1368,7 @@ impl PredicateRewriter { let mut required_columns = RequiredColumns::new(); build_predicate_expression( expr, - schema, + &Arc::new(schema.clone()), &mut required_columns, &self.unhandled_hook, ) @@ -1419,10 +1386,15 @@ impl PredicateRewriter { /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` fn build_predicate_expression( expr: &Arc, - schema: &Schema, + schema: &SchemaRef, required_columns: &mut RequiredColumns, unhandled_hook: &Arc, ) -> Arc { + if is_always_false(expr) { + // Shouldn't return `unhandled_hook.handle(expr)` + // Because it will transfer false to true. + return Arc::clone(expr); + } // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { @@ -1522,6 +1494,11 @@ fn build_predicate_expression( build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { + (left, Operator::And, right) + if is_always_false(left) || is_always_false(right) => + { + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))) + } (left, Operator::And, _) if is_always_true(left) => right_expr, (_, Operator::And, right) if is_always_true(right) => left_expr, (left, Operator::Or, right) @@ -1529,6 +1506,9 @@ fn build_predicate_expression( { Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) } + (left, Operator::Or, _) if is_always_false(left) => right_expr, + (_, Operator::Or, right) if is_always_false(right) => left_expr, + _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; return expr; @@ -1540,7 +1520,10 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => return unhandled_hook.handle(expr), + Err(e) => { + debug!("Error building pruning expression: {e}"); + return unhandled_hook.handle(expr); + } }; build_statistics_expr(&mut expr_builder) @@ -1885,7 +1868,7 @@ mod tests { use super::*; use datafusion_common::test_util::batches_to_string; - use datafusion_expr::{col, lit}; + use datafusion_expr::{and, col, lit, or}; use insta::assert_snapshot; use arrow::array::Decimal128Array; @@ -2301,8 +2284,7 @@ mod tests { let was_new = fields.insert(field); if !was_new { panic!( - "Duplicate field in required schema: {:?}. Previous fields:\n{:#?}", - field, fields + "Duplicate field in required schema: {field:?}. Previous fields:\n{fields:#?}" ); } } @@ -2807,8 +2789,8 @@ mod tests { let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut required_columns); assert_eq!(predicate_expr.to_string(), expected_expr); - println!("required_columns: {:#?}", required_columns); // for debugging assertions below - // c1 < 1 should add c1_min + println!("required_columns: {required_columns:#?}"); // for debugging assertions below + // c1 < 1 should add c1_min let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( required_columns.columns[0], @@ -3002,7 +2984,7 @@ mod tests { } #[test] - fn row_group_predicate_cast() -> Result<()> { + fn row_group_predicate_cast_int_int() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; @@ -3039,6 +3021,291 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_cast_string_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_cast_string_int() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(1)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Int32(Some(1))).eq(cast(col("c1"), DataType::Int32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_cast_int_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_date() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date64).eq(lit(ScalarValue::Date64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date64(Some(123))).eq(cast(col("c1"), DataType::Date64)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_string_date() -> Result<()> { + // Test with Dictionary for the literal + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + ) + .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + )); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_dict_string() -> Result<()> { + // Test with Dictionary for the column + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]); + let expected_expr = "true"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_dict_same_value_type() -> Result<()> { + // Test with Dictionary types that have the same value type but different key types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]); + + // Direct comparison with no cast + let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected_expr = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; + assert_eq!(predicate_expr.to_string(), expected_expr); + + // Test with column cast to a dictionary with different key type + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + ) + .eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Utf8)) <= test AND test <= CAST(c1_max@1 AS Dictionary(UInt16, Utf8))"; + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_dict_different_value_type() -> Result<()> { + // Test with Dictionary types that have different value types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Int32)), + false, + )]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)"; + + // Test with literal of a different type + let expr = + cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_nested_dict() -> Result<()> { + // Test with nested Dictionary types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary( + Box::new(DataType::UInt8), + Box::new(DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + )), + ), + false, + )]); + let expected_expr = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; + + // Test with a simple literal + let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_date_dict_date() -> Result<()> { + // Test with dictionary-wrapped date types for both sides + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Date32)), + false, + )]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Date64)) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Dictionary(UInt16, Date64))"; + + // Test with a cast to a different date type + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Date64)), + ) + .eq(lit(ScalarValue::Date64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_string_date() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + #[test] fn row_group_predicate_cast_list() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); @@ -3281,12 +3548,10 @@ mod tests { prune_with_expr( // false - // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is - // "all true" lit(false), &schema, &statistics, - &[true, true, true, true, true], + &[false, false, false, false, false], ); } @@ -4158,7 +4423,7 @@ mod tests { // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) true, // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate - // orignal (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") + // original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") true, ]; prune_with_expr(expr, &schema, &statistics, expected_ret); @@ -4851,7 +5116,7 @@ mod tests { statistics: &TestStatistics, expected: &[bool], ) { - println!("Pruning with expr: {}", expr); + println!("Pruning with expr: {expr}"); let expr = logical2physical(&expr, schema); let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); let result = p.prune(statistics).unwrap(); @@ -4865,6 +5130,49 @@ mod tests { ) -> Arc { let expr = logical2physical(expr, schema); let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; - build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) + build_predicate_expression( + &expr, + &Arc::new(schema.clone()), + required_columns, + &unhandled_hook, + ) + } + + #[test] + fn test_build_predicate_expression_with_false() { + let expr = lit(ScalarValue::Boolean(Some(false))); + let schema = Schema::empty(); + let res = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected = logical2physical(&expr, &schema); + assert_eq!(&res, &expected); + } + + #[test] + fn test_build_predicate_expression_with_and_false() { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expr = and( + col("c1").eq(lit("a")), + lit(ScalarValue::Boolean(Some(false))), + ); + let res = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected = logical2physical(&lit(ScalarValue::Boolean(Some(false))), &schema); + assert_eq!(&res, &expected); + } + + #[test] + fn test_build_predicate_expression_with_or_false() { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let left_expr = col("c1").eq(lit("a")); + let right_expr = lit(ScalarValue::Boolean(Some(false))); + let res = test_build_predicate_expression( + &or(left_expr.clone(), right_expr.clone()), + &schema, + &mut RequiredColumns::new(), + ); + let expected = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= a AND a <= c1_max@1"; + assert_eq!(res.to_string(), expected); } } diff --git a/datafusion/session/Cargo.toml b/datafusion/session/Cargo.toml index c6e268735a7b3..0489da61eed86 100644 --- a/datafusion/session/Cargo.toml +++ b/datafusion/session/Cargo.toml @@ -18,11 +18,11 @@ [package] name = "datafusion-session" description = "datafusion-session" +readme = "README.md" authors.workspace = true edition.workspace = true homepage.workspace = true license.workspace = true -readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true @@ -31,22 +31,12 @@ version.workspace = true all-features = true [dependencies] -arrow = { workspace = true } async-trait = { workspace = true } -dashmap = { workspace = true } datafusion-common = { workspace = true } -datafusion-common-runtime = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } -datafusion-sql = { workspace = true } -futures = { workspace = true } -itertools = { workspace = true } -log = { workspace = true } -object_store = { workspace = true } parking_lot = { workspace = true } -tokio = { workspace = true } [lints] workspace = true diff --git a/datafusion/session/README.md b/datafusion/session/README.md index 019f9f8892476..4bb605b1e199c 100644 --- a/datafusion/session/README.md +++ b/datafusion/session/README.md @@ -17,10 +17,16 @@ under the License. --> -# DataFusion Session +# Apache DataFusion Session -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. This crate provides **session-related abstractions** used in the DataFusion query engine. A _session_ represents the runtime context for query execution, including configuration, runtime environment, function registry, and planning. -[df]: https://crates.io/crates/datafusion +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml new file mode 100644 index 0000000000000..b95cc31caec68 --- /dev/null +++ b/datafusion/spark/Cargo.toml @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-spark" +description = "DataFusion expressions that emulate Apache Spark's behavior" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = "README.md" +license = { workspace = true } +edition = { workspace = true } + +[package.metadata.docs.rs] +all-features = true + +[lints] +workspace = true + +[lib] +name = "datafusion_spark" + +[dependencies] +arrow = { workspace = true } +bigdecimal = { workspace = true } +chrono = { workspace = true } +crc32fast = "1.4" +datafusion-catalog = { workspace = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true, features = ["crypto_expressions"] } +log = { workspace = true } +sha1 = "0.10" +url = { workspace = true } + +[dev-dependencies] +criterion = { workspace = true } +rand = { workspace = true } + +[[bench]] +harness = false +name = "char" diff --git a/datafusion/spark/LICENSE.txt b/datafusion/spark/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/spark/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/spark/NOTICE.txt b/datafusion/spark/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/spark/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/spark/README.md b/datafusion/spark/README.md new file mode 100644 index 0000000000000..7cb24084cd228 --- /dev/null +++ b/datafusion/spark/README.md @@ -0,0 +1,49 @@ + + +# Apache DataFusion Spark-compatible Expressions + +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. + +This crate is a submodule of DataFusion that provides [Apache Spark] compatible expressions for use with DataFusion. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[apache spark]: https://spark.apache.org/ + +## Testing Guide + +When testing functions by directly invoking them (e.g., `test_scalar_function!()`), input coercion (from the `signature` +or `coerce_types`) is not applied. + +Therefore, direct invocation tests should only be used to verify that the function is correctly implemented. + +Please be sure to add additional tests beyond direct invocation. +For more detailed testing guidelines, refer to the [Spark SQLLogicTest README]. + +## Implementation References + +When implementing Spark-compatible functions, you can check if there are existing implementations in +the [Sail] or [Comet] projects first. +If you do port functionality from these sources, make sure to port over the corresponding tests too, to ensure +correctness and compatibility. + +[spark sqllogictest readme]: ../sqllogictest/test_files/spark/README.md +[sail]: https://github.com/lakehq/sail +[comet]: https://github.com/apache/datafusion-comet diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs new file mode 100644 index 0000000000000..e30e21f69d183 --- /dev/null +++ b/datafusion/spark/benches/char.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::datatypes::{DataType, Field}; +use arrow::{array::PrimitiveArray, datatypes::Int64Type}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_spark::function::string::char; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; + +/// Returns fixed seedable RNG +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn criterion_benchmark(c: &mut Criterion) { + let cot_fn = char(); + let size = 1024; + let input: PrimitiveArray = { + let null_density = 0.2; + let mut rng = StdRng::seed_from_u64(42); + (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range::(1i64..10_000)) + } + }) + .collect() + }; + let input = Arc::new(input); + let args = vec![ColumnarValue::Array(input)]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("char", |b| { + b.iter(|| { + black_box( + cot_fn + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs new file mode 100644 index 0000000000000..a22561ba8b9ca --- /dev/null +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -0,0 +1,350 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrowNativeTypeOp; +use arrow::array::{ + builder::PrimitiveBuilder, + cast::AsArray, + types::{Float64Type, Int64Type}, + Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray, +}; +use arrow::compute::sum; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::coerce_avg_type; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ + type_coercion::aggregates::avg_return_type, Accumulator, AggregateUDFImpl, EmitTo, + GroupsAccumulator, ReversedUDAF, Signature, +}; +use std::{any::Any, sync::Arc}; +use DataType::*; + +/// AVG aggregate expression +/// Spark average aggregate expression. Differs from standard DataFusion average aggregate +/// in that it uses an `i64` for the count (DataFusion version uses `u64`); also there is ANSI mode +/// support planned in the future for Spark version. + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SparkAvg { + name: String, + signature: Signature, + input_data_type: DataType, + result_data_type: DataType, +} + +impl SparkAvg { + /// Implement AVG aggregate function + pub fn new(name: impl Into, data_type: DataType) -> Self { + let result_data_type = avg_return_type("avg", &data_type).unwrap(); + + Self { + name: name.into(), + signature: Signature::user_defined(Immutable), + input_data_type: data_type, + result_data_type, + } + } +} + +impl AggregateUDFImpl for SparkAvg { + fn as_any(&self) -> &dyn Any { + self + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + // instantiate specialized accumulator based for the type + match (&self.input_data_type, &self.result_data_type) { + (Float64, Float64) => Ok(Box::::default()), + _ => not_impl_err!( + "AvgAccumulator for ({} --> {})", + self.input_data_type, + self.result_data_type + ), + } + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Arc::new(Field::new( + format_state_name(&self.name, "sum"), + self.input_data_type.clone(), + true, + )), + Arc::new(Field::new( + format_state_name(&self.name, "count"), + Int64, + true, + )), + ]) + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // instantiate specialized accumulator based for the type + match (&self.input_data_type, &self.result_data_type) { + (Float64, Float64) => { + Ok(Box::new(AvgGroupsAccumulator::::new( + &self.input_data_type, + |sum: f64, count: i64| Ok(sum / count as f64), + ))) + } + + _ => not_impl_err!( + "AvgGroupsAccumulator for ({} --> {})", + self.input_data_type, + self.result_data_type + ), + } + } + + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg] = take_function_args(self.name(), arg_types)?; + coerce_avg_type(self.name(), std::slice::from_ref(arg)) + } +} + +/// An accumulator to compute the average +#[derive(Debug, Default)] +pub struct AvgAccumulator { + sum: Option, + count: i64, +} + +impl Accumulator for AvgAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::Float64(self.sum), + ScalarValue::from(self.count), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count += (values.len() - values.null_count()) as i64; + let v = self.sum.get_or_insert(0.); + if let Some(x) = sum(values) { + *v += x; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // counts are summed + self.count += sum(states[1].as_primitive::()).unwrap_or_default(); + + // sums are summed + if let Some(x) = sum(states[0].as_primitive::()) { + let v = self.sum.get_or_insert(0.); + *v += x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + if self.count == 0 { + // If all input are nulls, count will be 0 and we will get null after the division. + // This is consistent with Spark Average implementation. + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64( + self.sum.map(|f| f / self.count as f64), + )) + } + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + +/// An accumulator to compute the average of `[PrimitiveArray]`. +/// Stores values as native types, and does overflow checking +/// +/// F: Function that calculates the average value from a sum of +/// T::Native and a total count +#[derive(Debug)] +struct AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, i64) -> Result + Send, +{ + /// The type of the returned average + return_data_type: DataType, + + /// Count per group (use i64 to make Int64Array) + counts: Vec, + + /// Sums per group, stored as the native type + sums: Vec, + + /// Function that computes the final average (value / count) + avg_fn: F, +} + +impl AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, i64) -> Result + Send, +{ + pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { + Self { + return_data_type: return_data_type.clone(), + counts: vec![], + sums: vec![], + avg_fn, + } + } +} + +impl GroupsAccumulator for AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, i64) -> Result + Send, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + let data = values.values(); + + // increment counts, update sums + self.counts.resize(total_num_groups, 0); + self.sums.resize(total_num_groups, T::default_value()); + + let iter = group_indices.iter().zip(data.iter()); + if values.null_count() == 0 { + for (&group_index, &value) in iter { + let sum = &mut self.sums[group_index]; + *sum = (*sum).add_wrapping(value); + self.counts[group_index] += 1; + } + } else { + for (idx, (&group_index, &value)) in iter.enumerate() { + if values.is_null(idx) { + continue; + } + let sum = &mut self.sums[group_index]; + *sum = (*sum).add_wrapping(value); + + self.counts[group_index] += 1; + } + } + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is partial sums, second is counts + let partial_sums = values[0].as_primitive::(); + let partial_counts = values[1].as_primitive::(); + // update counts with partial counts + self.counts.resize(total_num_groups, 0); + let iter1 = group_indices.iter().zip(partial_counts.values().iter()); + for (&group_index, &partial_count) in iter1 { + self.counts[group_index] += partial_count; + } + + // update sums + self.sums.resize(total_num_groups, T::default_value()); + let iter2 = group_indices.iter().zip(partial_sums.values().iter()); + for (&group_index, &new_value) in iter2 { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + let sums = emit_to.take_needed(&mut self.sums); + let mut builder = PrimitiveBuilder::::with_capacity(sums.len()); + let iter = sums.into_iter().zip(counts); + + for (sum, count) in iter { + if count != 0 { + builder.append_value((self.avg_fn)(sum, count)?) + } else { + builder.append_null(); + } + } + let array: PrimitiveArray = builder.finish(); + + Ok(Arc::new(array)) + } + + // return arrays for sums and counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let counts = Int64Array::new(counts.into(), None); + + let sums = emit_to.take_needed(&mut self.sums); + let sums = PrimitiveArray::::new(sums.into(), None) + .with_data_type(self.return_data_type.clone()); + + Ok(vec![ + Arc::new(sums) as ArrayRef, + Arc::new(counts) as ArrayRef, + ]) + } + + fn size(&self) -> usize { + self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() + } +} diff --git a/datafusion/spark/src/function/aggregate/mod.rs b/datafusion/spark/src/function/aggregate/mod.rs new file mode 100644 index 0000000000000..54001d28da6b4 --- /dev/null +++ b/datafusion/spark/src/function/aggregate/mod.rs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_expr::AggregateUDF; +use std::sync::Arc; + +pub mod avg; +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((avg, "Returns the average value of a given column", arg1)); +} + +pub fn avg() -> Arc { + Arc::new(AggregateUDF::new_from_impl(avg::SparkAvg::new( + "avg", + DataType::Float64, + ))) +} + +pub fn functions() -> Vec> { + vec![avg()] +} diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs new file mode 100644 index 0000000000000..fed52a494281d --- /dev/null +++ b/datafusion/spark/src/function/array/mod.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod spark_array; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(spark_array::SparkArray, array); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((array, "Returns an array with the given elements.", args)); +} + +pub fn functions() -> Vec> { + vec![array()] +} diff --git a/datafusion/spark/src/function/array/spark_array.rs b/datafusion/spark/src/function/array/spark_array.rs new file mode 100644 index 0000000000000..bf5842cb5a5a6 --- /dev/null +++ b/datafusion/spark/src/function/array/spark_array.rs @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::array::{ + make_array, new_null_array, Array, ArrayData, ArrayRef, Capacities, GenericListArray, + MutableArrayData, NullArray, OffsetSizeTrait, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::utils::SingleRowListArrayBuilder; +use datafusion_common::{plan_datafusion_err, plan_err, Result}; +use datafusion_expr::type_coercion::binary::comparison_coercion; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; + +use crate::function::functions_nested_utils::make_scalar_function; + +const ARRAY_FIELD_DEFAULT_NAME: &str = "element"; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkArray { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkArray { + fn default() -> Self { + Self::new() + } +} + +impl SparkArray { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::UserDefined, TypeSignature::Nullary], + Volatility::Immutable, + ), + aliases: vec![String::from("spark_make_array")], + } + } +} + +impl ScalarUDFImpl for SparkArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let mut expr_type = DataType::Null; + for arg_type in arg_types { + if !arg_type.equals_datatype(&DataType::Null) { + expr_type = arg_type.clone(); + break; + } + } + + if expr_type.is_null() { + expr_type = DataType::Int32; + } + + Ok(DataType::List(Arc::new(Field::new( + ARRAY_FIELD_DEFAULT_NAME, + expr_type, + true, + )))) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let return_type = self.return_type(&data_types)?; + Ok(Arc::new(Field::new( + "this_field_name_is_irrelevant", + return_type, + false, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(make_array_inner)(args.as_slice()) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let first_type = arg_types.first().ok_or_else(|| { + plan_datafusion_err!("Spark array function requires at least one argument") + })?; + let new_type = + arg_types + .iter() + .skip(1) + .try_fold(first_type.clone(), |acc, x| { + // The coerced types found by `comparison_coercion` are not guaranteed to be + // coercible for the arguments. `comparison_coercion` returns more loose + // types that can be coerced to both `acc` and `x` for comparison purpose. + // See `maybe_data_types` for the actual coercion. + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + plan_err!("Coercion from {acc} to {x} failed.") + } + })?; + Ok(vec![new_type; arg_types.len()]) + } +} + +/// `make_array_inner` is the implementation of the `make_array` function. +/// Constructs an array using the input `data` as `ArrayRef`. +/// Returns a reference-counted `Array` instance result. +pub fn make_array_inner(arrays: &[ArrayRef]) -> Result { + let mut data_type = DataType::Null; + for arg in arrays { + let arg_data_type = arg.data_type(); + if !arg_data_type.equals_datatype(&DataType::Null) { + data_type = arg_data_type.clone(); + break; + } + } + + match data_type { + // Either an empty array or all nulls: + DataType::Null => { + let length = arrays.iter().map(|a| a.len()).sum(); + // By default Int32 + let array = new_null_array(&DataType::Int32, length); + Ok(Arc::new( + SingleRowListArrayBuilder::new(array) + .with_nullable(true) + .with_field_name(Some(ARRAY_FIELD_DEFAULT_NAME.to_string())) + .build_list_array(), + )) + } + DataType::LargeList(..) => array_array::(arrays, data_type), + _ => array_array::(arrays, data_type), + } +} + +/// Convert one or more [`ArrayRef`] of the same type into a +/// `ListArray` or 'LargeListArray' depending on the offset size. +/// +/// # Example (non nested) +/// +/// Calling `array(col1, col2)` where col1 and col2 are non nested +/// would return a single new `ListArray`, where each row was a list +/// of 2 elements: +/// +/// ```text +/// ┌─────────┐ ┌─────────┐ ┌──────────────┐ +/// │ ┌─────┐ │ │ ┌─────┐ │ │ ┌──────────┐ │ +/// │ │ A │ │ │ │ X │ │ │ │ [A, X] │ │ +/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ +/// │ │NULL │ │ │ │ Y │ │──────────▶│ │[NULL, Y] │ │ +/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ +/// │ │ C │ │ │ │ Z │ │ │ │ [C, Z] │ │ +/// │ └─────┘ │ │ └─────┘ │ │ └──────────┘ │ +/// └─────────┘ └─────────┘ └──────────────┘ +/// col1 col2 output +/// ``` +/// +/// # Example (nested) +/// +/// Calling `array(col1, col2)` where col1 and col2 are lists +/// would return a single new `ListArray`, where each row was a list +/// of the corresponding elements of col1 and col2. +/// +/// ``` text +/// ┌──────────────┐ ┌──────────────┐ ┌─────────────────────────────┐ +/// │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌────────────────────────┐ │ +/// │ │ [A, X] │ │ │ │ [] │ │ │ │ [[A, X], []] │ │ +/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────┤ │ +/// │ │[NULL, Y] │ │ │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │ │ +/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────│ │ +/// │ │ [C, Z] │ │ │ │ NULL │ │ │ │ [[C, Z], NULL] │ │ +/// │ └──────────┘ │ │ └──────────┘ │ │ └────────────────────────┘ │ +/// └──────────────┘ └──────────────┘ └─────────────────────────────┘ +/// col1 col2 output +/// ``` +fn array_array( + args: &[ArrayRef], + data_type: DataType, +) -> Result { + // do not accept 0 arguments. + if args.is_empty() { + return plan_err!("Array requires at least one argument"); + } + + let mut data = vec![]; + let mut total_len = 0; + for arg in args { + let arg_data = if arg.as_any().is::() { + ArrayData::new_empty(&data_type) + } else { + arg.to_data() + }; + total_len += arg_data.len(); + data.push(arg_data); + } + + let mut offsets: Vec = Vec::with_capacity(total_len); + offsets.push(O::usize_as(0)); + + let capacity = Capacities::Array(total_len); + let data_ref = data.iter().collect::>(); + let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); + + let num_rows = args[0].len(); + for row_idx in 0..num_rows { + for (arr_idx, arg) in args.iter().enumerate() { + if !arg.as_any().is::() + && !arg.is_null(row_idx) + && arg.is_valid(row_idx) + { + mutable.extend(arr_idx, row_idx, row_idx + 1); + } else { + mutable.extend_nulls(1); + } + } + offsets.push(O::usize_as(mutable.len())); + } + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new(ARRAY_FIELD_DEFAULT_NAME, data_type, true)), + OffsetBuffer::new(offsets.into()), + make_array(data), + None, + )?)) +} diff --git a/datafusion/spark/src/function/bitmap/bitmap_count.rs b/datafusion/spark/src/function/bitmap/bitmap_count.rs new file mode 100644 index 0000000000000..15bd33229a3d5 --- /dev/null +++ b/datafusion/spark/src/function/bitmap/bitmap_count.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, BinaryArray, BinaryViewArray, FixedSizeBinaryArray, Int64Array, + LargeBinaryArray, +}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{ + Binary, BinaryView, FixedSizeBinary, Int64, LargeBinary, +}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::downcast_arg; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct BitmapCount { + signature: Signature, +} + +impl Default for BitmapCount { + fn default() -> Self { + Self::new() + } +} + +impl BitmapCount { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Binary)], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for BitmapCount { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bitmap_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(bitmap_count_inner, vec![])(&args.args) + } +} + +fn binary_count_ones(opt: Option<&[u8]>) -> Option { + opt.map(|value| value.iter().map(|b| b.count_ones() as i64).sum()) +} + +macro_rules! downcast_and_count_ones { + ($input_array:expr, $array_type:ident) => {{ + let arr = downcast_arg!($input_array, $array_type); + Ok(arr.iter().map(binary_count_ones).collect::()) + }}; +} + +pub fn bitmap_count_inner(arg: &[ArrayRef]) -> Result { + let [input_array] = take_function_args("bitmap_count", arg)?; + + let res: Result = match &input_array.data_type() { + Binary => downcast_and_count_ones!(input_array, BinaryArray), + BinaryView => downcast_and_count_ones!(input_array, BinaryViewArray), + LargeBinary => downcast_and_count_ones!(input_array, LargeBinaryArray), + FixedSizeBinary(_size) => { + downcast_and_count_ones!(input_array, FixedSizeBinaryArray) + } + data_type => { + internal_err!("bitmap_count does not support {data_type}") + } + }; + + Ok(Arc::new(res?)) +} + +#[cfg(test)] +mod tests { + use crate::function::bitmap::bitmap_count::BitmapCount; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, Int64Array}; + use arrow::datatypes::DataType::Int64; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_bitmap_count_binary_invoke { + ($INPUT:expr, $EXPECTED:expr) => { + test_scalar_function!( + BitmapCount::new(), + vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))], + $EXPECTED, + i64, + Int64, + Int64Array + ); + + test_scalar_function!( + BitmapCount::new(), + vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))], + $EXPECTED, + i64, + Int64, + Int64Array + ); + + test_scalar_function!( + BitmapCount::new(), + vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))], + $EXPECTED, + i64, + Int64, + Int64Array + ); + + test_scalar_function!( + BitmapCount::new(), + vec![ColumnarValue::Scalar(ScalarValue::FixedSizeBinary( + $INPUT.map(|a| a.len()).unwrap_or(0) as i32, + $INPUT + ))], + $EXPECTED, + i64, + Int64, + Int64Array + ); + }; + } + + #[test] + fn test_bitmap_count_invoke() -> Result<()> { + test_bitmap_count_binary_invoke!(None::>, Ok(None)); + test_bitmap_count_binary_invoke!(Some(vec![0x0Au8]), Ok(Some(2))); + test_bitmap_count_binary_invoke!(Some(vec![0xFFu8, 0xFFu8]), Ok(Some(16))); + test_bitmap_count_binary_invoke!( + Some(vec![0x0Au8, 0xB0u8, 0xCDu8]), + Ok(Some(10)) + ); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/bitmap/mod.rs b/datafusion/spark/src/function/bitmap/mod.rs new file mode 100644 index 0000000000000..8532c32ac9c5f --- /dev/null +++ b/datafusion/spark/src/function/bitmap/mod.rs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod bitmap_count; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(bitmap_count::BitmapCount, bitmap_count); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + bitmap_count, + "Returns the number of set bits in the input bitmap.", + arg + )); +} + +pub fn functions() -> Vec> { + vec![bitmap_count()] +} diff --git a/datafusion/spark/src/function/bitwise/bit_count.rs b/datafusion/spark/src/function/bitwise/bit_count.rs new file mode 100644 index 0000000000000..ba44d3bc0a958 --- /dev/null +++ b/datafusion/spark/src/function/bitwise/bit_count.rs @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, Int32Array}; +use arrow::datatypes::{ + DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, + UInt64Type, UInt8Type, +}; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBitCount { + signature: Signature, +} + +impl Default for SparkBitCount { + fn default() -> Self { + Self::new() + } +} + +impl SparkBitCount { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Exact(vec![DataType::Int16]), + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int64]), + TypeSignature::Exact(vec![DataType::UInt8]), + TypeSignature::Exact(vec![DataType::UInt16]), + TypeSignature::Exact(vec![DataType::UInt32]), + TypeSignature::Exact(vec![DataType::UInt64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkBitCount { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bit_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) // Spark returns int (Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 1 { + return plan_err!("bit_count expects exactly 1 argument"); + } + + make_scalar_function(spark_bit_count, vec![])(&args.args) + } +} + +fn spark_bit_count(value_array: &[ArrayRef]) -> Result { + let value_array = value_array[0].as_ref(); + match value_array.data_type() { + DataType::Int8 => { + let result: Int32Array = value_array + .as_primitive::() + .unary(|v| v.count_ones() as i32); + Ok(Arc::new(result)) + } + DataType::Int16 => { + let result: Int32Array = value_array + .as_primitive::() + .unary(|v| v.count_ones() as i32); + Ok(Arc::new(result)) + } + DataType::Int32 => { + let result: Int32Array = value_array + .as_primitive::() + .unary(|v| v.count_ones() as i32); + Ok(Arc::new(result)) + } + DataType::Int64 => { + let result: Int32Array = value_array + .as_primitive::() + .unary(|v| v.count_ones() as i32); + Ok(Arc::new(result)) + } + DataType::UInt8 => { + let result: Int32Array = value_array + .as_primitive::() + .unary(|v| v.count_ones() as i32); + Ok(Arc::new(result)) + } + DataType::UInt16 => { + let result: Int32Array = value_array + .as_primitive::() + .unary(|v| v.count_ones() as i32); + Ok(Arc::new(result)) + } + DataType::UInt32 => { + let result: Int32Array = value_array + .as_primitive::() + .unary(|v| v.count_ones() as i32); + Ok(Arc::new(result)) + } + DataType::UInt64 => { + let result: Int32Array = value_array + .as_primitive::() + .unary(|v| v.count_ones() as i32); + Ok(Arc::new(result)) + } + _ => { + plan_err!( + "bit_count function does not support data type: {}", + value_array.data_type() + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, + }; + use arrow::datatypes::Int32Type; + + #[test] + fn test_bit_count_basic() { + // Test bit_count(0) - no bits set + let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![0]))]).unwrap(); + + assert_eq!(result.as_primitive::().value(0), 0); + + // Test bit_count(1) - 1 bit set + let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![1]))]).unwrap(); + + assert_eq!(result.as_primitive::().value(0), 1); + + // Test bit_count(7) - 7 = 111 in binary, 3 bits set + let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![7]))]).unwrap(); + + assert_eq!(result.as_primitive::().value(0), 3); + + // Test bit_count(15) - 15 = 1111 in binary, 4 bits set + let result = spark_bit_count(&[Arc::new(Int32Array::from(vec![15]))]).unwrap(); + + assert_eq!(result.as_primitive::().value(0), 4); + } + + #[test] + fn test_bit_count_int8() { + // Test bit_count on Int8Array + let result = + spark_bit_count(&[Arc::new(Int8Array::from(vec![0i8, 1, 3, 7, 15, -1]))]) + .unwrap(); + + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); + assert_eq!(arr.value(1), 1); + assert_eq!(arr.value(2), 2); + assert_eq!(arr.value(3), 3); + assert_eq!(arr.value(4), 4); + assert_eq!(arr.value(5), 8); + } + + #[test] + fn test_bit_count_int16() { + // Test bit_count on Int16Array + let result = + spark_bit_count(&[Arc::new(Int16Array::from(vec![0i16, 1, 255, 1023, -1]))]) + .unwrap(); + + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); + assert_eq!(arr.value(1), 1); + assert_eq!(arr.value(2), 8); + assert_eq!(arr.value(3), 10); + assert_eq!(arr.value(4), 16); + } + + #[test] + fn test_bit_count_int32() { + // Test bit_count on Int32Array + let result = + spark_bit_count(&[Arc::new(Int32Array::from(vec![0i32, 1, 255, 1023, -1]))]) + .unwrap(); + + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0 + assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1 + assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8 + assert_eq!(arr.value(3), 10); // 0b00000000000000000000001111111111 = 10 + assert_eq!(arr.value(4), 32); // -1 in two's complement = all 32 bits set + } + + #[test] + fn test_bit_count_int64() { + // Test bit_count on Int64Array + let result = + spark_bit_count(&[Arc::new(Int64Array::from(vec![0i64, 1, 255, 1023, -1]))]) + .unwrap(); + + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 0b0000000000000000000000000000000000000000000000000000000000000000 = 0 + assert_eq!(arr.value(1), 1); // 0b0000000000000000000000000000000000000000000000000000000000000001 = 1 + assert_eq!(arr.value(2), 8); // 0b0000000000000000000000000000000000000000000000000000000011111111 = 8 + assert_eq!(arr.value(3), 10); // 0b0000000000000000000000000000000000000000000000000000001111111111 = 10 + assert_eq!(arr.value(4), 64); // -1 in two's complement = all 64 bits set + } + + #[test] + fn test_bit_count_uint8() { + // Test bit_count on UInt8Array + let result = + spark_bit_count(&[Arc::new(UInt8Array::from(vec![0u8, 1, 255]))]).unwrap(); + + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 0b00000000 = 0 + assert_eq!(arr.value(1), 1); // 0b00000001 = 1 + assert_eq!(arr.value(2), 8); // 0b11111111 = 8 + } + + #[test] + fn test_bit_count_uint16() { + // Test bit_count on UInt16Array + let result = + spark_bit_count(&[Arc::new(UInt16Array::from(vec![0u16, 1, 255, 65535]))]) + .unwrap(); + + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 0b0000000000000000 = 0 + assert_eq!(arr.value(1), 1); // 0b0000000000000001 = 1 + assert_eq!(arr.value(2), 8); // 0b0000000011111111 = 8 + assert_eq!(arr.value(3), 16); // 0b1111111111111111 = 16 + } + + #[test] + fn test_bit_count_uint32() { + // Test bit_count on UInt32Array + let result = spark_bit_count(&[Arc::new(UInt32Array::from(vec![ + 0u32, 1, 255, 4294967295, + ]))]) + .unwrap(); + + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 0b00000000000000000000000000000000 = 0 + assert_eq!(arr.value(1), 1); // 0b00000000000000000000000000000001 = 1 + assert_eq!(arr.value(2), 8); // 0b00000000000000000000000011111111 = 8 + assert_eq!(arr.value(3), 32); // 0b11111111111111111111111111111111 = 32 + } + + #[test] + fn test_bit_count_uint64() { + // Test bit_count on UInt64Array + let result = spark_bit_count(&[Arc::new(UInt64Array::from(vec![ + 0u64, + 1, + 255, + 256, + u64::MAX, + ]))]) + .unwrap(); + + let arr = result.as_primitive::(); + // 0b0 = 0 + assert_eq!(arr.value(0), 0); + // 0b1 = 1 + assert_eq!(arr.value(1), 1); + // 0b11111111 = 8 + assert_eq!(arr.value(2), 8); + // 0b100000000 = 1 + assert_eq!(arr.value(3), 1); + // u64::MAX = all 64 bits set + assert_eq!(arr.value(4), 64); + } + + #[test] + fn test_bit_count_nulls() { + // Test bit_count with nulls + let arr = Int32Array::from(vec![Some(3), None, Some(7)]); + let result = spark_bit_count(&[Arc::new(arr)]).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 0b11 + assert!(arr.is_null(1)); + assert_eq!(arr.value(2), 3); // 0b111 + } +} diff --git a/datafusion/spark/src/function/bitwise/bit_get.rs b/datafusion/spark/src/function/bitwise/bit_get.rs new file mode 100644 index 0000000000000..a8562618cb8cb --- /dev/null +++ b/datafusion/spark/src/function/bitwise/bit_get.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; +use arrow::compute::try_binary; +use arrow::datatypes::DataType::{ + Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, +}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBitGet { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkBitGet { + fn default() -> Self { + Self::new() + } +} + +impl SparkBitGet { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["getbit".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkBitGet { + fn as_any(&self) -> &dyn Any { + self + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return Err(invalid_arg_count_exec_err( + "bit_get", + (2, 2), + arg_types.len(), + )); + } + if !arg_types[0].is_integer() && !arg_types[0].is_null() { + return Err(unsupported_data_type_exec_err( + "bit_get", + "Integer Type", + &arg_types[0], + )); + } + if !arg_types[1].is_integer() && !arg_types[1].is_null() { + return Err(unsupported_data_type_exec_err( + "bit_get", + "Integer Type", + &arg_types[1], + )); + } + if arg_types[0].is_null() { + return Ok(vec![Int8, Int32]); + } + Ok(vec![arg_types[0].clone(), Int32]) + } + + fn name(&self) -> &str { + "bit_get" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_bit_get, vec![])(&args.args) + } +} + +fn spark_bit_get_inner( + value: &PrimitiveArray, + pos: &PrimitiveArray, +) -> Result> { + let bit_length = (size_of::() * 8) as i32; + + let result: PrimitiveArray = try_binary(value, pos, |value, pos| { + if pos < 0 || pos >= bit_length { + return Err(arrow::error::ArrowError::ComputeError(format!( + "bit_get: position {pos} is out of bounds. Expected pos < {bit_length} and pos >= 0" + ))); + } + Ok(((value.to_i64().unwrap() >> pos) & 1) as i8) + })?; + Ok(result) +} + +pub fn spark_bit_get(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("`bit_get` expects exactly two arguments"); + } + + if args[1].data_type() != &Int32 { + return exec_err!("`bit_get` expects Int32 as the second argument"); + } + + let pos_arg = args[1].as_primitive::(); + + let ret = match &args[0].data_type() { + Int64 => { + let value_arg = args[0].as_primitive::(); + spark_bit_get_inner(value_arg, pos_arg) + } + Int32 => { + let value_arg = args[0].as_primitive::(); + spark_bit_get_inner(value_arg, pos_arg) + } + Int16 => { + let value_arg = args[0].as_primitive::(); + spark_bit_get_inner(value_arg, pos_arg) + } + Int8 => { + let value_arg = args[0].as_primitive::(); + spark_bit_get_inner(value_arg, pos_arg) + } + UInt64 => { + let value_arg = args[0].as_primitive::(); + spark_bit_get_inner(value_arg, pos_arg) + } + UInt32 => { + let value_arg = args[0].as_primitive::(); + spark_bit_get_inner(value_arg, pos_arg) + } + UInt16 => { + let value_arg = args[0].as_primitive::(); + spark_bit_get_inner(value_arg, pos_arg) + } + UInt8 => { + let value_arg = args[0].as_primitive::(); + spark_bit_get_inner(value_arg, pos_arg) + } + _ => { + exec_err!( + "`bit_get` expects Int64, Int32, Int16, or Int8 as the first argument" + ) + } + }?; + Ok(Arc::new(ret)) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Int32Array, Int64Array}; + + use super::*; + + #[test] + fn test_bit_get_basic() { + // Test bit_get(11, 0) - 11 = 1011 in binary, bit 0 = 1 + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![11])), + Arc::new(Int32Array::from(vec![0])), + ]) + .unwrap(); + + assert_eq!(result.as_primitive::().value(0), 1); + + // Test bit_get(11, 2) - 11 = 1011 in binary, bit 2 = 0 + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![11])), + Arc::new(Int32Array::from(vec![2])), + ]) + .unwrap(); + + assert_eq!(result.as_primitive::().value(0), 0); + + // Test bit_get(11, 3) - 11 = 1011 in binary, bit 3 = 1 + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![11])), + Arc::new(Int32Array::from(vec![3])), + ]) + .unwrap(); + + assert_eq!(result.as_primitive::().value(0), 1); + } + + #[test] + fn test_bit_get_edge_cases() { + // Test with 0 + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![0])), + Arc::new(Int32Array::from(vec![0])), + ]) + .unwrap(); + + assert_eq!(result.as_primitive::().value(0), 0); + + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![11])), + Arc::new(Int32Array::from(vec![-1])), + ]); + assert_eq!( + result.unwrap_err().message().lines().next().unwrap(), + "Compute error: bit_get: position -1 is out of bounds. Expected pos < 64 and pos >= 0" + ); + + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![11])), + Arc::new(Int32Array::from(vec![64])), + ]); + + assert_eq!( + result.unwrap_err().message().lines().next().unwrap(), + "Compute error: bit_get: position 64 is out of bounds. Expected pos < 64 and pos >= 0" + ); + } + + #[test] + fn test_bit_get_null_inputs() { + // Test with NULL value + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![None])), + Arc::new(Int32Array::from(vec![0])), + ]) + .unwrap(); + + assert_eq!(result.as_primitive::().value(0), 0); + + // Test with NULL position + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![11])), + Arc::new(Int32Array::from(vec![None])), + ]) + .unwrap(); + + assert_eq!(result.as_primitive::().value(0), 0); + } + + #[test] + fn test_bit_get_large_numbers() { + // Test with larger number + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![255])), // 11111111 in binary + Arc::new(Int32Array::from(vec![7])), // bit 7 = 1 + ]) + .unwrap(); + + assert_eq!(result.as_primitive::().value(0), 1); + + let result = spark_bit_get(&[ + Arc::new(Int64Array::from(vec![255])), // 11111111 in binary + Arc::new(Int32Array::from(vec![8])), // bit 8 = 0 + ]) + .unwrap(); + + assert_eq!(result.as_primitive::().value(0), 0); + } +} diff --git a/datafusion/spark/src/function/bitwise/bit_shift.rs b/datafusion/spark/src/function/bitwise/bit_shift.rs new file mode 100644 index 0000000000000..bb645b7660584 --- /dev/null +++ b/datafusion/spark/src/function/bitwise/bit_shift.rs @@ -0,0 +1,740 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; +use arrow::compute; +use arrow::datatypes::{ + ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type, +}; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; + +/// Performs a bitwise left shift on each element of the `value` array by the corresponding amount in the `shift` array. +/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts. +/// +/// # Arguments +/// * `value` - The array of values to shift. +/// * `shift` - The array of shift amounts (must be Int32). +/// +/// # Returns +/// A new array with the shifted values. +/// +fn shift_left( + value: &PrimitiveArray, + shift: &PrimitiveArray, +) -> Result> +where + T::Native: ArrowNativeType + std::ops::Shl, +{ + let bit_num = (T::Native::get_byte_width() * 8) as i32; + let result = compute::binary::<_, Int32Type, _, _>( + value, + shift, + |value: T::Native, shift: i32| { + let shift = ((shift % bit_num) + bit_num) % bit_num; + value << shift + }, + )?; + Ok(result) +} + +/// Performs a bitwise right shift on each element of the `value` array by the corresponding amount in the `shift` array. +/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts. +/// +/// # Arguments +/// * `value` - The array of values to shift. +/// * `shift` - The array of shift amounts (must be Int32). +/// +/// # Returns +/// A new array with the shifted values. +/// +fn shift_right( + value: &PrimitiveArray, + shift: &PrimitiveArray, +) -> Result> +where + T::Native: ArrowNativeType + std::ops::Shr, +{ + let bit_num = (T::Native::get_byte_width() * 8) as i32; + let result = compute::binary::<_, Int32Type, _, _>( + value, + shift, + |value: T::Native, shift: i32| { + let shift = ((shift % bit_num) + bit_num) % bit_num; + value >> shift + }, + )?; + Ok(result) +} + +/// Trait for performing an unsigned right shift (logical shift right). +/// This is used to mimic Java's `>>>` operator, which does not exist in Rust. +/// For unsigned types, this is just the normal right shift. +/// For signed types, this casts to the unsigned type, shifts, then casts back. +trait UShr { + fn ushr(self, rhs: Rhs) -> Self; +} + +impl UShr for u32 { + fn ushr(self, rhs: i32) -> Self { + self >> rhs + } +} + +impl UShr for u64 { + fn ushr(self, rhs: i32) -> Self { + self >> rhs + } +} + +impl UShr for i32 { + fn ushr(self, rhs: i32) -> Self { + ((self as u32) >> rhs) as i32 + } +} + +impl UShr for i64 { + fn ushr(self, rhs: i32) -> Self { + ((self as u64) >> rhs) as i64 + } +} + +/// Performs a bitwise unsigned right shift on each element of the `value` array by the corresponding amount in the `shift` array. +/// The shift amount is normalized to the bit width of the type, matching Spark/Java semantics for negative and large shifts. +/// +/// # Arguments +/// * `value` - The array of values to shift. +/// * `shift` - The array of shift amounts (must be Int32). +/// +/// # Returns +/// A new array with the shifted values. +/// +fn shift_right_unsigned( + value: &PrimitiveArray, + shift: &PrimitiveArray, +) -> Result> +where + T::Native: ArrowNativeType + UShr, +{ + let bit_num = (T::Native::get_byte_width() * 8) as i32; + let result = compute::binary::<_, Int32Type, _, _>( + value, + shift, + |value: T::Native, shift: i32| { + let shift = ((shift % bit_num) + bit_num) % bit_num; + value.ushr(shift) + }, + )?; + Ok(result) +} + +trait BitShiftUDF: ScalarUDFImpl { + fn shift( + &self, + value: &PrimitiveArray, + shift: &PrimitiveArray, + ) -> Result> + where + T::Native: ArrowNativeType + + std::ops::Shl + + std::ops::Shr + + UShr; + + fn spark_shift(&self, arrays: &[ArrayRef]) -> Result { + let value_array = arrays[0].as_ref(); + let shift_array = arrays[1].as_ref(); + + // Ensure shift array is Int32 + let shift_array = if shift_array.data_type() != &DataType::Int32 { + return plan_err!("{} shift amount must be Int32", self.name()); + } else { + shift_array.as_primitive::() + }; + + match value_array.data_type() { + DataType::Int32 => { + let value_array = value_array.as_primitive::(); + Ok(Arc::new(self.shift(value_array, shift_array)?)) + } + DataType::Int64 => { + let value_array = value_array.as_primitive::(); + Ok(Arc::new(self.shift(value_array, shift_array)?)) + } + DataType::UInt32 => { + let value_array = value_array.as_primitive::(); + Ok(Arc::new(self.shift(value_array, shift_array)?)) + } + DataType::UInt64 => { + let value_array = value_array.as_primitive::(); + Ok(Arc::new(self.shift(value_array, shift_array)?)) + } + _ => { + plan_err!( + "{} function does not support data type: {}", + self.name(), + value_array.data_type() + ) + } + } + } +} + +fn bit_shift_coerce_types(arg_types: &[DataType], func: &str) -> Result> { + if arg_types.len() != 2 { + return Err(invalid_arg_count_exec_err(func, (2, 2), arg_types.len())); + } + if !arg_types[0].is_integer() && !arg_types[0].is_null() { + return Err(unsupported_data_type_exec_err( + func, + "Integer Type", + &arg_types[0], + )); + } + if !arg_types[1].is_integer() && !arg_types[1].is_null() { + return Err(unsupported_data_type_exec_err( + func, + "Integer Type", + &arg_types[1], + )); + } + + // Coerce smaller integer types to Int32 + let coerced_first = match &arg_types[0] { + DataType::Int8 | DataType::Int16 | DataType::Null => DataType::Int32, + DataType::UInt8 | DataType::UInt16 => DataType::UInt32, + _ => arg_types[0].clone(), + }; + + Ok(vec![coerced_first, DataType::Int32]) +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkShiftLeft { + signature: Signature, +} + +impl Default for SparkShiftLeft { + fn default() -> Self { + Self::new() + } +} + +impl SparkShiftLeft { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl BitShiftUDF for SparkShiftLeft { + fn shift( + &self, + value: &PrimitiveArray, + shift: &PrimitiveArray, + ) -> Result> + where + T::Native: ArrowNativeType + + std::ops::Shl + + std::ops::Shr + + UShr, + { + shift_left(value, shift) + } +} + +impl ScalarUDFImpl for SparkShiftLeft { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "shiftleft" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + bit_shift_coerce_types(arg_types, "shiftleft") + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return plan_err!("shiftleft expects exactly 2 arguments"); + } + // Return type is the same as the first argument (the value to shift) + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return plan_err!("shiftleft expects exactly 2 arguments"); + } + let inner = |arr: &[ArrayRef]| -> Result { self.spark_shift(arr) }; + make_scalar_function(inner, vec![])(&args.args) + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkShiftRightUnsigned { + signature: Signature, +} + +impl Default for SparkShiftRightUnsigned { + fn default() -> Self { + Self::new() + } +} + +impl SparkShiftRightUnsigned { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl BitShiftUDF for SparkShiftRightUnsigned { + fn shift( + &self, + value: &PrimitiveArray, + shift: &PrimitiveArray, + ) -> Result> + where + T::Native: ArrowNativeType + + std::ops::Shl + + std::ops::Shr + + UShr, + { + shift_right_unsigned(value, shift) + } +} + +impl ScalarUDFImpl for SparkShiftRightUnsigned { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "shiftrightunsigned" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + bit_shift_coerce_types(arg_types, "shiftrightunsigned") + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return plan_err!("shiftrightunsigned expects exactly 2 arguments"); + } + // Return type is the same as the first argument (the value to shift) + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return plan_err!("shiftrightunsigned expects exactly 2 arguments"); + } + let inner = |arr: &[ArrayRef]| -> Result { self.spark_shift(arr) }; + make_scalar_function(inner, vec![])(&args.args) + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkShiftRight { + signature: Signature, +} + +impl Default for SparkShiftRight { + fn default() -> Self { + Self::new() + } +} + +impl SparkShiftRight { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl BitShiftUDF for SparkShiftRight { + fn shift( + &self, + value: &PrimitiveArray, + shift: &PrimitiveArray, + ) -> Result> + where + T::Native: ArrowNativeType + + std::ops::Shl + + std::ops::Shr + + UShr, + { + shift_right(value, shift) + } +} + +impl ScalarUDFImpl for SparkShiftRight { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "shiftright" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + bit_shift_coerce_types(arg_types, "shiftright") + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return plan_err!("shiftright expects exactly 2 arguments"); + } + // Return type is the same as the first argument (the value to shift) + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return plan_err!("shiftright expects exactly 2 arguments"); + } + let inner = |arr: &[ArrayRef]| -> Result { self.spark_shift(arr) }; + make_scalar_function(inner, vec![])(&args.args) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Int32Array, Int64Array, UInt32Array, UInt64Array}; + + #[test] + fn test_shift_right_unsigned_int32() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16, 32])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2 + assert_eq!(arr.value(3), 2); // 32 >>> 4 = 2 + } + + #[test] + fn test_shift_right_unsigned_int64() { + let value_array = Arc::new(Int64Array::from(vec![4i64, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2 + } + + #[test] + fn test_shift_right_unsigned_uint32() { + let value_array = Arc::new(UInt32Array::from(vec![4u32, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2 + } + + #[test] + fn test_shift_right_unsigned_uint64() { + let value_array = Arc::new(UInt64Array::from(vec![4u64, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >>> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >>> 3 = 2 + } + + #[test] + fn test_shift_right_unsigned_nulls() { + let value_array = Arc::new(Int32Array::from(vec![Some(4), None, Some(8)])); + let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >>> 1 = 2 + assert!(arr.is_null(1)); // null >>> 2 = null + assert!(arr.is_null(2)); // 8 >>> null = null + } + + #[test] + fn test_shift_right_unsigned_negative_shift() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 4 >>> -1 = 0 + assert_eq!(arr.value(1), 0); // 8 >>> -2 = 0 + assert_eq!(arr.value(2), 0); // 16 >>> -3 = 0 + } + + #[test] + fn test_shift_right_unsigned_negative_values() { + let value_array = Arc::new(Int32Array::from(vec![-4, -8, -16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRightUnsigned::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + // For unsigned right shift, negative values are treated as large positive values + // -4 as u32 = 4294967292, -4 >>> 1 = 2147483646 + assert_eq!(arr.value(0), 2147483646); + // -8 as u32 = 4294967288, -8 >>> 2 = 1073741822 + assert_eq!(arr.value(1), 1073741822); + // -16 as u32 = 4294967280, -16 >>> 3 = 536870910 + assert_eq!(arr.value(2), 536870910); + } + + #[test] + fn test_shift_right_int32() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16, 32])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >> 3 = 2 + assert_eq!(arr.value(3), 2); // 32 >> 4 = 2 + } + + #[test] + fn test_shift_right_int64() { + let value_array = Arc::new(Int64Array::from(vec![4i64, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >> 3 = 2 + } + + #[test] + fn test_shift_right_uint32() { + let value_array = Arc::new(UInt32Array::from(vec![4u32, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >> 3 = 2 + } + + #[test] + fn test_shift_right_uint64() { + let value_array = Arc::new(UInt64Array::from(vec![4u64, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert_eq!(arr.value(1), 2); // 8 >> 2 = 2 + assert_eq!(arr.value(2), 2); // 16 >> 3 = 2 + } + + #[test] + fn test_shift_right_nulls() { + let value_array = Arc::new(Int32Array::from(vec![Some(4), None, Some(8)])); + let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 4 >> 1 = 2 + assert!(arr.is_null(1)); // null >> 2 = null + assert!(arr.is_null(2)); // 8 >> null = null + } + + #[test] + fn test_shift_right_large_shift() { + let value_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![32, 33, 64])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 1); // 1 >> 32 = 1 + assert_eq!(arr.value(1), 1); // 2 >> 33 = 1 + assert_eq!(arr.value(2), 3); // 3 >> 64 = 3 + } + + #[test] + fn test_shift_right_negative_shift() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 4 >> -1 = 0 + assert_eq!(arr.value(1), 0); // 8 >> -2 = 0 + assert_eq!(arr.value(2), 0); // 16 >> -3 = 0 + } + + #[test] + fn test_shift_right_negative_values() { + let value_array = Arc::new(Int32Array::from(vec![-4, -8, -16])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftRight::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + // For signed integers, right shift preserves the sign bit + assert_eq!(arr.value(0), -2); // -4 >> 1 = -2 + assert_eq!(arr.value(1), -2); // -8 >> 2 = -2 + assert_eq!(arr.value(2), -2); // -16 >> 3 = -2 + } + + #[test] + fn test_shift_left_int32() { + let value_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 1 << 1 = 2 + assert_eq!(arr.value(1), 8); // 2 << 2 = 8 + assert_eq!(arr.value(2), 24); // 3 << 3 = 24 + assert_eq!(arr.value(3), 64); // 4 << 4 = 64 + } + + #[test] + fn test_shift_left_int64() { + let value_array = Arc::new(Int64Array::from(vec![1i64, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 1 << 1 = 2 + assert_eq!(arr.value(1), 8); // 2 << 2 = 8 + assert_eq!(arr.value(2), 24); // 3 << 3 = 24 + } + + #[test] + fn test_shift_left_uint32() { + let value_array = Arc::new(UInt32Array::from(vec![1u32, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 1 << 1 = 2 + assert_eq!(arr.value(1), 8); // 2 << 2 = 8 + assert_eq!(arr.value(2), 24); // 3 << 3 = 24 + } + + #[test] + fn test_shift_left_uint64() { + let value_array = Arc::new(UInt64Array::from(vec![1u64, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2); // 1 << 1 = 2 + assert_eq!(arr.value(1), 8); // 2 << 2 = 8 + assert_eq!(arr.value(2), 24); // 3 << 3 = 24 + } + + #[test] + fn test_shift_left_nulls() { + let value_array = Arc::new(Int32Array::from(vec![Some(2), None, Some(3)])); + let shift_array = Arc::new(Int32Array::from(vec![Some(1), Some(2), None])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 4); // 2 << 1 = 4 + assert!(arr.is_null(1)); // null << 2 = null + assert!(arr.is_null(2)); // 3 << null = null + } + + #[test] + fn test_shift_left_large_shift() { + let value_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let shift_array = Arc::new(Int32Array::from(vec![32, 33, 64])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 1); // 1 << 32 = 0 (overflow) + assert_eq!(arr.value(1), 4); // 2 << 33 = 0 (overflow) + assert_eq!(arr.value(2), 3); // 3 << 64 = 0 (overflow) + } + + #[test] + fn test_shift_left_negative_shift() { + let value_array = Arc::new(Int32Array::from(vec![4, 8, 16])); + let shift_array = Arc::new(Int32Array::from(vec![-1, -2, -3])); + let result = SparkShiftLeft::new() + .spark_shift(&[value_array, shift_array]) + .unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); // 4 << -1 = 0 + assert_eq!(arr.value(1), 0); // 8 << -2 = 0 + assert_eq!(arr.value(2), 0); // 16 << -3 = 0 + } +} diff --git a/datafusion/spark/src/function/bitwise/bitwise_not.rs b/datafusion/spark/src/function/bitwise/bitwise_not.rs new file mode 100644 index 0000000000000..2f3fe227833b0 --- /dev/null +++ b/datafusion/spark/src/function/bitwise/bitwise_not.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::kernels::bitwise; +use arrow::datatypes::{Int16Type, Int32Type, Int64Type, Int8Type}; +use arrow::{array::*, datatypes::DataType}; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::{ColumnarValue, TypeSignature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_functions::utils::make_scalar_function; +use std::{any::Any, sync::Arc}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBitwiseNot { + signature: Signature, +} + +impl Default for SparkBitwiseNot { + fn default() -> Self { + Self::new() + } +} + +impl SparkBitwiseNot { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Exact(vec![DataType::Int16]), + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkBitwiseNot { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bitwise_not" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 1 { + return plan_err!("bitwise_not expects exactly 1 argument"); + } + make_scalar_function(spark_bitwise_not, vec![])(&args.args) + } +} + +pub fn spark_bitwise_not(args: &[ArrayRef]) -> Result { + let array = args[0].as_ref(); + match array.data_type() { + DataType::Int8 => { + let result: Int8Array = + bitwise::bitwise_not(array.as_primitive::())?; + Ok(Arc::new(result)) + } + DataType::Int16 => { + let result: Int16Array = + bitwise::bitwise_not(array.as_primitive::())?; + Ok(Arc::new(result)) + } + DataType::Int32 => { + let result: Int32Array = + bitwise::bitwise_not(array.as_primitive::())?; + Ok(Arc::new(result)) + } + DataType::Int64 => { + let result: Int64Array = + bitwise::bitwise_not(array.as_primitive::())?; + Ok(Arc::new(result)) + } + _ => { + plan_err!( + "bitwise_not function does not support data type: {}", + array.data_type() + ) + } + } +} diff --git a/datafusion/spark/src/function/bitwise/mod.rs b/datafusion/spark/src/function/bitwise/mod.rs new file mode 100644 index 0000000000000..d729a3ddd09a1 --- /dev/null +++ b/datafusion/spark/src/function/bitwise/mod.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod bit_count; +pub mod bit_get; +pub mod bit_shift; +pub mod bitwise_not; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(bit_shift::SparkShiftLeft, shiftleft); +make_udf_function!(bit_shift::SparkShiftRight, shiftright); +make_udf_function!(bit_shift::SparkShiftRightUnsigned, shiftrightunsigned); +make_udf_function!(bit_get::SparkBitGet, bit_get); +make_udf_function!(bit_count::SparkBitCount, bit_count); +make_udf_function!(bitwise_not::SparkBitwiseNot, bitwise_not); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((bit_get, "Returns the value of the bit (0 or 1) at the specified position.", col pos)); + export_functions!(( + bit_count, + "Returns the number of bits set in the binary representation of the argument.", + col + )); + export_functions!(( + bitwise_not, + "Returns the result of a bitwise negation operation on the argument, where each bit in the binary representation is flipped, following two's complement arithmetic for signed integers.", + col + )); + export_functions!(( + shiftleft, + "Shifts the bits of the first argument left by the number of positions specified by the second argument. If the shift amount is negative or greater than or equal to the bit width, it is normalized to the bit width (i.e., pmod(shift, bit_width)).", + value shift + )); + export_functions!(( + shiftright, + "Shifts the bits of the first argument right by the number of positions specified by the second argument (arithmetic/signed shift). If the shift amount is negative or greater than or equal to the bit width, it is normalized to the bit width (i.e., pmod(shift, bit_width)).", + value shift + )); + export_functions!(( + shiftrightunsigned, + "Shifts the bits of the first argument right by the number of positions specified by the second argument (logical/unsigned shift). If the shift amount is negative or greater than or equal to the bit width, it is normalized to the bit width (i.e., pmod(shift, bit_width)).", + value shift + )); +} + +pub fn functions() -> Vec> { + vec![ + bit_get(), + bit_count(), + bitwise_not(), + shiftleft(), + shiftright(), + shiftrightunsigned(), + ] +} diff --git a/datafusion/spark/src/function/collection/mod.rs b/datafusion/spark/src/function/collection/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/collection/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs new file mode 100644 index 0000000000000..aee43dd8d0a58 --- /dev/null +++ b/datafusion/spark/src/function/conditional/if.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_expr::{ + binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, ColumnarValue, + Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkIf { + signature: Signature, +} + +impl Default for SparkIf { + fn default() -> Self { + Self::new() + } +} + +impl SparkIf { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkIf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "if" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 3 { + return plan_err!( + "Function 'if' expects 3 arguments but received {}", + arg_types.len() + ); + } + + if arg_types[0] != DataType::Boolean && arg_types[0] != DataType::Null { + return plan_err!( + "For function 'if' {} is not a boolean or null", + arg_types[0] + ); + } + + let target_types = try_type_union_resolution(&arg_types[1..])?; + let mut result = vec![DataType::Boolean]; + result.extend(target_types); + Ok(result) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[1].clone()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("if should have been simplified to case") + } + + fn simplify( + &self, + args: Vec, + _info: &dyn datafusion_expr::simplify::SimplifyInfo, + ) -> Result { + let condition = args[0].clone(); + let then_expr = args[1].clone(); + let else_expr = args[2].clone(); + + // Convert IF(condition, then_expr, else_expr) to + // CASE WHEN condition THEN then_expr ELSE else_expr END + let case_expr = when(condition, then_expr).otherwise(else_expr)?; + + Ok(ExprSimplifyResult::Simplified(case_expr)) + } +} diff --git a/datafusion/spark/src/function/conditional/mod.rs b/datafusion/spark/src/function/conditional/mod.rs new file mode 100644 index 0000000000000..4301d7642b41d --- /dev/null +++ b/datafusion/spark/src/function/conditional/mod.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +mod r#if; + +make_udf_function!(r#if::SparkIf, r#if); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((r#if, "If arg1 evaluates to true, then returns arg2; otherwise returns arg3", arg1 arg2 arg3)); +} + +pub fn functions() -> Vec> { + vec![r#if()] +} diff --git a/datafusion/spark/src/function/conversion/mod.rs b/datafusion/spark/src/function/conversion/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/conversion/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/csv/mod.rs b/datafusion/spark/src/function/csv/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/csv/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/datetime/date_add.rs b/datafusion/spark/src/function/datetime/date_add.rs new file mode 100644 index 0000000000000..a00430febcdb0 --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_add.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::compute; +use arrow::datatypes::{DataType, Date32Type}; +use arrow::error::ArrowError; +use datafusion_common::cast::{ + as_date32_array, as_int16_array, as_int32_array, as_int8_array, +}; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateAdd { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkDateAdd { + fn default() -> Self { + Self::new() + } +} + +impl SparkDateAdd { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Date32, DataType::Int8]), + TypeSignature::Exact(vec![DataType::Date32, DataType::Int16]), + TypeSignature::Exact(vec![DataType::Date32, DataType::Int32]), + ], + Volatility::Immutable, + ), + aliases: vec!["dateadd".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkDateAdd { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_add" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Date32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_date_add, vec![])(&args.args) + } +} + +fn spark_date_add(args: &[ArrayRef]) -> Result { + let [date_arg, days_arg] = args else { + return internal_err!( + "Spark `date_add` function requires 2 arguments, got {}", + args.len() + ); + }; + let date_array = as_date32_array(date_arg)?; + let result = match days_arg.data_type() { + DataType::Int8 => { + let days_array = as_int8_array(days_arg)?; + compute::try_binary::<_, _, _, Date32Type>( + date_array, + days_array, + |date, days| { + date.checked_add(days as i32).ok_or_else(|| { + ArrowError::ArithmeticOverflow("date_add".to_string()) + }) + }, + )? + } + DataType::Int16 => { + let days_array = as_int16_array(days_arg)?; + compute::try_binary::<_, _, _, Date32Type>( + date_array, + days_array, + |date, days| { + date.checked_add(days as i32).ok_or_else(|| { + ArrowError::ArithmeticOverflow("date_add".to_string()) + }) + }, + )? + } + DataType::Int32 => { + let days_array = as_int32_array(days_arg)?; + compute::try_binary::<_, _, _, Date32Type>( + date_array, + days_array, + |date, days| { + date.checked_add(days).ok_or_else(|| { + ArrowError::ArithmeticOverflow("date_add".to_string()) + }) + }, + )? + } + _ => { + return internal_err!( + "Spark `date_add` function: argument must be int8, int16, int32, got {:?}", + days_arg.data_type() + ); + } + }; + Ok(Arc::new(result)) +} diff --git a/datafusion/spark/src/function/datetime/date_sub.rs b/datafusion/spark/src/function/datetime/date_sub.rs new file mode 100644 index 0000000000000..a3b26661d196c --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_sub.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::compute; +use arrow::datatypes::{DataType, Date32Type}; +use arrow::error::ArrowError; +use datafusion_common::cast::{ + as_date32_array, as_int16_array, as_int32_array, as_int8_array, +}; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateSub { + signature: Signature, +} + +impl Default for SparkDateSub { + fn default() -> Self { + Self::new() + } +} + +impl SparkDateSub { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Date32, DataType::Int8]), + TypeSignature::Exact(vec![DataType::Date32, DataType::Int16]), + TypeSignature::Exact(vec![DataType::Date32, DataType::Int32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkDateSub { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "date_sub" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Date32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_date_sub, vec![])(&args.args) + } +} + +fn spark_date_sub(args: &[ArrayRef]) -> Result { + let [date_arg, days_arg] = args else { + return internal_err!( + "Spark `date_sub` function requires 2 arguments, got {}", + args.len() + ); + }; + let date_array = as_date32_array(date_arg)?; + let result = match days_arg.data_type() { + DataType::Int8 => { + let days_array = as_int8_array(days_arg)?; + compute::try_binary::<_, _, _, Date32Type>( + date_array, + days_array, + |date, days| { + date.checked_sub(days as i32).ok_or_else(|| { + ArrowError::ArithmeticOverflow("date_sub".to_string()) + }) + }, + )? + } + DataType::Int16 => { + let days_array = as_int16_array(days_arg)?; + compute::try_binary::<_, _, _, Date32Type>( + date_array, + days_array, + |date, days| { + date.checked_sub(days as i32).ok_or_else(|| { + ArrowError::ArithmeticOverflow("date_sub".to_string()) + }) + }, + )? + } + DataType::Int32 => { + let days_array = as_int32_array(days_arg)?; + compute::try_binary::<_, _, _, Date32Type>( + date_array, + days_array, + |date, days| { + date.checked_sub(days).ok_or_else(|| { + ArrowError::ArithmeticOverflow("date_sub".to_string()) + }) + }, + )? + } + _ => { + return internal_err!( + "Spark `date_sub` function: argument must be int8, int16, int32, got {:?}", + days_arg.data_type() + ); + } + }; + Ok(Arc::new(result)) +} diff --git a/datafusion/spark/src/function/datetime/last_day.rs b/datafusion/spark/src/function/datetime/last_day.rs new file mode 100644 index 0000000000000..c01a6403649c5 --- /dev/null +++ b/datafusion/spark/src/function/datetime/last_day.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, Date32Array}; +use arrow::datatypes::{DataType, Date32Type}; +use chrono::{Datelike, Duration, NaiveDate}; +use datafusion_common::{exec_datafusion_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLastDay { + signature: Signature, +} + +impl Default for SparkLastDay { + fn default() -> Self { + Self::new() + } +} + +impl SparkLastDay { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Date32], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkLastDay { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "last_day" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Date32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + let [arg] = args.as_slice() else { + return internal_err!( + "Spark `last_day` function requires 1 argument, got {}", + args.len() + ); + }; + match arg { + ColumnarValue::Scalar(ScalarValue::Date32(days)) => { + if let Some(days) = days { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some( + spark_last_day(*days)?, + )))) + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) + } + } + ColumnarValue::Array(array) => { + let result = match array.data_type() { + DataType::Date32 => { + let result: Date32Array = array + .as_primitive::() + .try_unary(spark_last_day)? + .with_data_type(DataType::Date32); + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!("Unsupported data type {other:?} for Spark function `last_day`") + } + }?; + Ok(ColumnarValue::Array(result)) + } + other => { + internal_err!("Unsupported arg {other:?} for Spark function `last_day") + } + } + } +} + +fn spark_last_day(days: i32) -> Result { + let date = Date32Type::to_naive_date(days); + + let (year, month) = (date.year(), date.month()); + let (next_year, next_month) = if month == 12 { + (year + 1, 1) + } else { + (year, month + 1) + }; + + let first_day_next_month = NaiveDate::from_ymd_opt(next_year, next_month, 1) + .ok_or_else(|| { + exec_datafusion_err!( + "Spark `last_day`: Unable to parse date from {next_year}, {next_month}, 1" + ) + })?; + + Ok(Date32Type::from_naive_date( + first_day_next_month - Duration::days(1), + )) +} diff --git a/datafusion/spark/src/function/datetime/make_dt_interval.rs b/datafusion/spark/src/function/datetime/make_dt_interval.rs new file mode 100644 index 0000000000000..bbfba44861344 --- /dev/null +++ b/datafusion/spark/src/function/datetime/make_dt_interval.rs @@ -0,0 +1,485 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, DurationMicrosecondBuilder, PrimitiveArray, +}; +use arrow::datatypes::TimeUnit::Microsecond; +use arrow::datatypes::{DataType, Float64Type, Int32Type}; +use datafusion_common::{ + exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkMakeDtInterval { + signature: Signature, +} + +impl Default for SparkMakeDtInterval { + fn default() -> Self { + Self::new() + } +} + +impl SparkMakeDtInterval { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkMakeDtInterval { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "make_dt_interval" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + /// Note the return type is `DataType::Duration(TimeUnit::Microsecond)` and not `DataType::Interval(DayTime)` as you might expect. + /// This is because `DataType::Interval(DayTime)` has precision only to the millisecond, whilst Spark's `DayTimeIntervalType` has + /// precision to the microsecond. We use `DataType::Duration(TimeUnit::Microsecond)` in order to not lose any precision. See the + /// [Sail compatibility doc] for reference. + /// + /// [Sail compatibility doc]: https://github.com/lakehq/sail/blob/dc5368daa24d40a7758a299e1ba8fc985cb29108/docs/guide/dataframe/data-types/compatibility.md?plain=1#L260 + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Duration(Microsecond)) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.is_empty() { + return Ok(ColumnarValue::Scalar(ScalarValue::DurationMicrosecond( + Some(0), + ))); + } + make_scalar_function(make_dt_interval_kernel, vec![])(&args.args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() > 4 { + return exec_err!( + "make_dt_interval expects between 0 and 4 arguments, got {}", + arg_types.len() + ); + } + + Ok((0..arg_types.len()) + .map(|i| { + if i == 3 { + DataType::Float64 + } else { + DataType::Int32 + } + }) + .collect()) + } +} + +fn make_dt_interval_kernel(args: &[ArrayRef]) -> Result { + let n_rows = args[0].len(); + let days = args[0] + .as_primitive_opt::() + .ok_or_else(|| plan_datafusion_err!("make_dt_interval arg[0] must be Int32"))?; + let hours: Option<&PrimitiveArray> = args + .get(1) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[1] must be Int32") + }) + }) + .transpose()?; + let mins: Option<&PrimitiveArray> = args + .get(2) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[2] must be Int32") + }) + }) + .transpose()?; + let secs: Option<&PrimitiveArray> = args + .get(3) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[3] must be Float64") + }) + }) + .transpose()?; + let mut builder = DurationMicrosecondBuilder::with_capacity(n_rows); + + for i in 0..n_rows { + // if one column is NULL → result NULL + let any_null_present = days.is_null(i) + || hours.as_ref().is_some_and(|a| a.is_null(i)) + || mins.as_ref().is_some_and(|a| a.is_null(i)) + || secs + .as_ref() + .is_some_and(|a| a.is_null(i) || !a.value(i).is_finite()); + + if any_null_present { + builder.append_null(); + continue; + } + + // default values 0 or 0.0 + let d = days.value(i); + let h = hours.as_ref().map_or(0, |a| a.value(i)); + let mi = mins.as_ref().map_or(0, |a| a.value(i)); + let s = secs.as_ref().map_or(0.0, |a| a.value(i)); + + match make_interval_dt_nano(d, h, mi, s) { + Some(v) => builder.append_value(v), + None => { + builder.append_null(); + continue; + } + } + } + + Ok(Arc::new(builder.finish())) +} +fn make_interval_dt_nano(day: i32, hour: i32, min: i32, sec: f64) -> Option { + const HOURS_PER_DAY: i32 = 24; + const MINS_PER_HOUR: i32 = 60; + const SECS_PER_MINUTE: i64 = 60; + const MICROS_PER_SEC: i64 = 1_000_000; + + let total_hours: i32 = day + .checked_mul(HOURS_PER_DAY) + .and_then(|v| v.checked_add(hour))?; + + let total_mins: i32 = total_hours + .checked_mul(MINS_PER_HOUR) + .and_then(|v| v.checked_add(min))?; + + let mut sec_whole: i64 = sec.trunc() as i64; + let sec_frac: f64 = sec - (sec_whole as f64); + let mut frac_us: i64 = (sec_frac * (MICROS_PER_SEC as f64)).round() as i64; + + if frac_us.abs() >= MICROS_PER_SEC { + if frac_us > 0 { + frac_us -= MICROS_PER_SEC; + sec_whole = sec_whole.checked_add(1)?; + } else { + frac_us += MICROS_PER_SEC; + sec_whole = sec_whole.checked_sub(1)?; + } + } + + let total_secs: i64 = (total_mins as i64) + .checked_mul(SECS_PER_MINUTE) + .and_then(|v| v.checked_add(sec_whole))?; + + let total_us = total_secs + .checked_mul(MICROS_PER_SEC) + .and_then(|v| v.checked_add(frac_us))?; + + Some(total_us) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{DurationMicrosecondArray, Float64Array, Int32Array}; + use arrow::datatypes::DataType::Duration; + use arrow::datatypes::Field; + use arrow::datatypes::TimeUnit::Microsecond; + use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; + + use super::*; + + fn run_make_dt_interval(arrs: Vec) -> Result { + make_dt_interval_kernel(&arrs) + } + + #[test] + fn nulls_propagate_per_row() -> Result<()> { + let days = Arc::new(Int32Array::from(vec![ + None, + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + ])) as ArrayRef; + + let hours = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + ])) as ArrayRef; + + let mins = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + Some(6), + Some(7), + ])) as ArrayRef; + + let secs = Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(2.0), + Some(3.0), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + ])) as ArrayRef; + + let out = run_make_dt_interval(vec![days, hours, mins, secs])?; + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected DurationMicrosecondArray") + })?; + + for i in 0..out.len() { + assert!(out.is_null(i), "row {i} should be NULL"); + } + Ok(()) + } + + #[test] + fn error_months_overflow_should_be_null() -> Result<()> { + // months = year*12 + month → NULL + + let days = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef; + + let hours = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef; + + let mins = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef; + + let secs = Arc::new(Float64Array::from(vec![Some(1.0)])) as ArrayRef; + + let out = run_make_dt_interval(vec![days, hours, mins, secs])?; + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected DurationMicrosecondArray") + })?; + + for i in 0..out.len() { + assert!(out.is_null(i), "row {i} should be NULL"); + } + + Ok(()) + } + + fn invoke_make_dt_interval_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + let args = ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", Duration(Microsecond), true).into(), + config_options: Arc::new(Default::default()), + }; + SparkMakeDtInterval::new().invoke_with_args(args) + } + + #[test] + fn zero_args_returns_zero_duration() -> Result<()> { + let number_rows: usize = 3; + + let res: ColumnarValue = invoke_make_dt_interval_with_args(vec![], number_rows)?; + let arr = res.into_array(number_rows)?; + let arr = arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected DurationMicrosecondArray") + })?; + + assert_eq!(arr.len(), number_rows); + for i in 0..number_rows { + assert!(!arr.is_null(i)); + assert_eq!(arr.value(i), 0_i64); + } + Ok(()) + } + + #[test] + fn one_day_minus_24_hours_equals_zero() -> Result<()> { + let arr_days = Arc::new(Int32Array::from(vec![Some(1), Some(-1)])) as ArrayRef; + let arr_hours = Arc::new(Int32Array::from(vec![Some(-24), Some(24)])) as ArrayRef; + let arr_mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let arr_secs = + Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef; + + let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?; + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected DurationMicrosecondArray") + })?; + + assert_eq!(out.len(), 2); + assert_eq!(out.null_count(), 0); + assert_eq!(out.value(0), 0_i64); + assert_eq!(out.value(1), 0_i64); + Ok(()) + } + + #[test] + fn one_hour_minus_60_mins_equals_zero() -> Result<()> { + let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let arr_hours = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef; + let arr_mins = Arc::new(Int32Array::from(vec![Some(60), Some(-60)])) as ArrayRef; + let arr_secs = + Arc::new(Float64Array::from(vec![Some(0.0), Some(0.0)])) as ArrayRef; + + let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?; + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected DurationMicrosecondArray") + })?; + + assert_eq!(out.len(), 2); + assert_eq!(out.null_count(), 0); + assert_eq!(out.value(0), 0_i64); + assert_eq!(out.value(1), 0_i64); + Ok(()) + } + + #[test] + fn one_mins_minus_60_secs_equals_zero() -> Result<()> { + let arr_days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let arr_hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let arr_mins = Arc::new(Int32Array::from(vec![Some(-1), Some(1)])) as ArrayRef; + let arr_secs = + Arc::new(Float64Array::from(vec![Some(60.0), Some(-60.0)])) as ArrayRef; + + let out = run_make_dt_interval(vec![arr_days, arr_hours, arr_mins, arr_secs])?; + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected DurationMicrosecondArray") + })?; + + assert_eq!(out.len(), 2); + assert_eq!(out.null_count(), 0); + assert_eq!(out.value(0), 0_i64); + assert_eq!(out.value(1), 0_i64); + Ok(()) + } + + #[test] + fn frac_carries_up_to_next_second_positive() -> Result<()> { + // 0.9999995s → 1_000_000 µs (carry a +1s) + let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let secs = Arc::new(Float64Array::from(vec![ + Some(0.999_999_5), + Some(0.999_999_4), + ])) as ArrayRef; + + let out = run_make_dt_interval(vec![days, hours, mins, secs])?; + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected DurationMicrosecondArray") + })?; + + assert_eq!(out.len(), 2); + assert_eq!(out.value(0), 1_000_000); + assert_eq!(out.value(1), 999_999); + Ok(()) + } + + #[test] + fn frac_carries_down_to_prev_second_negative() -> Result<()> { + // -0.9999995s → -1_000_000 µs (carry a −1s) + let days = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let hours = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let mins = Arc::new(Int32Array::from(vec![Some(0), Some(0)])) as ArrayRef; + let secs = Arc::new(Float64Array::from(vec![ + Some(-0.999_999_5), + Some(-0.999_999_4), + ])) as ArrayRef; + + let out = run_make_dt_interval(vec![days, hours, mins, secs])?; + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected DurationMicrosecondArray") + })?; + + assert_eq!(out.len(), 2); + assert_eq!(out.value(0), -1_000_000); + assert_eq!(out.value(1), -999_999); + Ok(()) + } + + #[test] + fn no_more_than_4_params() -> Result<()> { + let udf = SparkMakeDtInterval::new(); + + let arg_types = vec![ + DataType::Int32, + DataType::Int32, + DataType::Int32, + DataType::Float64, + DataType::Int32, + ]; + + let res = udf.coerce_types(&arg_types); + + assert!( + matches!(res, Err(DataFusionError::Execution(_))), + "make_dt_interval should return execution error for too many arguments" + ); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs new file mode 100644 index 0000000000000..8e3169556b95b --- /dev/null +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -0,0 +1,573 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, IntervalMonthDayNanoBuilder, PrimitiveArray}; +use arrow::datatypes::DataType::Interval; +use arrow::datatypes::IntervalUnit::MonthDayNano; +use arrow::datatypes::{DataType, IntervalMonthDayNano}; +use datafusion_common::{ + exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkMakeInterval { + signature: Signature, +} + +impl Default for SparkMakeInterval { + fn default() -> Self { + Self::new() + } +} + +impl SparkMakeInterval { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkMakeInterval { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "make_interval" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Interval(MonthDayNano)) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.is_empty() { + return Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano( + Some(IntervalMonthDayNano::new(0, 0, 0)), + ))); + } + make_scalar_function(make_interval_kernel, vec![])(&args.args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let length = arg_types.len(); + match length { + x if x > 7 => { + exec_err!( + "make_interval expects between 0 and 7 arguments, got {}", + arg_types.len() + ) + } + _ => Ok((0..arg_types.len()) + .map(|i| { + if i == 6 { + DataType::Float64 + } else { + DataType::Int32 + } + }) + .collect()), + } + } +} + +fn make_interval_kernel(args: &[ArrayRef]) -> Result { + use arrow::array::AsArray; + use arrow::datatypes::{Float64Type, Int32Type}; + + let n_rows = args[0].len(); + + let years = args[0] + .as_primitive_opt::() + .ok_or_else(|| plan_datafusion_err!("make_interval arg[0] must be Int32"))?; + let months = args + .get(1) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[1] must be Int32") + }) + }) + .transpose()?; + let weeks = args + .get(2) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[2] must be Int32") + }) + }) + .transpose()?; + let days: Option<&PrimitiveArray> = args + .get(3) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[3] must be Int32") + }) + }) + .transpose()?; + let hours: Option<&PrimitiveArray> = args + .get(4) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[4] must be Int32") + }) + }) + .transpose()?; + let mins: Option<&PrimitiveArray> = args + .get(5) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[5] must be Int32") + }) + }) + .transpose()?; + let secs: Option<&PrimitiveArray> = args + .get(6) + .map(|a| { + a.as_primitive_opt::().ok_or_else(|| { + plan_datafusion_err!("make_dt_interval arg[6] must be Float64") + }) + }) + .transpose()?; + + let mut builder = IntervalMonthDayNanoBuilder::with_capacity(n_rows); + + for i in 0..n_rows { + // if one column is NULL → result NULL + let any_null_present = years.is_null(i) + || months.as_ref().is_some_and(|a| a.is_null(i)) + || weeks.as_ref().is_some_and(|a| a.is_null(i)) + || days.as_ref().is_some_and(|a| a.is_null(i)) + || hours.as_ref().is_some_and(|a| a.is_null(i)) + || mins.as_ref().is_some_and(|a| a.is_null(i)) + || secs + .as_ref() + .is_some_and(|a| a.is_null(i) || !a.value(i).is_finite()); + + if any_null_present { + builder.append_null(); + continue; + } + + // default values 0 or 0.0 + let y = years.value(i); + let mo = months.as_ref().map_or(0, |a| a.value(i)); + let w = weeks.as_ref().map_or(0, |a| a.value(i)); + let d = days.as_ref().map_or(0, |a| a.value(i)); + let h = hours.as_ref().map_or(0, |a| a.value(i)); + let mi = mins.as_ref().map_or(0, |a| a.value(i)); + let s = secs.as_ref().map_or(0.0, |a| a.value(i)); + + match make_interval_month_day_nano(y, mo, w, d, h, mi, s) { + Some(v) => builder.append_value(v), + None => { + builder.append_null(); + continue; + } + } + } + + Ok(Arc::new(builder.finish())) +} + +fn make_interval_month_day_nano( + year: i32, + month: i32, + week: i32, + day: i32, + hour: i32, + min: i32, + sec: f64, +) -> Option { + // checks if overflow + let months = year.checked_mul(12).and_then(|v| v.checked_add(month))?; + let total_days = week.checked_mul(7).and_then(|v| v.checked_add(day))?; + + let hours_nanos = (hour as i64).checked_mul(3_600_000_000_000)?; + let mins_nanos = (min as i64).checked_mul(60_000_000_000)?; + + let sec_int = sec.trunc() as i64; + let frac = sec - sec.trunc(); + let mut frac_nanos = (frac * 1_000_000_000.0).round() as i64; + + if frac_nanos.abs() >= 1_000_000_000 { + if frac_nanos > 0 { + frac_nanos -= 1_000_000_000; + } else { + frac_nanos += 1_000_000_000; + } + } + + let secs_nanos = sec_int.checked_mul(1_000_000_000)?; + + let total_nanos = hours_nanos + .checked_add(mins_nanos) + .and_then(|v| v.checked_add(secs_nanos)) + .and_then(|v| v.checked_add(frac_nanos))?; + + Some(IntervalMonthDayNano::new(months, total_days, total_nanos)) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Float64Array, Int32Array, IntervalMonthDayNanoArray}; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{internal_datafusion_err, internal_err, Result}; + + use super::*; + fn run_make_interval_month_day_nano(arrs: Vec) -> Result { + make_interval_kernel(&arrs) + } + + #[test] + fn nulls_propagate_per_row() { + let year = Arc::new(Int32Array::from(vec![ + None, + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ])); + let month = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ])); + let week = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ])); + let day = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ])); + let hour = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + None, + Some(6), + Some(7), + Some(8), + Some(9), + ])); + let min = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + None, + Some(7), + Some(8), + Some(9), + ])); + let sec = Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(2.0), + Some(3.0), + Some(4.0), + Some(5.0), + Some(6.0), + None, + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + ])); + + let out = run_make_interval_month_day_nano(vec![ + year, month, week, day, hour, min, sec, + ]) + .unwrap(); + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("expected IntervalMonthDayNano")) + .unwrap(); + + for i in 0..out.len() { + assert!(out.is_null(i), "row {i} should be NULL"); + } + } + + #[test] + fn error_months_overflow_should_be_null() { + // months = year*12 + month → NULL + let year = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef; + let month = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef; + let week = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let day = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let hour = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let min = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let sec = Arc::new(Float64Array::from(vec![Some(0.0)])) as ArrayRef; + + let out = run_make_interval_month_day_nano(vec![ + year, month, week, day, hour, min, sec, + ]) + .unwrap(); + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("expected IntervalMonthDayNano")) + .unwrap(); + + for i in 0..out.len() { + assert!(out.is_null(i), "row {i} should be NULL"); + } + } + #[test] + fn error_days_overflow_should_be_null() { + // months = year*12 + month → NULL + let year = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let month = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef; + let week = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef; + let day = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let hour = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let min = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let sec = Arc::new(Float64Array::from(vec![Some(0.0)])) as ArrayRef; + + let out = run_make_interval_month_day_nano(vec![ + year, month, week, day, hour, min, sec, + ]) + .unwrap(); + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("expected IntervalMonthDayNano")) + .unwrap(); + + for i in 0..out.len() { + assert!(out.is_null(i), "row {i} should be NULL"); + } + } + #[test] + fn error_min_overflow_should_be_null() { + let year = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let month = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let week = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let day = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let hour = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let min = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef; + let sec = Arc::new(Float64Array::from(vec![Some(0.0)])) as ArrayRef; + + let out = run_make_interval_month_day_nano(vec![ + year, month, week, day, hour, min, sec, + ]) + .unwrap(); + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("expected IntervalMonthDayNano")) + .unwrap(); + + for i in 0..out.len() { + assert!(out.is_null(i), "row {i} should be NULL"); + } + } + #[test] + fn error_sec_overflow_should_be_null() { + let year = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let month = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let week = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let day = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let hour = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let min = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef; + let sec = Arc::new(Float64Array::from(vec![Some(f64::MAX)])) as ArrayRef; + + let out = run_make_interval_month_day_nano(vec![ + year, month, week, day, hour, min, sec, + ]) + .unwrap(); + let out = out + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("expected IntervalMonthDayNano")) + .unwrap(); + + for i in 0..out.len() { + assert!(out.is_null(i), "row {i} should be NULL"); + } + } + + #[test] + fn happy_path_all_present_single_row() { + // 1y 2m 3w 4d 5h 6m 7.25s + let year = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef; + let month = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef; + let week = Arc::new(Int32Array::from(vec![Some(3)])) as ArrayRef; + let day = Arc::new(Int32Array::from(vec![Some(4)])) as ArrayRef; + let hour = Arc::new(Int32Array::from(vec![Some(5)])) as ArrayRef; + let mins = Arc::new(Int32Array::from(vec![Some(6)])) as ArrayRef; + let secs = Arc::new(Float64Array::from(vec![Some(7.25)])) as ArrayRef; + + let out = run_make_interval_month_day_nano(vec![ + year, month, week, day, hour, mins, secs, + ]) + .unwrap(); + assert_eq!(out.data_type(), &Interval(MonthDayNano)); + + let out = out + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(out.len(), 1); + assert_eq!(out.null_count(), 0); + + let v: IntervalMonthDayNano = out.value(0); + assert_eq!(v.months, 12 + 2); // 14 + assert_eq!(v.days, 3 * 7 + 4); // 25 + let expected_nanos = (5_i64 * 3600 + 6 * 60 + 7) * 1_000_000_000 + 250_000_000; + assert_eq!(v.nanoseconds, expected_nanos); + } + + #[test] + fn negative_components_and_fractional_seconds() { + // -1y -2m -1w -1d -1h -1m -1.5s + let year = Arc::new(Int32Array::from(vec![Some(-1)])) as ArrayRef; + let month = Arc::new(Int32Array::from(vec![Some(-2)])) as ArrayRef; + let week = Arc::new(Int32Array::from(vec![Some(-1)])) as ArrayRef; + let day = Arc::new(Int32Array::from(vec![Some(-1)])) as ArrayRef; + let hour = Arc::new(Int32Array::from(vec![Some(-1)])) as ArrayRef; + let mins = Arc::new(Int32Array::from(vec![Some(-1)])) as ArrayRef; + let secs = Arc::new(Float64Array::from(vec![Some(-1.5)])) as ArrayRef; + + let out = run_make_interval_month_day_nano(vec![ + year, month, week, day, hour, mins, secs, + ]) + .unwrap(); + let out = out + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(out.len(), 1); + assert_eq!(out.null_count(), 0); + let v = out.value(0); + + assert_eq!(v.months, -12 + (-2)); // -14 + assert_eq!(v.days, -7 + (-1)); // -8 + + // -(1h + 1m + 1.5s) en nanos + let expected_nanos = -((3600_i64 + 60 + 1) * 1_000_000_000 + 500_000_000); + assert_eq!(v.nanoseconds, expected_nanos); + } + + fn invoke_make_interval_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + let args = ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", Interval(MonthDayNano), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + SparkMakeInterval::new().invoke_with_args(args) + } + + #[test] + fn zero_args_returns_zero_seconds() -> Result<()> { + let number_rows = 2; + let res: ColumnarValue = invoke_make_interval_with_args(vec![], number_rows)?; + + match res { + ColumnarValue::Array(arr) => { + let arr = arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("expected IntervalMonthDayNanoArray") + })?; + if arr.len() != number_rows { + return internal_err!( + "expected array length {number_rows}, got {}", + arr.len() + ); + } + for i in 0..number_rows { + let iv = arr.value(i); + if (iv.months, iv.days, iv.nanoseconds) != (0, 0, 0) { + return internal_err!( + "row {i}: expected (0,0,0), got ({},{},{})", + iv.months, + iv.days, + iv.nanoseconds + ); + } + } + } + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(iv))) => { + if (iv.months, iv.days, iv.nanoseconds) != (0, 0, 0) { + return internal_err!( + "expected scalar 0s, got ({},{},{})", + iv.months, + iv.days, + iv.nanoseconds + ); + } + } + other => { + return internal_err!( + "expected Array or Scalar IntervalMonthDayNano, got {other:?}" + ); + } + } + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/datetime/mod.rs b/datafusion/spark/src/function/datetime/mod.rs new file mode 100644 index 0000000000000..a6adc99607665 --- /dev/null +++ b/datafusion/spark/src/function/datetime/mod.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod date_add; +pub mod date_sub; +pub mod last_day; +pub mod make_dt_interval; +pub mod make_interval; +pub mod next_day; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(date_add::SparkDateAdd, date_add); +make_udf_function!(date_sub::SparkDateSub, date_sub); +make_udf_function!(last_day::SparkLastDay, last_day); +make_udf_function!(make_dt_interval::SparkMakeDtInterval, make_dt_interval); +make_udf_function!(make_interval::SparkMakeInterval, make_interval); +make_udf_function!(next_day::SparkNextDay, next_day); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + date_add, + "Returns the date that is days days after start. The function returns NULL if at least one of the input parameters is NULL.", + arg1 arg2 + )); + export_functions!(( + date_sub, + "Returns the date that is days days before start. The function returns NULL if at least one of the input parameters is NULL.", + arg1 arg2 + )); + export_functions!(( + last_day, + "Returns the last day of the month which the date belongs to.", + arg1 + )); + export_functions!(( + make_dt_interval, + "Make a day time interval from given days, hours, mins and secs (return type is actually a Duration(Microsecond))", + days hours mins secs + )); + export_functions!(( + make_interval, + "Make interval from years, months, weeks, days, hours, mins and secs.", + years months weeks days hours mins secs + )); + // TODO: add once ANSI support is added: + // "When both of the input parameters are not NULL and day_of_week is an invalid input, the function throws SparkIllegalArgumentException if spark.sql.ansi.enabled is set to true, otherwise NULL." + export_functions!(( + next_day, + "Returns the first date which is later than start_date and named as indicated. The function returns NULL if at least one of the input parameters is NULL.", + arg1 arg2 + )); +} + +pub fn functions() -> Vec> { + vec![ + date_add(), + date_sub(), + last_day(), + make_dt_interval(), + make_interval(), + next_day(), + ] +} diff --git a/datafusion/spark/src/function/datetime/next_day.rs b/datafusion/spark/src/function/datetime/next_day.rs new file mode 100644 index 0000000000000..32739f3e2c591 --- /dev/null +++ b/datafusion/spark/src/function/datetime/next_day.rs @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{new_null_array, ArrayRef, AsArray, Date32Array, StringArrayType}; +use arrow::datatypes::{DataType, Date32Type}; +use chrono::{Datelike, Duration, Weekday}; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkNextDay { + signature: Signature, +} + +impl Default for SparkNextDay { + fn default() -> Self { + Self::new() + } +} + +impl SparkNextDay { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Date32, DataType::Utf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkNextDay { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "next_day" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Date32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + let [date, day_of_week] = args.as_slice() else { + return exec_err!( + "Spark `next_day` function requires 2 arguments, got {}", + args.len() + ); + }; + + match (date, day_of_week) { + (ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => { + match (date, day_of_week) { + (ScalarValue::Date32(days), ScalarValue::Utf8(day_of_week) | ScalarValue::LargeUtf8(day_of_week) | ScalarValue::Utf8View(day_of_week)) => { + if let Some(days) = days { + if let Some(day_of_week) = day_of_week { + Ok(ColumnarValue::Scalar(ScalarValue::Date32( + spark_next_day(*days, day_of_week.as_str()), + ))) + } else { + // TODO: if spark.sql.ansi.enabled is false, + // returns NULL instead of an error for a malformed dayOfWeek. + Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) + } + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) + } + } + _ => exec_err!("Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"), + } + } + (ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => { + match (date_array.data_type(), day_of_week) { + (DataType::Date32, ScalarValue::Utf8(day_of_week) | ScalarValue::LargeUtf8(day_of_week) | ScalarValue::Utf8View(day_of_week)) => { + if let Some(day_of_week) = day_of_week { + let result: Date32Array = date_array + .as_primitive::() + .unary_opt(|days| spark_next_day(days, day_of_week.as_str())) + .with_data_type(DataType::Date32); + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } else { + // TODO: if spark.sql.ansi.enabled is false, + // returns NULL instead of an error for a malformed dayOfWeek. + Ok(ColumnarValue::Array(Arc::new(new_null_array(&DataType::Date32, date_array.len())))) + } + } + _ => exec_err!("Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"), + } + } + ( + ColumnarValue::Array(date_array), + ColumnarValue::Array(day_of_week_array), + ) => { + let result = match (date_array.data_type(), day_of_week_array.data_type()) + { + ( + DataType::Date32, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View, + ) => { + let date_array: &Date32Array = + date_array.as_primitive::(); + match day_of_week_array.data_type() { + DataType::Utf8 => { + let day_of_week_array = + day_of_week_array.as_string::(); + process_next_day_arrays(date_array, day_of_week_array) + } + DataType::LargeUtf8 => { + let day_of_week_array = + day_of_week_array.as_string::(); + process_next_day_arrays(date_array, day_of_week_array) + } + DataType::Utf8View => { + let day_of_week_array = + day_of_week_array.as_string_view(); + process_next_day_arrays(date_array, day_of_week_array) + } + other => { + exec_err!("Spark `next_day` function: second arg must be string. Got {other:?}") + } + } + } + (left, right) => { + exec_err!( + "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}" + ) + } + }?; + Ok(ColumnarValue::Array(result)) + } + _ => exec_err!("Unsupported args {args:?} for Spark function `next_day`"), + } + } +} + +fn process_next_day_arrays<'a, S>( + date_array: &Date32Array, + day_of_week_array: &'a S, +) -> Result +where + &'a S: StringArrayType<'a>, +{ + let result = date_array + .iter() + .zip(day_of_week_array.iter()) + .map(|(days, day_of_week)| { + if let Some(days) = days { + if let Some(day_of_week) = day_of_week { + spark_next_day(days, day_of_week) + } else { + // TODO: if spark.sql.ansi.enabled is false, + // returns NULL instead of an error for a malformed dayOfWeek. + None + } + } else { + None + } + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) +} + +fn spark_next_day(days: i32, day_of_week: &str) -> Option { + let date = Date32Type::to_naive_date(days); + + let day_of_week = day_of_week.trim().to_uppercase(); + let day_of_week = match day_of_week.as_str() { + "MO" | "MON" | "MONDAY" => Some("MONDAY"), + "TU" | "TUE" | "TUESDAY" => Some("TUESDAY"), + "WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"), + "TH" | "THU" | "THURSDAY" => Some("THURSDAY"), + "FR" | "FRI" | "FRIDAY" => Some("FRIDAY"), + "SA" | "SAT" | "SATURDAY" => Some("SATURDAY"), + "SU" | "SUN" | "SUNDAY" => Some("SUNDAY"), + _ => { + // TODO: if spark.sql.ansi.enabled is false, + // returns NULL instead of an error for a malformed dayOfWeek. + None + } + }; + + if let Some(day_of_week) = day_of_week { + let day_of_week = day_of_week.parse::(); + match day_of_week { + Ok(day_of_week) => Some(Date32Type::from_naive_date( + date + Duration::days( + (7 - date.weekday().days_since(day_of_week)) as i64, + ), + )), + Err(_) => { + // TODO: if spark.sql.ansi.enabled is false, + // returns NULL instead of an error for a malformed dayOfWeek. + None + } + } + } else { + None + } +} diff --git a/datafusion/spark/src/function/error_utils.rs b/datafusion/spark/src/function/error_utils.rs new file mode 100644 index 0000000000000..b972d64ed3e9a --- /dev/null +++ b/datafusion/spark/src/function/error_utils.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// TODO: https://github.com/apache/spark/tree/master/common/utils/src/main/resources/error + +use arrow::datatypes::DataType; +use datafusion_common::{exec_datafusion_err, internal_datafusion_err, DataFusionError}; + +pub fn invalid_arg_count_exec_err( + function_name: &str, + required_range: (i32, i32), + provided: usize, +) -> DataFusionError { + let (min_required, max_required) = required_range; + let required = if min_required == max_required { + format!( + "{min_required} argument{}", + if min_required == 1 { "" } else { "s" } + ) + } else { + format!("{min_required} to {max_required} arguments") + }; + exec_datafusion_err!( + "Spark `{function_name}` function requires {required}, got {provided}" + ) +} + +pub fn unsupported_data_type_exec_err( + function_name: &str, + required: &str, + provided: &DataType, +) -> DataFusionError { + exec_datafusion_err!("Unsupported Data Type: Spark `{function_name}` function expects {required}, got {provided}") +} + +pub fn unsupported_data_types_exec_err( + function_name: &str, + required: &str, + provided: &[DataType], +) -> DataFusionError { + exec_datafusion_err!( + "Unsupported Data Type: Spark `{function_name}` function expects {required}, got {}", + provided + .iter() + .map(|dt| format!("{dt}")) + .collect::>() + .join(", ") + ) +} + +pub fn generic_exec_err(function_name: &str, message: &str) -> DataFusionError { + exec_datafusion_err!("Spark `{function_name}` function: {message}") +} + +pub fn generic_internal_err(function_name: &str, message: &str) -> DataFusionError { + internal_datafusion_err!("Spark `{function_name}` function: {message}") +} diff --git a/datafusion/spark/src/function/functions_nested_utils.rs b/datafusion/spark/src/function/functions_nested_utils.rs new file mode 100644 index 0000000000000..b455ba735d749 --- /dev/null +++ b/datafusion/spark/src/function/functions_nested_utils.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::ColumnarValue; + +/// array function wrapper that differentiates between scalar (length 1) and array. +pub(crate) fn make_scalar_function( + inner: F, +) -> impl Fn(&[ColumnarValue]) -> Result +where + F: Fn(&[ArrayRef]) -> Result, +{ + move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let args = ColumnarValue::values_to_arrays(args)?; + + let result = (inner)(&args); + + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } +} diff --git a/datafusion/spark/src/function/generator/mod.rs b/datafusion/spark/src/function/generator/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/generator/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/hash/crc32.rs b/datafusion/spark/src/function/hash/crc32.rs new file mode 100644 index 0000000000000..76e31d12c6487 --- /dev/null +++ b/datafusion/spark/src/function/hash/crc32.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; +use arrow::datatypes::DataType; +use crc32fast::Hasher; +use datafusion_common::cast::{ + as_binary_array, as_binary_view_array, as_large_binary_array, +}; +use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCrc32 { + signature: Signature, +} + +impl Default for SparkCrc32 { + fn default() -> Self { + Self::new() + } +} + +impl SparkCrc32 { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkCrc32 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "crc32" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_crc32, vec![])(&args.args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!( + "`crc32` function requires 1 argument, got {}", + arg_types.len() + ); + } + match arg_types[0] { + DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { + Ok(vec![arg_types[0].clone()]) + } + DataType::Utf8 | DataType::Utf8View => Ok(vec![DataType::Binary]), + DataType::LargeUtf8 => Ok(vec![DataType::LargeBinary]), + DataType::Null => Ok(vec![DataType::Binary]), + _ => exec_err!("`crc32` function does not support type {}", arg_types[0]), + } + } +} + +fn spark_crc32_digest(value: &[u8]) -> i64 { + let mut hasher = Hasher::new(); + hasher.update(value); + hasher.finalize() as i64 +} + +fn spark_crc32_impl<'a>(input: impl Iterator>) -> ArrayRef { + let result = input + .map(|value| value.map(spark_crc32_digest)) + .collect::(); + Arc::new(result) +} + +fn spark_crc32(args: &[ArrayRef]) -> Result { + let [input] = args else { + return internal_err!( + "Spark `crc32` function requires 1 argument, got {}", + args.len() + ); + }; + + match input.data_type() { + DataType::Binary => { + let input = as_binary_array(input)?; + Ok(spark_crc32_impl(input.iter())) + } + DataType::LargeBinary => { + let input = as_large_binary_array(input)?; + Ok(spark_crc32_impl(input.iter())) + } + DataType::BinaryView => { + let input = as_binary_view_array(input)?; + Ok(spark_crc32_impl(input.iter())) + } + _ => { + exec_err!( + "Spark `crc32` function: argument must be binary or large binary, got {:?}", + input.data_type() + ) + } + } +} diff --git a/datafusion/spark/src/function/hash/mod.rs b/datafusion/spark/src/function/hash/mod.rs new file mode 100644 index 0000000000000..5860596ac70a3 --- /dev/null +++ b/datafusion/spark/src/function/hash/mod.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod crc32; +pub mod sha1; +pub mod sha2; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(crc32::SparkCrc32, crc32); +make_udf_function!(sha1::SparkSha1, sha1); +make_udf_function!(sha2::SparkSha2, sha2); + +pub mod expr_fn { + use datafusion_functions::export_functions; + export_functions!( + (crc32, "crc32(expr) - Returns a cyclic redundancy check value of the expr as a bigint.", arg1), + (sha1, "sha1(expr) - Returns a SHA-1 hash value of the expr as a hex string.", arg1), + (sha2, "sha2(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of expr. SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.", arg1 arg2) + ); +} + +pub fn functions() -> Vec> { + vec![crc32(), sha1(), sha2()] +} diff --git a/datafusion/spark/src/function/hash/sha1.rs b/datafusion/spark/src/function/hash/sha1.rs new file mode 100644 index 0000000000000..25cbdd4453505 --- /dev/null +++ b/datafusion/spark/src/function/hash/sha1.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Write; +use std::sync::Arc; + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_binary_array, as_binary_view_array, as_large_binary_array, +}; +use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use sha1::{Digest, Sha1}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSha1 { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkSha1 { + fn default() -> Self { + Self::new() + } +} + +impl SparkSha1 { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["sha".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkSha1 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sha1" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_sha1, vec![])(&args.args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!( + "`sha1` function requires 1 argument, got {}", + arg_types.len() + ); + } + match arg_types[0] { + DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { + Ok(vec![arg_types[0].clone()]) + } + DataType::Utf8 | DataType::Utf8View => Ok(vec![DataType::Binary]), + DataType::LargeUtf8 => Ok(vec![DataType::LargeBinary]), + DataType::Null => Ok(vec![DataType::Binary]), + _ => exec_err!("`sha1` function does not support type {}", arg_types[0]), + } + } +} + +fn spark_sha1_digest(value: &[u8]) -> String { + let result = Sha1::digest(value); + let mut s = String::with_capacity(result.len() * 2); + for b in result.as_slice() { + #[allow(clippy::unwrap_used)] + write!(&mut s, "{b:02x}").unwrap(); + } + s +} + +fn spark_sha1_impl<'a>(input: impl Iterator>) -> ArrayRef { + let result = input + .map(|value| value.map(spark_sha1_digest)) + .collect::(); + Arc::new(result) +} + +fn spark_sha1(args: &[ArrayRef]) -> Result { + let [input] = args else { + return internal_err!( + "Spark `sha1` function requires 1 argument, got {}", + args.len() + ); + }; + + match input.data_type() { + DataType::Binary => { + let input = as_binary_array(input)?; + Ok(spark_sha1_impl(input.iter())) + } + DataType::LargeBinary => { + let input = as_large_binary_array(input)?; + Ok(spark_sha1_impl(input.iter())) + } + DataType::BinaryView => { + let input = as_binary_view_array(input)?; + Ok(spark_sha1_impl(input.iter())) + } + _ => { + exec_err!( + "Spark `sha1` function: argument must be binary or large binary, got {:?}", + input.data_type() + ) + } + } +} diff --git a/datafusion/spark/src/function/hash/sha2.rs b/datafusion/spark/src/function/hash/sha2.rs new file mode 100644 index 0000000000000..b006607d3eeda --- /dev/null +++ b/datafusion/spark/src/function/hash/sha2.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate datafusion_functions; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use crate::function::math::hex::spark_sha2_hex; +use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::datatypes::{DataType, Int32Type}; +use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; +pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSha2 { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkSha2 { + fn default() -> Self { + Self::new() + } +} + +impl SparkSha2 { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkSha2 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sha2" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types[1].is_null() { + return Ok(DataType::Null); + } + Ok(match arg_types[0] { + DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::BinaryView + | DataType::LargeBinary => DataType::Utf8, + DataType::Null => DataType::Null, + _ => { + return exec_err!( + "{} function can only accept strings or binary arrays.", + self.name() + ) + } + }) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { + internal_datafusion_err!("Expected 2 arguments for function sha2") + })?; + + sha2(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return Err(invalid_arg_count_exec_err( + self.name(), + (2, 2), + arg_types.len(), + )); + } + let expr_type = match &arg_types[0] { + DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::BinaryView + | DataType::LargeBinary + | DataType::Null => Ok(arg_types[0].clone()), + _ => Err(unsupported_data_type_exec_err( + self.name(), + "String, Binary", + &arg_types[0], + )), + }?; + let bit_length_type = if arg_types[1].is_numeric() { + Ok(DataType::Int32) + } else if arg_types[1].is_null() { + Ok(DataType::Null) + } else { + Err(unsupported_data_type_exec_err( + self.name(), + "Numeric Type", + &arg_types[1], + )) + }?; + + Ok(vec![expr_type, bit_length_type]) + } +} + +pub fn sha2(args: [ColumnarValue; 2]) -> Result { + match args { + [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => { + compute_sha2( + bit_length_arg, + &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))], + ) + } + [ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => { + compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]) + } + [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Array(bit_length_arg)] => + { + let arr: StringArray = bit_length_arg + .as_primitive::() + .iter() + .map(|bit_length| { + match sha2([ + ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())), + ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), + ]) + .unwrap() + { + ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, + ColumnarValue::Array(arr) => arr + .as_string::() + .iter() + .map(|str| str.unwrap().to_string()) + .next(), // first element + _ => unreachable!(), + } + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + } + [ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] => { + let expr_iter = expr_arg.as_string::().iter(); + let bit_length_iter = bit_length_arg.as_primitive::().iter(); + let arr: StringArray = expr_iter + .zip(bit_length_iter) + .map(|(expr, bit_length)| { + match sha2([ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + expr.unwrap().to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), + ]) + .unwrap() + { + ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, + ColumnarValue::Array(arr) => arr + .as_string::() + .iter() + .map(|str| str.unwrap().to_string()) + .next(), // first element + _ => unreachable!(), + } + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + } + _ => exec_err!("Unsupported argument types for sha2 function"), + } +} + +fn compute_sha2( + bit_length_arg: i32, + expr_arg: &[ColumnarValue], +) -> Result { + match bit_length_arg { + 0 | 256 => sha256(expr_arg), + 224 => sha224(expr_arg), + 384 => sha384(expr_arg), + 512 => sha512(expr_arg), + _ => { + // Return null for unsupported bit lengths instead of error, because spark sha2 does not + // error out for this. + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + } + .map(|hashed| spark_sha2_hex(&[hashed]).unwrap()) +} diff --git a/datafusion/spark/src/function/json/mod.rs b/datafusion/spark/src/function/json/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/json/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/lambda/mod.rs b/datafusion/spark/src/function/lambda/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/lambda/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/map/map_from_arrays.rs b/datafusion/spark/src/function/map/map_from_arrays.rs new file mode 100644 index 0000000000000..987548e353e44 --- /dev/null +++ b/datafusion/spark/src/function/map/map_from_arrays.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use crate::function::map::utils::{ + get_element_type, get_list_offsets, get_list_values, + map_from_keys_values_offsets_nulls, map_type_from_key_value_types, +}; +use arrow::array::{Array, ArrayRef, NullArray}; +use arrow::compute::kernels::cast; +use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; +use datafusion_common::Result; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `map_from_arrays` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MapFromArrays { + signature: Signature, +} + +impl Default for MapFromArrays { + fn default() -> Self { + Self::new() + } +} + +impl MapFromArrays { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MapFromArrays { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_from_arrays" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [key_type, value_type] = take_function_args("map_from_arrays", arg_types)?; + Ok(map_type_from_key_value_types( + get_element_type(key_type)?, + get_element_type(value_type)?, + )) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + make_scalar_function(map_from_arrays_inner, vec![])(&args.args) + } +} + +fn map_from_arrays_inner(args: &[ArrayRef]) -> Result { + let [keys, values] = take_function_args("map_from_arrays", args)?; + + if matches!(keys.data_type(), DataType::Null) + || matches!(values.data_type(), DataType::Null) + { + return Ok(cast( + &NullArray::new(keys.len()), + &map_type_from_key_value_types( + get_element_type(keys.data_type())?, + get_element_type(values.data_type())?, + ), + )?); + } + + map_from_keys_values_offsets_nulls( + get_list_values(keys)?, + get_list_values(values)?, + &get_list_offsets(keys)?, + &get_list_offsets(values)?, + keys.nulls(), + values.nulls(), + ) +} diff --git a/datafusion/spark/src/function/map/map_from_entries.rs b/datafusion/spark/src/function/map/map_from_entries.rs new file mode 100644 index 0000000000000..6648979c5dd23 --- /dev/null +++ b/datafusion/spark/src/function/map/map_from_entries.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use crate::function::map::utils::{ + get_element_type, get_list_offsets, get_list_values, + map_from_keys_values_offsets_nulls, map_type_from_key_value_types, +}; +use arrow::array::{Array, ArrayRef, NullBufferBuilder, StructArray}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::utils::make_scalar_function; + +/// Spark-compatible `map_from_entries` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MapFromEntries { + signature: Signature, +} + +impl Default for MapFromEntries { + fn default() -> Self { + Self::new() + } +} + +impl MapFromEntries { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MapFromEntries { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_from_entries" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [entries_type] = take_function_args("map_from_entries", arg_types)?; + let entries_element_type = get_element_type(entries_type)?; + let (keys_type, values_type) = match entries_element_type { + DataType::Struct(fields) if fields.len() == 2 => { + Ok((fields[0].data_type(), fields[1].data_type())) + } + wrong_type => exec_err!( + "map_from_entries: expected array>, got {:?}", + wrong_type + ), + }?; + Ok(map_type_from_key_value_types(keys_type, values_type)) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + make_scalar_function(map_from_entries_inner, vec![])(&args.args) + } +} + +fn map_from_entries_inner(args: &[ArrayRef]) -> Result { + let [entries] = take_function_args("map_from_entries", args)?; + let entries_offsets = get_list_offsets(entries)?; + let entries_values = get_list_values(entries)?; + + let (flat_keys, flat_values) = + match entries_values.as_any().downcast_ref::() { + Some(a) => Ok((a.column(0), a.column(1))), + None => exec_err!( + "map_from_entries: expected array>, got {:?}", + entries_values.data_type() + ), + }?; + + let entries_with_nulls = entries_values.nulls().and_then(|entries_inner_nulls| { + let mut builder = NullBufferBuilder::new_with_len(0); + let mut cur_offset = entries_offsets + .first() + .map(|offset| *offset as usize) + .unwrap_or(0); + + for next_offset in entries_offsets.iter().skip(1) { + let num_entries = *next_offset as usize - cur_offset; + builder.append( + entries_inner_nulls + .slice(cur_offset, num_entries) + .null_count() + == 0, + ); + cur_offset = *next_offset as usize; + } + builder.finish() + }); + + let res_nulls = NullBuffer::union(entries.nulls(), entries_with_nulls.as_ref()); + + map_from_keys_values_offsets_nulls( + flat_keys, + flat_values, + &entries_offsets, + &entries_offsets, + None, + res_nulls.as_ref(), + ) +} diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs new file mode 100644 index 0000000000000..2f596b19b422f --- /dev/null +++ b/datafusion/spark/src/function/map/mod.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod map_from_arrays; +pub mod map_from_entries; +mod utils; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(map_from_arrays::MapFromArrays, map_from_arrays); +make_udf_function!(map_from_entries::MapFromEntries, map_from_entries); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + map_from_arrays, + "Creates a map from arrays of keys and values.", + keys values + )); + + export_functions!(( + map_from_entries, + "Creates a map from array>.", + arg1 + )); +} + +pub fn functions() -> Vec> { + vec![map_from_arrays(), map_from_entries()] +} diff --git a/datafusion/spark/src/function/map/utils.rs b/datafusion/spark/src/function/map/utils.rs new file mode 100644 index 0000000000000..b568f45403c30 --- /dev/null +++ b/datafusion/spark/src/function/map/utils.rs @@ -0,0 +1,231 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::borrow::Cow; +use std::collections::HashSet; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, AsArray, BooleanBuilder, MapArray, StructArray}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::compute::filter; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::{exec_err, Result, ScalarValue}; + +/// Helper function to get element [`DataType`] +/// from [`List`](DataType::List)/[`LargeList`](DataType::LargeList)/[`FixedSizeList`](DataType::FixedSizeList)
+/// [`Null`](DataType::Null) can be coerced to `ListType`([`Null`](DataType::Null)), so [`Null`](DataType::Null) is returned
+/// For all other types [`exec_err`] is raised +pub fn get_element_type(data_type: &DataType) -> Result<&DataType> { + match data_type { + DataType::Null => Ok(data_type), + DataType::List(element) + | DataType::LargeList(element) + | DataType::FixedSizeList(element, _) => Ok(element.data_type()), + _ => exec_err!( + "get_element_type expects List/LargeList/FixedSizeList/Null as argument, got {data_type:?}" + ), + } +} + +/// Helper function to get [`values`](arrow::array::ListArray::values) +/// from [`ListArray`](arrow::array::ListArray)/[`LargeListArray`](arrow::array::LargeListArray)/[`FixedSizeListArray`](arrow::array::FixedSizeListArray)
+/// [`NullArray`](arrow::array::NullArray) can be coerced to `ListType`([`Null`](DataType::Null)), so [`NullArray`](arrow::array::NullArray) is returned
+/// For all other types [`exec_err`] is raised +pub fn get_list_values(array: &ArrayRef) -> Result<&ArrayRef> { + match array.data_type() { + DataType::Null => Ok(array), + DataType::List(_) => Ok(array.as_list::().values()), + DataType::LargeList(_) => Ok(array.as_list::().values()), + DataType::FixedSizeList(..) => Ok(array.as_fixed_size_list().values()), + wrong_type => exec_err!( + "get_list_values expects List/LargeList/FixedSizeList/Null as argument, got {wrong_type:?}" + ), + } +} + +/// Helper function to get [`offsets`](arrow::array::ListArray::offsets) +/// from [`ListArray`](arrow::array::ListArray)/[`LargeListArray`](arrow::array::LargeListArray)/[`FixedSizeListArray`](arrow::array::FixedSizeListArray)
+/// For all other types [`exec_err`] is raised +pub fn get_list_offsets(array: &ArrayRef) -> Result> { + match array.data_type() { + DataType::List(_) => Ok(Cow::Borrowed(array.as_list::().offsets().as_ref())), + DataType::LargeList(_) => Ok(Cow::Owned( + array.as_list::() + .offsets() + .iter() + .map(|i| *i as i32) + .collect::>(), + )), + DataType::FixedSizeList(_, size) => Ok(Cow::Owned( + (0..=array.len() as i32).map(|i| size * i).collect() + )), + wrong_type => exec_err!( + "get_list_offsets expects List/LargeList/FixedSizeList as argument, got {wrong_type:?}" + ), + } +} + +/// Helper function to construct [`MapType`](DataType::Map) given K and V DataTypes for keys and values +/// - Map keys are unsorted +/// - Map keys are non-nullable +/// - Map entries are non-nullable +/// - Map values can be null +pub fn map_type_from_key_value_types( + key_type: &DataType, + value_type: &DataType, +) -> DataType { + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + // the key must not be nullable + Field::new("key", key_type.clone(), false), + Field::new("value", value_type.clone(), true), + ])), + false, // the entry is not nullable + )), + false, // the keys are not sorted + ) +} + +/// Helper function to construct MapArray from flattened ListArrays and OffsetBuffer +/// +/// Logic is close to `datafusion_functions_nested::map::make_map_array_internal`
+/// But there are some core differences: +/// 1. Input arrays are not [`ListArrays`](arrow::array::ListArray) itself, but their flattened [`values`](arrow::array::ListArray::values)
+/// So the inputs can be [`ListArray`](`arrow::array::ListArray`)/[`LargeListArray`](`arrow::array::LargeListArray`)/[`FixedSizeListArray`](`arrow::array::FixedSizeListArray`)
+/// To preserve the row info, [`offsets`](arrow::array::ListArray::offsets) and [`nulls`](arrow::array::ListArray::nulls) for both keys and values need to be provided
+/// [`FixedSizeListArray`](`arrow::array::FixedSizeListArray`) has no `offsets`, so they can be generated as a cumulative sum of it's `Size` +/// 2. Spark provides [spark.sql.mapKeyDedupPolicy](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961) +/// to handle duplicate keys
+/// For now, configurable functions are not supported by Datafusion
+/// So more permissive `LAST_WIN` option is used in this implementation (instead of `EXCEPTION`)
+/// `EXCEPTION` behaviour can still be achieved externally in cost of performance:
+/// `when(array_length(array_distinct(keys)) == array_length(keys), constructed_map)`
+/// `.otherwise(raise_error("duplicate keys occurred during map construction"))` +pub fn map_from_keys_values_offsets_nulls( + flat_keys: &ArrayRef, + flat_values: &ArrayRef, + keys_offsets: &[i32], + values_offsets: &[i32], + keys_nulls: Option<&NullBuffer>, + values_nulls: Option<&NullBuffer>, +) -> Result { + let (keys, values, offsets) = map_deduplicate_keys( + flat_keys, + flat_values, + keys_offsets, + values_offsets, + keys_nulls, + values_nulls, + )?; + let nulls = NullBuffer::union(keys_nulls, values_nulls); + + let fields = Fields::from(vec![ + Field::new("key", flat_keys.data_type().clone(), false), + Field::new("value", flat_values.data_type().clone(), true), + ]); + let entries = StructArray::try_new(fields.clone(), vec![keys, values], None)?; + let field = Arc::new(Field::new("entries", DataType::Struct(fields), false)); + Ok(Arc::new(MapArray::try_new( + field, offsets, entries, nulls, false, + )?)) +} + +fn map_deduplicate_keys( + flat_keys: &ArrayRef, + flat_values: &ArrayRef, + keys_offsets: &[i32], + values_offsets: &[i32], + keys_nulls: Option<&NullBuffer>, + values_nulls: Option<&NullBuffer>, +) -> Result<(ArrayRef, ArrayRef, OffsetBuffer)> { + let offsets_len = keys_offsets.len(); + let mut new_offsets = Vec::with_capacity(offsets_len); + + let mut cur_keys_offset = keys_offsets + .first() + .map(|offset| *offset as usize) + .unwrap_or(0); + let mut cur_values_offset = values_offsets + .first() + .map(|offset| *offset as usize) + .unwrap_or(0); + + let mut new_last_offset = 0; + new_offsets.push(new_last_offset); + + let mut keys_mask_builder = BooleanBuilder::new(); + let mut values_mask_builder = BooleanBuilder::new(); + for (row_idx, (next_keys_offset, next_values_offset)) in keys_offsets + .iter() + .zip(values_offsets.iter()) + .skip(1) + .enumerate() + { + let num_keys_entries = *next_keys_offset as usize - cur_keys_offset; + let num_values_entries = *next_values_offset as usize - cur_values_offset; + + let mut keys_mask_one = [false].repeat(num_keys_entries); + let mut values_mask_one = [false].repeat(num_values_entries); + + let key_is_valid = keys_nulls.is_none_or(|buf| buf.is_valid(row_idx)); + let value_is_valid = values_nulls.is_none_or(|buf| buf.is_valid(row_idx)); + + if key_is_valid && value_is_valid { + if num_keys_entries != num_values_entries { + return exec_err!("map_deduplicate_keys: keys and values lists in the same row must have equal lengths"); + } else if num_keys_entries != 0 { + let mut seen_keys = HashSet::new(); + + for cur_entry_idx in (0..num_keys_entries).rev() { + let key = ScalarValue::try_from_array( + &flat_keys, + cur_keys_offset + cur_entry_idx, + )? + .compacted(); + if seen_keys.contains(&key) { + // TODO: implement configuration and logic for spark.sql.mapKeyDedupPolicy=EXCEPTION (this is default spark-config) + // exec_err!("invalid argument: duplicate keys in map") + // https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961 + } else { + // This code implements deduplication logic for spark.sql.mapKeyDedupPolicy=LAST_WIN (this is NOT default spark-config) + keys_mask_one[cur_entry_idx] = true; + values_mask_one[cur_entry_idx] = true; + seen_keys.insert(key); + new_last_offset += 1; + } + } + } + } else { + // the result entry is NULL + // both current row offsets are skipped + // keys or values in the current row are marked false in the masks + } + keys_mask_builder.append_array(&keys_mask_one.into()); + values_mask_builder.append_array(&values_mask_one.into()); + new_offsets.push(new_last_offset); + cur_keys_offset += num_keys_entries; + cur_values_offset += num_values_entries; + } + let keys_mask = keys_mask_builder.finish(); + let values_mask = values_mask_builder.finish(); + let needed_keys = filter(&flat_keys, &keys_mask)?; + let needed_values = filter(&flat_values, &values_mask)?; + let offsets = OffsetBuffer::new(new_offsets.into()); + Ok((needed_keys, needed_values, offsets)) +} diff --git a/datafusion/spark/src/function/math/expm1.rs b/datafusion/spark/src/function/math/expm1.rs new file mode 100644 index 0000000000000..42eccf3a2431a --- /dev/null +++ b/datafusion/spark/src/function/math/expm1.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float64Type}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkExpm1 { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkExpm1 { + fn default() -> Self { + Self::new() + } +} + +impl SparkExpm1 { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkExpm1 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "expm1" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 1 { + return Err(invalid_arg_count_exec_err("expm1", (1, 1), args.args.len())); + } + match &args.args[0] { + ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float64(value.map(|x| x.exp_m1()))), + ), + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float64Type>(|x| x.exp_m1()), + ) + as ArrayRef)), + other => Err(unsupported_data_type_exec_err( + "expm1", + format!("{}", DataType::Float64).as_str(), + other, + )), + }, + other => Err(unsupported_data_type_exec_err( + "expm1", + format!("{}", DataType::Float64).as_str(), + &other.data_type(), + )), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return Err(invalid_arg_count_exec_err("expm1", (1, 1), arg_types.len())); + } + if arg_types[0].is_numeric() { + Ok(vec![DataType::Float64]) + } else { + Err(unsupported_data_type_exec_err( + "expm1", + "Numeric Type", + &arg_types[0], + )) + } + } +} + +#[cfg(test)] +mod tests { + use crate::function::math::expm1::SparkExpm1; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, Float64Array}; + use arrow::datatypes::DataType::Float64; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_expm1_float64_invoke { + ($INPUT:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkExpm1::new(), + vec![ColumnarValue::Scalar(ScalarValue::Float64($INPUT))], + $EXPECTED, + f64, + Float64, + Float64Array + ); + }; + } + + #[test] + fn test_expm1_invoke() -> Result<()> { + test_expm1_float64_invoke!(Some(0f64), Ok(Some(0.0f64))); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/math/factorial.rs b/datafusion/spark/src/function/math/factorial.rs new file mode 100644 index 0000000000000..4921e73d262a3 --- /dev/null +++ b/datafusion/spark/src/function/math/factorial.rs @@ -0,0 +1,194 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Int32, Int64}; +use datafusion_common::cast::as_int32_array; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkFactorial { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkFactorial { + fn default() -> Self { + Self::new() + } +} + +impl SparkFactorial { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![Int32], Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkFactorial { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "factorial" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_factorial(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +const FACTORIALS: [i64; 21] = [ + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800, + 87178291200, + 1307674368000, + 20922789888000, + 355687428096000, + 6402373705728000, + 121645100408832000, + 2432902008176640000, +]; + +pub fn spark_factorial(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return internal_err!("`factorial` expects exactly one argument"); + } + + match &args[0] { + ColumnarValue::Scalar(ScalarValue::Int32(value)) => { + let result = compute_factorial(*value); + Ok(ColumnarValue::Scalar(ScalarValue::Int64(result))) + } + ColumnarValue::Scalar(other) => { + exec_err!("`factorial` got an unexpected scalar type: {}", other) + } + ColumnarValue::Array(array) => match array.data_type() { + Int32 => { + let array = as_int32_array(array)?; + + let result: Int64Array = array.iter().map(compute_factorial).collect(); + + Ok(ColumnarValue::Array(Arc::new(result))) + } + other => { + exec_err!("`factorial` got an unexpected argument type: {}", other) + } + }, + } +} + +#[inline] +fn compute_factorial(num: Option) -> Option { + num.filter(|&v| (0..=20).contains(&v)) + .map(|v| FACTORIALS[v as usize]) +} + +#[cfg(test)] +mod test { + use crate::function::math::factorial::spark_factorial; + use arrow::array::{Int32Array, Int64Array}; + use datafusion_common::cast::as_int64_array; + use datafusion_common::ScalarValue; + use datafusion_expr::ColumnarValue; + use std::sync::Arc; + + #[test] + fn test_spark_factorial_array() { + let input = Int32Array::from(vec![ + Some(-1), + Some(0), + Some(1), + Some(2), + Some(4), + Some(20), + Some(21), + None, + ]); + + let args = ColumnarValue::Array(Arc::new(input)); + let result = spark_factorial(&[args]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let actual = as_int64_array(&result).unwrap(); + let expected = Int64Array::from(vec![ + None, + Some(1), + Some(1), + Some(2), + Some(24), + Some(2432902008176640000), + None, + None, + ]); + + assert_eq!(actual, &expected); + } + + #[test] + fn test_spark_factorial_scalar() { + let input = ScalarValue::Int32(Some(5)); + + let args = ColumnarValue::Scalar(input); + let result = spark_factorial(&[args]).unwrap(); + let result = match result { + ColumnarValue::Scalar(ScalarValue::Int64(val)) => val, + _ => panic!("Expected scalar"), + }; + let actual = result.unwrap(); + let expected = 120_i64; + + assert_eq!(actual, expected); + } +} diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs new file mode 100644 index 0000000000000..cdd13e9033265 --- /dev/null +++ b/datafusion/spark/src/function/math/hex.rs @@ -0,0 +1,427 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use arrow::array::{Array, StringArray}; +use arrow::datatypes::DataType; +use arrow::{ + array::{as_dictionary_array, as_largestring_array, as_string_array}, + datatypes::Int32Type, +}; +use datafusion_common::cast::as_string_view_array; +use datafusion_common::{ + cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, + exec_err, internal_err, DataFusionError, +}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; +use std::fmt::Write; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkHex { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkHex { + fn default() -> Self { + Self::new() + } +} + +impl SparkHex { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkHex { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "hex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + spark_hex(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types( + &self, + arg_types: &[DataType], + ) -> datafusion_common::Result> { + if arg_types.len() != 1 { + return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len())); + } + match &arg_types[0] { + DataType::Int64 + | DataType::Utf8 + | DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]), + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Int64 + | DataType::Utf8 + | DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]), + other => { + if other.is_numeric() { + Ok(vec![DataType::Dictionary( + key_type.clone(), + Box::new(DataType::Int64), + )]) + } else { + Err(unsupported_data_type_exec_err( + "hex", + "Numeric, String, or Binary", + &arg_types[0], + )) + } + } + }, + other => { + if other.is_numeric() { + Ok(vec![DataType::Int64]) + } else { + Err(unsupported_data_type_exec_err( + "hex", + "Numeric, String, or Binary", + &arg_types[0], + )) + } + } + } + } +} + +fn hex_int64(num: i64) -> String { + format!("{num:X}") +} + +#[inline(always)] +fn hex_encode>(data: T, lower_case: bool) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + if lower_case { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); + } + } else { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02X}").unwrap(); + } + } + s +} + +#[inline(always)] +fn hex_bytes>( + bytes: T, + lowercase: bool, +) -> Result { + let hex_string = hex_encode(bytes, lowercase); + Ok(hex_string) +} + +/// Spark-compatible `hex` function +pub fn spark_hex(args: &[ColumnarValue]) -> Result { + compute_hex(args, false) +} + +/// Spark-compatible `sha2` function +pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result { + compute_hex(args, true) +} + +pub fn compute_hex( + args: &[ColumnarValue], + lowercase: bool, +) -> Result { + if args.len() != 1 { + return internal_err!("hex expects exactly one argument"); + } + + let input = match &args[0] { + ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?), + ColumnarValue::Array(_) => args[0].clone(), + }; + + match &input { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + let array = as_int64_array(array)?; + + let hexed_array: StringArray = + array.iter().map(|v| v.map(hex_int64)).collect(); + + Ok(ColumnarValue::Array(Arc::new(hexed_array))) + } + DataType::Utf8 => { + let array = as_string_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Utf8View => { + let array = as_string_view_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::LargeUtf8 => { + let array = as_largestring_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Binary => { + let array = as_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::FixedSizeBinary(_) => { + let array = as_fixed_size_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(&array); + + let values = match **value_type { + DataType::Int64 => as_int64_array(dict.values())? + .iter() + .map(|v| v.map(hex_int64)) + .collect::>(), + DataType::Utf8 => as_string_array(dict.values()) + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?, + DataType::Binary => as_binary_array(dict.values())? + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?, + _ => exec_err!( + "hex got an unexpected argument type: {}", + array.data_type() + )?, + }; + + let new_values: Vec> = dict + .keys() + .iter() + .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) + .collect(); + + let string_array_values = StringArray::from(new_values); + + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), + }, + _ => exec_err!("native hex does not support scalar values at this time"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{Int64Array, StringArray}; + use arrow::{ + array::{ + as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, + StringBuilder, StringDictionaryBuilder, + }, + datatypes::{Int32Type, Int64Type}, + }; + use datafusion_expr::ColumnarValue; + + #[test] + fn test_dictionary_hex_utf8() { + let mut input_builder = StringDictionaryBuilder::::new(); + input_builder.append_value("hi"); + input_builder.append_value("bye"); + input_builder.append_null(); + input_builder.append_value("rust"); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("6869"); + string_builder.append_value("627965"); + string_builder.append_null(); + string_builder.append_value("72757374"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_int64() { + let mut input_builder = PrimitiveDictionaryBuilder::::new(); + input_builder.append_value(1); + input_builder.append_value(2); + input_builder.append_null(); + input_builder.append_value(3); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("1"); + string_builder.append_value("2"); + string_builder.append_null(); + string_builder.append_value("3"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_binary() { + let mut input_builder = BinaryDictionaryBuilder::::new(); + input_builder.append_value("1"); + input_builder.append_value("j"); + input_builder.append_null(); + input_builder.append_value("3"); + let input = input_builder.finish(); + + let mut expected_builder = StringBuilder::new(); + expected_builder.append_value("31"); + expected_builder.append_value("6A"); + expected_builder.append_null(); + expected_builder.append_value("33"); + let expected = expected_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_hex_int64() { + let num = 1234; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + } + + #[test] + fn test_spark_hex_int64() { + let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); + let columnar_value = ColumnarValue::Array(Arc::new(int_array)); + + let result = super::spark_hex(&[columnar_value]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let string_array = as_string_array(&result); + let expected_array = StringArray::from(vec![ + Some("1".to_string()), + Some("2".to_string()), + None, + Some("3".to_string()), + ]); + + assert_eq!(string_array, &expected_array); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs new file mode 100644 index 0000000000000..092335e4aa18d --- /dev/null +++ b/datafusion/spark/src/function/math/mod.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod expm1; +pub mod factorial; +pub mod hex; +pub mod modulus; +pub mod rint; +pub mod width_bucket; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(expm1::SparkExpm1, expm1); +make_udf_function!(factorial::SparkFactorial, factorial); +make_udf_function!(hex::SparkHex, hex); +make_udf_function!(modulus::SparkMod, modulus); +make_udf_function!(modulus::SparkPmod, pmod); +make_udf_function!(rint::SparkRint, rint); +make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); + export_functions!(( + factorial, + "Returns the factorial of expr. expr is [0..20]. Otherwise, null.", + arg1 + )); + export_functions!((hex, "Computes hex value of the given column.", arg1)); + export_functions!((modulus, "Returns the remainder of division of the first argument by the second argument.", arg1 arg2)); + export_functions!((pmod, "Returns the positive remainder of division of the first argument by the second argument.", arg1 arg2)); + export_functions!((rint, "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1)); + export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); +} + +pub fn functions() -> Vec> { + vec![ + expm1(), + factorial(), + hex(), + modulus(), + pmod(), + rint(), + width_bucket(), + ] +} diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs new file mode 100644 index 0000000000000..fea0297a7ae94 --- /dev/null +++ b/datafusion/spark/src/function/math/modulus.rs @@ -0,0 +1,609 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::kernels::numeric::add; +use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip}; +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; + +/// Spark-compatible `mod` function +/// This function directly uses Arrow's arithmetic_op function for modulo operations +pub fn spark_mod(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return internal_err!("mod expects exactly two arguments"); + } + let args = ColumnarValue::values_to_arrays(args)?; + let result = rem(&args[0], &args[1])?; + Ok(ColumnarValue::Array(result)) +} + +/// Spark-compatible `pmod` function +/// This function directly uses Arrow's arithmetic_op function for modulo operations +pub fn spark_pmod(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return internal_err!("pmod expects exactly two arguments"); + } + let args = ColumnarValue::values_to_arrays(args)?; + let left = &args[0]; + let right = &args[1]; + let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?; + let result = rem(left, right)?; + let neg = lt(&result, &zero)?; + let plus = zip(&neg, right, &zero)?; + let result = add(&plus, &result)?; + let result = rem(&result, right)?; + Ok(ColumnarValue::Array(result)) +} + +/// SparkMod implements the Spark-compatible modulo function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkMod { + signature: Signature, +} + +impl Default for SparkMod { + fn default() -> Self { + Self::new() + } +} + +impl SparkMod { + pub fn new() -> Self { + Self { + signature: Signature::numeric(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkMod { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "mod" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return internal_err!("mod expects exactly two arguments"); + } + + // Return the same type as the first argument for simplicity + // Arrow's rem function handles type promotion internally + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_mod(&args.args) + } +} + +/// SparkMod implements the Spark-compatible modulo function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkPmod { + signature: Signature, +} + +impl Default for SparkPmod { + fn default() -> Self { + Self::new() + } +} + +impl SparkPmod { + pub fn new() -> Self { + Self { + signature: Signature::numeric(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkPmod { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "pmod" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return internal_err!("pmod expects exactly two arguments"); + } + + // Return the same type as the first argument for simplicity + // Arrow's rem function handles type promotion internally + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_pmod(&args.args) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use super::*; + use arrow::array::*; + use datafusion_common::ScalarValue; + + #[test] + fn test_mod_int32() { + let left = Int32Array::from(vec![Some(10), Some(7), Some(15), None]); + let right = Int32Array::from(vec![Some(3), Some(2), Some(4), Some(5)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_mod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert_eq!(result_int32.value(0), 1); // 10 % 3 = 1 + assert_eq!(result_int32.value(1), 1); // 7 % 2 = 1 + assert_eq!(result_int32.value(2), 3); // 15 % 4 = 3 + assert!(result_int32.is_null(3)); // None % 5 = None + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_mod_int64() { + let left = Int64Array::from(vec![Some(100), Some(50), Some(200)]); + let right = Int64Array::from(vec![Some(30), Some(25), Some(60)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_mod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int64 = + result_array.as_any().downcast_ref::().unwrap(); + assert_eq!(result_int64.value(0), 10); // 100 % 30 = 10 + assert_eq!(result_int64.value(1), 0); // 50 % 25 = 0 + assert_eq!(result_int64.value(2), 20); // 200 % 60 = 20 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_mod_float64() { + let left = Float64Array::from(vec![ + Some(10.5), + Some(7.2), + Some(15.8), + Some(f64::NAN), + Some(f64::INFINITY), + Some(5.0), + Some(5.0), + Some(f64::NAN), + Some(f64::INFINITY), + ]); + let right = Float64Array::from(vec![ + Some(3.0), + Some(2.5), + Some(4.2), + Some(2.0), + Some(2.0), + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::INFINITY), + Some(f64::NAN), + ]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_mod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_float64 = result_array + .as_any() + .downcast_ref::() + .unwrap(); + // Regular cases + assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 % 3.0 = 1.5 + assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); // 7.2 % 2.5 = 2.2 + assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); // 15.8 % 4.2 = 3.2 + // nan % 2.0 = nan + assert!(result_float64.value(3).is_nan()); + // inf % 2.0 = nan (IEEE 754) + assert!(result_float64.value(4).is_nan()); + // 5.0 % nan = nan + assert!(result_float64.value(5).is_nan()); + // 5.0 % inf = 5.0 + assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON); + // nan % inf = nan + assert!(result_float64.value(7).is_nan()); + // inf % nan = nan + assert!(result_float64.value(8).is_nan()); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_mod_float32() { + let left = Float32Array::from(vec![ + Some(10.5), + Some(7.2), + Some(15.8), + Some(f32::NAN), + Some(f32::INFINITY), + Some(5.0), + Some(5.0), + Some(f32::NAN), + Some(f32::INFINITY), + ]); + let right = Float32Array::from(vec![ + Some(3.0), + Some(2.5), + Some(4.2), + Some(2.0), + Some(2.0), + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::INFINITY), + Some(f32::NAN), + ]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_mod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_float32 = result_array + .as_any() + .downcast_ref::() + .unwrap(); + // Regular cases + assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 % 3.0 = 1.5 + assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); // 7.2 % 2.5 = 2.2 + assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 % 4.2 = 3.2 + // nan % 2.0 = nan + assert!(result_float32.value(3).is_nan()); + // inf % 2.0 = nan (IEEE 754) + assert!(result_float32.value(4).is_nan()); + // 5.0 % nan = nan + assert!(result_float32.value(5).is_nan()); + // 5.0 % inf = 5.0 + assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON); + // nan % inf = nan + assert!(result_float32.value(7).is_nan()); + // inf % nan = nan + assert!(result_float32.value(8).is_nan()); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_mod_scalar() { + let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]); + let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3))); + + let left_value = ColumnarValue::Array(Arc::new(left)); + + let result = spark_mod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert_eq!(result_int32.value(0), 1); // 10 % 3 = 1 + assert_eq!(result_int32.value(1), 1); // 7 % 3 = 1 + assert_eq!(result_int32.value(2), 0); // 15 % 3 = 0 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_mod_wrong_arg_count() { + let left = Int32Array::from(vec![Some(10)]); + let left_value = ColumnarValue::Array(Arc::new(left)); + + let result = spark_mod(&[left_value]); + assert!(result.is_err()); + } + + #[test] + fn test_mod_zero_division() { + let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]); + let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_mod(&[left_value, right_value]); + assert!(result.is_err()); // Division by zero should error + } + + // PMOD tests + #[test] + fn test_pmod_int32() { + let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15), None]); + let right = Int32Array::from(vec![Some(3), Some(3), Some(4), Some(4), Some(5)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert_eq!(result_int32.value(0), 1); // 10 pmod 3 = 1 + assert_eq!(result_int32.value(1), 2); // -7 pmod 3 = 2 (positive remainder) + assert_eq!(result_int32.value(2), 3); // 15 pmod 4 = 3 + assert_eq!(result_int32.value(3), 1); // -15 pmod 4 = 1 (positive remainder) + assert!(result_int32.is_null(4)); // None pmod 5 = None + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_pmod_int64() { + let left = Int64Array::from(vec![Some(100), Some(-50), Some(200), Some(-200)]); + let right = Int64Array::from(vec![Some(30), Some(30), Some(60), Some(60)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int64 = + result_array.as_any().downcast_ref::().unwrap(); + assert_eq!(result_int64.value(0), 10); // 100 pmod 30 = 10 + assert_eq!(result_int64.value(1), 10); // -50 pmod 30 = 10 (positive remainder) + assert_eq!(result_int64.value(2), 20); // 200 pmod 60 = 20 + assert_eq!(result_int64.value(3), 40); // -200 pmod 60 = 40 (positive remainder) + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_pmod_float64() { + let left = Float64Array::from(vec![ + Some(10.5), + Some(-7.2), + Some(15.8), + Some(-15.8), + Some(f64::NAN), + Some(f64::INFINITY), + Some(5.0), + Some(-5.0), + ]); + let right = Float64Array::from(vec![ + Some(3.0), + Some(3.0), + Some(4.2), + Some(4.2), + Some(2.0), + Some(2.0), + Some(f64::INFINITY), + Some(f64::INFINITY), + ]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_float64 = result_array + .as_any() + .downcast_ref::() + .unwrap(); + // Regular cases + assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 pmod 3.0 = 1.5 + assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive) + assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); // 15.8 pmod 4.2 = 3.2 + assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); // -15.8 pmod 4.2 = 1.0 (positive) + // nan pmod 2.0 = nan + assert!(result_float64.value(4).is_nan()); + // inf pmod 2.0 = nan (IEEE 754) + assert!(result_float64.value(5).is_nan()); + // 5.0 pmod inf = 5.0 + assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON); + // -5.0 pmod inf = NaN + assert!(result_float64.value(7).is_nan()); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_pmod_float32() { + let left = Float32Array::from(vec![ + Some(10.5), + Some(-7.2), + Some(15.8), + Some(-15.8), + Some(f32::NAN), + Some(f32::INFINITY), + Some(5.0), + Some(-5.0), + ]); + let right = Float32Array::from(vec![ + Some(3.0), + Some(3.0), + Some(4.2), + Some(4.2), + Some(2.0), + Some(2.0), + Some(f32::INFINITY), + Some(f32::INFINITY), + ]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_float32 = result_array + .as_any() + .downcast_ref::() + .unwrap(); + // Regular cases + assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 pmod 3.0 = 1.5 + assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive) + assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 pmod 4.2 = 3.2 + assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); // -15.8 pmod 4.2 = 1.0 (positive) + // nan pmod 2.0 = nan + assert!(result_float32.value(4).is_nan()); + // inf pmod 2.0 = nan (IEEE 754) + assert!(result_float32.value(5).is_nan()); + // 5.0 pmod inf = 5.0 + assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0); + // -5.0 pmod inf = NaN + assert!(result_float32.value(7).is_nan()); + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_pmod_scalar() { + let left = Int32Array::from(vec![Some(10), Some(-7), Some(15), Some(-15)]); + let right_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(3))); + + let left_value = ColumnarValue::Array(Arc::new(left)); + + let result = spark_pmod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert_eq!(result_int32.value(0), 1); // 10 pmod 3 = 1 + assert_eq!(result_int32.value(1), 2); // -7 pmod 3 = 2 (positive remainder) + assert_eq!(result_int32.value(2), 0); // 15 pmod 3 = 0 + assert_eq!(result_int32.value(3), 0); // -15 pmod 3 = 0 (positive remainder) + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_pmod_wrong_arg_count() { + let left = Int32Array::from(vec![Some(10)]); + let left_value = ColumnarValue::Array(Arc::new(left)); + + let result = spark_pmod(&[left_value]); + assert!(result.is_err()); + } + + #[test] + fn test_pmod_zero_division() { + let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]); + let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value]); + assert!(result.is_err()); // Division by zero should error + } + + #[test] + fn test_pmod_negative_divisor() { + // PMOD with negative divisor should still work like regular mod + let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]); + let right = Int32Array::from(vec![Some(-3), Some(-3), Some(-4)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert_eq!(result_int32.value(0), 1); // 10 pmod -3 = 1 + assert_eq!(result_int32.value(1), -1); // -7 pmod -3 = -1 + assert_eq!(result_int32.value(2), 3); // 15 pmod -4 = 3 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_pmod_edge_cases() { + // Test edge cases for PMOD + let left = Int32Array::from(vec![ + Some(0), // 0 pmod 5 = 0 + Some(-1), // -1 pmod 5 = 4 + Some(1), // 1 pmod 5 = 1 + Some(-5), // -5 pmod 5 = 0 + Some(5), // 5 pmod 5 = 0 + Some(-6), // -6 pmod 5 = 4 + Some(6), // 6 pmod 5 = 1 + ]); + let right = Int32Array::from(vec![ + Some(5), + Some(5), + Some(5), + Some(5), + Some(5), + Some(5), + Some(5), + ]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value]).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert_eq!(result_int32.value(0), 0); // 0 pmod 5 = 0 + assert_eq!(result_int32.value(1), 4); // -1 pmod 5 = 4 + assert_eq!(result_int32.value(2), 1); // 1 pmod 5 = 1 + assert_eq!(result_int32.value(3), 0); // -5 pmod 5 = 0 + assert_eq!(result_int32.value(4), 0); // 5 pmod 5 = 0 + assert_eq!(result_int32.value(5), 4); // -6 pmod 5 = 4 + assert_eq!(result_int32.value(6), 1); // 6 pmod 5 = 1 + } else { + panic!("Expected array result"); + } + } +} diff --git a/datafusion/spark/src/function/math/rint.rs b/datafusion/spark/src/function/math/rint.rs new file mode 100644 index 0000000000000..9b61529c5bc44 --- /dev/null +++ b/datafusion/spark/src/function/math/rint.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType::{ + Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, +}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRint { + signature: Signature, +} + +impl Default for SparkRint { + fn default() -> Self { + Self::new() + } +} + +impl SparkRint { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkRint { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rint" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_rint, vec![])(&args.args) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // round preserves the order of the first argument + if input.len() == 1 { + let value = &input[0]; + Ok(value.sort_properties) + } else { + Ok(SortProperties::default()) + } + } +} + +pub fn spark_rint(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("rint expects exactly 1 argument, got {}", args.len()); + } + + let array: &dyn Array = args[0].as_ref(); + match args[0].data_type() { + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { + Ok(cast(array, &Float64)?) + } + Float64 => { + let array = array + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| value.round_ties_even()); + Ok(Arc::new(array)) + } + Float32 => { + let array = array + .as_primitive::() + .unary::<_, Float64Type>(|value: f32| value.round_ties_even() as f64); + Ok(Arc::new(array)) + } + _ => { + exec_err!( + "rint expects a numeric argument, got {}", + args[0].data_type() + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Float64Array; + + #[test] + fn test_rint_positive_decimals() { + // Test positive decimal rounding + let result = spark_rint(&[Arc::new(Float64Array::from(vec![12.3456]))]).unwrap(); + assert_eq!(result.as_ref(), &Float64Array::from(vec![12.0])); + + // Test rounding to nearest even (banker's rounding) + let result = spark_rint(&[Arc::new(Float64Array::from(vec![2.5]))]).unwrap(); + assert_eq!(result.as_ref(), &Float64Array::from(vec![2.0])); + + let result = spark_rint(&[Arc::new(Float64Array::from(vec![3.5]))]).unwrap(); + assert_eq!(result.as_ref(), &Float64Array::from(vec![4.0])); + } + + #[test] + fn test_rint_negative_decimals() { + // Test negative decimal rounding + let result = spark_rint(&[Arc::new(Float64Array::from(vec![-12.3456]))]).unwrap(); + assert_eq!(result.as_ref(), &Float64Array::from(vec![-12.0])); + + // Test negative rounding to nearest even + let result = spark_rint(&[Arc::new(Float64Array::from(vec![-2.5]))]).unwrap(); + assert_eq!(result.as_ref(), &Float64Array::from(vec![-2.0])); + } + + #[test] + fn test_rint_integers() { + // Test integer input (should return as float64) + let result = spark_rint(&[Arc::new(Float64Array::from(vec![42.0]))]).unwrap(); + assert_eq!(result.as_ref(), &Float64Array::from(vec![42.0])); + } + + #[test] + fn test_rint_null() { + let result = spark_rint(&[Arc::new(Float64Array::from(vec![None]))]).unwrap(); + assert_eq!(result.as_ref(), &Float64Array::from(vec![None])); + } + + #[test] + fn test_rint_zero() { + // Test zero + let result = spark_rint(&[Arc::new(Float64Array::from(vec![0.0]))]).unwrap(); + assert_eq!(result.as_ref(), &Float64Array::from(vec![0.0])); + } +} diff --git a/datafusion/spark/src/function/math/width_bucket.rs b/datafusion/spark/src/function/math/width_bucket.rs new file mode 100644 index 0000000000000..45a0d843b7ed7 --- /dev/null +++ b/datafusion/spark/src/function/math/width_bucket.rs @@ -0,0 +1,788 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use crate::function::error_utils::unsupported_data_types_exec_err; +use arrow::array::{ + Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray, + IntervalYearMonthArray, +}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval}; +use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth}; +use datafusion_common::cast::{ + as_duration_microsecond_array, as_float64_array, as_int32_array, + as_interval_mdn_array, as_interval_ym_array, +}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::type_coercion::is_signed_numeric; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_functions::utils::make_scalar_function; + +use arrow::array::{Int32Array, Int32Builder}; +use arrow::datatypes::TimeUnit::Microsecond; +use datafusion_expr::Volatility::Immutable; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkWidthBucket { + signature: Signature, +} + +impl Default for SparkWidthBucket { + fn default() -> Self { + Self::new() + } +} + +impl SparkWidthBucket { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Immutable), + } + } +} + +impl ScalarUDFImpl for SparkWidthBucket { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "width_bucket" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(width_bucket_kern, vec![])(&args.args) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + if input.len() == 1 { + let value = &input[0]; + Ok(value.sort_properties) + } else { + Ok(SortProperties::default()) + } + } + + fn coerce_types(&self, types: &[DataType]) -> Result> { + use DataType::*; + + let (v, lo, hi, n) = (&types[0], &types[1], &types[2], &types[3]); + + match (v, lo, hi, n) { + (a, b, c, &(Int8 | Int16 | Int32 | Int64)) + if is_signed_numeric(a) + && is_signed_numeric(b) + && is_signed_numeric(c) => + { + Ok(vec![Float64, Float64, Float64, Int32]) + } + ( + &Duration(_), + &Duration(_), + &Duration(_), + &(Int8 | Int16 | Int32 | Int64), + ) => Ok(vec![ + Duration(Microsecond), + Duration(Microsecond), + Duration(Microsecond), + Int32, + ]), + ( + &Interval(MonthDayNano), + &Interval(MonthDayNano), + &Interval(MonthDayNano), + &(Int8 | Int16 | Int32 | Int64), + ) => Ok(vec![ + Interval(MonthDayNano), + Interval(MonthDayNano), + Interval(MonthDayNano), + Int32, + ]), + ( + &Interval(YearMonth), + &Interval(YearMonth), + &Interval(YearMonth), + &(Int8 | Int16 | Int32 | Int64), + ) => Ok(vec![ + Interval(YearMonth), + Interval(YearMonth), + Interval(YearMonth), + Int32, + ]), + + _ => exec_err!( + "width_bucket expects a numeric argument, got {} {} {} {}", + types[0], + types[1], + types[2], + types[3] + ), + } + } +} + +fn width_bucket_kern(args: &[ArrayRef]) -> Result { + let [v, minv, maxv, nb] = args else { + return exec_err!( + "width_bucket expects exactly 4 argument, got {}", + args.len() + ); + }; + + match v.data_type() { + Float64 => { + let v = as_float64_array(v)?; + let min = as_float64_array(minv)?; + let max = as_float64_array(maxv)?; + let n_bucket = as_int32_array(nb)?; + Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket))) + } + Duration(Microsecond) => { + let v = as_duration_microsecond_array(v)?; + let min = as_duration_microsecond_array(minv)?; + let max = as_duration_microsecond_array(maxv)?; + let n_bucket = as_int32_array(nb)?; + Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket))) + } + Interval(YearMonth) => { + let v = as_interval_ym_array(v)?; + let min = as_interval_ym_array(minv)?; + let max = as_interval_ym_array(maxv)?; + let n_bucket = as_int32_array(nb)?; + Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket))) + } + Interval(MonthDayNano) => { + let v = as_interval_mdn_array(v)?; + let min = as_interval_mdn_array(minv)?; + let max = as_interval_mdn_array(maxv)?; + let n_bucket = as_int32_array(nb)?; + Ok(Arc::new(width_bucket_interval_mdn_exact(v, min, max, n_bucket))) + } + + + other => Err(unsupported_data_types_exec_err( + "width_bucket", + "Float/Decimal OR Duration OR Interval(YearMonth) for first 3 args; Int for 4th", + &[ + other.clone(), + minv.data_type().clone(), + maxv.data_type().clone(), + nb.data_type().clone(), + ], + )), + } +} + +macro_rules! width_bucket_kernel_impl { + ($name:ident, $arr_ty:ty, $to_f64:expr, $check_nan:expr) => { + pub(crate) fn $name( + v: &$arr_ty, + min: &$arr_ty, + max: &$arr_ty, + n_bucket: &Int32Array, + ) -> Int32Array { + let len = v.len(); + let mut b = Int32Builder::with_capacity(len); + + for i in 0..len { + if v.is_null(i) || min.is_null(i) || max.is_null(i) || n_bucket.is_null(i) + { + b.append_null(); + continue; + } + let x = ($to_f64)(v, i); + let l = ($to_f64)(min, i); + let h = ($to_f64)(max, i); + let buckets = n_bucket.value(i); + + if buckets <= 0 { + b.append_null(); + continue; + } + if $check_nan { + if !x.is_finite() || !l.is_finite() || !h.is_finite() { + b.append_null(); + continue; + } + } + + let ord = match l.partial_cmp(&h) { + Some(o) => o, + None => { + b.append_null(); + continue; + } + }; + if matches!(ord, std::cmp::Ordering::Equal) { + b.append_null(); + continue; + } + let asc = matches!(ord, std::cmp::Ordering::Less); + + if asc { + if x < l { + b.append_value(0); + continue; + } + if x >= h { + b.append_value(buckets + 1); + continue; + } + } else { + if x > l { + b.append_value(0); + continue; + } + if x <= h { + b.append_value(buckets + 1); + continue; + } + } + + let width = (h - l) / (buckets as f64); + if width == 0.0 || !width.is_finite() { + b.append_null(); + continue; + } + let mut bucket = ((x - l) / width).floor() as i32 + 1; + if bucket < 1 { + bucket = 1; + } + if bucket > buckets + 1 { + bucket = buckets + 1; + } + + b.append_value(bucket); + } + + b.finish() + } + }; +} + +width_bucket_kernel_impl!( + width_bucket_float64, + Float64Array, + |arr: &Float64Array, i: usize| arr.value(i), + true +); + +width_bucket_kernel_impl!( + width_bucket_i64_as_float, + DurationMicrosecondArray, + |arr: &DurationMicrosecondArray, i: usize| arr.value(i) as f64, + false +); + +width_bucket_kernel_impl!( + width_bucket_i32_as_float, + IntervalYearMonthArray, + |arr: &IntervalYearMonthArray, i: usize| arr.value(i) as f64, + false +); +const NS_PER_DAY_I128: i128 = 86_400_000_000_000; +pub(crate) fn width_bucket_interval_mdn_exact( + v: &IntervalMonthDayNanoArray, + lo: &IntervalMonthDayNanoArray, + hi: &IntervalMonthDayNanoArray, + n: &Int32Array, +) -> Int32Array { + let len = v.len(); + let mut b = Int32Builder::with_capacity(len); + + for i in 0..len { + if v.is_null(i) || lo.is_null(i) || hi.is_null(i) || n.is_null(i) { + b.append_null(); + continue; + } + let buckets = n.value(i); + if buckets <= 0 { + b.append_null(); + continue; + } + + let x = v.value(i); + let l = lo.value(i); + let h = hi.value(i); + + // asc/desc + // Values of IntervalMonthDayNano are compared using their binary representation, which can lead to surprising results. + let asc = (l.months, l.days, l.nanoseconds) < (h.months, h.days, h.nanoseconds); + if (l.months, l.days, l.nanoseconds) == (h.months, h.days, h.nanoseconds) { + b.append_null(); + continue; + } + + // ------------------- only month ------------------- + if l.days == h.days && l.nanoseconds == h.nanoseconds && l.months != h.months { + let x_m = x.months as f64; + let l_m = l.months as f64; + let h_m = h.months as f64; + + if asc { + if x_m < l_m { + b.append_value(0); + continue; + } + if x_m >= h_m { + b.append_value(buckets + 1); + continue; + } + } else { + if x_m > l_m { + b.append_value(0); + continue; + } + if x_m <= h_m { + b.append_value(buckets + 1); + continue; + } + } + + let width = (h_m - l_m) / (buckets as f64); + if width == 0.0 || !width.is_finite() { + b.append_null(); + continue; + } + + let mut bucket = ((x_m - l_m) / width).floor() as i32 + 1; + if bucket < 1 { + bucket = 1; + } + if bucket > buckets + 1 { + bucket = buckets + 1; + } + b.append_value(bucket); + continue; + } + + // --------------- months equals ------------------- + if l.months == h.months { + let base_days = l.days as i128; + let base_ns = l.nanoseconds as i128; + + let xf = (x.days as i128 - base_days) * NS_PER_DAY_I128 + + (x.nanoseconds as i128 - base_ns); + let hf = (h.days as i128 - base_days) * NS_PER_DAY_I128 + + (h.nanoseconds as i128 - base_ns); + + let x_f = xf as f64; + let l_f = 0.0; + let h_f = hf as f64; + + if asc { + if x_f < l_f { + b.append_value(0); + continue; + } + if x_f >= h_f { + b.append_value(buckets + 1); + continue; + } + } else { + if x_f > l_f { + b.append_value(0); + continue; + } + if x_f <= h_f { + b.append_value(buckets + 1); + continue; + } + } + + let width = (h_f - l_f) / (buckets as f64); + if width == 0.0 || !width.is_finite() { + b.append_null(); + continue; + } + + let mut bucket = ((x_f - l_f) / width).floor() as i32 + 1; + if bucket < 1 { + bucket = 1; + } + if bucket > buckets + 1 { + bucket = buckets + 1; + } + b.append_value(bucket); + continue; + } + + b.append_null(); + } + + b.finish() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow::array::{ + ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, + IntervalYearMonthArray, + }; + use arrow::datatypes::IntervalMonthDayNano; + + // --- Helpers ------------------------------------------------------------- + + fn i32_array_all(len: usize, val: i32) -> Arc { + Arc::new(Int32Array::from(vec![val; len])) + } + + fn f64_array(vals: &[f64]) -> Arc { + Arc::new(Float64Array::from(vals.to_vec())) + } + + fn f64_array_opt(vals: &[Option]) -> Arc { + Arc::new(Float64Array::from(vals.to_vec())) + } + + fn dur_us_array(vals: &[i64]) -> Arc { + Arc::new(DurationMicrosecondArray::from(vals.to_vec())) + } + + fn ym_array(vals: &[i32]) -> Arc { + Arc::new(IntervalYearMonthArray::from(vals.to_vec())) + } + + fn downcast_i32(arr: &ArrayRef) -> &Int32Array { + arr.as_any().downcast_ref::().unwrap() + } + + fn mdn_array(vals: &[(i32, i32, i64)]) -> Arc { + let data: Vec = vals + .iter() + .map(|(m, d, ns)| IntervalMonthDayNano::new(*m, *d, *ns)) + .collect(); + Arc::new(IntervalMonthDayNanoArray::from(data)) + } + + // --- Float64 ------------------------------------------------------------- + + #[test] + fn test_width_bucket_f64_basic() { + let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]); + let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]); + let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]); + let n = i32_array_all(5, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert_eq!(out.values(), &[1, 2, 10, 0, 11]); + } + + #[test] + fn test_width_bucket_f64_descending_range() { + let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]); + let lo = f64_array(&[10.0; 5]); + let hi = f64_array(&[0.0; 5]); + let n = i32_array_all(5, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + + assert_eq!(out.values(), &[1, 1, 11, 11, 0]); + } + #[test] + fn test_width_bucket_f64_bounds_inclusive_exclusive_asc() { + let v = f64_array(&[0.0, 9.999999999, 10.0]); + let lo = f64_array(&[0.0; 3]); + let hi = f64_array(&[10.0; 3]); + let n = i32_array_all(3, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert_eq!(out.values(), &[1, 10, 11]); + } + + #[test] + fn test_width_bucket_f64_bounds_inclusive_exclusive_desc() { + let v = f64_array(&[10.0, 0.0, -0.000001]); + let lo = f64_array(&[10.0; 3]); + let hi = f64_array(&[0.0; 3]); + let n = i32_array_all(3, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert_eq!(out.values(), &[1, 11, 11]); + } + + #[test] + fn test_width_bucket_f64_edge_cases() { + let v = f64_array(&[1.0, 5.0, 9.0]); + let lo = f64_array(&[0.0, 0.0, 0.0]); + let hi = f64_array(&[10.0, 10.0, 10.0]); + let n = Arc::new(Int32Array::from(vec![0, -1, 10])); + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert!(out.is_null(0)); + assert!(out.is_null(1)); + assert_eq!(out.value(2), 10); + + let v = f64_array(&[1.0]); + let lo = f64_array(&[5.0]); + let hi = f64_array(&[5.0]); + let n = i32_array_all(1, 10); + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert!(out.is_null(0)); + + let v = f64_array_opt(&[Some(f64::NAN)]); + let lo = f64_array(&[0.0]); + let hi = f64_array(&[10.0]); + let n = i32_array_all(1, 10); + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert!(out.is_null(0)); + } + + #[test] + fn test_width_bucket_f64_nulls_propagate() { + let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]); + let lo = f64_array(&[0.0; 4]); + let hi = f64_array(&[10.0; 4]); + let n = i32_array_all(4, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert!(out.is_null(0)); + assert_eq!(out.value(1), 2); + assert_eq!(out.value(2), 3); + assert_eq!(out.value(3), 4); + + let v = f64_array(&[1.0]); + let lo = f64_array_opt(&[None]); + let hi = f64_array(&[10.0]); + let n = i32_array_all(1, 10); + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert!(out.is_null(0)); + } + + // --- Duration(Microsecond) ---------------------------------------------- + + #[test] + fn test_width_bucket_duration_us() { + let v = dur_us_array(&[1_000_000, 0, -1]); + let lo = dur_us_array(&[0, 0, 0]); + let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]); + let n = i32_array_all(3, 2); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert_eq!(out.values(), &[2, 1, 0]); + } + + #[test] + fn test_width_bucket_duration_us_equal_bounds() { + let v = dur_us_array(&[0]); + let lo = dur_us_array(&[1]); + let hi = dur_us_array(&[1]); + let n = i32_array_all(1, 10); + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + assert!(downcast_i32(&out).is_null(0)); + } + + // --- Interval(YearMonth) ------------------------------------------------ + + #[test] + fn test_width_bucket_interval_ym_basic() { + let v = ym_array(&[0, 5, 11, 12, 13]); + let lo = ym_array(&[0; 5]); + let hi = ym_array(&[12; 5]); + let n = i32_array_all(5, 12); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert_eq!(out.values(), &[1, 6, 12, 13, 13]); + } + + #[test] + fn test_width_bucket_interval_ym_desc() { + let v = ym_array(&[11, 12, 0, -1, 13]); + let lo = ym_array(&[12; 5]); + let hi = ym_array(&[0; 5]); + let n = i32_array_all(5, 12); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert_eq!(out.values(), &[2, 1, 13, 13, 0]); + } + + // --- Interval(MonthDayNano) -------------------------------------------- + + #[test] + fn test_width_bucket_interval_mdn_months_only_basic() { + let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]); + let lo = mdn_array(&[(0, 0, 0); 5]); + let hi = mdn_array(&[(12, 0, 0); 5]); + let n = i32_array_all(5, 12); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert_eq!(out.values(), &[1, 6, 12, 13, 13]); + } + + #[test] + fn test_width_bucket_interval_mdn_months_only_desc() { + let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]); + let lo = mdn_array(&[(12, 0, 0); 5]); + let hi = mdn_array(&[(0, 0, 0); 5]); + let n = i32_array_all(5, 12); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + // Mismo patrón que YM descendente + assert_eq!(out.values(), &[2, 1, 13, 13, 0]); + } + + #[test] + fn test_width_bucket_interval_mdn_day_nano_basic() { + let v = mdn_array(&[ + (0, 0, 0), + (0, 5, 0), + (0, 9, 0), + (0, 10, 0), + (0, -1, 0), + (0, 11, 0), + ]); + let lo = mdn_array(&[(0, 0, 0); 6]); + let hi = mdn_array(&[(0, 10, 0); 6]); + let n = i32_array_all(6, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + // x==hi -> n+1, x 0, x>hi -> n+1 + assert_eq!(out.values(), &[1, 6, 10, 11, 0, 11]); + } + + #[test] + fn test_width_bucket_interval_mdn_day_nano_desc() { + let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]); + let lo = mdn_array(&[(0, 10, 0); 5]); + let hi = mdn_array(&[(0, 0, 0); 5]); + let n = i32_array_all(5, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + + assert_eq!(out.values(), &[2, 1, 11, 11, 0]); + } + #[test] + fn test_width_bucket_interval_mdn_day_nano_desc_inside() { + let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]); + let lo = mdn_array(&[(0, 10, 0); 5]); + let hi = mdn_array(&[(0, 0, 0); 5]); + let n = i32_array_all(5, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + + assert_eq!(out.values(), &[1, 1, 11, 11, 0]); + } + + #[test] + fn test_width_bucket_interval_mdn_mixed_months_and_days_is_null() { + let v = mdn_array(&[(0, 1, 0)]); + let lo = mdn_array(&[(0, 0, 0)]); + let hi = mdn_array(&[(1, 1, 0)]); + let n = i32_array_all(1, 4); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert!(out.is_null(0)); + } + + #[test] + fn test_width_bucket_interval_mdn_equal_bounds_is_null() { + let v = mdn_array(&[(0, 0, 0)]); + let lo = mdn_array(&[(1, 2, 3)]); + let hi = mdn_array(&[(1, 2, 3)]); // lo == hi + let n = i32_array_all(1, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + assert!(downcast_i32(&out).is_null(0)); + } + + #[test] + fn test_width_bucket_interval_mdn_invalid_n_is_null() { + let v = mdn_array(&[(0, 0, 0)]); + let lo = mdn_array(&[(0, 0, 0)]); + let hi = mdn_array(&[(0, 10, 0)]); + let n = Arc::new(Int32Array::from(vec![0])); // n <= 0 + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + assert!(downcast_i32(&out).is_null(0)); + } + + #[test] + fn test_width_bucket_interval_mdn_nulls_propagate() { + let v = Arc::new(IntervalMonthDayNanoArray::from(vec![ + None, + Some(IntervalMonthDayNano::new(0, 5, 0)), + ])); + let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]); + let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]); + let n = i32_array_all(2, 10); + + let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); + let out = downcast_i32(&out); + assert!(out.is_null(0)); + assert_eq!(out.value(1), 6); + } + + // --- Errores ------------------------------------------------------------- + + #[test] + fn test_width_bucket_wrong_arg_count() { + let v = f64_array(&[1.0]); + let lo = f64_array(&[0.0]); + let hi = f64_array(&[10.0]); + let err = width_bucket_kern(&[v, lo, hi]).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("expects exactly 4"), "unexpected error: {msg}"); + } + + #[test] + fn test_width_bucket_unsupported_type() { + let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let lo = f64_array(&[0.0, 0.0, 0.0]); + let hi = f64_array(&[10.0, 10.0, 10.0]); + let n = i32_array_all(3, 10); + + let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err(); + let msg = format!("{err}"); + assert!( + msg.contains("unsupported data types") + || msg.contains("Float/Decimal OR Duration OR Interval(YearMonth)"), + "unexpected error: {msg}" + ); + } +} diff --git a/datafusion/spark/src/function/misc/mod.rs b/datafusion/spark/src/function/misc/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/misc/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/mod.rs b/datafusion/spark/src/function/mod.rs new file mode 100644 index 0000000000000..3f4f94cfaaf8c --- /dev/null +++ b/datafusion/spark/src/function/mod.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod aggregate; +pub mod array; +pub mod bitmap; +pub mod bitwise; +pub mod collection; +pub mod conditional; +pub mod conversion; +pub mod csv; +pub mod datetime; +pub mod error_utils; +pub mod functions_nested_utils; +pub mod generator; +pub mod hash; +pub mod json; +pub mod lambda; +pub mod map; +pub mod math; +pub mod misc; +pub mod predicate; +pub mod string; +pub mod r#struct; +pub mod table; +pub mod url; +pub mod utils; +pub mod window; +pub mod xml; diff --git a/datafusion/spark/src/function/predicate/mod.rs b/datafusion/spark/src/function/predicate/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/predicate/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/string/ascii.rs b/datafusion/spark/src/function/string/ascii.rs new file mode 100644 index 0000000000000..f14a66d4e484d --- /dev/null +++ b/datafusion/spark/src/function/string/ascii.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::string::ascii::ascii; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; + +/// Spark compatible version of the [ascii] function. Differs from the [default ascii function] +/// in that it is more permissive of input types, for example casting numeric input to string +/// before executing the function (default version doesn't allow numeric input). +/// +/// [ascii]: https://spark.apache.org/docs/latest/api/sql/index.html#ascii +/// [default ascii function]: datafusion_functions::string::ascii::AsciiFunc +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkAscii { + signature: Signature, +} + +impl Default for SparkAscii { + fn default() -> Self { + Self::new() + } +} + +impl SparkAscii { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkAscii { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ascii" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(ascii, vec![])(&args.args) + } + + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + Ok(vec![DataType::Utf8]) + } +} diff --git a/datafusion/spark/src/function/string/char.rs b/datafusion/spark/src/function/string/char.rs new file mode 100644 index 0000000000000..a1813373c65ff --- /dev/null +++ b/datafusion/spark/src/function/string/char.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::array::GenericStringBuilder; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; +use arrow::datatypes::DataType::Utf8; +use std::{any::Any, sync::Arc}; + +use datafusion_common::{cast::as_int64_array, exec_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +/// Spark-compatible `char` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CharFunc { + signature: Signature, +} + +impl Default for CharFunc { + fn default() -> Self { + Self::new() + } +} + +impl CharFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for CharFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "char" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_chr(&args.args) + } +} + +/// Returns the ASCII character having the binary equivalent to the input expression. +/// E.g., chr(65) = 'A'. +/// Compatible with Apache Spark's Chr function +fn spark_chr(args: &[ColumnarValue]) -> Result { + let array = args[0].clone(); + match array { + ColumnarValue::Array(array) => { + let array = chr(&[array])?; + Ok(ColumnarValue::Array(array)) + } + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => { + if value < 0 { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "".to_string(), + )))) + } else { + match core::char::from_u32((value % 256) as u32) { + Some(ch) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ch.to_string(), + )))), + None => { + exec_err!("requested character was incompatible for encoding.") + } + } + } + } + _ => exec_err!("The argument must be an Int64 array or scalar."), + } +} + +fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + let mut builder = GenericStringBuilder::::with_capacity( + integer_array.len(), + integer_array.len(), + ); + + for integer_opt in integer_array { + match integer_opt { + Some(integer) => { + if integer < 0 { + builder.append_value(""); // empty string for negative numbers. + } else { + match core::char::from_u32((integer % 256) as u32) { + Some(ch) => builder.append_value(ch.to_string()), + None => { + return exec_err!( + "requested character not compatible for encoding." + ) + } + } + } + } + None => builder.append_null(), + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} diff --git a/datafusion/spark/src/function/string/elt.rs b/datafusion/spark/src/function/string/elt.rs new file mode 100644 index 0000000000000..35a22fe5edb6f --- /dev/null +++ b/datafusion/spark/src/function/string/elt.rs @@ -0,0 +1,251 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder, +}; +use arrow::compute::{can_cast_types, cast}; +use arrow::datatypes::DataType::{Int64, Utf8}; +use arrow::datatypes::{DataType, Int64Type}; +use datafusion_common::cast::as_string_array; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkElt { + signature: Signature, +} + +impl Default for SparkElt { + fn default() -> Self { + SparkElt::new() + } +} + +impl SparkElt { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkElt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "elt" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(elt, vec![])(&args.args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let length = arg_types.len(); + if length < 2 { + plan_datafusion_err!( + "ELT function expects at least 2 arguments: index, value1" + ); + } + + let idx_dt: &DataType = &arg_types[0]; + if *idx_dt != Int64 && !can_cast_types(idx_dt, &Int64) { + return Err(DataFusionError::Plan(format!( + "ELT index must be Int64 (or castable to Int64), got {idx_dt:?}" + ))); + } + let mut coerced = Vec::with_capacity(arg_types.len()); + coerced.push(Int64); + + for _ in 1..length { + coerced.push(Utf8); + } + + Ok(coerced) + } +} + +fn elt(args: &[ArrayRef]) -> Result { + let n_rows = args[0].len(); + + let idx: &PrimitiveArray = + args[0].as_primitive_opt::().ok_or_else(|| { + DataFusionError::Plan(format!( + "ELT function: first argument must be Int64 (got {:?})", + args[0].data_type() + )) + })?; + + let num_values = args.len() - 1; + let mut cols: Vec> = Vec::with_capacity(num_values); + for a in args.iter().skip(1) { + let casted = cast(a, &Utf8)?; + let sa = as_string_array(&casted)?; + cols.push(Arc::new(sa.clone())); + } + + let mut builder = StringBuilder::new(); + + for i in 0..n_rows { + if idx.is_null(i) { + builder.append_null(); + continue; + } + + let index = idx.value(i); + + // TODO: if spark.sql.ansi.enabled is true, + // throw ArrayIndexOutOfBoundsException for invalid indices; + // if false, return NULL instead (current behavior). + if index < 1 || (index as usize) > num_values { + builder.append_null(); + continue; + } + + let value_idx = (index as usize) - 1; + let col = &cols[value_idx]; + + if col.is_null(i) { + builder.append_null(); + } else { + builder.append_value(col.value(i)); + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use datafusion_common::Result; + + use arrow::array::{ArrayRef, StringArray}; + use datafusion_common::DataFusionError; + use std::sync::Arc; + + fn run_elt_arrays(arrs: Vec) -> Result> { + let arr = elt(&arrs)?; + let string_array = arr + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?; + Ok(Arc::new(string_array.clone())) + } + + #[test] + fn elt_utf8_basic() -> Result<()> { + let idx = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(0), + None, + ])); + let v1 = Arc::new(StringArray::from(vec![ + Some("a1"), + Some("a2"), + Some("a3"), + Some("a4"), + Some("a5"), + Some("a6"), + ])); + let v2 = Arc::new(StringArray::from(vec![ + Some("b1"), + Some("b2"), + None, + Some("b4"), + Some("b5"), + Some("b6"), + ])); + let v3 = Arc::new(StringArray::from(vec![ + Some("c1"), + Some("c2"), + Some("c3"), + None, + Some("c5"), + Some("c6"), + ])); + + let out = run_elt_arrays(vec![idx, v1, v2, v3])?; + assert_eq!(out.len(), 6); + assert_eq!(out.value(0), "a1"); + assert_eq!(out.value(1), "b2"); + assert_eq!(out.value(2), "c3"); + assert!(out.is_null(3)); + assert!(out.is_null(4)); + assert!(out.is_null(5)); + Ok(()) + } + + #[test] + fn elt_int64_basic() -> Result<()> { + let idx = Arc::new(Int64Array::from(vec![Some(2), Some(1), Some(2)])); + let v1 = Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30)])); + let v2 = Arc::new(Int64Array::from(vec![Some(100), None, Some(300)])); + + let out = run_elt_arrays(vec![idx, v1, v2])?; + assert_eq!(out.len(), 3); + assert_eq!(out.value(0), "100"); + assert_eq!(out.value(1), "20"); + assert_eq!(out.value(2), "300"); + Ok(()) + } + + #[test] + fn elt_out_of_range_all_null() -> Result<()> { + let idx = Arc::new(Int64Array::from(vec![Some(5), Some(-1), Some(0)])); + let v1 = Arc::new(StringArray::from(vec![Some("x"), Some("y"), Some("z")])); + let v2 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])); + + let out = run_elt_arrays(vec![idx, v1, v2])?; + assert!(out.is_null(0)); + assert!(out.is_null(1)); + assert!(out.is_null(2)); + Ok(()) + } + + #[test] + fn elt_utf8_returns_utf8() -> Result<()> { + let idx = Arc::new(Int64Array::from(vec![Some(1)])); + let v1 = Arc::new(StringArray::from(vec![Some("scala")])); + let v2 = Arc::new(StringArray::from(vec![Some("java")])); + + let out = run_elt_arrays(vec![idx, v1, v2])?; + assert_eq!(out.data_type(), &Utf8); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/format_string.rs b/datafusion/spark/src/function/string/format_string.rs new file mode 100644 index 0000000000000..9809456af9a40 --- /dev/null +++ b/datafusion/spark/src/function/string/format_string.rs @@ -0,0 +1,2350 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Write; +use std::sync::Arc; + +use core::num::FpCategory; + +use arrow::{ + array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}, + datatypes::DataType, +}; +use bigdecimal::{ + num_bigint::{BigInt, Sign}, + BigDecimal, ToPrimitive, +}; +use chrono::{DateTime, Datelike, Timelike, Utc}; +use datafusion_common::{ + exec_datafusion_err, exec_err, plan_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; + +/// Spark-compatible `format_string` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct FormatStringFunc { + signature: Signature, + aliases: Vec, +} + +impl Default for FormatStringFunc { + fn default() -> Self { + Self::new() + } +} + +impl FormatStringFunc { + pub fn new() -> Self { + Self { + signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable), + aliases: vec![String::from("printf")], + } + } +} + +impl ScalarUDFImpl for FormatStringFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "format_string" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + DataType::Null => Ok(DataType::Utf8), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Ok(arg_types[0].clone()), + _ => plan_err!("The format_string function expects the first argument to be Utf8, LargeUtf8 or Utf8View") + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let len = args.args.iter().find_map(|arg| match arg { + ColumnarValue::Scalar(_) => None, + ColumnarValue::Array(a) => Some(a.len()), + }); + let is_scalar = len.is_none(); + let data_types = args.args[1..] + .iter() + .map(|arg| arg.data_type()) + .collect::>(); + let fmt_type = args.args[0].data_type(); + + match &args.args[0] { + ColumnarValue::Scalar(ScalarValue::Null) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(fmt))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(fmt))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(fmt))) => { + let formatter = Formatter::parse(fmt, &data_types)?; + let mut result = Vec::with_capacity(len.unwrap_or(1)); + for i in 0..len.unwrap_or(1) { + let scalars = args.args[1..] + .iter() + .map(|arg| try_to_scalar(arg.clone(), i)) + .collect::>>()?; + let formatted = formatter.format(&scalars)?; + result.push(formatted); + } + if is_scalar { + let scalar_result = result.pop().unwrap(); + match fmt_type { + DataType::Utf8 => Ok(ColumnarValue::Scalar(ScalarValue::Utf8( + Some(scalar_result), + ))), + DataType::LargeUtf8 => Ok(ColumnarValue::Scalar( + ScalarValue::LargeUtf8(Some(scalar_result)), + )), + DataType::Utf8View => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8View(Some(scalar_result)), + )), + _ => unreachable!(), + } + } else { + let array: ArrayRef = match fmt_type { + DataType::Utf8 => Arc::new(StringArray::from(result)), + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(result)), + DataType::Utf8View => Arc::new(StringViewArray::from(result)), + _ => unreachable!(), + }; + Ok(ColumnarValue::Array(array)) + } + } + ColumnarValue::Array(fmts) => { + let mut result = Vec::with_capacity(len.unwrap()); + for i in 0..len.unwrap() { + let fmt = ScalarValue::try_from_array(fmts, i)?; + match fmt.try_as_str() { + Some(Some(fmt)) => { + let formatter = Formatter::parse(fmt, &data_types)?; + let scalars = args.args[1..] + .iter() + .map(|arg| try_to_scalar(arg.clone(), i)) + .collect::>>()?; + let formatted = formatter.format(&scalars)?; + result.push(Some(formatted)); + } + Some(None) => { + result.push(None); + } + _ => unreachable!(), + } + } + let array: ArrayRef = match fmt_type { + DataType::Utf8 => Arc::new(StringArray::from(result)), + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(result)), + DataType::Utf8View => Arc::new(StringViewArray::from(result)), + _ => unreachable!(), + }; + Ok(ColumnarValue::Array(array)) + } + _ => exec_err!( + "The format_string function expects the first argument to be a string" + ), + } + } +} + +fn try_to_scalar(arg: ColumnarValue, index: usize) -> Result { + match arg { + ColumnarValue::Scalar(scalar) => Ok(scalar), + ColumnarValue::Array(array) => ScalarValue::try_from_array(&array, index), + } +} + +/// Compatible with `java.util.Formatter` +#[derive(Debug)] +pub struct Formatter<'a> { + pub elements: Vec>, + pub arg_num: usize, +} + +impl<'a> Formatter<'a> { + pub fn new(elements: Vec>) -> Self { + let arg_num = elements + .iter() + .map(|element| match element { + FormatElement::Format(spec) => spec.argument_index, + _ => 0, + }) + .max() + .unwrap_or(0); + Self { elements, arg_num } + } + + /// Parses a printf-style format string into a Formatter with validation. + /// + /// This method implements a comprehensive parser for Java `java.util.Formatter` syntax, + /// processing the format string character by character to identify and validate format + /// specifiers against the provided argument types. + /// + /// # Arguments + /// + /// * `fmt` - The format string containing literal text and format specifiers + /// * `arg_types` - Array of DataFusion DataTypes corresponding to the arguments + /// + /// # Parsing Process + /// + /// The parser operates in several phases: + /// + /// 1. **String Scanning**: Iterates through the format string looking for '%' characters + /// that mark the beginning of format specifiers or special sequences. + /// + /// 2. **Special Sequence Handling**: Processes escape sequences: + /// - `%%` becomes a literal '%' character + /// - `%n` becomes a newline character + /// - `%<` indicates reuse of the previous argument with a new format specifier + /// + /// 3. **Argument Index Resolution**: Determines which argument each format specifier refers to: + /// - Sequential indexing: arguments are consumed in order (1, 2, 3, ...) + /// - Positional indexing: explicit argument position using `%n$` syntax + /// - Previous argument reuse: `%<` references the last used argument + /// + /// 4. **Format Specifier Parsing**: For each format specifier, extracts: + /// - Flags (-, +, space, #, 0, ',', '(') + /// - Width specification (minimum field width) + /// - Precision specification (decimal places or maximum characters) + /// - Conversion type (d, s, f, x, etc.) + /// + /// 5. **Type Validation**: Verifies that each format specifier's conversion type + /// is compatible with the corresponding argument's DataType. For example: + /// - Integer conversions (%d, %x, %o) require integer DataTypes + /// - String conversions (%s, %S) accept any DataType + /// - Float conversions (%f, %e, %g) require numeric DataTypes + /// + /// 6. **Element Construction**: Creates FormatElement instances for: + /// - Verbatim text sections (copied directly to output) + /// - Validated format specifiers with their parsed parameters + /// + /// # Internal State Management + /// + /// The parser maintains several state variables: + /// - `argument_index`: Tracks the current sequential argument position + /// - `prev`: Remembers the last used argument index for `%<` references + /// - `res`: Accumulates the parsed FormatElement instances + /// - `rem`: Points to the remaining unparsed portion of the format string + /// + /// # Validation and Error Handling + /// + /// The parser performs extensive validation including: + /// - Argument index bounds checking against the provided arg_types array + /// - Format specifier syntax validation + /// - Type compatibility verification between conversion types and DataTypes + /// - Detection of malformed numeric parameters and invalid flag combinations + /// + /// # Returns + /// + /// Returns a Formatter containing the parsed elements and the maximum argument + /// index encountered, enabling efficient argument validation during formatting. + pub fn parse(fmt: &'a str, arg_types: &[DataType]) -> Result { + // find the first % + let mut res = Vec::new(); + + let mut rem = fmt; + let mut argument_index = 0; + + let mut prev: Option = None; + + while !rem.is_empty() { + if let Some((verbatim_prefix, rest)) = rem.split_once('%') { + if !verbatim_prefix.is_empty() { + res.push(FormatElement::Verbatim(verbatim_prefix)); + } + if let Some(rest) = rest.strip_prefix('%') { + res.push(FormatElement::Verbatim("%")); + rem = rest; + continue; + } + if let Some(rest) = rest.strip_prefix('n') { + res.push(FormatElement::Verbatim("\n")); + rem = rest; + continue; + } + if let Some(rest) = rest.strip_prefix('<') { + // %< means reuse the previous argument + let Some(p) = prev else { + return exec_err!("No previous argument to reference"); + }; + let (spec, rest) = + take_conversion_specifier(rest, p, arg_types[p - 1].clone())?; + res.push(FormatElement::Format(spec)); + rem = rest; + continue; + } + + let (current_argument_index, rest2) = take_numeric_param(rest, false); + let (current_argument_index, rest) = + match (current_argument_index, rest2.starts_with('$')) { + (NumericParam::Literal(index), true) => { + (index as usize, &rest2[1..]) + } + (NumericParam::FromArgument, true) => { + return exec_err!("Invalid numeric parameter") + } + (_, false) => { + argument_index += 1; + (argument_index, rest) + } + }; + if current_argument_index == 0 || current_argument_index > arg_types.len() + { + return exec_err!( + "Argument index {} is out of bounds", + current_argument_index + ); + } + + let (spec, rest) = take_conversion_specifier( + rest, + current_argument_index, + arg_types[current_argument_index - 1].clone(), + ) + .map_err(|e| exec_datafusion_err!("{:?}, format string: {:?}", e, fmt))?; + res.push(FormatElement::Format(spec)); + prev = Some(spec.argument_index); + rem = rest; + } else { + res.push(FormatElement::Verbatim(rem)); + break; + } + } + + Ok(Self::new(res)) + } + + pub fn format(&self, args: &[ScalarValue]) -> Result { + if args.len() < self.arg_num { + return exec_err!( + "Expected at least {} arguments, got {}", + self.arg_num, + args.len() + ); + } + let mut string = String::new(); + for element in &self.elements { + match element { + FormatElement::Verbatim(text) => { + string.push_str(text); + } + FormatElement::Format(spec) => { + spec.format(&mut string, &args[spec.argument_index - 1])?; + } + } + } + Ok(string) + } +} + +#[derive(Debug)] +pub enum FormatElement<'a> { + /// Some characters that are copied to the output as-is + Verbatim(&'a str), + /// A format specifier + Format(ConversionSpecifier), +} + +/// Parsed printf conversion specifier +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ConversionSpecifier { + pub argument_index: usize, + /// flag `#`: use `0x`, etc? + pub alt_form: bool, + /// flag `0`: left-pad with zeros? + pub zero_pad: bool, + /// flag `-`: left-adjust (pad with spaces on the right) + pub left_adj: bool, + /// flag `' '` (space): indicate sign with a space? + pub space_sign: bool, + /// flag `+`: Always show sign? (for signed numbers) + pub force_sign: bool, + /// flag `,`: include locale-specific grouping separators + pub grouping_separator: bool, + /// flag `(`: enclose negative numbers in parentheses + pub negative_in_parentheses: bool, + /// field width + pub width: NumericParam, + /// floating point field precision + pub precision: NumericParam, + /// data type + pub conversion_type: ConversionType, +} + +/// Width / precision parameter +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NumericParam { + /// The literal width + Literal(i32), + /// Get the width from the previous argument + FromArgument, +} + +/// Printf data type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConversionType { + /// `B` + BooleanUpper, + /// `b` + BooleanLower, + /// Not implemented yet. Can be implemented after is merged + /// `h` + HexHashLower, + /// `H` + HexHashUpper, + /// `d` + DecInt, + /// `o` + OctInt, + /// `x` + HexIntLower, + /// `X` + HexIntUpper, + /// `e` + SciFloatLower, + /// `E` + SciFloatUpper, + /// `f` + DecFloatLower, + /// `g` + CompactFloatLower, + /// `G` + CompactFloatUpper, + /// `a` + HexFloatLower, + /// `A` + HexFloatUpper, + /// `t` + TimeLower(TimeFormat), + /// `T` + TimeUpper(TimeFormat), + /// `c` + CharLower, + /// `C` + CharUpper, + /// `s` + StringLower, + /// `S` + StringUpper, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TimeFormat { + // Hour of the day for the 24-hour clock, + // formatted as two digits with a leading zero as necessary i.e. 00 - 23. 00 corresponds to midnight. + HUpper, + // Hour for the 12-hour clock, + // formatted as two digits with a leading zero as necessary, i.e. 01 - 12. 01 corresponds to one o'clock (either morning or afternoon). + IUpper, + // Hour of the day for the 24-hour clock, + // i.e. 0 - 23. 0 corresponds to midnight. + KLower, + // Hour for the 12-hour clock, + // i.e. 1 - 12. 1 corresponds to one o'clock (either morning or afternoon). + LLower, + // Minute within the hour formatted as two digits with a leading zero as necessary, i.e. 00 - 59. + MUpper, + // Seconds within the minute, formatted as two digits with a leading zero as necessary, + // i.e. 00 - 60 ("60" is a special value required to support leap seconds). + SUpper, + // Millisecond within the second formatted as three digits with leading zeros as necessary, i.e. 000 - 999. + LUpper, + // Nanosecond within the second, formatted as nine digits with leading zeros as necessary, + // i.e. 000000000 - 999999999. The precision of this value is limited by the resolution of the underlying operating system or hardware. + NUpper, + // Locale-specific morning or afternoon marker in lower case, e.g."am" or "pm". + // Use of the conversion prefix 'T' forces this output to upper case. (Note that 'p' produces lower-case output. + // This is different from GNU date and POSIX strftime(3c) which produce upper-case output.) + PLower, + // RFC 822 style numeric time zone offset from GMT, + // e.g. -0800. This value will be adjusted as necessary for Daylight Saving Time. + // For long, Long, and Date the time zone used is the default time zone for this instance of the Java virtual machine. + ZLower, + // A string representing the abbreviation for the time zone. This value will be adjusted as necessary for Daylight Saving Time. + // For long, Long, and Date the time zone used is the default time zone for this instance of the Java virtual machine. + // The Formatter's locale will supersede the locale of the argument (if any). + ZUpper, + // Seconds since the beginning of the epoch starting at 1 January 1970 00:00:00 UTC, + // i.e. Long.MIN_VALUE/1000 to Long.MAX_VALUE/1000. + SLower, + // Milliseconds since the beginning of the epoch starting at 1 January 1970 00:00:00 UTC, + // i.e. Long.MIN_VALUE to Long.MAX_VALUE. The precision of this value is limited by the resolution of the underlying operating system or hardware. + QUpper, + // Locale-specific full month name, e.g. "January", "February". + BUpper, + // Locale-specific abbreviated month name, e.g. "Jan", "Feb". + BLower, + // Locale-specific full weekday name, e.g. "Monday", "Tuesday". + AUpper, + // Locale-specific abbreviated weekday name, e.g. "Mon", "Tue". + ALower, + // Four-digit year divided by 100, formatted as two digits with leading zero as necessary, i.e. 00 - 99 + CUpper, + // Year, formatted to at least four digits with leading zeros as necessary, e.g. 0092 equals 92 CE for the Gregorian calendar. + YUpper, + // Last two digits of the year, formatted with leading zeros as necessary, i.e. 00 - 99. + YLower, + // Day of year, formatted as three digits with leading zeros as necessary, e.g. 001 - 366 for the Gregorian calendar. 001 corresponds to the first day of the year. + JLower, + // Month, formatted as two digits with leading zeros as necessary, i.e. 01 - 13, where "01" is the first month of the year and ("13" is a special value required to support lunar calendars). + MLower, + // Day of month, formatted as two digits with leading zeros as necessary, i.e. 01 - 31, where "01" is the first day of the month. + DLower, + // Day of month, formatted as two digits, i.e. 1 - 31 where "1" is the first day of the month. + ELower, + // Time formatted for the 24-hour clock as "%tH:%tM" + RUpper, + // Time formatted for the 24-hour clock as "%tH:%tM:%tS" + TUpper, + // Time formatted for the 12-hour clock as "%tI:%tM:%tS %Tp". The location of the morning or afternoon marker ('%Tp') may be locale-dependent. + RLower, + // Date formatted as "%tm/%td/%ty" + DUpper, + // ISO 8601 complete date formatted as "%tY-%tm-%td" + FUpper, + // Date and time formatted as "%ta %tb %td %tT %tZ %tY", e.g. "Sun Jul 20 16:17:00 EDT 1969" + CLower, +} + +impl TryFrom for TimeFormat { + type Error = DataFusionError; + fn try_from(value: char) -> Result { + match value { + 'H' => Ok(TimeFormat::HUpper), + 'I' => Ok(TimeFormat::IUpper), + 'k' => Ok(TimeFormat::KLower), + 'l' => Ok(TimeFormat::LLower), + 'M' => Ok(TimeFormat::MUpper), + 'S' => Ok(TimeFormat::SUpper), + 'L' => Ok(TimeFormat::LUpper), + 'N' => Ok(TimeFormat::NUpper), + 'p' => Ok(TimeFormat::PLower), + 'z' => Ok(TimeFormat::ZLower), + 'Z' => Ok(TimeFormat::ZUpper), + 's' => Ok(TimeFormat::SLower), + 'Q' => Ok(TimeFormat::QUpper), + 'B' => Ok(TimeFormat::BUpper), + 'b' | 'h' => Ok(TimeFormat::BLower), + 'A' => Ok(TimeFormat::AUpper), + 'a' => Ok(TimeFormat::ALower), + 'C' => Ok(TimeFormat::CUpper), + 'Y' => Ok(TimeFormat::YUpper), + 'y' => Ok(TimeFormat::YLower), + 'j' => Ok(TimeFormat::JLower), + 'm' => Ok(TimeFormat::MLower), + 'd' => Ok(TimeFormat::DLower), + 'e' => Ok(TimeFormat::ELower), + 'R' => Ok(TimeFormat::RUpper), + 'T' => Ok(TimeFormat::TUpper), + 'r' => Ok(TimeFormat::RLower), + 'D' => Ok(TimeFormat::DUpper), + 'F' => Ok(TimeFormat::FUpper), + 'c' => Ok(TimeFormat::CLower), + _ => exec_err!("Invalid time format: {}", value), + } + } +} + +impl ConversionType { + pub fn validate(&self, arg_type: DataType) -> Result<()> { + match self { + ConversionType::BooleanLower | ConversionType::BooleanUpper => { + if !matches!(arg_type, DataType::Boolean) { + return exec_err!( + "Invalid argument type for boolean conversion: {:?}", + arg_type + ); + } + } + ConversionType::CharLower | ConversionType::CharUpper => { + if !matches!( + arg_type, + DataType::Int8 + | DataType::UInt8 + | DataType::Int16 + | DataType::UInt16 + | DataType::Int32 + | DataType::UInt32 + | DataType::Int64 + | DataType::UInt64 + ) { + return exec_err!( + "Invalid argument type for char conversion: {:?}", + arg_type + ); + } + } + ConversionType::DecInt + | ConversionType::OctInt + | ConversionType::HexIntLower + | ConversionType::HexIntUpper => { + if !arg_type.is_integer() { + return exec_err!( + "Invalid argument type for integer conversion: {:?}", + arg_type + ); + } + } + ConversionType::SciFloatLower + | ConversionType::SciFloatUpper + | ConversionType::DecFloatLower + | ConversionType::CompactFloatLower + | ConversionType::CompactFloatUpper + | ConversionType::HexFloatLower + | ConversionType::HexFloatUpper => { + if !arg_type.is_numeric() { + return exec_err!( + "Invalid argument type for float conversion: {:?}", + arg_type + ); + } + } + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_) => { + if !arg_type.is_temporal() { + return exec_err!( + "Invalid argument type for time conversion: {:?}", + arg_type + ); + } + } + _ => {} + } + Ok(()) + } + + fn supports_integer(&self) -> bool { + matches!( + self, + ConversionType::DecInt + | ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt + | ConversionType::CharLower + | ConversionType::CharUpper + | ConversionType::StringLower + | ConversionType::StringUpper + ) + } + + fn supports_float(&self) -> bool { + matches!( + self, + ConversionType::DecFloatLower + | ConversionType::SciFloatLower + | ConversionType::SciFloatUpper + | ConversionType::CompactFloatLower + | ConversionType::CompactFloatUpper + | ConversionType::StringLower + | ConversionType::StringUpper + | ConversionType::HexFloatLower + | ConversionType::HexFloatUpper + ) + } + + fn supports_decimal(&self) -> bool { + matches!( + self, + ConversionType::DecFloatLower + | ConversionType::SciFloatLower + | ConversionType::SciFloatUpper + | ConversionType::CompactFloatLower + | ConversionType::CompactFloatUpper + | ConversionType::StringLower + | ConversionType::StringUpper + ) + } + + fn supports_time(&self) -> bool { + matches!( + self, + ConversionType::TimeLower(_) + | ConversionType::TimeUpper(_) + | ConversionType::StringLower + | ConversionType::StringUpper + ) + } + + fn is_upper(&self) -> bool { + matches!( + self, + ConversionType::BooleanUpper + | ConversionType::HexHashUpper + | ConversionType::HexIntUpper + | ConversionType::SciFloatUpper + | ConversionType::CompactFloatUpper + | ConversionType::HexFloatUpper + | ConversionType::TimeUpper(_) + | ConversionType::CharUpper + | ConversionType::StringUpper + ) + } +} + +fn take_conversion_specifier( + mut s: &str, + argument_index: usize, + arg_type: DataType, +) -> Result<(ConversionSpecifier, &str)> { + let mut spec = ConversionSpecifier { + argument_index, + alt_form: false, + zero_pad: false, + left_adj: false, + space_sign: false, + force_sign: false, + grouping_separator: false, + negative_in_parentheses: false, + width: NumericParam::Literal(0), + precision: NumericParam::FromArgument, // Placeholder - must not be returned! + // ignore length modifier + conversion_type: ConversionType::DecInt, + }; + + // parse flags + loop { + match s.chars().next() { + Some('#') => { + spec.alt_form = true; + } + Some('0') => { + if spec.left_adj { + return exec_err!("Invalid flag combination: '0' and '-'"); + } + spec.zero_pad = true; + } + Some('-') => { + spec.left_adj = true; + } + Some(' ') => { + if spec.force_sign { + return exec_err!("Invalid flag combination: '+' and ' '"); + } + spec.space_sign = true; + } + Some('+') => { + if spec.space_sign { + return exec_err!("Invalid flag combination: '+' and ' '"); + } + spec.force_sign = true; + } + Some(',') => { + spec.grouping_separator = true; + } + Some('(') => { + spec.negative_in_parentheses = true; + } + _ => { + break; + } + } + s = &s[1..]; + } + // parse width + let (w, mut s) = take_numeric_param(s, false); + spec.width = w; + // parse precision + if matches!(s.chars().next(), Some('.')) { + s = &s[1..]; + let (p, s2) = take_numeric_param(s, true); + spec.precision = p; + s = s2; + } + let mut chars = s.chars(); + let mut offset = 1; + // parse conversion type + spec.conversion_type = match chars.next() { + Some('b') => ConversionType::BooleanLower, + Some('B') => ConversionType::BooleanUpper, + Some('h') => ConversionType::HexHashLower, + Some('H') => ConversionType::HexHashUpper, + Some('s') => ConversionType::StringLower, + Some('S') => ConversionType::StringUpper, + Some('c') => ConversionType::CharLower, + Some('C') => ConversionType::CharUpper, + Some('d') => ConversionType::DecInt, + Some('o') => ConversionType::OctInt, + Some('x') => ConversionType::HexIntLower, + Some('X') => ConversionType::HexIntUpper, + Some('e') => ConversionType::SciFloatLower, + Some('E') => ConversionType::SciFloatUpper, + Some('f') => ConversionType::DecFloatLower, + Some('g') => ConversionType::CompactFloatLower, + Some('G') => ConversionType::CompactFloatUpper, + Some('a') => ConversionType::HexFloatLower, + Some('A') => ConversionType::HexFloatUpper, + Some('t') => { + let Some(chr) = chars.next() else { + return exec_err!("Invalid time format: {}", s); + }; + offset += 1; + ConversionType::TimeLower(chr.try_into()?) + } + Some('T') => { + let Some(chr) = chars.next() else { + return exec_err!("Invalid time format: {}", s); + }; + offset += 1; + ConversionType::TimeUpper(chr.try_into()?) + } + chr => { + return plan_err!("Invalid conversion type: {:?}", chr); + } + }; + + spec.conversion_type.validate(arg_type)?; + Ok((spec, &s[offset..])) +} + +fn take_numeric_param(s: &str, zero: bool) -> (NumericParam, &str) { + match s.chars().next() { + Some(digit) if (if zero { '0'..='9' } else { '1'..='9' }).contains(&digit) => { + let mut s = s; + let mut w = 0; + loop { + match s.chars().next() { + Some(digit) if digit.is_ascii_digit() => { + w = 10 * w + (digit as i32 - '0' as i32); + } + _ => { + break; + } + } + s = &s[1..]; + } + (NumericParam::Literal(w), s) + } + _ => (NumericParam::FromArgument, s), + } +} + +impl ConversionSpecifier { + pub fn format(&self, string: &mut String, value: &ScalarValue) -> Result<()> { + match value { + ScalarValue::Boolean(value) => match self.conversion_type { + ConversionType::StringLower | ConversionType::StringUpper => { + self.format_string(string, &value.unwrap_or(false).to_string()) + } + + _ => self.format_boolean(string, value), + }, + ScalarValue::Int8(value) => match (self.conversion_type, value) { + (ConversionType::DecInt, Some(value)) => { + self.format_signed(string, *value as i64) + } + ( + ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt, + Some(value), + ) => self.format_unsigned(string, (*value as u8) as u64), + (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { + self.format_char(string, *value as u8 as char) + } + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_integer() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Int8", + self.conversion_type + ) + } + }, + ScalarValue::Int16(value) => match (self.conversion_type, value) { + (ConversionType::DecInt, Some(value)) => { + self.format_signed(string, *value as i64) + } + (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { + self.format_char( + string, + char::from_u32((*value as u16) as u32).unwrap(), + ) + } + ( + ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt, + Some(value), + ) => self.format_unsigned(string, (*value as u16) as u64), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_integer() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Int16", + self.conversion_type + ) + } + }, + ScalarValue::Int32(value) => match (self.conversion_type, value) { + (ConversionType::DecInt, Some(value)) => { + self.format_signed(string, *value as i64) + } + ( + ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt, + Some(value), + ) => self.format_unsigned(string, (*value as u32) as u64), + (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { + self.format_char(string, char::from_u32(*value as u32).unwrap()) + } + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_integer() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Int32", + self.conversion_type + ) + } + }, + ScalarValue::Int64(value) => match (self.conversion_type, value) { + (ConversionType::DecInt, Some(value)) => { + self.format_signed(string, *value) + } + ( + ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt, + Some(value), + ) => self.format_unsigned(string, *value as u64), + (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { + self.format_char( + string, + char::from_u32((*value as u64) as u32).unwrap(), + ) + } + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_integer() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Int64", + self.conversion_type + ) + } + }, + ScalarValue::UInt8(value) => match (self.conversion_type, value) { + ( + ConversionType::DecInt + | ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt, + Some(value), + ) => self.format_unsigned(string, *value as u64), + (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { + self.format_char(string, *value as char) + } + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_integer() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for UInt8", + self.conversion_type + ) + } + }, + ScalarValue::UInt16(value) => match (self.conversion_type, value) { + ( + ConversionType::DecInt + | ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt, + Some(value), + ) => self.format_unsigned(string, *value as u64), + (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { + self.format_char(string, char::from_u32(*value as u32).unwrap()) + } + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_integer() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for UInt16", + self.conversion_type + ) + } + }, + ScalarValue::UInt32(value) => match (self.conversion_type, value) { + ( + ConversionType::DecInt + | ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt, + Some(value), + ) => self.format_unsigned(string, *value as u64), + (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { + self.format_char(string, char::from_u32(*value).unwrap()) + } + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_integer() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for UInt32", + self.conversion_type + ) + } + }, + ScalarValue::UInt64(value) => match (self.conversion_type, value) { + ( + ConversionType::DecInt + | ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt, + Some(value), + ) => self.format_unsigned(string, *value), + (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { + self.format_char(string, char::from_u32(*value as u32).unwrap()) + } + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_integer() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for UInt64", + self.conversion_type + ) + } + }, + ScalarValue::Float16(value) => match (self.conversion_type, value) { + ( + ConversionType::DecFloatLower + | ConversionType::SciFloatLower + | ConversionType::SciFloatUpper + | ConversionType::CompactFloatLower + | ConversionType::CompactFloatUpper, + Some(value), + ) => self.format_float(string, value.to_f64().unwrap()), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_f32().unwrap().spark_string()), + ( + ConversionType::HexFloatLower | ConversionType::HexFloatUpper, + Some(value), + ) => self.format_hex_float(string, value.to_f64().unwrap()), + (t, None) if t.supports_float() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Float16", + self.conversion_type + ) + } + }, + ScalarValue::Float32(value) => match (self.conversion_type, value) { + ( + ConversionType::DecFloatLower + | ConversionType::SciFloatLower + | ConversionType::SciFloatUpper + | ConversionType::CompactFloatLower + | ConversionType::CompactFloatUpper, + Some(value), + ) => self.format_float(string, *value as f64), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.spark_string()), + ( + ConversionType::HexFloatLower | ConversionType::HexFloatUpper, + Some(value), + ) => self.format_hex_float(string, *value as f64), + (t, None) if t.supports_float() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Float32", + self.conversion_type + ) + } + }, + ScalarValue::Float64(value) => match (self.conversion_type, value) { + ( + ConversionType::DecFloatLower + | ConversionType::SciFloatLower + | ConversionType::SciFloatUpper + | ConversionType::CompactFloatLower + | ConversionType::CompactFloatUpper, + Some(value), + ) => self.format_float(string, *value), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.spark_string()), + ( + ConversionType::HexFloatLower | ConversionType::HexFloatUpper, + Some(value), + ) => self.format_hex_float(string, *value), + (t, None) if t.supports_float() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Float64", + self.conversion_type + ) + } + }, + ScalarValue::Utf8(value) => { + let value: &str = match value { + Some(value) => value.as_str(), + None => "null", + }; + if matches!( + self.conversion_type, + ConversionType::StringLower | ConversionType::StringUpper + ) { + self.format_string(string, value) + } else { + exec_err!( + "Invalid conversion type: {:?} for Utf8", + self.conversion_type + ) + } + } + ScalarValue::LargeUtf8(value) => { + let value: &str = match value { + Some(value) => value.as_str(), + None => "null", + }; + if matches!( + self.conversion_type, + ConversionType::StringLower | ConversionType::StringUpper + ) { + self.format_string(string, value) + } else { + exec_err!( + "Invalid conversion type: {:?} for LargeUtf8", + self.conversion_type + ) + } + } + ScalarValue::Utf8View(value) => { + let value: &str = match value { + Some(value) => value.as_str(), + None => "null", + }; + self.format_string(string, value) + } + ScalarValue::Decimal128(value, _, scale) => { + match (self.conversion_type, value) { + ( + ConversionType::DecFloatLower + | ConversionType::SciFloatLower + | ConversionType::SciFloatUpper + | ConversionType::CompactFloatLower + | ConversionType::CompactFloatUpper, + Some(value), + ) => self.format_decimal(string, value.to_string(), *scale as i64), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_decimal() => { + self.format_string(string, "null") + } + + _ => { + exec_err!( + "Invalid conversion type: {:?} for Decimal128", + self.conversion_type + ) + } + } + } + ScalarValue::Decimal256(value, _, scale) => { + match (self.conversion_type, value) { + ( + ConversionType::DecFloatLower + | ConversionType::SciFloatLower + | ConversionType::SciFloatUpper + | ConversionType::CompactFloatLower + | ConversionType::CompactFloatUpper, + Some(value), + ) => self.format_decimal(string, value.to_string(), *scale as i64), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_decimal() => { + self.format_string(string, "null") + } + + _ => { + exec_err!( + "Invalid conversion type: {:?} for Decimal256", + self.conversion_type + ) + } + } + } + + ScalarValue::Time32Second(value) => match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_time(string, *value as i64 * 1000000000, &None), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Time32Second", + self.conversion_type + ) + } + }, + ScalarValue::Time32Millisecond(value) => { + match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_time(string, *value as i64 * 1000000, &None), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Time32Millisecond", + self.conversion_type + ) + } + } + } + ScalarValue::Time64Microsecond(value) => { + match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_time(string, *value * 1000, &None), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Time64Microsecond", + self.conversion_type + ) + } + } + } + ScalarValue::Time64Nanosecond(value) => match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_time(string, *value, &None), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Time64Nanosecond", + self.conversion_type + ) + } + }, + ScalarValue::TimestampSecond(value, zone) => { + match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_time(string, value * 1000000000, zone), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for TimestampSecond", + self.conversion_type + ) + } + } + } + ScalarValue::TimestampMillisecond(value, zone) => { + match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_time(string, *value * 1000000, zone), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for TimestampMillisecond", + self.conversion_type + ) + } + } + } + ScalarValue::TimestampMicrosecond(value, zone) => { + match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_time(string, value * 1000, zone), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for timestampmicrosecond", + self.conversion_type + ) + } + } + } + + ScalarValue::TimestampNanosecond(value, zone) => { + match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_time(string, *value, zone), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for TimestampNanosecond", + self.conversion_type + ) + } + } + } + ScalarValue::Date32(value) => match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_date(string, *value as i64), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Date32", + self.conversion_type + ) + } + }, + ScalarValue::Date64(value) => match (self.conversion_type, value) { + ( + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_), + Some(value), + ) => self.format_date(string, *value), + ( + ConversionType::StringLower | ConversionType::StringUpper, + Some(value), + ) => self.format_string(string, &value.to_string()), + (t, None) if t.supports_time() => self.format_string(string, "null"), + _ => { + exec_err!( + "Invalid conversion type: {:?} for Date64", + self.conversion_type + ) + } + }, + ScalarValue::Null => { + let value = "null".to_string(); + self.format_string(string, &value) + } + _ => exec_err!("Invalid scalar value: {:?}", value), + } + } + + fn format_hex_float(&self, writer: &mut String, value: f64) -> Result<()> { + // Handle special cases first + let (sign, raw_exponent, mantissa) = value.to_parts(); + let is_subnormal = raw_exponent == 0; + + let precision = match self.precision { + NumericParam::FromArgument => None, + NumericParam::Literal(p) => Some(p), + }; + + // Determine if we need to normalize subnormal numbers + // Only normalize when precision is specified and less than full mantissa width + let mantissa_hex_digits = f64::MANTISSA_BITS.div_ceil(4); // 13 for f64 + let should_normalize = is_subnormal + && precision.is_some() + && precision.unwrap() < mantissa_hex_digits as i32; + + let (value, raw_exponent, mantissa) = if should_normalize { + let value = value * f64::SCALEUP; + let (_, raw_exponent, mantissa) = value.to_parts(); + (value, raw_exponent, mantissa) + } else { + (value, raw_exponent, mantissa) + }; + + let mut temp = String::new(); + + let sign_char = if sign { + "-" + } else if self.force_sign { + "+" + } else if self.space_sign { + " " + } else { + "" + }; + match value.category() { + FpCategory::Nan => { + write!(&mut temp, "NaN")?; + } + FpCategory::Infinite => { + write!(&mut temp, "{sign_char}Infinity")?; + } + FpCategory::Zero => { + write!(&mut temp, "{sign_char}0x0.0p0")?; + } + _ => { + let bias = i32::from(f64::EXPONENT_BIAS); + // Calculate actual exponent + // For subnormal numbers, the exponent is 1 - bias (not 0 - bias) + let exponent = if is_subnormal && !should_normalize { + 1 - bias + } else { + raw_exponent as i32 - bias + }; + + // Handle precision for rounding + let final_mantissa = if let Some(p) = precision { + if p == 0 { + // For precision 0, we still need at least 1 hex digit + // Round to the nearest integer mantissa value + let shift_distance = f64::MANTISSA_BITS as i32 - 4; // Keep 1 hex digit (4 bits) + let shifted = mantissa >> shift_distance; + let rounding_bits = mantissa & ((1u64 << shift_distance) - 1); + let round_bit = 1u64 << (shift_distance - 1); + + // Round to nearest, ties to even + if rounding_bits > round_bit + || (rounding_bits == round_bit && (shifted & 1) != 0) + { + (shifted + 1) << shift_distance + } else { + shifted << shift_distance + } + } else { + // Apply rounding based on precision + let precision_bits = p * 4; // Each hex digit is 4 bits + let keep_bits = f64::MANTISSA_BITS as i32; + let shift_distance = keep_bits - precision_bits; + + if shift_distance > 0 { + let shifted = mantissa >> shift_distance; + let rounding_bits = mantissa & ((1u64 << shift_distance) - 1); + let round_bit = 1u64 << (shift_distance - 1); + + // Round to nearest, ties to even + if rounding_bits > round_bit + || (rounding_bits == round_bit && (shifted & 1) != 0) + { + (shifted + 1) << shift_distance + } else { + shifted << shift_distance + } + } else { + mantissa + } + } + } else { + mantissa + }; + + if is_subnormal && !should_normalize { + // Original subnormal format: 0x0.xxxp-1022 + if precision.is_some() { + // precision >= 13, show as subnormal + let full_hex = format!( + "{:0width$x}", + final_mantissa, + width = mantissa_hex_digits as usize + ); + write!(&mut temp, "{sign_char}0x0.{full_hex}p{exponent}")?; + } else { + // No precision specified, show full subnormal + let hex_digits = format!( + "{:0width$x}", + final_mantissa, + width = mantissa_hex_digits as usize + ); + write!(&mut temp, "{sign_char}0x0.{hex_digits}p{exponent}")?; + } + } else { + // Normal format or normalized subnormal: 0x1.xxxpN + if let Some(p) = precision { + let p = if p == 0 { 1 } else { p }; + let hex_digits = format!("{final_mantissa:x}"); + let formatted_digits = if p as usize >= hex_digits.len() { + // Pad with zeros to match precision + format!("{:0().unwrap() - f64::SCALEUP_POWER as i32; + temp = format!("{prefix}p{iexp}"); + } + } + }; + + if self.conversion_type.is_upper() { + temp = temp.to_ascii_uppercase(); + } + + let NumericParam::Literal(width) = self.width else { + writer.push_str(&temp); + return Ok(()); + }; + if self.left_adj { + writer.push_str(&temp); + for _ in temp.len()..width as usize { + writer.push(' '); + } + } else if self.zero_pad && value.is_finite() { + let delimiter = if self.conversion_type.is_upper() { + "0X" + } else { + "0x" + }; + let (prefix, suffix) = temp.split_once(delimiter).unwrap(); + writer.push_str(prefix); + writer.push_str(delimiter); + for _ in temp.len()..width as usize { + writer.push('0'); + } + writer.push_str(suffix); + } else { + while temp.len() < width as usize { + temp = " ".to_owned() + &temp; + } + writer.push_str(&temp); + }; + Ok(()) + } + + fn format_char(&self, writer: &mut String, value: char) -> Result<()> { + let upper = self.conversion_type.is_upper(); + match self.conversion_type { + ConversionType::CharLower | ConversionType::CharUpper => { + let NumericParam::Literal(width) = self.width else { + if upper { + writer.push(value.to_ascii_uppercase()); + } else { + writer.push(value); + } + return Ok(()); + }; + + let start_len = writer.len(); + if self.left_adj { + if upper { + writer.push(value.to_ascii_uppercase()); + } else { + writer.push(value); + } + while writer.len() - start_len < width as usize { + writer.push(' '); + } + } else { + while writer.len() - start_len + value.len_utf8() < width as usize { + writer.push(' '); + } + if upper { + writer.push(value.to_ascii_uppercase()); + } else { + writer.push(value); + } + } + Ok(()) + } + _ => exec_err!( + "Invalid conversion type: {:?} for char", + self.conversion_type + ), + } + } + + fn format_boolean(&self, writer: &mut String, value: &Option) -> Result<()> { + let value = value.unwrap_or(false); + + let formatted = match self.conversion_type { + ConversionType::BooleanUpper => { + if value { + "TRUE" + } else { + "FALSE" + } + } + ConversionType::BooleanLower => { + if value { + "true" + } else { + "false" + } + } + _ => { + return exec_err!( + "Invalid conversion type: {:?} for boolean array", + self.conversion_type + ) + } + }; + self.format_str(writer, formatted) + } + + fn format_float(&self, writer: &mut String, value: f64) -> Result<()> { + let mut prefix = String::new(); + let mut suffix = String::new(); + let mut number = String::new(); + let upper = self.conversion_type.is_upper(); + + // set up the sign + if value.is_sign_negative() { + if self.negative_in_parentheses { + prefix.push('('); + suffix.push(')'); + } else { + prefix.push('-'); + } + } else if self.space_sign { + prefix.push(' '); + } else if self.force_sign { + prefix.push('+'); + } + + if value.is_finite() { + let mut use_scientific = false; + let mut strip_trailing_0s = false; + let mut abs = value.abs(); + let mut exponent = abs.log10().floor() as i32; + let mut precision = match self.precision { + NumericParam::Literal(p) => p, + _ => 6, + }; + match self.conversion_type { + ConversionType::DecFloatLower => { + // default + } + ConversionType::SciFloatLower => { + use_scientific = true; + } + ConversionType::SciFloatUpper => { + use_scientific = true; + } + ConversionType::CompactFloatLower | ConversionType::CompactFloatUpper => { + strip_trailing_0s = true; + if precision == 0 { + precision = 1; + } + // exponent signifies significant digits - we must round now + // to (re)calculate the exponent + let rounding_factor = + 10.0_f64.powf((precision - 1 - exponent) as f64); + let rounded_fixed = (abs * rounding_factor).round(); + abs = rounded_fixed / rounding_factor; + exponent = abs.log10().floor() as i32; + if exponent < -4 || exponent >= precision { + use_scientific = true; + precision -= 1; + } else { + // precision specifies the number of significant digits + precision -= 1 + exponent; + } + } + _ => { + return exec_err!( + "Invalid conversion type: {:?} for float", + self.conversion_type + ) + } + } + + if use_scientific { + // Manual scientific notation formatting for uppercase E + let mantissa = abs / 10.0_f64.powf(exponent as f64); + let exp_char = if upper { 'E' } else { 'e' }; + number = format!("{mantissa:.prec$}", prec = precision as usize); + if strip_trailing_0s { + number = trim_trailing_0s(&number).to_owned(); + } + number = format!("{number}{exp_char}{exponent:+03}"); + } else { + number = format!("{abs:.prec$}", prec = precision as usize); + if strip_trailing_0s { + number = trim_trailing_0s(&number).to_owned(); + } + } + if self.alt_form && !number.contains('.') { + number += "."; + } + } else { + // not finite + match self.conversion_type { + ConversionType::DecFloatLower + | ConversionType::SciFloatLower + | ConversionType::CompactFloatLower => { + if value.is_infinite() { + number.push_str("Infinity") + } else { + number.push_str("NaN") + } + } + ConversionType::SciFloatUpper | ConversionType::CompactFloatUpper => { + if value.is_infinite() { + number.push_str("INFINITY") + } else { + number.push_str("NAN") + } + } + _ => { + return exec_err!( + "Invalid conversion type: {:?} for float", + self.conversion_type + ) + } + } + } + // Take care of padding + let NumericParam::Literal(width) = self.width else { + writer.push_str(&prefix); + writer.push_str(&number); + writer.push_str(&suffix); + return Ok(()); + }; + if self.left_adj { + let mut full_num = prefix + &number + &suffix; + while full_num.len() < width as usize { + full_num.push(' '); + } + writer.push_str(&full_num); + } else if self.zero_pad && value.is_finite() { + while prefix.len() + number.len() + suffix.len() < width as usize { + prefix.push('0'); + } + writer.push_str(&prefix); + writer.push_str(&number); + writer.push_str(&suffix); + } else { + let mut full_num = prefix + &number + &suffix; + while full_num.len() < width as usize { + full_num = " ".to_owned() + &full_num; + } + writer.push_str(&full_num); + }; + + Ok(()) + } + + fn format_signed(&self, writer: &mut String, value: i64) -> Result<()> { + let negative = value < 0; + let abs_val = value.abs(); + + let (sign_prefix, sign_suffix) = if negative && self.negative_in_parentheses { + ("(".to_owned(), ")".to_owned()) + } else if negative { + ("-".to_owned(), "".to_owned()) + } else if self.force_sign { + ("+".to_owned(), "".to_owned()) + } else if self.space_sign { + (" ".to_owned(), "".to_owned()) + } else { + ("".to_owned(), "".to_owned()) + }; + + let mut mod_spec = *self; + mod_spec.width = match self.width { + NumericParam::Literal(w) => NumericParam::Literal( + w - sign_prefix.len() as i32 - sign_suffix.len() as i32, + ), + _ => NumericParam::FromArgument, + }; + let mut formatted = String::new(); + mod_spec.format_unsigned(&mut formatted, abs_val as u64)?; + // put the sign a after any leading spaces + let mut actual_number = &formatted[0..]; + let mut leading_spaces = &formatted[0..0]; + if let Some(first_non_space) = formatted.find(|c| c != ' ') { + actual_number = &formatted[first_non_space..]; + leading_spaces = &formatted[0..first_non_space]; + } + write!( + writer, + "{}{}{}{}", + leading_spaces.to_owned(), + sign_prefix, + actual_number, + sign_suffix + ) + .map_err(|e| exec_datafusion_err!("Write error: {}", e))?; + Ok(()) + } + + fn format_unsigned(&self, writer: &mut String, value: u64) -> Result<()> { + let mut s = String::new(); + let mut alt_prefix = ""; + match self.conversion_type { + ConversionType::DecInt => { + let num_str = format!("{value}"); + if self.grouping_separator { + // Add thousands separators + let mut result = String::new(); + let chars: Vec = num_str.chars().collect(); + for (i, c) in chars.iter().enumerate() { + if i > 0 && (chars.len() - i).is_multiple_of(3) { + result.push(','); + } + result.push(*c); + } + s = result; + } else { + s = num_str; + } + } + ConversionType::HexIntLower => { + alt_prefix = "0x"; + write!(&mut s, "{value:x}") + .map_err(|e| exec_datafusion_err!("Write error: {}", e))?; + } + ConversionType::HexIntUpper => { + alt_prefix = "0X"; + write!(&mut s, "{value:X}") + .map_err(|e| exec_datafusion_err!("Write error: {}", e))?; + } + ConversionType::OctInt => { + alt_prefix = "0"; + write!(&mut s, "{value:o}") + .map_err(|e| exec_datafusion_err!("Write error: {}", e))?; + } + _ => { + return exec_err!( + "Invalid conversion type: {:?} for u64", + self.conversion_type + ) + } + } + let mut prefix = if self.alt_form { + alt_prefix.to_owned() + } else { + String::new() + }; + + let formatted = if let NumericParam::Literal(width) = self.width { + if self.left_adj { + let mut num_str = prefix + &s; + while num_str.len() < width as usize { + num_str.push(' '); + } + num_str + } else if self.zero_pad { + while prefix.len() + s.len() < width as usize { + prefix.push('0'); + } + prefix + &s + } else { + let mut num_str = prefix + &s; + while num_str.len() < width as usize { + num_str = " ".to_owned() + &num_str; + } + num_str + } + } else { + prefix + &s + }; + write!(writer, "{formatted}") + .map_err(|e| exec_datafusion_err!("Write error: {}", e))?; + Ok(()) + } + + fn format_str(&self, writer: &mut String, value: &str) -> Result<()> { + // Take care of precision, putting the truncated string in `content` + let precision: usize = match self.precision { + NumericParam::Literal(p) => p, + _ => i32::MAX, + } + .try_into() + .unwrap_or_default(); + let content_len = { + let mut content_len = precision.min(value.len()); + while !value.is_char_boundary(content_len) { + content_len -= 1; + } + content_len + }; + let content = &value[..content_len]; + + // Pad to width if needed, putting the padded string in `s` + + if let NumericParam::Literal(width) = self.width { + let start_len = writer.len(); + if self.left_adj { + writer.push_str(content); + while writer.len() - start_len < width as usize { + writer.push(' '); + } + } else { + while writer.len() - start_len + content.len() < width as usize { + writer.push(' '); + } + writer.push_str(content); + } + } else { + writer.push_str(content); + } + Ok(()) + } + + fn format_string(&self, writer: &mut String, value: &str) -> Result<()> { + if self.conversion_type.is_upper() { + let upper = value.to_ascii_uppercase(); + self.format_str(writer, &upper) + } else { + self.format_str(writer, value) + } + } + + fn format_decimal( + &self, + writer: &mut String, + value: String, + scale: i64, + ) -> Result<()> { + let mut prefix = String::new(); + let upper = self.conversion_type.is_upper(); + + // Parse as BigDecimal + let decimal = value + .parse::() + .map_err(|e| exec_datafusion_err!("Failed to parse decimal: {}", e))?; + let decimal = BigDecimal::from_bigint(decimal, scale); + + // Handle sign + let is_negative = decimal.sign() == Sign::Minus; + let abs_decimal = decimal.abs(); + + if is_negative { + prefix.push('-'); + } else if self.space_sign { + prefix.push(' '); + } else if self.force_sign { + prefix.push('+'); + } + + let exp_symb = if upper { 'E' } else { 'e' }; + let mut strip_trailing_0s = false; + + // Get precision setting + let mut precision = match self.precision { + NumericParam::Literal(p) => p, + _ => 6, + }; + + let number = match self.conversion_type { + ConversionType::DecFloatLower => { + // Format as fixed-point decimal + self.format_decimal_fixed(&abs_decimal, precision, strip_trailing_0s)? + } + ConversionType::SciFloatLower => self.format_decimal_scientific( + &abs_decimal, + precision, + 'e', + strip_trailing_0s, + )?, + ConversionType::SciFloatUpper => self.format_decimal_scientific( + &abs_decimal, + precision, + 'E', + strip_trailing_0s, + )?, + ConversionType::CompactFloatLower | ConversionType::CompactFloatUpper => { + strip_trailing_0s = true; + if precision == 0 { + precision = 1; + } + // Determine if we should use scientific notation + let log10_val = abs_decimal.to_f64().map(|f| f.log10()).unwrap_or(0.0); + if log10_val < -4.0 || log10_val >= precision as f64 { + self.format_decimal_scientific( + &abs_decimal, + precision - 1, + exp_symb, + strip_trailing_0s, + )? + } else { + self.format_decimal_fixed( + &abs_decimal, + precision - 1 - log10_val.floor() as i32, + strip_trailing_0s, + )? + } + } + _ => { + return exec_err!( + "Invalid conversion type: {:?} for decimal", + self.conversion_type + ) + } + }; + + // Handle padding + let NumericParam::Literal(width) = self.width else { + writer.push_str(&prefix); + writer.push_str(&number); + return Ok(()); + }; + + if self.left_adj { + let mut full_num = prefix + &number; + while full_num.len() < width as usize { + full_num.push(' '); + } + writer.push_str(&full_num); + } else if self.zero_pad { + while prefix.len() + number.len() < width as usize { + prefix.push('0'); + } + writer.push_str(&prefix); + writer.push_str(&number); + } else { + let mut full_num = prefix + &number; + while full_num.len() < width as usize { + full_num = " ".to_owned() + &full_num; + } + writer.push_str(&full_num); + } + + Ok(()) + } + + fn format_decimal_fixed( + &self, + decimal: &BigDecimal, + precision: i32, + strip_trailing_0s: bool, + ) -> Result { + if precision <= 0 { + Ok(decimal.round(0).to_string()) + } else { + // Use BigDecimal's with_scale method for precise decimal formatting + let scaled = decimal.round(precision as i64); + let mut number = scaled.to_string(); + if strip_trailing_0s { + number = trim_trailing_0s(&number).to_owned(); + } + Ok(number) + } + } + + fn format_decimal_scientific( + &self, + decimal: &BigDecimal, + precision: i32, + exp_char: char, + strip_trailing_0s: bool, + ) -> Result { + // Convert to f64 for scientific notation (may lose precision for very large numbers) + let float_val = decimal.to_f64().unwrap_or(0.0); + if float_val == 0.0 { + return Ok(format!("0{exp_char}+00")); + } + + let abs_val = float_val.abs(); + let exponent = abs_val.log10().floor() as i32; + let mantissa = abs_val / 10.0_f64.powf(exponent as f64); + + let mut number = if precision <= 0 { + format!("{mantissa:.0}") + } else { + format!("{mantissa:.prec$}", prec = precision as usize) + }; + + if strip_trailing_0s { + number = trim_trailing_0s(&number).to_owned(); + } + + Ok(format!("{number}{exp_char}{exponent:+03}")) + } + + fn format_time( + &self, + writer: &mut String, + timestamp_nanos: i64, + timezone: &Option>, + ) -> Result<()> { + let upper = self.conversion_type.is_upper(); + match &self.conversion_type { + ConversionType::TimeLower(time_format) + | ConversionType::TimeUpper(time_format) => { + let formatted = + self.format_time_component(timestamp_nanos, *time_format, timezone)?; + let result = if upper { + formatted.to_uppercase() + } else { + formatted + }; + write!(writer, "{result}") + .map_err(|e| exec_datafusion_err!("Write error: {}", e))?; + Ok(()) + } + _ => exec_err!( + "Invalid conversion type for time: {:?}", + self.conversion_type + ), + } + } + + fn format_date(&self, writer: &mut String, date_days: i64) -> Result<()> { + // Convert days since epoch to timestamp in nanoseconds + let timestamp_nanos = date_days * 24 * 60 * 60 * 1_000_000_000; + self.format_time(writer, timestamp_nanos, &None) + } + + fn format_time_component( + &self, + timestamp_nanos: i64, + time_format: TimeFormat, + _timezone: &Option>, + ) -> Result { + // Convert nanoseconds to seconds and nanoseconds remainder + let secs = timestamp_nanos / 1_000_000_000; + let nanos = (timestamp_nanos % 1_000_000_000) as u32; + + // Create DateTime from timestamp + let dt = DateTime::::from_timestamp(secs, nanos).ok_or_else(|| { + exec_datafusion_err!("Invalid timestamp: {}", timestamp_nanos) + })?; + + match time_format { + TimeFormat::HUpper => Ok(format!("{:02}", dt.hour())), + TimeFormat::IUpper => { + let hour_12 = match dt.hour12() { + (true, h) => h, // PM + (false, h) => h, // AM + }; + Ok(format!("{hour_12:02}")) + } + TimeFormat::KLower => Ok(format!("{}", dt.hour())), + TimeFormat::LLower => { + let hour_12 = match dt.hour12() { + (true, h) => h, // PM + (false, h) => h, // AM + }; + Ok(format!("{hour_12}")) + } + TimeFormat::MUpper => Ok(format!("{:02}", dt.minute())), + TimeFormat::SUpper => Ok(format!("{:02}", dt.second())), + TimeFormat::LUpper => Ok(format!("{:03}", dt.timestamp_millis() % 1000)), + TimeFormat::NUpper => Ok(format!("{:09}", dt.nanosecond())), + TimeFormat::PLower => { + let (is_pm, _) = dt.hour12(); + Ok(if is_pm { + "pm".to_string() + } else { + "am".to_string() + }) + } + TimeFormat::ZLower => Ok("+0000".to_string()), // UTC timezone offset + TimeFormat::ZUpper => Ok("UTC".to_string()), // UTC timezone name + TimeFormat::SLower => Ok(format!("{}", dt.timestamp())), + TimeFormat::QUpper => Ok(format!("{}", dt.timestamp_millis())), + TimeFormat::BUpper => Ok(dt.format("%B").to_string()), // Full month name + TimeFormat::BLower => Ok(dt.format("%b").to_string()), // Abbreviated month name + TimeFormat::AUpper => Ok(dt.format("%A").to_string()), // Full weekday name + TimeFormat::ALower => Ok(dt.format("%a").to_string()), // Abbreviated weekday name + TimeFormat::CUpper => Ok(format!("{:02}", dt.year() / 100)), + TimeFormat::YUpper => Ok(format!("{:04}", dt.year())), + TimeFormat::YLower => Ok(format!("{:02}", dt.year() % 100)), + TimeFormat::JLower => Ok(format!("{:03}", dt.ordinal())), // Day of year + TimeFormat::MLower => Ok(format!("{:02}", dt.month())), + TimeFormat::DLower => Ok(format!("{:02}", dt.day())), + TimeFormat::ELower => Ok(format!("{}", dt.day())), + TimeFormat::RUpper => Ok(dt.format("%H:%M").to_string()), + TimeFormat::TUpper => Ok(dt.format("%H:%M:%S").to_string()), + TimeFormat::RLower => { + let (is_pm, hour_12) = dt.hour12(); + let am_pm = if is_pm { "PM" } else { "AM" }; + Ok(format!( + "{:02}:{:02}:{:02} {}", + hour_12, + dt.minute(), + dt.second(), + am_pm + )) + } + TimeFormat::DUpper => Ok(dt.format("%m/%d/%y").to_string()), + TimeFormat::FUpper => Ok(dt.format("%Y-%m-%d").to_string()), + TimeFormat::CLower => Ok(dt.format("%a %b %d %H:%M:%S UTC %Y").to_string()), + } + } +} + +trait FloatFormattable: std::fmt::Display { + fn category(&self) -> FpCategory; + + fn spark_string(&self) -> String { + match self.category() { + FpCategory::Nan => "NaN".to_string(), + FpCategory::Infinite => { + if self.negative() { + "-Infinity".to_string() + } else { + "Infinity".to_string() + } + } + _ => self.to_string(), + } + } + fn negative(&self) -> bool; +} + +impl FloatFormattable for f32 { + fn category(&self) -> FpCategory { + self.classify() + } + + fn negative(&self) -> bool { + self.is_sign_negative() + } +} + +impl FloatFormattable for f64 { + fn category(&self) -> FpCategory { + self.classify() + } + + fn negative(&self) -> bool { + self.is_sign_negative() + } +} + +trait FloatBits: FloatFormattable { + const MANTISSA_BITS: u8; + const EXPONENT_BIAS: u16; + const SCALEUP_POWER: u8; + const SCALEUP: Self; + + fn to_parts(&self) -> (bool, u16, u64); +} + +impl FloatBits for f64 { + const MANTISSA_BITS: u8 = 52; + const EXPONENT_BIAS: u16 = 1023; + const SCALEUP_POWER: u8 = 54; + const SCALEUP: f64 = (1_i64 << Self::SCALEUP_POWER) as f64; + + fn to_parts(&self) -> (bool, u16, u64) { + let bits = self.to_bits(); + let sign: bool = (bits >> 63) == 1; + let exponent = ((bits >> 52) & 0x7FF) as u16; + let mantissa = bits & 0x000F_FFFF_FFFF_FFFF; + (sign, exponent, mantissa) + } +} + +fn trim_trailing_0s(number: &str) -> &str { + if number.contains('.') { + for (i, c) in number.chars().rev().enumerate() { + if c != '0' { + return &number[..number.len() - i]; + } + } + } + number +} + +fn trim_trailing_0s_hex(number: &str) -> &str { + for (i, c) in number.chars().rev().enumerate() { + if c != '0' { + return &number[..number.len() - i]; + } + } + number +} diff --git a/datafusion/spark/src/function/string/ilike.rs b/datafusion/spark/src/function/string/ilike.rs new file mode 100644 index 0000000000000..a160749523f1e --- /dev/null +++ b/datafusion/spark/src/function/string/ilike.rs @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::compute::ilike; +use arrow::datatypes::DataType; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// ILIKE function for case-insensitive pattern matching +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkILike { + signature: Signature, +} + +impl Default for SparkILike { + fn default() -> Self { + Self::new() + } +} + +impl SparkILike { + pub fn new() -> Self { + Self { + signature: Signature::string(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkILike { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ilike" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_ilike, vec![])(&args.args) + } +} + +/// Returns true if str matches pattern (case insensitive). +pub fn spark_ilike(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("ilike function requires exactly 2 arguments"); + } + + let result = ilike(&args[0], &args[1])?; + Ok(Arc::new(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, BooleanArray}; + use arrow::datatypes::DataType::Boolean; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_ilike_string_invoke { + ($INPUT1:expr, $INPUT2:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkILike::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT1)), + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT2)) + ], + $EXPECTED, + bool, + Boolean, + BooleanArray + ); + + test_scalar_function!( + SparkILike::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT1)), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT2)) + ], + $EXPECTED, + bool, + Boolean, + BooleanArray + ); + + test_scalar_function!( + SparkILike::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT1)), + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT2)) + ], + $EXPECTED, + bool, + Boolean, + BooleanArray + ); + }; + } + + #[test] + fn test_ilike_invoke() -> Result<()> { + test_ilike_string_invoke!( + Some(String::from("Spark")), + Some(String::from("_park")), + Ok(Some(true)) + ); + test_ilike_string_invoke!( + Some(String::from("Spark")), + Some(String::from("_PARK")), + Ok(Some(true)) + ); + test_ilike_string_invoke!( + Some(String::from("SPARK")), + Some(String::from("_park")), + Ok(Some(true)) + ); + test_ilike_string_invoke!( + Some(String::from("Spark")), + Some(String::from("sp%")), + Ok(Some(true)) + ); + test_ilike_string_invoke!( + Some(String::from("Spark")), + Some(String::from("SP%")), + Ok(Some(true)) + ); + test_ilike_string_invoke!( + Some(String::from("Spark")), + Some(String::from("%ARK")), + Ok(Some(true)) + ); + test_ilike_string_invoke!( + Some(String::from("Spark")), + Some(String::from("xyz")), + Ok(Some(false)) + ); + test_ilike_string_invoke!(None, Some(String::from("_park")), Ok(None)); + test_ilike_string_invoke!(Some(String::from("Spark")), None, Ok(None)); + test_ilike_string_invoke!(None, None, Ok(None)); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/length.rs b/datafusion/spark/src/function/string/length.rs new file mode 100644 index 0000000000000..1fa54d000effa --- /dev/null +++ b/datafusion/spark/src/function/string/length.rs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType, +}; +use arrow::datatypes::{DataType, Int32Type}; +use datafusion_common::exec_err; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `length` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLengthFunc { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkLengthFunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkLengthFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::Utf8View, + DataType::Utf8, + DataType::LargeUtf8, + DataType::Binary, + DataType::LargeBinary, + DataType::BinaryView, + ], + Volatility::Immutable, + ), + aliases: vec![ + String::from("character_length"), + String::from("char_length"), + String::from("len"), + ], + } + } +} + +impl ScalarUDFImpl for SparkLengthFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result { + // spark length always returns Int32 + Ok(DataType::Int32) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + make_scalar_function(spark_length, vec![])(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + character_length::<_>(string_array) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + character_length::<_>(string_array) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + character_length::<_>(string_array) + } + DataType::Binary => { + let binary_array = args[0].as_binary::(); + byte_length::<_>(binary_array) + } + DataType::LargeBinary => { + let binary_array = args[0].as_binary::(); + byte_length::<_>(binary_array) + } + DataType::BinaryView => { + let binary_array = args[0].as_binary_view(); + byte_length::<_>(binary_array) + } + other => exec_err!("Unsupported data type {other:?} for function `length`"), + } +} + +fn character_length<'a, V>(array: V) -> datafusion_common::Result +where + V: StringArrayType<'a>, +{ + // String characters are variable length encoded in UTF-8, counting the + // number of chars requires expensive decoding, however checking if the + // string is ASCII only is relatively cheap. + // If strings are ASCII only, count bytes instead. + let is_array_ascii_only = array.is_ascii(); + let nulls = array.nulls().cloned(); + let array = { + if is_array_ascii_only { + let values: Vec<_> = (0..array.len()) + .map(|i| { + // Safety: we are iterating with array.len() so the index is always valid + let value = unsafe { array.value_unchecked(i) }; + value.len() as i32 + }) + .collect(); + PrimitiveArray::::new(values.into(), nulls) + } else { + let values: Vec<_> = (0..array.len()) + .map(|i| { + // Safety: we are iterating with array.len() so the index is always valid + if array.is_null(i) { + i32::default() + } else { + let value = unsafe { array.value_unchecked(i) }; + if value.is_empty() { + i32::default() + } else if value.is_ascii() { + value.len() as i32 + } else { + value.chars().count() as i32 + } + } + }) + .collect(); + PrimitiveArray::::new(values.into(), nulls) + } + }; + + Ok(Arc::new(array)) +} + +fn byte_length<'a, V>(array: V) -> datafusion_common::Result +where + V: BinaryArrayType<'a>, +{ + let nulls = array.nulls().cloned(); + let values: Vec<_> = (0..array.len()) + .map(|i| { + // Safety: we are iterating with array.len() so the index is always valid + let value = unsafe { array.value_unchecked(i) }; + value.len() as i32 + }) + .collect(); + Ok(Arc::new(PrimitiveArray::::new( + values.into(), + nulls, + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_spark_length_string { + ($INPUT:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkLengthFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_scalar_function!( + SparkLengthFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_scalar_function!( + SparkLengthFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + + macro_rules! test_spark_length_binary { + ($INPUT:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkLengthFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_scalar_function!( + SparkLengthFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_scalar_function!( + SparkLengthFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + + #[test] + fn test_functions() -> Result<()> { + test_spark_length_string!(Some(String::from("chars")), Ok(Some(5))); + test_spark_length_string!(Some(String::from("josé")), Ok(Some(4))); + // test long strings (more than 12 bytes for StringView) + test_spark_length_string!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16))); + test_spark_length_string!(Some(String::from("")), Ok(Some(0))); + test_spark_length_string!(None, Ok(None)); + + test_spark_length_binary!(Some(String::from("chars").into_bytes()), Ok(Some(5))); + test_spark_length_binary!(Some(String::from("josé").into_bytes()), Ok(Some(5))); + // test long strings (more than 12 bytes for BinaryView) + test_spark_length_binary!( + Some(String::from("joséjoséjoséjosé").into_bytes()), + Ok(Some(20)) + ); + test_spark_length_binary!(Some(String::from("").into_bytes()), Ok(Some(0))); + test_spark_length_binary!(None, Ok(None)); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/like.rs b/datafusion/spark/src/function/string/like.rs new file mode 100644 index 0000000000000..df8eaef7cecbc --- /dev/null +++ b/datafusion/spark/src/function/string/like.rs @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::compute::like; +use arrow::datatypes::DataType; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// LIKE function for case-sensitive pattern matching +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLike { + signature: Signature, +} + +impl Default for SparkLike { + fn default() -> Self { + Self::new() + } +} + +impl SparkLike { + pub fn new() -> Self { + Self { + signature: Signature::string(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkLike { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "like" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_like, vec![])(&args.args) + } +} + +/// Returns true if str matches pattern (case sensitive). +pub fn spark_like(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("like function requires exactly 2 arguments"); + } + + let result = like(&args[0], &args[1])?; + Ok(Arc::new(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, BooleanArray}; + use arrow::datatypes::DataType::Boolean; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_like_string_invoke { + ($INPUT1:expr, $INPUT2:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkLike::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT1)), + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT2)) + ], + $EXPECTED, + bool, + Boolean, + BooleanArray + ); + + test_scalar_function!( + SparkLike::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT1)), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT2)) + ], + $EXPECTED, + bool, + Boolean, + BooleanArray + ); + + test_scalar_function!( + SparkLike::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT1)), + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT2)) + ], + $EXPECTED, + bool, + Boolean, + BooleanArray + ); + }; + } + + #[test] + fn test_like_invoke() -> Result<()> { + test_like_string_invoke!( + Some(String::from("Spark")), + Some(String::from("_park")), + Ok(Some(true)) + ); + test_like_string_invoke!( + Some(String::from("Spark")), + Some(String::from("_PARK")), + Ok(Some(false)) // case-sensitive + ); + test_like_string_invoke!( + Some(String::from("SPARK")), + Some(String::from("_park")), + Ok(Some(false)) // case-sensitive + ); + test_like_string_invoke!( + Some(String::from("Spark")), + Some(String::from("Sp%")), + Ok(Some(true)) + ); + test_like_string_invoke!( + Some(String::from("Spark")), + Some(String::from("SP%")), + Ok(Some(false)) // case-sensitive + ); + test_like_string_invoke!( + Some(String::from("Spark")), + Some(String::from("%ark")), + Ok(Some(true)) + ); + test_like_string_invoke!( + Some(String::from("Spark")), + Some(String::from("%ARK")), + Ok(Some(false)) // case-sensitive + ); + test_like_string_invoke!( + Some(String::from("Spark")), + Some(String::from("xyz")), + Ok(Some(false)) + ); + test_like_string_invoke!(None, Some(String::from("_park")), Ok(None)); + test_like_string_invoke!(Some(String::from("Spark")), None, Ok(None)); + test_like_string_invoke!(None, None, Ok(None)); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/luhn_check.rs b/datafusion/spark/src/function/string/luhn_check.rs new file mode 100644 index 0000000000000..090b16e34b8f1 --- /dev/null +++ b/datafusion/spark/src/function/string/luhn_check.rs @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::array::{Array, AsArray, BooleanArray}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Boolean; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; + +/// Spark-compatible `luhn_check` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkLuhnCheck { + signature: Signature, +} + +impl Default for SparkLuhnCheck { + fn default() -> Self { + Self::new() + } +} + +impl SparkLuhnCheck { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkLuhnCheck { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "luhn_check" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [array] = take_function_args(self.name(), &args.args)?; + + match array { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8View => { + let str_array = array.as_string_view(); + let values = str_array + .iter() + .map(|s| s.map(luhn_check_impl)) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } + DataType::Utf8 => { + let str_array = array.as_string::(); + let values = str_array + .iter() + .map(|s| s.map(luhn_check_impl)) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } + DataType::LargeUtf8 => { + let str_array = array.as_string::(); + let values = str_array + .iter() + .map(|s| s.map(luhn_check_impl)) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } + other => { + exec_err!("Unsupported data type {other:?} for function `luhn_check`") + } + }, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => Ok( + ColumnarValue::Scalar(ScalarValue::Boolean(Some(luhn_check_impl(s)))), + ), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + } + other => { + exec_err!("Unsupported data type {other:?} for function `luhn_check`") + } + } + } +} + +/// Validates a string using the Luhn algorithm. +/// Returns `true` if the input is a valid Luhn number. +fn luhn_check_impl(input: &str) -> bool { + let mut sum = 0u32; + let mut alt = false; + let mut digits_processed = 0; + + for b in input.as_bytes().iter().rev() { + let digit = match b { + b'0'..=b'9' => { + digits_processed += 1; + b - b'0' + } + _ => return false, + }; + + let mut val = digit as u32; + if alt { + val *= 2; + if val > 9 { + val -= 9; + } + } + sum += val; + alt = !alt; + } + + digits_processed > 0 && sum.is_multiple_of(10) +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs new file mode 100644 index 0000000000000..3115c1e960fa8 --- /dev/null +++ b/datafusion/spark/src/function/string/mod.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod ascii; +pub mod char; +pub mod elt; +pub mod format_string; +pub mod ilike; +pub mod length; +pub mod like; +pub mod luhn_check; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(ascii::SparkAscii, ascii); +make_udf_function!(char::CharFunc, char); +make_udf_function!(ilike::SparkILike, ilike); +make_udf_function!(length::SparkLengthFunc, length); +make_udf_function!(elt::SparkElt, elt); +make_udf_function!(like::SparkLike, like); +make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check); +make_udf_function!(format_string::FormatStringFunc, format_string); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + ascii, + "Returns the ASCII code point of the first character of string.", + arg1 + )); + export_functions!(( + char, + "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", + arg1 + )); + export_functions!(( + elt, + "Returns the n-th input (1-indexed), e.g. returns 2nd input when n is 2. The function returns NULL if the index is 0 or exceeds the length of the array.", + select_col arg1 arg2 argn + )); + export_functions!(( + ilike, + "Returns true if str matches pattern (case insensitive).", + str pattern + )); + export_functions!(( + length, + "Returns the character length of string data or number of bytes of binary data. The length of string data includes the trailing spaces. The length of binary data includes binary zeros.", + arg1 + )); + export_functions!(( + like, + "Returns true if str matches pattern (case sensitive).", + str pattern + )); + export_functions!(( + luhn_check, + "Returns whether the input string of digits is valid according to the Luhn algorithm.", + arg1 + )); + export_functions!(( + format_string, + "Returns a formatted string from printf-style format strings.", + strfmt args + )); +} + +pub fn functions() -> Vec> { + vec![ + ascii(), + char(), + elt(), + ilike(), + length(), + like(), + luhn_check(), + format_string(), + ] +} diff --git a/datafusion/spark/src/function/struct/mod.rs b/datafusion/spark/src/function/struct/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/struct/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/table/mod.rs b/datafusion/spark/src/function/table/mod.rs new file mode 100644 index 0000000000000..aba7b7ceb78ea --- /dev/null +++ b/datafusion/spark/src/function/table/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_catalog::TableFunction; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/url/mod.rs b/datafusion/spark/src/function/url/mod.rs new file mode 100644 index 0000000000000..82bf8a9e09616 --- /dev/null +++ b/datafusion/spark/src/function/url/mod.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +pub mod parse_url; +pub mod try_parse_url; + +make_udf_function!(parse_url::ParseUrl, parse_url); +make_udf_function!(try_parse_url::TryParseUrl, try_parse_url); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + parse_url, + "Extracts a part from a URL, throwing an error if an invalid URL is provided.", + args + )); + export_functions!(( + try_parse_url, + "Same as parse_url but returns NULL if an invalid URL is provided.", + args + )); +} + +pub fn functions() -> Vec> { + vec![parse_url(), try_parse_url()] +} diff --git a/datafusion/spark/src/function/url/parse_url.rs b/datafusion/spark/src/function/url/parse_url.rs new file mode 100644 index 0000000000000..d93c260b4f340 --- /dev/null +++ b/datafusion/spark/src/function/url/parse_url.rs @@ -0,0 +1,432 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, GenericStringBuilder, LargeStringArray, StringArray, + StringArrayType, StringViewArray, +}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{exec_datafusion_err, exec_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use url::{ParseError, Url}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ParseUrl { + signature: Signature, +} + +impl Default for ParseUrl { + fn default() -> Self { + Self::new() + } +} + +impl ParseUrl { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::String(2), TypeSignature::String(3)], + Volatility::Immutable, + ), + } + } + /// Parses a URL and extracts the specified component. + /// + /// This function takes a URL string and extracts different parts of it based on the + /// `part` parameter. For query parameters, an optional `key` can be specified to + /// extract a specific query parameter value. + /// + /// # Arguments + /// + /// * `value` - The URL string to parse + /// * `part` - The component of the URL to extract. Valid values are: + /// - `"HOST"` - The hostname (e.g., "example.com") + /// - `"PATH"` - The path portion (e.g., "/path/to/resource") + /// - `"QUERY"` - The query string or a specific query parameter + /// - `"REF"` - The fragment/anchor (the part after #) + /// - `"PROTOCOL"` - The URL scheme (e.g., "https", "http") + /// - `"FILE"` - The path with query string (e.g., "/path?query=value") + /// - `"AUTHORITY"` - The authority component (host:port) + /// - `"USERINFO"` - The user information (username:password) + /// * `key` - Optional parameter used only with `"QUERY"`. When provided, extracts + /// the value of the specific query parameter with this key name. + /// + /// # Returns + /// + /// * `Ok(Some(String))` - The extracted URL component as a string + /// * `Ok(None)` - If the requested component doesn't exist or is empty + /// * `Err(DataFusionError)` - If the URL is malformed and cannot be parsed + /// + fn parse(value: &str, part: &str, key: Option<&str>) -> Result> { + let url: std::result::Result = Url::parse(value); + if let Err(ParseError::RelativeUrlWithoutBase) = url { + return if !value.contains("://") { + Ok(None) + } else { + Err(exec_datafusion_err!("The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02")) + }; + }; + url.map_err(|e| exec_datafusion_err!("{e:?}")) + .map(|url| match part { + "HOST" => url.host_str().map(String::from), + "PATH" => { + let path: String = url.path().to_string(); + let path: String = if path == "/" { "".to_string() } else { path }; + Some(path) + } + "QUERY" => match key { + None => url.query().map(String::from), + Some(key) => url + .query_pairs() + .find(|(k, _)| k == key) + .map(|(_, v)| v.into_owned()), + }, + "REF" => url.fragment().map(String::from), + "PROTOCOL" => Some(url.scheme().to_string()), + "FILE" => { + let path = url.path(); + match url.query() { + Some(query) => Some(format!("{path}?{query}")), + None => Some(path.to_string()), + } + } + "AUTHORITY" => Some(url.authority().to_string()), + "USERINFO" => { + let username = url.username(); + if username.is_empty() { + return None; + } + match url.password() { + Some(password) => Some(format!("{username}:{password}")), + None => Some(username.to_string()), + } + } + _ => None, + }) + } +} + +impl ScalarUDFImpl for ParseUrl { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "parse_url" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_parse_url, vec![])(&args) + } +} + +/// Core implementation of URL parsing function. +/// +/// # Arguments +/// +/// * `args` - A slice of ArrayRef containing the input arrays: +/// - `args[0]` - URL array: The URLs to parse +/// - `args[1]` - Part array: The URL components to extract (HOST, PATH, QUERY, etc.) +/// - `args[2]` - Key array (optional): For QUERY part, the specific parameter names to extract +/// +/// # Return Value +/// +/// Returns `Result` containing: +/// - A string array with extracted URL components +/// - `None` values where extraction failed or component doesn't exist +/// - The output array type (StringArray or LargeStringArray) is determined by input types +/// +fn spark_parse_url(args: &[ArrayRef]) -> Result { + spark_handled_parse_url(args, |x| x) +} + +pub fn spark_handled_parse_url( + args: &[ArrayRef], + handler_err: impl Fn(Result>) -> Result>, +) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "{} expects 2 or 3 arguments, but got {}", + "`parse_url`", + args.len() + ); + } + // Required arguments + let url = &args[0]; + let part = &args[1]; + + let result = if args.len() == 3 { + // In this case, the 'key' argument is passed + let key = &args[2]; + + match (url.data_type(), part.data_type(), key.data_type()) { + (DataType::Utf8, DataType::Utf8, DataType::Utf8) => { + process_parse_url::<_, _, _, StringArray>( + as_string_array(url)?, + as_string_array(part)?, + as_string_array(key)?, + handler_err, + ) + } + (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { + process_parse_url::<_, _, _, StringViewArray>( + as_string_view_array(url)?, + as_string_view_array(part)?, + as_string_view_array(key)?, + handler_err, + ) + } + (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { + process_parse_url::<_, _, _, LargeStringArray>( + as_large_string_array(url)?, + as_large_string_array(part)?, + as_large_string_array(key)?, + handler_err, + ) + } + _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), + } + } else { + // The 'key' argument is omitted, assume all values are null + // Create 'null' string array for 'key' argument + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + for _ in 0..args[0].len() { + builder.append_null(); + } + let key = builder.finish(); + + match (url.data_type(), part.data_type()) { + (DataType::Utf8, DataType::Utf8) => { + process_parse_url::<_, _, _, StringArray>( + as_string_array(url)?, + as_string_array(part)?, + &key, + handler_err, + ) + } + (DataType::Utf8View, DataType::Utf8View) => { + process_parse_url::<_, _, _, StringViewArray>( + as_string_view_array(url)?, + as_string_view_array(part)?, + &key, + handler_err, + ) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + process_parse_url::<_, _, _, LargeStringArray>( + as_large_string_array(url)?, + as_large_string_array(part)?, + &key, + handler_err, + ) + } + _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), + } + }; + result +} + +fn process_parse_url<'a, A, B, C, T>( + url_array: &'a A, + part_array: &'a B, + key_array: &'a C, + handle: impl Fn(Result>) -> Result>, +) -> Result +where + &'a A: StringArrayType<'a>, + &'a B: StringArrayType<'a>, + &'a C: StringArrayType<'a>, + T: Array + FromIterator> + 'static, +{ + url_array + .iter() + .zip(part_array.iter()) + .zip(key_array.iter()) + .map(|((url, part), key)| { + if let (Some(url), Some(part), key) = (url, part, key) { + handle(ParseUrl::parse(url, part, key)) + } else { + Ok(None) + } + }) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, Int32Array, StringArray}; + use datafusion_common::Result; + use std::array::from_ref; + use std::sync::Arc; + + fn sa(vals: &[Option<&str>]) -> ArrayRef { + Arc::new(StringArray::from(vals.to_vec())) as ArrayRef + } + + #[test] + fn test_parse_host() -> Result<()> { + let got = ParseUrl::parse("https://example.com/a?x=1", "HOST", None)?; + assert_eq!(got, Some("example.com".to_string())); + Ok(()) + } + + #[test] + fn test_parse_query_no_key_vs_with_key() -> Result<()> { + let got_all = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", None)?; + assert_eq!(got_all, Some("a=1&b=2".to_string())); + + let got_a = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("a"))?; + assert_eq!(got_a, Some("1".to_string())); + + let got_c = ParseUrl::parse("https://ex.com/p?a=1&b=2", "QUERY", Some("c"))?; + assert_eq!(got_c, None); + Ok(()) + } + + #[test] + fn test_parse_ref_protocol_userinfo_file_authority() -> Result<()> { + let url = "ftp://user:pwd@ftp.example.com:21/files?x=1#frag"; + assert_eq!(ParseUrl::parse(url, "REF", None)?, Some("frag".to_string())); + assert_eq!( + ParseUrl::parse(url, "PROTOCOL", None)?, + Some("ftp".to_string()) + ); + assert_eq!( + ParseUrl::parse(url, "USERINFO", None)?, + Some("user:pwd".to_string()) + ); + assert_eq!( + ParseUrl::parse(url, "FILE", None)?, + Some("/files?x=1".to_string()) + ); + assert_eq!( + ParseUrl::parse(url, "AUTHORITY", None)?, + Some("user:pwd@ftp.example.com".to_string()) + ); + Ok(()) + } + + #[test] + fn test_parse_path_root_is_empty_string() -> Result<()> { + let got = ParseUrl::parse("https://example.com/", "PATH", None)?; + assert_eq!(got, Some("".to_string())); + Ok(()) + } + + #[test] + fn test_parse_malformed_url_returns_error() -> Result<()> { + let got = ParseUrl::parse("notaurl", "HOST", None)?; + assert_eq!(got, None); + Ok(()) + } + + #[test] + fn test_spark_utf8_two_args() -> Result<()> { + let urls = sa(&[Some("https://example.com/a?x=1"), Some("https://ex.com/")]); + let parts = sa(&[Some("HOST"), Some("PATH")]); + + let out = spark_handled_parse_url(&[urls, parts], |x| x)?; + let out_sa = out.as_any().downcast_ref::().unwrap(); + + assert_eq!(out_sa.len(), 2); + assert_eq!(out_sa.value(0), "example.com"); + assert_eq!(out_sa.value(1), ""); + Ok(()) + } + + #[test] + fn test_spark_utf8_three_args_query_key() -> Result<()> { + let urls = sa(&[ + Some("https://example.com/a?x=1&y=2"), + Some("https://ex.com/?a=1"), + ]); + let parts = sa(&[Some("QUERY"), Some("QUERY")]); + let keys = sa(&[Some("y"), Some("b")]); + + let out = spark_handled_parse_url(&[urls, parts, keys], |x| x)?; + let out_sa = out.as_any().downcast_ref::().unwrap(); + + assert_eq!(out_sa.len(), 2); + assert_eq!(out_sa.value(0), "2"); + assert!(out_sa.is_null(1)); + Ok(()) + } + + #[test] + fn test_spark_userinfo_and_nulls() -> Result<()> { + let urls = sa(&[ + Some("ftp://user:pwd@ftp.example.com:21/files"), + Some("https://example.com"), + None, + ]); + let parts = sa(&[Some("USERINFO"), Some("USERINFO"), Some("USERINFO")]); + + let out = spark_handled_parse_url(&[urls, parts], |x| x)?; + let out_sa = out.as_any().downcast_ref::().unwrap(); + + assert_eq!(out_sa.len(), 3); + assert_eq!(out_sa.value(0), "user:pwd"); + assert!(out_sa.is_null(1)); + assert!(out_sa.is_null(2)); + Ok(()) + } + + #[test] + fn test_invalid_arg_count() { + let urls = sa(&[Some("https://example.com")]); + let err = spark_handled_parse_url(from_ref(&urls), |x| x).unwrap_err(); + assert!(format!("{err}").contains("expects 2 or 3 arguments")); + + let parts = sa(&[Some("HOST")]); + let keys = sa(&[Some("x")]); + let err = + spark_handled_parse_url(&[urls, parts, keys, sa(&[Some("extra")])], |x| x) + .unwrap_err(); + assert!(format!("{err}").contains("expects 2 or 3 arguments")); + } + + #[test] + fn test_non_string_types_error() { + let urls = sa(&[Some("https://example.com")]); + let bad_part = Arc::new(Int32Array::from(vec![1])) as ArrayRef; + + let err = spark_handled_parse_url(&[urls, bad_part], |x| x).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("expects STRING arguments")); + } +} diff --git a/datafusion/spark/src/function/url/try_parse_url.rs b/datafusion/spark/src/function/url/try_parse_url.rs new file mode 100644 index 0000000000000..c04850f3a6bf0 --- /dev/null +++ b/datafusion/spark/src/function/url/try_parse_url.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use crate::function::url::parse_url::{spark_handled_parse_url, ParseUrl}; +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// TRY_PARSE_URL function for tolerant URL component extraction (never errors; returns NULL on invalid or missing parts). +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct TryParseUrl { + signature: Signature, +} + +impl Default for TryParseUrl { + fn default() -> Self { + Self::new() + } +} + +impl TryParseUrl { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::String(2), TypeSignature::String(3)], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TryParseUrl { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "try_parse_url" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let parse_url: ParseUrl = ParseUrl::new(); + parse_url.return_type(arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_try_parse_url, vec![])(&args) + } +} + +fn spark_try_parse_url(args: &[ArrayRef]) -> Result { + spark_handled_parse_url(args, |x| match x { + Err(_) => Ok(None), + result => result, + }) +} diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs new file mode 100644 index 0000000000000..e272d91d8a70e --- /dev/null +++ b/datafusion/spark/src/function/utils.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +pub mod test { + /// $FUNC ScalarUDFImpl to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result + /// $EXPECTED_TYPE is the expected value type + /// $EXPECTED_DATA_TYPE is the expected result type + /// $ARRAY_TYPE is the column type after function applied + /// $CONFIG_OPTIONS config options to pass to function + macro_rules! test_scalar_function { + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => { + let expected: datafusion_common::Result> = $EXPECTED; + let func = $FUNC; + + let arg_fields: Vec = $ARGS + .iter() + .enumerate() + .map(|(idx, arg)| { + + let nullable = match arg { + datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(), + datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0, + }; + + std::sync::Arc::new(arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable)) + }) + .collect::>(); + + let cardinality = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + datafusion_expr::ColumnarValue::Scalar(_) => acc, + datafusion_expr::ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); + + let scalar_arguments = $ARGS.iter().map(|arg| match arg { + datafusion_expr::ColumnarValue::Scalar(scalar) => Some(scalar.clone()), + datafusion_expr::ColumnarValue::Array(_) => None, + }).collect::>(); + let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::>(); + + + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments_refs + }); + + match expected { + Ok(expected) => { + if let Ok(return_field) = return_field { + assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE); + + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ + args: $ARGS, + number_rows: cardinality, + return_field, + arg_fields: arg_fields.clone(), + config_options: $CONFIG_OPTIONS, + }) { + Ok(col_value) => { + match col_value.to_array(cardinality) { + Ok(array) => { + let result = array + .as_any() + .downcast_ref::<$ARRAY_TYPE>() + .expect("Failed to convert to type"); + assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(err) => { + panic!("Failed to convert to array: {err}"); + } + } + } + Err(err) => { + panic!("function returned an error: {err}"); + } + } + } else { + panic!("Expected return_field to be Ok but got Err"); + } + } + Err(expected_error) => { + if let Err(error) = &return_field { + datafusion_common::assert_contains!( + expected_error.strip_backtrace(), + error.strip_backtrace() + ); + } else if let Ok(value) = return_field { + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs { + args: $ARGS, + number_rows: cardinality, + return_field: value, + arg_fields, + config_options: $CONFIG_OPTIONS, + }) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); + } + } + } + } + }; + }; + + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + test_scalar_function!( + $FUNC, + $ARGS, + $EXPECTED, + $EXPECTED_TYPE, + $EXPECTED_DATA_TYPE, + $ARRAY_TYPE, + std::sync::Arc::new(datafusion_common::config::ConfigOptions::default()) + ) + }; + } + + pub(crate) use test_scalar_function; +} diff --git a/datafusion/spark/src/function/window/mod.rs b/datafusion/spark/src/function/window/mod.rs new file mode 100644 index 0000000000000..97ab4a9e35422 --- /dev/null +++ b/datafusion/spark/src/function/window/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::WindowUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/xml/mod.rs b/datafusion/spark/src/function/xml/mod.rs new file mode 100644 index 0000000000000..a87df9a2c87a0 --- /dev/null +++ b/datafusion/spark/src/function/xml/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs new file mode 100644 index 0000000000000..4d45f3c482af3 --- /dev/null +++ b/datafusion/spark/src/lib.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![doc( + html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", + html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" +)] +#![cfg_attr(docsrs, feature(doc_cfg))] +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +//! Spark Expression packages for [DataFusion]. +//! +//! This crate contains a collection of various Spark function packages for DataFusion, +//! implemented using the extension API. +//! +//! [DataFusion]: https://crates.io/crates/datafusion +//! +//! +//! # Available Function Packages +//! See the list of [modules](#modules) in this crate for available packages. +//! +//! # Example: using all function packages +//! +//! You can register all the functions in all packages using the [`register_all`] +//! function as shown below. Any existing functions will be overwritten, with these +//! Spark functions taking priority. +//! +//! ``` +//! # use datafusion_execution::FunctionRegistry; +//! # use datafusion_expr::{ScalarUDF, AggregateUDF, WindowUDF}; +//! # use datafusion_expr::planner::ExprPlanner; +//! # use datafusion_common::Result; +//! # use std::collections::HashSet; +//! # use std::sync::Arc; +//! # // Note: We can't use a real SessionContext here because the +//! # // `datafusion_spark` crate has no dependence on the DataFusion crate +//! # // thus use a dummy SessionContext that has enough of the implementation +//! # struct SessionContext {} +//! # impl FunctionRegistry for SessionContext { +//! # fn register_udf(&mut self, _udf: Arc) -> Result>> { Ok (None) } +//! # fn udfs(&self) -> HashSet { unimplemented!() } +//! # fn udafs(&self) -> HashSet { unimplemented!() } +//! # fn udwfs(&self) -> HashSet { unimplemented!() } +//! # fn udf(&self, _name: &str) -> Result> { unimplemented!() } +//! # fn udaf(&self, name: &str) -> Result> {unimplemented!() } +//! # fn udwf(&self, name: &str) -> Result> { unimplemented!() } +//! # fn expr_planners(&self) -> Vec> { unimplemented!() } +//! # } +//! # impl SessionContext { +//! # fn new() -> Self { SessionContext {} } +//! # async fn sql(&mut self, _query: &str) -> Result<()> { Ok(()) } +//! # } +//! # +//! # async fn stub() -> Result<()> { +//! // Create a new session context +//! let mut ctx = SessionContext::new(); +//! // Register all Spark functions with the context +//! datafusion_spark::register_all(&mut ctx)?; +//! // Run a query using the `sha2` function which is now available and has Spark semantics +//! let df = ctx.sql("SELECT sha2('The input String', 256)").await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Example: calling a specific function in Rust +//! +//! Each package also exports an `expr_fn` submodule that create [`Expr`]s for +//! invoking functions via rust using a fluent style. For example, to invoke the +//! `sha2` function, you can use the following code: +//! +//! ```rust +//! # use datafusion_expr::{col, lit}; +//! use datafusion_spark::expr_fn::sha2; +//! // Create the expression `sha2(my_data, 256)` +//! let expr = sha2(col("my_data"), lit(256)); +//!``` +//! +//![`Expr`]: datafusion_expr::Expr + +pub mod function; + +use datafusion_catalog::TableFunction; +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use log::debug; +use std::sync::Arc; + +/// Fluent-style API for creating `Expr`s +#[allow(unused)] +pub mod expr_fn { + pub use super::function::aggregate::expr_fn::*; + pub use super::function::array::expr_fn::*; + pub use super::function::bitmap::expr_fn::*; + pub use super::function::bitwise::expr_fn::*; + pub use super::function::collection::expr_fn::*; + pub use super::function::conditional::expr_fn::*; + pub use super::function::conversion::expr_fn::*; + pub use super::function::csv::expr_fn::*; + pub use super::function::datetime::expr_fn::*; + pub use super::function::generator::expr_fn::*; + pub use super::function::hash::expr_fn::*; + pub use super::function::json::expr_fn::*; + pub use super::function::lambda::expr_fn::*; + pub use super::function::map::expr_fn::*; + pub use super::function::math::expr_fn::*; + pub use super::function::misc::expr_fn::*; + pub use super::function::predicate::expr_fn::*; + pub use super::function::r#struct::expr_fn::*; + pub use super::function::string::expr_fn::*; + pub use super::function::table::expr_fn::*; + pub use super::function::url::expr_fn::*; + pub use super::function::window::expr_fn::*; + pub use super::function::xml::expr_fn::*; +} + +/// Returns all default scalar functions +pub fn all_default_scalar_functions() -> Vec> { + function::array::functions() + .into_iter() + .chain(function::bitmap::functions()) + .chain(function::bitwise::functions()) + .chain(function::collection::functions()) + .chain(function::conditional::functions()) + .chain(function::conversion::functions()) + .chain(function::csv::functions()) + .chain(function::datetime::functions()) + .chain(function::generator::functions()) + .chain(function::hash::functions()) + .chain(function::json::functions()) + .chain(function::lambda::functions()) + .chain(function::map::functions()) + .chain(function::math::functions()) + .chain(function::misc::functions()) + .chain(function::predicate::functions()) + .chain(function::string::functions()) + .chain(function::r#struct::functions()) + .chain(function::url::functions()) + .chain(function::xml::functions()) + .collect::>() +} + +/// Returns all default aggregate functions +pub fn all_default_aggregate_functions() -> Vec> { + function::aggregate::functions() +} + +/// Returns all default window functions +pub fn all_default_window_functions() -> Vec> { + function::window::functions() +} + +/// Returns all default table functions +pub fn all_default_table_functions() -> Vec> { + function::table::functions() +} + +/// Registers all enabled packages with a [`FunctionRegistry`], overriding any existing +/// functions if there is a name clash. +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let scalar_functions: Vec> = all_default_scalar_functions(); + scalar_functions.into_iter().try_for_each(|udf| { + let existing_udf = registry.register_udf(udf)?; + if let Some(existing_udf) = existing_udf { + debug!("Overwrite existing UDF: {}", existing_udf.name()); + } + Ok(()) as Result<()> + })?; + + let aggregate_functions: Vec> = all_default_aggregate_functions(); + aggregate_functions.into_iter().try_for_each(|udf| { + let existing_udaf = registry.register_udaf(udf)?; + if let Some(existing_udaf) = existing_udaf { + debug!("Overwrite existing UDAF: {}", existing_udaf.name()); + } + Ok(()) as Result<()> + })?; + + let window_functions: Vec> = all_default_window_functions(); + window_functions.into_iter().try_for_each(|udf| { + let existing_udwf = registry.register_udwf(udf)?; + if let Some(existing_udwf) = existing_udwf { + debug!("Overwrite existing UDWF: {}", existing_udwf.name()); + } + Ok(()) as Result<()> + })?; + + Ok(()) +} diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index b778db46769d0..ea2cd6dfcc7d8 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -43,11 +43,16 @@ unicode_expressions = [] unparser = [] recursive_protection = ["dep:recursive"] +# Note the sql planner should not depend directly on the datafusion-function packages +# so that it can be used in a standalone manner with other function implementations. +# +# They are used for testing purposes only, so they are in the dev-dependencies section. [dependencies] arrow = { workspace = true } bigdecimal = { workspace = true } -datafusion-common = { workspace = true, default-features = true } -datafusion-expr = { workspace = true } +chrono = { workspace = true } +datafusion-common = { workspace = true, features = ["sql"] } +datafusion-expr = { workspace = true, features = ["sql"] } indexmap = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } @@ -56,11 +61,13 @@ sqlparser = { workspace = true } [dev-dependencies] ctor = { workspace = true } +# please do not move these dependencies to the main dependencies section datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } -datafusion-functions-nested = { workspace = true } +datafusion-functions-nested = { workspace = true, features = ["sql"] } datafusion-functions-window = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } +itertools = { workspace = true } paste = "^1.0" rstest = { workspace = true } diff --git a/datafusion/sql/README.md b/datafusion/sql/README.md index 98f3c4faa2ec0..d0e5e498e514c 100644 --- a/datafusion/sql/README.md +++ b/datafusion/sql/README.md @@ -17,17 +17,24 @@ under the License. --> -# DataFusion SQL Query Planner +# Apache DataFusion SQL Query Planner This crate provides a general purpose SQL query planner that can parse SQL and translate queries into logical -plans. Although this crate is used by the [DataFusion][df] query engine, it was designed to be easily usable from any +plans. Although this crate is used by the [Apache DataFusion] query engine, it was designed to be easily usable from any project that requires a SQL query planner and does not make any assumptions about how the resulting logical plan will be translated to a physical plan. For example, there is no concept of row-based versus columnar execution in the logical plan. +Note that the [`datafusion`] crate re-exports this module. If you are already +using the [`datafusion`] crate in your project, there is no reason to use this +crate directly in your project as well. + +[apache datafusion]: https://datafusion.apache.org/ +[`datafusion`]: https://crates.io/crates/datafusion + ## Example Usage -See the [examples](examples) directory for fully working examples. +See the [examples] directory for fully working examples. Here is an example of producing a logical plan from a SQL string. @@ -62,8 +69,8 @@ fn main() { ``` This is the logical plan that is produced from this example. Note that this is an **unoptimized** -logical plan. The [datafusion-optimizer](https://crates.io/crates/datafusion-optimizer) crate provides a query -optimizer that can be applied to plans produced by this crate. +logical plan. The [datafusion-optimizer] crate provides a query optimizer that can be applied to +plans produced by this crate. ``` Sort: state_tax DESC NULLS FIRST @@ -80,4 +87,5 @@ Sort: state_tax DESC NULLS FIRST TableScan: orders ``` -[df]: https://crates.io/crates/datafusion +[examples]: examples +[datafusion-optimizer]: https://crates.io/crates/datafusion-optimizer diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 3650aea9c3c20..aceec676761cb 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -19,7 +19,6 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::datatypes::Schema; use datafusion_common::{ not_impl_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, @@ -135,10 +134,9 @@ impl SqlToRel<'_, S> { // ---------- Step 2: Create a temporary relation ------------------ // Step 2.1: Create a table source for the temporary relation - let work_table_source = self.context_provider.create_cte_work_table( - &cte_name, - Arc::new(Schema::from(static_plan.schema().as_ref())), - )?; + let work_table_source = self + .context_provider + .create_cte_work_table(&cte_name, Arc::clone(static_plan.schema().inner()))?; // Step 2.2: Create a temporary relation logical plan that will be used // as the input to the recursive term diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 436f4388d8a31..eabf645a5eafd 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -22,15 +22,15 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Diagnostic, Result, Span, }; -use datafusion_expr::expr::{ScalarFunction, Unnest, WildcardOptions}; -use datafusion_expr::planner::{PlannerResult, RawAggregateExpr, RawWindowExpr}; -use datafusion_expr::{ - expr, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, +use datafusion_expr::expr::{ + NullTreatment, ScalarFunction, Unnest, WildcardOptions, WindowFunction, }; +use datafusion_expr::planner::{PlannerResult, RawAggregateExpr, RawWindowExpr}; +use datafusion_expr::{expr, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition}; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, - NullTreatment, ObjectName, OrderByExpr, Spanned, WindowType, + ObjectName, OrderByExpr, Spanned, WindowType, }; /// Suggest a valid function based on an invalid input function name @@ -74,7 +74,7 @@ fn find_closest_match(candidates: Vec, target: &str) -> Option { }) } -/// Arguments to for a function call extracted from the SQL AST +/// Arguments for a function call extracted from the SQL AST #[derive(Debug)] struct FunctionArgs { /// Function name @@ -91,6 +91,10 @@ struct FunctionArgs { null_treatment: Option, /// DISTINCT distinct: bool, + /// WITHIN GROUP clause, if any + within_group: Vec, + /// Was the function called without parenthesis, i.e. could this also be a column reference? + function_without_parentheses: bool, } impl FunctionArgs { @@ -113,8 +117,10 @@ impl FunctionArgs { order_by: vec![], over, filter, - null_treatment, + null_treatment: null_treatment.map(|v| v.into()), distinct: false, + within_group, + function_without_parentheses: matches!(args, FunctionArguments::None), }); }; @@ -144,6 +150,9 @@ impl FunctionArgs { } FunctionArgumentClause::OrderBy(oby) => { if order_by.is_some() { + if !within_group.is_empty() { + return plan_err!("ORDER BY clause is only permitted in WITHIN GROUP clause when a WITHIN GROUP is used"); + } return not_impl_err!("Calling {name}: Duplicated ORDER BY clause in function arguments"); } order_by = Some(oby); @@ -173,11 +182,18 @@ impl FunctionArgs { "Calling {name}: JSON NULL clause not supported in function arguments: {jn}" ) } + FunctionArgumentClause::JsonReturningClause(jr) => { + return not_impl_err!( + "Calling {name}: JSON RETURNING clause not supported in function arguments: {jr}" + ) + }, } } - if !within_group.is_empty() { - return not_impl_err!("WITHIN GROUP is not supported yet: {within_group:?}"); + if within_group.len() > 1 { + return not_impl_err!( + "Only a single ordering expression is permitted in a WITHIN GROUP clause" + ); } let order_by = order_by.unwrap_or_default(); @@ -188,8 +204,10 @@ impl FunctionArgs { order_by, over, filter, - null_treatment, + null_treatment: null_treatment.map(|v| v.into()), distinct, + within_group, + function_without_parentheses: false, }) } } @@ -203,31 +221,42 @@ impl SqlToRel<'_, S> { ) -> Result { let function_args = FunctionArgs::try_new(function)?; let FunctionArgs { - name, + name: object_name, args, order_by, over, filter, null_treatment, distinct, + within_group, + function_without_parentheses, } = function_args; + if over.is_some() && !within_group.is_empty() { + return plan_err!("OVER and WITHIN GROUP clause cannot be used together. \ + OVER is for window functions, whereas WITHIN GROUP is for ordered set aggregate functions"); + } + + if !order_by.is_empty() && !within_group.is_empty() { + return plan_err!("ORDER BY and WITHIN GROUP clauses cannot be used together in the same aggregate function"); + } + // If function is a window function (it has an OVER clause), // it shouldn't have ordering requirement as function argument // required ordering should be defined in OVER clause. let is_function_window = over.is_some(); - let sql_parser_span = name.0[0].span(); - let name = if name.0.len() > 1 { + let sql_parser_span = object_name.0[0].span(); + let name = if object_name.0.len() > 1 { // DF doesn't handle compound identifiers // (e.g. "foo.bar") for function names yet - name.to_string() + object_name.to_string() } else { - match name.0[0].as_ident() { + match object_name.0[0].as_ident() { Some(ident) => crate::utils::normalize_ident(ident.clone()), None => { return plan_err!( "Expected an identifier in function name, but found {:?}", - name.0[0] + object_name.0[0] ) } } @@ -246,7 +275,24 @@ impl SqlToRel<'_, S> { // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); + let inner = ScalarFunction::new_udf(fm, args); + + if name.eq_ignore_ascii_case(inner.name()) { + return Ok(Expr::ScalarFunction(inner)); + } else { + // If the function is called by an alias, a verbose string representation is created + // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` + // to ensure the output column name matches the user's query. + let arg_names = inner + .args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(","); + let verbose_alias = format!("{name}({arg_names})"); + + return Ok(Expr::ScalarFunction(inner).alias(verbose_alias)); + } } // Build Unnest expression @@ -321,13 +367,22 @@ impl SqlToRel<'_, S> { if let Ok(fun) = self.find_window_func(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; + + // Plan FILTER clause if present + let filter = filter + .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) + .transpose()? + .map(Box::new); + let mut window_expr = RawWindowExpr { func_def: fun, args, partition_by, order_by, window_frame, + filter, null_treatment, + distinct: function_args.distinct, }; for planner in self.context_provider.get_expr_planners().iter() { @@ -343,28 +398,89 @@ impl SqlToRel<'_, S> { partition_by, order_by, window_frame, + filter, null_treatment, + distinct, } = window_expr; - return Expr::WindowFunction(expr::WindowFunction::new(func_def, args)) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build(); + let inner = WindowFunction { + fun: func_def, + params: expr::WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + null_treatment, + distinct, + }, + }; + + if name.eq_ignore_ascii_case(inner.fun.name()) { + return Ok(Expr::WindowFunction(Box::new(inner))); + } else { + // If the function is called by an alias, a verbose string representation is created + // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` + // to ensure the output column name matches the user's query. + let arg_names = inner + .params + .args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(","); + let verbose_alias = format!("{name}({arg_names})"); + + return Ok(Expr::WindowFunction(Box::new(inner)).alias(verbose_alias)); + } } } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { - let order_by = self.order_by_to_sort_expr( - order_by, - schema, - planner_context, - true, - None, - )?; - let order_by = (!order_by.is_empty()).then_some(order_by); - let args = self.function_args_to_expr(args, schema, planner_context)?; + if null_treatment.is_some() && !fm.supports_null_handling_clause() { + return plan_err!( + "[IGNORE | RESPECT] NULLS are not permitted for {}", + fm.name() + ); + } + + let mut args = + self.function_args_to_expr(args, schema, planner_context)?; + + let order_by = if fm.is_ordered_set_aggregate() { + let within_group = self.order_by_to_sort_expr( + within_group, + schema, + planner_context, + false, + None, + )?; + + // Add the WITHIN GROUP ordering expressions to the front of the argument list + // So function(arg) WITHIN GROUP (ORDER BY x) becomes function(x, arg) + if !within_group.is_empty() { + args = within_group + .iter() + .map(|sort| sort.expr.clone()) + .chain(args) + .collect::>(); + } + within_group + } else { + let order_by = if !order_by.is_empty() { + order_by + } else { + within_group + }; + self.order_by_to_sort_expr( + order_by, + schema, + planner_context, + true, + None, + )? + }; + let filter: Option> = filter .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) .transpose()? @@ -394,31 +510,69 @@ impl SqlToRel<'_, S> { null_treatment, } = aggregate_expr; - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let inner = expr::AggregateFunction::new_udf( func, args, distinct, filter, order_by, null_treatment, - ))); + ); + + if name.eq_ignore_ascii_case(inner.func.name()) { + return Ok(Expr::AggregateFunction(inner)); + } else { + // If the function is called by an alias, a verbose string representation is created + // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` + // to ensure the output column name matches the user's query. + let arg_names = inner + .params + .args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(","); + let verbose_alias = format!("{name}({arg_names})"); + + return Ok(Expr::AggregateFunction(inner).alias(verbose_alias)); + } } } + + // workaround for https://github.com/apache/datafusion-sqlparser-rs/issues/1909 + if function_without_parentheses { + let maybe_ids = object_name + .0 + .iter() + .map(|part| part.as_ident().cloned().ok_or(())) + .collect::, ()>>(); + if let Ok(ids) = maybe_ids { + if ids.len() == 1 { + return self.sql_identifier_to_expr( + ids.into_iter().next().unwrap(), + schema, + planner_context, + ); + } else { + return self.sql_compound_identifier_to_expr( + ids, + schema, + planner_context, + ); + } + } + } + // Could not find the relevant function, so return an error if let Some(suggested_func_name) = suggest_valid_function(&name, is_function_window, self.context_provider) { - plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") - .map_err(|e| { - let span = Span::try_from_sqlparser_span(sql_parser_span); - let mut diagnostic = - Diagnostic::new_error(format!("Invalid function '{name}'"), span); - diagnostic.add_note( - format!("Possible function '{}'", suggested_func_name), - None, - ); - e.with_diagnostic(diagnostic) - }) + let span = Span::try_from_sqlparser_span(sql_parser_span); + let mut diagnostic = + Diagnostic::new_error(format!("Invalid function '{name}'"), span); + diagnostic + .add_note(format!("Possible function '{suggested_func_name}'"), None); + plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?"; diagnostic=diagnostic) } else { internal_err!("No functions registered with this context.") } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 7c276ce53e35d..3c57d195ade67 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -17,12 +17,13 @@ use arrow::datatypes::Field; use datafusion_common::{ - internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, - DataFusionError, Result, Span, TableReference, + exec_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, + Column, DFSchema, Result, Span, TableReference, }; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; use sqlparser::ast::{CaseWhen, Expr as SQLExpr, Ident}; +use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_expr::UNNAMED_TABLE; @@ -75,7 +76,7 @@ impl SqlToRel<'_, S> { { // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column return Ok(Expr::OuterReferenceColumn( - field.data_type().clone(), + Arc::new(field.clone()), Column::from((qualifier, field)), )); } @@ -116,9 +117,7 @@ impl SqlToRel<'_, S> { .context_provider .get_variable_type(&var_names) .ok_or_else(|| { - DataFusionError::Execution(format!( - "variable {var_names:?} has no type information" - )) + exec_datafusion_err!("variable {var_names:?} has no type information") })?; Ok(Expr::ScalarVariable(ty, var_names)) } else { @@ -182,7 +181,7 @@ impl SqlToRel<'_, S> { Some((field, qualifier, _nested_names)) => { // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column Ok(Expr::OuterReferenceColumn( - field.data_type().clone(), + Arc::new(field.clone()), Column::from((qualifier, field)), )) } @@ -459,8 +458,8 @@ mod test { fn test_form_identifier() -> Result<()> { let err = form_identifier(&[]).expect_err("empty identifiers didn't fail"); let expected = "Internal error: Incorrect number of identifiers: 0.\n\ - This was likely caused by a bug in DataFusion's code and we would \ - welcome that you file an bug report in our issue tracker"; + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this \ + by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues"; assert!(expected.starts_with(&err.strip_backtrace())); let ids = vec!["a".to_string()]; @@ -497,8 +496,8 @@ mod test { ]) .expect_err("too many identifiers didn't fail"); let expected = "Internal error: Incorrect number of identifiers: 5.\n\ - This was likely caused by a bug in DataFusion's code and we would \ - welcome that you file an bug report in our issue tracker"; + This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this \ + by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues"; assert!(expected.starts_with(&err.strip_backtrace())); Ok(()) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d29ccdc6a7e9e..23426701409eb 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -22,7 +22,7 @@ use datafusion_expr::planner::{ use sqlparser::ast::{ AccessExpr, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, - StructField, Subscript, TrimWhereField, Value, ValueWithSpan, + StructField, Subscript, TrimWhereField, TypedString, Value, ValueWithSpan, }; use datafusion_common::{ @@ -215,7 +215,7 @@ impl SqlToRel<'_, S> { } SQLExpr::Extract { field, expr, .. } => { let mut extract_args = vec![ - Expr::Literal(ScalarValue::from(format!("{field}"))), + Expr::Literal(ScalarValue::from(format!("{field}")), None), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; @@ -254,6 +254,8 @@ impl SqlToRel<'_, S> { operand, conditions, else_result, + case_token: _, + end_token: _, } => self.sql_case_identifier_to_expr( operand, conditions, @@ -289,7 +291,11 @@ impl SqlToRel<'_, S> { ))) } - SQLExpr::TypedString { data_type, value } => Ok(Expr::Cast(Cast::new( + SQLExpr::TypedString(TypedString { + data_type, + value, + uses_odbc_syntax: _, + }) => Ok(Expr::Cast(Cast::new( Box::new(lit(value.into_string().unwrap())), self.convert_data_type(&data_type)?, ))), @@ -446,6 +452,7 @@ impl SqlToRel<'_, S> { substring_from, substring_for, special: _, + shorthand: _, } => self.sql_substring_to_expr( expr, substring_from, @@ -644,7 +651,9 @@ impl SqlToRel<'_, S> { values: Vec, ) -> Result { match values.first() { - Some(SQLExpr::Identifier(_)) | Some(SQLExpr::Value(_)) => { + Some(SQLExpr::Identifier(_)) + | Some(SQLExpr::Value(_)) + | Some(SQLExpr::CompoundIdentifier(_)) => { self.parse_struct(schema, planner_context, values, vec![]) } None => not_impl_err!("Empty tuple not supported yet"), @@ -811,7 +820,7 @@ impl SqlToRel<'_, S> { negated: bool, expr: SQLExpr, pattern: SQLExpr, - escape_char: Option, + escape_char: Option, schema: &DFSchema, planner_context: &mut PlannerContext, case_insensitive: bool, @@ -821,13 +830,12 @@ impl SqlToRel<'_, S> { return not_impl_err!("ANY in LIKE expression"); } let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; - let escape_char = if let Some(char) = escape_char { - if char.len() != 1 { - return plan_err!("Invalid escape character in LIKE expression"); + let escape_char = match escape_char { + Some(Value::SingleQuotedString(char)) if char.len() == 1 => { + Some(char.chars().next().unwrap()) } - Some(char.chars().next().unwrap()) - } else { - None + Some(value) => return plan_err!("Invalid escape character in LIKE expression. Expected a single character wrapped with single quotes, got {value}"), + None => None, }; Ok(Expr::Like(Like::new( negated, @@ -843,7 +851,7 @@ impl SqlToRel<'_, S> { negated: bool, expr: SQLExpr, pattern: SQLExpr, - escape_char: Option, + escape_char: Option, schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { @@ -852,13 +860,12 @@ impl SqlToRel<'_, S> { if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return plan_err!("Invalid pattern in SIMILAR TO expression"); } - let escape_char = if let Some(char) = escape_char { - if char.len() != 1 { - return plan_err!("Invalid escape character in SIMILAR TO expression"); + let escape_char = match escape_char { + Some(Value::SingleQuotedString(char)) if char.len() == 1 => { + Some(char.chars().next().unwrap()) } - Some(char.chars().next().unwrap()) - } else { - None + Some(value) => return plan_err!("Invalid escape character in SIMILAR TO expression. Expected a single character wrapped with single quotes, got {value}"), + None => None, }; Ok(Expr::SimilarTo(Like::new( negated, diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index cce3f3004809b..79ebc5943ffbe 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -41,13 +41,13 @@ impl SqlToRel<'_, S> { /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, - exprs: Vec, + order_by_exprs: Vec, input_schema: &DFSchema, planner_context: &mut PlannerContext, literal_to_column: bool, additional_schema: Option<&DFSchema>, ) -> Result> { - if exprs.is_empty() { + if order_by_exprs.is_empty() { return Ok(vec![]); } @@ -61,13 +61,23 @@ impl SqlToRel<'_, S> { None => input_schema, }; - let mut expr_vec = vec![]; - for e in exprs { + let mut sort_expr_vec = Vec::with_capacity(order_by_exprs.len()); + + let make_sort_expr = |expr: Expr, + asc: Option, + nulls_first: Option| { + let asc = asc.unwrap_or(true); + let nulls_first = nulls_first + .unwrap_or_else(|| self.options.default_null_ordering.nulls_first(asc)); + Sort::new(expr, asc, nulls_first) + }; + + for order_by_expr in order_by_exprs { let OrderByExpr { expr, options: OrderByOptions { asc, nulls_first }, with_fill, - } = e; + } = order_by_expr; if let Some(with_fill) = with_fill { return not_impl_err!("ORDER BY WITH FILL is not supported: {with_fill}"); @@ -102,15 +112,9 @@ impl SqlToRel<'_, S> { self.sql_expr_to_logical_expr(e, order_by_schema, planner_context)? } }; - let asc = asc.unwrap_or(true); - expr_vec.push(Sort::new( - expr, - asc, - // When asc is true, by default nulls last to be consistent with postgres - // postgres rule: https://www.postgresql.org/docs/current/queries-order.html - nulls_first.unwrap_or(!asc), - )) + sort_expr_vec.push(make_sort_expr(expr, asc, nulls_first)); } - Ok(expr_vec) + + Ok(sort_expr_vec) } } diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index 225c5d74c2abd..24bb813634cc1 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -58,7 +58,7 @@ impl SqlToRel<'_, S> { planner_context.set_outer_query_schema(Some(input_schema.clone().into())); let mut spans = Spans::new(); - if let SetExpr::Select(select) = subquery.body.as_ref() { + if let SetExpr::Select(select) = &subquery.body.as_ref() { for item in &select.projection { if let SelectItem::UnnamedExpr(SQLExpr::Identifier(ident)) = item { if let Some(span) = Span::try_from_sqlparser_span(ident.span) { @@ -138,15 +138,9 @@ impl SqlToRel<'_, S> { if sub_plan.schema().fields().len() > 1 { let sub_schema = sub_plan.schema(); let field_names = sub_schema.field_names(); - - plan_err!("{}: {}", error_message, field_names.join(", ")).map_err(|err| { - let diagnostic = self.build_multi_column_diagnostic( - spans, - error_message, - help_message, - ); - err.with_diagnostic(diagnostic) - }) + let diagnostic = + self.build_multi_column_diagnostic(spans, error_message, help_message); + plan_err!("{}: {}", error_message, field_names.join(", "); diagnostic=diagnostic) } else { Ok(()) } diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index 59c78bc713cc4..0ff361be0e206 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -18,8 +18,8 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, plan_err}; use datafusion_common::{DFSchema, Result, ScalarValue}; -use datafusion_expr::planner::PlannerResult; -use datafusion_expr::Expr; +use datafusion_expr::{planner::PlannerResult, Expr}; + use sqlparser::ast::Expr as SQLExpr; impl SqlToRel<'_, S> { @@ -51,7 +51,7 @@ impl SqlToRel<'_, S> { (None, Some(for_expr)) => { let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let from_logic = Expr::Literal(ScalarValue::Int64(Some(1))); + let from_logic = Expr::Literal(ScalarValue::Int64(Some(1)), None); let for_logic = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; vec![arg, from_logic, for_logic] @@ -62,12 +62,14 @@ impl SqlToRel<'_, S> { substring_from: None, substring_for: None, special: false, + shorthand: false, }; return plan_err!("Substring without for/from is not valid {orig_sql:?}"); } }; + // Try to plan the substring expression using one of the registered planners for planner in self.context_provider.get_expr_planners() { match planner.plan_substring(substring_args)? { PlannerResult::Planned(expr) => return Ok(expr), @@ -77,8 +79,7 @@ impl SqlToRel<'_, S> { } } - not_impl_err!( - "Substring not supported by UserDefinedExtensionPlanners: {substring_args:?}" - ) + not_impl_err!("Substring could not be planned by registered expr planner. \ + Hint: Please try with `unicode_expressions` DataFusion feature enabled") } } diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 626b79d6c3b65..e0c94543f6013 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -45,16 +45,18 @@ impl SqlToRel<'_, S> { { Ok(operand) } else { - plan_err!("Unary operator '+' only supports numeric, interval and timestamp types").map_err(|e| { - let span = operand.spans().and_then(|s| s.first()); - let mut diagnostic = Diagnostic::new_error( - format!("+ cannot be used with {data_type}"), - span - ); - diagnostic.add_note("+ can only be used with numbers, intervals, and timestamps", None); - diagnostic.add_help(format!("perhaps you need to cast {operand}"), None); - e.with_diagnostic(diagnostic) - }) + let span = operand.spans().and_then(|s| s.first()); + let mut diagnostic = Diagnostic::new_error( + format!("+ cannot be used with {data_type}"), + span, + ); + diagnostic.add_note( + "+ can only be used with numbers, intervals, and timestamps", + None, + ); + diagnostic + .add_help(format!("perhaps you need to cast {operand}"), None); + plan_err!("Unary operator '+' only supports numeric, interval and timestamp types"; diagnostic=diagnostic) } } UnaryOperator::Minus => { diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index d53691ef05d17..7075a1afd9dd0 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -50,7 +50,7 @@ impl SqlToRel<'_, S> { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(lit(s)), - Value::Null => Ok(Expr::Literal(ScalarValue::Null)), + Value::Null => Ok(Expr::Literal(ScalarValue::Null, None)), Value::Boolean(n) => Ok(lit(n)), Value::Placeholder(param) => { Self::create_placeholder_expr(param, param_data_types) @@ -131,10 +131,7 @@ impl SqlToRel<'_, S> { // Check if the placeholder is in the parameter list let param_type = param_data_types.get(idx); // Data type of the parameter - debug!( - "type of param {} param_data_types[idx]: {:?}", - param, param_type - ); + debug!("type of param {param} param_data_types[idx]: {param_type:?}"); Ok(Expr::Placeholder(Placeholder::new( param, @@ -301,7 +298,7 @@ fn interval_literal(interval_value: SQLExpr, negative: bool) -> Result { fn try_decode_hex_literal(s: &str) -> Option> { let hex_bytes = s.as_bytes(); - let mut decoded_bytes = Vec::with_capacity((hex_bytes.len() + 1) / 2); + let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2)); let start_idx = hex_bytes.len() % 2; if start_idx > 0 { @@ -383,11 +380,10 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { int_val ) })?; - Ok(Expr::Literal(ScalarValue::Decimal128( - Some(val), - precision as u8, - scale as i8, - ))) + Ok(Expr::Literal( + ScalarValue::Decimal128(Some(val), precision as u8, scale as i8), + None, + )) } else if precision <= DECIMAL256_MAX_PRECISION as u64 { let val = bigint_to_i256(&int_val).ok_or_else(|| { // Failures are unexpected here as we have already checked the precision @@ -396,11 +392,10 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { int_val ) })?; - Ok(Expr::Literal(ScalarValue::Decimal256( - Some(val), - precision as u8, - scale as i8, - ))) + Ok(Expr::Literal( + ScalarValue::Decimal256(Some(val), precision as u8, scale as i8), + None, + )) } else { not_impl_err!( "Decimal precision {} exceeds the maximum supported precision: {}", @@ -486,10 +481,13 @@ mod tests { ]; for (input, expect) in cases { let output = parse_decimal(input, true).unwrap(); - assert_eq!(output, Expr::Literal(expect.arithmetic_negate().unwrap())); + assert_eq!( + output, + Expr::Literal(expect.arithmetic_negate().unwrap(), None) + ); let output = parse_decimal(input, false).unwrap(); - assert_eq!(output, Expr::Literal(expect)); + assert_eq!(output, Expr::Literal(expect, None)); } // scale < i8::MIN diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 7e11f160a3977..da15b90d22a84 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 822b651eae864..271ad8a856b47 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -20,9 +20,9 @@ //! This parser implements DataFusion specific statements such as //! `CREATE EXTERNAL TABLE` -use std::collections::VecDeque; -use std::fmt; - +use datafusion_common::config::SqlParserOptions; +use datafusion_common::DataFusionError; +use datafusion_common::{sql_err, Diagnostic, Span}; use sqlparser::ast::{ExprWithAlias, OrderByOptions}; use sqlparser::tokenizer::TokenWithSpan; use sqlparser::{ @@ -34,15 +34,22 @@ use sqlparser::{ parser::{Parser, ParserError}, tokenizer::{Token, Tokenizer, Word}, }; +use std::collections::VecDeque; +use std::fmt; // Use `Parser::expected` instead, if possible macro_rules! parser_err { - ($MSG:expr) => { - Err(ParserError::ParserError($MSG.to_string())) - }; + ($MSG:expr $(; diagnostic = $DIAG:expr)?) => {{ + + let err = DataFusionError::from(ParserError::ParserError($MSG.to_string())); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } -fn parse_file_type(s: &str) -> Result { +fn parse_file_type(s: &str) -> Result { Ok(s.to_uppercase()) } @@ -140,7 +147,7 @@ impl fmt::Display for CopyToStatement { write!(f, "COPY {source} TO {target}")?; if let Some(file_type) = stored_as { - write!(f, " STORED AS {}", file_type)?; + write!(f, " STORED AS {file_type}")?; } if !partitioned_by.is_empty() { write!(f, " PARTITIONED BY ({})", partitioned_by.join(", "))?; @@ -181,7 +188,9 @@ pub(crate) type LexOrdering = Vec; /// Syntax: /// /// ```text -/// CREATE EXTERNAL TABLE +/// CREATE +/// [ OR REPLACE ] +/// EXTERNAL TABLE /// [ IF NOT EXISTS ] /// [ () ] /// STORED AS @@ -214,6 +223,8 @@ pub struct CreateExternalTable { pub order_exprs: Vec, /// Option to not error if table already exists pub if_not_exists: bool, + /// Option to replace table content if table already exists + pub or_replace: bool, /// Whether the table is a temporary table pub temporary: bool, /// Infinite streams? @@ -266,11 +277,9 @@ impl fmt::Display for Statement { } } -fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { +fn ensure_not_set(field: &Option, name: &str) -> Result<(), DataFusionError> { if field.is_some() { - return Err(ParserError::ParserError(format!( - "{name} specified more than once", - ))); + parser_err!(format!("{name} specified more than once",))? } Ok(()) } @@ -285,6 +294,7 @@ fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { /// [`Statement`] for a list of this special syntax pub struct DFParser<'a> { pub parser: Parser<'a>, + options: SqlParserOptions, } /// Same as `sqlparser` @@ -356,21 +366,28 @@ impl<'a> DFParserBuilder<'a> { self } - pub fn build(self) -> Result, ParserError> { + pub fn build(self) -> Result, DataFusionError> { let mut tokenizer = Tokenizer::new(self.dialect, self.sql); - let tokens = tokenizer.tokenize_with_location()?; + // Convert TokenizerError -> ParserError + let tokens = tokenizer + .tokenize_with_location() + .map_err(ParserError::from)?; Ok(DFParser { parser: Parser::new(self.dialect) .with_tokens_with_locations(tokens) .with_recursion_limit(self.recursion_limit), + options: SqlParserOptions { + recursion_limit: self.recursion_limit, + ..Default::default() + }, }) } } impl<'a> DFParser<'a> { #[deprecated(since = "46.0.0", note = "DFParserBuilder")] - pub fn new(sql: &'a str) -> Result { + pub fn new(sql: &'a str) -> Result { DFParserBuilder::new(sql).build() } @@ -378,13 +395,13 @@ impl<'a> DFParser<'a> { pub fn new_with_dialect( sql: &'a str, dialect: &'a dyn Dialect, - ) -> Result { + ) -> Result { DFParserBuilder::new(sql).with_dialect(dialect).build() } /// Parse a sql string into one or [`Statement`]s using the /// [`GenericDialect`]. - pub fn parse_sql(sql: &'a str) -> Result, ParserError> { + pub fn parse_sql(sql: &'a str) -> Result, DataFusionError> { let mut parser = DFParserBuilder::new(sql).build()?; parser.parse_statements() @@ -395,22 +412,27 @@ impl<'a> DFParser<'a> { pub fn parse_sql_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result, ParserError> { + ) -> Result, DataFusionError> { let mut parser = DFParserBuilder::new(sql).with_dialect(dialect).build()?; parser.parse_statements() } + pub fn parse_sql_into_expr(sql: &str) -> Result { + DFParserBuilder::new(sql).build()?.parse_into_expr() + } + pub fn parse_sql_into_expr_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result { - let mut parser = DFParserBuilder::new(sql).with_dialect(dialect).build()?; - - parser.parse_expr() + ) -> Result { + DFParserBuilder::new(sql) + .with_dialect(dialect) + .build()? + .parse_into_expr() } /// Parse a sql string into one or [`Statement`]s - pub fn parse_statements(&mut self) -> Result, ParserError> { + pub fn parse_statements(&mut self) -> Result, DataFusionError> { let mut stmts = VecDeque::new(); let mut expecting_statement_delimiter = false; loop { @@ -438,12 +460,35 @@ impl<'a> DFParser<'a> { &self, expected: &str, found: TokenWithSpan, - ) -> Result { - parser_err!(format!("Expected {expected}, found: {found}")) + ) -> Result { + let sql_parser_span = found.span; + let span = Span::try_from_sqlparser_span(sql_parser_span); + let diagnostic = Diagnostic::new_error( + format!("Expected: {expected}, found: {found}{}", found.span.start), + span, + ); + parser_err!( + format!("Expected: {expected}, found: {found}{}", found.span.start); + diagnostic= + diagnostic + ) + } + + fn expect_token( + &mut self, + expected: &str, + token: Token, + ) -> Result<(), DataFusionError> { + let next_token = self.parser.peek_token_ref(); + if next_token.token != token { + self.expected(expected, next_token.clone()) + } else { + Ok(()) + } } /// Parse a new expression - pub fn parse_statement(&mut self) -> Result { + pub fn parse_statement(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.keyword { @@ -455,9 +500,7 @@ impl<'a> DFParser<'a> { if let Token::Word(w) = self.parser.peek_nth_token(1).token { // use native parser for COPY INTO if w.keyword == Keyword::INTO { - return Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))); + return self.parse_and_handle_statement(); } } self.parser.next_token(); // COPY @@ -469,36 +512,59 @@ impl<'a> DFParser<'a> { } _ => { // use sqlparser-rs parser - Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))) + self.parse_and_handle_statement() } } } _ => { // use the native parser - Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))) + self.parse_and_handle_statement() } } } - pub fn parse_expr(&mut self) -> Result { + pub fn parse_expr(&mut self) -> Result { if let Token::Word(w) = self.parser.peek_token().token { match w.keyword { Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => { - return parser_err!("Unsupported command in expression"); + return parser_err!("Unsupported command in expression")?; } _ => {} } } - self.parser.parse_expr_with_alias() + Ok(self.parser.parse_expr_with_alias()?) + } + + /// Parses the entire SQL string into an expression. + /// + /// In contrast to [`DFParser::parse_expr`], this function will report an error if the input + /// contains any trailing, unparsed tokens. + pub fn parse_into_expr(&mut self) -> Result { + let expr = self.parse_expr()?; + self.expect_token("end of expression", Token::EOF)?; + Ok(expr) + } + + /// Helper method to parse a statement and handle errors consistently, especially for recursion limits + fn parse_and_handle_statement(&mut self) -> Result { + self.parser + .parse_statement() + .map(|stmt| Statement::Statement(Box::from(stmt))) + .map_err(|e| match e { + ParserError::RecursionLimitExceeded => DataFusionError::SQL( + Box::new(ParserError::RecursionLimitExceeded), + Some(format!( + " (current limit: {})", + self.options.recursion_limit + )), + ), + other => DataFusionError::SQL(Box::new(other), None), + }) } /// Parse a SQL `COPY TO` statement - pub fn parse_copy(&mut self) -> Result { + pub fn parse_copy(&mut self) -> Result { // parse as a query let source = if self.parser.consume_token(&Token::LParen) { let query = self.parser.parse_query()?; @@ -541,7 +607,7 @@ impl<'a> DFParser<'a> { Keyword::WITH => { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')"); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')")?; } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -561,17 +627,13 @@ impl<'a> DFParser<'a> { if token == Token::EOF || token == Token::SemiColon { break; } else { - return Err(ParserError::ParserError(format!( - "Unexpected token {token}" - ))); + return self.expected("end of statement or ;", token)?; } } } let Some(target) = builder.target else { - return Err(ParserError::ParserError( - "Missing TO clause in COPY statement".into(), - )); + return parser_err!("Missing TO clause in COPY statement")?; }; Ok(Statement::CopyTo(CopyToStatement { @@ -589,7 +651,7 @@ impl<'a> DFParser<'a> { /// because it allows keywords as well as other non words /// /// [`parse_literal_string`]: sqlparser::parser::Parser::parse_literal_string - pub fn parse_option_key(&mut self) -> Result { + pub fn parse_option_key(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { Token::Word(Word { value, .. }) => { @@ -602,7 +664,7 @@ impl<'a> DFParser<'a> { // Unquoted namespaced keys have to conform to the syntax // "[\.]*". If we have a key that breaks this // pattern, error out: - return self.parser.expected("key name", next_token); + return self.expected("key name", next_token); } } Ok(parts.join(".")) @@ -610,7 +672,7 @@ impl<'a> DFParser<'a> { Token::SingleQuotedString(s) => Ok(s), Token::DoubleQuotedString(s) => Ok(s), Token::EscapedStringLiteral(s) => Ok(s), - _ => self.parser.expected("key name", next_token), + _ => self.expected("key name", next_token), } } @@ -620,7 +682,7 @@ impl<'a> DFParser<'a> { /// word or keyword in this location. /// /// [`parse_value`]: sqlparser::parser::Parser::parse_value - pub fn parse_option_value(&mut self) -> Result { + pub fn parse_option_value(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { // e.g. things like "snappy" or "gzip" that may be keywords @@ -629,12 +691,12 @@ impl<'a> DFParser<'a> { Token::DoubleQuotedString(s) => Ok(Value::DoubleQuotedString(s)), Token::EscapedStringLiteral(s) => Ok(Value::EscapedStringLiteral(s)), Token::Number(n, l) => Ok(Value::Number(n, l)), - _ => self.parser.expected("string or numeric value", next_token), + _ => self.expected("string or numeric value", next_token), } } /// Parse a SQL `EXPLAIN` - pub fn parse_explain(&mut self) -> Result { + pub fn parse_explain(&mut self) -> Result { let analyze = self.parser.parse_keyword(Keyword::ANALYZE); let verbose = self.parser.parse_keyword(Keyword::VERBOSE); let format = self.parse_explain_format()?; @@ -649,7 +711,7 @@ impl<'a> DFParser<'a> { })) } - pub fn parse_explain_format(&mut self) -> Result, ParserError> { + pub fn parse_explain_format(&mut self) -> Result, DataFusionError> { if !self.parser.parse_keyword(Keyword::FORMAT) { return Ok(None); } @@ -659,26 +721,39 @@ impl<'a> DFParser<'a> { Token::Word(w) => Ok(w.value), Token::SingleQuotedString(w) => Ok(w), Token::DoubleQuotedString(w) => Ok(w), - _ => self - .parser - .expected("an explain format such as TREE", next_token), + _ => self.expected("an explain format such as TREE", next_token), }?; Ok(Some(format)) } /// Parse a SQL `CREATE` statement handling `CREATE EXTERNAL TABLE` - pub fn parse_create(&mut self) -> Result { - if self.parser.parse_keyword(Keyword::EXTERNAL) { - self.parse_create_external_table(false) - } else if self.parser.parse_keyword(Keyword::UNBOUNDED) { - self.parser.expect_keyword(Keyword::EXTERNAL)?; - self.parse_create_external_table(true) + pub fn parse_create(&mut self) -> Result { + // TODO: Change sql parser to take in `or_replace: bool` inside parse_create() + if self + .parser + .parse_keywords(&[Keyword::OR, Keyword::REPLACE, Keyword::EXTERNAL]) + { + self.parse_create_external_table(false, true) + } else if self.parser.parse_keywords(&[ + Keyword::OR, + Keyword::REPLACE, + Keyword::UNBOUNDED, + Keyword::EXTERNAL, + ]) { + self.parse_create_external_table(true, true) + } else if self.parser.parse_keyword(Keyword::EXTERNAL) { + self.parse_create_external_table(false, false) + } else if self + .parser + .parse_keywords(&[Keyword::UNBOUNDED, Keyword::EXTERNAL]) + { + self.parse_create_external_table(true, false) } else { Ok(Statement::Statement(Box::from(self.parser.parse_create()?))) } } - fn parse_partitions(&mut self) -> Result, ParserError> { + fn parse_partitions(&mut self) -> Result, DataFusionError> { let mut partitions: Vec = vec![]; if !self.parser.consume_token(&Token::LParen) || self.parser.consume_token(&Token::RParen) @@ -708,7 +783,7 @@ impl<'a> DFParser<'a> { } /// Parse the ordering clause of a `CREATE EXTERNAL TABLE` SQL statement - pub fn parse_order_by_exprs(&mut self) -> Result, ParserError> { + pub fn parse_order_by_exprs(&mut self) -> Result, DataFusionError> { let mut values = vec![]; self.parser.expect_token(&Token::LParen)?; loop { @@ -721,7 +796,7 @@ impl<'a> DFParser<'a> { } /// Parse an ORDER BY sub-expression optionally followed by ASC or DESC. - pub fn parse_order_by_expr(&mut self) -> Result { + pub fn parse_order_by_expr(&mut self) -> Result { let expr = self.parser.parse_expr()?; let asc = if self.parser.parse_keyword(Keyword::ASC) { @@ -753,7 +828,7 @@ impl<'a> DFParser<'a> { // This is a copy of the equivalent implementation in sqlparser. fn parse_columns( &mut self, - ) -> Result<(Vec, Vec), ParserError> { + ) -> Result<(Vec, Vec), DataFusionError> { let mut columns = vec![]; let mut constraints = vec![]; if !self.parser.consume_token(&Token::LParen) @@ -789,7 +864,7 @@ impl<'a> DFParser<'a> { Ok((columns, constraints)) } - fn parse_column_def(&mut self) -> Result { + fn parse_column_def(&mut self) -> Result { let name = self.parser.parse_identifier()?; let data_type = self.parser.parse_data_type()?; let mut options = vec![]; @@ -820,15 +895,22 @@ impl<'a> DFParser<'a> { fn parse_create_external_table( &mut self, unbounded: bool, - ) -> Result { + or_replace: bool, + ) -> Result { let temporary = self .parser .parse_one_of_keywords(&[Keyword::TEMP, Keyword::TEMPORARY]) .is_some(); + self.parser.expect_keyword(Keyword::TABLE)?; let if_not_exists = self.parser .parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); + + if if_not_exists && or_replace { + return parser_err!("'IF NOT EXISTS' cannot coexist with 'REPLACE'"); + } + let table_name = self.parser.parse_object_name(true)?; let (mut columns, constraints) = self.parse_columns()?; @@ -868,15 +950,15 @@ impl<'a> DFParser<'a> { } else { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)"); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)")?; } } Keyword::DELIMITER => { - return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')"); + return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')")?; } Keyword::COMPRESSION => { self.parser.expect_keyword(Keyword::TYPE)?; - return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)"); + return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)")?; } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -899,7 +981,7 @@ impl<'a> DFParser<'a> { columns.extend(cols); if !cons.is_empty() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Constraints on Partition Columns are not supported" .to_string(), )); @@ -919,21 +1001,19 @@ impl<'a> DFParser<'a> { if token == Token::EOF || token == Token::SemiColon { break; } else { - return Err(ParserError::ParserError(format!( - "Unexpected token {token}" - ))); + return self.expected("end of statement or ;", token)?; } } } // Validations: location and file_type are required if builder.file_type.is_none() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Missing STORED AS clause in CREATE EXTERNAL TABLE statement".into(), )); } if builder.location.is_none() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Missing LOCATION clause in CREATE EXTERNAL TABLE statement".into(), )); } @@ -946,6 +1026,7 @@ impl<'a> DFParser<'a> { table_partition_cols: builder.table_partition_cols.unwrap_or(vec![]), order_exprs: builder.order_exprs, if_not_exists, + or_replace, temporary, unbounded, options: builder.options.unwrap_or(Vec::new()), @@ -955,7 +1036,7 @@ impl<'a> DFParser<'a> { } /// Parses the set of valid formats - fn parse_file_format(&mut self) -> Result { + fn parse_file_format(&mut self) -> Result { let token = self.parser.next_token(); match &token.token { Token::Word(w) => parse_file_type(&w.value), @@ -967,7 +1048,7 @@ impl<'a> DFParser<'a> { /// /// This method supports keywords as key names as well as multiple /// value types such as Numbers as well as Strings. - fn parse_value_options(&mut self) -> Result, ParserError> { + fn parse_value_options(&mut self) -> Result, DataFusionError> { let mut options = vec![]; self.parser.expect_token(&Token::LParen)?; @@ -995,11 +1076,13 @@ mod tests { use super::*; use datafusion_common::assert_contains; use sqlparser::ast::Expr::Identifier; - use sqlparser::ast::{BinaryOperator, DataType, Expr, Ident}; + use sqlparser::ast::{ + BinaryOperator, DataType, ExactNumberInfo, Expr, Ident, ValueWithSpan, + }; use sqlparser::dialect::SnowflakeDialect; use sqlparser::tokenizer::Span; - fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), ParserError> { + fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), DataFusionError> { let statements = DFParser::parse_sql(sql)?; assert_eq!( statements.len(), @@ -1041,7 +1124,7 @@ mod tests { } #[test] - fn create_external_table() -> Result<(), ParserError> { + fn create_external_table() -> Result<(), DataFusionError> { // positive case let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'"; let display = None; @@ -1054,6 +1137,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1071,6 +1155,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1089,6 +1174,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1107,6 +1193,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![( @@ -1128,6 +1215,7 @@ mod tests { table_partition_cols: vec!["p1".to_string(), "p2".to_string()], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1156,6 +1244,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![( @@ -1177,6 +1266,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1194,6 +1284,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1211,6 +1302,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1229,6 +1321,26 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: true, + or_replace: false, + temporary: false, + unbounded: false, + options: vec![], + constraints: vec![], + }); + expect_parse_ok(sql, expected)?; + + // positive case: or replace + let sql = + "CREATE OR REPLACE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet'"; + let expected = Statement::CreateExternalTable(CreateExternalTable { + name: name.clone(), + columns: vec![], + file_type: "PARQUET".to_string(), + location: "foo.parquet".into(), + table_partition_cols: vec![], + order_exprs: vec![], + if_not_exists: false, + or_replace: true, temporary: false, unbounded: false, options: vec![], @@ -1250,6 +1362,7 @@ mod tests { table_partition_cols: vec!["p1".to_string()], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1262,13 +1375,13 @@ mod tests { "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int, c1) LOCATION 'foo.csv'"; expect_parse_error( sql, - "sql parser error: Expected: a data type name, found: )", + "SQL error: ParserError(\"Expected: a data type name, found: ) at Line: 1, Column: 73\")", ); // negative case: mixed column defs and column names in `PARTITIONED BY` clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c1, p1 int) LOCATION 'foo.csv'"; - expect_parse_error(sql, "sql parser error: Expected ',' or ')' after partition definition, found: int"); + expect_parse_error(sql, "SQL error: ParserError(\"Expected: ',' or ')' after partition definition, found: int at Line: 1, Column: 70\")"); // positive case: additional options (one entry) can be specified let sql = @@ -1281,6 +1394,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![("k1".into(), Value::SingleQuotedString("v1".into()))], @@ -1299,6 +1413,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![ @@ -1347,6 +1462,7 @@ mod tests { with_fill: None, }]], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1394,6 +1510,7 @@ mod tests { }, ]], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1434,6 +1551,7 @@ mod tests { with_fill: None, }]], if_not_exists: false, + or_replace: false, temporary: false, unbounded: false, options: vec![], @@ -1441,7 +1559,7 @@ mod tests { }); expect_parse_ok(sql, expected)?; - // Most complete CREATE EXTERNAL TABLE statement possible + // Most complete CREATE EXTERNAL TABLE statement possible (using IF NOT EXISTS) let sql = " CREATE UNBOUNDED EXTERNAL TABLE IF NOT EXISTS t (c1 int, c2 float) STORED AS PARQUET @@ -1457,7 +1575,7 @@ mod tests { name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), - make_column_def("c2", DataType::Float(None)), + make_column_def("c2", DataType::Float(ExactNumberInfo::None)), ], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), @@ -1483,6 +1601,75 @@ mod tests { with_fill: None, }]], if_not_exists: true, + or_replace: false, + temporary: false, + unbounded: true, + options: vec![ + ( + "format.compression".into(), + Value::SingleQuotedString("zstd".into()), + ), + ( + "format.delimiter".into(), + Value::SingleQuotedString("*".into()), + ), + ( + "ROW_GROUP_SIZE".into(), + Value::SingleQuotedString("1024".into()), + ), + ("TRUNCATE".into(), Value::SingleQuotedString("NO".into())), + ( + "format.has_header".into(), + Value::SingleQuotedString("true".into()), + ), + ], + constraints: vec![], + }); + expect_parse_ok(sql, expected)?; + + // Most complete CREATE EXTERNAL TABLE statement possible (using OR REPLACE) + let sql = " + CREATE OR REPLACE UNBOUNDED EXTERNAL TABLE t (c1 int, c2 float) + STORED AS PARQUET + WITH ORDER (c1 - c2 ASC) + PARTITIONED BY (c1) + LOCATION 'foo.parquet' + OPTIONS ('format.compression' 'zstd', + 'format.delimiter' '*', + 'ROW_GROUP_SIZE' '1024', + 'TRUNCATE' 'NO', + 'format.has_header' 'true')"; + let expected = Statement::CreateExternalTable(CreateExternalTable { + name: name.clone(), + columns: vec![ + make_column_def("c1", DataType::Int(None)), + make_column_def("c2", DataType::Float(ExactNumberInfo::None)), + ], + file_type: "PARQUET".to_string(), + location: "foo.parquet".into(), + table_partition_cols: vec!["c1".into()], + order_exprs: vec![vec![OrderByExpr { + expr: Expr::BinaryOp { + left: Box::new(Identifier(Ident { + value: "c1".to_owned(), + quote_style: None, + span: Span::empty(), + })), + op: BinaryOperator::Minus, + right: Box::new(Identifier(Ident { + value: "c2".to_owned(), + quote_style: None, + span: Span::empty(), + })), + }, + options: OrderByOptions { + asc: Some(true), + nulls_first: None, + }, + with_fill: None, + }]], + if_not_exists: false, + or_replace: true, temporary: false, unbounded: true, options: vec![ @@ -1514,7 +1701,7 @@ mod tests { } #[test] - fn copy_to_table_to_table() -> Result<(), ParserError> { + fn copy_to_table_to_table() -> Result<(), DataFusionError> { // positive case let sql = "COPY foo TO bar STORED AS CSV"; let expected = Statement::CopyTo(CopyToStatement { @@ -1530,7 +1717,7 @@ mod tests { } #[test] - fn skip_copy_into_snowflake() -> Result<(), ParserError> { + fn skip_copy_into_snowflake() -> Result<(), DataFusionError> { let sql = "COPY INTO foo FROM @~/staged FILE_FORMAT = (FORMAT_NAME = 'mycsv');"; let dialect = Box::new(SnowflakeDialect); let statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; @@ -1547,7 +1734,7 @@ mod tests { } #[test] - fn explain_copy_to_table_to_table() -> Result<(), ParserError> { + fn explain_copy_to_table_to_table() -> Result<(), DataFusionError> { let cases = vec![ ("EXPLAIN COPY foo TO bar STORED AS PARQUET", false, false), ( @@ -1588,7 +1775,7 @@ mod tests { } #[test] - fn copy_to_query_to_table() -> Result<(), ParserError> { + fn copy_to_query_to_table() -> Result<(), DataFusionError> { let statement = verified_stmt("SELECT 1"); // unwrap the various layers @@ -1621,7 +1808,7 @@ mod tests { } #[test] - fn copy_to_options() -> Result<(), ParserError> { + fn copy_to_options() -> Result<(), DataFusionError> { let sql = "COPY foo TO bar STORED AS CSV OPTIONS ('row_group_size' '55')"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), @@ -1638,7 +1825,7 @@ mod tests { } #[test] - fn copy_to_partitioned_by() -> Result<(), ParserError> { + fn copy_to_partitioned_by() -> Result<(), DataFusionError> { let sql = "COPY foo TO bar STORED AS CSV PARTITIONED BY (a) OPTIONS ('row_group_size' '55')"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), @@ -1655,7 +1842,7 @@ mod tests { } #[test] - fn copy_to_multi_options() -> Result<(), ParserError> { + fn copy_to_multi_options() -> Result<(), DataFusionError> { // order of options is preserved let sql = "COPY foo TO bar STORED AS parquet OPTIONS ('format.row_group_size' 55, 'format.compression' snappy, 'execution.keep_partition_by_columns' true)"; @@ -1754,7 +1941,86 @@ mod tests { assert_contains!( err.to_string(), - "sql parser error: recursion limit exceeded" + "SQL error: RecursionLimitExceeded (current limit: 1)" ); } + + fn expect_parse_expr_ok(sql: &str, expected: ExprWithAlias) { + let expr = DFParser::parse_sql_into_expr(sql).unwrap(); + assert_eq!(expr, expected, "actual:\n{expr:#?}"); + } + + /// Parses sql and asserts that the expected error message was found + fn expect_parse_expr_error(sql: &str, expected_error: &str) { + match DFParser::parse_sql_into_expr(sql) { + Ok(expr) => { + panic!("Expected parse error for '{sql}', but was successful: {expr:#?}"); + } + Err(e) => { + let error_message = e.to_string(); + assert!( + error_message.contains(expected_error), + "Expected error '{expected_error}' not found in actual error '{error_message}'" + ); + } + } + } + + #[test] + fn literal() { + expect_parse_expr_ok( + "1234", + ExprWithAlias { + expr: Expr::Value(ValueWithSpan::from(Value::Number( + "1234".to_string(), + false, + ))), + alias: None, + }, + ) + } + + #[test] + fn literal_with_alias() { + expect_parse_expr_ok( + "1234 as foo", + ExprWithAlias { + expr: Expr::Value(ValueWithSpan::from(Value::Number( + "1234".to_string(), + false, + ))), + alias: Some(Ident::from("foo")), + }, + ) + } + + #[test] + fn literal_with_alias_and_trailing_tokens() { + expect_parse_expr_error( + "1234 as foo.bar", + "Expected: end of expression, found: .", + ) + } + + #[test] + fn literal_with_alias_and_trailing_whitespace() { + expect_parse_expr_ok( + "1234 as foo ", + ExprWithAlias { + expr: Expr::Value(ValueWithSpan::from(Value::Number( + "1234".to_string(), + false, + ))), + alias: Some(Ident::from("foo")), + }, + ) + } + + #[test] + fn literal_with_alias_and_trailing_whitespace_and_token() { + expect_parse_expr_error( + "1234 as foo bar", + "Expected: end of expression, found: bar", + ) + } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 180017ee9c191..e93c5e066b662 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -17,6 +17,7 @@ //! [`SqlToRel`]: SQL Query Planner (produces [`LogicalPlan`] from SQL AST) use std::collections::HashMap; +use std::str::FromStr; use std::sync::Arc; use std::vec; @@ -52,8 +53,10 @@ pub struct ParserOptions { pub enable_options_value_normalization: bool, /// Whether to collect spans pub collect_spans: bool, - /// Whether `VARCHAR` is mapped to `Utf8View` during SQL planning. - pub map_varchar_to_utf8view: bool, + /// Whether string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. + pub map_string_types_to_utf8view: bool, + /// Default null ordering for sorting expressions. + pub default_null_ordering: NullOrdering, } impl ParserOptions { @@ -72,9 +75,12 @@ impl ParserOptions { parse_float_as_decimal: false, enable_ident_normalization: true, support_varchar_with_length: true, - map_varchar_to_utf8view: false, + map_string_types_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, + // By default, `nulls_max` is used to follow Postgres's behavior. + // postgres rule: https://www.postgresql.org/docs/current/queries-order.html + default_null_ordering: NullOrdering::NullsMax, } } @@ -112,9 +118,9 @@ impl ParserOptions { self } - /// Sets the `map_varchar_to_utf8view` option. - pub fn with_map_varchar_to_utf8view(mut self, value: bool) -> Self { - self.map_varchar_to_utf8view = value; + /// Sets the `map_string_types_to_utf8view` option. + pub fn with_map_string_types_to_utf8view(mut self, value: bool) -> Self { + self.map_string_types_to_utf8view = value; self } @@ -129,6 +135,12 @@ impl ParserOptions { self.collect_spans = value; self } + + /// Sets the `default_null_ordering` option. + pub fn with_default_null_ordering(mut self, value: NullOrdering) -> Self { + self.default_null_ordering = value; + self + } } impl Default for ParserOptions { @@ -143,14 +155,64 @@ impl From<&SqlParserOptions> for ParserOptions { parse_float_as_decimal: options.parse_float_as_decimal, enable_ident_normalization: options.enable_ident_normalization, support_varchar_with_length: options.support_varchar_with_length, - map_varchar_to_utf8view: options.map_varchar_to_utf8view, + map_string_types_to_utf8view: options.map_string_types_to_utf8view, enable_options_value_normalization: options .enable_options_value_normalization, collect_spans: options.collect_spans, + default_null_ordering: options.default_null_ordering.as_str().into(), } } } +/// Represents the null ordering for sorting expressions. +#[derive(Debug, Clone, Copy)] +pub enum NullOrdering { + /// Nulls appear last in ascending order. + NullsMax, + /// Nulls appear first in descending order. + NullsMin, + /// Nulls appear first. + NullsFirst, + /// Nulls appear last. + NullsLast, +} + +impl NullOrdering { + /// Evaluates the null ordering based on the given ascending flag. + /// + /// # Returns + /// * `true` if nulls should appear first. + /// * `false` if nulls should appear last. + pub fn nulls_first(&self, asc: bool) -> bool { + match self { + Self::NullsMax => !asc, + Self::NullsMin => asc, + Self::NullsFirst => true, + Self::NullsLast => false, + } + } +} + +impl FromStr for NullOrdering { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s { + "nulls_max" => Ok(Self::NullsMax), + "nulls_min" => Ok(Self::NullsMin), + "nulls_first" => Ok(Self::NullsFirst), + "nulls_last" => Ok(Self::NullsLast), + _ => plan_err!("Unknown null ordering: Expected one of 'nulls_first', 'nulls_last', 'nulls_min' or 'nulls_max'. Got {s}"), + } + } +} + +impl From<&str> for NullOrdering { + fn from(s: &str) -> Self { + Self::from_str(s).unwrap_or(Self::NullsMax) + } +} + /// Ident Normalizer #[derive(Debug)] pub struct IdentNormalizer { @@ -331,7 +393,7 @@ impl PlannerContext { /// /// Key interfaces are: /// * [`Self::sql_statement_to_plan`]: Convert a statement -/// (e.g. `SELECT ...`) into a [`LogicalPlan`] +/// (e.g. `SELECT ...`) into a [`LogicalPlan`] /// * [`Self::sql_to_expr`]: Convert an expression (e.g. `1 + 2`) into an [`Expr`] pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, @@ -391,7 +453,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Default expressions are restricted, column references are not allowed let empty_schema = DFSchema::empty(); let error_desc = |e: DataFusionError| match e { - DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _) => { + DataFusionError::SchemaError(ref err, _) + if matches!(**err, SchemaError::FieldNotFound { .. }) => + { plan_datafusion_err!( "Column reference is not allowed in the DEFAULT expression : {}", e @@ -483,13 +547,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } .map_err(|err: DataFusionError| match &err { - DataFusionError::SchemaError( - SchemaError::FieldNotFound { + DataFusionError::SchemaError(inner, _) + if matches!( + inner.as_ref(), + SchemaError::FieldNotFound { .. } + ) => + { + let SchemaError::FieldNotFound { field, valid_fields, - }, - _, - ) => { + } = inner.as_ref() + else { + unreachable!() + }; let mut diagnostic = if let Some(relation) = &col.relation { Diagnostic::new_error( format!( @@ -577,7 +647,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { please set `support_varchar_with_length` to be true" ), _ => { - if self.options.map_varchar_to_utf8view { + if self.options.map_string_types_to_utf8view { Ok(DataType::Utf8View) } else { Ok(DataType::Utf8) @@ -601,7 +671,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) } SQLDataType::Char(_) | SQLDataType::Text | SQLDataType::String(_) => { - Ok(DataType::Utf8) + if self.options.map_string_types_to_utf8view { + Ok(DataType::Utf8View) + } else { + Ok(DataType::Utf8) + } } SQLDataType::Timestamp(precision, tz_info) if precision.is_none() || [0, 3, 6, 9].contains(&precision.unwrap()) => @@ -612,7 +686,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Timestamp With Time Zone // INPUT : [SQLDataType] TimestampTz + [Config] Time Zone // OUTPUT: [ArrowDataType] Timestamp - self.context_provider.options().execution.time_zone.clone() + Some(self.context_provider.options().execution.time_zone.clone()) } else { // Timestamp Without Time zone None @@ -634,7 +708,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(DataType::Time64(TimeUnit::Nanosecond)) } else { // We don't support TIMETZ and TIME WITH TIME ZONE for now - not_impl_err!("Unsupported SQL type {sql_type:?}") + not_impl_err!("Unsupported SQL type {sql_type}") } } SQLDataType::Numeric(exact_number_info) @@ -646,10 +720,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (Some(precision), Some(scale)) } }; - make_decimal_type(precision, scale) + make_decimal_type(precision, scale.map(|s| s as u64)) } SQLDataType::Bytea => Ok(DataType::Binary), - SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + SQLDataType::Interval { fields, precision } => { + if fields.is_some() || precision.is_some() { + return not_impl_err!("Unsupported SQL type {sql_type}"); + } + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } SQLDataType::Struct(fields, _) => { let fields = fields .iter() @@ -735,8 +814,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::AnyType | SQLDataType::Table(_) | SQLDataType::VarBit(_) - | SQLDataType::GeometricType(_) => { - not_impl_err!("Unsupported SQL type {sql_type:?}") + | SQLDataType::UTinyInt + | SQLDataType::USmallInt + | SQLDataType::HugeInt + | SQLDataType::UHugeInt + | SQLDataType::UBigInt + | SQLDataType::TimestampNtz + | SQLDataType::NamedTable { .. } + | SQLDataType::TsVector + | SQLDataType::TsQuery + | SQLDataType::GeometricType(_) + | SQLDataType::DecimalUnsigned(_) // deprecated mysql type + | SQLDataType::FloatUnsigned(_) // deprecated mysql type + | SQLDataType::RealUnsigned // deprecated mysql type + | SQLDataType::DecUnsigned(_) // deprecated mysql type + | SQLDataType::DoubleUnsigned(_) // deprecated mysql type + | SQLDataType::DoublePrecisionUnsigned // deprecated mysql type + => { + not_impl_err!("Unsupported SQL type {sql_type}") } } } @@ -816,7 +911,7 @@ impl std::fmt::Display for IdentTaker { if !first { write!(f, ".")?; } - write!(f, "{}", ident)?; + write!(f, "{ident}")?; first = false; } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ea641320c01b4..d316550f4dd21 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -21,15 +21,18 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::stack::StackGuard; use datafusion_common::{not_impl_err, Constraints, DFSchema, Result}; -use datafusion_expr::expr::Sort; +use datafusion_expr::expr::{Sort, WildcardOptions}; + use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, LogicalPlan, LogicalPlanBuilder, + CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderBy, OrderByExpr, OrderByKind, Query, - SelectInto, SetExpr, + Expr as SQLExpr, ExprWithAliasAndOrderBy, Ident, LimitClause, Offset, OffsetRows, + OrderBy, OrderByExpr, OrderByKind, PipeOperator, Query, SelectInto, SetExpr, + SetOperator, SetQuantifier, TableAlias, }; +use sqlparser::tokenizer::Span; impl SqlToRel<'_, S> { /// Generate a logical plan from an SQL query/subquery @@ -48,13 +51,12 @@ impl SqlToRel<'_, S> { } let set_expr = *query.body; - match set_expr { + let plan = match set_expr { SetExpr::Select(mut select) => { let select_into = select.into.take(); let plan = self.select_to_plan(*select, query.order_by, planner_context)?; - let plan = - self.limit(plan, query.offset, query.limit, planner_context)?; + let plan = self.limit(plan, query.limit_clause, planner_context)?; // Process the `SELECT INTO` after `LIMIT`. self.select_into(plan, select_into) } @@ -76,32 +78,203 @@ impl SqlToRel<'_, S> { None, )?; let plan = self.order_by(plan, order_by_rex)?; - self.limit(plan, query.offset, query.limit, planner_context) + self.limit(plan, query.limit_clause, planner_context) } + }?; + + self.pipe_operators(plan, query.pipe_operators, planner_context) + } + + /// Apply pipe operators to a plan + fn pipe_operators( + &self, + mut plan: LogicalPlan, + pipe_operators: Vec, + planner_context: &mut PlannerContext, + ) -> Result { + for pipe_operator in pipe_operators { + plan = self.pipe_operator(plan, pipe_operator, planner_context)?; } + Ok(plan) + } + + /// Apply a pipe operator to a plan + fn pipe_operator( + &self, + plan: LogicalPlan, + pipe_operator: PipeOperator, + planner_context: &mut PlannerContext, + ) -> Result { + match pipe_operator { + PipeOperator::Where { expr } => { + self.plan_selection(Some(expr), plan, planner_context) + } + PipeOperator::OrderBy { exprs } => { + let sort_exprs = self.order_by_to_sort_expr( + exprs, + plan.schema(), + planner_context, + true, + None, + )?; + self.order_by(plan, sort_exprs) + } + PipeOperator::Limit { expr, offset } => self.limit( + plan, + Some(LimitClause::LimitOffset { + limit: Some(expr), + offset: offset.map(|offset| Offset { + value: offset, + rows: OffsetRows::None, + }), + limit_by: vec![], + }), + planner_context, + ), + PipeOperator::Select { exprs } => { + let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); + let select_exprs = + self.prepare_select_exprs(&plan, exprs, empty_from, planner_context)?; + self.project(plan, select_exprs) + } + PipeOperator::Extend { exprs } => { + let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); + let extend_exprs = + self.prepare_select_exprs(&plan, exprs, empty_from, planner_context)?; + let all_exprs = + std::iter::once(SelectExpr::Wildcard(WildcardOptions::default())) + .chain(extend_exprs) + .collect(); + self.project(plan, all_exprs) + } + PipeOperator::As { alias } => self.apply_table_alias( + plan, + TableAlias { + name: alias, + // Apply to all fields + columns: vec![], + }, + ), + PipeOperator::Union { + set_quantifier, + queries, + } => self.pipe_operator_set( + plan, + SetOperator::Union, + set_quantifier, + queries, + planner_context, + ), + PipeOperator::Intersect { + set_quantifier, + queries, + } => self.pipe_operator_set( + plan, + SetOperator::Intersect, + set_quantifier, + queries, + planner_context, + ), + PipeOperator::Except { + set_quantifier, + queries, + } => self.pipe_operator_set( + plan, + SetOperator::Except, + set_quantifier, + queries, + planner_context, + ), + PipeOperator::Aggregate { + full_table_exprs, + group_by_expr, + } => self.pipe_operator_aggregate( + plan, + full_table_exprs, + group_by_expr, + planner_context, + ), + PipeOperator::Join(join) => { + self.parse_relation_join(plan, join, planner_context) + } + + x => not_impl_err!("`{x}` pipe operator is not supported yet"), + } + } + + /// Handle Union/Intersect/Except pipe operators + fn pipe_operator_set( + &self, + mut plan: LogicalPlan, + set_operator: SetOperator, + set_quantifier: SetQuantifier, + queries: Vec, + planner_context: &mut PlannerContext, + ) -> Result { + for query in queries { + let right_plan = self.query_to_plan(query, planner_context)?; + plan = self.set_operation_to_plan( + set_operator, + plan, + right_plan, + set_quantifier, + )?; + } + + Ok(plan) } /// Wrap a plan in a limit fn limit( &self, input: LogicalPlan, - skip: Option, - fetch: Option, + limit_clause: Option, planner_context: &mut PlannerContext, ) -> Result { - if skip.is_none() && fetch.is_none() { + let Some(limit_clause) = limit_clause else { return Ok(input); - } + }; - // skip and fetch expressions are not allowed to reference columns from the input plan let empty_schema = DFSchema::empty(); - let skip = skip - .map(|o| self.sql_to_expr(o.value, &empty_schema, planner_context)) - .transpose()?; - let fetch = fetch - .map(|e| self.sql_to_expr(e, &empty_schema, planner_context)) - .transpose()?; + let (skip, fetch, limit_by_exprs) = match limit_clause { + LimitClause::LimitOffset { + limit, + offset, + limit_by, + } => { + let skip = offset + .map(|o| self.sql_to_expr(o.value, &empty_schema, planner_context)) + .transpose()?; + + let fetch = limit + .map(|e| self.sql_to_expr(e, &empty_schema, planner_context)) + .transpose()?; + + let limit_by_exprs = limit_by + .into_iter() + .map(|e| self.sql_to_expr(e, &empty_schema, planner_context)) + .collect::>>()?; + + (skip, fetch, limit_by_exprs) + } + LimitClause::OffsetCommaLimit { offset, limit } => { + let skip = + Some(self.sql_to_expr(offset, &empty_schema, planner_context)?); + let fetch = + Some(self.sql_to_expr(limit, &empty_schema, planner_context)?); + (skip, fetch, vec![]) + } + }; + + if !limit_by_exprs.is_empty() { + return not_impl_err!("LIMIT BY clause is not supported yet"); + } + + if skip.is_none() && fetch.is_none() { + return Ok(input); + } + LogicalPlanBuilder::from(input) .limit_by_expr(skip, fetch)? .build() @@ -127,6 +300,45 @@ impl SqlToRel<'_, S> { } } + /// Handle AGGREGATE pipe operator + fn pipe_operator_aggregate( + &self, + plan: LogicalPlan, + full_table_exprs: Vec, + group_by_expr: Vec, + planner_context: &mut PlannerContext, + ) -> Result { + let plan_schema = plan.schema(); + let process_expr = + |expr_with_alias_and_order_by: ExprWithAliasAndOrderBy, + planner_context: &mut PlannerContext| { + let expr_with_alias = expr_with_alias_and_order_by.expr; + let sql_expr = expr_with_alias.expr; + let alias = expr_with_alias.alias; + + let df_expr = self.sql_to_expr(sql_expr, plan_schema, planner_context)?; + + match alias { + Some(alias_ident) => df_expr.alias_if_changed(alias_ident.value), + None => Ok(df_expr), + } + }; + + let aggr_exprs: Vec = full_table_exprs + .into_iter() + .map(|e| process_expr(e, planner_context)) + .collect::>>()?; + + let group_by_exprs: Vec = group_by_expr + .into_iter() + .map(|e| process_expr(e, planner_context)) + .collect::>>()?; + + LogicalPlanBuilder::from(plan) + .aggregate(group_by_exprs, aggr_exprs)? + .build() + } + /// Wrap the logical plan in a `SelectInto` fn select_into( &self, @@ -137,7 +349,7 @@ impl SqlToRel<'_, S> { Some(into) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(into.name)?, - constraints: Constraints::empty(), + constraints: Constraints::default(), input: Arc::new(plan), if_not_exists: false, or_replace: false, @@ -158,7 +370,7 @@ fn to_order_by_exprs(order_by: Option) -> Result> { /// Returns the order by expressions from the query with the select expressions. pub(crate) fn to_order_by_exprs_with_select( order_by: Option, - _select_exprs: Option<&Vec>, // TODO: ORDER BY ALL + select_exprs: Option<&Vec>, ) -> Result> { let Some(OrderBy { kind, interpolate }) = order_by else { // If no order by, return an empty array. @@ -168,7 +380,30 @@ pub(crate) fn to_order_by_exprs_with_select( return not_impl_err!("ORDER BY INTERPOLATE is not supported"); } match kind { - OrderByKind::All(_) => not_impl_err!("ORDER BY ALL is not supported"), + OrderByKind::All(order_by_options) => { + let Some(exprs) = select_exprs else { + return Ok(vec![]); + }; + let order_by_exprs = exprs + .iter() + .map(|select_expr| match select_expr { + Expr::Column(column) => Ok(OrderByExpr { + expr: SQLExpr::Identifier(Ident { + value: column.name.clone(), + quote_style: None, + span: Span::empty(), + }), + options: order_by_options, + with_fill: None, + }), + // TODO: Support other types of expressions + _ => not_impl_err!( + "ORDER BY ALL is not supported for non-column expressions" + ), + }) + .collect::>>()?; + Ok(order_by_exprs) + } OrderByKind::Expressions(order_by_exprs) => Ok(order_by_exprs), } } diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 8a3c20e3971b8..754ded1514a63 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -43,7 +43,7 @@ impl SqlToRel<'_, S> { Ok(left) } - fn parse_relation_join( + pub(crate) fn parse_relation_join( &self, left: LogicalPlan, join: Join, @@ -95,7 +95,9 @@ impl SqlToRel<'_, S> { JoinOperator::FullOuter(constraint) => { self.parse_join(left, right, constraint, JoinType::Full, planner_context) } - JoinOperator::CrossJoin => self.parse_cross_join(left, right), + JoinOperator::CrossJoin(JoinConstraint::None) => { + self.parse_cross_join(left, right) + } other => not_impl_err!("Unsupported JOIN operator {other:?}"), } } @@ -142,7 +144,7 @@ impl SqlToRel<'_, S> { "Expected identifier in USING clause" ) }) - .map(|ident| self.ident_normalizer.normalize(ident.clone())) + .map(|ident| Column::from_name(self.ident_normalizer.normalize(ident.clone()))) } }) .collect::>>()?; diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index dee855f8c0006..9dfa078701d3d 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -57,7 +57,7 @@ impl SqlToRel<'_, S> { planner_context, ) } else { - plan_err!("Unsupported function argument type: {:?}", arg) + plan_err!("Unsupported function argument type: {}", arg) } }) .collect::>(); @@ -66,7 +66,7 @@ impl SqlToRel<'_, S> { .get_table_function_source(&tbl_func_name, args)?; let plan = LogicalPlanBuilder::scan( TableReference::Bare { - table: "tmp_table".into(), + table: format!("{tbl_func_name}()").into(), }, provider, None, @@ -92,7 +92,7 @@ impl SqlToRel<'_, S> { .build(), (None, Err(e)) => { let e = e.with_diagnostic(Diagnostic::new_error( - format!("table '{}' not found", table_ref), + format!("table '{table_ref}' not found"), Span::try_from_sqlparser_span(relation_span), )); Err(e) @@ -154,6 +154,35 @@ impl SqlToRel<'_, S> { "UNNEST table factor with offset is not supported yet" ); } + TableFactor::Function { + name, args, alias, .. + } => { + let tbl_func_ref = self.object_name_to_table_reference(name)?; + let schema = planner_context + .outer_query_schema() + .cloned() + .unwrap_or_else(DFSchema::empty); + let func_args = args + .into_iter() + .map(|arg| match arg { + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) + | FunctionArg::Named { + arg: FunctionArgExpr::Expr(expr), + .. + } => { + self.sql_expr_to_logical_expr(expr, &schema, planner_context) + } + _ => plan_err!("Unsupported function argument: {arg:?}"), + }) + .collect::>>()?; + let provider = self + .context_provider + .get_table_function_source(tbl_func_ref.table(), func_args)?; + let plan = + LogicalPlanBuilder::scan(tbl_func_ref.table(), provider, None)? + .build()?; + (plan, alias) + } // @todo Support TableFactory::TableFunction? _ => { return not_impl_err!( diff --git a/datafusion/sql/src/resolve.rs b/datafusion/sql/src/resolve.rs index 96012a92c09ad..9e909f66fa97a 100644 --- a/datafusion/sql/src/resolve.rs +++ b/datafusion/sql/src/resolve.rs @@ -78,7 +78,7 @@ impl Visitor for RelationVisitor { if !with.recursive { // This is a bit hackish as the CTE will be visited again as part of visiting `q`, // but thankfully `insert_relation` is idempotent. - cte.visit(self); + let _ = cte.visit(self); } self.ctes_in_scope .push(ObjectName::from(vec![cte.alias.name.clone()])); @@ -143,7 +143,7 @@ fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { visitor.insert_relation(table_name); } CopyToSource::Query(query) => { - query.visit(visitor); + let _ = query.visit(visitor); } }, DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor), diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 2a2d0b3b3eb8b..42013a76a8657 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -16,6 +16,7 @@ // under the License. use std::collections::HashSet; +use std::ops::ControlFlow; use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -23,7 +24,7 @@ use crate::query::to_order_by_exprs_with_select; use crate::utils::{ check_columns_satisfy_exprs, extract_aliases, rebase_expr, resolve_aliases_to_exprs, resolve_columns, resolve_positions_to_exprs, rewrite_recursive_unnests_bottom_up, - CheckColumnsSatisfyExprsPurpose, + CheckColumnsMustReferenceAggregatePurpose, CheckColumnsSatisfyExprsPurpose, }; use datafusion_common::error::DataFusionErrorBuilder; @@ -45,8 +46,8 @@ use datafusion_expr::{ use indexmap::IndexMap; use sqlparser::ast::{ - Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, OrderBy, - SelectItemQualifiedWildcardKind, WildcardAdditionalOptions, WindowType, + visit_expressions_mut, Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, + OrderBy, SelectItemQualifiedWildcardKind, WildcardAdditionalOptions, WindowType, }; use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins}; @@ -65,9 +66,7 @@ impl SqlToRel<'_, S> { if !select.lateral_views.is_empty() { return not_impl_err!("LATERAL VIEWS"); } - if select.qualify.is_some() { - return not_impl_err!("QUALIFY"); - } + if select.top.is_some() { return not_impl_err!("TOP"); } @@ -84,7 +83,8 @@ impl SqlToRel<'_, S> { // Handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; - match_window_definitions(&mut select.projection, &select.named_window)?; + self.match_window_definitions(&mut select.projection, &select.named_window)?; + // Process the SELECT expressions let select_exprs = self.prepare_select_exprs( &base_plan, @@ -93,13 +93,13 @@ impl SqlToRel<'_, S> { planner_context, )?; - let order_by = - to_order_by_exprs_with_select(query_order_by, Some(&select_exprs))?; - // Having and group by clause may reference aliases defined in select projection let projected_plan = self.project(base_plan.clone(), select_exprs)?; let select_exprs = projected_plan.expressions(); + let order_by = + to_order_by_exprs_with_select(query_order_by, Some(&select_exprs))?; + // Place the fields of the base plan at the front so that when there are references // with the same name, the fields of the base plan will be searched first. // See https://github.com/apache/datafusion/issues/9162 @@ -147,12 +147,6 @@ impl SqlToRel<'_, S> { }) .transpose()?; - // The outer expressions we will search through for aggregates. - // Aggregates may be sourced from the SELECT list or from the HAVING expression. - let aggr_expr_haystack = select_exprs.iter().chain(having_expr_opt.iter()); - // All of the aggregate expressions (deduplicated). - let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack); - // All of the group by expressions let group_by_exprs = if let GroupByExpr::Expressions(exprs, _) = select.group_by { exprs @@ -197,22 +191,61 @@ impl SqlToRel<'_, S> { .collect() }; + // Optionally the QUALIFY expression. + let qualify_expr_opt = select + .qualify + .map::, _>(|qualify_expr| { + let qualify_expr = self.sql_expr_to_logical_expr( + qualify_expr, + &combined_schema, + planner_context, + )?; + // This step "dereferences" any aliases in the QUALIFY clause. + // + // This is how we support queries with QUALIFY expressions that + // refer to aliased columns. + // + // For example: + // + // select row_number() over (PARTITION BY id) as rk from users qualify rk > 1; + // + // are rewritten as, respectively: + // + // select row_number() over (PARTITION BY id) as rk from users qualify row_number() over (PARTITION BY id) > 1; + // + let qualify_expr = resolve_aliases_to_exprs(qualify_expr, &alias_map)?; + normalize_col(qualify_expr, &projected_plan) + }) + .transpose()?; + + // The outer expressions we will search through for aggregates. + // Aggregates may be sourced from the SELECT list or from the HAVING expression. + let aggr_expr_haystack = select_exprs + .iter() + .chain(having_expr_opt.iter()) + .chain(qualify_expr_opt.iter()); + // All of the aggregate expressions (deduplicated). + let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack); + // Process group by, aggregation or having - let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs - .is_empty() - || !aggr_exprs.is_empty() - { + let ( + plan, + mut select_exprs_post_aggr, + having_expr_post_aggr, + qualify_expr_post_aggr, + ) = if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() { self.aggregate( &base_plan, &select_exprs, having_expr_opt.as_ref(), + qualify_expr_opt.as_ref(), &group_by_exprs, &aggr_exprs, )? } else { match having_expr_opt { Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"), - None => (base_plan.clone(), select_exprs.clone(), having_expr_opt) + None => (base_plan.clone(), select_exprs.clone(), having_expr_opt, qualify_expr_opt) } }; @@ -224,9 +257,17 @@ impl SqlToRel<'_, S> { plan }; - // Process window function - let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); + // The outer expressions we will search through for window functions. + // Window functions may be sourced from the SELECT list or from the QUALIFY expression. + let windows_expr_haystack = select_exprs_post_aggr + .iter() + .chain(qualify_expr_post_aggr.iter()); + // All of the window expressions (deduplicated and rewritten to reference aggregates as + // columns from input). + let window_func_exprs = find_window_exprs(windows_expr_haystack); + // Process window functions after aggregation as they can reference + // aggregate functions in their body let plan = if window_func_exprs.is_empty() { plan } else { @@ -241,6 +282,39 @@ impl SqlToRel<'_, S> { plan }; + // Process QUALIFY clause after window functions + // QUALIFY filters the results of window functions, similar to how HAVING filters aggregates + let plan = if let Some(qualify_expr) = qualify_expr_post_aggr { + // Validate that QUALIFY is used with window functions + if window_func_exprs.is_empty() { + return plan_err!( + "QUALIFY clause requires window functions in the SELECT list or QUALIFY clause" + ); + } + + // now attempt to resolve columns and replace with fully-qualified columns + let windows_projection_exprs = window_func_exprs + .iter() + .map(|expr| resolve_columns(expr, &plan)) + .collect::>>()?; + + // Rewrite the qualify expression to reference columns from the window plan + let qualify_expr_post_window = + rebase_expr(&qualify_expr, &windows_projection_exprs, &plan)?; + + // Validate that the qualify expression can be resolved from the window plan schema + self.validate_schema_satisfies_exprs( + plan.schema(), + std::slice::from_ref(&qualify_expr_post_window), + )?; + + LogicalPlanBuilder::from(plan) + .filter(qualify_expr_post_window)? + .build()? + } else { + plan + }; + // Try processing unnest expression or do the final projection let plan = self.try_process_unnest(plan, select_exprs_post_aggr)?; @@ -306,6 +380,15 @@ impl SqlToRel<'_, S> { let mut intermediate_plan = input; let mut intermediate_select_exprs = select_exprs; + // Fast path: If there is are no unnests in the select_exprs, wrap the plan in a projection + if !intermediate_select_exprs + .iter() + .any(has_unnest_expr_recursively) + { + return LogicalPlanBuilder::from(intermediate_plan) + .project(intermediate_select_exprs)? + .build(); + } // Each expr in select_exprs can contains multiple unnest stage // The transformation happen bottom up, one at a time for each iteration @@ -373,6 +456,12 @@ impl SqlToRel<'_, S> { fn try_process_aggregate_unnest(&self, input: LogicalPlan) -> Result { match input { + // Fast path if there are no unnest in group by + LogicalPlan::Aggregate(ref agg) + if !&agg.group_expr.iter().any(has_unnest_expr_recursively) => + { + Ok(input) + } LogicalPlan::Aggregate(agg) => { let agg_expr = agg.aggr_expr.clone(); let (new_input, new_group_by_exprs) = @@ -496,7 +585,7 @@ impl SqlToRel<'_, S> { Ok((intermediate_plan, intermediate_select_exprs)) } - fn plan_selection( + pub(crate) fn plan_selection( &self, selection: Option, plan: LogicalPlan, @@ -577,7 +666,7 @@ impl SqlToRel<'_, S> { } /// Returns the `Expr`'s corresponding to a SQL query's SELECT expressions. - fn prepare_select_exprs( + pub(crate) fn prepare_select_exprs( &self, plan: &LogicalPlan, projection: Vec, @@ -586,6 +675,7 @@ impl SqlToRel<'_, S> { ) -> Result> { let mut prepared_select_exprs = vec![]; let mut error_builder = DataFusionErrorBuilder::new(); + for expr in projection { match self.sql_select_to_rex(expr, plan, empty_from, planner_context) { Ok(expr) => prepared_select_exprs.push(expr), @@ -736,7 +826,11 @@ impl SqlToRel<'_, S> { } /// Wrap a plan in a projection - fn project(&self, input: LogicalPlan, expr: Vec) -> Result { + pub(crate) fn project( + &self, + input: LogicalPlan, + expr: Vec, + ) -> Result { // convert to Expr for validate_schema_satisfies_exprs let exprs = expr .iter() @@ -752,36 +846,42 @@ impl SqlToRel<'_, S> { /// Create an aggregate plan. /// - /// An aggregate plan consists of grouping expressions, aggregate expressions, and an - /// optional HAVING expression (which is a filter on the output of the aggregate). + /// An aggregate plan consists of grouping expressions, aggregate expressions, an + /// optional HAVING expression (which is a filter on the output of the aggregate), + /// and an optional QUALIFY clause which may reference aggregates. /// /// # Arguments /// /// * `input` - The input plan that will be aggregated. The grouping, aggregate, and - /// "having" expressions must all be resolvable from this plan. + /// "having" expressions must all be resolvable from this plan. /// * `select_exprs` - The projection expressions from the SELECT clause. /// * `having_expr_opt` - Optional HAVING clause. + /// * `qualify_expr_opt` - Optional QUALIFY clause. /// * `group_by_exprs` - Grouping expressions from the GROUP BY clause. These can be column - /// references or more complex expressions. + /// references or more complex expressions. /// * `aggr_exprs` - Aggregate expressions, such as `SUM(a)` or `COUNT(1)`. /// /// # Return /// - /// The return value is a triplet of the following items: + /// The return value is a quadruplet of the following items: /// /// * `plan` - A [LogicalPlan::Aggregate] plan for the newly created aggregate. /// * `select_exprs_post_aggr` - The projection expressions rewritten to reference columns from - /// the aggregate + /// the aggregate /// * `having_expr_post_aggr` - The "having" expression rewritten to reference a column from - /// the aggregate + /// the aggregate + /// * `qualify_expr_post_aggr` - The "qualify" expression rewritten to reference a column from + /// the aggregate + #[allow(clippy::type_complexity)] fn aggregate( &self, input: &LogicalPlan, select_exprs: &[Expr], having_expr_opt: Option<&Expr>, + qualify_expr_opt: Option<&Expr>, group_by_exprs: &[Expr], aggr_exprs: &[Expr], - ) -> Result<(LogicalPlan, Vec, Option)> { + ) -> Result<(LogicalPlan, Vec, Option, Option)> { // create the aggregate plan let options = LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); @@ -845,7 +945,9 @@ impl SqlToRel<'_, S> { check_columns_satisfy_exprs( &column_exprs_post_aggr, &select_exprs_post_aggr, - CheckColumnsSatisfyExprsPurpose::ProjectionMustReferenceAggregate, + CheckColumnsSatisfyExprsPurpose::Aggregate( + CheckColumnsMustReferenceAggregatePurpose::Projection, + ), )?; // Rewrite the HAVING expression to use the columns produced by the @@ -857,7 +959,9 @@ impl SqlToRel<'_, S> { check_columns_satisfy_exprs( &column_exprs_post_aggr, std::slice::from_ref(&having_expr_post_aggr), - CheckColumnsSatisfyExprsPurpose::HavingMustReferenceAggregate, + CheckColumnsSatisfyExprsPurpose::Aggregate( + CheckColumnsMustReferenceAggregatePurpose::Having, + ), )?; Some(having_expr_post_aggr) @@ -865,7 +969,86 @@ impl SqlToRel<'_, S> { None }; - Ok((plan, select_exprs_post_aggr, having_expr_post_aggr)) + // Rewrite the QUALIFY expression to use the columns produced by the + // aggregation. + let qualify_expr_post_aggr = if let Some(qualify_expr) = qualify_expr_opt { + let qualify_expr_post_aggr = + rebase_expr(qualify_expr, &aggr_projection_exprs, input)?; + + check_columns_satisfy_exprs( + &column_exprs_post_aggr, + std::slice::from_ref(&qualify_expr_post_aggr), + CheckColumnsSatisfyExprsPurpose::Aggregate( + CheckColumnsMustReferenceAggregatePurpose::Qualify, + ), + )?; + + Some(qualify_expr_post_aggr) + } else { + None + }; + + Ok(( + plan, + select_exprs_post_aggr, + having_expr_post_aggr, + qualify_expr_post_aggr, + )) + } + + // If the projection is done over a named window, that window + // name must be defined. Otherwise, it gives an error. + fn match_window_definitions( + &self, + projection: &mut [SelectItem], + named_windows: &[NamedWindowDefinition], + ) -> Result<()> { + let named_windows: Vec<(&NamedWindowDefinition, String)> = named_windows + .iter() + .map(|w| (w, self.ident_normalizer.normalize(w.0.clone()))) + .collect(); + for proj in projection.iter_mut() { + if let SelectItem::ExprWithAlias { expr, alias: _ } + | SelectItem::UnnamedExpr(expr) = proj + { + let mut err = None; + let _ = visit_expressions_mut(expr, |expr| { + if let SQLExpr::Function(f) = expr { + if let Some(WindowType::NamedWindow(ident)) = &f.over { + let normalized_ident = + self.ident_normalizer.normalize(ident.clone()); + for ( + NamedWindowDefinition(_, window_expr), + normalized_window_ident, + ) in named_windows.iter() + { + if normalized_ident.eq(normalized_window_ident) { + f.over = Some(match window_expr { + NamedWindowExpr::NamedWindow(ident) => { + WindowType::NamedWindow(ident.clone()) + } + NamedWindowExpr::WindowSpec(spec) => { + WindowType::WindowSpec(spec.clone()) + } + }) + } + } + // All named windows must be defined with a WindowSpec. + if let Some(WindowType::NamedWindow(ident)) = &f.over { + err = + Some(plan_err!("The window {ident} is not defined!")); + return ControlFlow::Break(()); + } + } + } + ControlFlow::Continue(()) + }); + if let Some(err) = err { + return err; + } + } + } + Ok(()) } } @@ -884,38 +1067,16 @@ fn check_conflicting_windows(window_defs: &[NamedWindowDefinition]) -> Result<() Ok(()) } -// If the projection is done over a named window, that window -// name must be defined. Otherwise, it gives an error. -fn match_window_definitions( - projection: &mut [SelectItem], - named_windows: &[NamedWindowDefinition], -) -> Result<()> { - for proj in projection.iter_mut() { - if let SelectItem::ExprWithAlias { - expr: SQLExpr::Function(f), - alias: _, - } - | SelectItem::UnnamedExpr(SQLExpr::Function(f)) = proj - { - for NamedWindowDefinition(window_ident, window_expr) in named_windows.iter() { - if let Some(WindowType::NamedWindow(ident)) = &f.over { - if ident.eq(window_ident) { - f.over = Some(match window_expr { - NamedWindowExpr::NamedWindow(ident) => { - WindowType::NamedWindow(ident.clone()) - } - NamedWindowExpr::WindowSpec(spec) => { - WindowType::WindowSpec(spec.clone()) - } - }) - } - } - } - // All named windows must be defined with a WindowSpec. - if let Some(WindowType::NamedWindow(ident)) = &f.over { - return plan_err!("The window {ident} is not defined!"); - } +/// Returns true if the expression recursively contains an `Expr::Unnest` expression +fn has_unnest_expr_recursively(expr: &Expr) -> bool { + let mut has_unnest = false; + let _ = expr.apply(|e| { + if let Expr::Unnest(_) = e { + has_unnest = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) } - } - Ok(()) + }); + has_unnest } diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 272d6f874b4d6..5b65e1c045bdc 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -95,26 +95,22 @@ impl SqlToRel<'_, S> { if left_plan.schema().fields().len() == right_plan.schema().fields().len() { return Ok(()); } - - plan_err!("{} queries have different number of columns", op).map_err(|err| { - err.with_diagnostic( - Diagnostic::new_error( - format!("{} queries have different number of columns", op), - set_expr_span, - ) - .with_note( - format!("this side has {} fields", left_plan.schema().fields().len()), - left_span, - ) - .with_note( - format!( - "this side has {} fields", - right_plan.schema().fields().len() - ), - right_span, - ), - ) - }) + let diagnostic = Diagnostic::new_error( + format!("{op} queries have different number of columns"), + set_expr_span, + ) + .with_note( + format!("this side has {} fields", left_plan.schema().fields().len()), + left_span, + ) + .with_note( + format!( + "this side has {} fields", + right_plan.schema().fields().len() + ), + right_span, + ); + plan_err!("{} queries have different number of columns", op; diagnostic =diagnostic) } pub(super) fn set_operation_to_plan( diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index fc6cb0d32feff..0e868e8c26899 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -55,16 +55,16 @@ use datafusion_expr::{ Volatility, WriteOp, }; use sqlparser::ast::{ - self, BeginTransactionKind, NullsDistinctOption, ShowStatementIn, - ShowStatementOptions, SqliteOnConflict, TableObject, UpdateTableFromKind, - ValueWithSpan, + self, BeginTransactionKind, IndexColumn, IndexType, NullsDistinctOption, OrderByExpr, + OrderByOptions, Set, ShowStatementIn, ShowStatementOptions, SqliteOnConflict, + TableObject, UpdateTableFromKind, ValueWithSpan, }; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, FromTable, Ident, Insert, - ObjectName, ObjectType, OneOrManyWithParens, Query, SchemaName, SetExpr, - ShowCreateObject, ShowStatementFilter, Statement, TableConstraint, TableFactor, - TableWithJoins, TransactionMode, UnaryOperator, Value, + ObjectName, ObjectType, Query, SchemaName, SetExpr, ShowCreateObject, + ShowStatementFilter, Statement, TableConstraint, TableFactor, TableWithJoins, + TransactionMode, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -111,7 +111,17 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(TableConstraint::Unique { name: name.clone(), - columns: vec![column.name.clone()], + columns: vec![IndexColumn { + column: OrderByExpr { + expr: SQLExpr::Identifier(column.name.clone()), + options: OrderByOptions { + asc: None, + nulls_first: None, + }, + with_fill: None, + }, + operator_class: None, + }], characteristics: *characteristics, index_name: None, index_type_display: ast::KeyOrIndexDisplay::None, @@ -124,7 +134,17 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(TableConstraint::PrimaryKey { name: name.clone(), - columns: vec![column.name.clone()], + columns: vec![IndexColumn { + column: OrderByExpr { + expr: SQLExpr::Identifier(column.name.clone()), + options: OrderByOptions { + asc: None, + nulls_first: None, + }, + with_fill: None, + }, + operator_class: None, + }], characteristics: *characteristics, index_name: None, index_type: None, @@ -144,11 +164,13 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { constraints.push(TableConstraint::Check { name: name.clone(), expr: Box::new(expr.clone()), + enforced: None, }) } // Other options are not constraint related. @@ -168,6 +190,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec {} } } @@ -215,7 +238,7 @@ impl SqlToRel<'_, S> { ) -> Result { match statement { Statement::ExplainTable { - describe_alias: DescribeAlias::Describe, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' + describe_alias: DescribeAlias::Describe | DescribeAlias::Desc, // only parse 'DESCRIBE table_name' or 'DESC table_name' and not 'EXPLAIN table_name' table_name, .. } => self.describe_table_to_plan(table_name), @@ -233,13 +256,7 @@ impl SqlToRel<'_, S> { } Statement::Query(query) => self.query_to_plan(*query, planner_context), Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable), - Statement::SetVariable { - local, - hivevar, - variables, - value, - } => self.set_variable_to_plan(local, hivevar, &variables, value), - + Statement::Set(statement) => self.set_statement_to_plan(statement), Statement::CreateTable(CreateTable { temporary, external, @@ -254,18 +271,12 @@ impl SqlToRel<'_, S> { name, columns, constraints, - table_properties, - with_options, if_not_exists, or_replace, without_rowid, like, clone, - engine, comment, - auto_increment_offset, - default_charset, - collation, on_commit, on_cluster, primary_key, @@ -273,7 +284,6 @@ impl SqlToRel<'_, S> { partition_by, cluster_by, clustered_by, - options, strict, copy_grants, enable_schema_evolution, @@ -290,7 +300,16 @@ impl SqlToRel<'_, S> { catalog, catalog_sync, storage_serialization_policy, - }) if table_properties.is_empty() && with_options.is_empty() => { + inherits, + table_options: CreateTableOptions::None, + dynamic, + version, + target_lag, + warehouse, + refresh_mode, + initialize, + require_user, + }) => { if temporary { return not_impl_err!("Temporary tables not supported")?; } @@ -339,21 +358,9 @@ impl SqlToRel<'_, S> { if clone.is_some() { return not_impl_err!("Clone not supported")?; } - if engine.is_some() { - return not_impl_err!("Engine not supported")?; - } if comment.is_some() { return not_impl_err!("Comment not supported")?; } - if auto_increment_offset.is_some() { - return not_impl_err!("Auto increment offset not supported")?; - } - if default_charset.is_some() { - return not_impl_err!("Default charset not supported")?; - } - if collation.is_some() { - return not_impl_err!("Collation not supported")?; - } if on_commit.is_some() { return not_impl_err!("On commit not supported")?; } @@ -375,9 +382,6 @@ impl SqlToRel<'_, S> { if clustered_by.is_some() { return not_impl_err!("Clustered by not supported")?; } - if options.is_some() { - return not_impl_err!("Options not supported")?; - } if strict { return not_impl_err!("Strict not supported")?; } @@ -428,7 +432,30 @@ impl SqlToRel<'_, S> { if storage_serialization_policy.is_some() { return not_impl_err!("Storage serialization policy not supported")?; } - + if inherits.is_some() { + return not_impl_err!("Table inheritance not supported")?; + } + if dynamic { + return not_impl_err!("Dynamic tables not supported")?; + } + if version.is_some() { + return not_impl_err!("Version not supported")?; + } + if target_lag.is_some() { + return not_impl_err!("Target lag not supported")?; + } + if warehouse.is_some() { + return not_impl_err!("Warehouse not supported")?; + } + if refresh_mode.is_some() { + return not_impl_err!("Refresh mode not supported")?; + } + if initialize.is_some() { + return not_impl_err!("Initialize not supported")?; + } + if require_user { + return not_impl_err!("Require user not supported")?; + } // Merge inline constraints and existing constraints let mut all_constraints = constraints; let inline_constraints = calc_inline_constraints_from_columns(&columns); @@ -451,10 +478,10 @@ impl SqlToRel<'_, S> { let plan = if has_columns { if schema.fields().len() != input_schema.fields().len() { return plan_err!( - "Mismatch: {} columns specified, but result has {} columns", - schema.fields().len(), - input_schema.fields().len() - ); + "Mismatch: {} columns specified, but result has {} columns", + schema.fields().len(), + input_schema.fields().len() + ); } let input_fields = input_schema.fields(); let project_exprs = schema @@ -519,7 +546,6 @@ impl SqlToRel<'_, S> { } } } - Statement::CreateView { or_replace, materialized, @@ -534,6 +560,9 @@ impl SqlToRel<'_, S> { temporary, to, params, + or_alter, + secure, + name_before_not_exists, } => { if materialized { return not_impl_err!("Materialized views not supported")?; @@ -570,6 +599,9 @@ impl SqlToRel<'_, S> { temporary, to, params, + or_alter, + secure, + name_before_not_exists, }; let sql = stmt.to_string(); let Statement::CreateView { @@ -617,6 +649,7 @@ impl SqlToRel<'_, S> { Statement::CreateSchema { schema_name, if_not_exists, + .. } => Ok(LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema( CreateCatalogSchema { schema_name: get_schema_name(&schema_name), @@ -643,6 +676,7 @@ impl SqlToRel<'_, S> { restrict: _, purge: _, temporary: _, + table: _, } => { // We don't support cascade and purge for now. // nor do we support multiple object names @@ -696,7 +730,7 @@ impl SqlToRel<'_, S> { statement, } => { // Convert parser data types to DataFusion data types - let data_types: Vec = data_types + let mut data_types: Vec = data_types .into_iter() .map(|t| self.convert_data_type(&t)) .collect::>()?; @@ -710,6 +744,19 @@ impl SqlToRel<'_, S> { *statement, &mut planner_context, )?; + + if data_types.is_empty() { + let map_types = plan.get_parameter_types()?; + let param_types: Vec<_> = (1..=map_types.len()) + .filter_map(|i| { + let key = format!("${i}"); + map_types.get(&key).and_then(|opt| opt.clone()) + }) + .collect(); + data_types.extend(param_types.iter().cloned()); + planner_context.with_prepare_param_data_types(param_types); + } + Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare { name: ident_to_string(&name), data_types, @@ -725,6 +772,8 @@ impl SqlToRel<'_, S> { has_parentheses: _, immediate, into, + output, + default, } => { // `USING` is a MySQL-specific syntax and currently not supported. if !using.is_empty() { @@ -740,6 +789,16 @@ impl SqlToRel<'_, S> { if !into.is_empty() { return not_impl_err!("Execute statement with INTO is not supported"); } + if output { + return not_impl_err!( + "Execute statement with OUTPUT is not supported" + ); + } + if default { + return not_impl_err!( + "Execute statement with DEFAULT is not supported" + ); + } let empty_schema = DFSchema::empty(); let parameters = parameters .into_iter() @@ -947,23 +1006,27 @@ impl SqlToRel<'_, S> { selection, returning, or, + limit, } => { - let froms = + let from_clauses = from.map(|update_table_from_kind| match update_table_from_kind { - UpdateTableFromKind::BeforeSet(froms) => froms, - UpdateTableFromKind::AfterSet(froms) => froms, + UpdateTableFromKind::BeforeSet(from_clauses) => from_clauses, + UpdateTableFromKind::AfterSet(from_clauses) => from_clauses, }); // TODO: support multiple tables in UPDATE SET FROM - if froms.as_ref().is_some_and(|f| f.len() > 1) { + if from_clauses.as_ref().is_some_and(|f| f.len() > 1) { plan_err!("Multiple tables in UPDATE SET FROM not yet supported")?; } - let update_from = froms.and_then(|mut f| f.pop()); + let update_from = from_clauses.and_then(|mut f| f.pop()); if returning.is_some() { plan_err!("Update-returning clause not yet supported")?; } if or.is_some() { plan_err!("ON conflict not supported")?; } + if limit.is_some() { + return not_impl_err!("Update-limit clause not supported")?; + } self.update_to_plan(table, assignments, update_from, selection) } @@ -1006,8 +1069,8 @@ impl SqlToRel<'_, S> { modifier, transaction, statements, - exception_statements, has_end_keyword, + exception, } => { if let Some(modifier) = modifier { return not_impl_err!( @@ -1019,7 +1082,7 @@ impl SqlToRel<'_, S> { "Transaction with multiple statements not supported" ); } - if exception_statements.is_some() { + if exception.is_some() { return not_impl_err!( "Transaction with exception statements not supported" ); @@ -1034,7 +1097,7 @@ impl SqlToRel<'_, S> { TransactionMode::AccessMode(_) => None, TransactionMode::IsolationLevel(level) => Some(level), }) - .last() + .next_back() .copied() .unwrap_or(ast::TransactionIsolationLevel::Serializable); let access_mode: ast::TransactionAccessMode = modes @@ -1043,7 +1106,7 @@ impl SqlToRel<'_, S> { TransactionMode::AccessMode(mode) => Some(mode), TransactionMode::IsolationLevel(_) => None, }) - .last() + .next_back() .copied() .unwrap_or(ast::TransactionAccessMode::ReadWrite); let isolation_level = match isolation_level { @@ -1169,6 +1232,17 @@ impl SqlToRel<'_, S> { ast::CreateFunctionBody::AsBeforeOptions(expr) => expr, ast::CreateFunctionBody::AsAfterOptions(expr) => expr, ast::CreateFunctionBody::Return(expr) => expr, + ast::CreateFunctionBody::AsBeginEnd(_) => { + return not_impl_err!( + "BEGIN/END enclosed function body syntax is not supported" + )?; + } + ast::CreateFunctionBody::AsReturnExpr(_) + | ast::CreateFunctionBody::AsReturnSelect(_) => { + return not_impl_err!( + "AS RETURN function syntax is not supported" + )? + } }, &DFSchema::empty(), &mut planner_context, @@ -1238,9 +1312,15 @@ impl SqlToRel<'_, S> { .get_table_source(table.clone())? .schema() .to_dfschema_ref()?; - let using: Option = using.as_ref().map(ident_to_string); + let using: Option = + using.as_ref().map(|index_type| match index_type { + IndexType::Custom(ident) => ident_to_string(ident), + _ => index_type.to_string().to_ascii_lowercase(), + }); + let order_by_exprs: Vec = + columns.into_iter().map(|col| col.column).collect(); let columns = self.order_by_to_sort_expr( - columns, + order_by_exprs, &table_schema, planner_context, false, @@ -1340,11 +1420,7 @@ impl SqlToRel<'_, S> { let options_map = self.parse_options_map(statement.options, true)?; let maybe_file_type = if let Some(stored_as) = &statement.stored_as { - if let Ok(ext_file_type) = self.context_provider.get_file_type(stored_as) { - Some(ext_file_type) - } else { - None - } + self.context_provider.get_file_type(stored_as).ok() } else { None }; @@ -1379,13 +1455,13 @@ impl SqlToRel<'_, S> { .map(|f| f.name().to_owned()) .collect(); - Ok(LogicalPlan::Copy(CopyTo { - input: Arc::new(input), - output_url: statement.target, - file_type, + Ok(LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + statement.target, partition_by, - options: options_map, - })) + file_type, + options_map, + ))) } fn build_order_by( @@ -1403,23 +1479,23 @@ impl SqlToRel<'_, S> { .map(|order_by_expr| { let ordered_expr = &order_by_expr.expr; let ordered_expr = ordered_expr.to_owned(); - let ordered_expr = self - .sql_expr_to_logical_expr( - ordered_expr, - schema, - planner_context, - ) - .unwrap(); + let ordered_expr = self.sql_expr_to_logical_expr( + ordered_expr, + schema, + planner_context, + )?; let asc = order_by_expr.options.asc.unwrap_or(true); let nulls_first = - order_by_expr.options.nulls_first.unwrap_or(!asc); + order_by_expr.options.nulls_first.unwrap_or_else(|| { + self.options.default_null_ordering.nulls_first(asc) + }); - SortExpr::new(ordered_expr, asc, nulls_first) + Ok(SortExpr::new(ordered_expr, asc, nulls_first)) }) - .collect::>(); - result + .collect::>>()?; + Ok(result) }) - .collect::>>(); + .collect::>>>()?; return Ok(results); } @@ -1462,6 +1538,7 @@ impl SqlToRel<'_, S> { unbounded, options, constraints, + or_replace, } = statement; // Merge inline constraints and existing constraints @@ -1510,6 +1587,7 @@ impl SqlToRel<'_, S> { file_type, table_partition_cols, if_not_exists, + or_replace, temporary, definition, order_exprs: ordered_exprs, @@ -1526,13 +1604,21 @@ impl SqlToRel<'_, S> { fn get_constraint_column_indices( &self, df_schema: &DFSchemaRef, - columns: &[Ident], + columns: &[IndexColumn], constraint_name: &str, ) -> Result> { let field_names = df_schema.field_names(); columns .iter() - .map(|ident| { + .map(|index_column| { + let expr = &index_column.column.expr; + let ident = if let SQLExpr::Identifier(ident) = expr { + ident + } else { + return Err(plan_datafusion_err!( + "Column name for {constraint_name} must be an identifier: {expr}" + )); + }; let column = self.ident_normalizer.normalize(ident.clone()); field_names .iter() @@ -1547,7 +1633,7 @@ impl SqlToRel<'_, S> { } /// Convert each [TableConstraint] to corresponding [Constraint] - fn new_constraint_from_table_constraints( + pub fn new_constraint_from_table_constraints( &self, constraints: &[TableConstraint], df_schema: &DFSchemaRef, @@ -1613,7 +1699,7 @@ impl SqlToRel<'_, S> { // If config does not belong to any namespace, assume it is // a format option and apply the format prefix for backwards // compatibility. - let renamed_key = format!("format.{}", key); + let renamed_key = format!("format.{key}"); options_map.insert(renamed_key.to_lowercase(), value_string); } else { options_map.insert(key.to_lowercase(), value_string); @@ -1661,10 +1747,15 @@ impl SqlToRel<'_, S> { vec![plan.to_stringified(PlanType::InitialLogicalPlan)]; // default to configuration value + // verbose mode only supports indent format let options = self.context_provider.options(); - let format = format.as_ref().unwrap_or(&options.explain.format); - - let format: ExplainFormat = format.parse()?; + let format = if verbose { + ExplainFormat::Indent + } else if let Some(format) = format { + ExplainFormat::from_str(&format)? + } else { + options.explain.format.clone() + }; Ok(LogicalPlan::Explain(Explain { verbose, @@ -1730,64 +1821,58 @@ impl SqlToRel<'_, S> { self.statement_to_plan(rewrite.pop_front().unwrap()) } - fn set_variable_to_plan( - &self, - local: bool, - hivevar: bool, - variables: &OneOrManyWithParens, - value: Vec, - ) -> Result { - if local { - return not_impl_err!("LOCAL is not supported"); - } - - if hivevar { - return not_impl_err!("HIVEVAR is not supported"); - } + fn set_statement_to_plan(&self, statement: Set) -> Result { + match statement { + Set::SingleAssignment { + scope, + hivevar, + variable, + values, + } => { + if scope.is_some() { + return not_impl_err!("SET with scope modifiers is not supported"); + } - let variable = match variables { - OneOrManyWithParens::One(v) => object_name_to_string(v), - OneOrManyWithParens::Many(vs) => { - return not_impl_err!( - "SET only supports single variable assignment: {vs:?}" - ); - } - }; - let mut variable_lower = variable.to_lowercase(); + if hivevar { + return not_impl_err!("SET HIVEVAR is not supported"); + } - if variable_lower == "timezone" || variable_lower == "time.zone" { - // We could introduce alias in OptionDefinition if this string matching thing grows - variable_lower = "datafusion.execution.time_zone".to_string(); - } + let variable = object_name_to_string(&variable); + let mut variable_lower = variable.to_lowercase(); - // Parse value string from Expr - let value_string = match &value[0] { - SQLExpr::Identifier(i) => ident_to_string(i), - SQLExpr::Value(v) => match crate::utils::value_to_string(&v.value) { - None => { - return plan_err!("Unsupported Value {}", value[0]); + if variable_lower == "timezone" || variable_lower == "time.zone" { + variable_lower = "datafusion.execution.time_zone".to_string(); } - Some(v) => v, - }, - // For capture signed number e.g. +8, -8 - SQLExpr::UnaryOp { op, expr } => match op { - UnaryOperator::Plus => format!("+{expr}"), - UnaryOperator::Minus => format!("-{expr}"), - _ => { - return plan_err!("Unsupported Value {}", value[0]); + + if values.len() != 1 { + return plan_err!("SET only supports single value assignment"); } - }, - _ => { - return plan_err!("Unsupported Value {}", value[0]); - } - }; - let statement = PlanStatement::SetVariable(SetVariable { - variable: variable_lower, - value: value_string, - }); + let value_string = match &values[0] { + SQLExpr::Identifier(i) => ident_to_string(i), + SQLExpr::Value(v) => match crate::utils::value_to_string(&v.value) { + None => { + return plan_err!("Unsupported value {:?}", v.value); + } + Some(s) => s, + }, + SQLExpr::UnaryOp { op, expr } => match op { + UnaryOperator::Plus => format!("+{expr}"), + UnaryOperator::Minus => format!("-{expr}"), + _ => return plan_err!("Unsupported unary op {:?}", op), + }, + _ => return plan_err!("Unsupported expr {:?}", values[0]), + }; - Ok(LogicalPlan::Statement(statement)) + Ok(LogicalPlan::Statement(PlanStatement::SetVariable( + SetVariable { + variable: variable_lower, + value: value_string, + }, + ))) + } + other => not_impl_err!("SET variant not implemented yet: {other:?}"), + } } fn delete_to_plan( @@ -1798,7 +1883,10 @@ impl SqlToRel<'_, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; let table_source = self.context_provider.get_table_source(table_ref.clone())?; - let schema = table_source.schema().to_dfschema_ref()?; + let schema = DFSchema::try_from_qualified_schema( + table_ref.clone(), + &table_source.schema(), + )?; let scan = LogicalPlanBuilder::scan(table_ref.clone(), Arc::clone(&table_source), None)? .build()?; @@ -1956,8 +2044,7 @@ impl SqlToRel<'_, S> { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; let table_source = self.context_provider.get_table_source(table_name.clone())?; - let arrow_schema = (*table_source.schema()).clone(); - let table_schema = DFSchema::try_from(arrow_schema)?; + let table_schema = DFSchema::try_from(table_source.schema())?; // Get insert fields and target table's value indices // @@ -1978,9 +2065,9 @@ impl SqlToRel<'_, S> { let mut value_indices = vec![None; table_schema.fields().len()]; let fields = columns .into_iter() - .map(|c| self.ident_normalizer.normalize(c)) .enumerate() .map(|(i, c)| { + let c = self.ident_normalizer.normalize(c); let column_index = table_schema .index_of_column_by_name(None, &c) .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; @@ -2053,7 +2140,7 @@ impl SqlToRel<'_, S> { .cloned() .unwrap_or_else(|| { // If there is no default for the column, then the default is NULL - Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null, None) }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 6fcc203637cc3..2cf26009ac0f2 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -19,7 +19,9 @@ use core::fmt; use std::ops::ControlFlow; use sqlparser::ast::helpers::attached_token::AttachedToken; -use sqlparser::ast::{self, visit_expressions_mut, OrderByKind, SelectFlavor}; +use sqlparser::ast::{ + self, visit_expressions_mut, LimitClause, OrderByKind, SelectFlavor, +}; #[derive(Clone)] pub struct QueryBuilder { @@ -32,6 +34,8 @@ pub struct QueryBuilder { fetch: Option, locks: Vec, for_clause: Option, + // If true, we need to unparse LogicalPlan::Union as a SQL `UNION` rather than a `UNION ALL`. + distinct_union: bool, } #[allow(dead_code)] @@ -75,6 +79,13 @@ impl QueryBuilder { self.for_clause = value; self } + pub fn distinct_union(&mut self) -> &mut Self { + self.distinct_union = true; + self + } + pub fn is_distinct_union(&self) -> bool { + self.distinct_union + } pub fn build(&self) -> Result { let order_by = self .order_by_kind @@ -91,14 +102,17 @@ impl QueryBuilder { None => return Err(Into::into(UninitializedFieldError::from("body"))), }, order_by, - limit: self.limit.clone(), - limit_by: self.limit_by.clone(), - offset: self.offset.clone(), + limit_clause: Some(LimitClause::LimitOffset { + limit: self.limit.clone(), + offset: self.offset.clone(), + limit_by: self.limit_by.clone(), + }), fetch: self.fetch.clone(), locks: self.locks.clone(), for_clause: self.for_clause.clone(), settings: None, format_clause: None, + pipe_operators: vec![], }) } fn create_empty() -> Self { @@ -112,6 +126,7 @@ impl QueryBuilder { fetch: Default::default(), locks: Default::default(), for_clause: Default::default(), + distinct_union: false, } } } @@ -133,7 +148,7 @@ pub struct SelectBuilder { group_by: Option, cluster_by: Vec, distribute_by: Vec, - sort_by: Vec, + sort_by: Vec, having: Option, named_window: Vec, qualify: Option, @@ -155,6 +170,11 @@ impl SelectBuilder { self.projection = value; self } + pub fn pop_projections(&mut self) -> Vec { + let ret = self.projection.clone(); + self.projection.clear(); + ret + } pub fn already_projected(&self) -> bool { !self.projection.is_empty() } @@ -198,7 +218,7 @@ impl SelectBuilder { value: &ast::Expr, ) -> &mut Self { if let Some(selection) = &mut self.selection { - visit_expressions_mut(selection, |expr| { + let _ = visit_expressions_mut(selection, |expr| { if expr == existing_expr { *expr = value.clone(); } @@ -245,7 +265,7 @@ impl SelectBuilder { self.distribute_by = value; self } - pub fn sort_by(&mut self, value: Vec) -> &mut Self { + pub fn sort_by(&mut self, value: Vec) -> &mut Self { self.sort_by = value; self } @@ -300,6 +320,7 @@ impl SelectBuilder { Some(ref value) => value.clone(), None => return Err(Into::into(UninitializedFieldError::from("flavor"))), }, + exclude: None, }) } fn create_empty() -> Self { @@ -383,6 +404,7 @@ pub struct RelationBuilder { #[allow(dead_code)] #[derive(Clone)] +#[allow(clippy::large_enum_variant)] enum TableFactorBuilder { Table(TableRelationBuilder), Derived(DerivedRelationBuilder), @@ -690,9 +712,9 @@ impl fmt::Display for BuilderError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::UninitializedField(ref field) => { - write!(f, "`{}` must be initialized", field) + write!(f, "`{field}` must be initialized") } - Self::ValidationError(ref error) => write!(f, "{}", error), + Self::ValidationError(ref error) => write!(f, "{error}"), } } } diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 05914b98f55f0..647ad680674b0 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -17,8 +17,13 @@ use std::{collections::HashMap, sync::Arc}; -use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; +use super::{ + utils::character_length_to_sql, utils::date_part_to_sql, + utils::sqlite_date_trunc_to_sql, utils::sqlite_from_unixtime_to_sql, Unparser, +}; +use arrow::array::timezone::Tz; use arrow::datatypes::TimeUnit; +use chrono::DateTime; use datafusion_common::Result; use datafusion_expr::Expr; use regex::Regex; @@ -194,6 +199,18 @@ pub trait Dialect: Send + Sync { fn unnest_as_table_factor(&self) -> bool { false } + + /// Allows the dialect to override column alias unparsing if the dialect has specific rules. + /// Returns None if the default unparsing should be used, or Some(String) if there is + /// a custom implementation for the alias. + fn col_alias_overrides(&self, _alias: &str) -> Result> { + Ok(None) + } + + /// Allows the dialect to override logic of formatting datetime with tz into string. + fn timestamp_with_tz_to_string(&self, dt: DateTime, _unit: TimeUnit) -> String { + dt.to_string() + } } /// `IntervalStyle` to use for unparsing @@ -391,6 +408,17 @@ impl Dialect for DuckDBDialect { Ok(None) } + + fn timestamp_with_tz_to_string(&self, dt: DateTime, unit: TimeUnit) -> String { + let format = match unit { + TimeUnit::Second => "%Y-%m-%d %H:%M:%S%:z", + TimeUnit::Millisecond => "%Y-%m-%d %H:%M:%S%.3f%:z", + TimeUnit::Microsecond => "%Y-%m-%d %H:%M:%S%.6f%:z", + TimeUnit::Nanosecond => "%Y-%m-%d %H:%M:%S%.9f%:z", + }; + + dt.format(format).to_string() + } } pub struct MySqlDialect {} @@ -477,6 +505,14 @@ impl Dialect for SqliteDialect { false } + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + _tz: &Option>, + ) -> ast::DataType { + ast::DataType::Text + } + fn scalar_function_to_sql_overrides( &self, unparser: &Unparser, @@ -490,11 +526,56 @@ impl Dialect for SqliteDialect { "character_length" => { character_length_to_sql(unparser, self.character_length_style(), args) } + "from_unixtime" => sqlite_from_unixtime_to_sql(unparser, args), + "date_trunc" => sqlite_date_trunc_to_sql(unparser, args), _ => Ok(None), } } } +#[derive(Default)] +pub struct BigQueryDialect {} + +impl Dialect for BigQueryDialect { + fn identifier_quote_style(&self, _: &str) -> Option { + Some('`') + } + + fn col_alias_overrides(&self, alias: &str) -> Result> { + // Check if alias contains any special characters not supported by BigQuery col names + // https://cloud.google.com/bigquery/docs/schemas#flexible-column-names + let special_chars: [char; 20] = [ + '!', '"', '$', '(', ')', '*', ',', '.', '/', ';', '?', '@', '[', '\\', ']', + '^', '`', '{', '}', '~', + ]; + + if alias.chars().any(|c| special_chars.contains(&c)) { + let mut encoded_name = String::new(); + for c in alias.chars() { + if special_chars.contains(&c) { + encoded_name.push_str(&format!("_{}", c as u32)); + } else { + encoded_name.push(c); + } + } + Ok(Some(encoded_name)) + } else { + Ok(Some(alias.to_string())) + } + } + + fn unnest_as_table_factor(&self) -> bool { + true + } +} + +impl BigQueryDialect { + #[must_use] + pub fn new() -> Self { + Self {} + } +} + pub struct CustomDialect { identifier_quote_style: Option, supports_nulls_first_in_sort: bool, @@ -548,17 +629,6 @@ impl Default for CustomDialect { } } -impl CustomDialect { - // Create a CustomDialect - #[deprecated(since = "41.0.0", note = "please use `CustomDialectBuilder` instead")] - pub fn new(identifier_quote_style: Option) -> Self { - Self { - identifier_quote_style, - ..Default::default() - } - } -} - impl Dialect for CustomDialect { fn identifier_quote_style(&self, _: &str) -> Option { self.identifier_quote_style diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 064adde55bdfd..a7fe8efa153c9 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -18,8 +18,9 @@ use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, Expr as AstExpr, Function, Ident, Interval, - ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, ValueWithSpan, + self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, + Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, + ValueWithSpan, }; use std::sync::Arc; use std::vec; @@ -34,7 +35,9 @@ use arrow::array::{ }, ArrayRef, Date32Array, Date64Array, PrimitiveArray, }; -use arrow::datatypes::{DataType, Decimal128Type, Decimal256Type, DecimalType}; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, +}; use arrow::util::display::array_value_to_string; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, Result, @@ -182,24 +185,29 @@ impl Unparser<'_> { operand, conditions, else_result, + case_token: AttachedToken::empty(), + end_token: AttachedToken::empty(), }) } Expr::Cast(Cast { expr, data_type }) => { Ok(self.cast_to_sql(expr, data_type)?) } - Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), + Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - .. - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + filter, + distinct, + .. + }, + } = window_fun.as_ref(); let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -255,11 +263,15 @@ impl Unparser<'_> { span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, + duplicate_treatment: distinct + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), - filter: None, + filter: filter + .as_ref() + .map(|f| self.expr_to_sql_inner(f).map(Box::new)) + .transpose()?, null_treatment: None, over, within_group: vec![], @@ -273,26 +285,48 @@ impl Unparser<'_> { pattern, escape_char, case_insensitive: _, - }) - | Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, }) => Ok(ast::Expr::Like { negated: *negated, expr: Box::new(self.expr_to_sql_inner(expr)?), pattern: Box::new(self.expr_to_sql_inner(pattern)?), - escape_char: escape_char.map(|c| c.to_string()), + escape_char: escape_char.map(|c| SingleQuotedString(c.to_string())), any: false, }), + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + if *case_insensitive { + Ok(ast::Expr::ILike { + negated: *negated, + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), + escape_char: escape_char + .map(|c| SingleQuotedString(c.to_string())), + any: false, + }) + } else { + Ok(ast::Expr::Like { + negated: *negated, + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), + escape_char: escape_char + .map(|c| SingleQuotedString(c.to_string())), + any: false, + }) + } + } + Expr::AggregateFunction(agg) => { let func_name = agg.func.name(); let AggregateFunctionParams { distinct, args, filter, + order_by, .. } = &agg.params; @@ -301,6 +335,15 @@ impl Unparser<'_> { Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; + let within_group: Vec = + if agg.func.is_ordered_set_aggregate() { + order_by + .iter() + .map(|sort_expr| self.sort_to_sql(sort_expr)) + .collect::>>()? + } else { + Vec::new() + }; Ok(ast::Expr::Function(Function { name: ObjectName::from(vec![Ident { value: func_name.to_string(), @@ -309,14 +352,14 @@ impl Unparser<'_> { }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: distinct - .then_some(ast::DuplicateTreatment::Distinct), + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), filter, null_treatment: None, over: None, - within_group: vec![], + within_group, parameters: ast::FunctionArguments::None, uses_odbc_syntax: false, })) @@ -563,7 +606,7 @@ impl Unparser<'_> { } fn named_struct_to_sql(&self, args: &[Expr]) -> Result { - if args.len() % 2 != 0 { + if !args.len().is_multiple_of(2) { return internal_err!("named_struct must have an even number of arguments"); } @@ -571,7 +614,7 @@ impl Unparser<'_> { .chunks_exact(2) .map(|chunk| { let key = match &chunk[0] { - Expr::Literal(ScalarValue::Utf8(Some(s))) => self.new_ident_quoted_if_needs(s.to_string()), + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => self.new_ident_quoted_if_needs(s.to_string()), _ => return internal_err!("named_struct expects even arguments to be strings, but received: {:?}", &chunk[0]) }; @@ -590,27 +633,43 @@ impl Unparser<'_> { return internal_err!("get_field must have exactly 2 arguments"); } - let mut id = match &args[0] { - Expr::Column(col) => match self.col_to_sql(col)? { - ast::Expr::Identifier(ident) => vec![ident], - ast::Expr::CompoundIdentifier(idents) => idents, - other => return internal_err!("expected col_to_sql to return an Identifier or CompoundIdentifier, but received: {:?}", other), - }, - _ => return internal_err!("get_field expects first argument to be column, but received: {:?}", &args[0]), - }; - let field = match &args[1] { - Expr::Literal(lit) => self.new_ident_quoted_if_needs(lit.to_string()), + Expr::Literal(lit, _) => self.new_ident_quoted_if_needs(lit.to_string()), _ => { return internal_err!( "get_field expects second argument to be a string, but received: {:?}", - &args[0] + &args[1] ) } }; - id.push(field); - Ok(ast::Expr::CompoundIdentifier(id)) + match &args[0] { + Expr::Column(col) => { + let mut id = match self.col_to_sql(col)? { + ast::Expr::Identifier(ident) => vec![ident], + ast::Expr::CompoundIdentifier(idents) => idents, + other => return internal_err!("expected col_to_sql to return an Identifier or CompoundIdentifier, but received: {:?}", other), + }; + id.push(field); + Ok(ast::Expr::CompoundIdentifier(id)) + } + Expr::ScalarFunction(struct_expr) => { + let root = self + .scalar_function_to_sql(struct_expr.func.name(), &struct_expr.args)?; + Ok(ast::Expr::CompoundFieldAccess { + root: Box::new(root), + access_chain: vec![ast::AccessExpr::Dot(ast::Expr::Identifier( + field, + ))], + }) + } + _ => { + internal_err!( + "get_field expects first argument to be column or scalar function, but received: {:?}", + &args[0] + ) + } + } } fn map_to_sql(&self, args: &[Expr]) -> Result { @@ -679,13 +738,21 @@ impl Unparser<'_> { } pub fn col_to_sql(&self, col: &Column) -> Result { + // Replace the column name if the dialect has an override + let col_name = + if let Some(rewritten_name) = self.dialect.col_alias_overrides(&col.name)? { + rewritten_name + } else { + col.name.to_string() + }; + if let Some(table_ref) = &col.relation { let mut id = if self.dialect.full_qualified_col() { table_ref.to_vec() } else { vec![table_ref.table().to_string()] }; - id.push(col.name.to_string()); + id.push(col_name); return Ok(ast::Expr::CompoundIdentifier( id.iter() .map(|i| self.new_ident_quoted_if_needs(i.to_string())) @@ -693,7 +760,7 @@ impl Unparser<'_> { )); } Ok(ast::Expr::Identifier( - self.new_ident_quoted_if_needs(col.name.to_string()), + self.new_ident_quoted_if_needs(col_name), )) } @@ -997,8 +1064,19 @@ impl Unparser<'_> { where i64: From, { + let time_unit = match T::DATA_TYPE { + DataType::Timestamp(unit, _) => unit, + _ => { + return Err(internal_datafusion_err!( + "Expected Timestamp, got {:?}", + T::DATA_TYPE + )) + } + }; + let ts = if let Some(tz) = tz { - v.to_array()? + let dt = v + .to_array()? .as_any() .downcast_ref::>() .ok_or(internal_datafusion_err!( @@ -1007,8 +1085,8 @@ impl Unparser<'_> { .value_as_datetime_with_tz(0, tz.parse()?) .ok_or(internal_datafusion_err!( "Unable to convert {v:?} to DateTime" - ))? - .to_string() + ))?; + self.dialect.timestamp_with_tz_to_string(dt, time_unit) } else { v.to_array()? .as_any() @@ -1023,16 +1101,6 @@ impl Unparser<'_> { .to_string() }; - let time_unit = match T::DATA_TYPE { - DataType::Timestamp(unit, _) => unit, - _ => { - return Err(internal_datafusion_err!( - "Expected Timestamp, got {:?}", - T::DATA_TYPE - )) - } - }; - Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(ast::Expr::value(SingleQuotedString(ts))), @@ -1103,20 +1171,34 @@ impl Unparser<'_> { ScalarValue::Float16(None) => Ok(ast::Expr::value(ast::Value::Null)), ScalarValue::Float32(Some(f)) => { let f_val = match f.fract() { - 0.0 => format!("{:.1}", f), - _ => format!("{}", f), + 0.0 => format!("{f:.1}"), + _ => format!("{f}"), }; Ok(ast::Expr::value(ast::Value::Number(f_val, false))) } ScalarValue::Float32(None) => Ok(ast::Expr::value(ast::Value::Null)), ScalarValue::Float64(Some(f)) => { let f_val = match f.fract() { - 0.0 => format!("{:.1}", f), - _ => format!("{}", f), + 0.0 => format!("{f:.1}"), + _ => format!("{f}"), }; Ok(ast::Expr::value(ast::Value::Number(f_val, false))) } ScalarValue::Float64(None) => Ok(ast::Expr::value(ast::Value::Null)), + ScalarValue::Decimal32(Some(value), precision, scale) => { + Ok(ast::Expr::value(ast::Value::Number( + Decimal32Type::format_decimal(*value, *precision, *scale), + false, + ))) + } + ScalarValue::Decimal32(None, ..) => Ok(ast::Expr::value(ast::Value::Null)), + ScalarValue::Decimal64(Some(value), precision, scale) => { + Ok(ast::Expr::value(ast::Value::Number( + Decimal64Type::format_decimal(*value, *precision, *scale), + false, + ))) + } + ScalarValue::Decimal64(None, ..) => Ok(ast::Expr::value(ast::Value::Null)), ScalarValue::Decimal128(Some(value), precision, scale) => { Ok(ast::Expr::value(ast::Value::Number( Decimal128Type::format_decimal(*value, *precision, *scale), @@ -1593,7 +1675,7 @@ impl Unparser<'_> { fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { match data_type { DataType::Null => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Boolean => Ok(ast::DataType::Bool), DataType::Int8 => Ok(ast::DataType::TinyInt(None)), @@ -1605,9 +1687,9 @@ impl Unparser<'_> { DataType::UInt32 => Ok(ast::DataType::IntegerUnsigned(None)), DataType::UInt64 => Ok(ast::DataType::BigIntUnsigned(None)), DataType::Float16 => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::Float32 => Ok(ast::DataType::Float(None)), + DataType::Float32 => Ok(ast::DataType::Float(ast::ExactNumberInfo::None)), DataType::Float64 => Ok(self.dialect.float64_ast_dtype()), DataType::Timestamp(time_unit, tz) => { Ok(self.dialect.timestamp_cast_dtype(time_unit, tz)) @@ -1615,53 +1697,58 @@ impl Unparser<'_> { DataType::Date32 => Ok(self.dialect.date32_cast_dtype()), DataType::Date64 => Ok(self.ast_type_for_date64_in_cast()), DataType::Time32(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Time64(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Duration(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::Interval(_) => Ok(ast::DataType::Interval), + DataType::Interval(_) => Ok(ast::DataType::Interval { + fields: None, + precision: None, + }), DataType::Binary => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::FixedSizeBinary(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::LargeBinary => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::BinaryView => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()), DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()), DataType::Utf8View => Ok(self.dialect.utf8_cast_dtype()), DataType::List(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::FixedSizeList(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::LargeList(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::ListView(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::LargeListView(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Struct(_) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Union(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val), - DataType::Decimal128(precision, scale) + DataType::Decimal32(precision, scale) + | DataType::Decimal64(precision, scale) + | DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { let mut new_precision = *precision as u64; let mut new_scale = *scale as u64; @@ -1671,14 +1758,17 @@ impl Unparser<'_> { } Ok(ast::DataType::Decimal( - ast::ExactNumberInfo::PrecisionAndScale(new_precision, new_scale), + ast::ExactNumberInfo::PrecisionAndScale( + new_precision, + new_scale as i64, + ), )) } DataType::Map(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } DataType::RunEndEncoded(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + not_impl_err!("Unsupported DataType: conversion: {data_type}") } } } @@ -1689,6 +1779,7 @@ mod tests { use std::ops::{Add, Sub}; use std::{any::Any, sync::Arc, vec}; + use crate::unparser::dialect::SqliteDialect; use arrow::array::{LargeListArray, ListArray}; use arrow::datatypes::{DataType::Int8, Field, Int32Type, Schema, TimeUnit}; use ast::ObjectName; @@ -1701,6 +1792,7 @@ mod tests { ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; + use datafusion_functions::datetime::from_unixtime::FromUnixtimeFunc; use datafusion_functions::expr_fn::{get_field, named_struct}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; @@ -1712,13 +1804,13 @@ mod tests { use crate::unparser::dialect::{ CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, - Dialect, DuckDBDialect, PostgreSqlDialect, ScalarFnToSqlHandler, + DefaultDialect, Dialect, DuckDBDialect, PostgreSqlDialect, ScalarFnToSqlHandler, }; use super::*; /// Mocked UDF - #[derive(Debug)] + #[derive(Debug, PartialEq, Eq, Hash)] struct DummyUDF { signature: Signature, } @@ -1853,10 +1945,20 @@ mod tests { expr: Box::new(col("a")), pattern: Box::new(lit("foo")), escape_char: Some('o'), - case_insensitive: true, + case_insensitive: false, }), r#"a NOT LIKE 'foo' ESCAPE 'o'"#, ), + ( + Expr::Like(Like { + negated: true, + expr: Box::new(col("a")), + pattern: Box::new(lit("foo")), + escape_char: Some('o'), + case_insensitive: true, + }), + r#"a NOT ILIKE 'foo' ESCAPE 'o'"#, + ), ( Expr::SimilarTo(Like { negated: false, @@ -1868,87 +1970,87 @@ mod tests { r#"a LIKE 'foo' ESCAPE 'o'"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(0))), + Expr::Literal(ScalarValue::Date64(Some(0)), None), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(10000))), + Expr::Literal(ScalarValue::Date64(Some(10000)), None), r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(-10000))), + Expr::Literal(ScalarValue::Date64(Some(-10000)), None), r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(0))), + Expr::Literal(ScalarValue::Date32(Some(0)), None), r#"CAST('1970-01-01' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(10))), + Expr::Literal(ScalarValue::Date32(Some(10)), None), r#"CAST('1970-01-11' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(-1))), + Expr::Literal(ScalarValue::Date32(Some(-1)), None), r#"CAST('1969-12-31' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None), None), r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampSecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampMillisecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampMicrosecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampNanosecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::Time32Second(Some(10001))), + Expr::Literal(ScalarValue::Time32Second(Some(10001)), None), r#"CAST('02:46:41' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))), + Expr::Literal(ScalarValue::Time32Millisecond(Some(10001)), None), r#"CAST('00:00:10.001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))), + Expr::Literal(ScalarValue::Time64Microsecond(Some(10001)), None), r#"CAST('00:00:00.010001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))), + Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001)), None), r#"CAST('00:00:00.000010001' AS TIME)"#, ), (sum(col("a")), r#"sum(a)"#), @@ -1977,7 +2079,7 @@ mod tests { "count(*) FILTER (WHERE true)", ), ( - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), params: WindowFunctionParams { args: vec![col("col")], @@ -1985,13 +2087,15 @@ mod tests { order_by: vec![], window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, + filter: None, }, }), r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, ), ( #[expect(deprecated)] - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), params: WindowFunctionParams { args: vec![Expr::Wildcard { @@ -2010,9 +2114,11 @@ mod tests { ), ), null_treatment: None, + distinct: false, + filter: Some(Box::new(col("a").gt(lit(100)))), }, }), - r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, + r#"count(*) FILTER (WHERE (a > 100)) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, ), (col("a").is_not_null(), r#"a IS NOT NULL"#), (col("a").is_null(), r#"a IS NULL"#), @@ -2093,19 +2199,31 @@ mod tests { (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), // See test_interval_scalar_to_expr for interval literals ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( - Some(100123), - 28, - 3, - ))), + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal32(Some(1123), 4, 3), + None, + )), + r#"((a + b) > 1.123)"#, + ), + ( + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal64(Some(1123), 4, 3), + None, + )), + r#"((a + b) > 1.123)"#, + ), + ( + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal128(Some(100123), 28, 3), + None, + )), r#"((a + b) > 100.123)"#, ), ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( - Some(100123.into()), - 28, - 3, - ))), + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal256(Some(100123.into()), 28, 3), + None, + )), r#"((a + b) > 100.123)"#, ), ( @@ -2141,28 +2259,39 @@ mod tests { "MAP {'a': 1, 'b': 2}", ), ( - Expr::Literal(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), - )), + Expr::Literal( + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + ), + None, + ), "'foo'", ), ( - Expr::Literal(ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ + Expr::Literal( + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< + Int32Type, + _, + _, + >(vec![Some(vec![ Some(1), Some(2), Some(3), - ])]), - ))), + ])]))), + None, + ), "[1, 2, 3]", ), ( - Expr::Literal(ScalarValue::LargeList(Arc::new( - LargeListArray::from_iter_primitive::(vec![Some( - vec![Some(1), Some(2), Some(3)], - )]), - ))), + Expr::Literal( + ScalarValue::LargeList(Arc::new( + LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ]), + )), + None, + ), "[1, 2, 3]", ), ( @@ -2186,7 +2315,7 @@ mod tests { for (expr, expected) in tests { let ast = expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2204,7 +2333,7 @@ mod tests { let expr = col("a").gt(lit(4)); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"('a' > 4)"#; assert_eq!(actual, expected); @@ -2220,7 +2349,7 @@ mod tests { let expr = col("a").gt(lit(4)); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"(a > 4)"#; assert_eq!(actual, expected); @@ -2244,7 +2373,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2269,7 +2398,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2291,7 +2420,7 @@ mod tests { let unparser = Unparser::new(&dialect); let ast = unparser.sort_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2467,11 +2596,17 @@ mod tests { #[test] fn test_float_scalar_to_expr() { let tests = [ - (Expr::Literal(ScalarValue::Float64(Some(3f64))), "3.0"), - (Expr::Literal(ScalarValue::Float64(Some(3.1f64))), "3.1"), - (Expr::Literal(ScalarValue::Float32(Some(-2f32))), "-2.0"), + (Expr::Literal(ScalarValue::Float64(Some(3f64)), None), "3.0"), + ( + Expr::Literal(ScalarValue::Float64(Some(3.1f64)), None), + "3.1", + ), ( - Expr::Literal(ScalarValue::Float32(Some(-2.989f32))), + Expr::Literal(ScalarValue::Float32(Some(-2f32)), None), + "-2.0", + ), + ( + Expr::Literal(ScalarValue::Float32(Some(-2.989f32)), None), "-2.989", ), ]; @@ -2491,18 +2626,20 @@ mod tests { let tests = [ ( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("blah".to_string())), + None, + )), data_type: DataType::Binary, }), "'blah'", ), ( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("blah".to_string())), + None, + )), data_type: DataType::BinaryView, }), "'blah'", @@ -2541,7 +2678,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2594,10 +2731,13 @@ mod tests { let expr = ScalarUDF::new_from_impl( datafusion_functions::datetime::date_part::DatePartFunc::new(), ) - .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + .call(vec![ + Expr::Literal(ScalarValue::new_utf8(unit), None), + col("x"), + ]); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2624,7 +2764,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2652,7 +2792,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2691,7 +2831,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2714,13 +2854,13 @@ mod tests { (&mysql_dialect, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Literal(ScalarValue::TimestampMillisecond( - Some(1738285549123), + let expr = Expr::Literal( + ScalarValue::TimestampMillisecond(Some(1738285549123), None), None, - )); + ); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST('2025-01-31 01:05:49.123' AS {identifier})"#); assert_eq!(actual, expected); @@ -2747,7 +2887,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2773,7 +2913,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = expected.to_string(); assert_eq!(actual, expected); @@ -2785,9 +2925,10 @@ mod tests { fn test_cast_value_to_dict_expr() { let tests = [( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "variation".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("variation".to_string())), + None, + )), data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), }), "'variation'", @@ -2825,12 +2966,12 @@ mod tests { expr: Box::new(col("a")), data_type: DataType::Float64, }), - Expr::Literal(ScalarValue::Int64(Some(2))), + Expr::Literal(ScalarValue::Int64(Some(2)), None), ], }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#); assert_eq!(actual, expected); @@ -2860,7 +3001,116 @@ mod tests { let func = WindowFunctionDefinition::WindowUDF(rank_udwf()); let mut window_func = WindowFunction::new(func, vec![]); window_func.params.order_by = vec![Sort::new(col("a"), true, true)]; - let expr = Expr::WindowFunction(window_func); + let expr = Expr::from(window_func); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = ast.to_string(); + let expected = expected.to_string(); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn test_from_unixtime() -> Result<()> { + let default_dialect: Arc = Arc::new(DefaultDialect {}); + let sqlite_dialect: Arc = Arc::new(SqliteDialect {}); + + for (dialect, expected) in [ + (default_dialect, "from_unixtime(date_col)"), + (sqlite_dialect, "datetime(`date_col`, 'unixepoch')"), + ] { + let unparser = Unparser::new(dialect.as_ref()); + let expr = Expr::ScalarFunction(ScalarFunction { + func: Arc::new(ScalarUDF::from(FromUnixtimeFunc::new())), + args: vec![col("date_col")], + }); + + let ast = unparser.expr_to_sql(&expr)?; + + let actual = ast.to_string(); + let expected = expected.to_string(); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn test_date_trunc() -> Result<()> { + let default_dialect: Arc = Arc::new(DefaultDialect {}); + let sqlite_dialect: Arc = Arc::new(SqliteDialect {}); + + for (dialect, precision, expected) in [ + ( + Arc::clone(&default_dialect), + "YEAR", + "date_trunc('YEAR', date_col)", + ), + ( + Arc::clone(&sqlite_dialect), + "YEAR", + "strftime('%Y', `date_col`)", + ), + ( + Arc::clone(&default_dialect), + "MONTH", + "date_trunc('MONTH', date_col)", + ), + ( + Arc::clone(&sqlite_dialect), + "MONTH", + "strftime('%Y-%m', `date_col`)", + ), + ( + Arc::clone(&default_dialect), + "DAY", + "date_trunc('DAY', date_col)", + ), + ( + Arc::clone(&sqlite_dialect), + "DAY", + "strftime('%Y-%m-%d', `date_col`)", + ), + ( + Arc::clone(&default_dialect), + "HOUR", + "date_trunc('HOUR', date_col)", + ), + ( + Arc::clone(&sqlite_dialect), + "HOUR", + "strftime('%Y-%m-%d %H', `date_col`)", + ), + ( + Arc::clone(&default_dialect), + "MINUTE", + "date_trunc('MINUTE', date_col)", + ), + ( + Arc::clone(&sqlite_dialect), + "MINUTE", + "strftime('%Y-%m-%d %H:%M', `date_col`)", + ), + (default_dialect, "SECOND", "date_trunc('SECOND', date_col)"), + ( + sqlite_dialect, + "SECOND", + "strftime('%Y-%m-%d %H:%M:%S', `date_col`)", + ), + ] { + let unparser = Unparser::new(dialect.as_ref()); + let expr = Expr::ScalarFunction(ScalarFunction { + func: Arc::new(ScalarUDF::from( + datafusion_functions::datetime::date_trunc::DateTruncFunc::new(), + )), + args: vec![ + Expr::Literal(ScalarValue::Utf8(Some(precision.to_string())), None), + col("date_col"), + ], + }); + let ast = unparser.expr_to_sql(&expr)?; let actual = ast.to_string(); @@ -2901,7 +3151,7 @@ mod tests { let expr = cast(col("a"), DataType::Utf8View); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"CAST(a AS CHAR)"#.to_string(); assert_eq!(actual, expected); @@ -2909,7 +3159,7 @@ mod tests { let expr = col("a").eq(lit(ScalarValue::Utf8View(Some("hello".to_string())))); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"(a = 'hello')"#.to_string(); assert_eq!(actual, expected); @@ -2917,7 +3167,7 @@ mod tests { let expr = col("a").is_not_null(); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"a IS NOT NULL"#.to_string(); assert_eq!(actual, expected); @@ -2925,7 +3175,7 @@ mod tests { let expr = col("a").is_null(); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"a IS NULL"#.to_string(); assert_eq!(actual, expected); @@ -2956,4 +3206,101 @@ mod tests { Ok(()) } + + #[test] + fn test_cast_timestamp_sqlite() -> Result<()> { + let dialect: Arc = Arc::new(SqliteDialect {}); + + let unparser = Unparser::new(dialect.as_ref()); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Timestamp(TimeUnit::Nanosecond, None), + }); + + let ast = unparser.expr_to_sql(&expr)?; + + let actual = ast.to_string(); + let expected = "CAST(`a` AS TEXT)".to_string(); + + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_timestamp_with_tz_format() -> Result<()> { + let default_dialect: Arc = + Arc::new(CustomDialectBuilder::new().build()); + + let duckdb_dialect: Arc = Arc::new(DuckDBDialect::new()); + + for (dialect, scalar, expected) in [ + ( + Arc::clone(&default_dialect), + ScalarValue::TimestampSecond(Some(1757934000), Some("+00:00".into())), + "CAST('2025-09-15 11:00:00 +00:00' AS TIMESTAMP)", + ), + ( + Arc::clone(&default_dialect), + ScalarValue::TimestampMillisecond( + Some(1757934000123), + Some("+01:00".into()), + ), + "CAST('2025-09-15 12:00:00.123 +01:00' AS TIMESTAMP)", + ), + ( + Arc::clone(&default_dialect), + ScalarValue::TimestampMicrosecond( + Some(1757934000123456), + Some("-01:00".into()), + ), + "CAST('2025-09-15 10:00:00.123456 -01:00' AS TIMESTAMP)", + ), + ( + Arc::clone(&default_dialect), + ScalarValue::TimestampNanosecond( + Some(1757934000123456789), + Some("+00:00".into()), + ), + "CAST('2025-09-15 11:00:00.123456789 +00:00' AS TIMESTAMP)", + ), + ( + Arc::clone(&duckdb_dialect), + ScalarValue::TimestampSecond(Some(1757934000), Some("+00:00".into())), + "CAST('2025-09-15 11:00:00+00:00' AS TIMESTAMP)", + ), + ( + Arc::clone(&duckdb_dialect), + ScalarValue::TimestampMillisecond( + Some(1757934000123), + Some("+01:00".into()), + ), + "CAST('2025-09-15 12:00:00.123+01:00' AS TIMESTAMP)", + ), + ( + Arc::clone(&duckdb_dialect), + ScalarValue::TimestampMicrosecond( + Some(1757934000123456), + Some("-01:00".into()), + ), + "CAST('2025-09-15 10:00:00.123456-01:00' AS TIMESTAMP)", + ), + ( + Arc::clone(&duckdb_dialect), + ScalarValue::TimestampNanosecond( + Some(1757934000123456789), + Some("+00:00".into()), + ), + "CAST('2025-09-15 11:00:00.123456789+00:00' AS TIMESTAMP)", + ), + ] { + let unparser = Unparser::new(dialect.as_ref()); + + let expr = Expr::Literal(scalar, None); + + let actual = format!("{}", unparser.expr_to_sql(&expr)?); + assert_eq!(actual, expected); + } + Ok(()) + } } diff --git a/datafusion/sql/src/unparser/extension_unparser.rs b/datafusion/sql/src/unparser/extension_unparser.rs index f7deabe7c9021..b778130ca5a27 100644 --- a/datafusion/sql/src/unparser/extension_unparser.rs +++ b/datafusion/sql/src/unparser/extension_unparser.rs @@ -64,6 +64,7 @@ pub enum UnparseWithinStatementResult { } /// The result of unparsing a custom logical node to a statement. +#[allow(clippy::large_enum_variant)] pub enum UnparseToStatementResult { /// If the custom logical node was successfully unparsed to a statement. Modified(Statement), diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index f90efd103b0f5..05b472dc92a93 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -118,9 +118,9 @@ impl<'a> Unparser<'a> { /// The child unparsers are called iteratively. /// There are two methods in [`Unparser`] will be called: /// - `extension_to_statement`: This method is called when the custom logical node is a custom statement. - /// If multiple child unparsers return a non-None value, the last unparsing result will be returned. + /// If multiple child unparsers return a non-None value, the last unparsing result will be returned. /// - `extension_to_sql`: This method is called when the custom logical node is part of a statement. - /// If multiple child unparsers are registered for the same custom logical node, all of them will be called in order. + /// If multiple child unparsers are registered for the same custom logical node, all of them will be called in order. pub fn with_extension_unparsers( mut self, extension_unparsers: Vec>, diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index a6d89638ff41d..b6c65614995a9 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -50,7 +50,7 @@ use datafusion_expr::{ UserDefinedLogicalNode, }; use sqlparser::ast::{self, Ident, OrderByKind, SetExpr, TableAliasColumnDef}; -use std::sync::Arc; +use std::{sync::Arc, vec}; /// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`] /// @@ -309,12 +309,13 @@ impl Unparser<'_> { plan: &LogicalPlan, relation: &mut RelationBuilder, lateral: bool, + columns: Vec, ) -> Result<()> { - if self.dialect.requires_derived_table_alias() { + if self.dialect.requires_derived_table_alias() || !columns.is_empty() { self.derive( plan, relation, - Some(self.new_table_alias(alias.to_string(), vec![])), + Some(self.new_table_alias(alias.to_string(), columns)), lateral, ) } else { @@ -392,6 +393,18 @@ impl Unparser<'_> { } } + // If it's a unnest projection, we should provide the table column alias + // to provide a column name for the unnest relation. + let columns = if unnest_input_type.is_some() { + p.expr + .iter() + .map(|e| { + self.new_ident_quoted_if_needs(e.schema_name().to_string()) + }) + .collect() + } else { + vec![] + }; // Projection can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -401,6 +414,7 @@ impl Unparser<'_> { unnest_input_type .filter(|t| matches!(t, UnnestInputType::OuterReference)) .is_some(), + columns, ); } self.reconstruct_select_statement(plan, p, select)?; @@ -434,6 +448,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } if let Some(fetch) = &limit.fetch { @@ -451,6 +466,7 @@ impl Unparser<'_> { "Offset operator only valid in a statement context." ); }; + query.offset(Some(ast::Offset { rows: ast::OffsetRows::None, value: self.expr_to_sql(skip)?, @@ -472,6 +488,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } let Some(query_ref) = query else { @@ -493,7 +510,7 @@ impl Unparser<'_> { .expr .iter() .map(|sort_expr| { - unproject_sort_expr(sort_expr, agg, sort.input.as_ref()) + unproject_sort_expr(sort_expr.clone(), agg, sort.input.as_ref()) }) .collect::>>()?; @@ -543,8 +560,26 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } + + // If this distinct is the parent of a Union and we're in a query context, + // then we need to unparse as a `UNION` rather than a `UNION ALL`. + if let Distinct::All(input) = distinct { + if matches!(input.as_ref(), LogicalPlan::Union(_)) { + if let Some(query_mut) = query.as_mut() { + query_mut.distinct_union(); + return self.select_to_sql_recursively( + input.as_ref(), + query, + select, + relation, + ); + } + } + } + let (select_distinct, input) = match distinct { Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()), Distinct::On(on) => { @@ -582,6 +617,10 @@ impl Unparser<'_> { } _ => (&join.left, &join.right), }; + // If there's an outer projection plan, it will already set up the projection. + // In that case, we don't need to worry about setting up the projection here. + // The outer projection plan will handle projecting the correct columns. + let already_projected = select.already_projected(); let left_plan = match try_transform_to_simple_table_scan_with_filters(left_plan)? { @@ -599,6 +638,13 @@ impl Unparser<'_> { relation, )?; + let left_projection: Option> = if !already_projected + { + Some(select.pop_projections()) + } else { + None + }; + let right_plan = match try_transform_to_simple_table_scan_with_filters(right_plan)? { Some((plan, filters)) => { @@ -650,19 +696,20 @@ impl Unparser<'_> { join_filters.as_ref(), )?; - self.select_to_sql_recursively( - right_plan.as_ref(), - query, - select, - &mut right_relation, - )?; + let right_projection: Option> = if !already_projected + { + Some(select.pop_projections()) + } else { + None + }; match join.join_type { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { let mut query_builder = QueryBuilder::default(); let mut from = TableWithJoinsBuilder::default(); let mut exists_select: SelectBuilder = SelectBuilder::default(); @@ -686,7 +733,8 @@ impl Unparser<'_> { let negated = match join.join_type { JoinType::LeftSemi | JoinType::RightSemi - | JoinType::LeftMark => false, + | JoinType::LeftMark + | JoinType::RightMark => false, JoinType::LeftAnti | JoinType::RightAnti => true, _ => unreachable!(), }; @@ -694,13 +742,28 @@ impl Unparser<'_> { subquery: Box::new(query_builder.build()?), negated, }; - if join.join_type == JoinType::LeftMark { - let (table_ref, _) = right_plan.schema().qualified_field(0); - let column = self - .col_to_sql(&Column::new(table_ref.cloned(), "mark"))?; - select.replace_mark(&column, &exists_expr); - } else { - select.selection(Some(exists_expr)); + + match join.join_type { + JoinType::LeftMark | JoinType::RightMark => { + let source_schema = + if join.join_type == JoinType::LeftMark { + right_plan.schema() + } else { + left_plan.schema() + }; + let (table_ref, _) = source_schema.qualified_field(0); + let column = self.col_to_sql(&Column::new( + table_ref.cloned(), + "mark", + ))?; + select.replace_mark(&column, &exists_expr); + } + _ => { + select.selection(Some(exists_expr)); + } + } + if let Some(projection) = left_projection { + select.projection(projection); } } JoinType::Inner @@ -719,6 +782,21 @@ impl Unparser<'_> { let mut from = select.pop_from().unwrap(); from.push_join(ast_join); select.push_from(from); + if !already_projected { + let Some(left_projection) = left_projection else { + return internal_err!("Left projection is missing"); + }; + + let Some(right_projection) = right_projection else { + return internal_err!("Right projection is missing"); + }; + + let projection = left_projection + .into_iter() + .chain(right_projection) + .collect(); + select.projection(projection); + } } }; @@ -780,6 +858,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } @@ -793,6 +872,15 @@ impl Unparser<'_> { return internal_err!("UNION operator requires at least 2 inputs"); } + let set_quantifier = + if query.as_ref().is_some_and(|q| q.is_distinct_union()) { + // Setting the SetQuantifier to None will unparse as a `UNION` + // rather than a `UNION ALL`. + ast::SetQuantifier::None + } else { + ast::SetQuantifier::All + }; + // Build the union expression tree bottom-up by reversing the order // note that we are also swapping left and right inputs because of the rev let union_expr = input_exprs @@ -800,7 +888,7 @@ impl Unparser<'_> { .rev() .reduce(|a, b| SetExpr::SetOperation { op: ast::SetOperator::Union, - set_quantifier: ast::SetQuantifier::All, + set_quantifier, left: Box::new(b), right: Box::new(a), }) @@ -888,6 +976,7 @@ impl Unparser<'_> { subquery.subquery.as_ref(), relation, true, + vec![], ) } } @@ -900,9 +989,9 @@ impl Unparser<'_> { /// Try to find the placeholder column name generated by `RecursiveUnnestRewriter`. /// /// - If the column is a placeholder column match the pattern `Expr::Alias(Expr::Column("__unnest_placeholder(...)"))`, - /// it means it is a scalar column, return [UnnestInputType::Scalar]. + /// it means it is a scalar column, return [UnnestInputType::Scalar]. /// - If the column is a placeholder column match the pattern `Expr::Alias(Expr::Column("__unnest_placeholder(outer_ref(...)))")`, - /// it means it is an outer reference column, return [UnnestInputType::OuterReference]. + /// it means it is an outer reference column, return [UnnestInputType::OuterReference]. /// - If the column is not a placeholder column, return [None]. /// /// `outer_ref` is the display result of [Expr::OuterReferenceColumn] @@ -910,8 +999,7 @@ impl Unparser<'_> { if let Expr::Alias(Alias { expr, .. }) = expr { if let Expr::Column(Column { name, .. }) = expr.as_ref() { if let Some(prefix) = name.strip_prefix(UNNEST_PLACEHOLDER) { - if prefix.starts_with(&format!("({}(", OUTER_REFERENCE_COLUMN_PREFIX)) - { + if prefix.starts_with(&format!("({OUTER_REFERENCE_COLUMN_PREFIX}(")) { return Some(UnnestInputType::OuterReference); } return Some(UnnestInputType::Scalar); @@ -998,6 +1086,7 @@ impl Unparser<'_> { if project_vec.is_empty() { builder = builder.project(vec![Expr::Literal( ScalarValue::Int64(Some(1)), + None, )])?; } else { let project_columns = project_vec @@ -1118,9 +1207,18 @@ impl Unparser<'_> { Expr::Alias(Alias { expr, name, .. }) => { let inner = self.expr_to_sql(expr)?; + // Determine the alias name to use + let col_name = if let Some(rewritten_name) = + self.dialect.col_alias_overrides(name)? + { + rewritten_name.to_string() + } else { + name.to_string() + }; + Ok(ast::SelectItem::ExprWithAlias { expr: inner, - alias: self.new_ident_quoted_if_needs(name.to_string()), + alias: self.new_ident_quoted_if_needs(col_name), }) } _ => { @@ -1153,7 +1251,7 @@ impl Unparser<'_> { ast::JoinConstraint::None => { // Inner joins with no conditions or filters are not valid SQL in most systems, // return a CROSS JOIN instead - ast::JoinOperator::CrossJoin + ast::JoinOperator::CrossJoin(constraint) } }, JoinType::Left => ast::JoinOperator::LeftOuter(constraint), @@ -1163,7 +1261,9 @@ impl Unparser<'_> { JoinType::LeftSemi => ast::JoinOperator::LeftSemi(constraint), JoinType::RightAnti => ast::JoinOperator::RightAnti(constraint), JoinType::RightSemi => ast::JoinOperator::RightSemi(constraint), - JoinType::LeftMark => unimplemented!("Unparsing of Left Mark join type"), + JoinType::LeftMark | JoinType::RightMark => { + unimplemented!("Unparsing of Mark join type") + } }) } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 75038ccc43145..8b3791017a8af 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -203,7 +203,7 @@ pub(crate) fn unproject_agg_exprs( windows.and_then(|w| find_window_expr(w, &c.name).cloned()) { // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - return Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)); + Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) } else { internal_err!( "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name @@ -270,51 +270,58 @@ fn find_window_expr<'a>( .find(|expr| expr.schema_name().to_string() == column_name) } -/// Transforms a Column expression into the actual expression from aggregation or projection if found. +/// Transforms all Column expressions in a sort expression into the actual expression from aggregation or projection if found. /// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced /// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to /// the actual expression, such as sum("catalog_returns"."cr_net_loss"). pub(crate) fn unproject_sort_expr( - sort_expr: &SortExpr, + mut sort_expr: SortExpr, agg: Option<&Aggregate>, input: &LogicalPlan, ) -> Result { - let mut sort_expr = sort_expr.clone(); - - // Remove alias if present, because ORDER BY cannot use aliases - if let Expr::Alias(alias) = &sort_expr.expr { - sort_expr.expr = *alias.expr.clone(); - } - - let Expr::Column(ref col_ref) = sort_expr.expr else { - return Ok(sort_expr); - }; + sort_expr.expr = sort_expr + .expr + .transform(|sub_expr| { + match sub_expr { + // Remove alias if present, because ORDER BY cannot use aliases + Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + Expr::Column(col) => { + if col.relation.is_some() { + return Ok(Transformed::no(Expr::Column(col))); + } - if col_ref.relation.is_some() { - return Ok(sort_expr); - }; + // In case of aggregation there could be columns containing aggregation functions we need to unproject + if let Some(agg) = agg { + if agg.schema.is_column_from_schema(&col) { + return Ok(Transformed::yes(unproject_agg_exprs( + Expr::Column(col), + agg, + None, + )?)); + } + } - // In case of aggregation there could be columns containing aggregation functions we need to unproject - if let Some(agg) = agg { - if agg.schema.is_column_from_schema(col_ref) { - let new_expr = unproject_agg_exprs(sort_expr.expr, agg, None)?; - sort_expr.expr = new_expr; - return Ok(sort_expr); - } - } + // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will + // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need + // to transform it back to the actual expression. + if let LogicalPlan::Projection(Projection { expr, schema, .. }) = + input + { + if let Ok(idx) = schema.index_of_column(&col) { + if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { + return Ok(Transformed::yes(Expr::ScalarFunction( + scalar_fn.clone(), + ))); + } + } + } - // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will - // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need - // to transform it back to the actual expression. - if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input { - if let Ok(idx) = schema.index_of_column(col_ref) { - if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { - sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone()); + Ok(Transformed::no(Expr::Column(col))) + } + _ => Ok(Transformed::no(sub_expr)), } - } - return Ok(sort_expr); - } - + }) + .map(|e| e.data)?; Ok(sort_expr) } @@ -385,7 +392,7 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( let mut builder = LogicalPlanBuilder::scan( table_scan.table_name.clone(), Arc::clone(&table_scan.source), - None, + table_scan.projection.clone(), )?; if let Some(alias) = table_alias.take() { @@ -415,7 +422,7 @@ pub(crate) fn date_part_to_sql( match (style, date_part_args.len()) { (DateFieldExtractStyle::Extract, 2) => { let date_expr = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => ast::DateTimeField::Year, "month" => ast::DateTimeField::Month, @@ -436,7 +443,7 @@ pub(crate) fn date_part_to_sql( (DateFieldExtractStyle::Strftime, 2) => { let column = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => "%Y", "month" => "%m", @@ -500,3 +507,72 @@ pub(crate) fn character_length_to_sql( character_length_args, )?)) } + +/// SQLite does not support timestamp/date scalars like `to_timestamp`, `from_unixtime`, `date_trunc`, etc. +/// This remaps `from_unixtime` to `datetime(expr, 'unixepoch')`, expecting the input to be in seconds. +/// It supports no other arguments, so if any are supplied it will return an error. +/// +/// # Errors +/// +/// - If the number of arguments is not 1 - the column or expression to convert. +/// - If the scalar function cannot be converted to SQL. +pub(crate) fn sqlite_from_unixtime_to_sql( + unparser: &Unparser, + from_unixtime_args: &[Expr], +) -> Result> { + if from_unixtime_args.len() != 1 { + return internal_err!( + "from_unixtime for SQLite expects 1 argument, found {}", + from_unixtime_args.len() + ); + } + + Ok(Some(unparser.scalar_function_to_sql( + "datetime", + &[ + from_unixtime_args[0].clone(), + Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string())), None), + ], + )?)) +} + +/// SQLite does not support timestamp/date scalars like `to_timestamp`, `from_unixtime`, `date_trunc`, etc. +/// This uses the `strftime` function to format the timestamp as a string depending on the truncation unit. +/// +/// # Errors +/// +/// - If the number of arguments is not 2 - truncation unit and the column or expression to convert. +/// - If the scalar function cannot be converted to SQL. +pub(crate) fn sqlite_date_trunc_to_sql( + unparser: &Unparser, + date_trunc_args: &[Expr], +) -> Result> { + if date_trunc_args.len() != 2 { + return internal_err!( + "date_trunc for SQLite expects 2 arguments, found {}", + date_trunc_args.len() + ); + } + + if let Expr::Literal(ScalarValue::Utf8(Some(unit)), _) = &date_trunc_args[0] { + let format = match unit.to_lowercase().as_str() { + "year" => "%Y", + "month" => "%Y-%m", + "day" => "%Y-%m-%d", + "hour" => "%Y-%m-%d %H", + "minute" => "%Y-%m-%d %H:%M", + "second" => "%Y-%m-%d %H:%M:%S", + _ => return Ok(None), + }; + + return Ok(Some(unparser.scalar_function_to_sql( + "strftime", + &[ + Expr::Literal(ScalarValue::Utf8(Some(format.to_string())), None), + date_trunc_args[1].clone(), + ], + )?)); + } + + Ok(None) +} diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index bc2a94cd44ff7..3c86d2d04905f 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -26,8 +26,8 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{ - exec_err, internal_err, plan_err, Column, DFSchemaRef, DataFusionError, Diagnostic, - HashMap, Result, ScalarValue, + exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchemaRef, + Diagnostic, HashMap, Result, ScalarValue, }; use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{ @@ -92,21 +92,30 @@ pub(crate) fn rebase_expr( .data() } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum CheckColumnsMustReferenceAggregatePurpose { + Projection, + Having, + Qualify, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum CheckColumnsSatisfyExprsPurpose { - ProjectionMustReferenceAggregate, - HavingMustReferenceAggregate, + Aggregate(CheckColumnsMustReferenceAggregatePurpose), } impl CheckColumnsSatisfyExprsPurpose { fn message_prefix(&self) -> &'static str { match self { - CheckColumnsSatisfyExprsPurpose::ProjectionMustReferenceAggregate => { + Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => { "Column in SELECT must be in GROUP BY or an aggregate function" } - CheckColumnsSatisfyExprsPurpose::HavingMustReferenceAggregate => { + Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => { "Column in HAVING must be in GROUP BY or an aggregate function" } + Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => { + "Column in QUALIFY must be in GROUP BY or an aggregate function" + } } } @@ -158,20 +167,19 @@ fn check_column_satisfies_expr( purpose: CheckColumnsSatisfyExprsPurpose, ) -> Result<()> { if !columns.contains(expr) { + let diagnostic = Diagnostic::new_error( + purpose.diagnostic_message(expr), + expr.spans().and_then(|spans| spans.first()), + ) + .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None); + return plan_err!( "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement", purpose.message_prefix(), expr, - expr_vec_fmt!(columns) - ) - .map_err(|err| { - let diagnostic = Diagnostic::new_error( - purpose.diagnostic_message(expr), - expr.spans().and_then(|spans| spans.first()), - ) - .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None); - err.with_diagnostic(diagnostic) - }); + expr_vec_fmt!(columns); + diagnostic=diagnostic + ); } Ok(()) } @@ -199,7 +207,7 @@ pub(crate) fn resolve_positions_to_exprs( match expr { // sql_expr_to_logical_expr maps number to i64 // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 - Expr::Literal(ScalarValue::Int64(Some(position))) + Expr::Literal(ScalarValue::Int64(Some(position)), _) if position > 0_i64 && position <= select_exprs.len() as i64 => { let index = (position - 1) as usize; @@ -209,7 +217,7 @@ pub(crate) fn resolve_positions_to_exprs( _ => select_expr.clone(), }) } - Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!( + Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!( "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}", position, select_exprs.len() ), @@ -242,15 +250,21 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { partition_by, .. }, - .. - }) => Ok(partition_by), - Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { + Expr::WindowFunction(window_fun) => { + let WindowFunction { params: WindowFunctionParams { partition_by, .. }, .. - }) => Ok(partition_by), + } = window_fun.as_ref(); + Ok(partition_by) + } + Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + params: WindowFunctionParams { partition_by, .. }, + .. + } = window_fun.as_ref(); + Ok(partition_by) + } expr => exec_err!("Impossibly got non-window expr {expr:?}"), }, expr => exec_err!("Impossibly got non-window expr {expr:?}"), @@ -259,9 +273,7 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let result = all_partition_keys .iter() .min_by_key(|s| s.len()) - .ok_or_else(|| { - DataFusionError::Execution("No window expressions found".to_owned()) - })?; + .ok_or_else(|| exec_datafusion_err!("No window expressions found"))?; Ok(result) } @@ -399,9 +411,9 @@ impl RecursiveUnnestRewriter<'_> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - let placeholder_name = format!("{UNNEST_PLACEHOLDER}({})", inner_expr_name); + let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})"); let post_unnest_name = - format!("{UNNEST_PLACEHOLDER}({},depth={})", inner_expr_name, level); + format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})"); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); @@ -543,7 +555,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { let most_inner = unnest_stack.first().unwrap(); let inner_expr = most_inner.expr.as_ref(); // unnest(unnest(struct_arr_col)) is not allow to be done recursively - // it needs to be splitted into multiple unnest logical plan + // it needs to be split into multiple unnest logical plan // unnest(struct_arr) // unnest(struct_arr_col) as struct_arr // instead of unnest(struct_arr_col, depth = 2) @@ -681,13 +693,13 @@ mod tests { "{}=>[{}]", i.0, vec.iter() - .map(|i| format!("{}", i)) + .map(|i| format!("{i}")) .collect::>() .join(", ") ), }) .collect(); - let l_formatted: Vec = l.iter().map(|i| i.to_string()).collect(); + let l_formatted: Vec = l.iter().map(|i| (*i).to_string()).collect(); assert_eq!(l_formatted, r_formatted); } diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index ebb21e9cdef53..8648dffb50046 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -16,19 +16,21 @@ // under the License. use datafusion_functions::string; +use insta::assert_snapshot; use std::{collections::HashMap, sync::Arc}; use datafusion_common::{Diagnostic, Location, Result, Span}; -use datafusion_sql::planner::{ParserOptions, SqlToRel}; +use datafusion_sql::{ + parser::{DFParser, DFParserBuilder}, + planner::{ParserOptions, SqlToRel}, +}; use regex::Regex; -use sqlparser::{dialect::GenericDialect, parser::Parser}; use crate::{MockContextProvider, MockSessionState}; fn do_query(sql: &'static str) -> Diagnostic { - let dialect = GenericDialect {}; - let statement = Parser::new(&dialect) - .try_with_sql(sql) + let statement = DFParserBuilder::new(sql) + .build() .expect("unable to create parser") .parse_statement() .expect("unable to parse query"); @@ -40,7 +42,7 @@ fn do_query(sql: &'static str) -> Diagnostic { .with_scalar_function(Arc::new(string::concat().as_ref().clone())); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new_with_options(&context, options); - match sql_to_rel.sql_statement_to_plan(statement) { + match sql_to_rel.statement_to_plan(statement) { Ok(_) => panic!("expected error"), Err(err) => match err.diagnostic() { Some(diag) => diag.clone(), @@ -136,7 +138,7 @@ fn test_table_not_found() -> Result<()> { let query = "SELECT * FROM /*a*/personx/*a*/"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "table 'personx' not found"); + assert_snapshot!(diag.message, @"table 'personx' not found"); assert_eq!(diag.span, Some(spans["a"])); Ok(()) } @@ -146,7 +148,7 @@ fn test_unqualified_column_not_found() -> Result<()> { let query = "SELECT /*a*/first_namex/*a*/ FROM person"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "column 'first_namex' not found"); + assert_snapshot!(diag.message, @"column 'first_namex' not found"); assert_eq!(diag.span, Some(spans["a"])); Ok(()) } @@ -156,7 +158,7 @@ fn test_qualified_column_not_found() -> Result<()> { let query = "SELECT /*a*/person.first_namex/*a*/ FROM person"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "column 'first_namex' not found in 'person'"); + assert_snapshot!(diag.message, @"column 'first_namex' not found in 'person'"); assert_eq!(diag.span, Some(spans["a"])); Ok(()) } @@ -166,14 +168,11 @@ fn test_union_wrong_number_of_columns() -> Result<()> { let query = "/*whole+left*/SELECT first_name FROM person/*left*/ UNION ALL /*right*/SELECT first_name, last_name FROM person/*right+whole*/"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!( - diag.message, - "UNION queries have different number of columns" - ); + assert_snapshot!(diag.message, @"UNION queries have different number of columns"); assert_eq!(diag.span, Some(spans["whole"])); - assert_eq!(diag.notes[0].message, "this side has 1 fields"); + assert_snapshot!(diag.notes[0].message, @"this side has 1 fields"); assert_eq!(diag.notes[0].span, Some(spans["left"])); - assert_eq!(diag.notes[1].message, "this side has 2 fields"); + assert_snapshot!(diag.notes[1].message, @"this side has 2 fields"); assert_eq!(diag.notes[1].span, Some(spans["right"])); Ok(()) } @@ -183,15 +182,9 @@ fn test_missing_non_aggregate_in_group_by() -> Result<()> { let query = "SELECT id, /*a*/first_name/*a*/ FROM person GROUP BY id"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!( - diag.message, - "'person.first_name' must appear in GROUP BY clause because it's not an aggregate expression" - ); + assert_snapshot!(diag.message, @"'person.first_name' must appear in GROUP BY clause because it's not an aggregate expression"); assert_eq!(diag.span, Some(spans["a"])); - assert_eq!( - diag.helps[0].message, - "Either add 'person.first_name' to GROUP BY clause, or use an aggregare function like ANY_VALUE(person.first_name)" - ); + assert_snapshot!(diag.helps[0].message, @"Either add 'person.first_name' to GROUP BY clause, or use an aggregate function like ANY_VALUE(person.first_name)"); Ok(()) } @@ -200,10 +193,10 @@ fn test_ambiguous_reference() -> Result<()> { let query = "SELECT /*a*/first_name/*a*/ FROM person a, person b"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "column 'first_name' is ambiguous"); + assert_snapshot!(diag.message, @"column 'first_name' is ambiguous"); assert_eq!(diag.span, Some(spans["a"])); - assert_eq!(diag.notes[0].message, "possible column a.first_name"); - assert_eq!(diag.notes[1].message, "possible column b.first_name"); + assert_snapshot!(diag.notes[0].message, @"possible column a.first_name"); + assert_snapshot!(diag.notes[1].message, @"possible column b.first_name"); Ok(()) } @@ -213,11 +206,11 @@ fn test_incompatible_types_binary_arithmetic() -> Result<()> { "SELECT /*whole+left*/id/*left*/ + /*right*/first_name/*right+whole*/ FROM person"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "expressions have incompatible types"); + assert_snapshot!(diag.message, @"expressions have incompatible types"); assert_eq!(diag.span, Some(spans["whole"])); - assert_eq!(diag.notes[0].message, "has type UInt32"); + assert_snapshot!(diag.notes[0].message, @"has type UInt32"); assert_eq!(diag.notes[0].span, Some(spans["left"])); - assert_eq!(diag.notes[1].message, "has type Utf8"); + assert_snapshot!(diag.notes[1].message, @"has type Utf8"); assert_eq!(diag.notes[1].span, Some(spans["right"])); Ok(()) } @@ -227,7 +220,7 @@ fn test_field_not_found_suggestion() -> Result<()> { let query = "SELECT /*whole*/first_na/*whole*/ FROM person"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "column 'first_na' not found"); + assert_snapshot!(diag.message, @"column 'first_na' not found"); assert_eq!(diag.span, Some(spans["whole"])); assert_eq!(diag.notes.len(), 1); @@ -243,7 +236,7 @@ fn test_field_not_found_suggestion() -> Result<()> { }) .collect(); suggested_fields.sort(); - assert_eq!(suggested_fields[0], "person.first_name"); + assert_snapshot!(suggested_fields[0], @"person.first_name"); Ok(()) } @@ -253,7 +246,7 @@ fn test_ambiguous_column_suggestion() -> Result<()> { let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "column 'id' is ambiguous"); + assert_snapshot!(diag.message, @"column 'id' is ambiguous"); assert_eq!(diag.span, Some(spans["whole"])); assert_eq!(diag.notes.len(), 2); @@ -281,8 +274,8 @@ fn test_invalid_function() -> Result<()> { let query = "SELECT /*whole*/concat_not_exist/*whole*/()"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "Invalid function 'concat_not_exist'"); - assert_eq!(diag.notes[0].message, "Possible function 'concat'"); + assert_snapshot!(diag.message, @"Invalid function 'concat_not_exist'"); + assert_snapshot!(diag.notes[0].message, @"Possible function 'concat'"); assert_eq!(diag.span, Some(spans["whole"])); Ok(()) } @@ -292,10 +285,7 @@ fn test_scalar_subquery_multiple_columns() -> Result<(), Box Result<(), Box> let spans = get_spans(query); let diag = do_query(query); - assert_eq!( - diag.message, - "Too many columns! The subquery should only return one column" - ); + assert_snapshot!(diag.message, @"Too many columns! The subquery should only return one column"); let expected_span = Some(Span { start: spans["id"].start, @@ -360,16 +347,10 @@ fn test_unary_op_plus_with_column() -> Result<()> { let query = "SELECT +/*whole*/first_name/*whole*/ FROM person"; let spans = get_spans(query); let diag = do_query(query); - assert_eq!(diag.message, "+ cannot be used with Utf8"); + assert_snapshot!(diag.message, @"+ cannot be used with Utf8"); assert_eq!(diag.span, Some(spans["whole"])); - assert_eq!( - diag.notes[0].message, - "+ can only be used with numbers, intervals, and timestamps" - ); - assert_eq!( - diag.helps[0].message, - "perhaps you need to cast person.first_name" - ); + assert_snapshot!(diag.notes[0].message, @"+ can only be used with numbers, intervals, and timestamps"); + assert_snapshot!(diag.helps[0].message, @"perhaps you need to cast person.first_name"); Ok(()) } @@ -379,16 +360,32 @@ fn test_unary_op_plus_with_non_column() -> Result<()> { let query = "SELECT +'a'"; let diag = do_query(query); assert_eq!(diag.message, "+ cannot be used with Utf8"); - assert_eq!( - diag.notes[0].message, - "+ can only be used with numbers, intervals, and timestamps" - ); + assert_snapshot!(diag.notes[0].message, @"+ can only be used with numbers, intervals, and timestamps"); assert_eq!(diag.notes[0].span, None); - assert_eq!( - diag.helps[0].message, - "perhaps you need to cast Utf8(\"a\")" - ); + assert_snapshot!(diag.helps[0].message, @r#"perhaps you need to cast Utf8("a")"#); assert_eq!(diag.helps[0].span, None); assert_eq!(diag.span, None); Ok(()) } + +#[test] +fn test_syntax_error() -> Result<()> { + // create a table with a column of type varchar + let query = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c1, p1 /*int*/int/*int*/) LOCATION 'foo.csv'"; + let spans = get_spans(query); + match DFParser::parse_sql(query) { + Ok(_) => panic!("expected error"), + Err(err) => match err.diagnostic() { + Some(diag) => { + let diag = diag.clone(); + assert_snapshot!(diag.message, @"Expected: ',' or ')' after partition definition, found: int at Line: 1, Column: 77"); + println!("{spans:?}"); + assert_eq!(diag.span, Some(spans["int"])); + Ok(()) + } + None => { + panic!("expected diagnostic") + } + }, + } +} diff --git a/datafusion/sql/tests/cases/mod.rs b/datafusion/sql/tests/cases/mod.rs index b3eedcdc41e35..426d188f633c0 100644 --- a/datafusion/sql/tests/cases/mod.rs +++ b/datafusion/sql/tests/cases/mod.rs @@ -17,4 +17,5 @@ mod collection; mod diagnostic; +mod params; mod plan_to_sql; diff --git a/datafusion/sql/tests/cases/params.rs b/datafusion/sql/tests/cases/params.rs new file mode 100644 index 0000000000000..343a90af3efb1 --- /dev/null +++ b/datafusion/sql/tests/cases/params.rs @@ -0,0 +1,887 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan; +use arrow::datatypes::DataType; +use datafusion_common::{assert_contains, ParamValues, ScalarValue}; +use datafusion_expr::{LogicalPlan, Prepare, Statement}; +use insta::assert_snapshot; +use itertools::Itertools as _; +use std::collections::HashMap; + +pub struct ParameterTest<'a> { + pub sql: &'a str, + pub expected_types: Vec<(&'a str, Option)>, + pub param_values: Vec, +} + +impl ParameterTest<'_> { + pub fn run(&self) -> String { + let plan = logical_plan(self.sql).unwrap(); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types: HashMap> = self + .expected_types + .iter() + .map(|(k, v)| ((*k).to_string(), v.clone())) + .collect(); + + assert_eq!(actual_types, expected_types); + + let plan_with_params = plan + .clone() + .with_param_values(self.param_values.clone()) + .unwrap(); + + format!("** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}") + } +} + +fn generate_prepare_stmt_and_data_types(sql: &str) -> (LogicalPlan, String) { + let plan = logical_plan(sql).unwrap(); + let data_types = match &plan { + LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) => { + data_types.iter().join(", ").to_string() + } + _ => panic!("Expected a Prepare statement"), + }; + (plan, data_types) +} + +#[test] +fn test_prepare_statement_to_plan_panic_param_format() { + // param is not number following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; + + assert_snapshot!( + logical_plan(sql).unwrap_err().strip_backtrace(), + @r###" + Error during planning: Invalid placeholder, not a number: $foo + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_panic_param_zero() { + // param is zero following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $0"; + + assert_snapshot!( + logical_plan(sql).unwrap_err().strip_backtrace(), + @r###" + Error during planning: Invalid placeholder, zero is not a valid index: $0 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { + // param is not number following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; + assert!(logical_plan(sql) + .unwrap_err() + .strip_backtrace() + .contains("Expected: AS, found: SELECT")) +} + +#[test] +fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; + + let plan = logical_plan(sql).unwrap_err().strip_backtrace(); + assert_snapshot!( + plan, + @r"Schema error: No field named id." + ); +} + +#[test] +fn test_prepare_statement_should_infer_types() { + // only provide 1 data type while using 2 params + let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + ("$1".to_string(), Some(DataType::Int32)), + ("$2".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + +#[test] +fn test_non_prepare_statement_should_infer_types() { + // Non prepared statements (like SELECT) should also have their parameter types inferred + let sql = "SELECT 1 + $1"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + // constant 1 is inferred to be int64 + ("$1".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + +#[test] +#[should_panic( + expected = "Expected: [NOT] NULL | TRUE | FALSE | DISTINCT | [form] NORMALIZED FROM after IS, found: $1" +)] +fn test_prepare_statement_to_plan_panic_is_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; + logical_plan(sql).unwrap(); +} + +#[test] +fn test_prepare_statement_to_plan_no_param() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"Int32"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + " + ); + + ////////////////////////////////////////// + // no embedded parameter and no declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [] + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + "# + ); + assert_snapshot!(dt, @r#""#); + + /////////////////// + // replace params with values + let param_values: Vec = vec![]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_one_param_no_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values: Vec = vec![]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected 1 parameters, got 0 + "###); +} + +#[test] +fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Float64(Some(20.0))]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected parameter of type Int32, got Float64 at index 0 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_no_param_on_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Int32(Some(10))]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected 0 parameters, got 1 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_params_as_constants() { + let sql = "PREPARE my_plan(INT) AS SELECT $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: $1 + EmptyRelation: rows=1 + "# + ); + assert_snapshot!(dt, @r#"Int32"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int32(10) AS $1 + EmptyRelation: rows=1 + " + ); + + /////////////////////////////////////// + let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: Int64(1) + $1 + EmptyRelation: rows=1 + "# + ); + assert_snapshot!(dt, @r#"Int32"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int64(1) + Int32(10) AS Int64(1) + $1 + EmptyRelation: rows=1 + " + ); + + /////////////////////////////////////// + let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Float64] + Projection: Int64(1) + $1 + $2 + EmptyRelation: rows=1 + "# + ); + assert_snapshot!(dt, @r#"Int32, Float64"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(10.0)), + ]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2 + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn test_infer_types_from_join() { + let test = ParameterTest { + sql: + "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 + TableScan: person + TableScan: orders + ** Final Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) + TableScan: person + TableScan: orders + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_join() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32] + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 + TableScan: person + TableScan: orders + ** Final Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) + TableScan: person + TableScan: orders + "# + ); +} + +#[test] +fn test_infer_types_from_predicate() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_predicate() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + "# + ); +} + +#[test] +fn test_infer_types_from_between_predicate() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_between_predicate() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32, Int32] + Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + "# + ); +} + +#[test] +fn test_infer_types_subquery() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = $1 + TableScan: person + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = UInt32(10) + TableScan: person + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_subquery() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [UInt32] + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = $1 + TableScan: person + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = UInt32(10) + TableScan: person + TableScan: person + "# + ); +} + +#[test] +fn test_update_infer() { + let test = ParameterTest { + sql: "update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = $2 + TableScan: person + ** Final Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = UInt32(1) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_update_infer() { + let test = ParameterTest { + sql: "PREPARE my_plan AS update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32, UInt32] + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = $2 + TableScan: person + ** Final Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = UInt32(1) + TableScan: person + "# + ); +} + +#[test] +fn test_insert_infer() { + let test = ParameterTest { + sql: "insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ], + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: ($1, $2, $3) + ** Final Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) + "# + ); +} + +#[test] +fn test_prepare_statement_insert_infer() { + let test = ParameterTest { + sql: "PREPARE my_plan AS insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ] + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [UInt32, Utf8, Utf8] + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: ($1, $2, $3) + ** Final Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_one_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"Int32"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_data_type() { + let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; + + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 + // Prepare statement and its logical plan should be created successfully + @r#" + Prepare: "my_plan" [Float64] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"Float64"#); + + /////////////////// + // replace params with values still succeed and use Float64 + let param_values = vec![ScalarValue::Float64(Some(10.0))]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Float64(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_multi_params() { + let sql = "PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS + SELECT id, age, $6 + FROM person + WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Utf8View, Float64, Int32, Float64, Utf8View] + Projection: person.id, person.age, $6 + Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"Int32, Utf8View, Float64, Int32, Float64, Utf8View"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Utf8View(Some("abc".into())), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Int32(Some(20)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Utf8View(Some("xyz".into())), + ]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Projection: person.id, person.age, Utf8View("xyz") AS $6 + Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8View("abc") + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_having() { + let sql = "PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS + SELECT id, sum(age) + FROM person \ + WHERE salary > $2 + GROUP BY id + HAVING sum(age) < $1 AND sum(age) > 10 OR sum(age) in ($3, $4)\ + "; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Float64, Float64, Float64] + Projection: person.id, sum(person.age) + Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4]) + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + Filter: person.salary > $2 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"Int32, Float64, Float64, Float64"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Float64(Some(300.0)), + ]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Projection: person.id, sum(person.age) + Filter: sum(person.age) < Int32(10) AND sum(person.age) > Int64(10) OR sum(person.age) IN ([Float64(200), Float64(300)]) + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + Filter: person.salary > Float64(100) + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_limit() { + let sql = "PREPARE my_plan(BIGINT, BIGINT) AS + SELECT id FROM person \ + OFFSET $1 LIMIT $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int64, Int64] + Limit: skip=$1, fetch=$2 + Projection: person.id + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"Int64, Int64"#); + + // replace params with values + let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Limit: skip=10, fetch=200 + Projection: person.id + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_unknown_list_param() { + let sql = "SELECT id from person where id = $2"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with id $2" + ); +} + +#[test] +fn test_prepare_statement_unknown_hash_param() { + let sql = "SELECT id from person where id = $bar"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::Map(HashMap::new()); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with name $bar" + ); +} + +#[test] +fn test_prepare_statement_bad_list_idx() { + let sql = "SELECT id from person where id = $foo"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!(err.to_string(), "Error during planning: Failed to parse placeholder id: invalid digit found in string"); +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index b7185c2d503df..7aa982dcf3dd9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -16,8 +16,10 @@ // under the License. use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{ - assert_contains, Column, DFSchema, DFSchemaRef, Result, TableReference, + assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, + TableReference, }; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, @@ -33,11 +35,12 @@ use datafusion_functions_nested::map::map_udf; use datafusion_functions_window::rank::rank_udwf; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ - CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect, DefaultDialect, - Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect, + BigQueryDialect, CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect, + DefaultDialect, Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect, PostgreSqlDialect as UnparserPostgreSqlDialect, SqliteDialect, }; use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; +use insta::assert_snapshot; use sqlparser::ast::Statement; use std::hash::Hash; use std::ops::Add; @@ -49,6 +52,7 @@ use datafusion_expr::builder::{ project, subquery_alias, table_scan_with_filter_and_fetch, table_scan_with_filters, }; use datafusion_functions::core::planner::CoreFunctionPlanner; +use datafusion_functions::unicode::planner::UnicodeFunctionPlanner; use datafusion_functions_nested::extract::array_element_udf; use datafusion_functions_nested::planner::{FieldAccessPlanner, NestedFunctionPlanner}; use datafusion_sql::unparser::ast::{ @@ -62,46 +66,44 @@ use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; #[test] -fn roundtrip_expr() { - let tests: Vec<(TableReference, &str, &str)> = vec![ - (TableReference::bare("person"), "age > 35", r#"(age > 35)"#), - ( - TableReference::bare("person"), - "id = '10'", - r#"(id = '10')"#, - ), - ( - TableReference::bare("person"), - "CAST(id AS VARCHAR)", - r#"CAST(id AS VARCHAR)"#, - ), - ( - TableReference::bare("person"), - "sum((age * 2))", - r#"sum((age * 2))"#, - ), - ]; +fn test_roundtrip_expr_1() { + let expr = roundtrip_expr(TableReference::bare("person"), "age > 35").unwrap(); + assert_snapshot!(expr, @r#"(age > 35)"#); +} - let roundtrip = |table, sql: &str| -> Result { - let dialect = GenericDialect {}; - let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?; - let state = MockSessionState::default().with_aggregate_function(sum_udaf()); - let context = MockContextProvider { state }; - let schema = context.get_table_source(table)?.schema(); - let df_schema = DFSchema::try_from(schema.as_ref().clone())?; - let sql_to_rel = SqlToRel::new(&context); - let expr = - sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; +#[test] +fn test_roundtrip_expr_2() { + let expr = roundtrip_expr(TableReference::bare("person"), "id = '10'").unwrap(); + assert_snapshot!(expr, @r#"(id = '10')"#); +} + +#[test] +fn test_roundtrip_expr_3() { + let expr = + roundtrip_expr(TableReference::bare("person"), "CAST(id AS VARCHAR)").unwrap(); + assert_snapshot!(expr, @r#"CAST(id AS VARCHAR)"#); +} - let ast = expr_to_sql(&expr)?; +#[test] +fn test_roundtrip_expr_4() { + let expr = roundtrip_expr(TableReference::bare("person"), "sum((age * 2))").unwrap(); + assert_snapshot!(expr, @r#"sum((age * 2))"#); +} - Ok(ast.to_string()) - }; +fn roundtrip_expr(table: TableReference, sql: &str) -> Result { + let dialect = GenericDialect {}; + let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?; + let state = MockSessionState::default().with_aggregate_function(sum_udaf()); + let context = MockContextProvider { state }; + let schema = context.get_table_source(table)?.schema(); + let df_schema = DFSchema::try_from(schema)?; + let sql_to_rel = SqlToRel::new(&context); + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; - for (table, query, expected) in tests { - let actual = roundtrip(table, query).unwrap(); - assert_eq!(actual, expected); - } + let ast = expr_to_sql(&expr)?; + + Ok(ast.to_string()) } #[test] @@ -170,6 +172,13 @@ fn roundtrip_statement() -> Result<()> { UNION ALL SELECT j3_string AS col1, j3_id AS id FROM j3 ) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#, + r#"SELECT col1, id FROM ( + SELECT j1_string AS col1, j1_id AS id FROM j1 + UNION + SELECT j2_string AS col1, j2_id AS id FROM j2 + UNION + SELECT j3_string AS col1, j3_id AS id FROM j3 + ) AS subquery ORDER BY col1 ASC, id ASC"#, "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), first_name from person", @@ -236,10 +245,6 @@ fn roundtrip_statement() -> Result<()> { let roundtrip_statement = plan_to_sql(&plan)?; - let actual = &roundtrip_statement.to_string(); - println!("roundtrip sql: {actual}"); - println!("plan {}", plan.display_indent()); - let plan_roundtrip = sql_to_rel .sql_statement_to_plan(roundtrip_statement.clone()) .unwrap(); @@ -275,109 +280,186 @@ fn roundtrip_crossjoin() -> Result<()> { let plan_roundtrip = sql_to_rel .sql_statement_to_plan(roundtrip_statement) .unwrap(); + assert_snapshot!( + plan_roundtrip, + @r" + Projection: j1.j1_id, j2.j2_string + Cross Join: + TableScan: j1 + TableScan: j2 + " + ); - let expected = "Projection: j1.j1_id, j2.j2_string\ - \n Cross Join: \ - \n TableScan: j1\ - \n TableScan: j2"; + Ok(()) +} + +#[macro_export] +macro_rules! roundtrip_statement_with_dialect_helper { + ( + sql: $sql:expr, + parser_dialect: $parser_dialect:expr, + unparser_dialect: $unparser_dialect:expr, + expected: @ $expected:literal $(,)? + ) => {{ + let statement = Parser::new(&$parser_dialect) + .try_with_sql($sql)? + .parse_statement()?; + + let state = MockSessionState::default() + .with_aggregate_function(max_udaf()) + .with_aggregate_function(min_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_expr_planner(Arc::new(NestedFunctionPlanner)) + .with_expr_planner(Arc::new(FieldAccessPlanner)); - assert_eq!(plan_roundtrip.to_string(), expected); + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel + .sql_statement_to_plan(statement) + .unwrap_or_else(|e| panic!("Failed to parse sql: {}\n{e}", $sql)); + let unparser = Unparser::new(&$unparser_dialect); + let roundtrip_statement = unparser.plan_to_sql(&plan)?; + + let actual = &roundtrip_statement.to_string(); + insta::assert_snapshot!(actual, @ $expected); + }}; +} + +#[test] +fn roundtrip_statement_with_dialect_1() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + parser_dialect: MySqlDialect {}, + unparser_dialect: UnparserMySqlDialect {}, + // top projection sort gets derived into a subquery + // for MySQL, this subquery needs an alias + expected: @"SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + ); Ok(()) } #[test] -fn roundtrip_statement_with_dialect() -> Result<()> { - struct TestStatementWithDialect { - sql: &'static str, - expected: &'static str, - parser_dialect: Box, - unparser_dialect: Box, - } - let tests: Vec = vec![ - TestStatementWithDialect { - sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", - expected: - // top projection sort gets derived into a subquery - // for MySQL, this subquery needs an alias - "SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", - parser_dialect: Box::new(MySqlDialect {}), - unparser_dialect: Box::new(UnparserMySqlDialect {}), - }, - TestStatementWithDialect { - sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", - expected: - // top projection sort still gets derived into a subquery in default dialect - // except for the default dialect, the subquery is left non-aliased - "SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10", - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", - expected: - "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` CROSS JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", - parser_dialect: Box::new(MySqlDialect {}), - unparser_dialect: Box::new(UnparserMySqlDialect {}), - }, - TestStatementWithDialect { - sql: "select j1_id from (select 1 as j1_id);", - expected: - "SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`", - parser_dialect: Box::new(MySqlDialect {}), - unparser_dialect: Box::new(UnparserMySqlDialect {}), - }, - TestStatementWithDialect { - sql: "select j1_id from (select j1_id from j1 limit 10);", - expected: - "SELECT `j1`.`j1_id` FROM (SELECT `j1`.`j1_id` FROM `j1` LIMIT 10) AS `derived_limit`", - parser_dialect: Box::new(MySqlDialect {}), - unparser_dialect: Box::new(UnparserMySqlDialect {}), - }, - TestStatementWithDialect { - sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", - expected: - "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10", - parser_dialect: Box::new(MySqlDialect {}), - unparser_dialect: Box::new(UnparserMySqlDialect {}), - }, - TestStatementWithDialect { - sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", - expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST LIMIT 10"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT j1_id FROM j1 +fn roundtrip_statement_with_dialect_2() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + // top projection sort still gets derived into a subquery in default dialect + // except for the default dialect, the subquery is left non-aliased + expected: @"SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10", + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_3() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", + parser_dialect: MySqlDialect {}, + unparser_dialect: UnparserMySqlDialect {}, + expected: @"SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` CROSS JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_4() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select j1_id from (select 1 as j1_id);", + parser_dialect: MySqlDialect {}, + unparser_dialect: UnparserMySqlDialect {}, + expected: @"SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`", + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_5() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select j1_id from (select j1_id from j1 limit 10);", + parser_dialect: MySqlDialect {}, + unparser_dialect: UnparserMySqlDialect {}, + expected: @"SELECT `j1`.`j1_id` FROM (SELECT `j1`.`j1_id` FROM `j1` LIMIT 10) AS `derived_limit`", + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_6() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", + parser_dialect: MySqlDialect {}, + unparser_dialect: UnparserMySqlDialect {}, + expected: @"SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10", + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_7() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST LIMIT 10"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_8() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT j1_id FROM j1 UNION ALL SELECT tb.j2_id as j1_id FROM j2 tb ORDER BY j1_id LIMIT 10;", - expected: r#"SELECT j1.j1_id FROM j1 UNION ALL SELECT tb.j2_id AS j1_id FROM j2 AS tb ORDER BY j1_id ASC NULLS LAST LIMIT 10"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - // Test query with derived tables that put distinct,sort,limit on the wrong level - TestStatementWithDialect { - sql: "SELECT j1_string from j1 order by j1_id", - expected: r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT j1_string AS a from j1 order by j1_id", - expected: r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id", - expected: r#"SELECT j1.j1_string FROM j1 INNER JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: " + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT j1.j1_id FROM j1 UNION ALL SELECT tb.j2_id AS j1_id FROM j2 AS tb ORDER BY j1_id ASC NULLS LAST LIMIT 10"#, + ); + Ok(()) +} + +// Test query with derived tables that put distinct,sort,limit on the wrong level +#[test] +fn roundtrip_statement_with_dialect_9() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT j1_string from j1 order by j1_id", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_10() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT j1_string AS a from j1 order by j1_id", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_11() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT j1.j1_string FROM j1 INNER JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_12() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: " SELECT j1_string, j2_string @@ -397,13 +479,18 @@ fn roundtrip_statement_with_dialect() -> Result<()> { ) abc ORDER BY abc.j2_string", - expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 INNER JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - // more tests around subquery/derived table roundtrip - TestStatementWithDialect { - sql: "SELECT string_count FROM ( + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 INNER JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + ); + Ok(()) +} + +// more tests around subquery/derived table roundtrip +#[test] +fn roundtrip_statement_with_dialect_13() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT string_count FROM ( SELECT j1_id, min(j2_string) @@ -414,12 +501,17 @@ fn roundtrip_statement_with_dialect() -> Result<()> { j1_id ) AS agg (id, string_count) ", - expected: r#"SELECT agg.string_count FROM (SELECT j1.j1_id, min(j2.j2_string) FROM j1 LEFT OUTER JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id) AS agg (id, string_count)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: " + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT agg.string_count FROM (SELECT j1.j1_id, min(j2.j2_string) FROM j1 LEFT OUTER JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id) AS agg (id, string_count)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_14() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: " SELECT j1_string, j2_string @@ -443,13 +535,18 @@ fn roundtrip_statement_with_dialect() -> Result<()> { ) abc ORDER BY abc.j2_string", - expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 INNER JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - // Test query that order by columns are not in select columns - TestStatementWithDialect { - sql: " + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 INNER JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + ); + Ok(()) +} + +// Test query that order by columns are not in select columns +#[test] +fn roundtrip_statement_with_dialect_15() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: " SELECT j1_string FROM @@ -468,221 +565,399 @@ fn roundtrip_statement_with_dialect() -> Result<()> { ) abc ORDER BY j2_string", - expected: r#"SELECT abc.j1_string FROM (SELECT j1.j1_string, j2.j2_string FROM j1 INNER JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)", - expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT id FROM (SELECT j1_id as id from j1) AS c", - expected: r#"SELECT c.id FROM (SELECT j1.j1_id AS id FROM j1) AS c"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - // Test query that has calculation in derived table with columns - TestStatementWithDialect { - sql: "SELECT id FROM (SELECT j1_id + 1 * 3 from j1) AS c (id)", - expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + (1 * 3)) FROM j1) AS c (id)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - // Test query that has limit/distinct/order in derived table with columns - TestStatementWithDialect { - sql: "SELECT id FROM (SELECT distinct (j1_id + 1 * 3) FROM j1 LIMIT 1) AS c (id)", - expected: r#"SELECT c.id FROM (SELECT DISTINCT (j1.j1_id + (1 * 3)) FROM j1 LIMIT 1) AS c (id)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT id FROM (SELECT j1_id + 1 FROM j1 ORDER BY j1_id DESC LIMIT 1) AS c (id)", - expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + 1) FROM j1 ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 1) AS c (id)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT id FROM (SELECT CAST((CAST(j1_id as BIGINT) + 1) as int) * 10 FROM j1 LIMIT 1) AS c (id)", - expected: r#"SELECT c.id FROM (SELECT (CAST((CAST(j1.j1_id AS BIGINT) + 1) AS INTEGER) * 10) FROM j1 LIMIT 1) AS c (id)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT id FROM (SELECT CAST(j1_id as BIGINT) + 1 FROM j1 ORDER BY j1_id LIMIT 1) AS c (id)", - expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", - expected: r#"SELECT temp_j.id2 FROM (SELECT j1.j1_id, j1.j1_string FROM j1) AS temp_j (id2, string2)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", - expected: r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2`, `j1`.`j1_string` AS `string2` FROM `j1`) AS `temp_j`"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(SqliteDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM (SELECT j1_id + 1 FROM j1) AS temp_j(id2)", - expected: r#"SELECT `temp_j`.`id2` FROM (SELECT (`j1`.`j1_id` + 1) AS `id2` FROM `j1`) AS `temp_j`"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(SqliteDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM (SELECT j1_id FROM j1 LIMIT 1) AS temp_j(id2)", - expected: r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2` FROM `j1` LIMIT 1) AS `temp_j`"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(SqliteDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3])", - expected: r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))" FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", - expected: r#"SELECT t1.c1 FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]), j1", - expected: r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))", j1.j1_id, j1.j1_string FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", - expected: r#"SELECT u.c1, j1.j1_id, j1.j1_string FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) INNER JOIN j1 ON (u.c1 = j1.j1_id)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", - expected: r#"SELECT u.c1 FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) UNION ALL SELECT u.c1 FROM (SELECT UNNEST([4, 5, 6]) AS "UNNEST(make_array(Int64(4),Int64(5),Int64(6)))") AS u (c1)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3])", - expected: r#"SELECT UNNEST(make_array(Int64(1),Int64(2),Int64(3))) FROM UNNEST([1, 2, 3])"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", - expected: r#"SELECT t1.c1 FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", - expected: r#"SELECT t1.c1 FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]), j1", - expected: r#"SELECT UNNEST(make_array(Int64(1),Int64(2),Int64(3))), j1.j1_id, j1.j1_string FROM UNNEST([1, 2, 3]) CROSS JOIN j1"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", - expected: r#"SELECT u.c1, j1.j1_id, j1.j1_string FROM UNNEST([1, 2, 3]) AS u (c1) INNER JOIN j1 ON (u.c1 = j1.j1_id)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", - expected: r#"SELECT u.c1 FROM UNNEST([1, 2, 3]) AS u (c1) UNION ALL SELECT u.c1 FROM UNNEST([4, 5, 6]) AS u (c1)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT UNNEST([1,2,3])", - expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT UNNEST([1,2,3]) as c1", - expected: r#"SELECT UNNEST([1, 2, 3]) AS c1"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT UNNEST([1,2,3]), 1", - expected: r#"SELECT UNNEST([1, 2, 3]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3))), Int64(1)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", - expected: r#"SELECT u.array_col, u.struct_col, UNNEST(outer_ref(u.array_col)) FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)", - expected: r#"SELECT u.array_col, u.struct_col, t1.c1 FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col) AS t1 (c1)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT unnest([1, 2, 3, 4]) from unnest([1, 2, 3]);", - expected: r#"SELECT UNNEST([1, 2, 3, 4]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3),Int64(4))) FROM UNNEST([1, 2, 3])"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), - }, - TestStatementWithDialect { - sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", - expected: r#"SELECT u.array_col, u.struct_col, "UNNEST(outer_ref(u.array_col))" FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))")"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - TestStatementWithDialect { - sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)", - expected: r#"SELECT u.array_col, u.struct_col, t1.c1 FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))") AS t1 (c1)"#, - parser_dialect: Box::new(GenericDialect {}), - unparser_dialect: Box::new(UnparserDefaultDialect {}), - }, - ]; + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT abc.j1_string FROM (SELECT j1.j1_string, j2.j2_string FROM j1 INNER JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + ); + Ok(()) +} - for query in tests { - let statement = Parser::new(&*query.parser_dialect) - .try_with_sql(query.sql)? - .parse_statement()?; +#[test] +fn roundtrip_statement_with_dialect_16() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)"#, + ); + Ok(()) +} - let state = MockSessionState::default() - .with_aggregate_function(max_udaf()) - .with_aggregate_function(min_udaf()) - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) - .with_expr_planner(Arc::new(NestedFunctionPlanner)); +#[test] +fn roundtrip_statement_with_dialect_17() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT id FROM (SELECT j1_id as id from j1) AS c", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT c.id FROM (SELECT j1.j1_id AS id FROM j1) AS c"#, + ); + Ok(()) +} - let context = MockContextProvider { state }; - let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel - .sql_statement_to_plan(statement) - .unwrap_or_else(|e| panic!("Failed to parse sql: {}\n{e}", query.sql)); +// Test query that has calculation in derived table with columns +#[test] +fn roundtrip_statement_with_dialect_18() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT id FROM (SELECT j1_id + 1 * 3 from j1) AS c (id)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT c.id FROM (SELECT (j1.j1_id + (1 * 3)) FROM j1) AS c (id)"#, + ); + Ok(()) +} - let unparser = Unparser::new(&*query.unparser_dialect); - let roundtrip_statement = unparser.plan_to_sql(&plan)?; +// Test query that has limit/distinct/order in derived table with columns +#[test] +fn roundtrip_statement_with_dialect_19() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT id FROM (SELECT distinct (j1_id + 1 * 3) FROM j1 LIMIT 1) AS c (id)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT c.id FROM (SELECT DISTINCT (j1.j1_id + (1 * 3)) FROM j1 LIMIT 1) AS c (id)"#, + ); + Ok(()) +} - let actual = &roundtrip_statement.to_string(); - println!("roundtrip sql: {actual}"); - println!("plan {}", plan.display_indent()); +#[test] +fn roundtrip_statement_with_dialect_20() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT id FROM (SELECT j1_id + 1 FROM j1 ORDER BY j1_id DESC LIMIT 1) AS c (id)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT c.id FROM (SELECT (j1.j1_id + 1) FROM j1 ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 1) AS c (id)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_21() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT id FROM (SELECT CAST((CAST(j1_id as BIGINT) + 1) as int) * 10 FROM j1 LIMIT 1) AS c (id)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT c.id FROM (SELECT (CAST((CAST(j1.j1_id AS BIGINT) + 1) AS INTEGER) * 10) FROM j1 LIMIT 1) AS c (id)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_22() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT id FROM (SELECT CAST(j1_id as BIGINT) + 1 FROM j1 ORDER BY j1_id LIMIT 1) AS c (id)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_23() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT temp_j.id2 FROM (SELECT j1.j1_id, j1.j1_string FROM j1) AS temp_j (id2, string2)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_24() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", + parser_dialect: GenericDialect {}, + unparser_dialect: SqliteDialect {}, + expected: @r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2`, `j1`.`j1_string` AS `string2` FROM `j1`) AS `temp_j`"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_25() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM (SELECT j1_id + 1 FROM j1) AS temp_j(id2)", + parser_dialect: GenericDialect {}, + unparser_dialect: SqliteDialect {}, + expected: @r#"SELECT `temp_j`.`id2` FROM (SELECT (`j1`.`j1_id` + 1) AS `id2` FROM `j1`) AS `temp_j`"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_26() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM (SELECT j1_id FROM j1 LIMIT 1) AS temp_j(id2)", + parser_dialect: GenericDialect {}, + unparser_dialect: SqliteDialect {}, + expected: @r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2` FROM `j1` LIMIT 1) AS `temp_j`"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_27() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3])", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))" FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS derived_projection ("UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_28() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT t1.c1 FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_29() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3]), j1", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))", j1.j1_id, j1.j1_string FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS derived_projection ("UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_30() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT u.c1, j1.j1_id, j1.j1_string FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) INNER JOIN j1 ON (u.c1 = j1.j1_id)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_31() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT u.c1 FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) UNION ALL SELECT u.c1 FROM (SELECT UNNEST([4, 5, 6]) AS "UNNEST(make_array(Int64(4),Int64(5),Int64(6)))") AS u (c1)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_32() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3])", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT UNNEST(make_array(Int64(1),Int64(2),Int64(3))) FROM UNNEST([1, 2, 3])"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_33() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT u.array_col, u.struct_col, "UNNEST(outer_ref(u.array_col))" FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))")"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_34() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT t1.c1 FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_35() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3]), j1", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT UNNEST(make_array(Int64(1),Int64(2),Int64(3))), j1.j1_id, j1.j1_string FROM UNNEST([1, 2, 3]) CROSS JOIN j1"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_36() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT u.c1, j1.j1_id, j1.j1_string FROM UNNEST([1, 2, 3]) AS u (c1) INNER JOIN j1 ON (u.c1 = j1.j1_id)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_37() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT u.c1 FROM UNNEST([1, 2, 3]) AS u (c1) UNION ALL SELECT u.c1 FROM UNNEST([4, 5, 6]) AS u (c1)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_38() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT UNNEST([1,2,3])", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT * FROM UNNEST([1, 2, 3])"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_39() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT UNNEST([1,2,3]) as c1", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT UNNEST([1, 2, 3]) AS c1"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_40() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT UNNEST([1,2,3]), 1", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT UNNEST([1, 2, 3]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3))), Int64(1)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_41() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT u.array_col, u.struct_col, UNNEST(outer_ref(u.array_col)) FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_42() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT u.array_col, u.struct_col, t1.c1 FROM unnest_table AS u CROSS JOIN UNNEST(u.array_col) AS t1 (c1)"#, + ); + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect_43() -> Result<(), DataFusionError> { + let unparser = CustomDialectBuilder::default() + .with_unnest_as_table_factor(true) + .build(); + roundtrip_statement_with_dialect_helper!( + sql: "SELECT unnest([1, 2, 3, 4]) from unnest([1, 2, 3]);", + parser_dialect: GenericDialect {}, + unparser_dialect: unparser, + expected: @r#"SELECT UNNEST([1, 2, 3, 4]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3),Int64(4))) FROM UNNEST([1, 2, 3])"#, + ); + Ok(()) +} - assert_eq!(query.expected, actual); - } +#[test] +fn roundtrip_statement_with_dialect_45() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col) AS t1 (c1)", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT u.array_col, u.struct_col, t1.c1 FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))") AS t1 (c1)"#, + ); + Ok(()) +} +#[test] +fn roundtrip_statement_with_dialect_special_char_alias() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select min(a) as \"min(a)\" from (select 1 as a)", + parser_dialect: GenericDialect {}, + unparser_dialect: BigQueryDialect {}, + expected: @r#"SELECT min(`a`) AS `min_40a_41` FROM (SELECT 1 AS `a`)"#, + ); + roundtrip_statement_with_dialect_helper!( + sql: "select a as \"a*\", b as \"b@\" from (select 1 as a , 2 as b)", + parser_dialect: GenericDialect {}, + unparser_dialect: BigQueryDialect {}, + expected: @r#"SELECT `a` AS `a_42`, `b` AS `b_64` FROM (SELECT 1 AS `a`, 2 AS `b`)"#, + ); + roundtrip_statement_with_dialect_helper!( + sql: "select a as \"a*\", b , c as \"c@\" from (select 1 as a , 2 as b, 3 as c)", + parser_dialect: GenericDialect {}, + unparser_dialect: BigQueryDialect {}, + expected: @r#"SELECT `a` AS `a_42`, `b`, `c` AS `c_64` FROM (SELECT 1 AS `a`, 2 AS `b`, 3 AS `c`)"#, + ); + roundtrip_statement_with_dialect_helper!( + sql: "select * from (select a as \"a*\", b as \"b@\" from (select 1 as a , 2 as b)) where \"a*\" = 1", + parser_dialect: GenericDialect {}, + unparser_dialect: BigQueryDialect {}, + expected: @r#"SELECT `a_42`, `b_64` FROM (SELECT `a` AS `a_42`, `b` AS `b_64` FROM (SELECT 1 AS `a`, 2 AS `b`)) WHERE (`a_42` = 1)"#, + ); + roundtrip_statement_with_dialect_helper!( + sql: "select * from (select a as \"a*\", b as \"b@\" from (select 1 as a , 2 as b)) where \"a*\" = 1", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT "a*", "b@" FROM (SELECT a AS "a*", b AS "b@" FROM (SELECT 1 AS a, 2 AS b)) WHERE ("a*" = 1)"#, + ); Ok(()) } @@ -700,13 +975,14 @@ fn test_unnest_logical_plan() -> Result<()> { }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); - let expected = r#" + assert_snapshot!( + plan, + @r#" Projection: __unnest_placeholder(unnest_table.struct_col).field1, __unnest_placeholder(unnest_table.struct_col).field2, __unnest_placeholder(unnest_table.array_col,depth=1) AS UNNEST(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col Unnest: lists[__unnest_placeholder(unnest_table.array_col)|depth=1] structs[__unnest_placeholder(unnest_table.struct_col)] Projection: unnest_table.struct_col AS __unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS __unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col - TableScan: unnest_table"#.trim_start(); - - assert_eq!(plan.to_string(), expected); + TableScan: unnest_table"# + ); Ok(()) } @@ -726,121 +1002,248 @@ fn test_aggregation_without_projection() -> Result<()> { let unparser = Unparser::default(); let statement = unparser.plan_to_sql(&plan)?; - - let actual = &statement.to_string(); - - assert_eq!( - actual, - r#"SELECT sum(users.age), users."name" FROM users GROUP BY users."name""# + assert_snapshot!( + statement, + @r#"SELECT sum(users.age), users."name" FROM users GROUP BY users."name""# ); Ok(()) } -#[test] -fn test_table_references_in_plan_to_sql() { - fn test(table_name: &str, expected_sql: &str, dialect: &impl UnparserDialect) { - let schema = Schema::new(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("value", DataType::Utf8, false), - ]); - let plan = table_scan(Some(table_name), &schema, None) - .unwrap() - .project(vec![col("id"), col("value")]) - .unwrap() - .build() - .unwrap(); - - let unparser = Unparser::new(dialect); - let sql = unparser.plan_to_sql(&plan).unwrap(); - - assert_eq!(sql.to_string(), expected_sql) - } +/// return a schema with two string columns: "id" and "value" +fn test_schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ]) +} - test( - "catalog.schema.table", - r#"SELECT "table".id, "table"."value" FROM "catalog"."schema"."table""#, +#[test] +fn test_table_references_in_plan_to_sql_1() { + let table_name = "catalog.schema.table"; + let schema = test_schema(); + let sql = table_references_in_plan_helper( + table_name, + schema, + vec![col("id"), col("value")], &DefaultDialect {}, ); - test( - "schema.table", - r#"SELECT "table".id, "table"."value" FROM "schema"."table""#, + assert_snapshot!( + sql, + @r#"SELECT "table".id, "table"."value" FROM "catalog"."schema"."table""# + ); +} + +#[test] +fn test_table_references_in_plan_to_sql_2() { + let table_name = "schema.table"; + let schema = test_schema(); + let sql = table_references_in_plan_helper( + table_name, + schema, + vec![col("id"), col("value")], &DefaultDialect {}, ); - test( - "table", - r#"SELECT "table".id, "table"."value" FROM "table""#, + assert_snapshot!( + sql, + @r#"SELECT "table".id, "table"."value" FROM "schema"."table""# + ); +} + +#[test] +fn test_table_references_in_plan_to_sql_3() { + let table_name = "table"; + let schema = test_schema(); + let sql = table_references_in_plan_helper( + table_name, + schema, + vec![col("id"), col("value")], &DefaultDialect {}, ); + assert_snapshot!( + sql, + @r#"SELECT "table".id, "table"."value" FROM "table""# + ); +} +#[test] +fn test_table_references_in_plan_to_sql_4() { + let table_name = "catalog.schema.table"; + let schema = test_schema(); let custom_dialect = CustomDialectBuilder::default() .with_full_qualified_col(true) .with_identifier_quote_style('"') .build(); - test( - "catalog.schema.table", - r#"SELECT "catalog"."schema"."table"."id", "catalog"."schema"."table"."value" FROM "catalog"."schema"."table""#, + let sql = table_references_in_plan_helper( + table_name, + schema, + vec![col("id"), col("value")], &custom_dialect, ); - test( - "schema.table", - r#"SELECT "schema"."table"."id", "schema"."table"."value" FROM "schema"."table""#, + assert_snapshot!( + sql, + @r#"SELECT "catalog"."schema"."table"."id", "catalog"."schema"."table"."value" FROM "catalog"."schema"."table""# + ); +} + +#[test] +fn test_table_references_in_plan_to_sql_5() { + let table_name = "schema.table"; + let schema = test_schema(); + let custom_dialect = CustomDialectBuilder::default() + .with_full_qualified_col(true) + .with_identifier_quote_style('"') + .build(); + + let sql = table_references_in_plan_helper( + table_name, + schema, + vec![col("id"), col("value")], &custom_dialect, ); - test( - "table", - r#"SELECT "table"."id", "table"."value" FROM "table""#, + assert_snapshot!( + sql, + @r#"SELECT "schema"."table"."id", "schema"."table"."value" FROM "schema"."table""# + ); +} + +#[test] +fn test_table_references_in_plan_to_sql_6() { + let table_name = "table"; + let schema = test_schema(); + let custom_dialect = CustomDialectBuilder::default() + .with_full_qualified_col(true) + .with_identifier_quote_style('"') + .build(); + + let sql = table_references_in_plan_helper( + table_name, + schema, + vec![col("id"), col("value")], &custom_dialect, ); + assert_snapshot!( + sql, + @r#"SELECT "table"."id", "table"."value" FROM "table""# + ); +} + +fn table_references_in_plan_helper( + table_name: &str, + table_schema: Schema, + expr: impl IntoIterator>, + dialect: &impl UnparserDialect, +) -> Statement { + let plan = table_scan(Some(table_name), &table_schema, None) + .unwrap() + .project(expr) + .unwrap() + .build() + .unwrap(); + let unparser = Unparser::new(dialect); + unparser.plan_to_sql(&plan).unwrap() } #[test] -fn test_table_scan_with_none_projection_in_plan_to_sql() { - fn test(table_name: &str, expected_sql: &str) { - let schema = Schema::new(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("value", DataType::Utf8, false), - ]); +fn test_table_scan_with_none_projection_in_plan_to_sql_1() { + let schema = test_schema(); + let table_name = "catalog.schema.table"; + let plan = table_scan_with_empty_projection_and_none_projection_helper( + table_name, schema, None, + ); + let sql = plan_to_sql(&plan).unwrap(); + assert_snapshot!( + sql, + @r#"SELECT * FROM "catalog"."schema"."table""# + ); +} - let plan = table_scan(Some(table_name), &schema, None) - .unwrap() - .build() - .unwrap(); - let sql = plan_to_sql(&plan).unwrap(); - assert_eq!(sql.to_string(), expected_sql) - } +#[test] +fn test_table_scan_with_none_projection_in_plan_to_sql_2() { + let schema = test_schema(); + let table_name = "schema.table"; + let plan = table_scan_with_empty_projection_and_none_projection_helper( + table_name, schema, None, + ); + let sql = plan_to_sql(&plan).unwrap(); + assert_snapshot!( + sql, + @r#"SELECT * FROM "schema"."table""# + ); +} - test( - "catalog.schema.table", - r#"SELECT * FROM "catalog"."schema"."table""#, +#[test] +fn test_table_scan_with_none_projection_in_plan_to_sql_3() { + let schema = test_schema(); + let table_name = "table"; + let plan = table_scan_with_empty_projection_and_none_projection_helper( + table_name, schema, None, + ); + let sql = plan_to_sql(&plan).unwrap(); + assert_snapshot!( + sql, + @r#"SELECT * FROM "table""# ); - test("schema.table", r#"SELECT * FROM "schema"."table""#); - test("table", r#"SELECT * FROM "table""#); } #[test] -fn test_table_scan_with_empty_projection_in_plan_to_sql() { - fn test(table_name: &str, expected_sql: &str) { - let schema = Schema::new(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("value", DataType::Utf8, false), - ]); +fn test_table_scan_with_empty_projection_in_plan_to_sql_1() { + let schema = test_schema(); + let table_name = "catalog.schema.table"; + let plan = table_scan_with_empty_projection_and_none_projection_helper( + table_name, + schema, + Some(vec![]), + ); + let sql = plan_to_sql(&plan).unwrap(); + assert_snapshot!( + sql, + @r#"SELECT 1 FROM "catalog"."schema"."table""# + ); +} - let plan = table_scan(Some(table_name), &schema, Some(vec![])) - .unwrap() - .build() - .unwrap(); - let sql = plan_to_sql(&plan).unwrap(); - assert_eq!(sql.to_string(), expected_sql) - } +#[test] +fn test_table_scan_with_empty_projection_in_plan_to_sql_2() { + let schema = test_schema(); + let table_name = "schema.table"; + let plan = table_scan_with_empty_projection_and_none_projection_helper( + table_name, + schema, + Some(vec![]), + ); + let sql = plan_to_sql(&plan).unwrap(); + assert_snapshot!( + sql, + @r#"SELECT 1 FROM "schema"."table""# + ); +} - test( - "catalog.schema.table", - r#"SELECT 1 FROM "catalog"."schema"."table""#, +#[test] +fn test_table_scan_with_empty_projection_in_plan_to_sql_3() { + let schema = test_schema(); + let table_name = "table"; + let plan = table_scan_with_empty_projection_and_none_projection_helper( + table_name, + schema, + Some(vec![]), + ); + let sql = plan_to_sql(&plan).unwrap(); + assert_snapshot!( + sql, + @r#"SELECT 1 FROM "table""# ); - test("schema.table", r#"SELECT 1 FROM "schema"."table""#); - test("table", r#"SELECT 1 FROM "table""#); +} + +fn table_scan_with_empty_projection_and_none_projection_helper( + table_name: &str, + table_schema: Schema, + projection: Option>, +) -> LogicalPlan { + table_scan(Some(table_name), &table_schema, projection) + .unwrap() + .build() + .unwrap() } #[test] @@ -901,7 +1304,7 @@ fn test_pretty_roundtrip() -> Result<()> { let expr = sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; let round_trip_sql = unparser.expr_to_sql(&expr)?.to_string(); - assert_eq!(pretty.to_string(), round_trip_sql); + assert_eq!((*pretty).to_string(), round_trip_sql); // verify that the pretty string parses to the same underlying Expr let pretty_sql_expr = Parser::new(&GenericDialect {}) @@ -920,12 +1323,12 @@ fn test_pretty_roundtrip() -> Result<()> { Ok(()) } -fn sql_round_trip(dialect: D, query: &str, expect: &str) +fn generate_round_trip_statement(dialect: D, sql: &str) -> Statement where D: Dialect, { let statement = Parser::new(&dialect) - .try_with_sql(query) + .try_with_sql(sql) .unwrap() .parse_statement() .unwrap(); @@ -937,13 +1340,16 @@ where .with_aggregate_function(grouping_udaf()) .with_window_function(rank_udwf()) .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())) - .with_scalar_function(make_array_udf()), + .with_scalar_function(make_array_udf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_expr_planner(Arc::new(UnicodeFunctionPlanner)) + .with_expr_planner(Arc::new(NestedFunctionPlanner)) + .with_expr_planner(Arc::new(FieldAccessPlanner)), }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); - let roundtrip_statement = plan_to_sql(&plan).unwrap(); - assert_eq!(roundtrip_statement.to_string(), expect); + plan_to_sql(&plan).unwrap() } #[test] @@ -958,7 +1364,10 @@ fn test_table_scan_alias() -> Result<()> { .alias("a")? .build()?; let sql = plan_to_sql(&plan)?; - assert_eq!(sql.to_string(), "SELECT * FROM (SELECT t1.id FROM t1) AS a"); + assert_snapshot!( + sql, + @"SELECT * FROM (SELECT t1.id FROM t1) AS a" + ); let plan = table_scan(Some("t1"), &schema, None)? .project(vec![col("id")])? @@ -966,7 +1375,10 @@ fn test_table_scan_alias() -> Result<()> { .build()?; let sql = plan_to_sql(&plan)?; - assert_eq!(sql.to_string(), "SELECT * FROM (SELECT t1.id FROM t1) AS a"); + assert_snapshot!( + sql, + @"SELECT * FROM (SELECT t1.id FROM t1) AS a" + ); let plan = table_scan(Some("t1"), &schema, None)? .filter(col("id").gt(lit(5)))? @@ -974,9 +1386,9 @@ fn test_table_scan_alias() -> Result<()> { .alias("a")? .build()?; let sql = plan_to_sql(&plan)?; - assert_eq!( - sql.to_string(), - "SELECT * FROM (SELECT t1.id FROM t1 WHERE (t1.id > 5)) AS a" + assert_snapshot!( + sql, + @r#"SELECT * FROM (SELECT t1.id FROM t1 WHERE (t1.id > 5)) AS a"# ); let table_scan_with_two_filter = table_scan_with_filters( @@ -989,9 +1401,9 @@ fn test_table_scan_alias() -> Result<()> { .alias("a")? .build()?; let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; - assert_eq!( - table_scan_with_two_filter.to_string(), - "SELECT a.id FROM t1 AS a WHERE ((a.id > 1) AND (a.age < 2))" + assert_snapshot!( + table_scan_with_two_filter, + @r#"SELECT a.id FROM t1 AS a WHERE ((a.id > 1) AND (a.age < 2))"# ); let table_scan_with_fetch = @@ -1000,9 +1412,9 @@ fn test_table_scan_alias() -> Result<()> { .alias("a")? .build()?; let table_scan_with_fetch = plan_to_sql(&table_scan_with_fetch)?; - assert_eq!( - table_scan_with_fetch.to_string(), - "SELECT a.id FROM (SELECT * FROM t1 LIMIT 10) AS a" + assert_snapshot!( + table_scan_with_fetch, + @r#"SELECT a.id FROM (SELECT * FROM t1 LIMIT 10) AS a"# ); let table_scan_with_pushdown_all = table_scan_with_filter_and_fetch( @@ -1016,9 +1428,9 @@ fn test_table_scan_alias() -> Result<()> { .alias("a")? .build()?; let table_scan_with_pushdown_all = plan_to_sql(&table_scan_with_pushdown_all)?; - assert_eq!( - table_scan_with_pushdown_all.to_string(), - "SELECT a.id FROM (SELECT a.id, a.age FROM t1 AS a WHERE (a.id > 1) LIMIT 10) AS a" + assert_snapshot!( + table_scan_with_pushdown_all, + @r#"SELECT a.id FROM (SELECT a.id, a.age FROM t1 AS a WHERE (a.id > 1) LIMIT 10) AS a"# ); Ok(()) } @@ -1032,18 +1444,24 @@ fn test_table_scan_pushdown() -> Result<()> { let scan_with_projection = table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?; let scan_with_projection = plan_to_sql(&scan_with_projection)?; - assert_eq!( - scan_with_projection.to_string(), - "SELECT t1.id, t1.age FROM t1" + assert_snapshot!( + scan_with_projection, + @r#"SELECT t1.id, t1.age FROM t1"# ); let scan_with_projection = table_scan(Some("t1"), &schema, Some(vec![1]))?.build()?; let scan_with_projection = plan_to_sql(&scan_with_projection)?; - assert_eq!(scan_with_projection.to_string(), "SELECT t1.age FROM t1"); + assert_snapshot!( + scan_with_projection, + @r#"SELECT t1.age FROM t1"# + ); let scan_with_no_projection = table_scan(Some("t1"), &schema, None)?.build()?; let scan_with_no_projection = plan_to_sql(&scan_with_no_projection)?; - assert_eq!(scan_with_no_projection.to_string(), "SELECT * FROM t1"); + assert_snapshot!( + scan_with_no_projection, + @r#"SELECT * FROM t1"# + ); let table_scan_with_projection_alias = table_scan(Some("t1"), &schema, Some(vec![0, 1]))? @@ -1051,9 +1469,9 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_projection_alias = plan_to_sql(&table_scan_with_projection_alias)?; - assert_eq!( - table_scan_with_projection_alias.to_string(), - "SELECT ta.id, ta.age FROM t1 AS ta" + assert_snapshot!( + table_scan_with_projection_alias, + @r#"SELECT ta.id, ta.age FROM t1 AS ta"# ); let table_scan_with_projection_alias = @@ -1062,9 +1480,9 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_projection_alias = plan_to_sql(&table_scan_with_projection_alias)?; - assert_eq!( - table_scan_with_projection_alias.to_string(), - "SELECT ta.age FROM t1 AS ta" + assert_snapshot!( + table_scan_with_projection_alias, + @r#"SELECT ta.age FROM t1 AS ta"# ); let table_scan_with_no_projection_alias = table_scan(Some("t1"), &schema, None)? @@ -1072,9 +1490,9 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_no_projection_alias = plan_to_sql(&table_scan_with_no_projection_alias)?; - assert_eq!( - table_scan_with_no_projection_alias.to_string(), - "SELECT * FROM t1 AS ta" + assert_snapshot!( + table_scan_with_no_projection_alias, + @r#"SELECT * FROM t1 AS ta"# ); let query_from_table_scan_with_projection = LogicalPlanBuilder::from( @@ -1084,9 +1502,9 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let query_from_table_scan_with_projection = plan_to_sql(&query_from_table_scan_with_projection)?; - assert_eq!( - query_from_table_scan_with_projection.to_string(), - "SELECT t1.id, t1.age FROM t1" + assert_snapshot!( + query_from_table_scan_with_projection, + @r#"SELECT t1.id, t1.age FROM t1"# ); let query_from_table_scan_with_two_projections = LogicalPlanBuilder::from( @@ -1097,9 +1515,9 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let query_from_table_scan_with_two_projections = plan_to_sql(&query_from_table_scan_with_two_projections)?; - assert_eq!( - query_from_table_scan_with_two_projections.to_string(), - "SELECT t1.id, t1.age FROM (SELECT t1.id, t1.age FROM t1)" + assert_snapshot!( + query_from_table_scan_with_two_projections, + @r#"SELECT t1.id, t1.age FROM (SELECT t1.id, t1.age FROM t1)"# ); let table_scan_with_filter = table_scan_with_filters( @@ -1110,9 +1528,9 @@ fn test_table_scan_pushdown() -> Result<()> { )? .build()?; let table_scan_with_filter = plan_to_sql(&table_scan_with_filter)?; - assert_eq!( - table_scan_with_filter.to_string(), - "SELECT * FROM t1 WHERE (t1.id > t1.age)" + assert_snapshot!( + table_scan_with_filter, + @r#"SELECT * FROM t1 WHERE (t1.id > t1.age)"# ); let table_scan_with_two_filter = table_scan_with_filters( @@ -1123,9 +1541,9 @@ fn test_table_scan_pushdown() -> Result<()> { )? .build()?; let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; - assert_eq!( - table_scan_with_two_filter.to_string(), - "SELECT * FROM t1 WHERE ((t1.id > 1) AND (t1.age < 2))" + assert_snapshot!( + table_scan_with_two_filter, + @r#"SELECT * FROM t1 WHERE ((t1.id > 1) AND (t1.age < 2))"# ); let table_scan_with_filter_alias = table_scan_with_filters( @@ -1137,9 +1555,9 @@ fn test_table_scan_pushdown() -> Result<()> { .alias("ta")? .build()?; let table_scan_with_filter_alias = plan_to_sql(&table_scan_with_filter_alias)?; - assert_eq!( - table_scan_with_filter_alias.to_string(), - "SELECT * FROM t1 AS ta WHERE (ta.id > ta.age)" + assert_snapshot!( + table_scan_with_filter_alias, + @r#"SELECT * FROM t1 AS ta WHERE (ta.id > ta.age)"# ); let table_scan_with_projection_and_filter = table_scan_with_filters( @@ -1151,9 +1569,9 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_projection_and_filter = plan_to_sql(&table_scan_with_projection_and_filter)?; - assert_eq!( - table_scan_with_projection_and_filter.to_string(), - "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age)" + assert_snapshot!( + table_scan_with_projection_and_filter, + @r#"SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age)"# ); let table_scan_with_projection_and_filter = table_scan_with_filters( @@ -1165,18 +1583,18 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_projection_and_filter = plan_to_sql(&table_scan_with_projection_and_filter)?; - assert_eq!( - table_scan_with_projection_and_filter.to_string(), - "SELECT t1.age FROM t1 WHERE (t1.id > t1.age)" + assert_snapshot!( + table_scan_with_projection_and_filter, + @r#"SELECT t1.age FROM t1 WHERE (t1.id > t1.age)"# ); let table_scan_with_inline_fetch = table_scan_with_filter_and_fetch(Some("t1"), &schema, None, vec![], Some(10))? .build()?; let table_scan_with_inline_fetch = plan_to_sql(&table_scan_with_inline_fetch)?; - assert_eq!( - table_scan_with_inline_fetch.to_string(), - "SELECT * FROM t1 LIMIT 10" + assert_snapshot!( + table_scan_with_inline_fetch, + @r#"SELECT * FROM t1 LIMIT 10"# ); let table_scan_with_projection_and_inline_fetch = table_scan_with_filter_and_fetch( @@ -1189,9 +1607,9 @@ fn test_table_scan_pushdown() -> Result<()> { .build()?; let table_scan_with_projection_and_inline_fetch = plan_to_sql(&table_scan_with_projection_and_inline_fetch)?; - assert_eq!( - table_scan_with_projection_and_inline_fetch.to_string(), - "SELECT t1.id, t1.age FROM t1 LIMIT 10" + assert_snapshot!( + table_scan_with_projection_and_inline_fetch, + @r#"SELECT t1.id, t1.age FROM t1 LIMIT 10"# ); let table_scan_with_all = table_scan_with_filter_and_fetch( @@ -1203,9 +1621,9 @@ fn test_table_scan_pushdown() -> Result<()> { )? .build()?; let table_scan_with_all = plan_to_sql(&table_scan_with_all)?; - assert_eq!( - table_scan_with_all.to_string(), - "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age) LIMIT 10" + assert_snapshot!( + table_scan_with_all, + @r#"SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age) LIMIT 10"# ); let table_scan_with_additional_filter = table_scan_with_filters( @@ -1217,9 +1635,9 @@ fn test_table_scan_pushdown() -> Result<()> { .filter(col("id").eq(lit(5)))? .build()?; let table_scan_with_filter = plan_to_sql(&table_scan_with_additional_filter)?; - assert_eq!( - table_scan_with_filter.to_string(), - "SELECT * FROM t1 WHERE (t1.id = 5) AND (t1.id > t1.age)" + assert_snapshot!( + table_scan_with_filter, + @r#"SELECT * FROM t1 WHERE (t1.id = 5) AND (t1.id > t1.age)"# ); Ok(()) @@ -1238,9 +1656,9 @@ fn test_sort_with_push_down_fetch() -> Result<()> { .build()?; let sql = plan_to_sql(&plan)?; - assert_eq!( - format!("{}", sql), - "SELECT t1.id, t1.age FROM t1 ORDER BY t1.age ASC NULLS FIRST LIMIT 10" + assert_snapshot!( + sql, + @r#"SELECT t1.id, t1.age FROM t1 ORDER BY t1.age ASC NULLS FIRST LIMIT 10"# ); Ok(()) } @@ -1284,10 +1702,10 @@ fn test_join_with_table_scan_filters() -> Result<()> { .build()?; let sql = plan_to_sql(&join_plan_with_filter)?; - - let expected_sql = r#"SELECT * FROM left_table AS "left" INNER JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"#; - - assert_eq!(sql.to_string(), expected_sql); + assert_snapshot!( + sql, + @r#"SELECT * FROM left_table AS "left" INNER JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"# + ); let join_plan_no_filter = LogicalPlanBuilder::from(left_plan.clone()) .join( @@ -1299,10 +1717,10 @@ fn test_join_with_table_scan_filters() -> Result<()> { .build()?; let sql = plan_to_sql(&join_plan_no_filter)?; - - let expected_sql = r#"SELECT * FROM left_table AS "left" INNER JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"#; - - assert_eq!(sql.to_string(), expected_sql); + assert_snapshot!( + sql, + @r#"SELECT * FROM left_table AS "left" INNER JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"# + ); let right_plan_with_filter = table_scan_with_filters( Some("right_table"), @@ -1324,10 +1742,10 @@ fn test_join_with_table_scan_filters() -> Result<()> { .build()?; let sql = plan_to_sql(&join_plan_multiple_filters)?; - - let expected_sql = r#"SELECT * FROM left_table AS "left" INNER JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 10))) WHERE ("left"."name" = 'after_join_filter_val')"#; - - assert_eq!(sql.to_string(), expected_sql); + assert_snapshot!( + sql, + @r#"SELECT * FROM left_table AS "left" INNER JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 10))) WHERE ("left"."name" = 'after_join_filter_val')"# + ); let right_plan_with_filter_schema = table_scan_with_filters( Some("right_table"), @@ -1354,114 +1772,199 @@ fn test_join_with_table_scan_filters() -> Result<()> { .build()?; let sql = plan_to_sql(&join_plan_duplicated_filter)?; - - let expected_sql = r#"SELECT * FROM left_table AS "left" INNER JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table.age > 10)) AND (right_table.age < 11)))"#; - - assert_eq!(sql.to_string(), expected_sql); + assert_snapshot!( + sql, + @r#"SELECT * FROM left_table AS "left" INNER JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table.age > 10)) AND (right_table.age < 11)))"# + ); Ok(()) } #[test] fn test_interval_lhs_eq() { - sql_round_trip( + let statement = generate_round_trip_statement( GenericDialect {}, "select interval '2 seconds' = interval '2 seconds'", - "SELECT (INTERVAL '2.000000000 SECS' = INTERVAL '2.000000000 SECS')", ); + assert_snapshot!( + statement, + @r#"SELECT (INTERVAL '2.000000000 SECS' = INTERVAL '2.000000000 SECS')"# + ) } #[test] fn test_interval_lhs_lt() { - sql_round_trip( + let statement = generate_round_trip_statement( GenericDialect {}, "select interval '2 seconds' < interval '2 seconds'", - "SELECT (INTERVAL '2.000000000 SECS' < INTERVAL '2.000000000 SECS')", ); + assert_snapshot!( + statement, + @r#"SELECT (INTERVAL '2.000000000 SECS' < INTERVAL '2.000000000 SECS')"# + ) } #[test] fn test_without_offset() { - sql_round_trip(MySqlDialect {}, "select 1", "SELECT 1"); + let statement = generate_round_trip_statement(MySqlDialect {}, "select 1"); + assert_snapshot!( + statement, + @r#"SELECT 1"# + ) } #[test] fn test_with_offset0() { - sql_round_trip(MySqlDialect {}, "select 1 offset 0", "SELECT 1 OFFSET 0"); + let statement = generate_round_trip_statement(MySqlDialect {}, "select 1 offset 0"); + assert_snapshot!( + statement, + @r#"SELECT 1 OFFSET 0"# + ) } #[test] fn test_with_offset95() { - sql_round_trip(MySqlDialect {}, "select 1 offset 95", "SELECT 1 OFFSET 95"); + let statement = generate_round_trip_statement(MySqlDialect {}, "select 1 offset 95"); + assert_snapshot!( + statement, + @r#"SELECT 1 OFFSET 95"# + ) } #[test] -fn test_order_by_to_sql() { +fn test_order_by_to_sql_1() { // order by aggregation function - sql_round_trip( + let statement = generate_round_trip_statement( GenericDialect {}, r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#, - r#"SELECT person.id, person.first_name, sum(person.id) FROM person GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#, ); + assert_snapshot!( + statement, + @r#"SELECT person.id, person.first_name, sum(person.id) FROM person GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"# + ); +} +#[test] +fn test_order_by_to_sql_2() { // order by aggregation function alias - sql_round_trip( + let statement = generate_round_trip_statement( GenericDialect {}, r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT 10"#, - r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#, ); + assert_snapshot!( + statement, + @r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"# + ); +} - // order by scalar function from projection - sql_round_trip( +#[test] +fn test_order_by_to_sql_3() { + let statement = generate_round_trip_statement( GenericDialect {}, r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY id, substr(first_name,0,5)"#, - r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 5) ASC NULLS LAST"#, ); + assert_snapshot!( + statement, + @r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 5) ASC NULLS LAST"# + ); +} + +#[test] +fn test_complex_order_by_with_grouping() -> Result<()> { + let state = MockSessionState::default().with_aggregate_function(grouping_udaf()); + + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new(&context); + + // This SQL is based on a simplified version of the TPC-DS query 36. + let statement = Parser::new(&GenericDialect {}) + .try_with_sql( + r#"SELECT + j1_id, + j1_string, + grouping(j1_id) + grouping(j1_string) as lochierarchy + FROM + j1 + GROUP BY + ROLLUP (j1_id, j1_string) + ORDER BY + grouping(j1_id) + grouping(j1_string) DESC, + CASE + WHEN grouping(j1_id) + grouping(j1_string) = 0 THEN j1_id + END + LIMIT 100"#, + )? + .parse_statement()?; + + let plan = sql_to_rel.sql_statement_to_plan(statement)?; + let unparser = Unparser::default(); + let sql = unparser.plan_to_sql(&plan)?; + insta::with_settings!({ + filters => vec![ + // Force a deterministic order for the grouping pairs + (r#"grouping\(j1\.(?:j1_id|j1_string)\),\s*grouping\(j1\.(?:j1_id|j1_string)\)"#, "grouping(j1.j1_string), grouping(j1.j1_id)") + ], + }, { + assert_snapshot!( + sql, + @r#"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY (grouping(j1.j1_id) + grouping(j1.j1_string)) DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100"# + ); + }); + + Ok(()) } #[test] fn test_aggregation_to_sql() { - sql_round_trip( - GenericDialect {}, - r#"SELECT id, first_name, + let sql = r#"SELECT id, first_name, SUM(id) AS total_sum, SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + SUM(id) FILTER (WHERE id > 50 AND first_name = 'John') OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS filtered_sum, MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1, rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN (CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_2 FROM person - GROUP BY id, first_name;"#, - r#"SELECT person.id, person.first_name, -sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, -max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, -rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, -rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 -FROM person -GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(), + GROUP BY id, first_name"#; + let statement = generate_round_trip_statement(GenericDialect {}, sql); + assert_snapshot!( + statement, + @"SELECT person.id, person.first_name, sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, sum(person.id) FILTER (WHERE ((person.id > 50) AND (person.first_name = 'John'))) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS filtered_sum, max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 FROM person GROUP BY person.id, person.first_name", ); } #[test] -fn test_unnest_to_sql() { - sql_round_trip( +fn test_unnest_to_sql_1() { + let statement = generate_round_trip_statement( GenericDialect {}, r#"SELECT unnest(array_col) as u1, struct_col, array_col FROM unnest_table WHERE array_col != NULL ORDER BY struct_col, array_col"#, - r#"SELECT UNNEST(unnest_table.array_col) AS u1, unnest_table.struct_col, unnest_table.array_col FROM unnest_table WHERE (unnest_table.array_col <> NULL) ORDER BY unnest_table.struct_col ASC NULLS LAST, unnest_table.array_col ASC NULLS LAST"#, ); + assert_snapshot!( + statement, + @r#"SELECT UNNEST(unnest_table.array_col) AS u1, unnest_table.struct_col, unnest_table.array_col FROM unnest_table WHERE (unnest_table.array_col <> NULL) ORDER BY unnest_table.struct_col ASC NULLS LAST, unnest_table.array_col ASC NULLS LAST"# + ); +} - sql_round_trip( +#[test] +fn test_unnest_to_sql_2() { + let statement = generate_round_trip_statement( GenericDialect {}, r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#, - r#"SELECT UNNEST([1, 2, 2, 5, NULL]) AS u1"#, + ); + assert_snapshot!( + statement, + @r#"SELECT UNNEST([1, 2, 2, 5, NULL]) AS u1"# ); } #[test] fn test_join_with_no_conditions() { - sql_round_trip( + let statement = generate_round_trip_statement( GenericDialect {}, "SELECT j1.j1_id, j1.j1_string FROM j1 CROSS JOIN j2", - "SELECT j1.j1_id, j1.j1_string FROM j1 CROSS JOIN j2", + ); + assert_snapshot!( + statement, + @r#"SELECT j1.j1_id, j1.j1_string FROM j1 CROSS JOIN j2"# ); } @@ -1562,8 +2065,10 @@ fn test_unparse_extension_to_statement() -> Result<()> { Arc::new(UnusedUnparser {}), ]); let sql = unparser.plan_to_sql(&extension)?; - let expected = "SELECT j1.j1_id, j1.j1_string FROM j1"; - assert_eq!(sql.to_string(), expected); + assert_snapshot!( + sql, + @r#"SELECT j1.j1_id, j1.j1_string FROM j1"# + ); if let Some(err) = plan_to_sql(&extension).err() { assert_contains!( @@ -1625,9 +2130,10 @@ fn test_unparse_extension_to_sql() -> Result<()> { Arc::new(UnusedUnparser {}), ]); let sql = unparser.plan_to_sql(&plan)?; - let expected = - "SELECT j1.j1_id AS user_id FROM (SELECT j1.j1_id, j1.j1_string FROM j1)"; - assert_eq!(sql.to_string(), expected); + assert_snapshot!( + sql, + @r#"SELECT j1.j1_id AS user_id FROM (SELECT j1.j1_id, j1.j1_string FROM j1)"# + ); if let Some(err) = plan_to_sql(&plan).err() { assert_contains!( @@ -1665,10 +2171,10 @@ fn test_unparse_optimized_multi_union() -> Result<()> { ], schema: dfschema.clone(), }); - - let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"; - - assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql); + assert_snapshot!( + unparser.plan_to_sql(&plan)?, + @r#"SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y"# + ); let plan = LogicalPlan::Union(Union { inputs: vec![project( @@ -1746,8 +2252,10 @@ fn test_unparse_subquery_alias_with_table_pushdown() -> Result<()> { let unparser = Unparser::default(); let sql = unparser.plan_to_sql(&plan)?; - let expected = "SELECT customer_view.c_custkey, customer_view.c_name, customer_view.custkey_plus FROM (SELECT customer.c_custkey, (CAST(customer.c_custkey AS BIGINT) + 1) AS custkey_plus, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM customer AS customer) AS customer) AS customer_view"; - assert_eq!(sql.to_string(), expected); + assert_snapshot!( + sql, + @r#"SELECT customer_view.c_custkey, customer_view.c_name, customer_view.custkey_plus FROM (SELECT customer.c_custkey, (CAST(customer.c_custkey AS BIGINT) + 1) AS custkey_plus, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM customer AS customer) AS customer) AS customer_view"# + ); Ok(()) } @@ -1778,7 +2286,10 @@ fn test_unparse_left_anti_join() -> Result<()> { let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); let sql = unparser.plan_to_sql(&plan)?; - assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE NOT EXISTS (SELECT 1 FROM \"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" = \"__correlated_sq_1\".\"c\"))", sql.to_string()); + assert_snapshot!( + sql, + @r#"SELECT "t1"."d" FROM "t1" WHERE NOT EXISTS (SELECT 1 FROM "t2" AS "__correlated_sq_1" WHERE ("t1"."c" = "__correlated_sq_1"."c"))"# + ); Ok(()) } @@ -1809,7 +2320,10 @@ fn test_unparse_left_semi_join() -> Result<()> { let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); let sql = unparser.plan_to_sql(&plan)?; - assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE EXISTS (SELECT 1 FROM \"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" = \"__correlated_sq_1\".\"c\"))", sql.to_string()); + assert_snapshot!( + sql, + @r#"SELECT "t1"."d" FROM "t1" WHERE EXISTS (SELECT 1 FROM "t2" AS "__correlated_sq_1" WHERE ("t1"."c" = "__correlated_sq_1"."c"))"# + ); Ok(()) } @@ -1841,7 +2355,10 @@ fn test_unparse_left_mark_join() -> Result<()> { let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); let sql = unparser.plan_to_sql(&plan)?; - assert_eq!("SELECT \"t1\".\"d\" FROM \"t1\" WHERE (EXISTS (SELECT 1 FROM \"t2\" AS \"__correlated_sq_1\" WHERE (\"t1\".\"c\" = \"__correlated_sq_1\".\"c\")) OR (\"t1\".\"d\" < 0))", sql.to_string()); + assert_snapshot!( + sql, + @r#"SELECT "t1"."d" FROM "t1" WHERE (EXISTS (SELECT 1 FROM "t2" AS "__correlated_sq_1" WHERE ("t1"."c" = "__correlated_sq_1"."c")) OR ("t1"."d" < 0))"# + ); Ok(()) } @@ -1876,7 +2393,10 @@ fn test_unparse_right_semi_join() -> Result<()> { .build()?; let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); let sql = unparser.plan_to_sql(&plan)?; - assert_eq!("SELECT \"t2\".\"c\", \"t2\".\"d\" FROM \"t2\" WHERE (\"t2\".\"c\" <= 1) AND EXISTS (SELECT 1 FROM \"t1\" WHERE (\"t1\".\"c\" = \"t2\".\"c\"))", sql.to_string()); + assert_snapshot!( + sql, + @r#"SELECT "t2"."c", "t2"."d" FROM "t2" WHERE ("t2"."c" <= 1) AND EXISTS (SELECT 1 FROM "t1" WHERE ("t1"."c" = "t2"."c"))"# + ); Ok(()) } @@ -1911,6 +2431,233 @@ fn test_unparse_right_anti_join() -> Result<()> { .build()?; let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); let sql = unparser.plan_to_sql(&plan)?; - assert_eq!("SELECT \"t2\".\"c\", \"t2\".\"d\" FROM \"t2\" WHERE (\"t2\".\"c\" <= 1) AND NOT EXISTS (SELECT 1 FROM \"t1\" WHERE (\"t1\".\"c\" = \"t2\".\"c\"))", sql.to_string()); + assert_snapshot!( + sql, + @r#"SELECT "t2"."c", "t2"."d" FROM "t2" WHERE ("t2"."c" <= 1) AND NOT EXISTS (SELECT 1 FROM "t1" WHERE ("t1"."c" = "t2"."c"))"# + ); + Ok(()) +} + +#[test] +fn test_unparse_cross_join_with_table_scan_projection() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); + // Cross Join: + // SubqueryAlias: t1 + // TableScan: test projection=[v] + // SubqueryAlias: t2 + // TableScan: test projection=[v] + let table_scan1 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let table_scan2 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let plan = LogicalPlanBuilder::from(subquery_alias(table_scan1, "t1")?) + .cross_join(subquery_alias(table_scan2, "t2")?)? + .build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT "t1"."v", "t2"."v" FROM "test" AS "t1" CROSS JOIN "test" AS "t2""# + ); + Ok(()) +} + +#[test] +fn test_unparse_inner_join_with_table_scan_projection() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); + // Inner Join: + // SubqueryAlias: t1 + // TableScan: test projection=[v] + // SubqueryAlias: t2 + // TableScan: test projection=[v] + let table_scan1 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let table_scan2 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let plan = LogicalPlanBuilder::from(subquery_alias(table_scan1, "t1")?) + .join_on( + subquery_alias(table_scan2, "t2")?, + datafusion_expr::JoinType::Inner, + vec![col("t1.v").eq(col("t2.v"))], + )? + .build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT "t1"."v", "t2"."v" FROM "test" AS "t1" INNER JOIN "test" AS "t2" ON ("t1"."v" = "t2"."v")"# + ); + Ok(()) +} + +#[test] +fn test_unparse_left_semi_join_with_table_scan_projection() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); + // LeftSemi Join: + // SubqueryAlias: t1 + // TableScan: test projection=[v] + // SubqueryAlias: t2 + // TableScan: test projection=[v] + let table_scan1 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let table_scan2 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let plan = LogicalPlanBuilder::from(subquery_alias(table_scan1, "t1")?) + .join_on( + subquery_alias(table_scan2, "t2")?, + datafusion_expr::JoinType::LeftSemi, + vec![col("t1.v").eq(col("t2.v"))], + )? + .build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT "t1"."v" FROM "test" AS "t1" WHERE EXISTS (SELECT 1 FROM "test" AS "t2" WHERE ("t1"."v" = "t2"."v"))"# + ); Ok(()) } + +#[test] +fn test_like_filter() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name LIKE '%John%'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name LIKE '%John%'" + ); +} + +#[test] +fn test_ilike_filter() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name ILIKE '%john%'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name ILIKE '%john%'" + ); +} + +#[test] +fn test_not_like_filter() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name NOT LIKE 'A%'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name NOT LIKE 'A%'" + ); +} + +#[test] +fn test_not_ilike_filter() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name NOT ILIKE 'a%'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name NOT ILIKE 'a%'" + ); +} + +#[test] +fn test_like_filter_with_escape() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name LIKE 'A!_%' ESCAPE '!'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name LIKE 'A!_%' ESCAPE '!'" + ); +} + +#[test] +fn test_not_like_filter_with_escape() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name NOT LIKE 'A!_%' ESCAPE '!'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name NOT LIKE 'A!_%' ESCAPE '!'" + ); +} + +#[test] +fn test_not_ilike_filter_with_escape() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name NOT ILIKE 'A!_%' ESCAPE '!'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name NOT ILIKE 'A!_%' ESCAPE '!'" + ); +} + +#[test] +fn test_struct_expr() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"WITH test AS (SELECT STRUCT(STRUCT('Product Name' as name) as product) AS metadata) SELECT metadata.product FROM test WHERE metadata.product.name = 'Product Name'"#, + ); + assert_snapshot!( + statement, + @r#"SELECT test."metadata".product FROM (SELECT {product: {"name": 'Product Name'}} AS "metadata") AS test WHERE (test."metadata".product."name" = 'Product Name')"# + ); + + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"WITH test AS (SELECT STRUCT(STRUCT('Product Name' as name) as product) AS metadata) SELECT metadata.product FROM test WHERE metadata['product']['name'] = 'Product Name'"#, + ); + assert_snapshot!( + statement, + @r#"SELECT test."metadata".product FROM (SELECT {product: {"name": 'Product Name'}} AS "metadata") AS test WHERE (test."metadata".product."name" = 'Product Name')"# + ); +} + +#[test] +fn test_struct_expr2() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT STRUCT(STRUCT('Product Name' as name) as product)['product']['name'] = 'Product Name';"#, + ); + assert_snapshot!( + statement, + @r#"SELECT ({product: {"name": 'Product Name'}}.product."name" = 'Product Name')"# + ); +} + +#[test] +fn test_struct_expr3() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"WITH + test AS ( + SELECT + STRUCT ( + STRUCT ( + STRUCT ('Product Name' as name) as product + ) AS metadata + ) AS c1 + ) + SELECT + c1.metadata.product.name + FROM + test"#, + ); + assert_snapshot!( + statement, + @r#"SELECT test.c1."metadata".product."name" FROM (SELECT {"metadata": {product: {"name": 'Product Name'}}} AS c1) AS test"# + ); +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 10e5b3b1f1267..f66af28f436e6 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -16,27 +16,23 @@ // under the License. use std::any::Any; +use std::hash::Hash; #[cfg(test)] -use std::collections::HashMap; use std::sync::Arc; use std::vec; use arrow::datatypes::{TimeUnit::Nanosecond, *}; use common::MockContextProvider; -use datafusion_common::{ - assert_contains, DataFusionError, ParamValues, Result, ScalarValue, -}; +use datafusion_common::{assert_contains, DataFusionError, Result}; use datafusion_expr::{ - col, - logical_plan::{LogicalPlan, Prepare}, - test::function_stub::sum_udaf, - ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Statement, Volatility, + col, logical_plan::LogicalPlan, test::function_stub::sum_udaf, ColumnarValue, + CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ parser::DFParser, - planner::{ParserOptions, SqlToRel}, + planner::{NullOrdering, ParserOptions, SqlToRel}, }; use crate::common::{CustomExprPlanner, CustomTypePlanner, MockSessionState}; @@ -47,7 +43,7 @@ use datafusion_functions_aggregate::{ }; use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; use datafusion_functions_nested::make_array::make_array_udf; -use datafusion_functions_window::rank::rank_udwf; +use datafusion_functions_window::{rank::rank_udwf, row_number::row_number_udwf}; use insta::{allow_duplicates, assert_snapshot}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -56,322 +52,510 @@ mod cases; mod common; #[test] -fn parse_decimals() { - let test_data = [ - ("1", "Int64(1)"), - ("001", "Int64(1)"), - ("0.1", "Decimal128(Some(1),1,1)"), - ("0.01", "Decimal128(Some(1),2,2)"), - ("1.0", "Decimal128(Some(10),2,1)"), - ("10.01", "Decimal128(Some(1001),4,2)"), - ( - "10000000000000000000.00", - "Decimal128(Some(1000000000000000000000),22,2)", - ), - ("18446744073709551615", "UInt64(18446744073709551615)"), - ( - "18446744073709551616", - "Decimal128(Some(18446744073709551616),20,0)", - ), - ]; - for (a, b) in test_data { - let sql = format!("SELECT {a}"); - let expected = format!("Projection: {b}\n EmptyRelation"); - quick_test_with_options( - &sql, - &expected, - ParserOptions { - parse_float_as_decimal: true, - enable_ident_normalization: false, - support_varchar_with_length: false, - map_varchar_to_utf8view: false, - enable_options_value_normalization: false, - collect_spans: false, - }, - ); - } +fn parse_decimals_1() { + let sql = "SELECT 1"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Int64(1) + EmptyRelation: rows=1 + " + ); } #[test] -fn parse_ident_normalization() { - let test_data = [ - ( - "SELECT CHARACTER_LENGTH('str')", - "Ok(Projection: character_length(Utf8(\"str\"))\n EmptyRelation)", - false, - ), - ( - "SELECT CONCAT('Hello', 'World')", - "Ok(Projection: concat(Utf8(\"Hello\"), Utf8(\"World\"))\n EmptyRelation)", - false, - ), - ( - "SELECT age FROM person", - "Ok(Projection: person.age\n TableScan: person)", - true, - ), - ( - "SELECT AGE FROM PERSON", - "Ok(Projection: person.age\n TableScan: person)", - true, - ), - ( - "SELECT AGE FROM PERSON", - "Error during planning: No table named: PERSON found", - false, - ), - ( - "SELECT Id FROM UPPERCASE_test", - "Ok(Projection: UPPERCASE_test.Id\ - \n TableScan: UPPERCASE_test)", - false, - ), - ( - "SELECT \"Id\", lower FROM \"UPPERCASE_test\"", - "Ok(Projection: UPPERCASE_test.Id, UPPERCASE_test.lower\ - \n TableScan: UPPERCASE_test)", - true, - ), - ]; +fn parse_decimals_2() { + let sql = "SELECT 001"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Int64(1) + EmptyRelation: rows=1 + " + ); +} - for (sql, expected, enable_ident_normalization) in test_data { - let plan = logical_plan_with_options( - sql, - ParserOptions { - parse_float_as_decimal: false, - enable_ident_normalization, - support_varchar_with_length: false, - map_varchar_to_utf8view: false, - enable_options_value_normalization: false, - collect_spans: false, - }, - ); - if plan.is_ok() { - let plan = plan.unwrap(); - assert_eq!(expected, format!("Ok({plan})")); - } else { - assert_eq!(expected, plan.unwrap_err().strip_backtrace()); - } - } +#[test] +fn parse_decimals_3() { + let sql = "SELECT 0.1"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Decimal128(Some(1),1,1) + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn parse_decimals_4() { + let sql = "SELECT 0.01"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Decimal128(Some(1),2,2) + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn parse_decimals_5() { + let sql = "SELECT 1.0"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Decimal128(Some(10),2,1) + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn parse_decimals_6() { + let sql = "SELECT 10.01"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Decimal128(Some(1001),4,2) + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn parse_decimals_7() { + let sql = "SELECT 10000000000000000000.00"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Decimal128(Some(1000000000000000000000),22,2) + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn parse_decimals_8() { + let sql = "SELECT 18446744073709551615"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: UInt64(18446744073709551615) + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn parse_decimals_9() { + let sql = "SELECT 18446744073709551616"; + let options = parse_decimals_parser_options(); + let plan = logical_plan_with_options(sql, options).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Decimal128(Some(18446744073709551616),20,0) + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn parse_ident_normalization_1() { + let sql = "SELECT CHARACTER_LENGTH('str')"; + let parser_option = ident_normalization_parser_options_no_ident_normalization(); + let plan = logical_plan_with_options(sql, parser_option).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: character_length(Utf8("str")) + EmptyRelation: rows=1 + "# + ); +} + +#[test] +fn parse_ident_normalization_2() { + let sql = "SELECT CONCAT('Hello', 'World')"; + let parser_option = ident_normalization_parser_options_no_ident_normalization(); + let plan = logical_plan_with_options(sql, parser_option).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: concat(Utf8("Hello"), Utf8("World")) + EmptyRelation: rows=1 + "# + ); +} + +#[test] +fn parse_ident_normalization_3() { + let sql = "SELECT age FROM person"; + let parser_option = ident_normalization_parser_options_ident_normalization(); + let plan = logical_plan_with_options(sql, parser_option).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.age + TableScan: person + "# + ); +} + +#[test] +fn parse_ident_normalization_4() { + let sql = "SELECT AGE FROM PERSON"; + let parser_option = ident_normalization_parser_options_ident_normalization(); + let plan = logical_plan_with_options(sql, parser_option).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.age + TableScan: person + "# + ); +} + +#[test] +fn parse_ident_normalization_5() { + let sql = "SELECT AGE FROM PERSON"; + let parser_option = ident_normalization_parser_options_no_ident_normalization(); + let plan = logical_plan_with_options(sql, parser_option) + .unwrap_err() + .strip_backtrace(); + assert_snapshot!( + plan, + @r#" + Error during planning: No table named: PERSON found + "# + ); +} + +#[test] +fn parse_ident_normalization_6() { + let sql = "SELECT Id FROM UPPERCASE_test"; + let parser_option = ident_normalization_parser_options_no_ident_normalization(); + let plan = logical_plan_with_options(sql, parser_option).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: UPPERCASE_test.Id + TableScan: UPPERCASE_test + "# + ); +} + +#[test] +fn parse_ident_normalization_7() { + let sql = r#"SELECT "Id", lower FROM "UPPERCASE_test""#; + let parser_option = ident_normalization_parser_options_ident_normalization(); + let plan = logical_plan_with_options(sql, parser_option).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: UPPERCASE_test.Id, UPPERCASE_test.lower + TableScan: UPPERCASE_test + "# + ); } #[test] fn select_no_relation() { - quick_test( - "SELECT 1", - "Projection: Int64(1)\ - \n EmptyRelation", + let plan = logical_plan("SELECT 1").unwrap(); + assert_snapshot!( + plan, + @r" + Projection: Int64(1) + EmptyRelation: rows=1 + " ); } #[test] fn test_real_f32() { - quick_test( - "SELECT CAST(1.1 AS REAL)", - "Projection: CAST(Float64(1.1) AS Float32)\ - \n EmptyRelation", + let plan = logical_plan("SELECT CAST(1.1 AS REAL)").unwrap(); + assert_snapshot!( + plan, + @r" + Projection: CAST(Float64(1.1) AS Float32) + EmptyRelation: rows=1 + " ); } #[test] fn test_int_decimal_default() { - quick_test( - "SELECT CAST(10 AS DECIMAL)", - "Projection: CAST(Int64(10) AS Decimal128(38, 10))\ - \n EmptyRelation", + let plan = logical_plan("SELECT CAST(10 AS DECIMAL)").unwrap(); + assert_snapshot!( + plan, + @r" + Projection: CAST(Int64(10) AS Decimal128(38, 10)) + EmptyRelation: rows=1 + " ); } #[test] fn test_int_decimal_no_scale() { - quick_test( - "SELECT CAST(10 AS DECIMAL(5))", - "Projection: CAST(Int64(10) AS Decimal128(5, 0))\ - \n EmptyRelation", + let plan = logical_plan("SELECT CAST(10 AS DECIMAL(5))").unwrap(); + assert_snapshot!( + plan, + @r" + Projection: CAST(Int64(10) AS Decimal128(5, 0)) + EmptyRelation: rows=1 + " ); } #[test] fn test_tinyint() { - quick_test( - "SELECT CAST(6 AS TINYINT)", - "Projection: CAST(Int64(6) AS Int8)\ - \n EmptyRelation", + let plan = logical_plan("SELECT CAST(6 AS TINYINT)").unwrap(); + assert_snapshot!( + plan, + @r" + Projection: CAST(Int64(6) AS Int8) + EmptyRelation: rows=1 + " ); } #[test] fn cast_from_subquery() { - quick_test( - "SELECT CAST (a AS FLOAT) FROM (SELECT 1 AS a)", - "Projection: CAST(a AS Float32)\ - \n Projection: Int64(1) AS a\ - \n EmptyRelation", + let plan = logical_plan("SELECT CAST (a AS FLOAT) FROM (SELECT 1 AS a)").unwrap(); + assert_snapshot!( + plan, + @r" + Projection: CAST(a AS Float32) + Projection: Int64(1) AS a + EmptyRelation: rows=1 + " ); } #[test] fn try_cast_from_aggregation() { - quick_test( - "SELECT TRY_CAST(sum(age) AS FLOAT) FROM person", - "Projection: TRY_CAST(sum(person.age) AS Float32)\ - \n Aggregate: groupBy=[[]], aggr=[[sum(person.age)]]\ - \n TableScan: person", + let plan = logical_plan("SELECT TRY_CAST(sum(age) AS FLOAT) FROM person").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: TRY_CAST(sum(person.age) AS Float32) + Aggregate: groupBy=[[]], aggr=[[sum(person.age)]] + TableScan: person + "# ); } #[test] fn cast_to_invalid_decimal_type_precision_0() { // precision == 0 - { - let sql = "SELECT CAST(10 AS DECIMAL(0))"; - let err = logical_plan(sql).expect_err("query should have failed"); - - assert_snapshot!( - err.strip_backtrace(), - @r"Error during planning: Decimal(precision = 0, scale = 0) should satisfy `0 < precision <= 76`, and `scale <= precision`." - ); - } + let sql = "SELECT CAST(10 AS DECIMAL(0))"; + let err = logical_plan(sql).expect_err("query should have failed"); + + assert_snapshot!( + err.strip_backtrace(), + @r"Error during planning: Decimal(precision = 0, scale = 0) should satisfy `0 < precision <= 76`, and `scale <= precision`." + ); } #[test] fn cast_to_invalid_decimal_type_precision_gt_38() { // precision > 38 - { - let sql = "SELECT CAST(10 AS DECIMAL(39))"; - let plan = "Projection: CAST(Int64(10) AS Decimal256(39, 0))\n EmptyRelation"; - quick_test(sql, plan); - } + let sql = "SELECT CAST(10 AS DECIMAL(39))"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: CAST(Int64(10) AS Decimal256(39, 0)) + EmptyRelation: rows=1 + " + ); } #[test] fn cast_to_invalid_decimal_type_precision_gt_76() { // precision > 76 - { - let sql = "SELECT CAST(10 AS DECIMAL(79))"; - let err = logical_plan(sql).expect_err("query should have failed"); - - assert_snapshot!( - err.strip_backtrace(), - @r"Error during planning: Decimal(precision = 79, scale = 0) should satisfy `0 < precision <= 76`, and `scale <= precision`." - ); - } + let sql = "SELECT CAST(10 AS DECIMAL(79))"; + let err = logical_plan(sql).expect_err("query should have failed"); + + assert_snapshot!( + err.strip_backtrace(), + @r"Error during planning: Decimal(precision = 79, scale = 0) should satisfy `0 < precision <= 76`, and `scale <= precision`." + ); } #[test] fn cast_to_invalid_decimal_type_precision_lt_scale() { // precision < scale - { - let sql = "SELECT CAST(10 AS DECIMAL(5, 10))"; - let err = logical_plan(sql).expect_err("query should have failed"); - - assert_snapshot!( - err.strip_backtrace(), - @r"Error during planning: Decimal(precision = 5, scale = 10) should satisfy `0 < precision <= 76`, and `scale <= precision`." - ); - } + let sql = "SELECT CAST(10 AS DECIMAL(5, 10))"; + let err = logical_plan(sql).expect_err("query should have failed"); + + assert_snapshot!( + err.strip_backtrace(), + @r"Error during planning: Decimal(precision = 5, scale = 10) should satisfy `0 < precision <= 76`, and `scale <= precision`." + ); } #[test] fn plan_create_table_with_pk() { let sql = "create table person (id int, name string, primary key(id))"; - let plan = r#" -CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0])] - EmptyRelation + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0])] + EmptyRelation: rows=0 "# - .trim(); - quick_test(sql, plan); + ); let sql = "create table person (id int primary key, name string)"; - let plan = r#" -CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0])] - EmptyRelation + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0])] + EmptyRelation: rows=0 "# - .trim(); - quick_test(sql, plan); + ); let sql = "create table person (id int, name string unique not null, primary key(id))"; - let plan = r#" -CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0]), Unique([1])] - EmptyRelation + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0]), Unique([1])] + EmptyRelation: rows=0 "# - .trim(); - quick_test(sql, plan); + ); let sql = "create table person (id int, name varchar, primary key(name, id));"; - let plan = r#" -CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([1, 0])] - EmptyRelation + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([1, 0])] + EmptyRelation: rows=0 "# - .trim(); - quick_test(sql, plan); + ); } #[test] fn plan_create_table_with_multi_pk() { let sql = "create table person (id int, name string primary key, primary key(id))"; - let plan = r#" -CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0]), PrimaryKey([1])] - EmptyRelation + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateMemoryTable: Bare { table: "person" } constraints=[PrimaryKey([0]), PrimaryKey([1])] + EmptyRelation: rows=0 "# - .trim(); - quick_test(sql, plan); + ); } #[test] fn plan_create_table_with_unique() { let sql = "create table person (id int unique, name string)"; - let plan = "CreateMemoryTable: Bare { table: \"person\" } constraints=[Unique([0])]\n EmptyRelation"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateMemoryTable: Bare { table: "person" } constraints=[Unique([0])] + EmptyRelation: rows=0 + "# + ); } #[test] fn plan_create_table_no_pk() { let sql = "create table person (id int, name string)"; - let plan = r#" -CreateMemoryTable: Bare { table: "person" } - EmptyRelation + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateMemoryTable: Bare { table: "person" } + EmptyRelation: rows=0 "# - .trim(); - quick_test(sql, plan); + ); } #[test] fn plan_create_table_check_constraint() { let sql = "create table person (id int, name string, unique(id))"; - let plan = "CreateMemoryTable: Bare { table: \"person\" } constraints=[Unique([0])]\n EmptyRelation"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateMemoryTable: Bare { table: "person" } constraints=[Unique([0])] + EmptyRelation: rows=0 + "# + ); } #[test] fn plan_start_transaction() { let sql = "start transaction"; - let plan = "TransactionStart: ReadWrite Serializable"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionStart: ReadWrite Serializable + "# + ); } #[test] fn plan_start_transaction_isolation() { let sql = "start transaction isolation level read committed"; - let plan = "TransactionStart: ReadWrite ReadCommitted"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionStart: ReadWrite ReadCommitted + "# + ); } #[test] fn plan_start_transaction_read_only() { let sql = "start transaction read only"; - let plan = "TransactionStart: ReadOnly Serializable"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionStart: ReadOnly Serializable + "# + ); } #[test] fn plan_start_transaction_fully_qualified() { let sql = "start transaction isolation level read committed read only"; - let plan = "TransactionStart: ReadOnly ReadCommitted"; - quick_test(sql, plan); -} - + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionStart: ReadOnly ReadCommitted + "# + ); +} + #[test] fn plan_start_transaction_overly_qualified() { let sql = r#"start transaction @@ -379,95 +563,131 @@ isolation level read committed read only isolation level repeatable read "#; - let plan = "TransactionStart: ReadOnly RepeatableRead"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionStart: ReadOnly RepeatableRead + "# + ); } #[test] fn plan_commit_transaction() { let sql = "commit transaction"; - let plan = "TransactionEnd: Commit chain:=false"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionEnd: Commit chain:=false + "# + ); } #[test] fn plan_commit_transaction_chained() { let sql = "commit transaction and chain"; - let plan = "TransactionEnd: Commit chain:=true"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionEnd: Commit chain:=true + "# + ); } #[test] fn plan_rollback_transaction() { let sql = "rollback transaction"; - let plan = "TransactionEnd: Rollback chain:=false"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionEnd: Rollback chain:=false + "# + ); } #[test] fn plan_rollback_transaction_chained() { let sql = "rollback transaction and chain"; - let plan = "TransactionEnd: Rollback chain:=true"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + TransactionEnd: Rollback chain:=true + "# + ); } #[test] fn plan_copy_to() { let sql = "COPY test_decimal to 'output.csv' STORED AS CSV"; - let plan = r#" -CopyTo: format=csv output_url=output.csv options: () - TableScan: test_decimal - "# - .trim(); - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CopyTo: format=csv output_url=output.csv options: () + TableScan: test_decimal + "# + ); } #[test] fn plan_explain_copy_to() { let sql = "EXPLAIN COPY test_decimal to 'output.csv'"; - let plan = r#" -Explain - CopyTo: format=csv output_url=output.csv options: () - TableScan: test_decimal - "# - .trim(); - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Explain + CopyTo: format=csv output_url=output.csv options: () + TableScan: test_decimal + "# + ); } #[test] fn plan_explain_copy_to_format() { let sql = "EXPLAIN COPY test_decimal to 'output.tbl' STORED AS CSV"; - let plan = r#" -Explain - CopyTo: format=csv output_url=output.tbl options: () - TableScan: test_decimal - "# - .trim(); - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Explain + CopyTo: format=csv output_url=output.tbl options: () + TableScan: test_decimal + "# + ); } #[test] fn plan_insert() { let sql = "insert into person (id, first_name, last_name) values (1, 'Alan', 'Turing')"; - let plan = "Dml: op=[Insert Into] table=[person]\ - \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ - CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ - CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ - \n Values: (CAST(Int64(1) AS UInt32), Utf8(\"Alan\"), Utf8(\"Turing\"))"; - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: (CAST(Int64(1) AS UInt32), Utf8("Alan"), Utf8("Turing")) + "# + ); } #[test] fn plan_insert_no_target_columns() { let sql = "INSERT INTO test_decimal VALUES (1, 2), (3, 4)"; - let plan = r#" -Dml: op=[Insert Into] table=[test_decimal] - Projection: column1 AS id, column2 AS price - Values: (CAST(Int64(1) AS Int32), CAST(Int64(2) AS Decimal128(10, 2))), (CAST(Int64(3) AS Int32), CAST(Int64(4) AS Decimal128(10, 2))) - "# - .trim(); - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Insert Into] table=[test_decimal] + Projection: column1 AS id, column2 AS price + Values: (CAST(Int64(1) AS Int32), CAST(Int64(2) AS Decimal128(10, 2))), (CAST(Int64(3) AS Int32), CAST(Int64(4) AS Decimal128(10, 2))) + "# + ); } #[rstest] @@ -505,19 +725,21 @@ fn test_insert_schema_errors(#[case] sql: &str, #[case] error: &str) { #[test] fn plan_update() { let sql = "update person set last_name='Kay' where id=1"; - let plan = r#" -Dml: op=[Update] table=[person] - Projection: person.id AS id, person.first_name AS first_name, Utf8("Kay") AS last_name, person.age AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: person.id = Int64(1) - TableScan: person - "# - .trim(); - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, Utf8("Kay") AS last_name, person.age AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = Int64(1) + TableScan: person + "# + ); } #[rstest] -#[case::missing_assignement_target("UPDATE person SET doesnotexist = true")] -#[case::missing_assignement_expression("UPDATE person SET age = doesnotexist + 42")] +#[case::missing_assignment_target("UPDATE person SET doesnotexist = true")] +#[case::missing_assignment_expression("UPDATE person SET age = doesnotexist + 42")] #[case::missing_selection_expression( "UPDATE person SET age = 42 WHERE doesnotexist = true" )] @@ -530,26 +752,30 @@ fn update_column_does_not_exist(#[case] sql: &str) { #[test] fn plan_delete() { let sql = "delete from person where id=1"; - let plan = r#" -Dml: op=[Delete] table=[person] - Filter: id = Int64(1) - TableScan: person - "# - .trim(); - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Delete] table=[person] + Filter: person.id = Int64(1) + TableScan: person + "# + ); } #[test] fn plan_delete_quoted_identifier_case_sensitive() { let sql = "DELETE FROM \"SomeCatalog\".\"SomeSchema\".\"UPPERCASE_test\" WHERE \"Id\" = 1"; - let plan = r#" -Dml: op=[Delete] table=[SomeCatalog.SomeSchema.UPPERCASE_test] - Filter: Id = Int64(1) - TableScan: SomeCatalog.SomeSchema.UPPERCASE_test - "# - .trim(); - quick_test(sql, plan); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Dml: op=[Delete] table=[SomeCatalog.SomeSchema.UPPERCASE_test] + Filter: SomeCatalog.SomeSchema.UPPERCASE_test.Id = Int64(1) + TableScan: SomeCatalog.SomeSchema.UPPERCASE_test + "# + ); } #[test] @@ -574,10 +800,13 @@ fn select_repeated_column() { #[test] fn select_scalar_func_with_literal_no_relation() { - quick_test( - "SELECT sqrt(9)", - "Projection: sqrt(Int64(9))\ - \n EmptyRelation", + let plan = logical_plan("SELECT sqrt(9)").unwrap(); + assert_snapshot!( + plan, + @r" + Projection: sqrt(Int64(9)) + EmptyRelation: rows=1 + " ); } @@ -585,10 +814,15 @@ fn select_scalar_func_with_literal_no_relation() { fn select_simple_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE state = 'CO'"; - let expected = "Projection: person.id, person.first_name, person.last_name\ - \n Filter: person.state = Utf8(\"CO\")\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.id, person.first_name, person.last_name + Filter: person.state = Utf8("CO") + TableScan: person + "# + ); } #[test] @@ -609,40 +843,58 @@ fn select_filter_cannot_use_alias() { fn select_neg_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE NOT state"; - let expected = "Projection: person.id, person.first_name, person.last_name\ - \n Filter: NOT person.state\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.id, person.first_name, person.last_name + Filter: NOT person.state + TableScan: person + "# + ); } #[test] fn select_compound_filter() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65"; - let expected = "Projection: person.id, person.first_name, person.last_name\ - \n Filter: person.state = Utf8(\"CO\") AND person.age >= Int64(21) AND person.age <= Int64(65)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.id, person.first_name, person.last_name + Filter: person.state = Utf8("CO") AND person.age >= Int64(21) AND person.age <= Int64(65) + TableScan: person + "# + ); } #[test] fn test_timestamp_filter() { let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)"; - let expected = "Projection: person.state\ - \n Filter: person.birth_date < CAST(CAST(Int64(158412331400600000) AS Timestamp(Second, None)) AS Timestamp(Nanosecond, None))\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state + Filter: person.birth_date < CAST(CAST(Int64(158412331400600000) AS Timestamp(Second, None)) AS Timestamp(Nanosecond, None)) + TableScan: person + "# + ); } #[test] fn test_date_filter() { let sql = "SELECT state FROM person WHERE birth_date < CAST ('2020-01-01' as date)"; - - let expected = "Projection: person.state\ - \n Filter: person.birth_date < CAST(Utf8(\"2020-01-01\") AS Date32)\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state + Filter: person.birth_date < CAST(Utf8("2020-01-01") AS Date32) + TableScan: person + "# + ); } #[test] @@ -655,35 +907,43 @@ fn select_all_boolean_operators() { AND age >= 21 \ AND age < 65 \ AND age <= 65"; - let expected = "Projection: person.age, person.first_name, person.last_name\ - \n Filter: person.age = Int64(21) \ - AND person.age != Int64(21) \ - AND person.age > Int64(21) \ - AND person.age >= Int64(21) \ - AND person.age < Int64(65) \ - AND person.age <= Int64(65)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.age, person.first_name, person.last_name + Filter: person.age = Int64(21) AND person.age != Int64(21) AND person.age > Int64(21) AND person.age >= Int64(21) AND person.age < Int64(65) AND person.age <= Int64(65) + TableScan: person + "# + ); } #[test] fn select_between() { let sql = "SELECT state FROM person WHERE age BETWEEN 21 AND 65"; - let expected = "Projection: person.state\ - \n Filter: person.age BETWEEN Int64(21) AND Int64(65)\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state + Filter: person.age BETWEEN Int64(21) AND Int64(65) + TableScan: person + "# + ); } #[test] fn select_between_negated() { let sql = "SELECT state FROM person WHERE age NOT BETWEEN 21 AND 65"; - let expected = "Projection: person.state\ - \n Filter: person.age NOT BETWEEN Int64(21) AND Int64(65)\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state + Filter: person.age NOT BETWEEN Int64(21) AND Int64(65) + TableScan: person + "# + ); } #[test] @@ -696,13 +956,18 @@ fn select_nested() { FROM person ) AS a ) AS b"; - let expected = "Projection: b.fn2, b.last_name\ - \n SubqueryAlias: b\ - \n Projection: a.fn1 AS fn2, a.last_name, a.birth_date\ - \n SubqueryAlias: a\ - \n Projection: person.first_name AS fn1, person.last_name, person.birth_date, person.age\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: b.fn2, b.last_name + SubqueryAlias: b + Projection: a.fn1 AS fn2, a.last_name, a.birth_date + SubqueryAlias: a + Projection: person.first_name AS fn1, person.last_name, person.birth_date, person.age + TableScan: person + "# + ); } #[test] @@ -714,27 +979,34 @@ fn select_nested_with_filters() { WHERE age > 20 ) AS a WHERE fn1 = 'X' AND age < 30"; - - let expected = "Projection: a.fn1, a.age\ - \n Filter: a.fn1 = Utf8(\"X\") AND a.age < Int64(30)\ - \n SubqueryAlias: a\ - \n Projection: person.first_name AS fn1, person.age\ - \n Filter: person.age > Int64(20)\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: a.fn1, a.age + Filter: a.fn1 = Utf8("X") AND a.age < Int64(30) + SubqueryAlias: a + Projection: person.first_name AS fn1, person.age + Filter: person.age > Int64(20) + TableScan: person + "# + ); } #[test] fn table_with_column_alias() { let sql = "SELECT a, b, c FROM lineitem l (a, b, c)"; - let expected = "Projection: l.a, l.b, l.c\ - \n SubqueryAlias: l\ - \n Projection: lineitem.l_item_id AS a, lineitem.l_description AS b, lineitem.price AS c\ - \n TableScan: lineitem"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: l.a, l.b, l.c + SubqueryAlias: l + Projection: lineitem.l_item_id AS a, lineitem.l_description AS b, lineitem.price AS c + TableScan: lineitem + "# + ); } #[test] @@ -764,37 +1036,52 @@ fn select_with_ambiguous_column() { fn join_with_ambiguous_column() { // This is legal. let sql = "SELECT id FROM person a join person b using(id)"; - let expected = "Projection: a.id\ - \n Inner Join: Using a.id = b.id\ - \n SubqueryAlias: a\ - \n TableScan: person\ - \n SubqueryAlias: b\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: a.id + Inner Join: Using a.id = b.id + SubqueryAlias: a + TableScan: person + SubqueryAlias: b + TableScan: person + "# + ); } #[test] fn natural_left_join() { let sql = "SELECT l_item_id FROM lineitem a NATURAL LEFT JOIN lineitem b"; - let expected = "Projection: a.l_item_id\ - \n Left Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price\ - \n SubqueryAlias: a\ - \n TableScan: lineitem\ - \n SubqueryAlias: b\ - \n TableScan: lineitem"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: a.l_item_id + Left Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price + SubqueryAlias: a + TableScan: lineitem + SubqueryAlias: b + TableScan: lineitem + "# + ); } #[test] fn natural_right_join() { let sql = "SELECT l_item_id FROM lineitem a NATURAL RIGHT JOIN lineitem b"; - let expected = "Projection: a.l_item_id\ - \n Right Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price\ - \n SubqueryAlias: a\ - \n TableScan: lineitem\ - \n SubqueryAlias: b\ - \n TableScan: lineitem"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: a.l_item_id + Right Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price + SubqueryAlias: a + TableScan: lineitem + SubqueryAlias: b + TableScan: lineitem + "# + ); } #[test] @@ -874,11 +1161,16 @@ fn select_aggregate_with_having_that_reuses_aggregate() { let sql = "SELECT MAX(age) FROM person HAVING MAX(age) < 30"; - let expected = "Projection: max(person.age)\ - \n Filter: max(person.age) < Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: max(person.age) + Filter: max(person.age) < Int64(30) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -886,11 +1178,16 @@ fn select_aggregate_with_having_with_aggregate_not_in_select() { let sql = "SELECT max(age) FROM person HAVING max(first_name) > 'M'"; - let expected = "Projection: max(person.age)\ - \n Filter: max(person.first_name) > Utf8(\"M\")\ - \n Aggregate: groupBy=[[]], aggr=[[max(person.age), max(person.first_name)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: max(person.age) + Filter: max(person.first_name) > Utf8("M") + Aggregate: groupBy=[[]], aggr=[[max(person.age), max(person.first_name)]] + TableScan: person + "# + ); } #[test] @@ -914,11 +1211,16 @@ fn select_aggregate_aliased_with_having_referencing_aggregate_by_its_alias() { FROM person HAVING max_age < 30"; // FIXME: add test for having in execution - let expected = "Projection: max(person.age) AS max_age\ - \n Filter: max(person.age) < Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: max(person.age) AS max_age + Filter: max(person.age) < Int64(30) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -926,11 +1228,16 @@ fn select_aggregate_aliased_with_having_that_reuses_aggregate_but_not_by_its_ali let sql = "SELECT max(age) as max_age FROM person HAVING max(age) < 30"; - let expected = "Projection: max(person.age) AS max_age\ - \n Filter: max(person.age) < Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: max(person.age) AS max_age + Filter: max(person.age) < Int64(30) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -939,11 +1246,16 @@ fn select_aggregate_with_group_by_with_having() { FROM person GROUP BY first_name HAVING first_name = 'M'"; - let expected = "Projection: person.first_name, max(person.age)\ - \n Filter: person.first_name = Utf8(\"M\")\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Filter: person.first_name = Utf8("M") + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -953,12 +1265,17 @@ fn select_aggregate_with_group_by_with_having_and_where() { WHERE id > 5 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Projection: person.first_name, max(person.age)\ - \n Filter: max(person.age) < Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n Filter: person.id > Int64(5)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Filter: max(person.age) < Int64(100) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + Filter: person.id > Int64(5) + TableScan: person + "# + ); } #[test] @@ -968,12 +1285,17 @@ fn select_aggregate_with_group_by_with_having_and_where_filtering_on_aggregate_c WHERE id > 5 AND age > 18 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Projection: person.first_name, max(person.age)\ - \n Filter: max(person.age) < Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n Filter: person.id > Int64(5) AND person.age > Int64(18)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Filter: max(person.age) < Int64(100) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + Filter: person.id > Int64(5) AND person.age > Int64(18) + TableScan: person + "# + ); } #[test] @@ -982,11 +1304,16 @@ fn select_aggregate_with_group_by_with_having_using_column_by_alias() { FROM person GROUP BY first_name HAVING MAX(age) > 2 AND fn = 'M'"; - let expected = "Projection: person.first_name AS fn, max(person.age)\ - \n Filter: max(person.age) > Int64(2) AND person.first_name = Utf8(\"M\")\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name AS fn, max(person.age) + Filter: max(person.age) > Int64(2) AND person.first_name = Utf8("M") + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -996,24 +1323,34 @@ fn select_aggregate_with_group_by_with_having_using_columns_with_and_without_the FROM person GROUP BY first_name HAVING MAX(age) > 2 AND max_age < 5 AND first_name = 'M' AND fn = 'N'"; - let expected = "Projection: person.first_name AS fn, max(person.age) AS max_age\ - \n Filter: max(person.age) > Int64(2) AND max(person.age) < Int64(5) AND person.first_name = Utf8(\"M\") AND person.first_name = Utf8(\"N\")\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); -} - -#[test] -fn select_aggregate_with_group_by_with_having_that_reuses_aggregate() { + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name AS fn, max(person.age) AS max_age + Filter: max(person.age) > Int64(2) AND max(person.age) < Int64(5) AND person.first_name = Utf8("M") AND person.first_name = Utf8("N") + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + TableScan: person + "# + ); +} + +#[test] +fn select_aggregate_with_group_by_with_having_that_reuses_aggregate() { let sql = "SELECT first_name, MAX(age) FROM person GROUP BY first_name HAVING MAX(age) > 100"; - let expected = "Projection: person.first_name, max(person.age)\ - \n Filter: max(person.age) > Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Filter: max(person.age) > Int64(100) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -1038,11 +1375,16 @@ fn select_aggregate_with_group_by_with_having_that_reuses_aggregate_multiple_tim FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MAX(age) < 200"; - let expected = "Projection: person.first_name, max(person.age)\ - \n Filter: max(person.age) > Int64(100) AND max(person.age) < Int64(200)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Filter: max(person.age) > Int64(100) AND max(person.age) < Int64(200) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -1051,11 +1393,16 @@ fn select_aggregate_with_group_by_with_having_using_aggregate_not_in_select() { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MIN(id) < 50"; - let expected = "Projection: person.first_name, max(person.age)\ - \n Filter: max(person.age) > Int64(100) AND min(person.id) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), min(person.id)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Filter: max(person.age) > Int64(100) AND min(person.id) < Int64(50) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), min(person.id)]] + TableScan: person + "# + ); } #[test] @@ -1065,11 +1412,16 @@ fn select_aggregate_aliased_with_group_by_with_having_referencing_aggregate_by_i FROM person GROUP BY first_name HAVING max_age > 100"; - let expected = "Projection: person.first_name, max(person.age) AS max_age\ - \n Filter: max(person.age) > Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) AS max_age + Filter: max(person.age) > Int64(100) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -1079,11 +1431,16 @@ fn select_aggregate_compound_aliased_with_group_by_with_having_referencing_compo FROM person GROUP BY first_name HAVING max_age_plus_one > 100"; - let expected = "Projection: person.first_name, max(person.age) + Int64(1) AS max_age_plus_one\ - \n Filter: max(person.age) + Int64(1) > Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Int64(1) AS max_age_plus_one + Filter: max(person.age) + Int64(1) > Int64(100) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]] + TableScan: person + "# + ); } #[test] @@ -1093,11 +1450,16 @@ fn select_aggregate_with_group_by_with_having_using_derived_column_aggregate_not FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MIN(id - 2) < 50"; - let expected = "Projection: person.first_name, max(person.age)\ - \n Filter: max(person.age) > Int64(100) AND min(person.id - Int64(2)) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), min(person.id - Int64(2))]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Filter: max(person.age) > Int64(100) AND min(person.id - Int64(2)) < Int64(50) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), min(person.id - Int64(2))]] + TableScan: person + "# + ); } #[test] @@ -1106,46 +1468,67 @@ fn select_aggregate_with_group_by_with_having_using_count_star_not_in_select() { FROM person GROUP BY first_name HAVING MAX(age) > 100 AND count(*) < 50"; - let expected = "Projection: person.first_name, max(person.age)\ - \n Filter: max(person.age) > Int64(100) AND count(*) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), count(*)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.first_name, max(person.age) + Filter: max(person.age) > Int64(100) AND count(*) < Int64(50) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), count(*)]] + TableScan: person + "# + ); } #[test] fn select_binary_expr() { let sql = "SELECT age + salary from person"; - let expected = "Projection: person.age + person.salary\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.age + person.salary + TableScan: person + "# + ); } #[test] fn select_binary_expr_nested() { let sql = "SELECT (age + salary)/2 from person"; - let expected = "Projection: (person.age + person.salary) / Int64(2)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: (person.age + person.salary) / Int64(2) + TableScan: person + "# + ); } #[test] fn select_simple_aggregate() { - quick_test( - "SELECT MIN(age) FROM person", - "Projection: min(person.age)\ - \n Aggregate: groupBy=[[]], aggr=[[min(person.age)]]\ - \n TableScan: person", + let plan = logical_plan("SELECT MIN(age) FROM person").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: min(person.age) + Aggregate: groupBy=[[]], aggr=[[min(person.age)]] + TableScan: person + "# ); } #[test] fn test_sum_aggregate() { - quick_test( - "SELECT sum(age) from person", - "Projection: sum(person.age)\ - \n Aggregate: groupBy=[[]], aggr=[[sum(person.age)]]\ - \n TableScan: person", + let plan = logical_plan("SELECT sum(age) from person").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: sum(person.age) + Aggregate: groupBy=[[]], aggr=[[sum(person.age)]] + TableScan: person + "# ); } @@ -1171,33 +1554,44 @@ fn select_simple_aggregate_repeated_aggregate() { #[test] fn select_simple_aggregate_repeated_aggregate_with_single_alias() { - quick_test( - "SELECT MIN(age), MIN(age) AS a FROM person", - "Projection: min(person.age), min(person.age) AS a\ - \n Aggregate: groupBy=[[]], aggr=[[min(person.age)]]\ - \n TableScan: person", + let plan = logical_plan("SELECT MIN(age), MIN(age) AS a FROM person").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: min(person.age), min(person.age) AS a + Aggregate: groupBy=[[]], aggr=[[min(person.age)]] + TableScan: person + "# ); } #[test] fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() { - quick_test( - "SELECT MIN(age) AS a, MIN(age) AS b FROM person", - "Projection: min(person.age) AS a, min(person.age) AS b\ - \n Aggregate: groupBy=[[]], aggr=[[min(person.age)]]\ - \n TableScan: person", + let plan = logical_plan("SELECT MIN(age) AS a, MIN(age) AS b FROM person").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: min(person.age) AS a, min(person.age) AS b + Aggregate: groupBy=[[]], aggr=[[min(person.age)]] + TableScan: person + "# ); } #[test] fn select_from_typed_string_values() { - quick_test( - "SELECT col1, col2 FROM (VALUES (TIMESTAMP '2021-06-10 17:01:00Z', DATE '2004-04-09')) as t (col1, col2)", - "Projection: t.col1, t.col2\ - \n SubqueryAlias: t\ - \n Projection: column1 AS col1, column2 AS col2\ - \n Values: (CAST(Utf8(\"2021-06-10 17:01:00Z\") AS Timestamp(Nanosecond, None)), CAST(Utf8(\"2004-04-09\") AS Date32))", - ); + let plan = logical_plan( + "SELECT col1, col2 FROM (VALUES (TIMESTAMP '2021-06-10 17:01:00Z', DATE '2004-04-09')) as t (col1, col2)", + ).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: t.col1, t.col2 + SubqueryAlias: t + Projection: column1 AS col1, column2 AS col2 + Values: (CAST(Utf8("2021-06-10 17:01:00Z") AS Timestamp(Nanosecond, None)), CAST(Utf8("2004-04-09") AS Date32)) + "# + ); } #[test] @@ -1215,21 +1609,31 @@ fn select_simple_aggregate_repeated_aggregate_with_repeated_aliases() { #[test] fn select_simple_aggregate_with_groupby() { - quick_test( - "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Projection: person.state, min(person.age), max(person.age)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age), max(person.age)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT state, MIN(age), MAX(age) FROM person GROUP BY state") + .unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state, min(person.age), max(person.age) + Aggregate: groupBy=[[person.state]], aggr=[[min(person.age), max(person.age)]] + TableScan: person + "# ); } #[test] fn select_simple_aggregate_with_groupby_with_aliases() { - quick_test( - "SELECT state AS a, MIN(age) AS b FROM person GROUP BY state", - "Projection: person.state AS a, min(person.age) AS b\ - \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT state AS a, MIN(age) AS b FROM person GROUP BY state") + .unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state AS a, min(person.age) AS b + Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]] + TableScan: person + "# ); } @@ -1248,11 +1652,15 @@ fn select_simple_aggregate_with_groupby_with_aliases_repeated() { #[test] fn select_simple_aggregate_with_groupby_column_unselected() { - quick_test( - "SELECT MIN(age), MAX(age) FROM person GROUP BY state", - "Projection: min(person.age), max(person.age)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age), max(person.age)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT MIN(age), MAX(age) FROM person GROUP BY state").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: min(person.age), max(person.age) + Aggregate: groupBy=[[person.state]], aggr=[[min(person.age), max(person.age)]] + TableScan: person + "# ); } @@ -1291,27 +1699,39 @@ fn select_interval_out_of_range() { #[test] fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { - quick_test( - "SELECT MAX(first_name) FROM person GROUP BY first_name", - "Projection: max(person.first_name)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.first_name)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT MAX(first_name) FROM person GROUP BY first_name").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: max(person.first_name) + Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.first_name)]] + TableScan: person + "# ); } #[test] fn select_simple_aggregate_with_groupby_can_use_positions() { - quick_test( - "SELECT state, age AS b, count(1) FROM person GROUP BY 1, 2", - "Projection: person.state, person.age AS b, count(Int64(1))\ - \n Aggregate: groupBy=[[person.state, person.age]], aggr=[[count(Int64(1))]]\ - \n TableScan: person", + let plan = logical_plan("SELECT state, age AS b, count(1) FROM person GROUP BY 1, 2") + .unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state, person.age AS b, count(Int64(1)) + Aggregate: groupBy=[[person.state, person.age]], aggr=[[count(Int64(1))]] + TableScan: person + "# ); - quick_test( - "SELECT state, age AS b, count(1) FROM person GROUP BY 2, 1", - "Projection: person.state, person.age AS b, count(Int64(1))\ - \n Aggregate: groupBy=[[person.age, person.state]], aggr=[[count(Int64(1))]]\ - \n TableScan: person", + let plan = logical_plan("SELECT state, age AS b, count(1) FROM person GROUP BY 2, 1") + .unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state, person.age AS b, count(Int64(1)) + Aggregate: groupBy=[[person.age, person.state]], aggr=[[count(Int64(1))]] + TableScan: person + "# ); } @@ -1340,11 +1760,15 @@ fn select_simple_aggregate_with_groupby_position_out_of_range() { #[test] fn select_simple_aggregate_with_groupby_can_use_alias() { - quick_test( - "SELECT state AS a, MIN(age) AS b FROM person GROUP BY a", - "Projection: person.state AS a, min(person.age) AS b\ - \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT state AS a, MIN(age) AS b FROM person GROUP BY a").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state AS a, min(person.age) AS b + Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]] + TableScan: person + "# ); } @@ -1363,48 +1787,72 @@ fn select_simple_aggregate_with_groupby_aggregate_repeated() { #[test] fn select_simple_aggregate_with_groupby_aggregate_repeated_and_one_has_alias() { - quick_test( - "SELECT state, MIN(age), MIN(age) AS ma FROM person GROUP BY state", - "Projection: person.state, min(person.age), min(person.age) AS ma\ - \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]\ - \n TableScan: person", - ) + let plan = + logical_plan("SELECT state, MIN(age), MIN(age) AS ma FROM person GROUP BY state") + .unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state, min(person.age), min(person.age) AS ma + Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]] + TableScan: person + "# + ); } #[test] fn select_simple_aggregate_with_groupby_non_column_expression_unselected() { - quick_test( - "SELECT MIN(first_name) FROM person GROUP BY age + 1", - "Projection: min(person.first_name)\ - \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT MIN(first_name) FROM person GROUP BY age + 1").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: min(person.first_name) + Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]] + TableScan: person + "# ); } #[test] fn select_simple_aggregate_with_groupby_non_column_expression_selected_and_resolvable() { - quick_test( - "SELECT age + 1, MIN(first_name) FROM person GROUP BY age + 1", - "Projection: person.age + Int64(1), min(person.first_name)\ - \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT age + 1, MIN(first_name) FROM person GROUP BY age + 1") + .unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.age + Int64(1), min(person.first_name) + Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]] + TableScan: person + "# ); - quick_test( - "SELECT MIN(first_name), age + 1 FROM person GROUP BY age + 1", - "Projection: min(person.first_name), person.age + Int64(1)\ - \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT MIN(first_name), age + 1 FROM person GROUP BY age + 1") + .unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: min(person.first_name), person.age + Int64(1) + Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]] + TableScan: person + "# ); } #[test] fn select_simple_aggregate_with_groupby_non_column_expression_nested_and_resolvable() { - quick_test( - "SELECT ((age + 1) / 2) * (age + 1), MIN(first_name) FROM person GROUP BY age + 1", - "Projection: person.age + Int64(1) / Int64(2) * person.age + Int64(1), min(person.first_name)\ - \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]]\ - \n TableScan: person", - ); + let plan = logical_plan( + "SELECT ((age + 1) / 2) * (age + 1), MIN(first_name) FROM person GROUP BY age + 1" + ).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.age + Int64(1) / Int64(2) * person.age + Int64(1), min(person.first_name) + Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]] + TableScan: person + "# + ); } #[test] @@ -1437,113 +1885,168 @@ fn select_simple_aggregate_with_groupby_non_column_expression_and_its_column_sel #[test] fn select_simple_aggregate_nested_in_binary_expr_with_groupby() { - quick_test( - "SELECT state, MIN(age) < 10 FROM person GROUP BY state", - "Projection: person.state, min(person.age) < Int64(10)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT state, MIN(age) < 10 FROM person GROUP BY state").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state, min(person.age) < Int64(10) + Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]] + TableScan: person + "# ); } #[test] fn select_simple_aggregate_and_nested_groupby_column() { - quick_test( - "SELECT age + 1, MAX(first_name) FROM person GROUP BY age", - "Projection: person.age + Int64(1), max(person.first_name)\ - \n Aggregate: groupBy=[[person.age]], aggr=[[max(person.first_name)]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT MAX(first_name), age + 1 FROM person GROUP BY age").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: max(person.first_name), person.age + Int64(1) + Aggregate: groupBy=[[person.age]], aggr=[[max(person.first_name)]] + TableScan: person + "# ); } #[test] fn select_aggregate_compounded_with_groupby_column() { - quick_test( - "SELECT age + MIN(salary) FROM person GROUP BY age", - "Projection: person.age + min(person.salary)\ - \n Aggregate: groupBy=[[person.age]], aggr=[[min(person.salary)]]\ - \n TableScan: person", + let plan = logical_plan("SELECT age + MIN(salary) FROM person GROUP BY age").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.age + min(person.salary) + Aggregate: groupBy=[[person.age]], aggr=[[min(person.salary)]] + TableScan: person + "# ); } #[test] fn select_aggregate_with_non_column_inner_expression_with_groupby() { - quick_test( - "SELECT state, MIN(age + 1) FROM person GROUP BY state", - "Projection: person.state, min(person.age + Int64(1))\ - \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age + Int64(1))]]\ - \n TableScan: person", + let plan = + logical_plan("SELECT state, MIN(age + 1) FROM person GROUP BY state").unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: person.state, min(person.age + Int64(1)) + Aggregate: groupBy=[[person.state]], aggr=[[min(person.age + Int64(1))]] + TableScan: person + "# ); } #[test] fn select_count_one() { let sql = "SELECT count(1) FROM person"; - let expected = "Projection: count(Int64(1))\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: count(Int64(1)) + Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] + TableScan: person +"# + ); } #[test] fn select_count_column() { let sql = "SELECT count(id) FROM person"; - let expected = "Projection: count(person.id)\ - \n Aggregate: groupBy=[[]], aggr=[[count(person.id)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: count(person.id) + Aggregate: groupBy=[[]], aggr=[[count(person.id)]] + TableScan: person +"# + ); } #[test] fn select_approx_median() { let sql = "SELECT approx_median(age) FROM person"; - let expected = "Projection: approx_median(person.age)\ - \n Aggregate: groupBy=[[]], aggr=[[approx_median(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: approx_median(person.age) + Aggregate: groupBy=[[]], aggr=[[approx_median(person.age)]] + TableScan: person +"# + ); } #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; - let expected = "Projection: sqrt(person.age)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: sqrt(person.age) + TableScan: person +"# + ); } #[test] fn select_aliased_scalar_func() { let sql = "SELECT sqrt(person.age) AS square_people FROM person"; - let expected = "Projection: sqrt(person.age) AS square_people\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: sqrt(person.age) AS square_people + TableScan: person +"# + ); } #[test] fn select_where_nullif_division() { let sql = "SELECT c3/(c4+c5) \ FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1"; - let expected = "Projection: aggregate_test_100.c3 / (aggregate_test_100.c4 + aggregate_test_100.c5)\ - \n Filter: aggregate_test_100.c3 / nullif(aggregate_test_100.c4 + aggregate_test_100.c5, Int64(0)) > Float64(0.1)\ - \n TableScan: aggregate_test_100"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: aggregate_test_100.c3 / (aggregate_test_100.c4 + aggregate_test_100.c5) + Filter: aggregate_test_100.c3 / nullif(aggregate_test_100.c4 + aggregate_test_100.c5, Int64(0)) > Float64(0.1) + TableScan: aggregate_test_100 +"# + ); } #[test] fn select_where_with_negative_operator() { let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > -0.1 AND -c4 > 0"; - let expected = "Projection: aggregate_test_100.c3\ - \n Filter: aggregate_test_100.c3 > Float64(-0.1) AND (- aggregate_test_100.c4) > Int64(0)\ - \n TableScan: aggregate_test_100"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: aggregate_test_100.c3 + Filter: aggregate_test_100.c3 > Float64(-0.1) AND (- aggregate_test_100.c4) > Int64(0) + TableScan: aggregate_test_100 +"# + ); } #[test] fn select_where_with_positive_operator() { let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > +0.1 AND +c4 > 0"; - let expected = "Projection: aggregate_test_100.c3\ - \n Filter: aggregate_test_100.c3 > Float64(0.1) AND aggregate_test_100.c4 > Int64(0)\ - \n TableScan: aggregate_test_100"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: aggregate_test_100.c3 + Filter: aggregate_test_100.c3 > Float64(0.1) AND aggregate_test_100.c4 > Int64(0) + TableScan: aggregate_test_100 +"# + ); } #[test] @@ -1551,30 +2054,43 @@ fn select_where_compound_identifiers() { let sql = "SELECT aggregate_test_100.c3 \ FROM public.aggregate_test_100 \ WHERE aggregate_test_100.c3 > 0.1"; - let expected = "Projection: public.aggregate_test_100.c3\ - \n Filter: public.aggregate_test_100.c3 > Float64(0.1)\ - \n TableScan: public.aggregate_test_100"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: public.aggregate_test_100.c3 + Filter: public.aggregate_test_100.c3 > Float64(0.1) + TableScan: public.aggregate_test_100 +"# + ); } #[test] fn select_order_by_index() { let sql = "SELECT id FROM person ORDER BY 1"; - let expected = "Sort: person.id ASC NULLS LAST\ - \n Projection: person.id\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: person.id ASC NULLS LAST + Projection: person.id + TableScan: person +"# + ); } #[test] fn select_order_by_multiple_index() { let sql = "SELECT id, state, age FROM person ORDER BY 1, 3"; - let expected = "Sort: person.id ASC NULLS LAST, person.age ASC NULLS LAST\ - \n Projection: person.id, person.state, person.age\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: person.id ASC NULLS LAST, person.age ASC NULLS LAST + Projection: person.id, person.state, person.age + TableScan: person +"# + ); } #[test] @@ -1608,88 +2124,124 @@ fn select_order_by_index_oob() { } #[test] -fn select_order_by() { +fn select_with_order_by() { let sql = "SELECT id FROM person ORDER BY id"; - let expected = "Sort: person.id ASC NULLS LAST\ - \n Projection: person.id\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: person.id ASC NULLS LAST + Projection: person.id + TableScan: person +"# + ); } #[test] fn select_order_by_desc() { let sql = "SELECT id FROM person ORDER BY id DESC"; - let expected = "Sort: person.id DESC NULLS FIRST\ - \n Projection: person.id\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: person.id DESC NULLS FIRST + Projection: person.id + TableScan: person +"# + ); } #[test] fn select_order_by_nulls_last() { - quick_test( - "SELECT id FROM person ORDER BY id DESC NULLS LAST", - "Sort: person.id DESC NULLS LAST\ - \n Projection: person.id\ - \n TableScan: person", + let plan = logical_plan("SELECT id FROM person ORDER BY id DESC NULLS LAST").unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: person.id DESC NULLS LAST + Projection: person.id + TableScan: person +"# ); - quick_test( - "SELECT id FROM person ORDER BY id NULLS LAST", - "Sort: person.id ASC NULLS LAST\ - \n Projection: person.id\ - \n TableScan: person", + let plan = logical_plan("SELECT id FROM person ORDER BY id NULLS LAST").unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: person.id ASC NULLS LAST + Projection: person.id + TableScan: person +"# ); } #[test] fn select_group_by() { let sql = "SELECT state FROM person GROUP BY state"; - let expected = "Projection: person.state\ - \n Aggregate: groupBy=[[person.state]], aggr=[[]]\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.state + Aggregate: groupBy=[[person.state]], aggr=[[]] + TableScan: person +"# + ); } #[test] fn select_group_by_columns_not_in_select() { let sql = "SELECT MAX(age) FROM person GROUP BY state"; - let expected = "Projection: max(person.age)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[max(person.age)]]\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: max(person.age) + Aggregate: groupBy=[[person.state]], aggr=[[max(person.age)]] + TableScan: person +"# + ); } #[test] fn select_group_by_count_star() { let sql = "SELECT state, count(*) FROM person GROUP BY state"; - let expected = "Projection: person.state, count(*)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[count(*)]]\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.state, count(*) + Aggregate: groupBy=[[person.state]], aggr=[[count(*)]] + TableScan: person +"# + ); } #[test] fn select_group_by_needs_projection() { let sql = "SELECT count(state), state FROM person GROUP BY state"; - let expected = "\ - Projection: count(person.state), person.state\ - \n Aggregate: groupBy=[[person.state]], aggr=[[count(person.state)]]\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: count(person.state), person.state + Aggregate: groupBy=[[person.state]], aggr=[[count(person.state)]] + TableScan: person + "# + ); } #[test] fn select_7480_1() { let sql = "SELECT c1, MIN(c12) FROM aggregate_test_100 GROUP BY c1, c13"; - let expected = "Projection: aggregate_test_100.c1, min(aggregate_test_100.c12)\ - \n Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c13]], aggr=[[min(aggregate_test_100.c12)]]\ - \n TableScan: aggregate_test_100"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: aggregate_test_100.c1, min(aggregate_test_100.c12) + Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c13]], aggr=[[min(aggregate_test_100.c12)]] + TableScan: aggregate_test_100 +"# + ); } #[test] @@ -1708,58 +2260,97 @@ fn select_7480_2() { #[test] fn create_external_table_csv() { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'"; - let expected = "CreateExternalTable: Bare { table: \"t\" }"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateExternalTable: Bare { table: "t" } +"# + ); } #[test] fn create_external_table_with_pk() { let sql = "CREATE EXTERNAL TABLE t(c1 int, primary key(c1)) STORED AS CSV LOCATION 'foo.csv'"; - let expected = - "CreateExternalTable: Bare { table: \"t\" } constraints=[PrimaryKey([0])]"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateExternalTable: Bare { table: "t" } constraints=[PrimaryKey([0])] + "# + ); } #[test] fn create_external_table_wih_schema() { let sql = "CREATE EXTERNAL TABLE staging.foo STORED AS CSV LOCATION 'foo.csv'"; - let expected = "CreateExternalTable: Partial { schema: \"staging\", table: \"foo\" }"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateExternalTable: Partial { schema: "staging", table: "foo" } +"# + ); } #[test] fn create_schema_with_quoted_name() { let sql = "CREATE SCHEMA \"quoted_schema_name\""; - let expected = "CreateCatalogSchema: \"quoted_schema_name\""; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateCatalogSchema: "quoted_schema_name" +"# + ); } #[test] fn create_schema_with_quoted_unnormalized_name() { let sql = "CREATE SCHEMA \"Foo\""; - let expected = "CreateCatalogSchema: \"Foo\""; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateCatalogSchema: "Foo" +"# + ); } #[test] fn create_schema_with_unquoted_normalized_name() { let sql = "CREATE SCHEMA Foo"; - let expected = "CreateCatalogSchema: \"foo\""; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateCatalogSchema: "foo" +"# + ); } #[test] fn create_external_table_custom() { let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';"; - let expected = r#"CreateExternalTable: Bare { table: "dt" }"#; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateExternalTable: Bare { table: "dt" } +"# + ); } #[test] fn create_external_table_csv_no_schema() { let sql = "CREATE EXTERNAL TABLE t STORED AS CSV LOCATION 'foo.csv'"; - let expected = "CreateExternalTable: Bare { table: \"t\" }"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateExternalTable: Bare { table: "t" } +"# + ); } #[test] @@ -1772,9 +2363,18 @@ fn create_external_table_with_compression_type() { "CREATE EXTERNAL TABLE t(c1 int) STORED AS JSON LOCATION 'foo.json.bz2' OPTIONS ('format.compression' 'bzip2')", "CREATE EXTERNAL TABLE t(c1 int) STORED AS NONSTANDARD LOCATION 'foo.unk' OPTIONS ('format.compression' 'gzip')", ]; - for sql in sqls { - let expected = "CreateExternalTable: Bare { table: \"t\" }"; - quick_test(sql, expected); + + allow_duplicates! { + for sql in sqls { + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + CreateExternalTable: Bare { table: "t" } + "# + ); + } + } // negative case @@ -1805,29 +2405,47 @@ fn create_external_table_with_compression_type() { #[test] fn create_external_table_parquet() { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS PARQUET LOCATION 'foo.parquet'"; - let expected = "CreateExternalTable: Bare { table: \"t\" }"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateExternalTable: Bare { table: "t" } +"# + ); } #[test] fn create_external_table_parquet_sort_order() { let sql = "create external table foo(a varchar, b varchar, c timestamp) stored as parquet location '/tmp/foo' with order (c)"; - let expected = "CreateExternalTable: Bare { table: \"foo\" }"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateExternalTable: Bare { table: "foo" } +"# + ); } #[test] fn create_external_table_parquet_no_schema() { let sql = "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet'"; - let expected = "CreateExternalTable: Bare { table: \"t\" }"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#"CreateExternalTable: Bare { table: "t" }"# + ); } #[test] fn create_external_table_parquet_no_schema_sort_order() { let sql = "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet' WITH ORDER (id)"; - let expected = "CreateExternalTable: Bare { table: \"t\" }"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +CreateExternalTable: Bare { table: "t" } +"# + ); } #[test] @@ -1836,11 +2454,16 @@ fn equijoin_explicit_syntax() { FROM person \ JOIN orders \ ON id = customer_id"; - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id = orders.customer_id\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -1849,12 +2472,16 @@ fn equijoin_with_condition() { FROM person \ JOIN orders \ ON id = customer_id AND order_id > 1 "; - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id = orders.customer_id AND orders.order_id > Int64(1)\ - \n TableScan: person\ - \n TableScan: orders"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND orders.order_id > Int64(1) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -1863,11 +2490,16 @@ fn left_equijoin_with_conditions() { FROM person \ LEFT JOIN orders \ ON id = customer_id AND order_id > 1 AND age < 30"; - let expected = "Projection: person.id, orders.order_id\ - \n Left Join: Filter: person.id = orders.customer_id AND orders.order_id > Int64(1) AND person.age < Int64(30)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Left Join: Filter: person.id = orders.customer_id AND orders.order_id > Int64(1) AND person.age < Int64(30) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -1876,12 +2508,16 @@ fn right_equijoin_with_conditions() { FROM person \ RIGHT JOIN orders \ ON id = customer_id AND id > 1 AND order_id < 100"; - - let expected = "Projection: person.id, orders.order_id\ - \n Right Join: Filter: person.id = orders.customer_id AND person.id > Int64(1) AND orders.order_id < Int64(100)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Right Join: Filter: person.id = orders.customer_id AND person.id > Int64(1) AND orders.order_id < Int64(100) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -1890,11 +2526,16 @@ fn full_equijoin_with_conditions() { FROM person \ FULL JOIN orders \ ON id = customer_id AND id > 1 AND order_id < 100"; - let expected = "Projection: person.id, orders.order_id\ - \n Full Join: Filter: person.id = orders.customer_id AND person.id > Int64(1) AND orders.order_id < Int64(100)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Full Join: Filter: person.id = orders.customer_id AND person.id > Int64(1) AND orders.order_id < Int64(100) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -1903,11 +2544,16 @@ fn join_with_table_name() { FROM person \ JOIN orders \ ON person.id = orders.customer_id"; - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id = orders.customer_id\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -1916,12 +2562,17 @@ fn join_with_using() { FROM person \ JOIN person as person2 \ USING (id)"; - let expected = "Projection: person.first_name, person.id\ - \n Inner Join: Using person.id = person2.id\ - \n TableScan: person\ - \n SubqueryAlias: person2\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.first_name, person.id + Inner Join: Using person.id = person2.id + TableScan: person + SubqueryAlias: person2 + TableScan: person +"# + ); } #[test] @@ -1930,13 +2581,18 @@ fn equijoin_explicit_syntax_3_tables() { FROM person \ JOIN orders ON id = customer_id \ JOIN lineitem ON o_item_id = l_item_id"; - let expected = "Projection: person.id, orders.order_id, lineitem.l_description\ - \n Inner Join: Filter: orders.o_item_id = lineitem.l_item_id\ - \n Inner Join: Filter: person.id = orders.customer_id\ - \n TableScan: person\ - \n TableScan: orders\ - \n TableScan: lineitem"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id, lineitem.l_description + Inner Join: Filter: orders.o_item_id = lineitem.l_item_id + Inner Join: Filter: person.id = orders.customer_id + TableScan: person + TableScan: orders + TableScan: lineitem +"# + ); } #[test] @@ -1944,152 +2600,206 @@ fn boolean_literal_in_condition_expression() { let sql = "SELECT order_id \ FROM orders \ WHERE delivered = false OR delivered = true"; - let expected = "Projection: orders.order_id\ - \n Filter: orders.delivered = Boolean(false) OR orders.delivered = Boolean(true)\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id + Filter: orders.delivered = Boolean(false) OR orders.delivered = Boolean(true) + TableScan: orders +"# + ); } #[test] fn union() { let sql = "SELECT order_id from orders UNION SELECT order_id FROM orders"; - let expected = "\ - Distinct:\ - \n Union\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Distinct: + Union + Projection: orders.order_id + TableScan: orders + Projection: orders.order_id + TableScan: orders +"# + ); } #[test] fn union_by_name_different_columns() { let sql = "SELECT order_id from orders UNION BY NAME SELECT order_id, 1 FROM orders"; - let expected = "\ - Distinct:\ - \n Union\ - \n Projection: order_id, NULL AS Int64(1)\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: order_id, Int64(1)\ - \n Projection: orders.order_id, Int64(1)\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Distinct: + Union + Projection: order_id, NULL AS Int64(1) + Projection: orders.order_id + TableScan: orders + Projection: order_id, Int64(1) + Projection: orders.order_id, Int64(1) + TableScan: orders +"# + ); } #[test] fn union_by_name_same_column_names() { let sql = "SELECT order_id from orders UNION SELECT order_id FROM orders"; - let expected = "\ - Distinct:\ - \n Union\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Distinct: + Union + Projection: orders.order_id + TableScan: orders + Projection: orders.order_id + TableScan: orders +"# + ); } #[test] fn union_all() { let sql = "SELECT order_id from orders UNION ALL SELECT order_id FROM orders"; - let expected = "Union\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.order_id\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Union + Projection: orders.order_id + TableScan: orders + Projection: orders.order_id + TableScan: orders +"# + ); } #[test] fn union_all_by_name_different_columns() { let sql = "SELECT order_id from orders UNION ALL BY NAME SELECT order_id, 1 FROM orders"; - let expected = "\ - Union\ - \n Projection: order_id, NULL AS Int64(1)\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: order_id, Int64(1)\ - \n Projection: orders.order_id, Int64(1)\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Union + Projection: order_id, NULL AS Int64(1) + Projection: orders.order_id + TableScan: orders + Projection: order_id, Int64(1) + Projection: orders.order_id, Int64(1) + TableScan: orders +"# + ); } #[test] fn union_all_by_name_same_column_names() { let sql = "SELECT order_id from orders UNION ALL BY NAME SELECT order_id FROM orders"; - let expected = "\ - Union\ - \n Projection: order_id\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: order_id\ - \n Projection: orders.order_id\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Union + Projection: order_id + Projection: orders.order_id + TableScan: orders + Projection: order_id + Projection: orders.order_id + TableScan: orders +"# + ); } #[test] fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: orders +"# + ); } #[test] fn empty_over_with_alias() { let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders"; - let expected = "\ - Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid\ - \n WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid + WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: orders +"# + ); } #[test] fn empty_over_dup_with_alias() { let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid, MAX(order_id) OVER () max_oid_dup from orders"; - let expected = "\ - Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid_dup\ - \n WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid_dup + WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: orders +"# + ); } #[test] fn empty_over_dup_with_different_sort() { let sql = "SELECT order_id oid, MAX(order_id) OVER (), MAX(order_id) OVER (ORDER BY order_id) from orders"; - let expected = "\ - Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(orders.order_id) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n WindowAggr: windowExpr=[[max(orders.order_id) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(orders.order_id) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + WindowAggr: windowExpr=[[max(orders.order_id) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } #[test] fn empty_over_plus() { let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty * Float64(1.1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(orders.qty * Float64(1.1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty * Float64(1.1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(orders.qty * Float64(1.1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: orders +"# + ); } #[test] fn empty_over_multiple() { let sql = "SELECT order_id, MAX(qty) OVER (), min(qty) over (), avg(qty) OVER () from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, avg(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, avg(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, avg(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, avg(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: orders +"# + ); } /// psql result @@ -2104,11 +2814,15 @@ fn empty_over_multiple() { #[test] fn over_partition_by() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: orders +"# + ); } /// psql result @@ -2126,45 +2840,61 @@ fn over_partition_by() { #[test] fn over_order_by() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } #[test] fn over_order_by_with_window_frame_double_end() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); -} - -#[test] -fn over_order_by_with_window_frame_single_end() { + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]] + WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); +} + +#[test] +fn over_order_by_with_window_frame_single_end() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } #[test] fn over_order_by_with_window_frame_single_end_groups() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } /// psql result @@ -2182,12 +2912,16 @@ fn over_order_by_with_window_frame_single_end_groups() { #[test] fn over_order_by_two_sort_keys() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id + Int64(1) ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id + Int64(1) ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id + Int64(1) ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id + Int64(1) ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } /// psql result @@ -2206,13 +2940,17 @@ fn over_order_by_two_sort_keys() { #[test] fn over_order_by_sort_keys_sorting() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), sum(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } /// psql result @@ -2229,13 +2967,17 @@ fn over_order_by_sort_keys_sorting() { #[test] fn over_order_by_sort_keys_sorting_prefix_compacting() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), sum(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } /// psql result @@ -2257,14 +2999,18 @@ fn over_order_by_sort_keys_sorting_prefix_compacting() { #[test] fn over_order_by_sort_keys_sorting_global_order_compacting() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), sum(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders ORDER BY order_id"; - let expected = "\ - Sort: orders.order_id ASC NULLS LAST\ - \n Projection: orders.order_id, max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: orders.order_id ASC NULLS LAST + Projection: orders.order_id, max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } /// psql result @@ -2280,11 +3026,15 @@ fn over_order_by_sort_keys_sorting_global_order_compacting() { fn over_partition_by_order_by() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } /// psql result @@ -2300,11 +3050,15 @@ fn over_partition_by_order_by() { fn over_partition_by_order_by_no_dup() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } /// psql result @@ -2323,12 +3077,16 @@ fn over_partition_by_order_by_no_dup() { fn over_partition_by_order_by_mix_up() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty), MIN(qty) OVER (PARTITION BY qty ORDER BY order_id) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) PARTITION BY [orders.qty] ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[min(orders.qty) PARTITION BY [orders.qty] ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) PARTITION BY [orders.qty] ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[min(orders.qty) PARTITION BY [orders.qty] ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } /// psql result @@ -2346,90 +3104,121 @@ fn over_partition_by_order_by_mix_up() { fn over_partition_by_order_by_mix_up_prefix() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty), MIN(qty) OVER (PARTITION BY order_id, qty ORDER BY price) from orders"; - let expected = "\ - Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.price ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[min(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.price ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.price ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[min(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.price ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: orders +"# + ); } #[test] fn approx_median_window() { let sql = "SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) from orders"; - let expected = "\ - Projection: orders.order_id, approx_median(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[approx_median(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, approx_median(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[approx_median(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: orders +"# + ); } #[test] fn select_typed_date_string() { let sql = "SELECT date '2020-12-10' AS date"; - let expected = "Projection: CAST(Utf8(\"2020-12-10\") AS Date32) AS date\ - \n EmptyRelation"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: CAST(Utf8("2020-12-10") AS Date32) AS date + EmptyRelation: rows=1 + "# + ); } #[test] fn select_typed_time_string() { let sql = "SELECT TIME '08:09:10.123' AS time"; - let expected = - "Projection: CAST(Utf8(\"08:09:10.123\") AS Time64(Nanosecond)) AS time\ - \n EmptyRelation"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: CAST(Utf8("08:09:10.123") AS Time64(Nanosecond)) AS time + EmptyRelation: rows=1 + "# + ); } #[test] fn select_multibyte_column() { let sql = r#"SELECT "😀" FROM person"#; - let expected = "Projection: person.😀\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.😀 + TableScan: person +"# + ); } #[test] fn select_groupby_orderby() { // ensure that references are correctly resolved in the order by clause // see https://github.com/apache/datafusion/issues/4854 - let sql = r#"SELECT - avg(age) AS "value", - date_trunc('month', birth_date) AS "birth_date" - FROM person GROUP BY birth_date ORDER BY birth_date; -"#; - // expect that this is not an ambiguous reference - let expected = - "Sort: birth_date ASC NULLS LAST\ - \n Projection: avg(person.age) AS value, date_trunc(Utf8(\"month\"), person.birth_date) AS birth_date\ - \n Aggregate: groupBy=[[person.birth_date]], aggr=[[avg(person.age)]]\ - \n TableScan: person"; - quick_test(sql, expected); - - // Use fully qualified `person.birth_date` as argument to date_trunc, plan should be the same - let sql = r#"SELECT - avg(age) AS "value", - date_trunc('month', person.birth_date) AS "birth_date" - FROM person GROUP BY birth_date ORDER BY birth_date; -"#; - quick_test(sql, expected); - - // Use fully qualified `person.birth_date` as group by, plan should be the same - let sql = r#"SELECT - avg(age) AS "value", - date_trunc('month', birth_date) AS "birth_date" - FROM person GROUP BY person.birth_date ORDER BY birth_date; -"#; - quick_test(sql, expected); - // Use fully qualified `person.birth_date` in both group and date_trunc, plan should be the same - let sql = r#"SELECT - avg(age) AS "value", - date_trunc('month', person.birth_date) AS "birth_date" - FROM person GROUP BY person.birth_date ORDER BY birth_date; -"#; - quick_test(sql, expected); + let sqls = vec![ + r#" + SELECT + avg(age) AS "value", + date_trunc('month', birth_date) AS "birth_date" + FROM person GROUP BY birth_date ORDER BY birth_date; + "#, + // Use fully qualified `person.birth_date` as argument to date_trunc, plan should be the same + r#" + SELECT + avg(age) AS "value", + date_trunc('month', person.birth_date) AS "birth_date" + FROM person GROUP BY birth_date ORDER BY birth_date; + "#, + // Use fully qualified `person.birth_date` as group by, plan should be the same + r#" + SELECT + avg(age) AS "value", + date_trunc('month', birth_date) AS "birth_date" + FROM person GROUP BY person.birth_date ORDER BY birth_date; + "#, + // Use fully qualified `person.birth_date` in both group and date_trunc, plan should be the same + r#" + SELECT + avg(age) AS "value", + date_trunc('month', person.birth_date) AS "birth_date" + FROM person GROUP BY person.birth_date ORDER BY birth_date; + "#, + ]; + for sql in sqls { + let plan = logical_plan(sql).unwrap(); + allow_duplicates! { + assert_snapshot!( + plan, + // expect that this is not an ambiguous reference + @r#" + Sort: birth_date ASC NULLS LAST + Projection: avg(person.age) AS value, date_trunc(Utf8("month"), person.birth_date) AS birth_date + Aggregate: groupBy=[[person.birth_date]], aggr=[[avg(person.age)]] + TableScan: person + "# + ); + } + } // Use columnized `avg(age)` in the order by let sql = r#"SELECT @@ -2438,13 +3227,16 @@ fn select_groupby_orderby() { FROM person GROUP BY person.birth_date ORDER BY avg(age) + avg(age); "#; - let expected = - "Sort: avg(person.age) + avg(person.age) ASC NULLS LAST\ - \n Projection: avg(person.age) + avg(person.age), date_trunc(Utf8(\"month\"), person.birth_date) AS birth_date\ - \n Aggregate: groupBy=[[person.birth_date]], aggr=[[avg(person.age)]]\ - \n TableScan: person"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Sort: avg(person.age) + avg(person.age) ASC NULLS LAST + Projection: avg(person.age) + avg(person.age), date_trunc(Utf8("month"), person.birth_date) AS birth_date + Aggregate: groupBy=[[person.birth_date]], aggr=[[avg(person.age)]] + TableScan: person +"# + ); } fn logical_plan(sql: &str) -> Result { @@ -2506,6 +3298,7 @@ fn logical_plan_with_dialect_and_options( .with_aggregate_function(max_udaf()) .with_aggregate_function(grouping_udaf()) .with_window_function(rank_udwf()) + .with_window_function(row_number_udwf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let context = MockContextProvider { state }; @@ -2520,7 +3313,7 @@ fn make_udf(name: &'static str, args: Vec, return_type: DataType) -> S } /// Mocked UDF -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct DummyUDF { name: &'static str, signature: Signature, @@ -2559,102 +3352,141 @@ impl ScalarUDFImpl for DummyUDF { } } -/// Create logical plan, write with formatter, compare to expected output -fn quick_test(sql: &str, expected: &str) { - quick_test_with_options(sql, expected, ParserOptions::default()) +fn parse_decimals_parser_options() -> ParserOptions { + ParserOptions { + parse_float_as_decimal: true, + enable_ident_normalization: false, + support_varchar_with_length: false, + map_string_types_to_utf8view: true, + enable_options_value_normalization: false, + collect_spans: false, + default_null_ordering: NullOrdering::NullsMax, + } } -fn quick_test_with_options(sql: &str, expected: &str, options: ParserOptions) { - let plan = logical_plan_with_options(sql, options).unwrap(); - assert_eq!(format!("{plan}"), expected); +fn ident_normalization_parser_options_no_ident_normalization() -> ParserOptions { + ParserOptions { + parse_float_as_decimal: true, + enable_ident_normalization: false, + support_varchar_with_length: false, + map_string_types_to_utf8view: true, + enable_options_value_normalization: false, + collect_spans: false, + default_null_ordering: NullOrdering::NullsMax, + } } -fn prepare_stmt_quick_test( - sql: &str, - expected_plan: &str, - expected_data_types: &str, -) -> LogicalPlan { - let plan = logical_plan(sql).unwrap(); - - let assert_plan = plan.clone(); - // verify plan - assert_eq!(format!("{assert_plan}"), expected_plan); - - // verify data types - if let LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) = - assert_plan - { - let dt = format!("{data_types:?}"); - assert_eq!(dt, expected_data_types); +fn ident_normalization_parser_options_ident_normalization() -> ParserOptions { + ParserOptions { + parse_float_as_decimal: true, + enable_ident_normalization: true, + support_varchar_with_length: false, + map_string_types_to_utf8view: true, + enable_options_value_normalization: false, + collect_spans: false, + default_null_ordering: NullOrdering::NullsMax, } - - plan } #[test] fn select_partially_qualified_column() { - let sql = r#"SELECT person.first_name FROM public.person"#; - let expected = "Projection: public.person.first_name\ - \n TableScan: public.person"; - quick_test(sql, expected); + let sql = "SELECT person.first_name FROM public.person"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: public.person.first_name + TableScan: public.person +"# + ); } #[test] fn cross_join_not_to_inner_join() { let sql = "select person.id from person, orders, lineitem where person.id = person.age;"; - let expected = "Projection: person.id\ - \n Filter: person.id = person.age\ - \n Cross Join: \ - \n Cross Join: \ - \n TableScan: person\ - \n TableScan: orders\ - \n TableScan: lineitem"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id + Filter: person.id = person.age + Cross Join: + Cross Join: + TableScan: person + TableScan: orders + TableScan: lineitem +"# + ); } #[test] fn join_with_aliases() { let sql = "select peeps.id, folks.first_name from person as peeps join person as folks on peeps.id = folks.id"; - let expected = "Projection: peeps.id, folks.first_name\ - \n Inner Join: Filter: peeps.id = folks.id\ - \n SubqueryAlias: peeps\ - \n TableScan: person\ - \n SubqueryAlias: folks\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: peeps.id, folks.first_name + Inner Join: Filter: peeps.id = folks.id + SubqueryAlias: peeps + TableScan: person + SubqueryAlias: folks + TableScan: person +"# + ); } #[test] fn negative_interval_plus_interval_in_projection() { let sql = "select -interval '2 days' + interval '5 days';"; - let expected = - "Projection: IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: -2, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\")\n EmptyRelation"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: -2, nanoseconds: 0 }") + IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }") + EmptyRelation: rows=1 + "# + ); } #[test] fn complex_interval_expression_in_projection() { let sql = "select -interval '2 days' + interval '5 days'+ (-interval '3 days' + interval '5 days');"; - let expected = - "Projection: IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: -2, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: -3, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\")\n EmptyRelation"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: -2, nanoseconds: 0 }") + IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }") + IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: -3, nanoseconds: 0 }") + IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }") + EmptyRelation: rows=1 + "# + ); } #[test] fn negative_sum_intervals_in_projection() { let sql = "select -((interval '2 days' + interval '5 days') + -(interval '4 days' + interval '7 days'));"; - let expected = - "Projection: (- IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 2, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\") + (- IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 4, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 7, nanoseconds: 0 }\")))\n EmptyRelation"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: (- IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 2, nanoseconds: 0 }") + IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }") + (- IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 4, nanoseconds: 0 }") + IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 7, nanoseconds: 0 }"))) + EmptyRelation: rows=1 + "# + ); } #[test] fn date_plus_interval_in_projection() { let sql = "select t_date32 + interval '5 days' FROM test"; - let expected = - "Projection: test.t_date32 + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\")\n TableScan: test"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: test.t_date32 + IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }") + TableScan: test +"# + ); } #[test] @@ -2663,11 +3495,15 @@ fn date_plus_interval_in_filter() { WHERE t_date64 \ BETWEEN cast('1999-12-31' as date) \ AND cast('1999-12-31' as date) + interval '30 days'"; - let expected = - "Projection: test.t_date64\ - \n Filter: test.t_date64 BETWEEN CAST(Utf8(\"1999-12-31\") AS Date32) AND CAST(Utf8(\"1999-12-31\") AS Date32) + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 30, nanoseconds: 0 }\")\ - \n TableScan: test"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: test.t_date64 + Filter: test.t_date64 BETWEEN CAST(Utf8("1999-12-31") AS Date32) AND CAST(Utf8("1999-12-31") AS Date32) + IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 30, nanoseconds: 0 }") + TableScan: test +"# + ); } #[test] @@ -2676,16 +3512,20 @@ fn exists_subquery() { (SELECT first_name FROM person \ WHERE last_name = p.last_name \ AND state = p.state)"; - - let expected = "Projection: p.id\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Projection: person.first_name\ - \n Filter: person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ - \n TableScan: person\ - \n SubqueryAlias: p\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: p.id + Filter: EXISTS () + Subquery: + Projection: person.first_name + Filter: person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state) + TableScan: person + SubqueryAlias: p + TableScan: person +"# + ); } #[test] @@ -2697,68 +3537,84 @@ fn exists_subquery_schema_outer_schema_overlap() { WHERE person.id = p2.id \ AND person.last_name = p.last_name \ AND person.state = p.state)"; - - let expected = "Projection: person.id\ - \n Filter: person.id = p.id AND EXISTS ()\ - \n Subquery:\ - \n Projection: person.first_name\ - \n Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ - \n Cross Join: \ - \n TableScan: person\ - \n SubqueryAlias: p2\ - \n TableScan: person\ - \n Cross Join: \ - \n TableScan: person\ - \n SubqueryAlias: p\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id + Filter: person.id = p.id AND EXISTS () + Subquery: + Projection: person.first_name + Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state) + Cross Join: + TableScan: person + SubqueryAlias: p2 + TableScan: person + Cross Join: + TableScan: person + SubqueryAlias: p + TableScan: person +"# + ); } #[test] fn in_subquery_uncorrelated() { let sql = "SELECT id FROM person p WHERE id IN \ (SELECT id FROM person)"; - - let expected = "Projection: p.id\ - \n Filter: p.id IN ()\ - \n Subquery:\ - \n Projection: person.id\ - \n TableScan: person\ - \n SubqueryAlias: p\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: p.id + Filter: p.id IN () + Subquery: + Projection: person.id + TableScan: person + SubqueryAlias: p + TableScan: person +"# + ); } #[test] fn not_in_subquery_correlated() { let sql = "SELECT id FROM person p WHERE id NOT IN \ (SELECT id FROM person WHERE last_name = p.last_name AND state = 'CO')"; - - let expected = "Projection: p.id\ - \n Filter: p.id NOT IN ()\ - \n Subquery:\ - \n Projection: person.id\ - \n Filter: person.last_name = outer_ref(p.last_name) AND person.state = Utf8(\"CO\")\ - \n TableScan: person\ - \n SubqueryAlias: p\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: p.id + Filter: p.id NOT IN () + Subquery: + Projection: person.id + Filter: person.last_name = outer_ref(p.last_name) AND person.state = Utf8("CO") + TableScan: person + SubqueryAlias: p + TableScan: person +"# + ); } #[test] fn scalar_subquery() { let sql = "SELECT p.id, (SELECT MAX(id) FROM person WHERE last_name = p.last_name) FROM person p"; - - let expected = "Projection: p.id, ()\ - \n Subquery:\ - \n Projection: max(person.id)\ - \n Aggregate: groupBy=[[]], aggr=[[max(person.id)]]\ - \n Filter: person.last_name = outer_ref(p.last_name)\ - \n TableScan: person\ - \n SubqueryAlias: p\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: p.id, () + Subquery: + Projection: max(person.id) + Aggregate: groupBy=[[]], aggr=[[max(person.id)]] + Filter: person.last_name = outer_ref(p.last_name) + TableScan: person + SubqueryAlias: p + TableScan: person +"# + ); } #[test] @@ -2770,41 +3626,54 @@ fn scalar_subquery_reference_outer_field() { FROM j1, j3 \ WHERE j2_id = j1_id \ AND j1_id = j3_id)"; - - let expected = "Projection: j1.j1_string, j2.j2_string\ - \n Filter: j1.j1_id = j2.j2_id - Int64(1) AND j2.j2_id < ()\ - \n Subquery:\ - \n Projection: count(*)\ - \n Aggregate: groupBy=[[]], aggr=[[count(*)]]\ - \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ - \n Cross Join: \ - \n TableScan: j1\ - \n TableScan: j3\ - \n Cross Join: \ - \n TableScan: j1\ - \n TableScan: j2"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: j1.j1_string, j2.j2_string + Filter: j1.j1_id = j2.j2_id - Int64(1) AND j2.j2_id < () + Subquery: + Projection: count(*) + Aggregate: groupBy=[[]], aggr=[[count(*)]] + Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id + Cross Join: + TableScan: j1 + TableScan: j3 + Cross Join: + TableScan: j1 + TableScan: j2 +"# + ); } #[test] fn aggregate_with_rollup() { let sql = "SELECT id, state, age, count(*) FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, count(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[count(*)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.state, person.age, count(*) + Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[count(*)]] + TableScan: person +"# + ); } #[test] fn aggregate_with_rollup_with_grouping() { let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), count(*) \ FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, grouping(person.state), grouping(person.age), grouping(person.state) + grouping(person.age), count(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[grouping(person.state), grouping(person.age), count(*)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.state, person.age, grouping(person.state), grouping(person.age), grouping(person.state) + grouping(person.age), count(*) + Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[grouping(person.state), grouping(person.age), count(*)]] + TableScan: person +"# + ); } #[test] @@ -2822,50 +3691,75 @@ fn rank_partition_grouping() { from person group by rollup(state, last_name)"; - let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), grouping(person.state), grouping(person.last_name)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank + WindowAggr: windowExpr=[[rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), grouping(person.state), grouping(person.last_name)]] + TableScan: person +"# + ); } #[test] fn aggregate_with_cube() { let sql = "SELECT id, state, age, count(*) FROM person GROUP BY id, CUBE (state, age)"; - let expected = "Projection: person.id, person.state, person.age, count(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[count(*)]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.state, person.age, count(*) + Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[count(*)]] + TableScan: person +"# + ); } #[test] fn round_decimal() { let sql = "SELECT round(price/3, 2) FROM test_decimal"; - let expected = "Projection: round(test_decimal.price / Int64(3), Int64(2))\ - \n TableScan: test_decimal"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: round(test_decimal.price / Int64(3), Int64(2)) + TableScan: test_decimal +"# + ); } #[test] fn aggregate_with_grouping_sets() { let sql = "SELECT id, state, age, count(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; - let expected = "Projection: person.id, person.state, person.age, count(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[count(*)]]\ - \n TableScan: person"; - quick_test(sql, expected); -} - -#[test] + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.state, person.age, count(*) + Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[count(*)]] + TableScan: person +"# + ); +} + +#[test] fn join_on_disjunction_condition() { let sql = "SELECT id, order_id \ FROM person \ JOIN orders ON id = customer_id OR person.age > 30"; - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id = orders.customer_id OR person.age > Int64(30)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id OR person.age > Int64(30) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -2873,11 +3767,16 @@ fn join_on_complex_condition() { let sql = "SELECT id, order_id \ FROM person \ JOIN orders ON id = customer_id AND (person.age > 30 OR person.last_name = 'X')"; - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id = orders.customer_id AND (person.age > Int64(30) OR person.last_name = Utf8(\"X\"))\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND (person.age > Int64(30) OR person.last_name = Utf8("X")) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -2888,11 +3787,11 @@ fn hive_aggregate_with_filter() -> Result<()> { assert_snapshot!( plan, - @r###" + @r##" Projection: sum(person.age) FILTER (WHERE person.age > Int64(4)) Aggregate: groupBy=[[]], aggr=[[sum(person.age) FILTER (WHERE person.age > Int64(4))]] TableScan: person - "### + "## ); Ok(()) @@ -2905,13 +3804,18 @@ fn order_by_unaliased_name() { // SchemaError(FieldNotFound { qualifier: Some("p"), name: "state", valid_fields: ["z", "q"] }) let sql = "select p.state z, sum(age) q from person p group by p.state order by p.state"; - let expected = "Projection: z, q\ - \n Sort: p.state ASC NULLS LAST\ - \n Projection: p.state AS z, sum(p.age) AS q, p.state\ - \n Aggregate: groupBy=[[p.state]], aggr=[[sum(p.age)]]\ - \n SubqueryAlias: p\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: z, q + Sort: p.state ASC NULLS LAST + Projection: p.state AS z, sum(p.age) AS q, p.state + Aggregate: groupBy=[[p.state]], aggr=[[sum(p.age)]] + SubqueryAlias: p + TableScan: person +"# + ); } #[test] @@ -2943,54 +3847,87 @@ fn group_by_ambiguous_name() { #[test] fn test_zero_offset_with_limit() { let sql = "select id from person where person.id > 100 LIMIT 5 OFFSET 0;"; - let expected = "Limit: skip=0, fetch=5\ - \n Projection: person.id\ - \n Filter: person.id > Int64(100)\ - \n TableScan: person"; - quick_test(sql, expected); - + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Limit: skip=0, fetch=5 + Projection: person.id + Filter: person.id > Int64(100) + TableScan: person +"# + ); // Flip the order of LIMIT and OFFSET in the query. Plan should remain the same. let sql = "SELECT id FROM person WHERE person.id > 100 OFFSET 0 LIMIT 5;"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Limit: skip=0, fetch=5 + Projection: person.id + Filter: person.id > Int64(100) + TableScan: person +"# + ); } #[test] fn test_offset_no_limit() { let sql = "SELECT id FROM person WHERE person.id > 100 OFFSET 5;"; - let expected = "Limit: skip=5, fetch=None\ - \n Projection: person.id\ - \n Filter: person.id > Int64(100)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Limit: skip=5, fetch=None + Projection: person.id + Filter: person.id > Int64(100) + TableScan: person +"# + ); } #[test] fn test_offset_after_limit() { let sql = "select id from person where person.id > 100 LIMIT 5 OFFSET 3;"; - let expected = "Limit: skip=3, fetch=5\ - \n Projection: person.id\ - \n Filter: person.id > Int64(100)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Limit: skip=3, fetch=5 + Projection: person.id + Filter: person.id > Int64(100) + TableScan: person +"# + ); } #[test] fn test_offset_before_limit() { let sql = "select id from person where person.id > 100 OFFSET 3 LIMIT 5;"; - let expected = "Limit: skip=3, fetch=5\ - \n Projection: person.id\ - \n Filter: person.id > Int64(100)\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Limit: skip=3, fetch=5 + Projection: person.id + Filter: person.id > Int64(100) + TableScan: person +"# + ); } #[test] fn test_distribute_by() { let sql = "select id from person distribute by state"; - let expected = "Repartition: DistributeBy(person.state)\ - \n Projection: person.id\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Repartition: DistributeBy(person.state) + Projection: person.id + TableScan: person +"# + ); } #[test] @@ -3018,12 +3955,16 @@ fn test_constant_expr_eq_join() { FROM person \ INNER JOIN orders \ ON person.id = 10"; - - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id = Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id = Int64(10) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -3032,13 +3973,16 @@ fn test_right_left_expr_eq_join() { FROM person \ INNER JOIN orders \ ON orders.customer_id * 2 = person.id + 10"; - - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: orders.customer_id * Int64(2) = person.id + Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: orders.customer_id * Int64(2) = person.id + Int64(10) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -3047,12 +3991,16 @@ fn test_single_column_expr_eq_join() { FROM person \ INNER JOIN orders \ ON person.id + 10 = orders.customer_id * 2"; - - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id + Int64(10) = orders.customer_id * Int64(2)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id + Int64(10) = orders.customer_id * Int64(2) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -3061,12 +4009,16 @@ fn test_multiple_column_expr_eq_join() { FROM person \ INNER JOIN orders \ ON person.id + person.age + 10 = orders.customer_id * 2 - orders.price"; - - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id + person.age + Int64(10) = orders.customer_id * Int64(2) - orders.price\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id + person.age + Int64(10) = orders.customer_id * Int64(2) - orders.price + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -3075,12 +4027,16 @@ fn test_left_expr_eq_join() { FROM person \ INNER JOIN orders \ ON person.id + person.age + 10 = orders.customer_id"; - - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id + person.age + Int64(10) = orders.customer_id\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id + person.age + Int64(10) = orders.customer_id + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -3089,12 +4045,16 @@ fn test_right_expr_eq_join() { FROM person \ INNER JOIN orders \ ON person.id = orders.customer_id * 2 - orders.price"; - - let expected = "Projection: person.id, orders.order_id\ - \n Inner Join: Filter: person.id = orders.customer_id * Int64(2) - orders.price\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id * Int64(2) - orders.price + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -3103,38 +4063,58 @@ fn test_noneq_with_filter_join() { let sql = "SELECT person.id, person.first_name \ FROM person INNER JOIN orders \ ON person.age > 10"; - let expected = "Projection: person.id, person.first_name\ - \n Inner Join: Filter: person.age > Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.first_name + Inner Join: Filter: person.age > Int64(10) + TableScan: person + TableScan: orders +"# + ); // left join let sql = "SELECT person.id, person.first_name \ FROM person LEFT JOIN orders \ ON person.age > 10"; - let expected = "Projection: person.id, person.first_name\ - \n Left Join: Filter: person.age > Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.first_name + Left Join: Filter: person.age > Int64(10) + TableScan: person + TableScan: orders +"# + ); // right join let sql = "SELECT person.id, person.first_name \ FROM person RIGHT JOIN orders \ ON person.age > 10"; - let expected = "Projection: person.id, person.first_name\ - \n Right Join: Filter: person.age > Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.first_name + Right Join: Filter: person.age > Int64(10) + TableScan: person + TableScan: orders +"# + ); // full join let sql = "SELECT person.id, person.first_name \ FROM person FULL JOIN orders \ ON person.age > 10"; - let expected = "Projection: person.id, person.first_name\ - \n Full Join: Filter: person.age > Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.first_name + Full Join: Filter: person.age > Int64(10) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -3145,842 +4125,296 @@ fn test_one_side_constant_full_join() { FROM person \ FULL OUTER JOIN orders \ ON person.id = 10"; - - let expected = "Projection: person.id, orders.order_id\ - \n Full Join: Filter: person.id = Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); -} - -#[test] -fn test_select_join_key_inner_join() { - let sql = "SELECT orders.customer_id * 2, person.id + 10 - FROM person - INNER JOIN orders - ON orders.customer_id * 2 = person.id + 10"; - - let expected = "Projection: orders.customer_id * Int64(2), person.id + Int64(10)\ - \n Inner Join: Filter: orders.customer_id * Int64(2) = person.id + Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); -} - -#[test] -fn test_select_order_by() { - let sql = "SELECT '1' from person order by id"; - - let expected = "Projection: Utf8(\"1\")\n Sort: person.id ASC NULLS LAST\n Projection: Utf8(\"1\"), person.id\n TableScan: person"; - quick_test(sql, expected); -} - -#[test] -fn test_select_distinct_order_by() { - let sql = "SELECT distinct '1' from person order by id"; - - // It should return error. - let result = logical_plan(sql); - assert!(result.is_err()); - let err = result.err().unwrap().strip_backtrace(); - - assert_snapshot!( - err, - @r###" - Error during planning: For SELECT DISTINCT, ORDER BY expressions person.id must appear in select list - "### - ); -} - -#[rstest] -#[case::select_cluster_by_unsupported( - "SELECT customer_name, sum(order_total) as total_order_amount FROM orders CLUSTER BY customer_name", - "This feature is not implemented: CLUSTER BY" -)] -#[case::select_lateral_view_unsupported( - "SELECT id, number FROM person LATERAL VIEW explode(numbers) exploded_table AS number", - "This feature is not implemented: LATERAL VIEWS" -)] -#[case::select_qualify_unsupported( - "SELECT i, p, o FROM person QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", - "This feature is not implemented: QUALIFY" -)] -#[case::select_top_unsupported( - "SELECT TOP (5) * FROM person", - "This feature is not implemented: TOP" -)] -#[case::select_sort_by_unsupported( - "SELECT * FROM person SORT BY id", - "This feature is not implemented: SORT BY" -)] -#[test] -fn test_select_unsupported_syntax_errors(#[case] sql: &str, #[case] error: &str) { - let err = logical_plan(sql).unwrap_err(); - assert_eq!(err.strip_backtrace(), error) -} - -#[test] -fn select_order_by_with_cast() { - let sql = - "SELECT first_name AS first_name FROM (SELECT first_name AS first_name FROM person) ORDER BY CAST(first_name as INT)"; - let expected = "Sort: CAST(person.first_name AS Int32) ASC NULLS LAST\ - \n Projection: person.first_name\ - \n Projection: person.first_name\ - \n TableScan: person"; - quick_test(sql, expected); -} - -#[test] -fn test_avoid_add_alias() { - // avoiding adding an alias if the column name is the same. - // plan1 = plan2 - let sql = "select person.id as id from person order by person.id"; - let plan1 = logical_plan(sql).unwrap(); - let sql = "select id from person order by id"; - let plan2 = logical_plan(sql).unwrap(); - assert_eq!(format!("{plan1:?}"), format!("{plan2:?}")); -} - -#[test] -fn test_duplicated_left_join_key_inner_join() { - // person.id * 2 happen twice in left side. - let sql = "SELECT person.id, person.age - FROM person - INNER JOIN orders - ON person.id * 2 = orders.customer_id + 10 and person.id * 2 = orders.order_id"; - - let expected = "Projection: person.id, person.age\ - \n Inner Join: Filter: person.id * Int64(2) = orders.customer_id + Int64(10) AND person.id * Int64(2) = orders.order_id\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); -} - -#[test] -fn test_duplicated_right_join_key_inner_join() { - // orders.customer_id + 10 happen twice in right side. - let sql = "SELECT person.id, person.age - FROM person - INNER JOIN orders - ON person.id * 2 = orders.customer_id + 10 and person.id = orders.customer_id + 10"; - - let expected = "Projection: person.id, person.age\ - \n Inner Join: Filter: person.id * Int64(2) = orders.customer_id + Int64(10) AND person.id = orders.customer_id + Int64(10)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); -} - -#[test] -fn test_ambiguous_column_references_in_on_join() { - let sql = "select p1.id, p1.age, p2.id - from person as p1 - INNER JOIN person as p2 - ON id = 1"; - - // It should return error. - let result = logical_plan(sql); - assert!(result.is_err()); - let err = result.err().unwrap().strip_backtrace(); - - assert_snapshot!( - err, - @r###" - Schema error: Ambiguous reference to unqualified field id - "### - ); -} - -#[test] -fn test_ambiguous_column_references_with_in_using_join() { - let sql = "select p1.id, p1.age, p2.id - from person as p1 - INNER JOIN person as p2 - using(id)"; - - let expected = "Projection: p1.id, p1.age, p2.id\ - \n Inner Join: Using p1.id = p2.id\ - \n SubqueryAlias: p1\ - \n TableScan: person\ - \n SubqueryAlias: p2\ - \n TableScan: person"; - quick_test(sql, expected); -} - -#[test] -fn test_prepare_statement_to_plan_panic_param_format() { - // param is not number following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; - - assert_snapshot!( - logical_plan(sql).unwrap_err().strip_backtrace(), - @r###" - Error during planning: Invalid placeholder, not a number: $foo - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_panic_param_zero() { - // param is zero following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $0"; - - assert_snapshot!( - logical_plan(sql).unwrap_err().strip_backtrace(), - @r###" - Error during planning: Invalid placeholder, zero is not a valid index: $0 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { - // param is not number following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; - assert!(logical_plan(sql) - .unwrap_err() - .strip_backtrace() - .contains("Expected: AS, found: SELECT")) -} - -#[test] -fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; - let expected = "Schema error: No field named id."; - assert_eq!(logical_plan(sql).unwrap_err().strip_backtrace(), expected); -} - -#[test] -fn test_prepare_statement_should_infer_types() { - // only provide 1 data type while using 2 params - let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; - let plan = logical_plan(sql).unwrap(); - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int64)), - ]); - assert_eq!(actual_types, expected_types); -} - -#[test] -fn test_non_prepare_statement_should_infer_types() { - // Non prepared statements (like SELECT) should also have their parameter types inferred - let sql = "SELECT 1 + $1"; let plan = logical_plan(sql).unwrap(); - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - // constant 1 is inferred to be int64 - ("$1".to_string(), Some(DataType::Int64)), - ]); - assert_eq!(actual_types, expected_types); -} - -#[test] -#[should_panic( - expected = "Expected: [NOT] NULL | TRUE | FALSE | DISTINCT | [form] NORMALIZED FROM after IS, found: $1" -)] -fn test_prepare_statement_to_plan_panic_is_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; - logical_plan(sql).unwrap(); -} - -#[test] -fn test_prepare_statement_to_plan_no_param() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - - let expected_plan = "Prepare: \"my_plan\" [Int32] \ - \n Projection: person.id, person.age\ - \n Filter: person.age = Int64(10)\ - \n TableScan: person"; - - let expected_dt = "[Int32]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - " - ); - - ////////////////////////////////////////// - // no embedded parameter and no declare it - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; - - let expected_plan = "Prepare: \"my_plan\" [] \ - \n Projection: person.id, person.age\ - \n Filter: person.age = Int64(10)\ - \n TableScan: person"; - - let expected_dt = "[]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values - let param_values: Vec = vec![]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - " - ); -} - -#[test] -fn test_prepare_statement_to_plan_one_param_no_value_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values: Vec = vec![]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected 1 parameters, got 0 - "###); -} - -#[test] -fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values = vec![ScalarValue::Float64(Some(20.0))]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected parameter of type Int32, got Float64 at index 0 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_no_param_on_value_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values = vec![ScalarValue::Int32(Some(10))]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected 0 parameters, got 1 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_params_as_constants() { - let sql = "PREPARE my_plan(INT) AS SELECT $1"; - - let expected_plan = "Prepare: \"my_plan\" [Int32] \ - \n Projection: $1\n EmptyRelation"; - let expected_dt = "[Int32]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int32(10) AS $1 - EmptyRelation - " - ); - - /////////////////////////////////////// - let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; - - let expected_plan = "Prepare: \"my_plan\" [Int32] \ - \n Projection: Int64(1) + $1\n EmptyRelation"; - let expected_dt = "[Int32]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int64(1) + Int32(10) AS Int64(1) + $1 - EmptyRelation - " - ); - - /////////////////////////////////////// - let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; - - let expected_plan = "Prepare: \"my_plan\" [Int32, Float64] \ - \n Projection: Int64(1) + $1 + $2\n EmptyRelation"; - let expected_dt = "[Int32, Float64]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::Float64(Some(10.0)), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2 - EmptyRelation - " - ); -} - -#[test] -fn test_prepare_statement_infer_types_from_join() { - let sql = - "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1"; - - let expected_plan = r#" + plan, + @r#" Projection: person.id, orders.order_id - Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 + Full Join: Filter: person.id = Int64(10) TableScan: person TableScan: orders - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, orders.order_id - Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) - TableScan: person - TableScan: orders - " - ); -} - -#[test] -fn test_prepare_statement_infer_types_from_predicate() { - let sql = "SELECT id, age FROM person WHERE age = $1"; - - let expected_plan = r#" -Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int32(10) - TableScan: person - " - ); -} - -#[test] -fn test_prepare_statement_infer_types_from_between_predicate() { - let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; - - let expected_plan = r#" -Projection: person.id, person.age - Filter: person.age BETWEEN $1 AND $2 - TableScan: person - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age BETWEEN Int32(10) AND Int32(30) - TableScan: person - " +"# ); } #[test] -fn test_prepare_statement_infer_types_subquery() { - let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; - - let expected_plan = r#" -Projection: person.id, person.age - Filter: person.age = () - Subquery: - Projection: max(person.age) - Aggregate: groupBy=[[]], aggr=[[max(person.age)]] - Filter: person.id = $1 - TableScan: person - TableScan: person - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - +fn test_select_join_key_inner_join() { + let sql = "SELECT orders.customer_id * 2, person.id + 10 + FROM person + INNER JOIN orders + ON orders.customer_id * 2 = person.id + 10"; + let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = () - Subquery: - Projection: max(person.age) - Aggregate: groupBy=[[]], aggr=[[max(person.age)]] - Filter: person.id = UInt32(10) - TableScan: person - TableScan: person - " + plan, + @r#" +Projection: orders.customer_id * Int64(2), person.id + Int64(10) + Inner Join: Filter: orders.customer_id * Int64(2) = person.id + Int64(10) + TableScan: person + TableScan: orders +"# ); } #[test] -fn test_prepare_statement_update_infer() { - let sql = "update person set age=$1 where id=$2"; - - let expected_plan = r#" -Dml: op=[Update] table=[person] - Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: person.id = $2 +fn test_select_order_by() { + let sql = "SELECT '1' from person order by id"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: Utf8("1") + Sort: person.id ASC NULLS LAST + Projection: Utf8("1"), person.id TableScan: person - "# - .trim(); - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); +"# + ); +} - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), - ]); - assert_eq!(actual_types, expected_types); +#[test] +fn test_select_distinct_order_by() { + let sql = "SELECT distinct '1' from person order by id"; - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); + // It should return error. + let result = logical_plan(sql); + assert!(result.is_err()); + let err = result.err().unwrap().strip_backtrace(); assert_snapshot!( - plan_with_params, - @r" - Dml: op=[Update] table=[person] - Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: person.id = UInt32(1) - TableScan: person - " + err, + @r###" + Error during planning: For SELECT DISTINCT, ORDER BY expressions person.id must appear in select list + "### ); } #[test] -fn test_prepare_statement_insert_infer() { - let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - - let expected_plan = "Dml: op=[Insert Into] table=[person]\ - \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ - CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ - CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ - \n Values: ($1, $2, $3)"; - - let expected_dt = "[Int32]"; - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::from("Alan"), - ScalarValue::from("Turing"), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); +fn test_select_qualify_basic() { + let sql = "SELECT person.id, ROW_NUMBER() OVER (PARTITION BY person.age ORDER BY person.id) as rn FROM person QUALIFY rn = 1"; + let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan_with_params, + plan, @r#" - Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 - Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) - "# +Projection: person.id, row_number() PARTITION BY [person.age] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn + Filter: row_number() PARTITION BY [person.age] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = Int64(1) + WindowAggr: windowExpr=[[row_number() PARTITION BY [person.age] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: person +"# ); } #[test] -fn test_prepare_statement_to_plan_one_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1"; - - let expected_plan = "Prepare: \"my_plan\" [Int32] \ - \n Projection: person.id, person.age\ - \n Filter: person.age = $1\ - \n TableScan: person"; - - let expected_dt = "[Int32]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); +fn test_select_qualify_aggregate_reference() { + let sql = " + SELECT + person.id, + ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY person.id) as rn + FROM person + GROUP BY + person.id + QUALIFY rn = 1 AND SUM(person.age) > 0"; + let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan_with_params, + plan, @r" - Projection: person.id, person.age - Filter: person.age = Int32(10) - TableScan: person + Projection: person.id, row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn + Filter: row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = Int64(1) AND sum(person.age) > Int64(0) + WindowAggr: windowExpr=[[row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + TableScan: person " ); } #[test] -fn test_prepare_statement_to_plan_data_type() { - let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; - - // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 - // Prepare statement and its logical plan should be created successfully - let expected_plan = "Prepare: \"my_plan\" [Float64] \ - \n Projection: person.id, person.age\ - \n Filter: person.age = $1\ - \n TableScan: person"; - - let expected_dt = "[Float64]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values still succeed and use Float64 - let param_values = vec![ScalarValue::Float64(Some(10.0))]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); +fn test_select_qualify_aggregate_reference_within_window_function() { + let sql = " + SELECT + person.id + FROM person + GROUP BY + person.id + QUALIFY ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY SUM(person.age) DESC) = 1"; + let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan_with_params, + plan, @r" - Projection: person.id, person.age - Filter: person.age = Float64(10) - TableScan: person + Projection: person.id + Filter: row_number() PARTITION BY [person.id] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = Int64(1) + WindowAggr: windowExpr=[[row_number() PARTITION BY [person.id] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + TableScan: person " ); } #[test] -fn test_prepare_statement_to_plan_multi_params() { - let sql = "PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS - SELECT id, age, $6 +fn test_select_qualify_aggregate_invalid_column_reference() { + let sql = " + SELECT + person.id FROM person - WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; - - let expected_plan = "Prepare: \"my_plan\" [Int32, Utf8, Float64, Int32, Float64, Utf8] \ - \n Projection: person.id, person.age, $6\ - \n Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2\ - \n TableScan: person"; - - let expected_dt = "[Int32, Utf8, Float64, Int32, Float64, Utf8]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::from("abc"), - ScalarValue::Float64(Some(100.0)), - ScalarValue::Int32(Some(20)), - ScalarValue::Float64(Some(200.0)), - ScalarValue::from("xyz"), - ]; + GROUP BY + person.id + QUALIFY ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY person.age DESC) = 1"; + let err = logical_plan(sql).unwrap_err(); + assert_snapshot!( + err.strip_backtrace(), + @r#"Error during planning: Column in QUALIFY must be in GROUP BY or an aggregate function: While expanding wildcard, column "person.age" must appear in the GROUP BY clause or must be part of an aggregate function, currently only "person.id" appears in the SELECT clause satisfies this requirement"# + ); +} + +#[test] +fn test_select_qualify_without_window_function() { + let sql = "SELECT person.id FROM person QUALIFY person.id > 1"; + let err = logical_plan(sql).unwrap_err(); + assert_eq!( + err.strip_backtrace(), + "Error during planning: QUALIFY clause requires window functions in the SELECT list or QUALIFY clause" + ); +} - let plan_with_params = plan.with_param_values(param_values).unwrap(); +#[test] +fn test_select_qualify_complex_condition() { + let sql = "SELECT person.id, person.age, ROW_NUMBER() OVER (PARTITION BY person.age ORDER BY person.id) as rn, RANK() OVER (ORDER BY person.salary) as rank FROM person QUALIFY rn <= 2 AND rank <= 5"; + let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan_with_params, + plan, @r#" - Projection: person.id, person.age, Utf8("xyz") AS $6 - Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8("abc") +Projection: person.id, person.age, row_number() PARTITION BY [person.age] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn, rank() ORDER BY [person.salary ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rank + Filter: row_number() PARTITION BY [person.age] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW <= Int64(2) AND rank() ORDER BY [person.salary ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW <= Int64(5) + WindowAggr: windowExpr=[[rank() ORDER BY [person.salary ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + WindowAggr: windowExpr=[[row_number() PARTITION BY [person.age] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] TableScan: person - "# +"# ); } +#[rstest] +#[case::select_cluster_by_unsupported( + "SELECT customer_name, sum(order_total) as total_order_amount FROM orders CLUSTER BY customer_name", + "This feature is not implemented: CLUSTER BY" +)] +#[case::select_lateral_view_unsupported( + "SELECT id, number FROM person LATERAL VIEW explode(numbers) exploded_table AS number", + "This feature is not implemented: LATERAL VIEWS" +)] +#[case::select_top_unsupported( + "SELECT TOP (5) * FROM person", + "This feature is not implemented: TOP" +)] +#[case::select_sort_by_unsupported( + "SELECT * FROM person SORT BY id", + "This feature is not implemented: SORT BY" +)] #[test] -fn test_prepare_statement_to_plan_having() { - let sql = "PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS - SELECT id, sum(age) - FROM person \ - WHERE salary > $2 - GROUP BY id - HAVING sum(age) < $1 AND sum(age) > 10 OR sum(age) in ($3, $4)\ - "; - - let expected_plan = "Prepare: \"my_plan\" [Int32, Float64, Float64, Float64] \ - \n Projection: person.id, sum(person.age)\ - \n Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4])\ - \n Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]]\ - \n Filter: person.salary > $2\ - \n TableScan: person"; - - let expected_dt = "[Int32, Float64, Float64, Float64]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::Float64(Some(100.0)), - ScalarValue::Float64(Some(200.0)), - ScalarValue::Float64(Some(300.0)), - ]; +fn test_select_unsupported_syntax_errors(#[case] sql: &str, #[case] error: &str) { + let err = logical_plan(sql).unwrap_err(); + assert_eq!(err.strip_backtrace(), error) +} - let plan_with_params = plan.with_param_values(param_values).unwrap(); +#[test] +fn select_order_by_with_cast() { + let sql = + "SELECT first_name AS first_name FROM (SELECT first_name AS first_name FROM person) ORDER BY CAST(first_name as INT)"; + let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan_with_params, + plan, @r#" - Projection: person.id, sum(person.age) - Filter: sum(person.age) < Int32(10) AND sum(person.age) > Int64(10) OR sum(person.age) IN ([Float64(200), Float64(300)]) - Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] - Filter: person.salary > Float64(100) - TableScan: person - "# +Sort: CAST(person.first_name AS Int32) ASC NULLS LAST + Projection: person.first_name + Projection: person.first_name + TableScan: person +"# ); } #[test] -fn test_prepare_statement_to_plan_limit() { - let sql = "PREPARE my_plan(BIGINT, BIGINT) AS - SELECT id FROM person \ - OFFSET $1 LIMIT $2"; - - let expected_plan = "Prepare: \"my_plan\" [Int64, Int64] \ - \n Limit: skip=$1, fetch=$2\ - \n Projection: person.id\ - \n TableScan: person"; - - let expected_dt = "[Int64, Int64]"; - - let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); +fn test_avoid_add_alias() { + // avoiding adding an alias if the column name is the same. + // plan1 = plan2 + let sql = "select person.id as id from person order by person.id"; + let plan1 = logical_plan(sql).unwrap(); + let sql = "select id from person order by id"; + let plan2 = logical_plan(sql).unwrap(); + assert_eq!(format!("{plan1:?}"), format!("{plan2:?}")); +} - // replace params with values - let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); +#[test] +fn test_duplicated_left_join_key_inner_join() { + // person.id * 2 happen twice in left side. + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON person.id * 2 = orders.customer_id + 10 and person.id * 2 = orders.order_id"; + let plan = logical_plan(sql).unwrap(); assert_snapshot!( - plan_with_params, + plan, @r#" - Limit: skip=10, fetch=200 - Projection: person.id - TableScan: person - "# +Projection: person.id, person.age + Inner Join: Filter: person.id * Int64(2) = orders.customer_id + Int64(10) AND person.id * Int64(2) = orders.order_id + TableScan: person + TableScan: orders +"# ); } #[test] -fn test_prepare_statement_unknown_list_param() { - let sql = "SELECT id from person where id = $2"; +fn test_duplicated_right_join_key_inner_join() { + // orders.customer_id + 10 happen twice in right side. + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON person.id * 2 = orders.customer_id + 10 and person.id = orders.customer_id + 10"; let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::List(vec![]); - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!( - err.to_string(), - "Error during planning: No value found for placeholder with id $2" + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.age + Inner Join: Filter: person.id * Int64(2) = orders.customer_id + Int64(10) AND person.id = orders.customer_id + Int64(10) + TableScan: person + TableScan: orders +"# ); } #[test] -fn test_prepare_statement_unknown_hash_param() { - let sql = "SELECT id from person where id = $bar"; - let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::Map(HashMap::new()); - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!( - err.to_string(), - "Error during planning: No value found for placeholder with name $bar" +fn test_ambiguous_column_references_in_on_join() { + let sql = "select p1.id, p1.age, p2.id + from person as p1 + INNER JOIN person as p2 + ON id = 1"; + + // It should return error. + let result = logical_plan(sql); + assert!(result.is_err()); + let err = result.err().unwrap().strip_backtrace(); + + assert_snapshot!( + err, + @r###" + Schema error: Ambiguous reference to unqualified field id + "### ); } #[test] -fn test_prepare_statement_bad_list_idx() { - let sql = "SELECT id from person where id = $foo"; +fn test_ambiguous_column_references_with_in_using_join() { + let sql = "select p1.id, p1.age, p2.id + from person as p1 + INNER JOIN person as p2 + using(id)"; let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::List(vec![]); - - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!(err.to_string(), "Error during planning: Failed to parse placeholder id: invalid digit found in string"); + assert_snapshot!( + plan, + @r#" +Projection: p1.id, p1.age, p2.id + Inner Join: Using p1.id = p2.id + SubqueryAlias: p1 + TableScan: person + SubqueryAlias: p2 + TableScan: person +"# + ); } #[test] @@ -3989,12 +4423,16 @@ fn test_inner_join_with_cast_key() { FROM person INNER JOIN orders ON cast(person.id as Int) = cast(orders.customer_id as Int)"; - - let expected = "Projection: person.id, person.age\ - \n Inner Join: Filter: CAST(person.id AS Int32) = CAST(orders.customer_id AS Int32)\ - \n TableScan: person\ - \n TableScan: orders"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.age + Inner Join: Filter: CAST(person.id AS Int32) = CAST(orders.customer_id AS Int32) + TableScan: person + TableScan: orders +"# + ); } #[test] @@ -4004,29 +4442,30 @@ fn test_multi_grouping_sets() { GROUP BY person.id, GROUPING SETS ((person.age,person.salary),(person.age))"; - - let expected = "Projection: person.id, person.age\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.age, person.salary), (person.id, person.age))]], aggr=[[]]\ - \n TableScan: person"; - quick_test(sql, expected); - + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.age + Aggregate: groupBy=[[GROUPING SETS ((person.id, person.age, person.salary), (person.id, person.age))]], aggr=[[]] + TableScan: person +"# + ); let sql = "SELECT person.id, person.age FROM person GROUP BY person.id, GROUPING SETS ((person.age, person.salary),(person.age)), ROLLUP(person.state, person.birth_date)"; - - let expected = "Projection: person.id, person.age\ - \n Aggregate: groupBy=[[GROUPING SETS (\ - (person.id, person.age, person.salary), \ - (person.id, person.age, person.salary, person.state), \ - (person.id, person.age, person.salary, person.state, person.birth_date), \ - (person.id, person.age), \ - (person.id, person.age, person.state), \ - (person.id, person.age, person.state, person.birth_date))]], aggr=[[]]\ - \n TableScan: person"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: person.id, person.age + Aggregate: groupBy=[[GROUPING SETS ((person.id, person.age, person.salary), (person.id, person.age, person.salary, person.state), (person.id, person.age, person.salary, person.state, person.birth_date), (person.id, person.age), (person.id, person.age, person.state), (person.id, person.age, person.state, person.birth_date))]], aggr=[[]] + TableScan: person +"# + ); } #[test] @@ -4055,30 +4494,46 @@ fn test_field_not_found_window_function() { "### ); - let qualified_sql = - "SELECT order_id, MAX(qty) OVER (PARTITION BY orders.order_id) from orders"; - let expected = "Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\n TableScan: orders"; - quick_test(qualified_sql, expected); + let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY orders.order_id) from orders"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" +Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: orders +"# + ); } #[test] fn test_parse_escaped_string_literal_value() { let sql = r"SELECT character_length('\r\n') AS len"; - let expected = "Projection: character_length(Utf8(\"\\r\\n\")) AS len\ - \n EmptyRelation"; - quick_test(sql, expected); - - let sql = r"SELECT character_length(E'\r\n') AS len"; - let expected = "Projection: character_length(Utf8(\"\r\n\")) AS len\ - \n EmptyRelation"; - quick_test(sql, expected); - + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: character_length(Utf8("\r\n")) AS len + EmptyRelation: rows=1 + "# + ); + let sql = "SELECT character_length(E'\r\n') AS len"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: character_length(Utf8(" + ")) AS len + EmptyRelation: rows=1 + "# + ); let sql = r"SELECT character_length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; - let expected = - "Projection: character_length(Utf8(\"%\")) AS len, Utf8(\"\u{004b}\") AS hex, Utf8(\"\u{0001}\") AS unicode\ - \n EmptyRelation"; - quick_test(sql, expected); + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @"Projection: character_length(Utf8(\"%\")) AS len, Utf8(\"K\") AS hex, Utf8(\"\u{1}\") AS unicode\n EmptyRelation: rows=1" + ); let sql = r"SELECT character_length(E'\000') AS len"; @@ -4132,7 +4587,7 @@ fn assert_field_not_found(mut err: DataFusionError, name: &str) { } }; match err { - DataFusionError::SchemaError { .. } => { + DataFusionError::SchemaError(_, _) => { let msg = format!("{err}"); let expected = format!("Schema error: No field named {name}."); if !msg.starts_with(&expected) { @@ -4170,6 +4625,30 @@ fn test_no_functions_registered() { ); } +#[test] +fn test_no_substring_registered() { + // substring requires an expression planner + let sql = "SELECT SUBSTRING(foo, bar, baz) FROM person"; + let err = logical_plan(sql).expect_err("query should have failed"); + + assert_snapshot!( + err.strip_backtrace(), + @"This feature is not implemented: Substring could not be planned by registered expr planner. Hint: Please try with `unicode_expressions` DataFusion feature enabled" + ); +} + +#[test] +fn test_no_substring_registered_alt_syntax() { + // Alternate syntax for substring + let sql = "SELECT SUBSTRING(foo FROM bar) FROM person"; + let err = logical_plan(sql).expect_err("query should have failed"); + + assert_snapshot!( + err.strip_backtrace(), + @"This feature is not implemented: Substring could not be planned by registered expr planner. Hint: Please try with `unicode_expressions` DataFusion feature enabled" + ); +} + #[test] fn test_custom_type_plan() -> Result<()> { let sql = "SELECT DATETIME '2001-01-01 18:00:00'"; @@ -4185,7 +4664,7 @@ fn test_custom_type_plan() -> Result<()> { let err = planner.statement_to_plan(ast.pop_front().unwrap()); assert_contains!( err.unwrap_err().to_string(), - "This feature is not implemented: Unsupported SQL type Datetime(None)" + "This feature is not implemented: Unsupported SQL type DATETIME" ); fn plan_sql(sql: &str) -> LogicalPlan { @@ -4206,20 +4685,20 @@ fn test_custom_type_plan() -> Result<()> { assert_snapshot!( plan, - @r###" - Projection: CAST(Utf8("2001-01-01 18:00:00") AS Timestamp(Nanosecond, None)) - EmptyRelation - "### + @r#" + Projection: CAST(Utf8("2001-01-01 18:00:00") AS Timestamp(Nanosecond, None)) + EmptyRelation: rows=1 + "# ); let plan = plan_sql("SELECT CAST(TIMESTAMP '2001-01-01 18:00:00' AS DATETIME)"); assert_snapshot!( plan, - @r###" - Projection: CAST(CAST(Utf8("2001-01-01 18:00:00") AS Timestamp(Nanosecond, None)) AS Timestamp(Nanosecond, None)) - EmptyRelation - "### + @r#" + Projection: CAST(CAST(Utf8("2001-01-01 18:00:00") AS Timestamp(Nanosecond, None)) AS Timestamp(Nanosecond, None)) + EmptyRelation: rows=1 + "# ); let plan = plan_sql( @@ -4228,10 +4707,10 @@ fn test_custom_type_plan() -> Result<()> { assert_snapshot!( plan, - @r###" - Projection: make_array(CAST(Utf8("2001-01-01 18:00:00") AS Timestamp(Nanosecond, None)), CAST(Utf8("2001-01-02 18:00:00") AS Timestamp(Nanosecond, None))) - EmptyRelation - "### + @r#" + Projection: make_array(CAST(Utf8("2001-01-01 18:00:00") AS Timestamp(Nanosecond, None)), CAST(Utf8("2001-01-02 18:00:00") AS Timestamp(Nanosecond, None))) + EmptyRelation: rows=1 + "# ); Ok(()) diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 4c7ee6c1bb865..d02d5f9cb5e44 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -42,27 +42,29 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } -clap = { version = "4.5.34", features = ["derive", "env"] } -datafusion = { workspace = true, default-features = true, features = ["avro"] } +clap = { version = "4.5.47", features = ["derive", "env"] } +datafusion = { workspace = true, default-features = true, features = ["avro", "parquet_encryption"] } +datafusion-spark = { workspace = true, default-features = true } +datafusion-substrait = { workspace = true, default-features = true } futures = { workspace = true } half = { workspace = true, default-features = true } -indicatif = "0.17" +indicatif = "0.18" itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true } postgres-protocol = { version = "0.6.7", optional = true } -postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"], optional = true } -rust_decimal = { version = "1.37.1", features = ["tokio-pg"] } +postgres-types = { version = "0.2.10", features = ["derive", "with-chrono-0_4"], optional = true } +rust_decimal = { version = "1.38.0", features = ["tokio-pg"] } # When updating the following dependency verify that sqlite test file regeneration works correctly # by running the regenerate_sqlite_files.sh script. -sqllogictest = "0.28.0" +sqllogictest = "0.28.4" sqlparser = { workspace = true } tempfile = { workspace = true } -testcontainers = { version = "0.23", features = ["default"], optional = true } -testcontainers-modules = { version = "0.11", features = ["postgres"], optional = true } -thiserror = "2.0.12" +testcontainers = { workspace = true, optional = true } +testcontainers-modules = { workspace = true, features = ["postgres"], optional = true } +thiserror = "2.0.17" tokio = { workspace = true } -tokio-postgres = { version = "0.7.12", optional = true } +tokio-postgres = { version = "0.7.14", optional = true } [features] avro = ["datafusion/avro"] @@ -79,6 +81,7 @@ postgres = [ [dev-dependencies] env_logger = { workspace = true } +regex = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread"] } [[test]] diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 77162f4001ae9..a389ae1ef60e2 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -17,23 +17,29 @@ under the License. --> -# DataFusion sqllogictest +# Apache DataFusion sqllogictest -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. -This crate is a submodule of DataFusion that contains an implementation of [sqllogictest](https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki). +This crate is a submodule of DataFusion that contains an implementation of [sqllogictest]. -[df]: https://crates.io/crates/datafusion +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ +[sqllogictest]: https://www.sqlite.org/sqllogictest/doc/trunk/about.wiki ## Overview -This crate uses [sqllogictest-rs](https://github.com/risinglightdb/sqllogictest-rs) to parse and run `.slt` files in the -[`test_files`](test_files) directory of this crate or the [`data/sqlite`](https://github.com/apache/datafusion-testing/tree/main/data/sqlite) -directory of the [datafusion-testing](https://github.com/apache/datafusion-testing) crate. +This crate uses [sqllogictest-rs] to parse and run `.slt` files in the [`test_files`] directory of +this crate or the [`data/sqlite`] directory of the [datafusion-testing] repository. + +[sqllogictest-rs]: https://github.com/risinglightdb/sqllogictest-rs +[`test_files`]: test_files +[`data/sqlite`]: https://github.com/apache/datafusion-testing/tree/main/data/sqlite +[datafusion-testing]: https://github.com/apache/datafusion-testing ## Testing setup -1. `rustup update stable` DataFusion uses the latest stable release of rust +1. `rustup update stable` DataFusion uses the latest stable release of Rust 2. `git submodule init` 3. `git submodule update --init --remote --recursive` @@ -156,6 +162,14 @@ sqllogictests also supports `cargo test` style substring matches on file names t cargo test --test sqllogictests -- information ``` +Additionally, executing specific tests within a file is also supported. Tests are identified by line number within +the .slt file; for example, the following command will run the test in line `709` for file `information.slt` along +with any other preparatory statements: + +```shell +cargo test --test sqllogictests -- information:709 +``` + ## Running tests: Postgres compatibility Test files that start with prefix `pg_compat_` verify compatibility @@ -283,6 +297,27 @@ Tests that need to write temporary files should write (only) to this directory to ensure they do not interfere with others concurrently running tests. +## Running tests: Substrait round-trip mode + +This mode will run all the .slt test files in validation mode, adding a Substrait conversion round-trip for each +generated DataFusion logical plan (SQL statement → DF logical → Substrait → DF logical → DF physical → execute). + +Not all statements will be round-tripped, some statements like CREATE, INSERT, SET or EXPLAIN statements will be +issued as is, but any other statement will be round-tripped to/from Substrait. + +_WARNING_: as there are still a lot of failures in this mode (https://github.com/apache/datafusion/issues/16248), +it is not enforced in the CI, instead, it needs to be run manually with the following command: + +```shell +cargo test --test sqllogictests -- --substrait-round-trip +``` + +For focusing on one specific failing test, a file:line filter can be used: + +```shell +cargo test --test sqllogictests -- --substrait-round-trip binary.slt:23 +``` + ## `.slt` file format [`sqllogictest`] was originally written for SQLite to verify the diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 5894ec056a2eb..7aca0fdd6e8d4 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -18,10 +18,11 @@ use clap::Parser; use datafusion::common::instant::Instant; use datafusion::common::utils::get_available_parallelism; -use datafusion::common::{exec_err, DataFusionError, Result}; +use datafusion::common::{exec_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_sqllogictest::{ - df_value_validator, read_dir_recursive, setup_scratch_dir, value_normalizer, - DataFusion, TestContext, + df_value_validator, read_dir_recursive, setup_scratch_dir, should_skip_file, + should_skip_record, value_normalizer, DataFusion, DataFusionSubstraitRoundTrip, + Filter, TestContext, }; use futures::stream::StreamExt; use indicatif::{ @@ -31,8 +32,8 @@ use itertools::Itertools; use log::Level::Info; use log::{info, log_enabled}; use sqllogictest::{ - parse_file, strict_column_validator, AsyncDB, Condition, Normalizer, Record, - Validator, + parse_file, strict_column_validator, AsyncDB, Condition, MakeConnection, Normalizer, + Record, Validator, }; #[cfg(feature = "postgres")] @@ -41,6 +42,7 @@ use crate::postgres_container::{ }; use datafusion::common::runtime::SpawnedTask; use std::ffi::OsStr; +use std::fs; use std::path::{Path, PathBuf}; #[cfg(feature = "postgres")] @@ -50,6 +52,7 @@ const TEST_DIRECTORY: &str = "test_files/"; const DATAFUSION_TESTING_TEST_DIRECTORY: &str = "../../datafusion-testing/data/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; const SQLITE_PREFIX: &str = "sqlite"; +const ERRS_PER_FILE_LIMIT: usize = 10; pub fn main() -> Result<()> { tokio::runtime::Builder::new_multi_thread() @@ -101,6 +104,7 @@ async fn run_tests() -> Result<()> { // to stdout and return OK so they can continue listing other tests. return Ok(()); } + options.warn_on_ignored(); #[cfg(feature = "postgres")] @@ -121,6 +125,20 @@ async fn run_tests() -> Result<()> { let start = Instant::now(); let test_files = read_test_files(&options)?; + + // Perform scratch file sanity check + let scratch_errors = scratch_file_check(&test_files)?; + if !scratch_errors.is_empty() { + eprintln!("Scratch file sanity check failed:"); + for error in &scratch_errors { + eprintln!(" {error}"); + } + + eprintln!("\nTemporary file check failed. Please ensure that within each test file, any scratch file created is placed under a folder with the same name as the test file (without extension).\nExample: inside `join.slt`, temporary files must be created under `.../scratch/join/`\n"); + + return exec_err!("sqllogictests scratch file check failed"); + } + let num_tests = test_files.len(); let errors: Vec<_> = futures::stream::iter(test_files) .map(|test_file| { @@ -134,27 +152,49 @@ async fn run_tests() -> Result<()> { let m_clone = m.clone(); let m_style_clone = m_style.clone(); + let filters = options.filters.clone(); SpawnedTask::spawn(async move { - match (options.postgres_runner, options.complete) { - (false, false) => { - run_test_file(test_file, validator, m_clone, m_style_clone) - .await? + match ( + options.postgres_runner, + options.complete, + options.substrait_round_trip, + ) { + (_, _, true) => { + run_test_file_substrait_round_trip( + test_file, + validator, + m_clone, + m_style_clone, + filters.as_ref(), + ) + .await? } - (false, true) => { + (false, false, _) => { + run_test_file( + test_file, + validator, + m_clone, + m_style_clone, + filters.as_ref(), + ) + .await? + } + (false, true, _) => { run_complete_file(test_file, validator, m_clone, m_style_clone) .await? } - (true, false) => { + (true, false, _) => { run_test_file_with_postgres( test_file, validator, m_clone, m_style_clone, + filters.as_ref(), ) .await? } - (true, true) => { + (true, true, _) => { run_complete_file_with_postgres( test_file, validator, @@ -169,18 +209,13 @@ async fn run_tests() -> Result<()> { .join() }) // run up to num_cpus streams in parallel - .buffer_unordered(get_available_parallelism()) + .buffer_unordered(options.test_threads) .flat_map(|result| { // Filter out any Ok() leaving only the DataFusionErrors futures::stream::iter(match result { // Tokio panic error Err(e) => Some(DataFusionError::External(Box::new(e))), - Ok(thread_result) => match thread_result { - // Test run error - Err(e) => Some(e), - // success - Ok(_) => None, - }, + Ok(thread_result) => thread_result.err(), }) }) .collect() @@ -206,11 +241,51 @@ async fn run_tests() -> Result<()> { } } +async fn run_test_file_substrait_round_trip( + test_file: TestFile, + validator: Validator, + mp: MultiProgress, + mp_style: ProgressStyle, + filters: &[Filter], +) -> Result<()> { + let TestFile { + path, + relative_path, + } = test_file; + let Some(test_ctx) = TestContext::try_new_for_test_file(&relative_path).await else { + info!("Skipping: {}", path.display()); + return Ok(()); + }; + setup_scratch_dir(&relative_path)?; + + let count: u64 = get_record_count(&path, "DatafusionSubstraitRoundTrip".to_string()); + let pb = mp.add(ProgressBar::new(count)); + + pb.set_style(mp_style); + pb.set_message(format!("{:?}", &relative_path)); + + let mut runner = sqllogictest::Runner::new(|| async { + Ok(DataFusionSubstraitRoundTrip::new( + test_ctx.session_ctx().clone(), + relative_path.clone(), + pb.clone(), + )) + }); + runner.add_label("DatafusionSubstraitRoundTrip"); + runner.with_column_validator(strict_column_validator); + runner.with_normalizer(value_normalizer); + runner.with_validator(validator); + let res = run_file_in_runner(path, runner, filters).await; + pb.finish_and_clear(); + res +} + async fn run_test_file( test_file: TestFile, validator: Validator, mp: MultiProgress, mp_style: ProgressStyle, + filters: &[Filter], ) -> Result<()> { let TestFile { path, @@ -239,15 +314,49 @@ async fn run_test_file( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); + let result = run_file_in_runner(path, runner, filters).await; + pb.finish_and_clear(); + result +} - let res = runner - .run_file_async(path) - .await - .map_err(|e| DataFusionError::External(Box::new(e))); +async fn run_file_in_runner>( + path: PathBuf, + mut runner: sqllogictest::Runner, + filters: &[Filter], +) -> Result<()> { + let path = path.canonicalize()?; + let records = + parse_file(&path).map_err(|e| DataFusionError::External(Box::new(e)))?; + let mut errs = vec![]; + for record in records.into_iter() { + if let Record::Halt { .. } = record { + break; + } + if should_skip_record::(&record, filters) { + continue; + } + if let Err(err) = runner.run_async(record).await { + errs.push(format!("{err}")); + } + } - pb.finish_and_clear(); + if !errs.is_empty() { + let mut msg = format!("{} errors in file {}\n\n", errs.len(), path.display()); + for (i, err) in errs.iter().enumerate() { + if i >= ERRS_PER_FILE_LIMIT { + msg.push_str(&format!( + "... other {} errors in {} not shown ...\n\n", + errs.len() - ERRS_PER_FILE_LIMIT, + path.display() + )); + break; + } + msg.push_str(&format!("{}. {err}\n\n", i + 1)); + } + return Err(DataFusionError::External(msg.into())); + } - res + Ok(()) } fn get_record_count(path: &PathBuf, label: String) -> u64 { @@ -292,6 +401,7 @@ async fn run_test_file_with_postgres( validator: Validator, mp: MultiProgress, mp_style: ProgressStyle, + filters: &[Filter], ) -> Result<()> { use datafusion_sqllogictest::Postgres; let TestFile { @@ -313,14 +423,9 @@ async fn run_test_file_with_postgres( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - runner - .run_file_async(path) - .await - .map_err(|e| DataFusionError::External(Box::new(e)))?; - + let result = run_file_in_runner(path, runner, filters).await; pb.finish_and_clear(); - - Ok(()) + result } #[cfg(not(feature = "postgres"))] @@ -329,6 +434,7 @@ async fn run_test_file_with_postgres( _validator: Validator, _mp: MultiProgress, _mp_style: ProgressStyle, + _filters: &[Filter], ) -> Result<()> { use datafusion::common::plan_err; plan_err!("Can not run with postgres as postgres feature is not enabled") @@ -378,9 +484,7 @@ async fn run_complete_file( ) .await // Can't use e directly because it isn't marked Send, so turn it into a string. - .map_err(|e| { - DataFusionError::Execution(format!("Error completing {relative_path:?}: {e}")) - }); + .map_err(|e| exec_datafusion_err!("Error completing {relative_path:?}: {e}")); pb.finish_and_clear(); @@ -430,9 +534,7 @@ async fn run_complete_file_with_postgres( ) .await // Can't use e directly because it isn't marked Send, so turn it into a string. - .map_err(|e| { - DataFusionError::Execution(format!("Error completing {relative_path:?}: {e}")) - }); + .map_err(|e| exec_datafusion_err!("Error completing {relative_path:?}: {e}")); pb.finish_and_clear(); @@ -542,14 +644,25 @@ struct Options { )] postgres_runner: bool, + #[clap( + long, + conflicts_with = "complete", + conflicts_with = "postgres_runner", + help = "Before executing each query, convert its logical plan to Substrait and from Substrait back to its logical plan" + )] + substrait_round_trip: bool, + #[clap(long, env = "INCLUDE_SQLITE", help = "Include sqlite files")] include_sqlite: bool, #[clap(long, env = "INCLUDE_TPCH", help = "Include tpch files")] include_tpch: bool, - #[clap(action, help = "test filter (substring match on filenames)")] - filters: Vec, + #[clap( + action, + help = "test filter (substring match on filenames with optional :{line_number} suffix)" + )] + filters: Vec, #[clap( long, @@ -587,6 +700,13 @@ struct Options { help = "IGNORED (for compatibility with built-in rust test runner)" )] nocapture: bool, + + #[clap( + long, + help = "Number of threads used for running tests in parallel", + default_value_t = get_available_parallelism() + )] + test_threads: usize, } impl Options { @@ -602,15 +722,7 @@ impl Options { /// filter and that does a substring match on each input. returns /// true f this path should be run fn check_test_file(&self, path: &Path) -> bool { - if self.filters.is_empty() { - return true; - } - - // otherwise check if any filter matches - let path_string = path.to_string_lossy(); - self.filters - .iter() - .any(|filter| path_string.contains(filter)) + !should_skip_file(path, &self.filters) } /// Postgres runner executes only tests in files with specific names or in @@ -637,3 +749,67 @@ impl Options { } } } + +/// Performs scratch file check for all test files. +/// +/// Scratch file rule: In each .slt test file, the temporary file created must +/// be under a folder that is has the same name as the test file. +/// e.g. In `join.slt`, temporary files must be created under `.../scratch/join/` +/// +/// See: +/// +/// This function searches for `scratch/[target]/...` patterns and verifies +/// that the target matches the file name. +/// +/// Returns a vector of error strings for incorrectly created scratch files. +fn scratch_file_check(test_files: &[TestFile]) -> Result> { + let mut errors = Vec::new(); + + // Search for any scratch/[target]/... patterns and check if they match the file name + let scratch_pattern = regex::Regex::new(r"scratch/([^/]+)/").unwrap(); + + for test_file in test_files { + // Get the file content + let content = match fs::read_to_string(&test_file.path) { + Ok(content) => content, + Err(e) => { + errors.push(format!( + "Failed to read file {}: {}", + test_file.path.display(), + e + )); + continue; + } + }; + + // Get the expected target name (file name without extension) + let expected_target = match test_file.path.file_stem() { + Some(stem) => stem.to_string_lossy().to_string(), + None => { + errors.push(format!("File {} has no stem", test_file.path.display())); + continue; + } + }; + + let lines: Vec<&str> = content.lines().collect(); + + for (line_num, line) in lines.iter().enumerate() { + if let Some(captures) = scratch_pattern.captures(line) { + if let Some(found_target) = captures.get(1) { + let found_target = found_target.as_str(); + if found_target != expected_target { + errors.push(format!( + "File {}:{}: scratch target '{}' does not match file name '{}'", + test_file.path.display(), + line_num + 1, + found_target, + expected_target + )); + } + } + } + } + } + + Ok(errors) +} diff --git a/datafusion/sqllogictest/data/composite_order.csv b/datafusion/sqllogictest/data/composite_order.csv new file mode 100644 index 0000000000000..b2c5e881bd605 --- /dev/null +++ b/datafusion/sqllogictest/data/composite_order.csv @@ -0,0 +1,8 @@ +a,b +1,0 +0,2 +1,2 +0,4 +5,0 +3,3 +4,3 diff --git a/datafusion/sqllogictest/regenerate/sqllogictests.rs b/datafusion/sqllogictest/regenerate/sqllogictests.rs index edad16bc84b1c..a50c4ae1cb7b1 100644 --- a/datafusion/sqllogictest/regenerate/sqllogictests.rs +++ b/datafusion/sqllogictest/regenerate/sqllogictests.rs @@ -497,7 +497,7 @@ async fn run_complete_file_with_postgres( .await // Can't use e directly because it isn't marked Send, so turn it into a string. .map_err(|e| { - DataFusionError::Execution(format!("Error completing {relative_path:?}: {e}")) + exec_datafusion_err!("Failed to complete test file {relative_path:?}: {e}") }); pb.finish_and_clear(); diff --git a/datafusion/sqllogictest/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs index 516ec69e0b07d..de3acbee93b1a 100644 --- a/datafusion/sqllogictest/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -35,7 +35,8 @@ pub(crate) fn varchar_to_str(value: &str) -> String { if value.is_empty() { "(empty)".to_string() } else { - value.trim_end_matches('\n').to_string() + // Escape nulls so that github renders them correctly in the webui + value.trim_end_matches('\n').replace("\u{0000}", "\\0") } } @@ -49,7 +50,7 @@ pub(crate) fn f16_to_str(value: f16) -> String { } else if value == f16::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } } @@ -63,7 +64,7 @@ pub(crate) fn f32_to_str(value: f32) -> String { } else if value == f32::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } } @@ -77,7 +78,21 @@ pub(crate) fn f64_to_str(value: f64) -> String { } else if value == f64::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) + } +} + +pub(crate) fn spark_f64_to_str(value: f64) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f64::INFINITY { + "Infinity".to_string() + } else if value == f64::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), Some(15)) } } @@ -86,6 +101,7 @@ pub(crate) fn decimal_128_to_str(value: i128, scale: i8) -> String { big_decimal_to_str( BigDecimal::from_str(&Decimal128Type::format_decimal(value, precision, scale)) .unwrap(), + None, ) } @@ -94,17 +110,21 @@ pub(crate) fn decimal_256_to_str(value: i256, scale: i8) -> String { big_decimal_to_str( BigDecimal::from_str(&Decimal256Type::format_decimal(value, precision, scale)) .unwrap(), + None, ) } #[cfg(feature = "postgres")] pub(crate) fn decimal_to_str(value: Decimal) -> String { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } -pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { +/// Converts a `BigDecimal` to its plain string representation, optionally rounding to a specified number of decimal places. +/// +/// If `round_digits` is `None`, the value is rounded to 12 decimal places by default. +pub(crate) fn big_decimal_to_str(value: BigDecimal, round_digits: Option) -> String { // Round the value to limit the number of decimal places - let value = value.round(12).normalized(); + let value = value.round(round_digits.unwrap_or(12)).normalized(); // Format the value to a string value.to_plain_string() } @@ -115,12 +135,12 @@ mod tests { use bigdecimal::{num_bigint::BigInt, BigDecimal}; macro_rules! assert_decimal_str_eq { - ($integer:expr, $scale:expr, $expected:expr) => { + ($integer:expr, $scale:expr, $round_digits:expr, $expected:expr) => { assert_eq!( - big_decimal_to_str(BigDecimal::from_bigint( - BigInt::from($integer), - $scale - )), + big_decimal_to_str( + BigDecimal::from_bigint(BigInt::from($integer), $scale), + $round_digits + ), $expected ); }; @@ -128,44 +148,51 @@ mod tests { #[test] fn test_big_decimal_to_str() { - assert_decimal_str_eq!(110, 3, "0.11"); - assert_decimal_str_eq!(11, 3, "0.011"); - assert_decimal_str_eq!(11, 2, "0.11"); - assert_decimal_str_eq!(11, 1, "1.1"); - assert_decimal_str_eq!(11, 0, "11"); - assert_decimal_str_eq!(11, -1, "110"); - assert_decimal_str_eq!(0, 0, "0"); + assert_decimal_str_eq!(110, 3, None, "0.11"); + assert_decimal_str_eq!(11, 3, None, "0.011"); + assert_decimal_str_eq!(11, 2, None, "0.11"); + assert_decimal_str_eq!(11, 1, None, "1.1"); + assert_decimal_str_eq!(11, 0, None, "11"); + assert_decimal_str_eq!(11, -1, None, "110"); + assert_decimal_str_eq!(0, 0, None, "0"); assert_decimal_str_eq!( 12345678901234567890123456789012345678_i128, 0, + None, "12345678901234567890123456789012345678" ); assert_decimal_str_eq!( 12345678901234567890123456789012345678_i128, 38, + None, "0.123456789012" ); // Negative cases - assert_decimal_str_eq!(-110, 3, "-0.11"); - assert_decimal_str_eq!(-11, 3, "-0.011"); - assert_decimal_str_eq!(-11, 2, "-0.11"); - assert_decimal_str_eq!(-11, 1, "-1.1"); - assert_decimal_str_eq!(-11, 0, "-11"); - assert_decimal_str_eq!(-11, -1, "-110"); + assert_decimal_str_eq!(-110, 3, None, "-0.11"); + assert_decimal_str_eq!(-11, 3, None, "-0.011"); + assert_decimal_str_eq!(-11, 2, None, "-0.11"); + assert_decimal_str_eq!(-11, 1, None, "-1.1"); + assert_decimal_str_eq!(-11, 0, None, "-11"); + assert_decimal_str_eq!(-11, -1, None, "-110"); assert_decimal_str_eq!( -12345678901234567890123456789012345678_i128, 0, + None, "-12345678901234567890123456789012345678" ); assert_decimal_str_eq!( -12345678901234567890123456789012345678_i128, 38, + None, "-0.123456789012" ); // Round to 12 decimal places // 1.0000000000011 -> 1.000000000001 - assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, "1.000000000001"); + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, None, "1.000000000001"); + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, Some(12), "1.000000000001"); + + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, Some(13), "1.0000000000011"); } } diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs index a60ae1012f9cf..f4e1a967e4834 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs @@ -28,7 +28,7 @@ pub type Result = std::result::Result; pub enum DFSqlLogicTestError { /// Error from sqllogictest-rs #[error("SqlLogicTest error(from sqllogictest-rs crate): {0}")] - SqlLogicTest(#[from] TestError), + SqlLogicTest(#[from] Box), /// Error from datafusion #[error("DataFusion error: {}", .0.strip_backtrace())] DataFusion(#[from] DataFusionError), diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index eeb34186ea208..87108b67424b2 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -19,41 +19,47 @@ use super::super::conversion::*; use super::error::{DFSqlLogicTestError, Result}; use crate::engines::output::DFColumnType; use arrow::array::{Array, AsArray}; -use arrow::datatypes::Fields; +use arrow::datatypes::{Fields, Schema}; use arrow::util::display::ArrayFormatter; use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; -use datafusion::common::format::DEFAULT_CLI_FORMAT_OPTIONS; -use datafusion::common::DataFusionError; +use datafusion::common::internal_datafusion_err; +use datafusion::config::ConfigField; use std::path::PathBuf; use std::sync::LazyLock; /// Converts `batches` to a result as expected by sqllogictest. -pub fn convert_batches(batches: Vec) -> Result>> { - if batches.is_empty() { - Ok(vec![]) - } else { - let schema = batches[0].schema(); - let mut rows = vec![]; - for batch in batches { - // Verify schema - if !schema.contains(&batch.schema()) { - return Err(DFSqlLogicTestError::DataFusion(DataFusionError::Internal( - format!( - "Schema mismatch. Previously had\n{:#?}\n\nGot:\n{:#?}", - &schema, - batch.schema() - ), - ))); - } - - let new_rows = convert_batch(batch)? - .into_iter() - .flat_map(expand_row) - .map(normalize_paths); - rows.extend(new_rows); +pub fn convert_batches( + schema: &Schema, + batches: Vec, + is_spark_path: bool, +) -> Result>> { + let mut rows = vec![]; + for batch in batches { + // Verify schema + if !schema.contains(&batch.schema()) { + return Err(DFSqlLogicTestError::DataFusion(internal_datafusion_err!( + "Schema mismatch. Previously had\n{:#?}\n\nGot:\n{:#?}", + &schema, + batch.schema() + ))); } - Ok(rows) + + // Convert a single batch to a `Vec>` for comparison, flatten expanded rows, and normalize each. + let new_rows = (0..batch.num_rows()) + .map(|row| { + batch + .columns() + .iter() + .map(|col| cell_to_string(col, row, is_spark_path)) + .collect::>>() + }) + .collect::>>>()? + .into_iter() + .flat_map(expand_row) + .map(normalize_paths); + rows.extend(new_rows); } + Ok(rows) } /// special case rows that have newlines in them (like explain plans) @@ -162,19 +168,6 @@ static WORKSPACE_ROOT: LazyLock = LazyLock::new(|| { object_store::path::Path::parse(sanitized_workplace_root).unwrap() }); -/// Convert a single batch to a `Vec>` for comparison -fn convert_batch(batch: RecordBatch) -> Result>> { - (0..batch.num_rows()) - .map(|row| { - batch - .columns() - .iter() - .map(|col| cell_to_string(col, row)) - .collect::>>() - }) - .collect() -} - macro_rules! get_row_value { ($array_type:ty, $column: ident, $row: ident) => {{ let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); @@ -193,7 +186,7 @@ macro_rules! get_row_value { /// /// Floating numbers are rounded to have a consistent representation with the Postgres runner. /// -pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { +pub fn cell_to_string(col: &ArrayRef, row: usize, is_spark_path: bool) -> Result { if !col.is_valid(row) { // represent any null value with the string "NULL" Ok(NULL_STR.to_string()) @@ -210,7 +203,12 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { Ok(f32_to_str(get_row_value!(array::Float32Array, col, row))) } DataType::Float64 => { - Ok(f64_to_str(get_row_value!(array::Float64Array, col, row))) + let result = get_row_value!(array::Float64Array, col, row); + if is_spark_path { + Ok(spark_f64_to_str(result)) + } else { + Ok(f64_to_str(result)) + } } DataType::Decimal128(_, scale) => { let value = get_row_value!(array::Decimal128Array, col, row); @@ -236,12 +234,20 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { DataType::Dictionary(_, _) => { let dict = col.as_any_dictionary(); let key = dict.normalized_keys()[row]; - Ok(cell_to_string(dict.values(), key)?) + Ok(cell_to_string(dict.values(), key, is_spark_path)?) } _ => { - let f = - ArrayFormatter::try_new(col.as_ref(), &DEFAULT_CLI_FORMAT_OPTIONS); - Ok(f.unwrap().value(row).to_string()) + let mut datafusion_format_options = + datafusion::config::FormatOptions::default(); + + datafusion_format_options.set("null", "NULL").unwrap(); + + let arrow_format_options: arrow::util::display::FormatOptions = + (&datafusion_format_options).try_into().unwrap(); + + let f = ArrayFormatter::try_new(col.as_ref(), &arrow_format_options)?; + + Ok(f.value(row).to_string()) } } .map_err(DFSqlLogicTestError::Arrow) @@ -280,7 +286,9 @@ pub fn convert_schema_to_types(columns: &Fields) -> Vec { if key_type.is_integer() { // mapping dictionary string types to Text match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => DFColumnType::Text, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + DFColumnType::Text + } _ => DFColumnType::Another, } } else { diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs index a3a29eda2ee9c..45deefdc9bbdf 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs @@ -31,6 +31,7 @@ use sqllogictest::DBOutput; use tokio::time::Instant; use crate::engines::output::{DFColumnType, DFOutput}; +use crate::is_spark_path; pub struct DataFusion { ctx: SessionContext, @@ -79,7 +80,7 @@ impl sqllogictest::AsyncDB for DataFusion { } let start = Instant::now(); - let result = run_query(&self.ctx, sql).await; + let result = run_query(&self.ctx, is_spark_path(&self.relative_path), sql).await; let duration = start.elapsed(); if duration.gt(&Duration::from_millis(500)) { @@ -115,15 +116,20 @@ impl sqllogictest::AsyncDB for DataFusion { async fn shutdown(&mut self) {} } -async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result { +async fn run_query( + ctx: &SessionContext, + is_spark_path: bool, + sql: impl Into, +) -> Result { let df = ctx.sql(sql.into().as_str()).await?; let task_ctx = Arc::new(df.task_ctx()); let plan = df.create_physical_plan().await?; + let schema = plan.schema(); let stream = execute_stream(plan, task_ctx)?; let types = normalize::convert_schema_to_types(stream.schema().fields()); let results: Vec = collect(stream).await?; - let rows = normalize::convert_batches(results)?; + let rows = normalize::convert_batches(&schema, results, is_spark_path)?; if rows.is_empty() && types.is_empty() { Ok(DBOutput::StatementComplete(0)) diff --git a/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs new file mode 100644 index 0000000000000..9ff077c67d8c1 --- /dev/null +++ b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod runner; + +pub use runner::*; diff --git a/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs new file mode 100644 index 0000000000000..2df93f0dede33 --- /dev/null +++ b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::{path::PathBuf, time::Duration}; + +use crate::engines::datafusion_engine::Result; +use crate::engines::output::{DFColumnType, DFOutput}; +use crate::{convert_batches, convert_schema_to_types, DFSqlLogicTestError}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use datafusion::logical_expr::LogicalPlan; +use datafusion::physical_plan::common::collect; +use datafusion::physical_plan::execute_stream; +use datafusion::prelude::SessionContext; +use datafusion_substrait::logical_plan::consumer::from_substrait_plan; +use datafusion_substrait::logical_plan::producer::to_substrait_plan; +use indicatif::ProgressBar; +use log::Level::{Debug, Info}; +use log::{debug, log_enabled, warn}; +use sqllogictest::DBOutput; +use tokio::time::Instant; + +pub struct DataFusionSubstraitRoundTrip { + ctx: SessionContext, + relative_path: PathBuf, + pb: ProgressBar, +} + +impl DataFusionSubstraitRoundTrip { + pub fn new(ctx: SessionContext, relative_path: PathBuf, pb: ProgressBar) -> Self { + Self { + ctx, + relative_path, + pb, + } + } + + fn update_slow_count(&self) { + let msg = self.pb.message(); + let split: Vec<&str> = msg.split(" ").collect(); + let mut current_count = 0; + + if split.len() > 2 { + // third match will be current slow count + current_count = split[2].parse::().unwrap(); + } + + current_count += 1; + + self.pb + .set_message(format!("{} - {} took > 500 ms", split[0], current_count)); + } +} + +#[async_trait] +impl sqllogictest::AsyncDB for DataFusionSubstraitRoundTrip { + type Error = DFSqlLogicTestError; + type ColumnType = DFColumnType; + + async fn run(&mut self, sql: &str) -> Result { + if log_enabled!(Debug) { + debug!( + "[{}] Running query: \"{}\"", + self.relative_path.display(), + sql + ); + } + + let start = Instant::now(); + let result = run_query_substrait_round_trip(&self.ctx, sql).await; + let duration = start.elapsed(); + + if duration.gt(&Duration::from_millis(500)) { + self.update_slow_count(); + } + + self.pb.inc(1); + + if log_enabled!(Info) && duration.gt(&Duration::from_secs(2)) { + warn!( + "[{}] Running query took more than 2 sec ({duration:?}): \"{sql}\"", + self.relative_path.display() + ); + } + + result + } + + /// Engine name of current database. + fn engine_name(&self) -> &str { + "DataFusionSubstraitRoundTrip" + } + + /// `DataFusion` calls this function to perform sleep. + /// + /// The default implementation is `std::thread::sleep`, which is universal to any async runtime + /// but would block the current thread. If you are running in tokio runtime, you should override + /// this by `tokio::time::sleep`. + async fn sleep(dur: Duration) { + tokio::time::sleep(dur).await; + } + + async fn shutdown(&mut self) {} +} + +async fn run_query_substrait_round_trip( + ctx: &SessionContext, + sql: impl Into, +) -> Result { + let df = ctx.sql(sql.into().as_str()).await?; + let task_ctx = Arc::new(df.task_ctx()); + + let state = ctx.state(); + let round_tripped_plan = match df.logical_plan() { + // Substrait does not handle these plans + LogicalPlan::Ddl(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Copy(_) + | LogicalPlan::Statement(_) => df.logical_plan().clone(), + // For any other plan, convert to Substrait + logical_plan => { + let plan = to_substrait_plan(logical_plan, &state)?; + from_substrait_plan(&state, &plan).await? + } + }; + + let physical_plan = state.create_physical_plan(&round_tripped_plan).await?; + let schema = physical_plan.schema(); + let stream = execute_stream(physical_plan, task_ctx)?; + let types = convert_schema_to_types(stream.schema().fields()); + let results: Vec = collect(stream).await?; + let rows = convert_batches(&schema, results, false)?; + + if rows.is_empty() && types.is_empty() { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { types, rows }) + } +} diff --git a/datafusion/sqllogictest/src/engines/mod.rs b/datafusion/sqllogictest/src/engines/mod.rs index 3569dea701761..ef6335ddbed66 100644 --- a/datafusion/sqllogictest/src/engines/mod.rs +++ b/datafusion/sqllogictest/src/engines/mod.rs @@ -18,12 +18,14 @@ /// Implementation of sqllogictest for datafusion. mod conversion; mod datafusion_engine; +mod datafusion_substrait_roundtrip_engine; mod output; pub use datafusion_engine::convert_batches; pub use datafusion_engine::convert_schema_to_types; pub use datafusion_engine::DFSqlLogicTestError; pub use datafusion_engine::DataFusion; +pub use datafusion_substrait_roundtrip_engine::DataFusionSubstraitRoundTrip; pub use output::DFColumnType; pub use output::DFOutput; diff --git a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs index 68816626bf672..375f06d34b44f 100644 --- a/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs +++ b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs @@ -93,7 +93,7 @@ impl Postgres { let spawned_task = SpawnedTask::spawn(async move { if let Err(e) = connection.await { - log::error!("Postgres connection error: {:?}", e); + log::error!("Postgres connection error: {e:?}"); } }); diff --git a/datafusion/sqllogictest/src/filters.rs b/datafusion/sqllogictest/src/filters.rs new file mode 100644 index 0000000000000..44482236f7c5b --- /dev/null +++ b/datafusion/sqllogictest/src/filters.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::sql::parser::{DFParserBuilder, Statement}; +use sqllogictest::{AsyncDB, Record}; +use sqlparser::ast::{SetExpr, Statement as SqlStatement}; +use sqlparser::dialect::dialect_from_str; +use std::path::Path; +use std::str::FromStr; + +/// Filter specification that determines whether a certain sqllogictest record in +/// a certain file should be filtered. In order for a [`Filter`] to match a test case: +/// +/// - The test must belong to a file whose absolute path contains the `file_substring` substring. +/// - If a `line_number` is specified, the test must be declared in that same line number. +/// +/// If a [`Filter`] matches a specific test case, then the record is executed, if there's +/// no match, the record is skipped. +/// +/// Filters can be parsed from strings of the form `:line_number`. For example, +/// `foo.slt:100` matches any test whose name contains `foo.slt` and the test starts on line +/// number 100. +#[derive(Debug, Clone)] +pub struct Filter { + file_substring: String, + line_number: Option, +} + +impl FromStr for Filter { + type Err = String; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.rsplitn(2, ':').collect(); + if parts.len() == 2 { + match parts[0].parse::() { + Ok(line) => Ok(Filter { + file_substring: parts[1].to_string(), + line_number: Some(line), + }), + Err(_) => Err(format!("Cannot parse line number from '{s}'")), + } + } else { + Ok(Filter { + file_substring: s.to_string(), + line_number: None, + }) + } + } +} + +/// Given a list of [`Filter`]s, determines if the whole file in the provided +/// path can be skipped. +/// +/// - If there's at least 1 filter whose file name is a substring of the provided path, +/// it returns true. +/// - If the provided filter list is empty, it returns false. +pub fn should_skip_file(path: &Path, filters: &[Filter]) -> bool { + if filters.is_empty() { + return false; + } + + let path_string = path.to_string_lossy(); + for filter in filters { + if path_string.contains(&filter.file_substring) { + return false; + } + } + true +} + +/// Determines whether a certain sqllogictest record should be skipped given the provided +/// filters. +/// +/// If there's at least 1 matching filter, or the filter list is empty, it returns false. +/// +/// There are certain records that will never be skipped even if they are not matched +/// by any filters, like CREATE TABLE, INSERT INTO, DROP or SELECT * INTO statements, +/// as they populate tables necessary for other tests to work. +pub fn should_skip_record( + record: &Record, + filters: &[Filter], +) -> bool { + if filters.is_empty() { + return false; + } + + let (sql, loc) = match record { + Record::Statement { sql, loc, .. } => (sql, loc), + Record::Query { sql, loc, .. } => (sql, loc), + _ => return false, + }; + + let statement = if let Some(statement) = parse_or_none(sql, "Postgres") { + statement + } else if let Some(statement) = parse_or_none(sql, "generic") { + statement + } else { + return false; + }; + + if !statement_is_skippable(&statement) { + return false; + } + + for filter in filters { + if !loc.file().contains(&filter.file_substring) { + continue; + } + if let Some(line_num) = filter.line_number { + if loc.line() != line_num { + continue; + } + } + + // This filter matches both file name substring and the exact + // line number (if one was provided), so don't skip it. + return false; + } + + true +} + +fn statement_is_skippable(statement: &Statement) -> bool { + // Only SQL statements can be skipped. + let Statement::Statement(sql_stmt) = statement else { + return false; + }; + + // Cannot skip SELECT INTO statements, as they can also create tables + // that further test cases will use. + if let SqlStatement::Query(v) = sql_stmt.as_ref() { + if let SetExpr::Select(v) = v.body.as_ref() { + if v.into.is_some() { + return false; + } + } + } + + // Only SELECT and EXPLAIN statements can be skipped, as any other + // statement might be populating tables that future test cases will use. + matches!( + sql_stmt.as_ref(), + SqlStatement::Query(_) | SqlStatement::Explain { .. } + ) +} + +fn parse_or_none(sql: &str, dialect: &str) -> Option { + let Ok(Ok(Some(statement))) = DFParserBuilder::new(sql) + .with_dialect(dialect_from_str(dialect).unwrap().as_ref()) + .build() + .map(|mut v| v.parse_statements().map(|mut v| v.pop_front())) + else { + return None; + }; + Some(statement) +} diff --git a/datafusion/sqllogictest/src/lib.rs b/datafusion/sqllogictest/src/lib.rs index 1a208aa3fac2c..f3a78607242ce 100644 --- a/datafusion/sqllogictest/src/lib.rs +++ b/datafusion/sqllogictest/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] @@ -34,12 +34,15 @@ pub use engines::DFColumnType; pub use engines::DFOutput; pub use engines::DFSqlLogicTestError; pub use engines::DataFusion; +pub use engines::DataFusionSubstraitRoundTrip; #[cfg(feature = "postgres")] pub use engines::Postgres; +mod filters; mod test_context; mod util; +pub use filters::*; pub use test_context::TestContext; pub use util::*; diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index ce819f1864544..b499401e5589c 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; @@ -31,8 +32,13 @@ use arrow::record_batch::RecordBatch; use datafusion::catalog::{ CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, Session, }; -use datafusion::common::DataFusionError; -use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; +use datafusion::common::{not_impl_err, DataFusionError, Result}; +use datafusion::functions::math::abs; +use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion::logical_expr::{ + create_udf, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, +}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use datafusion::{ @@ -40,8 +46,11 @@ use datafusion::{ prelude::{CsvReadOptions, SessionContext}, }; +use crate::is_spark_path; use async_trait::async_trait; use datafusion::common::cast::as_float64_array; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::SessionStateBuilder; use log::info; use tempfile::TempDir; @@ -70,8 +79,20 @@ impl TestContext { let config = SessionConfig::new() // hardcode target partitions so plans are deterministic .with_target_partitions(4); + let runtime = Arc::new(RuntimeEnv::default()); + let mut state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); + + if is_spark_path(relative_path) { + info!("Registering Spark functions"); + datafusion_spark::register_all(&mut state) + .expect("Can not register Spark functions"); + } - let mut test_ctx = TestContext::new(SessionContext::new_with_config(config)); + let mut test_ctx = TestContext::new(SessionContext::new_with_state(state)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { @@ -118,10 +139,15 @@ impl TestContext { info!("Registering table with union column"); register_union_table(test_ctx.session_ctx()) } + "async_udf.slt" => { + info!("Registering dummy async udf"); + register_async_abs_udf(test_ctx.session_ctx()) + } _ => { info!("Using default SessionContext"); } }; + Some(test_ctx) } @@ -219,18 +245,18 @@ pub async fn register_temp_table(ctx: &SessionContext) { #[async_trait] impl TableProvider for TestTable { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } - fn table_type(&self) -> TableType { - self.0 - } - fn schema(&self) -> SchemaRef { unimplemented!() } + fn table_type(&self) -> TableType { + self.0 + } + async fn scan( &self, _state: &dyn Session, @@ -410,10 +436,24 @@ fn create_example_udf() -> ScalarUDF { fn register_union_table(ctx: &SessionContext) { let union = UnionArray::try_new( - UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), - ScalarBuffer::from(vec![3, 3]), + UnionFields::new( + // typeids: 3 for int, 1 for string + vec![3, 1], + vec![ + Field::new("int", DataType::Int32, false), + Field::new("string", DataType::Utf8, false), + ], + ), + ScalarBuffer::from(vec![3, 1, 3]), None, - vec![Arc::new(Int32Array::from(vec![1, 2]))], + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("bar"), + Some("baz"), + ])), + ], ) .unwrap(); @@ -428,3 +468,48 @@ fn register_union_table(ctx: &SessionContext) { ctx.register_batch("union_table", batch).unwrap(); } + +fn register_async_abs_udf(ctx: &SessionContext) { + #[derive(Debug, PartialEq, Eq, Hash)] + struct AsyncAbs { + inner_abs: Arc, + } + impl AsyncAbs { + fn new() -> Self { + AsyncAbs { inner_abs: abs() } + } + } + impl ScalarUDFImpl for AsyncAbs { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_abs" + } + + fn signature(&self) -> &Signature { + self.inner_abs.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner_abs.return_type(arg_types) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("{} can only be called from async contexts", self.name()) + } + } + #[async_trait] + impl AsyncScalarUDFImpl for AsyncAbs { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + return self.inner_abs.invoke_with_args(args); + } + } + let async_abs = AsyncAbs::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_abs)); + ctx.register_udf(udf.into_scalar_udf()); +} diff --git a/datafusion/sqllogictest/src/util.rs b/datafusion/sqllogictest/src/util.rs index 5ae640cc98a90..695fe463fa676 100644 --- a/datafusion/sqllogictest/src/util.rs +++ b/datafusion/sqllogictest/src/util.rs @@ -106,3 +106,7 @@ pub fn df_value_validator( normalized_actual == normalized_expected } + +pub fn is_spark_path(relative_path: &Path) -> bool { + relative_path.starts_with("spark/") +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 621e212ebc718..9d6c7b11add6d 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -32,10 +32,12 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c10 BIGINT UNSIGNED NOT NULL, c11 FLOAT NOT NULL, c12 DOUBLE NOT NULL, - c13 VARCHAR NOT NULL + c13 VARCHAR NOT NULL, + c14 DATE NOT NULL, + c15 TIMESTAMP NOT NULL, ) STORED AS CSV -LOCATION '../../testing/data/csv/aggregate_test_100.csv' +LOCATION '../../testing/data/csv/aggregate_test_100_with_dates.csv' OPTIONS ('format.has_header' 'true'); statement ok @@ -132,37 +134,48 @@ statement error DataFusion error: Schema error: Schema contains duplicate unqual SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c1, 0.95) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c2, c1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins statement error DataFusion error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\)\. -SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(0.95, -1000) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function +SELECT approx_percentile_cont(0.95, c1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 +statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from Int16, Float64, Float64 to the signature OneOf(.*) failed(.|\n)* +SELECT approx_percentile_cont(0.95, 111.1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Float64, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 +statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from Float64, Float64, Float64 to the signature OneOf(.*) failed(.|\n)* +SELECT approx_percentile_cont(0.95, 111.1) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal -SELECT approx_percentile_cont(c12, c12) FROM aggregate_test_100 +SELECT approx_percentile_cont(c12) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal -SELECT approx_percentile_cont(c12, 0.95, c5) FROM aggregate_test_100 +SELECT approx_percentile_cont(0.95, c5) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: \[IGNORE | RESPECT\] NULLS are not permitted for approx_percentile_cont +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5) IGNORE NULLS FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: \[IGNORE | RESPECT\] NULLS are not permitted for approx_percentile_cont +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5) RESPECT NULLS FROM aggregate_test_100 + +statement error DataFusion error: This feature is not implemented: Only a single ordering expression is permitted in a WITHIN GROUP clause +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5, c12) FROM aggregate_test_100 # Not supported over sliding windows -query error This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented -SELECT approx_percentile_cont(c3, 0.5) OVER (ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) +query error DataFusion error: Error during planning: OVER and WITHIN GROUP clause cannot be used together. OVER is for window functions, whereas WITHIN GROUP is for ordered set aggregate functions +SELECT approx_percentile_cont(0.5) +WITHIN GROUP (ORDER BY c3) +OVER (ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) FROM aggregate_test_100 # array agg can use order by @@ -195,6 +208,56 @@ query error Execution error: In an aggregate with DISTINCT, ORDER BY expressions SELECT array_agg(DISTINCT c13 ORDER BY c13, c12) FROM aggregate_test_100 +query ?? rowsort +with tbl as (SELECT * FROM (VALUES ('xxx', 'yyy'), ('xxx', 'yyy'), ('xxx2', 'yyy2')) AS t(x, y)) +select + array_agg(x order by x) as x_agg, + array_agg(y order by y) as y_agg +from tbl +group by all +---- +[xxx, xxx, xxx2] [yyy, yyy, yyy2] + +query ?? +SELECT + (SELECT array_agg(c12 ORDER BY c12) FROM aggregate_test_100), + (SELECT array_agg(c13 ORDER BY c13) FROM aggregate_test_100) +---- +[0.01479305307777301, 0.02182578039211991, 0.03968347085780355, 0.04429073092078406, 0.047343434291126085, 0.04893135681998029, 0.0494924465469434, 0.05573662213439634, 0.05636955101974106, 0.061029375346466685, 0.07260475960924484, 0.09465635123783445, 0.12357539988406441, 0.152498292971736, 0.16301110515739792, 0.1640882545084913, 0.1754261586710173, 0.17592486905979987, 0.17909035118828576, 0.18628859265874176, 0.19113293583306745, 0.2145232647388039, 0.21535402343780985, 0.24899794314659673, 0.2537253407987472, 0.2667177795079635, 0.27159190516490006, 0.2739938529235548, 0.28534428578703896, 0.2944158618048994, 0.296036538664718, 0.3051364088814128, 0.30585375151301186, 0.3114712539863804, 0.3231750610081745, 0.32869374687050157, 0.33639590659276175, 0.3600766362333053, 0.36936304600612724, 0.38870280983958583, 0.39144436569161134, 0.40342283197779727, 0.4094218353587008, 0.40975383525297016, 0.42073125331890115, 0.4273123318932347, 0.42950521730777025, 0.4830878559436823, 0.5081765563442366, 0.5437595540422571, 0.5590205548347534, 0.5593249815276734, 0.5603062368164834, 0.560333188635217, 0.5614503754617461, 0.565352842229935, 0.574210838214554, 0.5759450483859969, 0.5773498217058918, 0.5991138115095911, 0.6009475544728957, 0.6108938307533, 0.6316565296547284, 0.6404495093354053, 0.6405262429561641, 0.6425694115212065, 0.658671129040488, 0.6668423897406515, 0.6864391962767343, 0.7035635283169166, 0.7325106678655877, 0.7328050041291218, 0.7614304100703713, 0.7631239070049998, 0.7670021786149205, 0.7697753383420857, 0.7764360990307122, 0.7784918983501654, 0.7973920072996036, 0.819715865079681, 0.8506721053047003, 0.8813167497816289, 0.8824879447595726, 0.9185813970744787, 0.9231889896940375, 0.9237877978193884, 0.9255031346434324, 0.9293883502480845, 0.9294097332465232, 0.9463098243875633, 0.946325164889271, 0.9491397432856566, 0.9567595541247681, 0.9706712283358269, 0.9723580396501548, 0.9748360509016578, 0.9800193410444061, 0.980809631269599, 0.991517828651004, 0.9965400387585364] [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8, 3BEOHQsMEFZ58VcNTOJYShTBpAPzbt, 4HX6feIvmNXBN7XGqgO4YVBkhu8GDI, 4JznSdBajNWhu4hRQwjV1FjTTxY68i, 52mKlRE3aHCBZtjECq6sY9OqVf8Dze, 56MZa5O1hVtX4c5sbnCfxuX5kDChqI, 6FPJlLAcaQ5uokyOWZ9HGdLZObFvOZ, 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW, 6oIXZuIPIqEoPBvFmbt2Nxy3tryGUE, 6x93sxYioWuq5c9Kkk8oTAAORM7cH0, 802bgTGl6Bk5TlkPYYTxp5JkKyaYUA, 8LIh0b6jmDGm87BmIyjdxNIpX4ugjD, 90gAtmGEeIqUTbo1ZrxCvWtsseukXC, 9UbObCsVkmYpJGcGrgfK90qOnwb2Lj, AFGCj7OWlEB5QfniEFgonMq90Tq5uH, ALuRhobVWbnQTTWZdSOk0iVe8oYFhW, Amn2K87Db5Es3dFQO9cw9cvpAM6h35, AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz, BJqx5WokrmrrezZA0dUbleMYkG5U2O, BPtQMxnuSPpxMExYV9YkDa6cAN7GP3, BsM5ZAYifRh5Lw3Y8X1r53I0cTJnfE, C2GT5KVyOPZpgKVl110TyZO0NcJ434, DuJNG8tufSqW0ZstHqWj3aGvFLMg4A, EcCuckwsF3gV1Ecgmh5v4KM8g1ozif, ErJFw6hzZ5fmI5r8bhE4JzlscnhKZU, F7NSTjWvQJyBburN7CXRUlbgp2dIrA, Fi4rJeTQq4eXj8Lxg3Hja5hBVTVV5u, H5j5ZHy1FGesOAHjkQEDYCucbpKWRu, HKSMQ9nTnwXCJIte1JrM1dtYnDtJ8g, IWl0G3ZlMNf7WT8yjIB49cx7MmYOmr, IZTkHMLvIKuiLjhDjYMmIHxh166we4, Ig1QcuKsjHXkproePdERo2w0mYzIqd, JHNgc2UCaiXOdmkxwDDyGhRlO0mnBQ, JN0VclewmjwYlSl8386MlWv5rEhWCz, JafwVLSVk5AVoXFuzclesQ000EE2k1, KJFcmTVjdkCMv94wYCtfHMFhzyRsmH, Ktb7GQ0N1DrxwkCkEUsTaIXk0xYinn, Ld2ej8NEv5zNcqU60FwpHeZKBhfpiV, LiEBxds3X0Uw0lxiYjDqrkAaAwoiIW, MXhhH1Var3OzzJCtI9VNyYvA0q8UyJ, MeSTAXq8gVxVjbEjgkvU9YLte0X9uE, NEhyk8uIx4kEULJGa8qIyFjjBcP2G6, O66j6PaYuZhEUtqV6fuU7TyjM2WxC5, OF7fQ37GzaZ5ikA2oMyvleKtgnLjXh, OPwBqCEK5PWTjWaiOyL45u2NLTaDWv, Oq6J4Rx6nde0YlhOIJkFsX2MsSvAQ0, Ow5PGpfTm4dXCfTDsXAOTatXRoAydR, QEHVvcP8gxI6EMJIrvcnIhgzPNjIvv, QJYm7YRA3YetcBHI5wkMZeLXVmfuNy, QYlaIAnJA6r8rlAb6f59wcxvcPcWFf, RilTlL1tKkPOUFuzmLydHAVZwv1OGl, Sfx0vxv1skzZWT1PqVdoRDdO6Sb6xH, TTQUwpMNSXZqVBKAFvXu7OlWvKXJKX, TtDKUZxzVxsq758G6AWPSYuZgVgbcl, VDhtJkYjAYPykCgOU9x3v7v3t4SO1a, VY0zXmXeksCT8BzvpzpPLbmU9Kp9Y4, Vp3gmWunM5A7wOC9YW2JroFqTWjvTi, WHmjWk2AY4c6m7DA4GitUx6nmb1yYS, XemNcT1xp61xcM1Qz3wZ1VECCnq06O, Z2sWcQr0qyCJRMHDpRy3aQr7PkHtkK, aDxBtor7Icd9C5hnTvvw5NrIre740e, akiiY5N0I44CMwEnBL6RTBk7BRkxEj, b3b9esRhTzFEawbs6XhpKnD9ojutHB, bgK1r6v3BCTh0aejJUhkA1Hn6idXGp, cBGc0kSm32ylBDnxogG727C0uhZEYZ, cq4WSAIFwx3wwTUS5bp1wCe71R6U5I, dVdvo6nUD5FgCgsbOZLds28RyGTpnx, e2Gh6Ov8XkXoFdJWhl0EjwEHlMDYyG, f9ALCzwDAKmdu7Rk2msJaB1wxe5IBX, fuyvs0w7WsKSlXqJ1e6HFSoLmx03AG, gTpyQnEODMcpsPnJMZC66gh33i3m0b, gpo8K5qtYePve6jyPt6xgJx4YOVjms, gxfHWUF8XgY2KdFxigxvNEXe2V2XMl, i6RQVXKUh7MzuGMDaNclUYnFUAireU, ioEncce3mPOXD2hWhpZpCPWGATG6GU, jQimhdepw3GKmioWUlVSWeBVRKFkY3, l7uwDoTepWwnAP0ufqtHJS3CRi7RfP, lqhzgLsXZ8JhtpeeUWWNbMz8PHI705, m6jD0LBIQWaMfenwRCTANI9eOdyyto, mhjME0zBHbrK6NMkytMTQzOssOa1gF, mzbkwXKrPeZnxg2Kn1LRF5hYSsmksS, nYVJnVicpGRqKZibHyBAmtmzBXAFfT, oHJMNvWuunsIMIWFnYG31RCfkOo2V7, oLZ21P2JEDooxV1pU31cIxQHEeeoLu, okOkcWflkNXIy4R8LzmySyY1EC3sYd, pLk3i59bZwd5KBZrI1FiweYTd5hteG, pTeu0WMjBRTaNRT15rLCuEh3tBJVc5, qnPOOmslCJaT45buUisMRnM0rc77EK, t6fQUjJejPcjc04wHvHTPe55S65B4V, ukOiFGGFnQJDHFgZxHMpvhD3zybF0M, ukyD7b0Efj7tNlFSRmzZ0IqkEzg2a8, waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs, wwXqSGKLyBQyPkonlzBNYUJTCo4LRS, xipQ93429ksjNcXPX5326VSg1xJZcW, y7C453hRWd4E7ImjNDWlpexB8nUqjh, ydkwycaISlYSlEq3TlkS2m15I2pcp8] + +query ?? +SELECT + array_agg(c12 ORDER BY c12), + array_agg(c13 ORDER BY c13) +FROM aggregate_test_100 +---- +[0.01479305307777301, 0.02182578039211991, 0.03968347085780355, 0.04429073092078406, 0.047343434291126085, 0.04893135681998029, 0.0494924465469434, 0.05573662213439634, 0.05636955101974106, 0.061029375346466685, 0.07260475960924484, 0.09465635123783445, 0.12357539988406441, 0.152498292971736, 0.16301110515739792, 0.1640882545084913, 0.1754261586710173, 0.17592486905979987, 0.17909035118828576, 0.18628859265874176, 0.19113293583306745, 0.2145232647388039, 0.21535402343780985, 0.24899794314659673, 0.2537253407987472, 0.2667177795079635, 0.27159190516490006, 0.2739938529235548, 0.28534428578703896, 0.2944158618048994, 0.296036538664718, 0.3051364088814128, 0.30585375151301186, 0.3114712539863804, 0.3231750610081745, 0.32869374687050157, 0.33639590659276175, 0.3600766362333053, 0.36936304600612724, 0.38870280983958583, 0.39144436569161134, 0.40342283197779727, 0.4094218353587008, 0.40975383525297016, 0.42073125331890115, 0.4273123318932347, 0.42950521730777025, 0.4830878559436823, 0.5081765563442366, 0.5437595540422571, 0.5590205548347534, 0.5593249815276734, 0.5603062368164834, 0.560333188635217, 0.5614503754617461, 0.565352842229935, 0.574210838214554, 0.5759450483859969, 0.5773498217058918, 0.5991138115095911, 0.6009475544728957, 0.6108938307533, 0.6316565296547284, 0.6404495093354053, 0.6405262429561641, 0.6425694115212065, 0.658671129040488, 0.6668423897406515, 0.6864391962767343, 0.7035635283169166, 0.7325106678655877, 0.7328050041291218, 0.7614304100703713, 0.7631239070049998, 0.7670021786149205, 0.7697753383420857, 0.7764360990307122, 0.7784918983501654, 0.7973920072996036, 0.819715865079681, 0.8506721053047003, 0.8813167497816289, 0.8824879447595726, 0.9185813970744787, 0.9231889896940375, 0.9237877978193884, 0.9255031346434324, 0.9293883502480845, 0.9294097332465232, 0.9463098243875633, 0.946325164889271, 0.9491397432856566, 0.9567595541247681, 0.9706712283358269, 0.9723580396501548, 0.9748360509016578, 0.9800193410444061, 0.980809631269599, 0.991517828651004, 0.9965400387585364] [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8, 3BEOHQsMEFZ58VcNTOJYShTBpAPzbt, 4HX6feIvmNXBN7XGqgO4YVBkhu8GDI, 4JznSdBajNWhu4hRQwjV1FjTTxY68i, 52mKlRE3aHCBZtjECq6sY9OqVf8Dze, 56MZa5O1hVtX4c5sbnCfxuX5kDChqI, 6FPJlLAcaQ5uokyOWZ9HGdLZObFvOZ, 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW, 6oIXZuIPIqEoPBvFmbt2Nxy3tryGUE, 6x93sxYioWuq5c9Kkk8oTAAORM7cH0, 802bgTGl6Bk5TlkPYYTxp5JkKyaYUA, 8LIh0b6jmDGm87BmIyjdxNIpX4ugjD, 90gAtmGEeIqUTbo1ZrxCvWtsseukXC, 9UbObCsVkmYpJGcGrgfK90qOnwb2Lj, AFGCj7OWlEB5QfniEFgonMq90Tq5uH, ALuRhobVWbnQTTWZdSOk0iVe8oYFhW, Amn2K87Db5Es3dFQO9cw9cvpAM6h35, AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz, BJqx5WokrmrrezZA0dUbleMYkG5U2O, BPtQMxnuSPpxMExYV9YkDa6cAN7GP3, BsM5ZAYifRh5Lw3Y8X1r53I0cTJnfE, C2GT5KVyOPZpgKVl110TyZO0NcJ434, DuJNG8tufSqW0ZstHqWj3aGvFLMg4A, EcCuckwsF3gV1Ecgmh5v4KM8g1ozif, ErJFw6hzZ5fmI5r8bhE4JzlscnhKZU, F7NSTjWvQJyBburN7CXRUlbgp2dIrA, Fi4rJeTQq4eXj8Lxg3Hja5hBVTVV5u, H5j5ZHy1FGesOAHjkQEDYCucbpKWRu, HKSMQ9nTnwXCJIte1JrM1dtYnDtJ8g, IWl0G3ZlMNf7WT8yjIB49cx7MmYOmr, IZTkHMLvIKuiLjhDjYMmIHxh166we4, Ig1QcuKsjHXkproePdERo2w0mYzIqd, JHNgc2UCaiXOdmkxwDDyGhRlO0mnBQ, JN0VclewmjwYlSl8386MlWv5rEhWCz, JafwVLSVk5AVoXFuzclesQ000EE2k1, KJFcmTVjdkCMv94wYCtfHMFhzyRsmH, Ktb7GQ0N1DrxwkCkEUsTaIXk0xYinn, Ld2ej8NEv5zNcqU60FwpHeZKBhfpiV, LiEBxds3X0Uw0lxiYjDqrkAaAwoiIW, MXhhH1Var3OzzJCtI9VNyYvA0q8UyJ, MeSTAXq8gVxVjbEjgkvU9YLte0X9uE, NEhyk8uIx4kEULJGa8qIyFjjBcP2G6, O66j6PaYuZhEUtqV6fuU7TyjM2WxC5, OF7fQ37GzaZ5ikA2oMyvleKtgnLjXh, OPwBqCEK5PWTjWaiOyL45u2NLTaDWv, Oq6J4Rx6nde0YlhOIJkFsX2MsSvAQ0, Ow5PGpfTm4dXCfTDsXAOTatXRoAydR, QEHVvcP8gxI6EMJIrvcnIhgzPNjIvv, QJYm7YRA3YetcBHI5wkMZeLXVmfuNy, QYlaIAnJA6r8rlAb6f59wcxvcPcWFf, RilTlL1tKkPOUFuzmLydHAVZwv1OGl, Sfx0vxv1skzZWT1PqVdoRDdO6Sb6xH, TTQUwpMNSXZqVBKAFvXu7OlWvKXJKX, TtDKUZxzVxsq758G6AWPSYuZgVgbcl, VDhtJkYjAYPykCgOU9x3v7v3t4SO1a, VY0zXmXeksCT8BzvpzpPLbmU9Kp9Y4, Vp3gmWunM5A7wOC9YW2JroFqTWjvTi, WHmjWk2AY4c6m7DA4GitUx6nmb1yYS, XemNcT1xp61xcM1Qz3wZ1VECCnq06O, Z2sWcQr0qyCJRMHDpRy3aQr7PkHtkK, aDxBtor7Icd9C5hnTvvw5NrIre740e, akiiY5N0I44CMwEnBL6RTBk7BRkxEj, b3b9esRhTzFEawbs6XhpKnD9ojutHB, bgK1r6v3BCTh0aejJUhkA1Hn6idXGp, cBGc0kSm32ylBDnxogG727C0uhZEYZ, cq4WSAIFwx3wwTUS5bp1wCe71R6U5I, dVdvo6nUD5FgCgsbOZLds28RyGTpnx, e2Gh6Ov8XkXoFdJWhl0EjwEHlMDYyG, f9ALCzwDAKmdu7Rk2msJaB1wxe5IBX, fuyvs0w7WsKSlXqJ1e6HFSoLmx03AG, gTpyQnEODMcpsPnJMZC66gh33i3m0b, gpo8K5qtYePve6jyPt6xgJx4YOVjms, gxfHWUF8XgY2KdFxigxvNEXe2V2XMl, i6RQVXKUh7MzuGMDaNclUYnFUAireU, ioEncce3mPOXD2hWhpZpCPWGATG6GU, jQimhdepw3GKmioWUlVSWeBVRKFkY3, l7uwDoTepWwnAP0ufqtHJS3CRi7RfP, lqhzgLsXZ8JhtpeeUWWNbMz8PHI705, m6jD0LBIQWaMfenwRCTANI9eOdyyto, mhjME0zBHbrK6NMkytMTQzOssOa1gF, mzbkwXKrPeZnxg2Kn1LRF5hYSsmksS, nYVJnVicpGRqKZibHyBAmtmzBXAFfT, oHJMNvWuunsIMIWFnYG31RCfkOo2V7, oLZ21P2JEDooxV1pU31cIxQHEeeoLu, okOkcWflkNXIy4R8LzmySyY1EC3sYd, pLk3i59bZwd5KBZrI1FiweYTd5hteG, pTeu0WMjBRTaNRT15rLCuEh3tBJVc5, qnPOOmslCJaT45buUisMRnM0rc77EK, t6fQUjJejPcjc04wHvHTPe55S65B4V, ukOiFGGFnQJDHFgZxHMpvhD3zybF0M, ukyD7b0Efj7tNlFSRmzZ0IqkEzg2a8, waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs, wwXqSGKLyBQyPkonlzBNYUJTCo4LRS, xipQ93429ksjNcXPX5326VSg1xJZcW, y7C453hRWd4E7ImjNDWlpexB8nUqjh, ydkwycaISlYSlEq3TlkS2m15I2pcp8] + +query ?? rowsort +with tbl as (SELECT * FROM (VALUES ('xxx', 'yyy'), ('xxx', 'yyy'), ('xxx2', 'yyy2')) AS t(x, y)) +select + array_agg(distinct x order by x) as x_agg, + array_agg(distinct y order by y) as y_agg +from tbl +group by all +---- +[xxx, xxx2] [yyy, yyy2] + +query ?? +SELECT + (SELECT array_agg(DISTINCT c12 ORDER BY c12) FROM aggregate_test_100), + (SELECT array_agg(DISTINCT c13 ORDER BY c13) FROM aggregate_test_100) +---- +[0.01479305307777301, 0.02182578039211991, 0.03968347085780355, 0.04429073092078406, 0.047343434291126085, 0.04893135681998029, 0.0494924465469434, 0.05573662213439634, 0.05636955101974106, 0.061029375346466685, 0.07260475960924484, 0.09465635123783445, 0.12357539988406441, 0.152498292971736, 0.16301110515739792, 0.1640882545084913, 0.1754261586710173, 0.17592486905979987, 0.17909035118828576, 0.18628859265874176, 0.19113293583306745, 0.2145232647388039, 0.21535402343780985, 0.24899794314659673, 0.2537253407987472, 0.2667177795079635, 0.27159190516490006, 0.2739938529235548, 0.28534428578703896, 0.2944158618048994, 0.296036538664718, 0.3051364088814128, 0.30585375151301186, 0.3114712539863804, 0.3231750610081745, 0.32869374687050157, 0.33639590659276175, 0.3600766362333053, 0.36936304600612724, 0.38870280983958583, 0.39144436569161134, 0.40342283197779727, 0.4094218353587008, 0.40975383525297016, 0.42073125331890115, 0.4273123318932347, 0.42950521730777025, 0.4830878559436823, 0.5081765563442366, 0.5437595540422571, 0.5590205548347534, 0.5593249815276734, 0.5603062368164834, 0.560333188635217, 0.5614503754617461, 0.565352842229935, 0.574210838214554, 0.5759450483859969, 0.5773498217058918, 0.5991138115095911, 0.6009475544728957, 0.6108938307533, 0.6316565296547284, 0.6404495093354053, 0.6405262429561641, 0.6425694115212065, 0.658671129040488, 0.6668423897406515, 0.6864391962767343, 0.7035635283169166, 0.7325106678655877, 0.7328050041291218, 0.7614304100703713, 0.7631239070049998, 0.7670021786149205, 0.7697753383420857, 0.7764360990307122, 0.7784918983501654, 0.7973920072996036, 0.819715865079681, 0.8506721053047003, 0.8813167497816289, 0.8824879447595726, 0.9185813970744787, 0.9231889896940375, 0.9237877978193884, 0.9255031346434324, 0.9293883502480845, 0.9294097332465232, 0.9463098243875633, 0.946325164889271, 0.9491397432856566, 0.9567595541247681, 0.9706712283358269, 0.9723580396501548, 0.9748360509016578, 0.9800193410444061, 0.980809631269599, 0.991517828651004, 0.9965400387585364] [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8, 3BEOHQsMEFZ58VcNTOJYShTBpAPzbt, 4HX6feIvmNXBN7XGqgO4YVBkhu8GDI, 4JznSdBajNWhu4hRQwjV1FjTTxY68i, 52mKlRE3aHCBZtjECq6sY9OqVf8Dze, 56MZa5O1hVtX4c5sbnCfxuX5kDChqI, 6FPJlLAcaQ5uokyOWZ9HGdLZObFvOZ, 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW, 6oIXZuIPIqEoPBvFmbt2Nxy3tryGUE, 6x93sxYioWuq5c9Kkk8oTAAORM7cH0, 802bgTGl6Bk5TlkPYYTxp5JkKyaYUA, 8LIh0b6jmDGm87BmIyjdxNIpX4ugjD, 90gAtmGEeIqUTbo1ZrxCvWtsseukXC, 9UbObCsVkmYpJGcGrgfK90qOnwb2Lj, AFGCj7OWlEB5QfniEFgonMq90Tq5uH, ALuRhobVWbnQTTWZdSOk0iVe8oYFhW, Amn2K87Db5Es3dFQO9cw9cvpAM6h35, AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz, BJqx5WokrmrrezZA0dUbleMYkG5U2O, BPtQMxnuSPpxMExYV9YkDa6cAN7GP3, BsM5ZAYifRh5Lw3Y8X1r53I0cTJnfE, C2GT5KVyOPZpgKVl110TyZO0NcJ434, DuJNG8tufSqW0ZstHqWj3aGvFLMg4A, EcCuckwsF3gV1Ecgmh5v4KM8g1ozif, ErJFw6hzZ5fmI5r8bhE4JzlscnhKZU, F7NSTjWvQJyBburN7CXRUlbgp2dIrA, Fi4rJeTQq4eXj8Lxg3Hja5hBVTVV5u, H5j5ZHy1FGesOAHjkQEDYCucbpKWRu, HKSMQ9nTnwXCJIte1JrM1dtYnDtJ8g, IWl0G3ZlMNf7WT8yjIB49cx7MmYOmr, IZTkHMLvIKuiLjhDjYMmIHxh166we4, Ig1QcuKsjHXkproePdERo2w0mYzIqd, JHNgc2UCaiXOdmkxwDDyGhRlO0mnBQ, JN0VclewmjwYlSl8386MlWv5rEhWCz, JafwVLSVk5AVoXFuzclesQ000EE2k1, KJFcmTVjdkCMv94wYCtfHMFhzyRsmH, Ktb7GQ0N1DrxwkCkEUsTaIXk0xYinn, Ld2ej8NEv5zNcqU60FwpHeZKBhfpiV, LiEBxds3X0Uw0lxiYjDqrkAaAwoiIW, MXhhH1Var3OzzJCtI9VNyYvA0q8UyJ, MeSTAXq8gVxVjbEjgkvU9YLte0X9uE, NEhyk8uIx4kEULJGa8qIyFjjBcP2G6, O66j6PaYuZhEUtqV6fuU7TyjM2WxC5, OF7fQ37GzaZ5ikA2oMyvleKtgnLjXh, OPwBqCEK5PWTjWaiOyL45u2NLTaDWv, Oq6J4Rx6nde0YlhOIJkFsX2MsSvAQ0, Ow5PGpfTm4dXCfTDsXAOTatXRoAydR, QEHVvcP8gxI6EMJIrvcnIhgzPNjIvv, QJYm7YRA3YetcBHI5wkMZeLXVmfuNy, QYlaIAnJA6r8rlAb6f59wcxvcPcWFf, RilTlL1tKkPOUFuzmLydHAVZwv1OGl, Sfx0vxv1skzZWT1PqVdoRDdO6Sb6xH, TTQUwpMNSXZqVBKAFvXu7OlWvKXJKX, TtDKUZxzVxsq758G6AWPSYuZgVgbcl, VDhtJkYjAYPykCgOU9x3v7v3t4SO1a, VY0zXmXeksCT8BzvpzpPLbmU9Kp9Y4, Vp3gmWunM5A7wOC9YW2JroFqTWjvTi, WHmjWk2AY4c6m7DA4GitUx6nmb1yYS, XemNcT1xp61xcM1Qz3wZ1VECCnq06O, Z2sWcQr0qyCJRMHDpRy3aQr7PkHtkK, aDxBtor7Icd9C5hnTvvw5NrIre740e, akiiY5N0I44CMwEnBL6RTBk7BRkxEj, b3b9esRhTzFEawbs6XhpKnD9ojutHB, bgK1r6v3BCTh0aejJUhkA1Hn6idXGp, cBGc0kSm32ylBDnxogG727C0uhZEYZ, cq4WSAIFwx3wwTUS5bp1wCe71R6U5I, dVdvo6nUD5FgCgsbOZLds28RyGTpnx, e2Gh6Ov8XkXoFdJWhl0EjwEHlMDYyG, f9ALCzwDAKmdu7Rk2msJaB1wxe5IBX, fuyvs0w7WsKSlXqJ1e6HFSoLmx03AG, gTpyQnEODMcpsPnJMZC66gh33i3m0b, gpo8K5qtYePve6jyPt6xgJx4YOVjms, gxfHWUF8XgY2KdFxigxvNEXe2V2XMl, i6RQVXKUh7MzuGMDaNclUYnFUAireU, ioEncce3mPOXD2hWhpZpCPWGATG6GU, jQimhdepw3GKmioWUlVSWeBVRKFkY3, l7uwDoTepWwnAP0ufqtHJS3CRi7RfP, lqhzgLsXZ8JhtpeeUWWNbMz8PHI705, m6jD0LBIQWaMfenwRCTANI9eOdyyto, mhjME0zBHbrK6NMkytMTQzOssOa1gF, mzbkwXKrPeZnxg2Kn1LRF5hYSsmksS, nYVJnVicpGRqKZibHyBAmtmzBXAFfT, oHJMNvWuunsIMIWFnYG31RCfkOo2V7, oLZ21P2JEDooxV1pU31cIxQHEeeoLu, okOkcWflkNXIy4R8LzmySyY1EC3sYd, pLk3i59bZwd5KBZrI1FiweYTd5hteG, pTeu0WMjBRTaNRT15rLCuEh3tBJVc5, qnPOOmslCJaT45buUisMRnM0rc77EK, t6fQUjJejPcjc04wHvHTPe55S65B4V, ukOiFGGFnQJDHFgZxHMpvhD3zybF0M, ukyD7b0Efj7tNlFSRmzZ0IqkEzg2a8, waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs, wwXqSGKLyBQyPkonlzBNYUJTCo4LRS, xipQ93429ksjNcXPX5326VSg1xJZcW, y7C453hRWd4E7ImjNDWlpexB8nUqjh, ydkwycaISlYSlEq3TlkS2m15I2pcp8] + +query ?? +SELECT + array_agg(DISTINCT c12 ORDER BY c12), + array_agg(DISTINCT c13 ORDER BY c13) +FROM aggregate_test_100 +---- +[0.01479305307777301, 0.02182578039211991, 0.03968347085780355, 0.04429073092078406, 0.047343434291126085, 0.04893135681998029, 0.0494924465469434, 0.05573662213439634, 0.05636955101974106, 0.061029375346466685, 0.07260475960924484, 0.09465635123783445, 0.12357539988406441, 0.152498292971736, 0.16301110515739792, 0.1640882545084913, 0.1754261586710173, 0.17592486905979987, 0.17909035118828576, 0.18628859265874176, 0.19113293583306745, 0.2145232647388039, 0.21535402343780985, 0.24899794314659673, 0.2537253407987472, 0.2667177795079635, 0.27159190516490006, 0.2739938529235548, 0.28534428578703896, 0.2944158618048994, 0.296036538664718, 0.3051364088814128, 0.30585375151301186, 0.3114712539863804, 0.3231750610081745, 0.32869374687050157, 0.33639590659276175, 0.3600766362333053, 0.36936304600612724, 0.38870280983958583, 0.39144436569161134, 0.40342283197779727, 0.4094218353587008, 0.40975383525297016, 0.42073125331890115, 0.4273123318932347, 0.42950521730777025, 0.4830878559436823, 0.5081765563442366, 0.5437595540422571, 0.5590205548347534, 0.5593249815276734, 0.5603062368164834, 0.560333188635217, 0.5614503754617461, 0.565352842229935, 0.574210838214554, 0.5759450483859969, 0.5773498217058918, 0.5991138115095911, 0.6009475544728957, 0.6108938307533, 0.6316565296547284, 0.6404495093354053, 0.6405262429561641, 0.6425694115212065, 0.658671129040488, 0.6668423897406515, 0.6864391962767343, 0.7035635283169166, 0.7325106678655877, 0.7328050041291218, 0.7614304100703713, 0.7631239070049998, 0.7670021786149205, 0.7697753383420857, 0.7764360990307122, 0.7784918983501654, 0.7973920072996036, 0.819715865079681, 0.8506721053047003, 0.8813167497816289, 0.8824879447595726, 0.9185813970744787, 0.9231889896940375, 0.9237877978193884, 0.9255031346434324, 0.9293883502480845, 0.9294097332465232, 0.9463098243875633, 0.946325164889271, 0.9491397432856566, 0.9567595541247681, 0.9706712283358269, 0.9723580396501548, 0.9748360509016578, 0.9800193410444061, 0.980809631269599, 0.991517828651004, 0.9965400387585364] [0VVIHzxWtNOFLtnhjHEKjXaJOSLJfm, 0keZ5G8BffGwgF2RwQD59TFzMStxCB, 0og6hSkhbX8AC1ktFS4kounvTzy8Vo, 1aOcrEGd0cOqZe2I5XBOm0nDcwtBZO, 2T3wSlHdEmASmO0xcXHnndkKEt6bz8, 3BEOHQsMEFZ58VcNTOJYShTBpAPzbt, 4HX6feIvmNXBN7XGqgO4YVBkhu8GDI, 4JznSdBajNWhu4hRQwjV1FjTTxY68i, 52mKlRE3aHCBZtjECq6sY9OqVf8Dze, 56MZa5O1hVtX4c5sbnCfxuX5kDChqI, 6FPJlLAcaQ5uokyOWZ9HGdLZObFvOZ, 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW, 6oIXZuIPIqEoPBvFmbt2Nxy3tryGUE, 6x93sxYioWuq5c9Kkk8oTAAORM7cH0, 802bgTGl6Bk5TlkPYYTxp5JkKyaYUA, 8LIh0b6jmDGm87BmIyjdxNIpX4ugjD, 90gAtmGEeIqUTbo1ZrxCvWtsseukXC, 9UbObCsVkmYpJGcGrgfK90qOnwb2Lj, AFGCj7OWlEB5QfniEFgonMq90Tq5uH, ALuRhobVWbnQTTWZdSOk0iVe8oYFhW, Amn2K87Db5Es3dFQO9cw9cvpAM6h35, AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz, BJqx5WokrmrrezZA0dUbleMYkG5U2O, BPtQMxnuSPpxMExYV9YkDa6cAN7GP3, BsM5ZAYifRh5Lw3Y8X1r53I0cTJnfE, C2GT5KVyOPZpgKVl110TyZO0NcJ434, DuJNG8tufSqW0ZstHqWj3aGvFLMg4A, EcCuckwsF3gV1Ecgmh5v4KM8g1ozif, ErJFw6hzZ5fmI5r8bhE4JzlscnhKZU, F7NSTjWvQJyBburN7CXRUlbgp2dIrA, Fi4rJeTQq4eXj8Lxg3Hja5hBVTVV5u, H5j5ZHy1FGesOAHjkQEDYCucbpKWRu, HKSMQ9nTnwXCJIte1JrM1dtYnDtJ8g, IWl0G3ZlMNf7WT8yjIB49cx7MmYOmr, IZTkHMLvIKuiLjhDjYMmIHxh166we4, Ig1QcuKsjHXkproePdERo2w0mYzIqd, JHNgc2UCaiXOdmkxwDDyGhRlO0mnBQ, JN0VclewmjwYlSl8386MlWv5rEhWCz, JafwVLSVk5AVoXFuzclesQ000EE2k1, KJFcmTVjdkCMv94wYCtfHMFhzyRsmH, Ktb7GQ0N1DrxwkCkEUsTaIXk0xYinn, Ld2ej8NEv5zNcqU60FwpHeZKBhfpiV, LiEBxds3X0Uw0lxiYjDqrkAaAwoiIW, MXhhH1Var3OzzJCtI9VNyYvA0q8UyJ, MeSTAXq8gVxVjbEjgkvU9YLte0X9uE, NEhyk8uIx4kEULJGa8qIyFjjBcP2G6, O66j6PaYuZhEUtqV6fuU7TyjM2WxC5, OF7fQ37GzaZ5ikA2oMyvleKtgnLjXh, OPwBqCEK5PWTjWaiOyL45u2NLTaDWv, Oq6J4Rx6nde0YlhOIJkFsX2MsSvAQ0, Ow5PGpfTm4dXCfTDsXAOTatXRoAydR, QEHVvcP8gxI6EMJIrvcnIhgzPNjIvv, QJYm7YRA3YetcBHI5wkMZeLXVmfuNy, QYlaIAnJA6r8rlAb6f59wcxvcPcWFf, RilTlL1tKkPOUFuzmLydHAVZwv1OGl, Sfx0vxv1skzZWT1PqVdoRDdO6Sb6xH, TTQUwpMNSXZqVBKAFvXu7OlWvKXJKX, TtDKUZxzVxsq758G6AWPSYuZgVgbcl, VDhtJkYjAYPykCgOU9x3v7v3t4SO1a, VY0zXmXeksCT8BzvpzpPLbmU9Kp9Y4, Vp3gmWunM5A7wOC9YW2JroFqTWjvTi, WHmjWk2AY4c6m7DA4GitUx6nmb1yYS, XemNcT1xp61xcM1Qz3wZ1VECCnq06O, Z2sWcQr0qyCJRMHDpRy3aQr7PkHtkK, aDxBtor7Icd9C5hnTvvw5NrIre740e, akiiY5N0I44CMwEnBL6RTBk7BRkxEj, b3b9esRhTzFEawbs6XhpKnD9ojutHB, bgK1r6v3BCTh0aejJUhkA1Hn6idXGp, cBGc0kSm32ylBDnxogG727C0uhZEYZ, cq4WSAIFwx3wwTUS5bp1wCe71R6U5I, dVdvo6nUD5FgCgsbOZLds28RyGTpnx, e2Gh6Ov8XkXoFdJWhl0EjwEHlMDYyG, f9ALCzwDAKmdu7Rk2msJaB1wxe5IBX, fuyvs0w7WsKSlXqJ1e6HFSoLmx03AG, gTpyQnEODMcpsPnJMZC66gh33i3m0b, gpo8K5qtYePve6jyPt6xgJx4YOVjms, gxfHWUF8XgY2KdFxigxvNEXe2V2XMl, i6RQVXKUh7MzuGMDaNclUYnFUAireU, ioEncce3mPOXD2hWhpZpCPWGATG6GU, jQimhdepw3GKmioWUlVSWeBVRKFkY3, l7uwDoTepWwnAP0ufqtHJS3CRi7RfP, lqhzgLsXZ8JhtpeeUWWNbMz8PHI705, m6jD0LBIQWaMfenwRCTANI9eOdyyto, mhjME0zBHbrK6NMkytMTQzOssOa1gF, mzbkwXKrPeZnxg2Kn1LRF5hYSsmksS, nYVJnVicpGRqKZibHyBAmtmzBXAFfT, oHJMNvWuunsIMIWFnYG31RCfkOo2V7, oLZ21P2JEDooxV1pU31cIxQHEeeoLu, okOkcWflkNXIy4R8LzmySyY1EC3sYd, pLk3i59bZwd5KBZrI1FiweYTd5hteG, pTeu0WMjBRTaNRT15rLCuEh3tBJVc5, qnPOOmslCJaT45buUisMRnM0rc77EK, t6fQUjJejPcjc04wHvHTPe55S65B4V, ukOiFGGFnQJDHFgZxHMpvhD3zybF0M, ukyD7b0Efj7tNlFSRmzZ0IqkEzg2a8, waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs, wwXqSGKLyBQyPkonlzBNYUJTCo4LRS, xipQ93429ksjNcXPX5326VSg1xJZcW, y7C453hRWd4E7ImjNDWlpexB8nUqjh, ydkwycaISlYSlEq3TlkS2m15I2pcp8] + statement ok CREATE EXTERNAL TABLE agg_order ( c1 INT NOT NULL, @@ -289,17 +352,19 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES ('b', [1,0]), ('b', [1,0]), ('b', [1,0]), - ('b', [0,1]) + ('b', [0,1]), + (NULL, [0,1]), + ('b', NULL) ; # Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort, # so they are covered in `datafusion/functions-aggregate/src/array_agg.rs` query ?? select array_sort(c1), array_sort(c2) from ( - select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table + select array_agg(distinct column1) as c1, array_agg(distinct column2) ignore nulls as c2 from array_agg_distinct_list_table ); ---- -[b, w] [[0, 1], [1, 0]] +[NULL, b, w] [[0, 1], [1, 0]] statement ok drop table array_agg_distinct_list_table; @@ -347,15 +412,15 @@ logical_plan 04)------SubqueryAlias: a 05)--------Union 06)----------Projection: Int64(1) AS id, Int64(2) AS foo -07)------------EmptyRelation +07)------------EmptyRelation: rows=1 08)----------Projection: Int64(1) AS id, Int64(NULL) AS foo -09)------------EmptyRelation +09)------------EmptyRelation: rows=1 10)----------Projection: Int64(1) AS id, Int64(NULL) AS foo -11)------------EmptyRelation +11)------------EmptyRelation: rows=1 12)----------Projection: Int64(1) AS id, Int64(3) AS foo -13)------------EmptyRelation +13)------------EmptyRelation: rows=1 14)----------Projection: Int64(1) AS id, Int64(2) AS foo -15)------------EmptyRelation +15)------------EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[array_length(array_agg(DISTINCT a.foo)@1) as array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))] 02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted @@ -492,7 +557,7 @@ SELECT corr(c2, c12) FROM aggregate_test_100 query R select corr(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq ---- -0 +NULL # all_nulls_query_correlation query R @@ -660,10 +725,6 @@ SELECT c2, var_samp(c12) FROM aggregate_test_100 WHERE c12 > 0.90 GROUP BY c2 OR 4 NULL 5 0.000269544643 -# Use PostgresSQL dialect -statement ok -set datafusion.sql_parser.dialect = 'Postgres'; - # csv_query_stddev_12 query IR SELECT c2, var_samp(c12) FILTER (WHERE c12 > 0.95) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2 @@ -674,9 +735,30 @@ SELECT c2, var_samp(c12) FILTER (WHERE c12 > 0.95) FROM aggregate_test_100 GROUP 4 NULL 5 NULL -# Restore the default dialect statement ok -set datafusion.sql_parser.dialect = 'Generic'; +CREATE TABLE t ( + a DOUBLE, + b BIGINT, + c INT +) AS VALUES +(1.0, 10, -5), +(2.0, 20, -5), +(3.0, 20, 4); + +# https://github.com/apache/datafusion/issues/15291 +query III +WITH s AS ( + SELECT + COUNT(a) FILTER (WHERE (b * b) - 3600 <= b), + COUNT(a) FILTER (WHERE (b * b) - 3000 <= b AND (c >= 0)), + COUNT(a) FILTER (WHERE (b * b) - 3000 <= b AND (c >= 0) AND (c >= 0)) + FROM t +) SELECT * FROM s +---- +3 1 1 + +statement ok +DROP TABLE t # csv_query_stddev_13 query IR @@ -1244,12 +1326,24 @@ SELECT COUNT(2) FROM aggregate_test_100 # ---- # 100 99 +# csv_query_approx_count_literal_null +query I +SELECT approx_distinct(null) +---- +0 + # csv_query_approx_count_dupe_expr_aliased query II SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_100 ---- 100 100 +# csv_query_approx_count_date_timestamp +query IIIII +SELECT approx_distinct(c14) AS a, approx_distinct(c15) AS b, approx_distinct(arrow_cast(c15, 'Date64')), approx_distinct(arrow_cast(c15, 'Time32(Second)')) as c, approx_distinct(arrow_cast(c15, 'Time64(Nanosecond)')) AS d FROM aggregate_test_100 +---- +18 60 60 60 60 + ## This test executes the APPROX_PERCENTILE_CONT aggregation against the test ## data, asserting the estimated quantiles are ±5% their actual values. ## @@ -1274,7 +1368,24 @@ SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_10 ## Column `c12` is omitted due to a large relative error (~10%) due to the small ## float values. -#csv_query_approx_percentile_cont (c2) +# csv_query_approx_percentile_cont (c2) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + + +# csv_query_approx_percentile_cont (c2, alternate syntax, should be the same as above) query B SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 ---- @@ -1292,157 +1403,170 @@ true # csv_query_approx_percentile_cont (c3) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.1) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.5) AS DOUBLE) / 15.5) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / 15.5) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.9) AS DOUBLE) / 102.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / 102.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c4) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.1) AS DOUBLE) / -22925.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / -22925.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.5) AS DOUBLE) / 4599.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / 4599.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.9) AS DOUBLE) / 25334.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / 25334.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c5) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.1) AS DOUBLE) / -1882606710.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / -1882606710.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.5) AS DOUBLE) / 377164262.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / 377164262.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.9) AS DOUBLE) / 1991374996.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / 1991374996.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c6) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.1) AS DOUBLE) / -7250000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / -7250000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.5) AS DOUBLE) / 1130000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / 1130000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.9) AS DOUBLE) / 7370000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / 7370000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c7) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.1) AS DOUBLE) / 18.9) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 18.9) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.5) AS DOUBLE) / 134.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 134.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.9) AS DOUBLE) / 231.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 231.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c8) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.1) AS DOUBLE) / 2671.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 2671.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.5) AS DOUBLE) / 30634.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 30634.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.9) AS DOUBLE) / 57518.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 57518.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c9) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.1) AS DOUBLE) / 472608672.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 472608672.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.5) AS DOUBLE) / 2365817608.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 2365817608.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.9) AS DOUBLE) / 3776538487.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 3776538487.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c10) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.1) AS DOUBLE) / 1830000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 1830000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.5) AS DOUBLE) / 9300000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 9300000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.9) AS DOUBLE) / 16100000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 16100000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c11) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.1) AS DOUBLE) / 0.109) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.109) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.5) AS DOUBLE) / 0.491) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.491) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.9) AS DOUBLE) / 0.834) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.834) < 0.05) AS q FROM aggregate_test_100 ---- true # percentile_cont_with_nulls query I -SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); +SELECT APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); ---- 2 # percentile_cont_with_nulls_only query I -SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (CAST(NULL as INT))) as t (v); +SELECT APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (CAST(NULL as INT))) as t (v); +---- +NULL + +# percentile_cont_with_weight_with_nulls +query I +SELECT APPROX_PERCENTILE_CONT_WITH_WEIGHT(w, 0.5) WITHIN GROUP (ORDER BY v) +FROM (VALUES (1, 1), (2, 1), (3, 1), (4, NULL), (NULL, 1), (NULL, NULL)) as t (v, w); +---- +2 + +# percentile_cont_with_weight_nulls_only +query I +SELECT APPROX_PERCENTILE_CONT_WITH_WEIGHT(1, 0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (CAST(NULL as INT))) as t (v); ---- NULL @@ -1465,7 +1589,7 @@ NaN # ISSUE: https://github.com/apache/datafusion/issues/11870 query R -select APPROX_PERCENTILE_CONT(v2, 0.8) from tmp_percentile_cont; +select APPROX_PERCENTILE_CONT(0.8) WITHIN GROUP (ORDER BY v2) from tmp_percentile_cont; ---- NaN @@ -1473,10 +1597,10 @@ NaN # Note: `approx_percentile_cont_with_weight()` uses the same implementation as `approx_percentile_cont()` query R SELECT APPROX_PERCENTILE_CONT_WITH_WEIGHT( - v2, '+Inf'::Double, 0.9 ) +WITHIN GROUP (ORDER BY v2) FROM tmp_percentile_cont; ---- NaN @@ -1495,7 +1619,7 @@ INSERT INTO t1 VALUES (TRUE); # ISSUE: https://github.com/apache/datafusion/issues/12716 # This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' query R -SELECT approx_percentile_cont_with_weight('NaN'::DOUBLE, 0, 0) FROM t1 WHERE t1.v1; +SELECT approx_percentile_cont_with_weight(0, 0) WITHIN GROUP (ORDER BY 'NaN'::DOUBLE) FROM t1 WHERE t1.v1; ---- Infinity @@ -1722,6 +1846,17 @@ b NULL NULL 7732.315789473684 # csv_query_approx_percentile_cont_with_weight query TI +SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + + +# csv_query_approx_percentile_cont_with_weight (should be the same as above) +query TI SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 @@ -1730,8 +1865,50 @@ c 122 d 124 e 115 + +# using approx_percentile_cont on 2 columns with same signature +query TII +SELECT c1, approx_percentile_cont(c2, 0.95) AS c2, approx_percentile_cont(c3, 0.95) AS c3 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 5 73 +b 5 68 +c 5 122 +d 5 124 +e 5 115 + +# error is unique to this UDAF +query TRR +SELECT c1, avg(c2) AS c2, avg(c3) AS c3 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 2.857142857143 -18.333333333333 +b 3.263157894737 -5.842105263158 +c 2.666666666667 -1.333333333333 +d 2.444444444444 25.444444444444 +e 3 40.333333333333 + + + +query TI +SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a -101 +b -114 +c -109 +d -98 +e -93 + # csv_query_approx_percentile_cont_with_weight (2) query TI +SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + +# csv_query_approx_percentile_cont_with_weight alternate syntax +query TI SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 @@ -1740,9 +1917,19 @@ c 122 d 124 e 115 + +query TI +SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a -101 +b -114 +c -109 +d -98 +e -93 + # csv_query_approx_percentile_cont_with_histogram_bins query TI -SELECT c1, approx_percentile_cont(c3, 0.95, 200) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 b 68 @@ -1751,7 +1938,17 @@ d 124 e 115 query TI -SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 74 +b 68 +c 123 +d 124 +e 115 + +# approx_percentile_cont_with_weight with centroids +query TI +SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 74 b 68 @@ -2237,7 +2434,7 @@ drop table t; # test count with largeutf8 statement ok -create table t (c string) as values +create table t (c string) as values (arrow_cast('a', 'LargeUtf8')), (arrow_cast('b', 'LargeUtf8')), (arrow_cast(null, 'LargeUtf8')), @@ -2247,10 +2444,10 @@ create table t (c string) as values query T select arrow_typeof(c) from t; ---- -Utf8 -Utf8 -Utf8 -Utf8 +Utf8View +Utf8View +Utf8View +Utf8View query IT select count(c), arrow_typeof(count(c)) from t; @@ -2517,7 +2714,117 @@ select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; statement ok drop table t; +# correlation_f64_1 +statement ok +create table t (c1 double, c2 double) as values (1, 4), (2, 5), (3, 6); + +query RT rowsort +select corr(c1, c2), arrow_typeof(corr(c1, c2)) from t; +---- +1 Float64 + +# correlation with different numeric types (create test data) +statement ok +CREATE OR REPLACE TABLE corr_test( + int8_col TINYINT, + int16_col SMALLINT, + int32_col INT, + int64_col BIGINT, + uint32_col INT UNSIGNED, + float32_col FLOAT, + float64_col DOUBLE +) as VALUES +(1, 10, 100, 1000, 10000, 1.1, 10.1), +(2, 20, 200, 2000, 20000, 2.2, 20.2), +(3, 30, 300, 3000, 30000, 3.3, 30.3), +(4, 40, 400, 4000, 40000, 4.4, 40.4), +(5, 50, 500, 5000, 50000, 5.5, 50.5); + +# correlation using int32 and float64 +query R +SELECT corr(int32_col, float64_col) FROM corr_test; +---- +1 + +# correlation using int64 and int32 +query R +SELECT corr(int64_col, int32_col) FROM corr_test; +---- +1 + +# correlation using float32 and int8 +query R +SELECT corr(float32_col, int8_col) FROM corr_test; +---- +1 + +# correlation using uint32 and int16 +query R +SELECT corr(uint32_col, int16_col) FROM corr_test; +---- +1 + +# correlation with nulls +statement ok +CREATE OR REPLACE TABLE corr_nulls( + x INT, + y DOUBLE +) as VALUES +(1, 10.0), +(2, 20.0), +(NULL, 30.0), +(4, NULL), +(5, 50.0); + +# correlation with some nulls (should skip null pairs) +query R +SELECT corr(x, y) FROM corr_nulls; +---- +1 + +# correlation with single row (should return NULL) +statement ok +CREATE OR REPLACE TABLE corr_single_row( + x INT, + y DOUBLE +) as VALUES +(1, 10.0); + +query R +SELECT corr(x, y) FROM corr_single_row; +---- +NULL + +# correlation with all nulls +statement ok +CREATE OR REPLACE TABLE corr_all_nulls( + x INT, + y DOUBLE +) as VALUES +(NULL, NULL), +(NULL, NULL); + +query R +SELECT corr(x, y) FROM corr_all_nulls; +---- +NULL + +statement ok +drop table corr_test; + +statement ok +drop table corr_nulls; + +statement ok +drop table corr_single_row; + +statement ok +drop table corr_all_nulls; + # covariance_f64_4 +statement ok +drop table if exists t; + statement ok create table t (c1 double, c2 double) as values (1.1, 4.1), (2.0, 5.0), (3.0, 6.0); @@ -3041,7 +3348,7 @@ SELECT COUNT(DISTINCT c1) FROM test # test_approx_percentile_cont_decimal_support query TI -SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(cast(0.85 as decimal(10,2))) WITHIN GROUP (ORDER BY c2) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 4 b 5 @@ -3194,6 +3501,33 @@ select array_agg(column1) from t; statement ok drop table t; +# array_agg_ignore_nulls +statement ok +create table t as values (NULL, ''), (1, 'c'), (2, 'a'), (NULL, 'b'), (4, NULL), (NULL, NULL), (5, 'a'); + +query ? +select array_agg(column1) ignore nulls as c1 from t; +---- +[1, 2, 4, 5] + +query II +select count(*), array_length(array_agg(distinct column2) ignore nulls) from t; +---- +7 4 + +query ? +select array_agg(column2 order by column1) ignore nulls from t; +---- +[c, a, a, , b] + +query ? +select array_agg(DISTINCT column2 order by column2) ignore nulls from t; +---- +[, a, b, c] + +statement ok +drop table t; + # variance_single_value query RRRR select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq; @@ -4183,6 +4517,50 @@ DROP VIEW binary_views statement ok DROP TABLE strings; +############ FixedSizeBinary ############ + +statement ok +CREATE TABLE binaries +AS VALUES + (X'000103', 1), + (X'000104', 1), + (X'000101', 3), + (X'000103', 1), + (X'000102', 1), + (NULL, 1), + (NULL, 4), + (X'000104', 1), + (X'000109', 2), + (X'000103', 1), + (X'000101', 2); + +statement ok +CREATE VIEW fixed_size_binary_views +AS SELECT arrow_cast(column1, 'FixedSizeBinary(3)') as value, column2 as id FROM binaries; + +query I? +SELECT id, MIN(value) FROM fixed_size_binary_views GROUP BY id ORDER BY id; +---- +1 000102 +2 000101 +3 000101 +4 NULL + +query I? +SELECT id, MAX(value) FROM fixed_size_binary_views GROUP BY id ORDER BY id; +---- +1 000104 +2 000109 +3 000101 +4 NULL + +statement ok +DROP VIEW fixed_size_binary_views; + +statement ok +DROP TABLE binaries; + + ################# # End min_max on strings/binary with null values and groups ################# @@ -4408,9 +4786,7 @@ statement ok create table t as select arrow_cast(column1, 'Date32') as date32, - -- Workaround https://github.com/apache/arrow-rs/issues/4512 is fixed, can use this - -- arrow_cast(column1, 'Date64') as date64, - arrow_cast(arrow_cast(column1, 'Date32'), 'Date64') as date64, + arrow_cast(column1, 'Date64') as date64, column2 as names, column3 as tag from t_source; @@ -4736,7 +5112,7 @@ statement ok create table t (c1 decimal(10, 0), c2 int) as values (null, null), (null, null), (null, null); query RTIT -select +select sum(c1), arrow_typeof(sum(c1)), sum(c2), arrow_typeof(sum(c2)) from t; @@ -4813,10 +5189,6 @@ select c2, count(DISTINCT cast(c1 AS DECIMAL(10, 2))) from d_table GROUP BY c2 O A 2 B 2 -# Use PostgresSQL dialect -statement ok -set datafusion.sql_parser.dialect = 'Postgres'; - # Creating the table statement ok CREATE TABLE test_table (c1 INT, c2 INT, c3 INT) @@ -4946,29 +5318,130 @@ select c3, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c4), sum(c4) fro 700.1 2 15.15 30.3 10.1 20.2 0 NULL NULL 1 10.1 10.1 10.1 10.1 0 NULL -# Restore the default dialect -statement ok -set datafusion.sql_parser.dialect = 'Generic'; - ## Multiple distinct aggregates and dictionaries statement ok -create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); +create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')), (1, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); query IT -select * from dict_test; +select * from dict_test order by column1, column2; ---- +1 bar +1 foo 1 foo 2 bar query II -select count(distinct column1), count(distinct column2) from dict_test group by column1; +select count(distinct column1), count(distinct column2) from dict_test group by column1 order by column1; ---- -1 1 +1 2 1 1 statement ok drop table dict_test; +## count distinct dictionary with null values +statement ok +create table dict_null_test as + select arrow_cast(NULL, 'Dictionary(Int32, Utf8)') as d + from (values (1), (2), (3), (4), (5)); + +query I +select count(distinct d) from dict_null_test; +---- +0 + +statement ok +drop table dict_null_test; + +# avg_duration + +statement ok +create table d as values + (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1), + (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1); + +query ???? +SELECT avg(column1), avg(column2), avg(column3), avg(column4) FROM d; +---- +0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs + +query ????I +SELECT avg(column1), avg(column2), avg(column3), avg(column4), column5 FROM d GROUP BY column5; +---- +0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs 1 + +statement ok +drop table d; + +statement ok +create table d as values + (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1), + (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1), + (arrow_cast(5, 'Duration(Second)'), arrow_cast(10, 'Duration(Millisecond)'), arrow_cast(15, 'Duration(Microsecond)'), arrow_cast(20, 'Duration(Nanosecond)'), 2), + (arrow_cast(25, 'Duration(Second)'), arrow_cast(50, 'Duration(Millisecond)'), arrow_cast(75, 'Duration(Microsecond)'), arrow_cast(100, 'Duration(Nanosecond)'), 2), + (NULL, NULL, NULL, NULL, 1), + (NULL, NULL, NULL, NULL, 2); + + +query I? rowsort +SELECT column5, avg(column1) FROM d GROUP BY column5; +---- +1 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 15 secs + +query I?? rowsort +SELECT column5, column1, avg(column1) OVER (PARTITION BY column5 ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) as window_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs + +# Cumulative average window function +query I?? +SELECT column5, column1, avg(column1) OVER (ORDER BY column5, column1 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumulative_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 10 secs + +# Centered average window function +query I?? +SELECT column5, column1, avg(column1) OVER (ORDER BY column5 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as centered_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 6 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 5 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 13 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs + +statement ok +drop table d; + +statement ok +create table dn as values + (arrow_cast(10, 'Duration(Second)'), 'a', 1), + (arrow_cast(20, 'Duration(Second)'), 'a', 2), + (NULL, 'b', 1), + (arrow_cast(40, 'Duration(Second)'), 'b', 2), + (arrow_cast(50, 'Duration(Second)'), 'c', 1), + (NULL, 'c', 2); + +query T?I +SELECT column2, avg(column1), column3 FROM dn GROUP BY column2, column3 ORDER BY column2, column3; +---- +a 0 days 0 hours 0 mins 10 secs 1 +a 0 days 0 hours 0 mins 20 secs 2 +b NULL 1 +b 0 days 0 hours 0 mins 40 secs 2 +c 0 days 0 hours 0 mins 50 secs 1 +c NULL 2 + +statement ok +drop table dn; # Prepare the table with dictionary values for testing statement ok @@ -5005,8 +5478,10 @@ select avg(distinct x_dict) from value_dict; ---- 3 -query error +query RR select avg(x_dict), avg(distinct x_dict) from value_dict; +---- +2.625 3 query I select min(x_dict) from value_dict; @@ -5131,7 +5606,7 @@ physical_plan 08)--------------RepartitionExec: partitioning=Hash([c3@0], 4), input_partitions=4 09)----------------AggregateExec: mode=Partial, gby=[c3@1 as c3], aggr=[min(aggregate_test_100.c1)] 10)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -11)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3], file_type=csv, has_header=true +11)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c1, c3], file_type=csv, has_header=true # @@ -5156,16 +5631,16 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], file_type=csv, has_header=true +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c3], file_type=csv, has_header=true query I -SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 order by c3 limit 5; ---- -1 --40 -29 --85 --82 +-117 +-111 +-107 +-106 +-101 query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; @@ -5180,16 +5655,16 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[9] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query II -SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 order by c2, c3 limit 5 offset 4; ---- -5 -82 -4 -111 -3 104 -3 13 -1 38 +1 -56 +1 -25 +1 -24 +1 -8 +1 -5 # The limit should only apply to the aggregations which group by c3 query TT @@ -5215,15 +5690,15 @@ physical_plan 10)------------------CoalesceBatchesExec: target_batch_size=8192 11)--------------------FilterExec: c3@1 >= 10 AND c3@1 <= 20 12)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -13)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true +13)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query I -SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c3 order by c3 limit 4; ---- -13 -17 12 +13 14 +17 # An aggregate expression causes the limit to not be pushed to the aggregation query TT @@ -5241,7 +5716,7 @@ physical_plan 04)------CoalescePartitionsExec 05)--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[max(aggregate_test_100.c1)] 06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true +07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true # TODO(msirek): Extend checking in LimitedDistinctAggregation equal groupings to ignore the order of columns # in the group-by column lists, so the limit could be pushed to the lowest AggregateExec in this case @@ -5265,14 +5740,14 @@ physical_plan 08)--------------CoalescePartitionsExec 09)----------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[] 10)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -11)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true +11)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query II -SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c3, c2 order by c3, c2 limit 3 offset 10; ---- -57 1 --54 4 -112 3 +-95 3 +-94 5 +-90 4 query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; @@ -5289,7 +5764,7 @@ physical_plan 04)------CoalescePartitionsExec 05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] 06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true +07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query II SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; @@ -5316,7 +5791,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], file_type=csv, has_header=true +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c3], file_type=csv, has_header=true statement ok set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true; @@ -5388,7 +5863,7 @@ NULL NULL 3 NULL 1 4 0 8 0 # regr_*() basic tests query RRIRRRRRR -select +select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), @@ -5403,7 +5878,7 @@ from (values (1,2), (2,4), (3,6)); 2 0 3 1 2 4 2 8 4 query RRIRRRRRR -select +select regr_slope(c12, c11), regr_intercept(c12, c11), regr_count(c12, c11), @@ -5421,7 +5896,7 @@ from aggregate_test_100; # regr_*() functions ignore NULLs query RRIRRRRRR -select +select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), @@ -5436,7 +5911,7 @@ from (values (1,NULL), (2,4), (3,6)); 2 0 2 1 2.5 5 0.5 2 1 query RRIRRRRRR -select +select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), @@ -5451,7 +5926,7 @@ from (values (1,NULL), (NULL,4), (3,6)); NULL NULL 1 NULL 3 6 0 0 0 query RRIRRRRRR -select +select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), @@ -5466,8 +5941,8 @@ from (values (1,NULL), (NULL,4), (NULL,NULL)); NULL NULL 0 NULL NULL NULL NULL NULL NULL query TRRIRRRRRR rowsort -select - column3, +select + column3, regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), @@ -5491,7 +5966,7 @@ statement ok set datafusion.execution.batch_size = 1; query RRIRRRRRR -select +select regr_slope(c12, c11), regr_intercept(c12, c11), regr_count(c12, c11), @@ -5509,7 +5984,7 @@ statement ok set datafusion.execution.batch_size = 2; query RRIRRRRRR -select +select regr_slope(c12, c11), regr_intercept(c12, c11), regr_count(c12, c11), @@ -5527,7 +6002,7 @@ statement ok set datafusion.execution.batch_size = 3; query RRIRRRRRR -select +select regr_slope(c12, c11), regr_intercept(c12, c11), regr_count(c12, c11), @@ -5620,6 +6095,11 @@ SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); ---- | +query T +SELECT STRING_AGG(DISTINCT column1, '|') FROM (values (''), (null), ('')); +---- +(empty) + statement ok CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) @@ -5641,6 +6121,22 @@ SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 ---- NULL +query T +SELECT STRING_AGG(DISTINCT x,',') FROM strings WHERE g > 100 +---- +NULL + +query T +SELECT STRING_AGG(DISTINCT x,'|' ORDER BY x) FROM strings +---- +a|b|i|j|p|x|y|z + +query error This feature is not implemented: The second argument of the string_agg function must be a string literal +SELECT STRING_AGG(DISTINCT x,y) FROM strings + +query error Execution error: In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list +SELECT STRING_AGG(DISTINCT x,'|' ORDER BY y) FROM strings + statement ok drop table strings @@ -5655,6 +6151,17 @@ FROM my_data ---- text1, text1, text1 +query T +WITH my_data as ( +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column +) +SELECT string_agg(DISTINCT my_column,', ') as my_string_agg +FROM my_data +---- +text1 + query T WITH my_data as ( SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all @@ -5667,6 +6174,148 @@ GROUP BY dummy ---- text1, text1, text1 +query T +WITH my_data as ( +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +) +SELECT string_agg(DISTINCT my_column,', ') as my_string_agg +FROM my_data +GROUP BY dummy +---- +text1 + + +# Test string_agg with ORDER BY clasuses (issue #17011) +statement ok +create table t (k varchar, v int); + +statement ok +insert into t values ('a', 2), ('b', 3), ('c', 1), ('d', null); + +query T +select string_agg(k, ',' order by k) from t; +---- +a,b,c,d + +query T +select string_agg(k, ',' order by k desc) from t; +---- +d,c,b,a + +query T +select string_agg(k, ',' order by v) from t; +---- +c,a,b,d + +query T +select string_agg(k, ',' order by v nulls first) from t; +---- +d,c,a,b + +query T +select string_agg(k, ',' order by v desc) from t; +---- +d,b,a,c + +query T +select string_agg(k, ',' order by v desc nulls last) from t; +---- +b,a,c,d + +query T +-- odd indexes should appear first, ties solved by v +select string_agg(k, ',' order by v % 2 == 0, v) from t; +---- +c,b,a,d + +query T +-- odd indexes should appear first, ties solved by v desc +select string_agg(k, ',' order by v % 2 == 0, v desc) from t; +---- +b,c,a,d + +query T +select string_agg(k, ',' order by + case + when k = 'a' then 3 + when k = 'b' then 0 + when k = 'c' then 2 + when k = 'd' then 1 + end) +from t; +---- +b,d,c,a + +query T +select string_agg(k, ',' order by + case + when k = 'a' then 3 + when k = 'b' then 0 + when k = 'c' then 2 + when k = 'd' then 1 + end desc) +from t; +---- +a,c,d,b + +# Test explain / reverse_expr for string_agg +query TT +explain select string_agg(k, ',' order by v) from t; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[string_agg(t.k, Utf8(",")) ORDER BY [t.v ASC NULLS LAST]]] +02)--TableScan: t projection=[k, v] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[string_agg(t.k,Utf8(",")) ORDER BY [t.v ASC NULLS LAST]] +02)--SortExec: expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query T +select string_agg(k, ',' order by v) from t; +---- +c,a,b,d + +query TT +explain select string_agg(k, ',' order by v desc) from t; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[string_agg(t.k, Utf8(",")) ORDER BY [t.v DESC NULLS FIRST]]] +02)--TableScan: t projection=[k, v] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[string_agg(t.k,Utf8(",")) ORDER BY [t.v DESC NULLS FIRST]] +02)--SortExec: expr=[v@1 DESC], preserve_partitioning=[false] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query T +select string_agg(k, ',' order by v desc) from t; +---- +d,b,a,c + +# Call string_agg with both ASC and DESC orderings, and expect only one sort +# (because the aggregate can handle reversed inputs) +query TT +explain select string_agg(k, ',' order by v asc), string_agg(k, ',' order by v desc) from t; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[string_agg(t.k, Utf8(",")) ORDER BY [t.v ASC NULLS LAST], string_agg(t.k, Utf8(",")) ORDER BY [t.v DESC NULLS FIRST]]] +02)--TableScan: t projection=[k, v] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[string_agg(t.k,Utf8(",")) ORDER BY [t.v ASC NULLS LAST], string_agg(t.k,Utf8(",")) ORDER BY [t.v DESC NULLS FIRST]] +02)--SortExec: expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +select string_agg(k, ',' order by v asc), string_agg(k, ',' order by v desc) from t; +---- +c,a,b,d d,b,a,c + + +statement ok +drop table t; + + # Tests for aggregating with NaN values statement ok CREATE TABLE float_table ( @@ -5681,7 +6330,7 @@ CREATE TABLE float_table ( # Test string_agg with largeutf8 statement ok -create table string_agg_large_utf8 (c string) as values +create table string_agg_large_utf8 (c string) as values (arrow_cast('a', 'LargeUtf8')), (arrow_cast('b', 'LargeUtf8')), (arrow_cast('c', 'LargeUtf8')) @@ -5736,7 +6385,7 @@ select count(*) from (select count(*) a, count(*) b from (select 1)); # UTF8 string matters for string to &[u8] conversion, add it to prevent regression statement ok -create table distinct_count_string_table as values +create table distinct_count_string_table as values (1, 'a', 'longstringtest_a', '台灣'), (2, 'b', 'longstringtest_b1', '日本'), (2, 'b', 'longstringtest_b2', '中國'), @@ -6321,15 +6970,15 @@ logical_plan 04)------SubqueryAlias: a 05)--------Union 06)----------Projection: Int64(1) AS id, Int64(2) AS foo -07)------------EmptyRelation +07)------------EmptyRelation: rows=1 08)----------Projection: Int64(1) AS id, Int64(4) AS foo -09)------------EmptyRelation +09)------------EmptyRelation: rows=1 10)----------Projection: Int64(1) AS id, Int64(5) AS foo -11)------------EmptyRelation +11)------------EmptyRelation: rows=1 12)----------Projection: Int64(1) AS id, Int64(3) AS foo -13)------------EmptyRelation +13)------------EmptyRelation: rows=1 14)----------Projection: Int64(1) AS id, Int64(2) AS foo -15)------------EmptyRelation +15)------------EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[last_value(a.foo) ORDER BY [a.foo ASC NULLS LAST]@1 as last_value(a.foo) ORDER BY [a.foo ASC NULLS LAST], sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))] 02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[last_value(a.foo) ORDER BY [a.foo ASC NULLS LAST], sum(DISTINCT Int64(1))], ordering_mode=Sorted @@ -6363,7 +7012,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(aggregate_test_100.c5)] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5], file_type=csv, has_header=true +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c5], file_type=csv, has_header=true statement count 0 drop table aggregate_test_100; @@ -6714,7 +7363,7 @@ group1 0.0003 # median with all nulls statement ok create table group_median_all_nulls( - a STRING NOT NULL, + a STRING NOT NULL, b INT ) AS VALUES ( 'group0', NULL), @@ -6729,3 +7378,220 @@ SELECT a, median(b), arrow_typeof(median(b)) FROM group_median_all_nulls GROUP B ---- group0 NULL Int32 group1 NULL Int32 + +statement ok +create table t_decimal (c decimal(10, 4)) as values (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null); + +# Test avg_distinct for Decimal128 +query RT +select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal; +---- +180 Decimal128(14, 8) + +statement ok +drop table t_decimal; + +# Test avg_distinct for Decimal256 +statement ok +create table t_decimal256 (c decimal(50, 2)) as values + (100.00), + (125.00), + (175.00), + (200.00), + (200.00), + (300.00), + (null), + (null); + +query RT +select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal256; +---- +180 Decimal256(54, 6) + +statement ok +drop table t_decimal256; + +query I +with test AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +select count(*) from test WHERE 1 = 1; +---- +10 + +query I +with test AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +select count(c1) from test WHERE 1 = 1; +---- +10 + +query II rowsort +with test AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 5) t(i)) +select c2, count(*) from test WHERE 1 = 1 group by c2; +---- +2 1 +3 1 +4 1 +5 1 +6 1 + +# Min/Max struct +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c FROM t) +---- +{a: 1, b: 2} {a: 10, b: 11} + +# Min/Max struct with NULL +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 2 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c FROM t) +---- +{a: 2, b: 3} {a: 10, b: 11} + +# Min/Max struct with two recordbatch +query ?? rowsort +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c UNION SELECT STRUCT(3 as 'a', 4 as 'b') AS c ) +---- +{a: 1, b: 2} {a: 3, b: 4} + +# Min/Max struct empty +query ?? rowsort +SELECT MIN(c), MAX(c) FROM (SELECT * FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c) LIMIT 0) +---- +NULL NULL + +# Min/Max group struct +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 2, b: 3} {a: 10, b: 11} +1 {a: 1, b: 2} {a: 9, b: 10} + +# Min/Max group struct with NULL +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 2 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 2, b: 3} {a: 10, b: 11} +1 NULL NULL + +# Min/Max group struct with NULL +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 3 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 6, b: 7} {a: 6, b: 7} +1 {a: 3, b: 4} {a: 9, b: 10} + +# Min/Max struct empty +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c, (c1 % 2) AS key FROM t LIMIT 0) GROUP BY key +---- + +# Min/Max aggregation on struct with a single field +query ?? +WITH t AS (SELECT i as c1 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a') AS c FROM t); +---- +{a: 1} {a: 10} + +# Min/Max aggregation on struct with identical first fields but different last fields +query ?? +SELECT MIN(column1),MAX(column1) FROM ( +VALUES + (STRUCT(1 AS 'a',2 AS 'b', 3 AS 'c')), + (STRUCT(1 AS 'a',2 AS 'b', 4 AS 'c')) +); +---- +{a: 1, b: 2, c: 3} {a: 1, b: 2, c: 4} + +query TI +SELECT column1, COUNT(DISTINCT column2) FROM ( +VALUES + ('x', arrow_cast('NAN','Float64')), + ('x', arrow_cast('NAN','Float64')) +) GROUP BY 1 ORDER BY 1; +---- +x 1 + +query ? +SELECT array_agg(a_varchar) WITHIN GROUP (ORDER BY a_varchar) +FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); +---- +[a, a, c, d] + +query ? +SELECT array_agg(DISTINCT a_varchar) WITHIN GROUP (ORDER BY a_varchar) +FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); +---- +[a, c, d] + +query error Error during planning: ORDER BY and WITHIN GROUP clauses cannot be used together in the same aggregate function +SELECT array_agg(a_varchar order by a_varchar) WITHIN GROUP (ORDER BY a_varchar) +FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); + +# distinct average +statement ok +create table distinct_avg (a int, b double, c decimal(10, 4), d decimal(50, 2)) as values + (3, null, 100.2562, 90251.21), + (2, null, 100.2562, null), + (5, 100.5, null, 10000000.11), + (5, 1.0, 100.2563, -1.0), + (5, 44.112, -132.12, null), + (null, 1.0, 100.2562, 90251.21), + (5, 100.5, -100.2562, -10000000.11), + (1, 4.09, 4222.124, 0.0), + (5, 100.5, null, 10000000.11), + (5, 100.5, 1.1, 1.0), + (4, null, 4222.124, null), + (null, null, null, null) +; + +# Need two columns to ensure single_distinct_to_group_by rule doesn't kick in, so we know our actual avg(distinct) code is being tested +query RTRTRTRTRRRR +select + avg(distinct a), + arrow_typeof(avg(distinct a)), + avg(distinct b), + arrow_typeof(avg(distinct b)), + avg(distinct c), + arrow_typeof(avg(distinct c)), + avg(distinct d), + arrow_typeof(avg(distinct d)), + avg(a), + avg(b), + avg(c), + avg(d) +from distinct_avg; +---- +3 Float64 37.4255 Float64 698.56005 Decimal128(14, 8) 15041.868333 Decimal256(54, 6) 4 56.52525 957.11074444 1272562.81625 + +query RRRR rowsort +select + avg(distinct a), + avg(distinct b), + avg(distinct c), + avg(distinct d) +from distinct_avg +group by b; +---- +1 4.09 4222.124 0 +3 NULL 2161.1901 90251.21 +5 1 100.25625 45125.105 +5 100.5 -49.5781 0.333333 +5 44.112 -132.12 NULL + +query RRRR +select + avg(distinct a), + avg(distinct b), + avg(distinct c), + avg(distinct d) +from distinct_avg +where a is null and b is null and c is null and d is null; +---- +NULL NULL NULL NULL + +statement ok +drop table distinct_avg; diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index 8755918cd16c2..5dcb72b7055b8 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -69,9 +69,6 @@ set datafusion.execution.target_partitions = 2; statement ok set datafusion.execution.batch_size = 1; -statement ok -set datafusion.sql_parser.dialect = 'Postgres'; - # Grouping by unique fields allows to check all accumulators query ITIIII SELECT c5, c1, @@ -420,10 +417,6 @@ c true false NULL d NULL false NULL e true false NULL -# Enabling PG dialect for filtered aggregates tests -statement ok -set datafusion.sql_parser.dialect = 'Postgres'; - # Test count with filter query III SELECT diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index cb56686b64373..d8c29a323e945 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -310,7 +310,7 @@ AS VALUES statement ok CREATE TABLE fixed_size_array_has_table_2D AS VALUES - (arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(1,3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array([1,2,3], [4,5], [6,7]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([4,5], [6,7], [1,2]), 'FixedSizeList(3, List(Int64))')), + (arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(1,3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array([1,2,3], [4,5], [6,7]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([4,5], [6,7], [1,2,3]), 'FixedSizeList(3, List(Int64))')), (arrow_cast(make_array([3,4], [5]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(5, 3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array([1,2,3,4], [5,6,7], [8,9,10]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([1,2,3], [5,6,7], [8,9,10]), 'FixedSizeList(3, List(Int64))')) ; @@ -362,6 +362,14 @@ AS VALUES (make_array(NULL, NULL, NULL), 2) ; +statement ok +CREATE TABLE array_has_table_empty +AS VALUES + (make_array(1, 3, 5), 1), + (make_array(), 1), + (NULL, 1) +; + statement ok CREATE TABLE array_distinct_table_1D AS VALUES @@ -687,7 +695,7 @@ SELECT array_length([now()]) query ? select [abs(-1.2), sin(-1), log(2), ceil(3.141)] ---- -[1.2, -0.8414709848078965, 0.3010299801826477, 4.0] +[1.2, -0.8414709848078965, 0.30102999566398114, 4.0] ## array literal with nested types query ??? @@ -1204,7 +1212,7 @@ select array_element([1, 2], NULL); ---- NULL -query I +query ? select array_element(NULL, 2); ---- NULL @@ -1435,6 +1443,12 @@ NULL 23 NULL 43 5 NULL +# array_element of empty array +query T +select coalesce(array_element([], 1), array_element(NULL, 1), 'ok'); +---- +ok + ## array_max # array_max scalar function #1 (with positive index) @@ -1448,7 +1462,7 @@ select array_max(make_array(5, 3, 4, NULL, 6, NULL)); ---- 6 -query I +query ? select array_max(make_array(NULL, NULL)); ---- NULL @@ -1507,12 +1521,17 @@ select input, array_max(input) from (select make_array(d - 1, d, d + 1) input fr [29, 30, 31] 31 [NULL, NULL, NULL] NULL +query II +select array_max(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_max(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +3 1 + query II select array_max(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_max(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')); ---- 3 1 -query I +query ? select array_max(make_array()); ---- NULL @@ -1521,6 +1540,96 @@ NULL query error DataFusion error: Error during planning: 'array_max' does not support zero arguments select array_max(); +## array_min + +query I +select array_min(make_array(5, 3, 6, 4)); +---- +3 + +query I +select array_min(make_array(5, 3, 4, NULL, 6, NULL)); +---- +3 + +query ? +select array_min(make_array(NULL, NULL)); +---- +NULL + +query T +select array_min(make_array('h', 'e', 'o', 'l', 'l')); +---- +e + +query T +select array_min(make_array('h', 'e', 'l', NULL, 'l', 'o', NULL)); +---- +e + +query B +select array_min(make_array(false, true, false, true)); +---- +false + +query B +select array_min(make_array(false, true, NULL, false, true)); +---- +false + +query D +select array_min(make_array(DATE '1992-09-01', DATE '1993-03-01', DATE '1999-05-01', DATE '1985-11-01')); +---- +1985-11-01 + +query D +select array_min(make_array(DATE '1995-09-01', DATE '1999-05-01', DATE '1993-03-01', NULL)); +---- +1993-03-01 + +query P +select array_min(make_array(TIMESTAMP '1992-09-01', TIMESTAMP '1995-06-01', TIMESTAMP '1984-10-01')); +---- +1984-10-01T00:00:00 + +query P +select array_min(make_array(NULL, TIMESTAMP '1996-10-01', TIMESTAMP '1995-06-01')); +---- +1995-06-01T00:00:00 + +query R +select array_min(make_array(5.1, -3.2, 6.3, 4.9)); +---- +-3.2 + +query ?I +select input, array_min(input) from (select make_array(d - 1, d, d + 1) input from (values (0), (10), (20), (30), (NULL)) t(d)) +---- +[-1, 0, 1] -1 +[9, 10, 11] 9 +[19, 20, 21] 19 +[29, 30, 31] 29 +[NULL, NULL, NULL] NULL + +query II +select array_min(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_min(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +1 1 + +query II +select array_min(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_min(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')); +---- +1 1 + +query ? +select array_min(make_array()); +---- +NULL + +# Testing with empty arguments should result in an error +query error DataFusion error: Error during planning: 'array_min' does not support zero arguments +select array_min(); + ## array_pop_back (aliases: `list_pop_back`) @@ -1839,6 +1948,12 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 2, 4), + array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice scalar function #2 (with positive indexes; full array) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); @@ -1850,6 +1965,12 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, ---- [1, 2, 3, 4, 5] [h, e, l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 0, 6), + array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + # array_slice scalar function #3 (with positive indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); @@ -2177,7 +2298,7 @@ select array_any_value(1), array_any_value('a'), array_any_value(NULL); # array_any_value scalar function #1 (with null and non-null elements) -query ITII +query IT?I select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o')), array_any_value(make_array(NULL, NULL)), array_any_value(make_array(NULL, NULL, 1, 2, 3)); ---- 1 h NULL 1 @@ -2324,6 +2445,20 @@ select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, ---- [NULL, NULL, -5, 1, 3, 5] [NULL, 1, 2, 3] [NULL, 3, 2, 1] +query ??? +select array_sort(arrow_cast(make_array(1, 3, null, 5, NULL, -5), 'LargeList(Int64)')), + array_sort(arrow_cast(make_array(1, 3, null, 2), 'LargeList(Int64)'), 'ASC'), + array_sort(arrow_cast(make_array(1, 3, null, 2), 'LargeList(Int64)'), 'desc', 'NULLS FIRST'); +---- +[NULL, NULL, -5, 1, 3, 5] [NULL, 1, 2, 3] [NULL, 3, 2, 1] + +query ??? +select array_sort(arrow_cast(make_array(1, 3, null, 5, NULL, -5), 'FixedSizeList(6, Int64)')), + array_sort(arrow_cast(make_array(1, 3, null, 2), 'FixedSizeList(4, Int64)'), 'ASC'), + array_sort(arrow_cast(make_array(1, 3, null, 2), 'FixedSizeList(4, Int64)'), 'desc', 'NULLS FIRST'); +---- +[NULL, NULL, -5, 1, 3, 5] [NULL, 1, 2, 3] [NULL, 3, 2, 1] + query ? select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; ---- @@ -2348,6 +2483,11 @@ NULL [NULL, 51, 52, 54, 55, 56, 57, 58, 59, 60] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +# test with empty table +query ? +select array_sort(column1, 'DESC', 'NULLS FIRST') from arrays_values where false; +---- + # test with empty array query ? select array_sort([]); @@ -2396,6 +2536,11 @@ NULL NULL NULL NULL NULL NULL +query ? +select array_sort([struct('foo', 3), struct('foo', 1), struct('bar', 1)]) +---- +[{c0: bar, c1: 1}, {c0: foo, c1: 1}, {c0: foo, c1: 3}] + ## test with argument of incorrect types query error DataFusion error: Execution error: the second parameter of array_sort expects DESC or ASC select array_sort([1, 3, null, 5, NULL, -5], 1), array_sort([1, 3, null, 5, NULL, -5], 'DESC', 1), array_sort([1, 3, null, 5, NULL, -5], 1, 1); @@ -2430,11 +2575,15 @@ select array_append(null, 1); ---- [1] -query error +query ? select array_append(null, [2, 3]); +---- +[[2, 3]] -query error +query ? select array_append(null, [[4]]); +---- +[[[4]]] query ???? select @@ -2675,7 +2824,6 @@ select array_append(column1, arrow_cast(make_array(1, 11, 111), 'FixedSizeList(3 # DuckDB: [4] # ClickHouse: Null -# Since they dont have the same result, we just follow Postgres, return error query ? select array_prepend(4, NULL); ---- @@ -2711,8 +2859,10 @@ select array_prepend(null, [[1,2,3]]); # DuckDB: [[]] # ClickHouse: [[]] # TODO: We may also return [[]] -query error +query ? select array_prepend([], []); +---- +[[]] query ? select array_prepend(null, null); @@ -3048,6 +3198,42 @@ select array_concat([]); ---- [] +# test with NULL array +query ? +select array_concat(NULL::integer[]); +---- +NULL + +# test with multiple NULL arrays +query ? +select array_concat(NULL::integer[], NULL::integer[]); +---- +NULL + +# test with NULL LargeList +query ? +select array_concat(arrow_cast(NULL::string[], 'LargeList(Utf8)')); +---- +NULL + +# test with NULL FixedSizeList +query ? +select array_concat(arrow_cast(NULL::string[], 'FixedSizeList(2, Utf8)')); +---- +NULL + +# test with mix of NULL and empty arrays +query ? +select array_concat(NULL::integer[], []); +---- +[] + +# test with mix of NULL and non-empty arrays +query ? +select array_concat(NULL::integer[], [1, 2, 3]); +---- +[1, 2, 3] + # Concatenating strings arrays query ? select array_concat( @@ -3057,6 +3243,22 @@ select array_concat( ---- [1, 2, 3] +query ? +select array_concat( + arrow_cast(['1', '2'], 'LargeList(Utf8)'), + arrow_cast(['3'], 'LargeList(Utf8)') +); +---- +[1, 2, 3] + +query ? +select array_concat( + arrow_cast(['1', '2'], 'FixedSizeList(2, Utf8)'), + arrow_cast(['3'], 'FixedSizeList(1, Utf8)') +); +---- +[1, 2, 3] + # Concatenating string arrays query ? select array_concat( @@ -3075,22 +3277,25 @@ select array_concat( ---- [1, 2, 3] -# Concatenating Mixed types (doesn't work) -query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: LargeUtf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +# Concatenating Mixed types +query ? select array_concat( [arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'LargeUtf8')] ); +---- +[1, 2, 3] -# Concatenating Mixed types (doesn't work) -query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) -select array_concat( - [arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], - [arrow_cast('3', 'Utf8View')] -); +# Concatenating Mixed types +query ?T +select + array_concat([arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'Utf8View')]), + arrow_typeof(array_concat([arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'Utf8View')])); +---- +[1, 2, 3] List(Field { name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) # array_concat error -query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\. +query error DataFusion error: Error during planning: Execution error: Function 'array_concat' user-defined coercion failed with "Error during planning: array_concat does not support type Int64" select array_concat(1, 2); # array_concat scalar function #1 @@ -3342,10 +3547,16 @@ select array_concat(make_array(column3), column1, column2) from arrays_values_v2 ## array_position (aliases: `list_position`, `array_indexof`, `list_indexof`) ## array_position with NULL (follow PostgreSQL) -#query I -#select array_position([1, 2, 3, 4, 5], null), array_position(NULL, 1); -#---- -#NULL NULL +query II +select array_position([1, 2, 3, 4, 5], arrow_cast(NULL, 'Int64')), array_position(arrow_cast(NULL, 'List(Int64)'), 1); +---- +NULL NULL + +# array_position with no match (incl. empty array) returns NULL +query II +select array_position([], 1), array_position([2], 1); +---- +NULL NULL # array_position scalar function #1 query III @@ -3358,6 +3569,11 @@ select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ---- 3 5 1 +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'FixedSizeList(5, Utf8)'), 'l'), array_position(arrow_cast([1, 2, 3, 4, 5], 'FixedSizeList(5, Int64)'), 5), array_position(arrow_cast([1, 1, 1], 'FixedSizeList(3, Int64)'), 1); +---- +3 5 1 + # array_position scalar function #2 (with optional argument) query III select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); @@ -3369,6 +3585,11 @@ select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ---- 4 5 2 +query III +select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'FixedSizeList(5, Utf8)'), 'l', 4), array_position(arrow_cast([1, 2, 3, 4, 5], 'FixedSizeList(5, Int64)'), 5, 4), array_position(arrow_cast([1, 1, 1], 'FixedSizeList(3, Int64)'), 1, 2); +---- +4 5 2 + # array_position scalar function #3 (element is list) query II select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); @@ -3401,15 +3622,11 @@ SELECT array_position(arrow_cast([1, 1, 100, 1, 1], 'LargeList(Int32)'), 100) ---- 3 -query I +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_position' function: coercion from SELECT array_position([1, 2, 3], 'foo') ----- -NULL -query I +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_position' function: coercion from SELECT array_position([1, 2, 3], 'foo', 2) ----- -NULL # list_position scalar function #5 (function alias `array_position`) query III @@ -3675,6 +3892,14 @@ select ---- [1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] +query ??? +select + array_replace(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), 2, 3), + array_replace(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'FixedSizeList(7, Int64)'), 4, 0), + array_replace(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 4, 5, 4, 6, 7] [1, 2, 3] + # array_replace scalar function #2 (element is list) query ?? select @@ -3706,6 +3931,21 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + # list_replace scalar function #3 (function alias `list_replace`) query ??? select list_replace( @@ -3847,6 +4087,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] +query ??? +select + array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), 2, 3, 2), + array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'FixedSizeList(7, Int64)'), 4, 0, 2), + array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] + # array_replace_n scalar function #2 (element is list) query ?? select @@ -3882,6 +4130,23 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace_n( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), + [4, 5, 6], + [1, 1, 1], + 2 + ), + array_replace_n( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), + [2, 3, 4], + [3, 1, 4], + 2 + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + # list_replace_n scalar function #3 (function alias `array_replace_n`) query ??? select @@ -4038,6 +4303,14 @@ select ---- [1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] +query ??? +select + array_replace_all(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), 2, 3), + array_replace_all(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'FixedSizeList(7, Int64)'), 4, 0), + array_replace_all(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0); +---- +[1, 3, 3, 4] [1, 0, 0, 5, 0, 6, 7] [1, 2, 3] + # array_replace_all scalar function #2 (element is list) query ?? select @@ -4069,6 +4342,21 @@ select ---- [[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select + array_replace_all( + arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), + [4, 5, 6], + [1, 1, 1] + ), + array_replace_all( + arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), + [2, 3, 4], + [3, 1, 4] + ); +---- +[[1, 2, 3], [1, 1, 1], [5, 5, 5], [1, 1, 1], [7, 8, 9]] [[1, 3, 2], [3, 1, 4], [3, 1, 4], [5, 3, 1], [1, 3, 2]] + # list_replace_all scalar function #3 (function alias `array_replace_all`) query ??? select @@ -4334,6 +4622,16 @@ select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, ---- [1, 2, 3, 4, 5, 6] +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'FixedSizeList(4, Int64)'), arrow_cast([5, 6, 3, 4], 'FixedSizeList(4, Int64)')); +---- +[1, 2, 3, 4, 5, 6] + +query ? +select array_union(arrow_cast([1, 2, 3, 4], 'FixedSizeList(4, Int64)'), arrow_cast([5, 6], 'FixedSizeList(2, Int64)')); +---- +[1, 2, 3, 4, 5, 6] + # array_union scalar function #2 query ? select array_union([1, 2, 3, 4], [5, 6, 7, 8]); @@ -4371,7 +4669,8 @@ select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, statement ok CREATE TABLE arrays_with_repeating_elements_for_union AS VALUES - ([1], [2]), + ([0, 1, 1], []), + ([1, 1], [2]), ([2, 3], [3]), ([3], [3, 4]) ; @@ -4379,6 +4678,7 @@ AS VALUES query ? select array_union(column1, column2) from arrays_with_repeating_elements_for_union; ---- +[0, 1] [1, 2] [2, 3] [3, 4] @@ -4386,6 +4686,7 @@ select array_union(column1, column2) from arrays_with_repeating_elements_for_uni query ? select array_union(arrow_cast(column1, 'LargeList(Int64)'), arrow_cast(column2, 'LargeList(Int64)')) from arrays_with_repeating_elements_for_union; ---- +[0, 1] [1, 2] [2, 3] [3, 4] @@ -4408,12 +4709,10 @@ select array_union(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList query ? select array_union([[null]], []); ---- -[[NULL]] +[[]] -query ? +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_union' function: select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([], 'LargeList(Int64)')); ----- -[[NULL]] # array_union scalar function #8 query ? @@ -4532,6 +4831,11 @@ select array_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), ---- h,e,l,l,o 1-2-3-4-5 1|2|3 +query TTT +select array_to_string(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'FixedSizeList(5, Utf8)'), ','), array_to_string(arrow_cast([1, 2, 3, 4, 5], 'FixedSizeList(5, Int64)'), '-'), array_to_string(arrow_cast([1.0, 2.0, 3.0], 'FixedSizeList(3, Float64)'), '|'); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + # array_to_string scalar function with nulls #2 query TTT select array_to_string(make_array('h', NULL, NULL, NULL, 'o'), ',', '-'), array_to_string(make_array(NULL, 2, NULL, 4, 5), '-', 'nil'), array_to_string(make_array(1.0, NULL, 3.0), '|', '0'); @@ -4543,6 +4847,16 @@ select array_to_string(arrow_cast(make_array('h', NULL, NULL, NULL, 'o'), 'Large ---- h,-,-,-,o nil-2-nil-4-5 1|0|3 +query TTT +select array_to_string(arrow_cast(make_array('h', NULL, NULL, NULL, 'o'), 'FixedSizeList(5, Utf8)'), ',', '-'), array_to_string(arrow_cast(make_array(NULL, 2, NULL, 4, 5), 'FixedSizeList(5, Int64)'), '-', 'nil'), array_to_string(arrow_cast(make_array(1.0, NULL, 3.0), 'FixedSizeList(3, Float64)'), '|', '0'); +---- +h,-,-,-,o nil-2-nil-4-5 1|0|3 + +query T +select array_to_string(arrow_cast([arrow_cast([NULL, 'a'], 'FixedSizeList(2, Utf8)'), NULL], 'FixedSizeList(2, FixedSizeList(2, Utf8))'), ',', '-'); +---- +-,a,-,- + # array_to_string with columns #1 # For reference @@ -4808,6 +5122,12 @@ select array_remove(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5 ---- [[1, 2, 3], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [2, 3, 4], [5, 3, 1], [1, 3, 2]] +query ?? +select array_remove(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), [4, 5, 6]), + array_remove(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), [2, 3, 4]); +---- +[[1, 2, 3], [5, 5, 5], [4, 5, 6], [7, 8, 9]] [[1, 3, 2], [2, 3, 4], [5, 3, 1], [1, 3, 2]] + # list_remove scalar function #3 (function alias `array_remove`) query ??? select list_remove(make_array(1, 2, 2, 1, 1), 2), list_remove(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), list_remove(make_array('h', 'e', 'l', 'l', 'o'), 'l'); @@ -4931,12 +5251,38 @@ select array_remove_n(make_array(1, 2, 2, 1, 1), 2, 2), array_remove_n(make_arra ---- [1, 1, 1] [2.0, 2.0, 1.0] [h, e, o] +query ??? +select array_remove_n(arrow_cast(make_array(1, 2, 2, 1, 1), 'LargeList(Int32)'), 2, 2), + array_remove_n(arrow_cast(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 'LargeList(Float32)'), 1.0, 2), + array_remove_n(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 'l', 3); +---- +[1, 1, 1] [2.0, 2.0, 1.0] [h, e, o] + +query ??? +select array_remove_n(arrow_cast(make_array(1, 2, 2, 1, 1), 'FixedSizeList(5, Int32)'), 2, 2), + array_remove_n(arrow_cast(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 'FixedSizeList(5, Float32)'), 1.0, 2), + array_remove_n(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 'l', 3); +---- +[1, 1, 1] [2.0, 2.0, 1.0] [h, e, o] + # array_remove_n scalar function #2 (element is list) query ?? select array_remove_n(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], 2), array_remove_n(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], 2); ---- [[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] +query ?? +select array_remove_n(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), [4, 5, 6], 2), + array_remove_n(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), [2, 3, 4], 2); +---- +[[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] + +query ?? +select array_remove_n(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), [4, 5, 6], 2), + array_remove_n(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), [2, 3, 4], 2); +---- +[[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] + # list_remove_n scalar function #3 (function alias `array_remove_n`) query ??? select list_remove_n(make_array(1, 2, 2, 1, 1), 2, 2), list_remove_n(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0, 2), list_remove_n(make_array('h', 'e', 'l', 'l', 'o'), 'l', 3); @@ -4999,6 +5345,13 @@ select array_remove_all(make_array(1, 2, 2, 1, 1), 2), array_remove_all(make_arr ---- [1, 1, 1] [2.0, 2.0] [h, e, o] +query ??? +select array_remove_all(arrow_cast(make_array(1, 2, 2, 1, 1), 'LargeList(Int64)'), 2), + array_remove_all(arrow_cast(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 'LargeList(Float64)'), 1.0), + array_remove_all(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 'l'); +---- +[1, 1, 1] [2.0, 2.0] [h, e, o] + query ??? select array_remove_all(arrow_cast(make_array(1, 2, 2, 1, 1), 'FixedSizeList(5, Int64)'), 2), array_remove_all(arrow_cast(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 'FixedSizeList(5, Float64)'), 1.0), array_remove_all(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'FixedSizeList(5, Utf8)'), 'l'); ---- @@ -5016,6 +5369,12 @@ select array_remove_all(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [ ---- [[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] +query ?? +select array_remove_all(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), [4, 5, 6]), + array_remove_all(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'FixedSizeList(5, FixedSizeList(3, Int64))'), [2, 3, 4]); +---- +[[1, 2, 3], [5, 5, 5], [7, 8, 9]] [[1, 3, 2], [5, 3, 1], [1, 3, 2]] + # list_remove_all scalar function #3 (function alias `array_remove_all`) query ??? select list_remove_all(make_array(1, 2, 2, 1, 1), 2), list_remove_all(make_array(1.0, 2.0, 2.0, 1.0, 1.0), 1.0), list_remove_all(make_array('h', 'e', 'l', 'l', 'o'), 'l'); @@ -5218,6 +5577,19 @@ NULL 10 NULL 10 NULL 10 +# array_length for fixed sized list + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'FixedSizeList(3, List(Int64))')); +---- +5 3 3 + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'FixedSizeList(3, List(Int64))'), 1); +---- +5 3 3 + + query RRR select array_distance([2], [3]), list_distance([1], [2]), list_distance([1], [-2]); ---- @@ -5665,6 +6037,30 @@ false false false +# array_has([1, 3, 5], 1) -> true (array contains element) +# array_has([], 1) -> false (empty array, not null) +# array_has(null, 1) -> null (null array) +query B +select array_has(column1, column2) +from array_has_table_empty; +---- +true +false +NULL + +# Test for issue: array_has should return false for empty arrays, not null +# This test demonstrates the correct behavior with COALESCE to show the distinction +# array_has([1, 3, 5], 1) -> 'true' +# array_has([], 1) -> 'false' (empty array should return false) +# array_has(null, 1) -> 'null' (null array should return null) +query ?T +SELECT column1, COALESCE(CAST(array_has(column1, column2) AS VARCHAR), 'null') +from array_has_table_empty; +---- +[1, 3, 5] true +[] false +NULL null + query B select array_has(column1, column2) from fixed_size_array_has_table_1D; @@ -5672,14 +6068,13 @@ from fixed_size_array_has_table_1D; true false -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query BB -#select array_has_all(column3, column4), -# array_has_any(column5, column6) -#from fixed_size_array_has_table_1D; -#---- -#true true -#false false +query BB +select array_has_all(column3, column4), + array_has_any(column5, column6) +from fixed_size_array_has_table_1D; +---- +true true +false false query BBB select array_has(column1, column2), @@ -5706,14 +6101,13 @@ from fixed_size_array_has_table_1D_Float; true false -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query BB -#select array_has_all(column3, column4), -# array_has_any(column5, column6) -#from fixed_size_array_has_table_1D_Float; -#---- -#true true -#false true +query BB +select array_has_all(column3, column4), + array_has_any(column5, column6) +from fixed_size_array_has_table_1D_Float; +---- +true true +false true query BBB select array_has(column1, column2), @@ -5740,14 +6134,27 @@ from fixed_size_array_has_table_1D_Boolean; false true -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query BB -#select array_has_all(column3, column4), -# array_has_any(column5, column6) -#from fixed_size_array_has_table_1D_Boolean; -#---- -#true true -#true true +query BB +select array_has_all(column3, column4), + array_has_any(column5, column6) +from fixed_size_array_has_table_1D_Boolean; +---- +true true +true true + +query BBBBBBBB +select array_has_all(column3, arrow_cast(column4,'LargeList(Boolean)')), + array_has_any(column5, arrow_cast(column6,'LargeList(Boolean)')), + array_has_all(column3, arrow_cast(column4,'List(Boolean)')), + array_has_any(column5, arrow_cast(column6,'List(Boolean)')), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), column4), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), column6), + array_has_all(arrow_cast(column3, 'List(Boolean)'), column4), + array_has_any(arrow_cast(column5, 'List(Boolean)'), column6) +from fixed_size_array_has_table_1D_Boolean; +---- +true true true true true true true true +true true true true true true true true query BBB select array_has(column1, column2), @@ -5797,13 +6204,12 @@ from fixed_size_array_has_table_2D; false false -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query B -#select array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) -#from fixed_size_array_has_table_2D; -#---- -#true -#false +query B +select array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from fixed_size_array_has_table_2D; +---- +true +false query B select array_has_all(column1, column2) @@ -5819,13 +6225,12 @@ from array_has_table_2D_float; true false -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query B -#select array_has_all(column1, column2) -#from fixed_size_array_has_table_2D_float; -#---- -#false -#false +query B +select array_has_all(column1, column2) +from fixed_size_array_has_table_2D_float; +---- +false +false query B select array_has(column1, column2) from array_has_table_3D; @@ -5890,6 +6295,13 @@ NULL NULL false false false false NULL false false false false NULL +# Row 1: [[NULL,2],[3,NULL]], [1.1,2.2,3.3], ['L','o','r','e','m'] +# Row 2: [[3,4],[5,6]], [NULL,5.5,6.6], ['i','p',NULL,'u','m'] +# Row 3: [[5,6],[7,8]], [7.7,8.8,9.9], ['d',NULL,'l','o','r'] +# Row 4: [[7,NULL],[9,10]], [10.1,NULL,12.2], ['s','i','t','a','b'] +# Row 5: NULL, [13.3,14.4,15.5], ['a','m','e','t','x'] +# Row 6: [[11,12],[13,14]], NULL, [',','a','b','c','d'] +# Row 7: [[15,16],[NULL,18]], [16.6,17.7,18.8], NULL query BBBB select array_has(column1, make_array(5, 6)), array_has(column1, make_array(7, NULL)), @@ -5901,9 +6313,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBB select array_has_all(make_array(1,2,3), []), @@ -5950,24 +6362,23 @@ select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_ca ---- true false true false false false true true false false true false true -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query BBBBBBBBBBBBB -#select array_has_all(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 3), 'FixedSizeList(2, Int64)')), -# array_has_all(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 4), 'FixedSizeList(2, Int64)')), -# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2]), 'FixedSizeList(1, List(Int64))')), -# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,3]), 'FixedSizeList(1, List(Int64))')), -# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'FixedSizeList(3, List(Int64))')), -# array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1]]), 'FixedSizeList(1, List(List(Int64)))')), -# array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))')), -# array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1,10,100), 'FixedSizeList(3, Int64)')), -# array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(10, 100),'FixedSizeList(2, Int64)')), -# array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'FixedSizeList(2, List(Int64))')), -# array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'FixedSizeList(2, List(Int64))')), -# array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'FixedSizeList(1, List(List(Int64)))')), -# array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'FixedSizeList(2, List(List(Int64)))')) -#; -#---- -#true false true false false false true true false false true false true +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 3), 'FixedSizeList(2, Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 4), 'FixedSizeList(2, Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2]), 'FixedSizeList(1, List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,3]), 'FixedSizeList(1, List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'FixedSizeList(3, List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1]]), 'FixedSizeList(1, List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1,10,100), 'FixedSizeList(3, Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(10, 100),'FixedSizeList(2, Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'FixedSizeList(2, List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'FixedSizeList(2, List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'FixedSizeList(1, List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'FixedSizeList(2, List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true # rewrite various array_has operations to InList where the haystack is a literal list # NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, so we make 4-element haystack lists @@ -5988,8 +6399,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -5997,7 +6408,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6017,8 +6428,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6026,7 +6437,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6046,8 +6457,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6055,18 +6466,16 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] -# FIXME: due to rewrite below not working, this is _extremely_ slow to evaluate -# query I -# with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) -# select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); -# ---- -# 1 +query I +with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); +---- +1 -# FIXME: array_has with large list haystack not currently rewritten to InList query TT explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); @@ -6077,8 +6486,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32))) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6086,7 +6495,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c], substr(md5(CAST(value@0 AS Utf8)), 1, 32)) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6106,8 +6515,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6115,7 +6524,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6125,7 +6534,8 @@ select count(*) from test WHERE array_has([needle], needle); ---- 100000 -# TODO: this should probably be possible to completely remove the filter as always true? +# The optimizer does not currently eliminate the filter; +# Instead, it's rewritten as `IS NULL OR NOT NULL` due to SQL null semantics query TT explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has([needle], needle); @@ -6136,9 +6546,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: __common_expr_3 = __common_expr_3 -07)------------Projection: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) AS __common_expr_3 -08)--------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IS NOT NULL OR Boolean(NULL) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6146,10 +6555,9 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: __common_expr_3@0 = __common_expr_3@0 -08)--------------ProjectionExec: expr=[substr(md5(CAST(value@0 AS Utf8)), 1, 32) as __common_expr_3] -09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)------------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IS NOT NULL OR NULL +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] # any operator query ? @@ -6400,6 +6808,17 @@ SELECT array_intersect(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow ---- [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] +query ?????? +SELECT array_intersect(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(2,3,4), 'FixedSizeList(3, Int64)')), + array_intersect(arrow_cast(make_array(1,3,5), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(2,4,6), 'FixedSizeList(3, Int64)')), + array_intersect(arrow_cast(make_array('aa','bb','cc'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array('cc','aa','dd'), 'FixedSizeList(3, Utf8)')), + array_intersect(arrow_cast(make_array(true, false), 'FixedSizeList(2, Boolean)'), arrow_cast(make_array(true), 'FixedSizeList(1, Boolean)')), + array_intersect(arrow_cast(make_array(1.1, 2.2, 3.3), 'FixedSizeList(3, Float64)'), arrow_cast(make_array(2.2, 3.3, 4.4), 'FixedSizeList(3, Float64)')), + array_intersect(arrow_cast(make_array([1, 1], [2, 2], [3, 3]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([2, 2], [3, 3], [4, 4]), 'FixedSizeList(3, List(Int64))')) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + query ? select array_intersect([], []); ---- @@ -6423,12 +6842,12 @@ select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null) query ? select array_intersect(null, [1, 1, 2, 2, 3, 3]); ---- -NULL +[] query ? select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -NULL +[] query ? select array_intersect([], null); @@ -6453,12 +6872,12 @@ select array_intersect(arrow_cast([], 'LargeList(Int64)'), null); query ? select array_intersect(null, []); ---- -NULL +[] query ? select array_intersect(null, arrow_cast([], 'LargeList(Int64)')); ---- -NULL +[] query ? select array_intersect(null, null); @@ -7077,6 +7496,16 @@ select array_except(null, null) ---- NULL +query ? +select array_except(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, 6, 3, 4], 'LargeList(Int64)')); +---- +[1, 2] + +query ? +select array_except(arrow_cast([1, 2, 3, 4], 'FixedSizeList(4, Int64)'), arrow_cast([5, 6, 3, 4], 'FixedSizeList(4, Int64)')); +---- +[1, 2] + ### Array operators tests @@ -7159,7 +7588,7 @@ explain select [1,2,3] @> [1,3]; ---- logical_plan 01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3))) -02)--EmptyRelation +02)--EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))] 02)--PlaceholderRowExec @@ -7182,7 +7611,7 @@ explain select [1,3] <@ [1,2,3]; ---- logical_plan 01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3))) -02)--EmptyRelation +02)--EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))] 02)--PlaceholderRowExec @@ -7281,12 +7710,10 @@ select array_concat(column1, [7]) from arrays_values_v2; # flatten -#TODO: https://github.com/apache/datafusion/issues/7142 -# follow DuckDB -#query ? -#select flatten(NULL); -#---- -#NULL +query ? +select flatten(NULL); +---- +NULL # flatten with scalar values #1 query ??? @@ -7294,21 +7721,21 @@ select flatten(make_array(1, 2, 1, 3, 2)), flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))), flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]])); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] query ??? select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')), flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'LargeList(LargeList(Int64))')), flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 'LargeList(LargeList(LargeList(Float64)))')); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] query ??? select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'FixedSizeList(5, Int64)')), flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'FixedSizeList(4, List(Int64))')), flatten(arrow_cast(make_array([[1.1], [2.2]], [[3.3], [4.4]]), 'FixedSizeList(2, List(List(Float64)))')); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] # flatten with column values query ???? @@ -7318,8 +7745,8 @@ select flatten(column1), flatten(column4) from flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] query ???? select flatten(column1), @@ -7328,8 +7755,8 @@ select flatten(column1), flatten(column4) from large_flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] query ???? select flatten(column1), @@ -7338,8 +7765,19 @@ select flatten(column1), flatten(column4) from fixed_size_flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8, 9, 10, 11, 12, 13] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8], [9, 10], [11, 12, 13]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + +# flatten with different inner list type +query ?????? +select flatten(arrow_cast(make_array([1, 2], [3, 4]), 'List(FixedSizeList(2, Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'List(FixedSizeList(1, List(Int64)))')), + flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(List(List(Int64)))')), + flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(FixedSizeList(2, Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(FixedSizeList(1, List(Int64)))')) +---- +[1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] ## empty (aliases: `array_empty`, `list_empty`) # empty scalar function #1 @@ -7631,6 +8069,11 @@ select array_resize(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1); ---- [1] +query ? +select array_resize(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 1); +---- +[1] + # array_resize scalar function #2 query ? select array_resize(make_array(1, 2, 3), 5); @@ -7642,6 +8085,11 @@ select array_resize(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 5); ---- [1, 2, 3, NULL, NULL] +query ? +select array_resize(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 5); +---- +[1, 2, 3, NULL, NULL] + # array_resize scalar function #3 query ? select array_resize(make_array(1, 2, 3), 5, 4); @@ -7760,11 +8208,13 @@ select array_reverse(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array ---- [3, 2, 1] [1] -#TODO: support after FixedSizeList type coercion -#query ?? -#select array_reverse(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_reverse(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')); -#---- -#[3, 2, 1] [1] +query ???? +select array_reverse(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), + array_reverse(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')), + array_reverse(arrow_cast(make_array(1, NULL, 3), 'FixedSizeList(3, Int64)')), + array_reverse(arrow_cast(make_array(NULL, NULL, NULL), 'FixedSizeList(3, Int64)')); +---- +[3, 2, 1] [1] [3, NULL, 1] [NULL, NULL, NULL] query ?? select array_reverse(NULL), array_reverse([]); @@ -7783,6 +8233,23 @@ NULL NULL [60, 59, 58, 57, 56, 55, 54, NULL, 52, 51] [51, 52, NULL, 54, 55, 56, 57, 58, 59, 60] [70, 69, 68, 67, 66, 65, 64, 63, 62, 61] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +statement ok +CREATE TABLE test_reverse_fixed_size AS VALUES + (arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)')), + (arrow_cast([4, 5, 6], 'FixedSizeList(3, Int64)')), + (arrow_cast([NULL, 8, 9], 'FixedSizeList(3, Int64)')), + (NULL); + +query ? +SELECT array_reverse(column1) FROM test_reverse_fixed_size; +---- +[3, 2, 1] +[6, 5, 4] +[9, 8, NULL] +NULL + +statement ok +DROP TABLE test_reverse_fixed_size; # Test defining a table with array columns statement ok @@ -7820,7 +8287,7 @@ List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int3 query ??T select [1,2,3]::int[], [['1']]::int[][], arrow_typeof([]::text[]); ---- -[1, 2, 3] [[1]] List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +[1, 2, 3] [[1]] List(Field { name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) # test empty arrays return length # issue: https://github.com/apache/datafusion/pull/12459 @@ -7843,9 +8310,40 @@ select arrow_typeof(a) from fixed_size_col_table; FixedSizeList(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) FixedSizeList(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) -statement error +query ? rowsort +SELECT DISTINCT a FROM fixed_size_col_table +---- +[1, 2, 3] +[4, 5, 6] + +query ?I rowsort +SELECT a, count(*) FROM fixed_size_col_table GROUP BY a +---- +[1, 2, 3] 1 +[4, 5, 6] 1 + +statement error Cast error: Cannot cast to FixedSizeList\(3\): value at index 0 has length 2 create table varying_fixed_size_col_table (a int[3]) as values ([1,2,3]), ([4,5]); +# https://github.com/apache/datafusion/issues/16187 +# should be NULL in case of out of bounds for Null Type +query ? +select [named_struct('a', 1, 'b', null)][-2]; +---- +NULL + +statement ok +COPY (select [[true, false], [false, true]] a, [false, true] b union select [[null, null]], null) to 'test_files/scratch/array/array_has/single_file.parquet' stored as parquet; + +statement ok +CREATE EXTERNAL TABLE array_has STORED AS PARQUET location 'test_files/scratch/array/array_has/single_file.parquet'; + +query B +select array_contains(a, b) from array_has order by 1 nulls last; +---- +true +NULL + ### Delete tables statement ok @@ -8024,3 +8522,6 @@ drop table values_all_empty; statement ok drop table fixed_size_col_table; + +statement ok +drop table array_has; diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 8fde295e6051f..65d4fa495e3b3 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -108,11 +108,15 @@ SELECT * FROM data WHERE column2 is not distinct from null; # Aggregates ########### -query error Internal error: Min/Max accumulator not implemented for type List +query ? SELECT min(column1) FROM data; +---- +[1, 2, 3] -query error Internal error: Min/Max accumulator not implemented for type List +query ? SELECT max(column1) FROM data; +---- +[2, 3] query I SELECT count(column1) FROM data; diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt index 30f322cf98fcd..62453ec4bf3e6 100644 --- a/datafusion/sqllogictest/test_files/arrow_files.slt +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -19,6 +19,11 @@ ## Arrow Files Format support ############# +# We using fixed arrow file to test for sqllogictests, and this arrow field is writing with arrow-ipc utf8, +# so when we decode to read it's also loading utf8. +# Currently, so we disable the map_string_types_to_utf8view +statement ok +set datafusion.sql_parser.map_string_types_to_utf8view = false; statement ok diff --git a/datafusion/sqllogictest/test_files/async_udf.slt b/datafusion/sqllogictest/test_files/async_udf.slt new file mode 100644 index 0000000000000..c61d02cfecfd4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/async_udf.slt @@ -0,0 +1,107 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +create table data(x int) as values (-10), (2); + +# Async udf can be used in aggregation +query I +select min(async_abs(x)) from data; +---- +2 + +query TT +explain select min(async_abs(x)) from data; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[min(async_abs(data.x))]] +02)--TableScan: data projection=[x] +physical_plan +01)AggregateExec: mode=Final, gby=[], aggr=[min(async_abs(data.x))] +02)--CoalescePartitionsExec +03)----AggregateExec: mode=Partial, gby=[], aggr=[min(async_abs(data.x))] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + +# Async udf can be used in aggregation with group by +query I rowsort +select min(async_abs(x)) from data group by async_abs(x); +---- +10 +2 + +query TT +explain select min(async_abs(x)) from data group by async_abs(x); +---- +logical_plan +01)Projection: min(async_abs(data.x)) +02)--Aggregate: groupBy=[[__common_expr_1 AS async_abs(data.x)]], aggr=[[min(__common_expr_1 AS async_abs(data.x))]] +03)----Projection: async_abs(data.x) AS __common_expr_1 +04)------TableScan: data projection=[x] +physical_plan +01)ProjectionExec: expr=[min(async_abs(data.x))@1 as min(async_abs(data.x))] +02)--AggregateExec: mode=FinalPartitioned, gby=[async_abs(data.x)@0 as async_abs(data.x)], aggr=[min(async_abs(data.x))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------RepartitionExec: partitioning=Hash([async_abs(data.x)@0], 4), input_partitions=4 +05)--------AggregateExec: mode=Partial, gby=[__common_expr_1@0 as async_abs(data.x)], aggr=[min(async_abs(data.x))] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------ProjectionExec: expr=[__async_fn_0@1 as __common_expr_1] +08)--------------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] +09)----------------CoalesceBatchesExec: target_batch_size=8192 +10)------------------DataSourceExec: partitions=1, partition_sizes=[1] + +# Async udf can be used in filter +query I +select * from data where async_abs(x) < 5; +---- +2 + +query TT +explain select * from data where async_abs(x) < 5; +---- +logical_plan +01)Filter: async_abs(data.x) < Int32(5) +02)--TableScan: data projection=[x] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: __async_fn_0@1 < 5, projection=[x@0] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +# Async udf can be used in projection +query I rowsort +select async_abs(x) from data; +---- +10 +2 + +query TT +explain select async_abs(x) from data; +---- +logical_plan +01)Projection: async_abs(data.x) +02)--TableScan: data projection=[x] +physical_plan +01)ProjectionExec: expr=[__async_fn_0@1 as async_abs(data.x)] +02)--AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/avro.slt b/datafusion/sqllogictest/test_files/avro.slt index 1b4150b074ccd..2ad60c0082e87 100644 --- a/datafusion/sqllogictest/test_files/avro.slt +++ b/datafusion/sqllogictest/test_files/avro.slt @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +# Currently, the avro not support Utf8View type, so we disable the map_string_types_to_utf8view +# After https://github.com/apache/arrow-rs/issues/7262 released, we can remove this setting +statement ok +set datafusion.sql_parser.map_string_types_to_utf8view = false; statement ok CREATE EXTERNAL TABLE alltypes_plain ( @@ -253,3 +257,13 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]}, file_type=avro + +# test column projection order from avro file +query ITII +SELECT id, string_col, int_col, bigint_col FROM alltypes_plain ORDER BY id LIMIT 5 +---- +0 0 0 0 +1 1 1 10 +2 0 0 0 +3 1 1 10 +4 0 0 0 diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index 5c5f9d510e554..1077c32e46f35 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -147,8 +147,45 @@ query error DataFusion error: Error during planning: Cannot infer common argumen SELECT column1, column1 = arrow_cast(X'0102', 'FixedSizeBinary(2)') FROM t # Comparison to different sized Binary -query error DataFusion error: Error during planning: Cannot infer common argument type for comparison operation FixedSizeBinary\(3\) = Binary +query ?B SELECT column1, column1 = X'0102' FROM t +---- +000102 false +003102 false +NULL NULL +ff0102 false +000102 false + +query ?B +SELECT column1, column1 = X'000102' FROM t +---- +000102 true +003102 false +NULL NULL +ff0102 false +000102 true + +query ?B +SELECT arrow_cast(column1, 'FixedSizeBinary(3)'), arrow_cast(column1, 'FixedSizeBinary(3)') = arrow_cast(arrow_cast(X'000102', 'FixedSizeBinary(3)'), 'BinaryView') FROM t; +---- +000102 true +003102 false +NULL NULL +ff0102 false +000102 true + +# Plan should not have a cast of the column (should have casted the literal +# to FixedSizeBinary as that is much faster) + +query TT +explain SELECT column1, column1 = X'000102' FROM t +---- +logical_plan +01)Projection: t.column1, t.column1 = FixedSizeBinary(3, "0,1,2") AS t.column1 = Binary("0,1,2") +02)--TableScan: t projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 as column1, column1@0 = 000102 as t.column1 = Binary("0,1,2")] +02)--DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table t_source diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 21913005e26ba..9bc1f83ed1196 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -467,6 +467,7 @@ FROM t; ---- [{foo: blarg}] +# mix of then and else query II SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2)) t(v) ---- @@ -474,11 +475,47 @@ SELECT v, CASE WHEN v != 0 THEN 10/v ELSE 42 END FROM (VALUES (0), (1), (2)) t(v 1 10 2 5 +# when expressions is always false, then branch should never be evaluated query II SELECT v, CASE WHEN v < 0 THEN 10/0 ELSE 1 END FROM (VALUES (1), (2)) t(v) ---- 1 1 2 1 +# when expressions is always true, else branch should never be evaluated +query II +SELECT v, CASE WHEN v > 0 THEN 1 ELSE 10/0 END FROM (VALUES (1), (2)) t(v) +---- +1 1 +2 1 + + +# lazy evaluation of multiple when branches, else branch should never be evaluated +query II +SELECT v, CASE WHEN v == 1 THEN -1 WHEN v == 2 THEN -2 WHEN v == 3 THEN -3 ELSE 10/0 END FROM (VALUES (1), (2), (3)) t(v) +---- +1 -1 +2 -2 +3 -3 + +# covers the InfallibleExprOrNull evaluation strategy +query II +SELECT v, CASE WHEN v THEN 1 END FROM (VALUES (1), (2), (3), (NULL)) t(v) +---- +1 1 +2 1 +3 1 +NULL NULL + statement ok drop table t + +query I +SELECT case when true then 1 / 1 else 1 / 0 end; +---- +1 + +query I +SELECT case when false then 1 / 0 else 1 / 1 end; +---- +1 diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index dfcd924758574..4c60a4365ee26 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -64,10 +64,10 @@ SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; ---- 1 -query DD -SELECT MIN("EventDate"::INT::DATE), MAX("EventDate"::INT::DATE) FROM hits; +query II +SELECT MIN("EventDate"), MAX("EventDate") FROM hits; ---- -2013-07-15 2013-07-15 +15901 15901 query II SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; @@ -168,11 +168,11 @@ SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT " ---- query IITIIIIIIIIIITTIIIIIIIIIITIIITIIIITTIIITIIIIIIIIIITIIIIITIIIIIITIIIIIIIIIITTTTIIIIIIIITITTITTTTTTTTTTIIII -SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10; +SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; ---- query T -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime") LIMIT 10; +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime" LIMIT 10; ---- query T @@ -180,7 +180,7 @@ SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhras ---- query T -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY to_timestamp_seconds("EventTime"), "SearchPhrase" LIMIT 10; +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime", "SearchPhrase" LIMIT 10; ---- query IRI @@ -247,31 +247,31 @@ SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c 1615432634 1615432633 1615432632 1615432631 1 query TI -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; ---- query TI -SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; +SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; ---- query TI -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; ---- query IIITTI -SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; +SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; ---- -query IDI -SELECT "URLHash", "EventDate"::INT::DATE, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate"::INT::DATE ORDER BY PageViews DESC LIMIT 10 OFFSET 100; +query III +SELECT "URLHash", "EventDate", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate" ORDER BY PageViews DESC LIMIT 10 OFFSET 100; ---- query III -SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-01' AND "EventDate"::INT::DATE <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; +SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; ---- query PI -SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; +SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-14' AND "EventDate" <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; ---- # Clickbench "Extended" queries that test count distinct diff --git a/datafusion/sqllogictest/test_files/clickbench_extended.slt b/datafusion/sqllogictest/test_files/clickbench_extended.slt new file mode 100644 index 0000000000000..ee3e33551ee3e --- /dev/null +++ b/datafusion/sqllogictest/test_files/clickbench_extended.slt @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# DataFusion specific ClickBench "Extended" Queries +# See data provenance notes in clickbench.slt + +statement ok +CREATE EXTERNAL TABLE hits +STORED AS PARQUET +LOCATION '../core/tests/data/clickbench_hits_10.parquet'; + +# If you change any of these queries, please change the corresponding query in +# benchmarks/queries/clickbench/extended.sql and update the README. + +query III +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; +---- +1 1 1 + +query III +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; +---- +1 1 1 + +query TIIII +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; +---- +� 1 1 1 1 + +query IIIRRRR +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; +---- +0 839 6 0 0 0 0 +0 197 2 0 0 0 0 + +query IIIIII +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; +---- + +query IIIIII +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; +---- + +query I +SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3; +---- +0 + + +statement ok +drop table hits; diff --git a/datafusion/sqllogictest/test_files/coalesce.slt b/datafusion/sqllogictest/test_files/coalesce.slt index e7cf31dc690b7..9740bade5e27b 100644 --- a/datafusion/sqllogictest/test_files/coalesce.slt +++ b/datafusion/sqllogictest/test_files/coalesce.slt @@ -260,8 +260,8 @@ select arrow_typeof(coalesce(c, arrow_cast('b', 'Dictionary(Int32, Utf8)'))) from t; ---- -a Dictionary(Int32, Utf8) -b Dictionary(Int32, Utf8) +a Utf8View +b Utf8View statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 925f96bd4ac0c..096cde86f26f5 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -332,7 +332,6 @@ OPTIONS ( 'format.dictionary_enabled' false, 'format.statistics_enabled' page, 'format.statistics_enabled::col2' none, -'format.max_statistics_size' 123, 'format.bloom_filter_fpp' 0.001, 'format.bloom_filter_ndv' 100, 'format.metadata::key' 'value' @@ -637,7 +636,7 @@ query error DataFusion error: SQL error: ParserError\("Expected: \), found: EOF" COPY (select col2, sum(col1) from source_table # Copy from table with non literal -query error DataFusion error: SQL error: ParserError\("Unexpected token \("\) +query error DataFusion error: SQL error: ParserError\("Expected: end of statement or ;, found: \( at Line: 1, Column: 44"\) COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); # Copy using execution.keep_partition_by_columns with an invalid value diff --git a/datafusion/sqllogictest/test_files/count_star_rule.slt b/datafusion/sqllogictest/test_files/count_star_rule.slt index d38d3490fed47..826742267290c 100644 --- a/datafusion/sqllogictest/test_files/count_star_rule.slt +++ b/datafusion/sqllogictest/test_files/count_star_rule.slt @@ -34,7 +34,7 @@ logical_plan 01)Projection: count(Int64(1)) AS count() 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: t -04)------EmptyRelation +04)------EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[1 as count()] 02)--PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index bb66aef2514c9..1e6183f48bac7 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -77,7 +77,7 @@ statement error DataFusion error: SQL error: ParserError\("Expected: HEADER, fou CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH LOCATION 'foo.csv'; # Unrecognized random clause -statement error DataFusion error: SQL error: ParserError\("Unexpected token FOOBAR"\) +statement error DataFusion error: SQL error: ParserError\("Expected: end of statement or ;, found: FOOBAR at Line: 1, Column: 47"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV FOOBAR BARBAR BARFOO LOCATION 'foo.csv'; # Missing partition column @@ -297,3 +297,9 @@ CREATE EXTERNAL TABLE staging.foo STORED AS parquet LOCATION '../../parquet-test # Create external table with qualified name, but no schema should error statement error DataFusion error: Error during planning: failed to resolve schema: release CREATE EXTERNAL TABLE release.bar STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; + +# Cannot create external table alongside `if_not_exists` and `or_replace` +statement error DataFusion error: SQL error: ParserError\("'IF NOT EXISTS' cannot coexist with 'REPLACE'"\) +CREATE OR REPLACE EXTERNAL TABLE IF NOT EXISTS t_conflict(c1 int) +STORED AS CSV +LOCATION 'foo.csv'; diff --git a/datafusion/sqllogictest/test_files/create_function.slt b/datafusion/sqllogictest/test_files/create_function.slt index 4f0c53c36ca1a..4e82c0866ee23 100644 --- a/datafusion/sqllogictest/test_files/create_function.slt +++ b/datafusion/sqllogictest/test_files/create_function.slt @@ -21,11 +21,6 @@ ## Note that DataFusion provides a pluggable system for creating functions ## but has no built in support for doing so. -# Use PostgresSQL dialect (until we upgrade to sqlparser 0.44, where CREATE FUNCTION) -# is supported in the Generic dialect (the default) -statement ok -set datafusion.sql_parser.dialect = 'Postgres'; - # Create function will fail unless a user supplied function factory is supplied statement error DataFusion error: Invalid or Unsupported Configuration: Function factory has not been configured CREATE FUNCTION foo (DOUBLE) RETURNS DOUBLE RETURN $1 + $2; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e019af9775a42..a581bcb539a91 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -33,7 +33,7 @@ EXPLAIN WITH "NUMBERS" AS (SELECT 1 as a, 2 as b, 3 as c) SELECT "NUMBERS".* FRO logical_plan 01)SubqueryAlias: NUMBERS 02)--Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c -03)----EmptyRelation +03)----EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[1 as a, 2 as b, 3 as c] 02)--PlaceholderRowExec @@ -107,10 +107,10 @@ logical_plan 01)SubqueryAlias: nodes 02)--RecursiveQuery: is_distinct=false 03)----Projection: Int64(1) AS id -04)------EmptyRelation +04)------EmptyRelation: rows=1 05)----Projection: nodes.id + Int64(1) AS id 06)------Filter: nodes.id < Int64(10) -07)--------TableScan: nodes +07)--------TableScan: nodes projection=[id] physical_plan 01)RecursiveQueryExec: name=nodes, is_distinct=false 02)--ProjectionExec: expr=[1 as id] @@ -152,11 +152,10 @@ logical_plan 01)Sort: balances.time ASC NULLS LAST, balances.name ASC NULLS LAST, balances.account_balance ASC NULLS LAST 02)--SubqueryAlias: balances 03)----RecursiveQuery: is_distinct=false -04)------Projection: balance.time, balance.name, balance.account_balance -05)--------TableScan: balance -06)------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance -07)--------Filter: balances.time < Int64(10) -08)----------TableScan: balances +04)------TableScan: balance projection=[time, name, account_balance] +05)------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance +06)--------Filter: balances.time < Int64(10) +07)----------TableScan: balances projection=[time, name, account_balance] physical_plan 01)SortExec: expr=[time@0 ASC NULLS LAST, name@1 ASC NULLS LAST, account_balance@2 ASC NULLS LAST], preserve_partitioning=[false] 02)--RecursiveQueryExec: name=balances, is_distinct=false @@ -720,14 +719,14 @@ logical_plan 01)SubqueryAlias: recursive_cte 02)--RecursiveQuery: is_distinct=false 03)----Projection: Int64(1) AS val -04)------EmptyRelation +04)------EmptyRelation: rows=1 05)----Projection: Int64(2) AS val -06)------Cross Join: +06)------Cross Join: 07)--------Filter: recursive_cte.val < Int64(2) 08)----------TableScan: recursive_cte 09)--------SubqueryAlias: sub_cte 10)----------Projection: Int64(2) AS val -11)------------EmptyRelation +11)------------EmptyRelation: rows=1 physical_plan 01)RecursiveQueryExec: name=recursive_cte, is_distinct=false 02)--ProjectionExec: expr=[1 as val] @@ -869,7 +868,7 @@ explain with numbers(a,b,c) as (select 1 as x, 2 as y, 3 as z) select * from num logical_plan 01)SubqueryAlias: numbers 02)--Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c -03)----EmptyRelation +03)----EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[1 as a, 2 as b, 3 as c] 02)--PlaceholderRowExec @@ -880,7 +879,7 @@ explain with numbers(a,b,c) as (select 1,2,3) select * from numbers; logical_plan 01)SubqueryAlias: numbers 02)--Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c -03)----EmptyRelation +03)----EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[1 as a, 2 as b, 3 as c] 02)--PlaceholderRowExec @@ -891,7 +890,7 @@ explain with numbers as (select 1 as a, 2 as b, 3 as c) select * from numbers; logical_plan 01)SubqueryAlias: numbers 02)--Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c -03)----EmptyRelation +03)----EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[1 as a, 2 as b, 3 as c] 02)--PlaceholderRowExec @@ -931,7 +930,7 @@ logical_plan 02)--TableScan: j1 projection=[a] 03)--SubqueryAlias: j2 04)----Projection: Int64(1) -05)------EmptyRelation +05)------EmptyRelation: rows=1 physical_plan 01)CrossJoinExec 02)--DataSourceExec: partitions=1, partition_sizes=[0] @@ -955,10 +954,10 @@ logical_plan 01)SubqueryAlias: numbers 02)--RecursiveQuery: is_distinct=false 03)----Projection: Int64(1) AS n -04)------EmptyRelation +04)------EmptyRelation: rows=1 05)----Projection: numbers.n + Int64(1) 06)------Filter: numbers.n < Int64(10) -07)--------TableScan: numbers +07)--------TableScan: numbers projection=[n] physical_plan 01)RecursiveQueryExec: name=numbers, is_distinct=false 02)--ProjectionExec: expr=[1 as n] @@ -981,10 +980,10 @@ logical_plan 01)SubqueryAlias: numbers 02)--RecursiveQuery: is_distinct=false 03)----Projection: Int64(1) AS n -04)------EmptyRelation +04)------EmptyRelation: rows=1 05)----Projection: numbers.n + Int64(1) 06)------Filter: numbers.n < Int64(10) -07)--------TableScan: numbers +07)--------TableScan: numbers projection=[n] physical_plan 01)RecursiveQueryExec: name=numbers, is_distinct=false 02)--ProjectionExec: expr=[1 as n] @@ -996,6 +995,60 @@ physical_plan 08)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)------------WorkTableExec: name=numbers +# Test for issue #16998: SortExec shares DynamicFilterPhysicalExpr across multiple executions +query II +with recursive r as ( + select 0 as k, 0 as v + union all + ( + select * + from r + order by v + limit 1 + ) +) +select * +from r +limit 5; +---- +0 0 +0 0 +0 0 +0 0 +0 0 + +query TT +explain +with recursive r as ( + select 0 as k, 0 as v + union all + ( + select * + from r + order by v + limit 1 + ) +) +select * +from r +limit 5; +---- +logical_plan +01)SubqueryAlias: r +02)--Limit: skip=0, fetch=5 +03)----RecursiveQuery: is_distinct=false +04)------Projection: Int64(0) AS k, Int64(0) AS v +05)--------EmptyRelation: rows=1 +06)------Sort: r.v ASC NULLS LAST, fetch=1 +07)--------TableScan: r projection=[k, v] +physical_plan +01)GlobalLimitExec: skip=0, fetch=5 +02)--RecursiveQueryExec: name=r, is_distinct=false +03)----ProjectionExec: expr=[0 as k, 0 as v] +04)------PlaceholderRowExec +05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false] +06)------WorkTableExec: name=r + statement count 0 set datafusion.execution.enable_recursive_ctes = false; diff --git a/datafusion/sqllogictest/test_files/current_date_timezone.slt b/datafusion/sqllogictest/test_files/current_date_timezone.slt new file mode 100644 index 0000000000000..b30373acfaa0e --- /dev/null +++ b/datafusion/sqllogictest/test_files/current_date_timezone.slt @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## current_date with timezone tests +########## + +# CI Fails https://github.com/apache/datafusion/issues/18062 + +# Test 1: Verify current_date is consistent within the same query (default UTC) +# query B +# SELECT current_date() = current_date(); +# ---- +# true + +# Test 2: Verify alias 'today' works the same as current_date +# query B +# SELECT current_date() = today(); +# ---- +# true + +# Test 3: Set timezone to +05:00 and verify current_date is still stable +# statement ok +# SET datafusion.execution.time_zone = '+05:00'; + +# query B +# SELECT current_date() = current_date(); +# ---- +# true + +# Test 4: Verify current_date matches cast(now() as date) in the same timezone +# query B +# SELECT current_date() = cast(now() as date); +# ---- +# true + +# Test 5: Test with negative offset timezone +# statement ok +# SET datafusion.execution.time_zone = '-08:00'; + +# query B +# SELECT current_date() = today(); +# ---- +# true + +# Test 6: Test with named timezone (America/New_York) +# statement ok +# SET datafusion.execution.time_zone = 'America/New_York'; + +# query B +# SELECT current_date() = current_date(); +# ---- +# true + +# Test 7: Verify date type is preserved +# query T +# SELECT arrow_typeof(current_date()); +# ---- +# Date32 + +# Test 8: Reset to UTC +# statement ok +# SET datafusion.execution.time_zone = '+00:00'; + +# query B +# SELECT current_date() = today(); +# ---- +# true diff --git a/datafusion/sqllogictest/test_files/dates.slt b/datafusion/sqllogictest/test_files/dates.slt index 4425eee333735..2e91a0363db06 100644 --- a/datafusion/sqllogictest/test_files/dates.slt +++ b/datafusion/sqllogictest/test_files/dates.slt @@ -108,6 +108,17 @@ SELECT '2023-01-01T00:00:00'::timestamp - DATE '2021-01-01'; ---- 730 days 0 hours 0 mins 0.000000000 secs +# NULL with DATE arithmetic should yield NULL +query ? +SELECT NULL - DATE '1984-02-28'; +---- +NULL + +query ? +SELECT DATE '1984-02-28' - NULL +---- +NULL + # to_date_test statement ok create table to_date_t1(ts bigint) as VALUES @@ -183,7 +194,7 @@ query error input contains invalid characters SELECT to_date('2020-09-08 12/00/00+00:00', '%c', '%+') # to_date with broken formatting -query error bad or unsupported format string +query error DataFusion error: Execution error: Error parsing timestamp from '2020\-09\-08 12/00/00\+00:00' using format '%q': trailing input SELECT to_date('2020-09-08 12/00/00+00:00', '%q') statement ok diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 088d0155a66f3..03ef08e1a5f83 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -272,7 +272,7 @@ drop table my_table # select_into statement ok -SELECT* INTO my_table FROM (SELECT * FROM aggregate_simple) +SELECT * INTO my_table FROM (SELECT * FROM aggregate_simple) query RRB rowsort SELECT * FROM my_table order by c1 LIMIT 1 @@ -587,7 +587,7 @@ statement ok CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); # Should not recreate the same EXTERNAL table -statement error Execution error: Table 'aggregate_simple' already exists +statement error Execution error: External table 'aggregate_simple' already exists CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); statement ok @@ -607,6 +607,55 @@ CREATE TABLE table_without_values(field1 BIGINT, field2 BIGINT); statement error Execution error: 'IF NOT EXISTS' cannot coexist with 'REPLACE' CREATE OR REPLACE TABLE IF NOT EXISTS table_without_values(field1 BIGINT, field2 BIGINT); +# CREATE OR REPLACE +statement ok +CREATE OR REPLACE EXTERNAL TABLE aggregate_simple_repl +STORED AS CSV +LOCATION '../core/tests/data/aggregate_simple.csv' +OPTIONS ('format.has_header' 'true'); + +statement ok +CREATE OR REPLACE EXTERNAL TABLE aggregate_simple_repl +STORED AS CSV +LOCATION '../core/tests/data/aggregate_simple.csv' +OPTIONS ('format.has_header' 'true'); + +# Create replacement table for table that doesn't already exist +statement ok +DROP TABLE IF EXISTS aggregate_table; + +statement ok +CREATE OR REPLACE EXTERNAL TABLE aggregate_table +STORED AS CSV +LOCATION '../core/tests/data/aggregate_simple.csv' +OPTIONS ('format.has_header' 'true'); + +query TTT +DESCRIBE aggregate_table; +---- +c1 Float64 YES +c2 Float64 YES +c3 Boolean YES + +# Create replacement table with different format for table that doesn't already exist +query I +COPY (SELECT * FROM (VALUES (1),(2),(3)) AS t(id)) +TO 'test_files/scratch/ddl/test_table' +STORED AS PARQUET; +---- +3 + +statement ok +CREATE OR REPLACE EXTERNAL TABLE aggregate_table +STORED AS PARQUET +LOCATION 'test_files/scratch/ddl/test_table'; + + +query TTT +DESCRIBE aggregate_table; +---- +id Int64 YES + # Should insert into an empty table statement ok insert into table_without_values values (1, 2), (2, 3), (2, 4); @@ -658,9 +707,9 @@ CREATE EXTERNAL TABLE empty STORED AS CSV LOCATION '../core/tests/data/empty.csv query TTI select column_name, data_type, ordinal_position from information_schema.columns where table_name='empty';; ---- -c1 Utf8 0 -c2 Utf8 1 -c3 Utf8 2 +c1 Null 0 +c2 Null 1 +c3 Null 2 ## should allow any type of exprs as values @@ -819,7 +868,7 @@ show columns FROM table_with_pk; ---- datafusion public table_with_pk sn Int32 NO datafusion public table_with_pk ts Timestamp(Nanosecond, Some("+00:00")) NO -datafusion public table_with_pk currency Utf8 NO +datafusion public table_with_pk currency Utf8View NO datafusion public table_with_pk amount Float32 YES statement ok @@ -828,18 +877,18 @@ drop table table_with_pk; statement ok set datafusion.catalog.information_schema = false; -# Test VARCHAR is mapped to Utf8View during SQL planning when setting map_varchar_to_utf8view to true +# Test VARCHAR is mapped to Utf8View during SQL planning when setting map_string_types_to_utf8view to true statement ok CREATE TABLE t1(c1 VARCHAR(10) NOT NULL, c2 VARCHAR); query TTT DESCRIBE t1; ---- -c1 Utf8 NO -c2 Utf8 YES +c1 Utf8View NO +c2 Utf8View YES statement ok -set datafusion.sql_parser.map_varchar_to_utf8view = true; +set datafusion.sql_parser.map_string_types_to_utf8view = true; statement ok CREATE TABLE t2(c1 VARCHAR(10) NOT NULL, c2 VARCHAR); diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 089910785ad9d..502821fcc3043 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -747,3 +747,163 @@ SELECT cast(cast('5.20' as decimal(4,2)) as decimal(3,2)) ---- 0 5.2 + +query RR +SELECT + arrow_cast(1.23,'Decimal128(3,2)') - arrow_cast(123, 'UInt64') as subtration_uint, + arrow_cast(1.23,'Decimal128(3,2)') - arrow_cast(123, 'Int64') as subtration_int +---- +-121.77 -121.77 + +query RR +SELECT + arrow_cast(1.23,'Decimal128(3,2)') + arrow_cast(123, 'UInt64') as addition_uint, + arrow_cast(1.23,'Decimal128(3,2)') + arrow_cast(123, 'Int64') as addition_int +---- +124.23 124.23 + +query RR +SELECT + arrow_cast(1.23,'Decimal128(3,2)') * arrow_cast(123, 'UInt64') as mulitplication_uint, + arrow_cast(1.23,'Decimal128(3,2)') * arrow_cast(123, 'Int64') as multiplication_int +---- +151.29 151.29 + +query RR +SELECT + arrow_cast(1.23,'Decimal128(3,2)') / arrow_cast(123, 'UInt64') as divison_uint, + arrow_cast(1.23,'Decimal128(3,2)') / arrow_cast(123, 'Int64') as divison_int +---- +0.01 0.01 + +query TR +with tt as ( + select arrow_cast(133333333333333333333333333333333333333333333.34, 'Decimal256(50, 2)') as v1 +) select arrow_typeof(v1 + 1.5), v1 + 1.5 from tt; +---- +Float64 133333333333333330000000000000000000000000000 + +# Following tests only make sense if numbers are parsed as decimals +# Remove when `parse_float_as_decimal` is true by default (#14612) +statement ok +set datafusion.sql_parser.parse_float_as_decimal = true; + +# smoke test for decimal parsing +query RT +select 100000000000000000000000000000000000::decimal(38,0), arrow_typeof(100000000000000000000000000000000000::decimal(38,0)); +---- +100000000000000000000000000000000000 Decimal128(38, 0) + +# log for small decimal128 +query R +select log(100::decimal(38,0)); +---- +2 + +# log for small decimal256 +query R +select log(100::decimal(76,0)); +---- +2 + +# log(10^21) for large decimal128 +query R +select log(10, 1000000000000000000000::decimal(38,0)); +---- +21 + +# log(10^35) for large decimal128 +# Must be 35 if parsed as decimal; 34 for floats +query R +select log(100000000000000000000000000000000000::decimal(38,0)) +---- +35 + +# Decimal overflow for 10^38 +query error Arrow error: Invalid argument error: .* is too large to store in a Decimal128 of precision 38. Max is +select log(100000000000000000000000000000000000000::decimal(38,0)) + +# log(10^35) for decimal256 for a value able to fit i128 +query R +select log(100000000000000000000000000000000000::decimal(76,0)); +---- +35 + +# log(10^50) for decimal256 for a value larger than i128 +query error Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported +select log(100000000000000000000000000000000000000000000000000::decimal(76,0)); + +# log(10^35) for decimal128 with explicit base +query R +select log(10, 100000000000000000000000000000000000::decimal(38,0)); +---- +35 + +# log(10^35) for decimal256 with explicit base - only float as a base +query R +select log(10.0, 100000000000000000000000000000000000::decimal(76,0)); +---- +35 + +# log(10^35) for decimal128 with explicit decimal base +query R +select log(10::decimal(38, 0), 100000000000000000000000000000000000::decimal(38,0)); +---- +35 + +# log(10^35) for decimal128 with another base +query R +select log(2, 100000000000000000000000000000000000::decimal(38,0)); +---- +116 + +# log(10^35) for decimal128 with another base +query R +select log(2.0, 100000000000000000000000000000000000::decimal(38,0)); +---- +116.267483321058 + +# null cases +query R +select log(null, 100); +---- +NULL + +query R +select log(null, 100000000000000000000000000000000000::decimal(38,0)); +---- +NULL + +query R +select log(null); +---- +NULL + +query R +select log(2.0, null); +---- +NULL + +# Set parse_float_as_decimal to false to test float parsing +statement ok +set datafusion.sql_parser.parse_float_as_decimal = false; + +# smoke test for decimal parsing +query R +select 100000000000000000000000000000000000::decimal(38,0) +---- +99999999999999996863366107917975552 + +# log(10^35) for decimal128 with explicit decimal base +# Float parsing is rounding down +query R +select log(10, 100000000000000000000000000000000000::decimal(38,0)); +---- +34 + +# log(10^35) for large decimal128 if parsed as float +# Float parsing is rounding down +query R +select log(100000000000000000000000000000000000::decimal(38,0)) +---- +34 diff --git a/datafusion/sqllogictest/test_files/delete.slt b/datafusion/sqllogictest/test_files/delete.slt new file mode 100644 index 0000000000000..258318f09423c --- /dev/null +++ b/datafusion/sqllogictest/test_files/delete.slt @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Delete Tests +########## + +statement ok +create table t1(a int, b varchar, c double, d int); + +# Turn off the optimizer to make the logical plan closer to the initial one +statement ok +set datafusion.optimizer.max_passes = 0; + + +# Delete all +query TT +explain delete from t1; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by existing columns +query TT +explain delete from t1 where a = 1 and b = 2 and c > 3 and d != 4; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Int64(2) AS Utf8View) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by existing columns, using qualified and unqualified names +query TT +explain delete from t1 where t1.a = 1 and b = 2 and t1.c > 3 and d != 4; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Int64(2) AS Utf8View) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by a mix of columns and literal predicates +query TT +explain delete from t1 where a = 1 and 1 = 1 and true; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND Int64(1) = Int64(1) AND Boolean(true) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Deleting by columns that do not exist returns an error +query error DataFusion error: Schema error: No field named e. Valid fields are t1.a, t1.b, t1.c, t1.d. +explain delete from t1 where e = 1; + + +# Filtering using subqueries + +statement ok +create table t2(a int, b varchar, c double, d int); + +query TT +explain delete from t1 where a = (select max(a) from t2 where t1.b = t2.b); +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: t1.a = () +03)----Subquery: +04)------Projection: max(t2.a) +05)--------Aggregate: groupBy=[[]], aggr=[[max(t2.a)]] +06)----------Filter: outer_ref(t1.b) = t2.b +07)------------TableScan: t2 +08)----TableScan: t1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() + +query TT +explain delete from t1 where a in (select a from t2); +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: t1.a IN () +03)----Subquery: +04)------Projection: t2.a +05)--------TableScan: t2 +06)----TableScan: t1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression InSubquery(InSubquery { expr: Column(Column { relation: Some(Bare { table: "t1" }), name: "a" }), subquery: , negated: false }) diff --git a/datafusion/sqllogictest/test_files/describe.slt b/datafusion/sqllogictest/test_files/describe.slt index e4cb30628eec5..de5208b5483aa 100644 --- a/datafusion/sqllogictest/test_files/describe.slt +++ b/datafusion/sqllogictest/test_files/describe.slt @@ -86,3 +86,33 @@ string_col Utf8View YES timestamp_col Timestamp(Nanosecond, None) YES year Int32 YES month Int32 YES + +# Test DESC alias functionality +statement ok +CREATE TABLE test_desc_table (id INT, name VARCHAR); + +# Test DESC works the same as DESCRIBE +query TTT +DESC test_desc_table; +---- +id Int32 YES +name Utf8View YES + +query TTT +DESCRIBE test_desc_table; +---- +id Int32 YES +name Utf8View YES + +# Test with qualified table names +statement ok +CREATE TABLE public.test_qualified (col1 INT); + +query TTT +DESC public.test_qualified; +---- +col1 Int32 YES + +# Test error cases +statement error +DESC nonexistent_table; diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 778b3537d1bff..9e8a39494095f 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -450,3 +450,10 @@ query I select dense_rank() over (order by arrow_cast('abc', 'Dictionary(UInt16, Utf8)')); ---- 1 + +# Test dictionary encoded column to partition column casting +statement ok +CREATE TABLE test0 AS VALUES ('foo',1), ('bar',2), ('foo',3); + +statement ok +COPY (SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') AS column1, column2 FROM test0) TO 'test_files/scratch/dictionary/part_dict_test' STORED AS PARQUET PARTITIONED BY (column1); diff --git a/datafusion/sqllogictest/test_files/encrypted_parquet.slt b/datafusion/sqllogictest/test_files/encrypted_parquet.slt new file mode 100644 index 0000000000000..d580b7d1ad2b8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/encrypted_parquet.slt @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test parquet encryption and decryption in DataFusion SQL. +# See datafusion/common/src/config.rs for equivalent rust code + +statement count 0 +CREATE EXTERNAL TABLE encrypted_parquet_table +( +double_field double, +float_field float +) +STORED AS PARQUET LOCATION 'test_files/scratch/encrypted_parquet/' OPTIONS ( + -- Configure encryption for reading and writing Parquet files + -- Encryption properties + 'format.crypto.file_encryption.encrypt_footer' 'true', + 'format.crypto.file_encryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + 'format.crypto.file_encryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + 'format.crypto.file_encryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + -- Decryption properties + 'format.crypto.file_decryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + 'format.crypto.file_decryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + 'format.crypto.file_decryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" +) + +statement count 0 +CREATE TABLE temp_table ( + double_field double, + float_field float +) + +query I +INSERT INTO temp_table VALUES(-1.0, -1.0) +---- +1 + +query I +INSERT INTO temp_table VALUES(1.0, 2.0) +---- +1 + +query I +INSERT INTO temp_table VALUES(3.0, 4.0) +---- +1 + +query I +INSERT INTO temp_table VALUES(5.0, 6.0) +---- +1 + +query I +INSERT INTO TABLE encrypted_parquet_table(double_field, float_field) SELECT * FROM temp_table +---- +4 + +query RR +SELECT * FROM encrypted_parquet_table +WHERE double_field > 0.0 AND float_field > 0.0 +ORDER BY double_field +---- +1 2 +3 4 +5 6 + +statement count 0 +CREATE EXTERNAL TABLE parquet_table +( +double_field double, +float_field float +) +STORED AS PARQUET LOCATION 'test_files/scratch/encrypted_parquet/' + +query error DataFusion error: Parquet error: Parquet error: Parquet file has an encrypted footer but decryption properties were not provided +SELECT * FROM parquet_table diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index dc7a53adf889d..3e60423df8a0e 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -120,7 +120,7 @@ from aggregate_test_100 order by c9 # WindowFunction wrong signature -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'nth_value' function: coercion from \[Int32, Int64, Int64\] to the signature OneOf\(\[Any\(0\), Any\(1\), Any\(2\)\]\) failed +statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'nth_value' function: coercion from Int32, Int64, Int64 to the signature OneOf\(\[Any\(0\), Any\(1\), Any\(2\)\]\) failed select c9, nth_value(c5, 2, 3) over (order by c9) as nv1 @@ -148,7 +148,7 @@ SELECT query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Int64 type create table foo as values (1), ('foo'); -query error user-defined coercion failed +query error DataFusion error: Error during planning: Substring without for/from is not valid select 1 group by substr(''); # Error in filter should be reported @@ -168,8 +168,9 @@ CREATE TABLE tab0(col0 INTEGER, col1 INTEGER, col2 INTEGER); statement ok INSERT INTO tab0 VALUES(83,0,38); -query error DataFusion error: Arrow error: Divide by zero error +query I SELECT DISTINCT - 84 FROM tab0 AS cor0 WHERE NOT + 96 / + col1 <= NULL GROUP BY col1, col0; +---- statement ok create table a(timestamp int, birthday int, ts int, tokens int, amp int, staamp int); diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index deff793e51106..a3b6d40aea2d1 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -183,6 +183,7 @@ logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -204,6 +205,7 @@ logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -224,21 +226,26 @@ initial_physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/dat initial_physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] initial_physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] physical_plan after OutputRequirements -01)OutputRequirementExec +01)OutputRequirementExec: order_by=[], dist_by=Unspecified 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after LimitPushPastWindows SAME TEXT AS ABOVE physical_plan after LimitPushdown SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after EnsureCooperative SAME TEXT AS ABOVE +physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] @@ -297,24 +304,29 @@ initial_physical_plan_with_schema 01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] physical_plan after OutputRequirements -01)OutputRequirementExec, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +01)OutputRequirementExec: order_by=[], dist_by=Unspecified, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after LimitPushPastWindows SAME TEXT AS ABOVE physical_plan after LimitPushdown DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after EnsureCooperative SAME TEXT AS ABOVE +physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] @@ -337,24 +349,29 @@ initial_physical_plan_with_schema 01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] physical_plan after OutputRequirements -01)OutputRequirementExec +01)OutputRequirementExec: order_by=[], dist_by=Unspecified 02)--GlobalLimitExec: skip=0, fetch=10 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after LimitPushPastWindows SAME TEXT AS ABOVE physical_plan after LimitPushdown DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after EnsureCooperative SAME TEXT AS ABOVE +physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -374,7 +391,7 @@ explain select make_array(make_array(1, 2, 3), make_array(4, 5, 6)); ---- logical_plan 01)Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) -02)--EmptyRelation +02)--EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] 02)--PlaceholderRowExec @@ -384,7 +401,7 @@ explain select [[1, 2, 3], [4, 5, 6]]; ---- logical_plan 01)Projection: List([[1, 2, 3], [4, 5, 6]]) AS make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6))) -02)--EmptyRelation +02)--EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[[[1, 2, 3], [4, 5, 6]] as make_array(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(4),Int64(5),Int64(6)))] 02)--PlaceholderRowExec @@ -396,7 +413,7 @@ explain select struct(1, 2.3, 'abc'); ---- logical_plan 01)Projection: Struct({c0:1,c1:2.3,c2:abc}) AS struct(Int64(1),Float64(2.3),Utf8("abc")) -02)--EmptyRelation +02)--EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))] 02)--PlaceholderRowExec @@ -415,14 +432,11 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[a] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: -05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -06)--------TableScan: t2 projection=[] +04)----EmptyRelation: rows=1 physical_plan 01)NestedLoopJoinExec: join_type=LeftSemi 02)--DataSourceExec: partitions=1, partition_sizes=[0] -03)--ProjectionExec: expr=[] -04)----PlaceholderRowExec +03)--PlaceholderRowExec statement ok drop table t1; @@ -509,11 +523,107 @@ explain format 123 select * from values (1); query error DataFusion error: Error during planning: EXPLAIN VERBOSE with FORMAT is not supported explain verbose format tree select * from values (1); +# valid explain format +query error DataFusion error: Invalid or Unsupported Configuration: Invalid explain format. Expected 'indent', 'tree', 'pgjson' or 'graphviz'. Got 'xxx' +set datafusion.explain.format = "xxx"; + +# verbose uses indent mode even when a different mode (e.g tree) is set + +statement ok +set datafusion.explain.format = "tree"; + +query TT +EXPLAIN VERBOSE SELECT a, b, c FROM simple_explain_test +---- +initial_logical_plan +01)Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c +02)--TableScan: simple_explain_test +logical_plan after resolve_grouping_function SAME TEXT AS ABOVE +logical_plan after type_coercion SAME TEXT AS ABOVE +analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE +logical_plan after simplify_expressions SAME TEXT AS ABOVE +logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE +logical_plan after eliminate_join SAME TEXT AS ABOVE +logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE +logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE +logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE +logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE +logical_plan after eliminate_filter SAME TEXT AS ABOVE +logical_plan after eliminate_cross_join SAME TEXT AS ABOVE +logical_plan after eliminate_limit SAME TEXT AS ABOVE +logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE +logical_plan after filter_null_join_keys SAME TEXT AS ABOVE +logical_plan after eliminate_outer_join SAME TEXT AS ABOVE +logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_filter SAME TEXT AS ABOVE +logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE +logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after eliminate_nested_union SAME TEXT AS ABOVE +logical_plan after simplify_expressions SAME TEXT AS ABOVE +logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE +logical_plan after eliminate_join SAME TEXT AS ABOVE +logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE +logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE +logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE +logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE +logical_plan after eliminate_filter SAME TEXT AS ABOVE +logical_plan after eliminate_cross_join SAME TEXT AS ABOVE +logical_plan after eliminate_limit SAME TEXT AS ABOVE +logical_plan after propagate_empty_relation SAME TEXT AS ABOVE +logical_plan after eliminate_one_union SAME TEXT AS ABOVE +logical_plan after filter_null_join_keys SAME TEXT AS ABOVE +logical_plan after eliminate_outer_join SAME TEXT AS ABOVE +logical_plan after push_down_limit SAME TEXT AS ABOVE +logical_plan after push_down_filter SAME TEXT AS ABOVE +logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE +logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after optimize_projections SAME TEXT AS ABOVE +logical_plan TableScan: simple_explain_test projection=[a, b, c] +initial_physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true +initial_physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +initial_physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] +physical_plan after OutputRequirements +01)OutputRequirementExec: order_by=[], dist_by=Unspecified +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true +physical_plan after aggregate_statistics SAME TEXT AS ABOVE +physical_plan after join_selection SAME TEXT AS ABOVE +physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE +physical_plan after EnforceDistribution SAME TEXT AS ABOVE +physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE +physical_plan after EnforceSorting SAME TEXT AS ABOVE +physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE +physical_plan after OutputRequirements DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true +physical_plan after LimitAggregation SAME TEXT AS ABOVE +physical_plan after LimitPushPastWindows SAME TEXT AS ABOVE +physical_plan after LimitPushdown SAME TEXT AS ABOVE +physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after EnsureCooperative SAME TEXT AS ABOVE +physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE +physical_plan after SanityCheckPlan SAME TEXT AS ABOVE +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true +physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] + +# Set back to original default value +statement ok +set datafusion.explain.format = "indent"; + # no such thing as json mode -query error DataFusion error: Error during planning: Invalid explain format\. Expected 'indent', 'tree', 'pgjson' or 'graphviz'\. Got 'json' +query error DataFusion error: Invalid or Unsupported Configuration: Invalid explain format\. Expected 'indent', 'tree', 'pgjson' or 'graphviz'\. Got 'json' explain format json select * from values (1); -query error DataFusion error: Error during planning: Invalid explain format\. Expected 'indent', 'tree', 'pgjson' or 'graphviz'\. Got 'foo' +query error DataFusion error: Invalid or Unsupported Configuration: Invalid explain format\. Expected 'indent', 'tree', 'pgjson' or 'graphviz'\. Got 'foo' explain format foo select * from values (1); # pgjson mode diff --git a/datafusion/sqllogictest/test_files/explain_tree.slt b/datafusion/sqllogictest/test_files/explain_tree.slt index 7a0e322eb8bcd..22f19a0af32e4 100644 --- a/datafusion/sqllogictest/test_files/explain_tree.slt +++ b/datafusion/sqllogictest/test_files/explain_tree.slt @@ -180,8 +180,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -218,8 +218,8 @@ physical_plan 18)┌─────────────┴─────────────┐ 19)│ RepartitionExec │ 20)│ -------------------- │ -21)│ output_partition_count: │ -22)│ 4 │ +21)│ partition_count(in->out): │ +22)│ 4 -> 4 │ 23)│ │ 24)│ partitioning_scheme: │ 25)│ Hash([string_col@0], 4) │ @@ -236,8 +236,8 @@ physical_plan 36)┌─────────────┴─────────────┐ 37)│ RepartitionExec │ 38)│ -------------------- │ -39)│ output_partition_count: │ -40)│ 1 │ +39)│ partition_count(in->out): │ +40)│ 1 -> 4 │ 41)│ │ 42)│ partitioning_scheme: │ 43)│ RoundRobinBatch(4) │ @@ -280,58 +280,60 @@ physical_plan 06)┌─────────────┴─────────────┐ 07)│ DataSourceExec │ 08)│ -------------------- │ -09)│ bytes: 3120 │ +09)│ bytes: 1040 │ 10)│ format: memory │ 11)│ rows: 2 │ 12)└───────────────────────────┘ # 2 Joins query TT -explain SELECT table1.string_col, table2.date_col FROM table1 JOIN table2 ON table1.int_col = table2.int_col; +EXPLAIN SELECT table1.string_col, table2.date_col +FROM table1 +JOIN table2 +ON + (table1.int_col = table2.int_col) + AND (((table1.int_col + table2.int_col) % 2) = 0) ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ on: ├──────────────┐ -11)│ (int_col = int_col) │ │ -12)└─────────────┬─────────────┘ │ -13)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -14)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -15)│ -------------------- ││ -------------------- │ -16)│ target_batch_size: ││ target_batch_size: │ -17)│ 8192 ││ 8192 │ -18)└─────────────┬─────────────┘└─────────────┬─────────────┘ -19)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -20)│ RepartitionExec ││ RepartitionExec │ -21)│ -------------------- ││ -------------------- │ -22)│ output_partition_count: ││ output_partition_count: │ -23)│ 4 ││ 4 │ -24)│ ││ │ -25)│ partitioning_scheme: ││ partitioning_scheme: │ -26)│ Hash([int_col@0], 4) ││ Hash([int_col@0], 4) │ -27)└─────────────┬─────────────┘└─────────────┬─────────────┘ -28)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -29)│ RepartitionExec ││ RepartitionExec │ -30)│ -------------------- ││ -------------------- │ -31)│ output_partition_count: ││ output_partition_count: │ -32)│ 1 ││ 1 │ -33)│ ││ │ -34)│ partitioning_scheme: ││ partitioning_scheme: │ -35)│ RoundRobinBatch(4) ││ RoundRobinBatch(4) │ -36)└─────────────┬─────────────┘└─────────────┬─────────────┘ -37)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -38)│ DataSourceExec ││ DataSourceExec │ -39)│ -------------------- ││ -------------------- │ -40)│ files: 1 ││ files: 1 │ -41)│ format: csv ││ format: parquet │ -42)└───────────────────────────┘└───────────────────────────┘ +04)│ date_col: date_col │ +05)│ │ +06)│ string_col: │ +07)│ string_col │ +08)└─────────────┬─────────────┘ +09)┌─────────────┴─────────────┐ +10)│ CoalesceBatchesExec │ +11)│ -------------------- │ +12)│ target_batch_size: │ +13)│ 8192 │ +14)└─────────────┬─────────────┘ +15)┌─────────────┴─────────────┐ +16)│ HashJoinExec │ +17)│ -------------------- │ +18)│ filter: │ +19)│ CAST(int_col + int_col AS │ +20)│ Int64) % 2 = 0 ├──────────────┐ +21)│ │ │ +22)│ on: │ │ +23)│ (int_col = int_col) │ │ +24)└─────────────┬─────────────┘ │ +25)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +26)│ DataSourceExec ││ RepartitionExec │ +27)│ -------------------- ││ -------------------- │ +28)│ files: 1 ││ partition_count(in->out): │ +29)│ format: parquet ││ 1 -> 4 │ +30)│ ││ │ +31)│ ││ partitioning_scheme: │ +32)│ ││ RoundRobinBatch(4) │ +33)└───────────────────────────┘└─────────────┬─────────────┘ +34)-----------------------------┌─────────────┴─────────────┐ +35)-----------------------------│ DataSourceExec │ +36)-----------------------------│ -------------------- │ +37)-----------------------------│ files: 1 │ +38)-----------------------------│ format: csv │ +39)-----------------------------└───────────────────────────┘ # 3 Joins query TT @@ -365,48 +367,41 @@ physical_plan 19)│ (int_col = int_col) │ │ 20)└─────────────┬─────────────┘ │ 21)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -22)│ DataSourceExec ││ CoalesceBatchesExec │ +22)│ DataSourceExec ││ ProjectionExec │ 23)│ -------------------- ││ -------------------- │ -24)│ bytes: 1560 ││ target_batch_size: │ -25)│ format: memory ││ 8192 │ +24)│ bytes: 520 ││ date_col: date_col │ +25)│ format: memory ││ int_col: int_col │ 26)│ rows: 1 ││ │ -27)└───────────────────────────┘└─────────────┬─────────────┘ -28)-----------------------------┌─────────────┴─────────────┐ -29)-----------------------------│ HashJoinExec │ -30)-----------------------------│ -------------------- │ -31)-----------------------------│ on: ├──────────────┐ -32)-----------------------------│ (int_col = int_col) │ │ -33)-----------------------------└─────────────┬─────────────┘ │ -34)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -35)-----------------------------│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -36)-----------------------------│ -------------------- ││ -------------------- │ -37)-----------------------------│ target_batch_size: ││ target_batch_size: │ -38)-----------------------------│ 8192 ││ 8192 │ -39)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -40)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -41)-----------------------------│ RepartitionExec ││ RepartitionExec │ -42)-----------------------------│ -------------------- ││ -------------------- │ -43)-----------------------------│ output_partition_count: ││ output_partition_count: │ -44)-----------------------------│ 4 ││ 4 │ -45)-----------------------------│ ││ │ -46)-----------------------------│ partitioning_scheme: ││ partitioning_scheme: │ -47)-----------------------------│ Hash([int_col@0], 4) ││ Hash([int_col@0], 4) │ -48)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -49)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -50)-----------------------------│ RepartitionExec ││ RepartitionExec │ -51)-----------------------------│ -------------------- ││ -------------------- │ -52)-----------------------------│ output_partition_count: ││ output_partition_count: │ -53)-----------------------------│ 1 ││ 1 │ -54)-----------------------------│ ││ │ -55)-----------------------------│ partitioning_scheme: ││ partitioning_scheme: │ -56)-----------------------------│ RoundRobinBatch(4) ││ RoundRobinBatch(4) │ -57)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -58)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -59)-----------------------------│ DataSourceExec ││ DataSourceExec │ -60)-----------------------------│ -------------------- ││ -------------------- │ -61)-----------------------------│ files: 1 ││ files: 1 │ -62)-----------------------------│ format: csv ││ format: parquet │ -63)-----------------------------└───────────────────────────┘└───────────────────────────┘ +27)│ ││ string_col: │ +28)│ ││ string_col │ +29)└───────────────────────────┘└─────────────┬─────────────┘ +30)-----------------------------┌─────────────┴─────────────┐ +31)-----------------------------│ CoalesceBatchesExec │ +32)-----------------------------│ -------------------- │ +33)-----------------------------│ target_batch_size: │ +34)-----------------------------│ 8192 │ +35)-----------------------------└─────────────┬─────────────┘ +36)-----------------------------┌─────────────┴─────────────┐ +37)-----------------------------│ HashJoinExec │ +38)-----------------------------│ -------------------- │ +39)-----------------------------│ on: ├──────────────┐ +40)-----------------------------│ (int_col = int_col) │ │ +41)-----------------------------└─────────────┬─────────────┘ │ +42)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +43)-----------------------------│ DataSourceExec ││ RepartitionExec │ +44)-----------------------------│ -------------------- ││ -------------------- │ +45)-----------------------------│ files: 1 ││ partition_count(in->out): │ +46)-----------------------------│ format: parquet ││ 1 -> 4 │ +47)-----------------------------│ ││ │ +48)-----------------------------│ predicate: ││ partitioning_scheme: │ +49)-----------------------------│ DynamicFilter [ empty ] ││ RoundRobinBatch(4) │ +50)-----------------------------└───────────────────────────┘└─────────────┬─────────────┘ +51)----------------------------------------------------------┌─────────────┴─────────────┐ +52)----------------------------------------------------------│ DataSourceExec │ +53)----------------------------------------------------------│ -------------------- │ +54)----------------------------------------------------------│ files: 1 │ +55)----------------------------------------------------------│ format: csv │ +56)----------------------------------------------------------└───────────────────────────┘ # Long Filter (demonstrate what happens with wrapping) query TT @@ -434,8 +429,8 @@ physical_plan 17)┌─────────────┴─────────────┐ 18)│ RepartitionExec │ 19)│ -------------------- │ -20)│ output_partition_count: │ -21)│ 1 │ +20)│ partition_count(in->out): │ +21)│ 1 -> 4 │ 22)│ │ 23)│ partitioning_scheme: │ 24)│ RoundRobinBatch(4) │ @@ -496,8 +491,8 @@ physical_plan 41)┌─────────────┴─────────────┐ 42)│ RepartitionExec │ 43)│ -------------------- │ -44)│ output_partition_count: │ -45)│ 1 │ +44)│ partition_count(in->out): │ +45)│ 1 -> 4 │ 46)│ │ 47)│ partitioning_scheme: │ 48)│ RoundRobinBatch(4) │ @@ -530,8 +525,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -566,8 +561,8 @@ physical_plan 15)┌─────────────┴─────────────┐ 16)│ RepartitionExec │ 17)│ -------------------- │ -18)│ output_partition_count: │ -19)│ 1 │ +18)│ partition_count(in->out): │ +19)│ 1 -> 4 │ 20)│ │ 21)│ partitioning_scheme: │ 22)│ RoundRobinBatch(4) │ @@ -599,8 +594,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -633,8 +628,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -669,7 +664,7 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ DataSourceExec │ 15)│ -------------------- │ -16)│ bytes: 1560 │ +16)│ bytes: 520 │ 17)│ format: memory │ 18)│ rows: 1 │ 19)└───────────────────────────┘ @@ -694,8 +689,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -727,8 +722,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -889,7 +884,7 @@ explain SELECT * FROM table1 ORDER BY string_col LIMIT 1; ---- physical_plan 01)┌───────────────────────────┐ -02)│ SortExec │ +02)│ SortExec(TopK) │ 03)│ -------------------- │ 04)│ limit: 1 │ 05)│ │ @@ -922,8 +917,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -1029,21 +1024,11 @@ physical_plan 11)│ bigint_col │ 12)└─────────────┬─────────────┘ 13)┌─────────────┴─────────────┐ -14)│ RepartitionExec │ +14)│ DataSourceExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ -18)│ │ -19)│ partitioning_scheme: │ -20)│ RoundRobinBatch(4) │ -21)└─────────────┬─────────────┘ -22)┌─────────────┴─────────────┐ -23)│ DataSourceExec │ -24)│ -------------------- │ -25)│ files: 1 │ -26)│ format: parquet │ -27)└───────────────────────────┘ - +16)│ files: 1 │ +17)│ format: parquet │ +18)└───────────────────────────┘ # Query with projection on memory query TT @@ -1065,7 +1050,7 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ DataSourceExec │ 15)│ -------------------- │ -16)│ bytes: 1560 │ +16)│ bytes: 520 │ 17)│ format: memory │ 18)│ rows: 1 │ 19)└───────────────────────────┘ @@ -1089,8 +1074,8 @@ physical_plan 12)┌─────────────┴─────────────┐ 13)│ RepartitionExec │ 14)│ -------------------- │ -15)│ output_partition_count: │ -16)│ 1 │ +15)│ partition_count(in->out): │ +16)│ 1 -> 4 │ 17)│ │ 18)│ partitioning_scheme: │ 19)│ RoundRobinBatch(4) │ @@ -1123,8 +1108,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ output_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -1186,69 +1171,46 @@ explain select * from table1 inner join table2 on table1.int_col = table2.int_co ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ on: │ -11)│ (int_col = int_col), (CAST├──────────────┐ -12)│ (table1.string_col AS │ │ -13)│ Utf8View) = │ │ -14)│ string_col) │ │ -15)└─────────────┬─────────────┘ │ -16)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -17)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -18)│ -------------------- ││ -------------------- │ -19)│ target_batch_size: ││ target_batch_size: │ -20)│ 8192 ││ 8192 │ -21)└─────────────┬─────────────┘└─────────────┬─────────────┘ -22)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -23)│ RepartitionExec ││ RepartitionExec │ -24)│ -------------------- ││ -------------------- │ -25)│ output_partition_count: ││ output_partition_count: │ -26)│ 4 ││ 4 │ -27)│ ││ │ -28)│ partitioning_scheme: ││ partitioning_scheme: │ -29)│ Hash([int_col@0, CAST ││ Hash([int_col@0, │ -30)│ (table1.string_col ││ string_col@1], │ -31)│ AS Utf8View)@4], 4) ││ 4) │ -32)└─────────────┬─────────────┘└─────────────┬─────────────┘ -33)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -34)│ ProjectionExec ││ RepartitionExec │ -35)│ -------------------- ││ -------------------- │ -36)│ CAST(table1.string_col AS ││ output_partition_count: │ -37)│ Utf8View): ││ 1 │ -38)│ CAST(string_col AS ││ │ -39)│ Utf8View) ││ partitioning_scheme: │ -40)│ ││ RoundRobinBatch(4) │ -41)│ bigint_col: ││ │ -42)│ bigint_col ││ │ -43)│ ││ │ -44)│ date_col: date_col ││ │ -45)│ int_col: int_col ││ │ -46)│ ││ │ -47)│ string_col: ││ │ -48)│ string_col ││ │ -49)└─────────────┬─────────────┘└─────────────┬─────────────┘ -50)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -51)│ RepartitionExec ││ DataSourceExec │ -52)│ -------------------- ││ -------------------- │ -53)│ output_partition_count: ││ files: 1 │ -54)│ 1 ││ format: parquet │ -55)│ ││ │ -56)│ partitioning_scheme: ││ │ -57)│ RoundRobinBatch(4) ││ │ -58)└─────────────┬─────────────┘└───────────────────────────┘ -59)┌─────────────┴─────────────┐ -60)│ DataSourceExec │ -61)│ -------------------- │ -62)│ files: 1 │ -63)│ format: csv │ -64)└───────────────────────────┘ +04)│ bigint_col: │ +05)│ bigint_col │ +06)│ │ +07)│ date_col: date_col │ +08)│ int_col: int_col │ +09)│ │ +10)│ string_col: │ +11)│ string_col │ +12)└─────────────┬─────────────┘ +13)┌─────────────┴─────────────┐ +14)│ CoalesceBatchesExec │ +15)│ -------------------- │ +16)│ target_batch_size: │ +17)│ 8192 │ +18)└─────────────┬─────────────┘ +19)┌─────────────┴─────────────┐ +20)│ HashJoinExec │ +21)│ -------------------- │ +22)│ on: │ +23)│ (int_col = int_col), ├──────────────┐ +24)│ (string_col = │ │ +25)│ string_col) │ │ +26)└─────────────┬─────────────┘ │ +27)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +28)│ DataSourceExec ││ RepartitionExec │ +29)│ -------------------- ││ -------------------- │ +30)│ files: 1 ││ partition_count(in->out): │ +31)│ format: parquet ││ 1 -> 4 │ +32)│ ││ │ +33)│ ││ partitioning_scheme: │ +34)│ ││ RoundRobinBatch(4) │ +35)└───────────────────────────┘└─────────────┬─────────────┘ +36)-----------------------------┌─────────────┴─────────────┐ +37)-----------------------------│ DataSourceExec │ +38)-----------------------------│ -------------------- │ +39)-----------------------------│ files: 1 │ +40)-----------------------------│ format: csv │ +41)-----------------------------└───────────────────────────┘ # Query with outer hash join. query TT @@ -1256,71 +1218,48 @@ explain select * from table1 left outer join table2 on table1.int_col = table2.i ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ join_type: Left │ -11)│ │ -12)│ on: ├──────────────┐ -13)│ (int_col = int_col), (CAST│ │ -14)│ (table1.string_col AS │ │ -15)│ Utf8View) = │ │ -16)│ string_col) │ │ -17)└─────────────┬─────────────┘ │ -18)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -19)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -20)│ -------------------- ││ -------------------- │ -21)│ target_batch_size: ││ target_batch_size: │ -22)│ 8192 ││ 8192 │ -23)└─────────────┬─────────────┘└─────────────┬─────────────┘ -24)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -25)│ RepartitionExec ││ RepartitionExec │ -26)│ -------------------- ││ -------------------- │ -27)│ output_partition_count: ││ output_partition_count: │ -28)│ 4 ││ 4 │ -29)│ ││ │ -30)│ partitioning_scheme: ││ partitioning_scheme: │ -31)│ Hash([int_col@0, CAST ││ Hash([int_col@0, │ -32)│ (table1.string_col ││ string_col@1], │ -33)│ AS Utf8View)@4], 4) ││ 4) │ -34)└─────────────┬─────────────┘└─────────────┬─────────────┘ -35)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -36)│ ProjectionExec ││ RepartitionExec │ -37)│ -------------------- ││ -------------------- │ -38)│ CAST(table1.string_col AS ││ output_partition_count: │ -39)│ Utf8View): ││ 1 │ -40)│ CAST(string_col AS ││ │ -41)│ Utf8View) ││ partitioning_scheme: │ -42)│ ││ RoundRobinBatch(4) │ -43)│ bigint_col: ││ │ -44)│ bigint_col ││ │ -45)│ ││ │ -46)│ date_col: date_col ││ │ -47)│ int_col: int_col ││ │ -48)│ ││ │ -49)│ string_col: ││ │ -50)│ string_col ││ │ -51)└─────────────┬─────────────┘└─────────────┬─────────────┘ -52)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -53)│ RepartitionExec ││ DataSourceExec │ -54)│ -------------------- ││ -------------------- │ -55)│ output_partition_count: ││ files: 1 │ -56)│ 1 ││ format: parquet │ -57)│ ││ │ -58)│ partitioning_scheme: ││ │ -59)│ RoundRobinBatch(4) ││ │ -60)└─────────────┬─────────────┘└───────────────────────────┘ -61)┌─────────────┴─────────────┐ -62)│ DataSourceExec │ -63)│ -------------------- │ -64)│ files: 1 │ -65)│ format: csv │ -66)└───────────────────────────┘ +04)│ bigint_col: │ +05)│ bigint_col │ +06)│ │ +07)│ date_col: date_col │ +08)│ int_col: int_col │ +09)│ │ +10)│ string_col: │ +11)│ string_col │ +12)└─────────────┬─────────────┘ +13)┌─────────────┴─────────────┐ +14)│ CoalesceBatchesExec │ +15)│ -------------------- │ +16)│ target_batch_size: │ +17)│ 8192 │ +18)└─────────────┬─────────────┘ +19)┌─────────────┴─────────────┐ +20)│ HashJoinExec │ +21)│ -------------------- │ +22)│ join_type: Right │ +23)│ │ +24)│ on: ├──────────────┐ +25)│ (int_col = int_col), │ │ +26)│ (string_col = │ │ +27)│ string_col) │ │ +28)└─────────────┬─────────────┘ │ +29)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +30)│ DataSourceExec ││ RepartitionExec │ +31)│ -------------------- ││ -------------------- │ +32)│ files: 1 ││ partition_count(in->out): │ +33)│ format: parquet ││ 1 -> 4 │ +34)│ ││ │ +35)│ ││ partitioning_scheme: │ +36)│ ││ RoundRobinBatch(4) │ +37)└───────────────────────────┘└─────────────┬─────────────┘ +38)-----------------------------┌─────────────┴─────────────┐ +39)-----------------------------│ DataSourceExec │ +40)-----------------------------│ -------------------- │ +41)-----------------------------│ files: 1 │ +42)-----------------------------│ format: csv │ +43)-----------------------------└───────────────────────────┘ # Query with nested loop join. query TT @@ -1333,41 +1272,11 @@ physical_plan 04)│ join_type: LeftSemi │ │ 05)└─────────────┬─────────────┘ │ 06)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -07)│ DataSourceExec ││ ProjectionExec │ +07)│ DataSourceExec ││ PlaceholderRowExec │ 08)│ -------------------- ││ │ 09)│ files: 1 ││ │ 10)│ format: csv ││ │ -11)└───────────────────────────┘└─────────────┬─────────────┘ -12)-----------------------------┌─────────────┴─────────────┐ -13)-----------------------------│ AggregateExec │ -14)-----------------------------│ -------------------- │ -15)-----------------------------│ aggr: count(1) │ -16)-----------------------------│ mode: Final │ -17)-----------------------------└─────────────┬─────────────┘ -18)-----------------------------┌─────────────┴─────────────┐ -19)-----------------------------│ CoalescePartitionsExec │ -20)-----------------------------└─────────────┬─────────────┘ -21)-----------------------------┌─────────────┴─────────────┐ -22)-----------------------------│ AggregateExec │ -23)-----------------------------│ -------------------- │ -24)-----------------------------│ aggr: count(1) │ -25)-----------------------------│ mode: Partial │ -26)-----------------------------└─────────────┬─────────────┘ -27)-----------------------------┌─────────────┴─────────────┐ -28)-----------------------------│ RepartitionExec │ -29)-----------------------------│ -------------------- │ -30)-----------------------------│ output_partition_count: │ -31)-----------------------------│ 1 │ -32)-----------------------------│ │ -33)-----------------------------│ partitioning_scheme: │ -34)-----------------------------│ RoundRobinBatch(4) │ -35)-----------------------------└─────────────┬─────────────┘ -36)-----------------------------┌─────────────┴─────────────┐ -37)-----------------------------│ DataSourceExec │ -38)-----------------------------│ -------------------- │ -39)-----------------------------│ files: 1 │ -40)-----------------------------│ format: parquet │ -41)-----------------------------└───────────────────────────┘ +11)└───────────────────────────┘└───────────────────────────┘ # Query with cross join. query TT @@ -1378,21 +1287,11 @@ physical_plan 02)│ CrossJoinExec ├──────────────┐ 03)└─────────────┬─────────────┘ │ 04)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -05)│ DataSourceExec ││ RepartitionExec │ +05)│ DataSourceExec ││ DataSourceExec │ 06)│ -------------------- ││ -------------------- │ -07)│ files: 1 ││ output_partition_count: │ -08)│ format: csv ││ 1 │ -09)│ ││ │ -10)│ ││ partitioning_scheme: │ -11)│ ││ RoundRobinBatch(4) │ -12)└───────────────────────────┘└─────────────┬─────────────┘ -13)-----------------------------┌─────────────┴─────────────┐ -14)-----------------------------│ DataSourceExec │ -15)-----------------------------│ -------------------- │ -16)-----------------------------│ files: 1 │ -17)-----------------------------│ format: parquet │ -18)-----------------------------└───────────────────────────┘ - +07)│ files: 1 ││ files: 1 │ +08)│ format: csv ││ format: parquet │ +09)└───────────────────────────┘└───────────────────────────┘ # Query with sort merge join. statement ok @@ -1415,7 +1314,7 @@ physical_plan 11)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 12)│ DataSourceExec ││ DataSourceExec │ 13)│ -------------------- ││ -------------------- │ -14)│ bytes: 6040 ││ bytes: 6040 │ +14)│ bytes: 5932 ││ bytes: 5932 │ 15)│ format: memory ││ format: memory │ 16)│ rows: 1 ││ rows: 1 │ 17)└───────────────────────────┘└───────────────────────────┘ @@ -1505,8 +1404,8 @@ physical_plan 33)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 34)│ RepartitionExec ││ RepartitionExec │ 35)│ -------------------- ││ -------------------- │ -36)│ output_partition_count: ││ output_partition_count: │ -37)│ 4 ││ 4 │ +36)│ partition_count(in->out): ││ partition_count(in->out): │ +37)│ 4 -> 4 ││ 4 -> 4 │ 38)│ ││ │ 39)│ partitioning_scheme: ││ partitioning_scheme: │ 40)│ Hash([name@0], 4) ││ Hash([name@0], 4) │ @@ -1514,8 +1413,8 @@ physical_plan 42)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 43)│ RepartitionExec ││ RepartitionExec │ 44)│ -------------------- ││ -------------------- │ -45)│ output_partition_count: ││ output_partition_count: │ -46)│ 1 ││ 1 │ +45)│ partition_count(in->out): ││ partition_count(in->out): │ +46)│ 1 -> 4 ││ 1 -> 4 │ 47)│ ││ │ 48)│ partitioning_scheme: ││ partitioning_scheme: │ 49)│ RoundRobinBatch(4) ││ RoundRobinBatch(4) │ @@ -1529,7 +1428,7 @@ physical_plan 57)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 58)│ DataSourceExec ││ DataSourceExec │ 59)│ -------------------- ││ -------------------- │ -60)│ bytes: 1320 ││ bytes: 1312 │ +60)│ bytes: 296 ││ bytes: 288 │ 61)│ format: memory ││ format: memory │ 62)│ rows: 1 ││ rows: 1 │ 63)└───────────────────────────┘└───────────────────────────┘ @@ -1548,14 +1447,14 @@ physical_plan 04)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 05)│ DataSourceExec ││ ProjectionExec │ 06)│ -------------------- ││ -------------------- │ -07)│ bytes: 1320 ││ id: CAST(id AS Int32) │ +07)│ bytes: 296 ││ id: CAST(id AS Int32) │ 08)│ format: memory ││ name: name │ 09)│ rows: 1 ││ │ 10)└───────────────────────────┘└─────────────┬─────────────┘ 11)-----------------------------┌─────────────┴─────────────┐ 12)-----------------------------│ DataSourceExec │ 13)-----------------------------│ -------------------- │ -14)-----------------------------│ bytes: 1312 │ +14)-----------------------------│ bytes: 288 │ 15)-----------------------------│ format: memory │ 16)-----------------------------│ rows: 1 │ 17)-----------------------------└───────────────────────────┘ @@ -1606,8 +1505,8 @@ physical_plan 18)┌─────────────┴─────────────┐ 19)│ RepartitionExec │ 20)│ -------------------- │ -21)│ output_partition_count: │ -22)│ 1 │ +21)│ partition_count(in->out): │ +22)│ 1 -> 4 │ 23)│ │ 24)│ partitioning_scheme: │ 25)│ RoundRobinBatch(4) │ @@ -1648,8 +1547,8 @@ physical_plan 19)┌─────────────┴─────────────┐ 20)│ RepartitionExec │ 21)│ -------------------- │ -22)│ output_partition_count: │ -23)│ 1 │ +22)│ partition_count(in->out): │ +23)│ 1 -> 4 │ 24)│ │ 25)│ partitioning_scheme: │ 26)│ RoundRobinBatch(4) │ @@ -1689,8 +1588,8 @@ physical_plan 19)┌─────────────┴─────────────┐ 20)│ RepartitionExec │ 21)│ -------------------- │ -22)│ output_partition_count: │ -23)│ 1 │ +22)│ partition_count(in->out): │ +23)│ 1 -> 4 │ 24)│ │ 25)│ partitioning_scheme: │ 26)│ RoundRobinBatch(4) │ @@ -1728,8 +1627,8 @@ physical_plan 17)┌─────────────┴─────────────┐ 18)│ RepartitionExec │ 19)│ -------------------- │ -20)│ output_partition_count: │ -21)│ 1 │ +20)│ partition_count(in->out): │ +21)│ 1 -> 4 │ 22)│ │ 23)│ partitioning_scheme: │ 24)│ RoundRobinBatch(4) │ @@ -1771,8 +1670,8 @@ physical_plan 20)┌─────────────┴─────────────┐ 21)│ RepartitionExec │ 22)│ -------------------- │ -23)│ output_partition_count: │ -24)│ 1 │ +23)│ partition_count(in->out): │ +24)│ 1 -> 4 │ 25)│ │ 26)│ partitioning_scheme: │ 27)│ RoundRobinBatch(4) │ @@ -1786,6 +1685,54 @@ physical_plan +# query +query TT +explain SELECT * FROM data +WHERE date = '2006-01-02' +ORDER BY "ticker", "time" +LIMIT 5; +---- +physical_plan +01)┌───────────────────────────┐ +02)│ SortPreservingMergeExec │ +03)│ -------------------- │ +04)│ limit: 5 │ +05)│ │ +06)│ ticker ASC NULLS LAST, │ +07)│ time ASC NULLS LAST │ +08)└─────────────┬─────────────┘ +09)┌─────────────┴─────────────┐ +10)│ CoalesceBatchesExec │ +11)│ -------------------- │ +12)│ limit: 5 │ +13)│ │ +14)│ target_batch_size: │ +15)│ 8192 │ +16)└─────────────┬─────────────┘ +17)┌─────────────┴─────────────┐ +18)│ FilterExec │ +19)│ -------------------- │ +20)│ predicate: │ +21)│ date = 2006-01-02 │ +22)└─────────────┬─────────────┘ +23)┌─────────────┴─────────────┐ +24)│ RepartitionExec │ +25)│ -------------------- │ +26)│ partition_count(in->out): │ +27)│ 1 -> 4 │ +28)│ │ +29)│ partitioning_scheme: │ +30)│ RoundRobinBatch(4) │ +31)└─────────────┬─────────────┘ +32)┌─────────────┴─────────────┐ +33)│ StreamingTableExec │ +34)│ -------------------- │ +35)│ infinite: true │ +36)│ limit: None │ +37)└───────────────────────────┘ + + + # query query TT @@ -1815,8 +1762,8 @@ physical_plan 19)┌─────────────┴─────────────┐ 20)│ RepartitionExec │ 21)│ -------------------- │ -22)│ output_partition_count: │ -23)│ 1 │ +22)│ partition_count(in->out): │ +23)│ 1 -> 4 │ 24)│ │ 25)│ partitioning_scheme: │ 26)│ RoundRobinBatch(4) │ @@ -1869,8 +1816,8 @@ physical_plan 25)-----------------------------┌─────────────┴─────────────┐ 26)-----------------------------│ RepartitionExec │ 27)-----------------------------│ -------------------- │ -28)-----------------------------│ output_partition_count: │ -29)-----------------------------│ 1 │ +28)-----------------------------│ partition_count(in->out): │ +29)-----------------------------│ 1 -> 4 │ 30)-----------------------------│ │ 31)-----------------------------│ partitioning_scheme: │ 32)-----------------------------│ RoundRobinBatch(4) │ @@ -1899,7 +1846,7 @@ physical_plan 11)┌─────────────┴─────────────┐ 12)│ DataSourceExec │ 13)│ -------------------- │ -14)│ bytes: 2672 │ +14)│ bytes: 2576 │ 15)│ format: memory │ 16)│ rows: 1 │ 17)└───────────────────────────┘ @@ -1922,7 +1869,7 @@ physical_plan 11)┌─────────────┴─────────────┐ 12)│ DataSourceExec │ 13)│ -------------------- │ -14)│ bytes: 2672 │ +14)│ bytes: 2576 │ 15)│ format: memory │ 16)│ rows: 1 │ 17)└───────────────────────────┘ @@ -1945,7 +1892,7 @@ physical_plan 11)┌─────────────┴─────────────┐ 12)│ DataSourceExec │ 13)│ -------------------- │ -14)│ bytes: 2672 │ +14)│ bytes: 2576 │ 15)│ format: memory │ 16)│ rows: 1 │ 17)└───────────────────────────┘ @@ -1983,8 +1930,8 @@ physical_plan 22)┌─────────────┴─────────────┐ 23)│ RepartitionExec │ 24)│ -------------------- │ -25)│ output_partition_count: │ -26)│ 1 │ +25)│ partition_count(in->out): │ +26)│ 1 -> 4 │ 27)│ │ 28)│ partitioning_scheme: │ 29)│ RoundRobinBatch(4) │ @@ -2062,8 +2009,8 @@ physical_plan 19)┌─────────────┴─────────────┐ 20)│ RepartitionExec │ 21)│ -------------------- │ -22)│ output_partition_count: │ -23)│ 1 │ +22)│ partition_count(in->out): │ +23)│ 1 -> 4 │ 24)│ │ 25)│ partitioning_scheme: │ 26)│ RoundRobinBatch(4) │ @@ -2088,3 +2035,234 @@ physical_plan 06)┌─────────────┴─────────────┐ 07)│ PlaceholderRowExec │ 08)└───────────────────────────┘ + + +# Test explain for large plans + +statement ok +CREATE TABLE t (k int) + +# By default, the plan of this large query is cropped +query TT +EXPLAIN SELECT * FROM t t1, t t2, t t3, t t4, t t5, t t6, t t7, t t8, t t9, t t10 +---- +physical_plan +01)┌───────────────────────────┐ +02)│ CrossJoinExec ├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── +03)└─────────────┬─────────────┘ +04)┌─────────────┴─────────────┐ +05)│ CrossJoinExec │ +06)│ │ +07)│ ├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +08)│ │ │ +09)│ │ │ +10)└─────────────┬─────────────┘ │ +11)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +12)│ CrossJoinExec │ │ DataSourceExec │ +13)│ │ │ -------------------- │ +14)│ ├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +15)│ │ │ │ format: memory │ +16)│ │ │ │ rows: 0 │ +17)└─────────────┬─────────────┘ │ └───────────────────────────┘ +18)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +19)│ CrossJoinExec │ │ DataSourceExec │ +20)│ │ │ -------------------- │ +21)│ ├───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +22)│ │ │ │ format: memory │ +23)│ │ │ │ rows: 0 │ +24)└─────────────┬─────────────┘ │ └───────────────────────────┘ +25)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +26)│ CrossJoinExec │ │ DataSourceExec │ +27)│ │ │ -------------------- │ +28)│ ├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +29)│ │ │ │ format: memory │ +30)│ │ │ │ rows: 0 │ +31)└─────────────┬─────────────┘ │ └───────────────────────────┘ +32)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +33)│ CrossJoinExec │ │ DataSourceExec │ +34)│ │ │ -------------------- │ +35)│ ├─────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +36)│ │ │ │ format: memory │ +37)│ │ │ │ rows: 0 │ +38)└─────────────┬─────────────┘ │ └───────────────────────────┘ +39)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +40)│ CrossJoinExec │ │ DataSourceExec │ +41)│ │ │ -------------------- │ +42)│ ├────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +43)│ │ │ │ format: memory │ +44)│ │ │ │ rows: 0 │ +45)└─────────────┬─────────────┘ │ └───────────────────────────┘ +46)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +47)│ CrossJoinExec │ │ DataSourceExec │ +48)│ │ │ -------------------- │ +49)│ ├───────────────────────────────────────────┐ │ bytes: 0 │ +50)│ │ │ │ format: memory │ +51)│ │ │ │ rows: 0 │ +52)└─────────────┬─────────────┘ │ └───────────────────────────┘ +53)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +54)│ CrossJoinExec │ │ DataSourceExec │ +55)│ │ │ -------------------- │ +56)│ ├──────────────┐ │ bytes: 0 │ +57)│ │ │ │ format: memory │ +58)│ │ │ │ rows: 0 │ +59)└─────────────┬─────────────┘ │ └───────────────────────────┘ +60)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +61)│ DataSourceExec ││ DataSourceExec │ +62)│ -------------------- ││ -------------------- │ +63)│ bytes: 0 ││ bytes: 0 │ +64)│ format: memory ││ format: memory │ +65)│ rows: 0 ││ rows: 0 │ +66)└───────────────────────────┘└───────────────────────────┘ + +# Setting the tree_maximum_render_size to 0 will allow the entire plan to be rendered +statement ok +SET datafusion.explain.tree_maximum_render_width = 0 + +query TT +EXPLAIN SELECT * FROM t t1, t t2, t t3, t t4, t t5, t t6, t t7, t t8, t t9, t t10 +---- +physical_plan +01)┌───────────────────────────┐ +02)│ CrossJoinExec ├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +03)└─────────────┬─────────────┘ │ +04)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +05)│ CrossJoinExec │ │ DataSourceExec │ +06)│ │ │ -------------------- │ +07)│ ├─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +08)│ │ │ │ format: memory │ +09)│ │ │ │ rows: 0 │ +10)└─────────────┬─────────────┘ │ └───────────────────────────┘ +11)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +12)│ CrossJoinExec │ │ DataSourceExec │ +13)│ │ │ -------------------- │ +14)│ ├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +15)│ │ │ │ format: memory │ +16)│ │ │ │ rows: 0 │ +17)└─────────────┬─────────────┘ │ └───────────────────────────┘ +18)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +19)│ CrossJoinExec │ │ DataSourceExec │ +20)│ │ │ -------------------- │ +21)│ ├───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +22)│ │ │ │ format: memory │ +23)│ │ │ │ rows: 0 │ +24)└─────────────┬─────────────┘ │ └───────────────────────────┘ +25)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +26)│ CrossJoinExec │ │ DataSourceExec │ +27)│ │ │ -------------------- │ +28)│ ├──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +29)│ │ │ │ format: memory │ +30)│ │ │ │ rows: 0 │ +31)└─────────────┬─────────────┘ │ └───────────────────────────┘ +32)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +33)│ CrossJoinExec │ │ DataSourceExec │ +34)│ │ │ -------------------- │ +35)│ ├─────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +36)│ │ │ │ format: memory │ +37)│ │ │ │ rows: 0 │ +38)└─────────────┬─────────────┘ │ └───────────────────────────┘ +39)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +40)│ CrossJoinExec │ │ DataSourceExec │ +41)│ │ │ -------------------- │ +42)│ ├────────────────────────────────────────────────────────────────────────┐ │ bytes: 0 │ +43)│ │ │ │ format: memory │ +44)│ │ │ │ rows: 0 │ +45)└─────────────┬─────────────┘ │ └───────────────────────────┘ +46)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +47)│ CrossJoinExec │ │ DataSourceExec │ +48)│ │ │ -------------------- │ +49)│ ├───────────────────────────────────────────┐ │ bytes: 0 │ +50)│ │ │ │ format: memory │ +51)│ │ │ │ rows: 0 │ +52)└─────────────┬─────────────┘ │ └───────────────────────────┘ +53)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +54)│ CrossJoinExec │ │ DataSourceExec │ +55)│ │ │ -------------------- │ +56)│ ├──────────────┐ │ bytes: 0 │ +57)│ │ │ │ format: memory │ +58)│ │ │ │ rows: 0 │ +59)└─────────────┬─────────────┘ │ └───────────────────────────┘ +60)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +61)│ DataSourceExec ││ DataSourceExec │ +62)│ -------------------- ││ -------------------- │ +63)│ bytes: 0 ││ bytes: 0 │ +64)│ format: memory ││ format: memory │ +65)│ rows: 0 ││ rows: 0 │ +66)└───────────────────────────┘└───────────────────────────┘ + +# Setting the tree_maximum_render_size to a smaller size +statement ok +SET datafusion.explain.tree_maximum_render_width = 60 + +query TT +EXPLAIN SELECT * FROM t t1, t t2, t t3, t t4, t t5, t t6, t t7, t t8, t t9, t t10 +---- +physical_plan +01)┌───────────────────────────┐ +02)│ CrossJoinExec ├────────────────────────────────────────────────────────── +03)└─────────────┬─────────────┘ +04)┌─────────────┴─────────────┐ +05)│ CrossJoinExec │ +06)│ │ +07)│ ├────────────────────────────────────────────────────────── +08)│ │ +09)│ │ +10)└─────────────┬─────────────┘ +11)┌─────────────┴─────────────┐ +12)│ CrossJoinExec │ +13)│ │ +14)│ ├────────────────────────────────────────────────────────── +15)│ │ +16)│ │ +17)└─────────────┬─────────────┘ +18)┌─────────────┴─────────────┐ +19)│ CrossJoinExec │ +20)│ │ +21)│ ├────────────────────────────────────────────────────────── +22)│ │ +23)│ │ +24)└─────────────┬─────────────┘ +25)┌─────────────┴─────────────┐ +26)│ CrossJoinExec │ +27)│ │ +28)│ ├────────────────────────────────────────────────────────── +29)│ │ +30)│ │ +31)└─────────────┬─────────────┘ +32)┌─────────────┴─────────────┐ +33)│ CrossJoinExec │ +34)│ │ +35)│ ├────────────────────────────────────────────────────────── +36)│ │ +37)│ │ +38)└─────────────┬─────────────┘ +39)┌─────────────┴─────────────┐ +40)│ CrossJoinExec │ +41)│ │ +42)│ ├────────────────────────────────────────────────────────── +43)│ │ +44)│ │ +45)└─────────────┬─────────────┘ +46)┌─────────────┴─────────────┐ +47)│ CrossJoinExec │ +48)│ │ +49)│ ├───────────────────────────────────────────┐ +50)│ │ │ +51)│ │ │ +52)└─────────────┬─────────────┘ │ +53)┌─────────────┴─────────────┐ ┌─────────────┴─────────────┐ +54)│ CrossJoinExec │ │ DataSourceExec │ +55)│ │ │ -------------------- │ +56)│ ├──────────────┐ │ bytes: 0 │ +57)│ │ │ │ format: memory │ +58)│ │ │ │ rows: 0 │ +59)└─────────────┬─────────────┘ │ └───────────────────────────┘ +60)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +61)│ DataSourceExec ││ DataSourceExec │ +62)│ -------------------- ││ -------------------- │ +63)│ bytes: 0 ││ bytes: 0 │ +64)│ format: memory ││ format: memory │ +65)│ rows: 0 ││ rows: 0 │ +66)└───────────────────────────┘└───────────────────────────┘ + +statement ok +DROP TABLE t diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index e4d0b72338569..87345b833e264 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -424,10 +424,12 @@ SELECT chr(CAST(NULL AS int)) ---- NULL -statement error DataFusion error: Execution error: null character not permitted. +query T SELECT chr(CAST(0 AS int)) +---- +\0 -statement error DataFusion error: Execution error: requested character too large for encoding. +statement error DataFusion error: Execution error: invalid Unicode scalar value: 9223372036854775807 SELECT chr(CAST(9223372036854775807 AS bigint)) query T @@ -698,6 +700,11 @@ SELECT to_hex(2147483647) ---- 7fffffff +query T +SELECT to_hex(CAST(2147483647 as BIGINT UNSIGNED)) +---- +7fffffff + query T SELECT to_hex(9223372036854775807) ---- @@ -2072,9 +2079,6 @@ host1 1.1 101 host2 2.2 202 host3 3.3 303 -statement ok -set datafusion.sql_parser.dialect = 'Postgres'; - statement ok create table t (a float) as values (1), (2), (3); @@ -2094,9 +2098,6 @@ physical_plan statement ok drop table t; -statement ok -set datafusion.sql_parser.dialect = 'Generic'; - # test between expression with null query I select 1 where null between null and null; @@ -2126,3 +2127,26 @@ query T select E'foo\t\tbar'; ---- foo bar + +statement ok +create table t (a float) as values (1), (null), (3); + +# https://github.com/apache/datafusion/issues/17055 +# is not null did not correctly infer as boolean in udf argument position +query B +select greatest(a is not null, false) from t; +---- +true +false +true + +# same for is null +query B +select greatest(a is null, false) from t; +---- +false +true +false + +statement ok +drop table t; diff --git a/datafusion/sqllogictest/test_files/expr/date_part.slt b/datafusion/sqllogictest/test_files/expr/date_part.slt index dec796aa59cb5..64f16f72421a0 100644 --- a/datafusion/sqllogictest/test_files/expr/date_part.slt +++ b/datafusion/sqllogictest/test_files/expr/date_part.slt @@ -884,7 +884,7 @@ SELECT extract(day from arrow_cast('14400 minutes', 'Interval(DayTime)')) query I SELECT extract(minute from arrow_cast('14400 minutes', 'Interval(DayTime)')) ---- -14400 +0 query I SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) @@ -894,7 +894,7 @@ SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) query I SELECT extract(second from arrow_cast('14400 minutes', 'Interval(DayTime)')) ---- -864000 +0 query I SELECT extract(second from arrow_cast('2 months', 'Interval(MonthDayNano)')) @@ -954,7 +954,7 @@ from t order by id; ---- 0 0 5 -1 0 15 +1 0 3 2 0 0 3 2 0 4 0 8 @@ -1070,3 +1070,23 @@ true query error DataFusion error: This feature is not implemented: Date part Nanosecond not supported SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) + +query I +SELECT date_part('ISODOW', CAST('2000-01-01' AS DATE)) +---- +5 + +query I +SELECT EXTRACT(isodow FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +1 + +query I +SELECT EXTRACT("isodow" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +1 + +query I +SELECT EXTRACT('isodow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +1 diff --git a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt index d96044fda8c05..a09d8ce26ddfb 100644 --- a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt +++ b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt @@ -34,7 +34,7 @@ ORDER BY "date", "time"; ---- logical_plan 01)Sort: data.date ASC NULLS LAST, data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") +02)--Filter: data.ticker = Utf8View("A") 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [date@0 ASC NULLS LAST, time@2 ASC NULLS LAST] @@ -51,7 +51,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [time@2 ASC NULLS LAST] @@ -68,7 +68,7 @@ ORDER BY "date" ---- logical_plan 01)Sort: data.date ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [date@0 ASC NULLS LAST] @@ -85,7 +85,7 @@ ORDER BY "ticker" ---- logical_plan 01)Sort: data.ticker ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)CoalescePartitionsExec @@ -102,7 +102,7 @@ ORDER BY "time", "date"; ---- logical_plan 01)Sort: data.time ASC NULLS LAST, data.date ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [time@2 ASC NULLS LAST, date@0 ASC NULLS LAST] @@ -120,7 +120,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) != data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) != data.date 03)----TableScan: data projection=[date, ticker, time] # no relation between time & date @@ -132,7 +132,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") +02)--Filter: data.ticker = Utf8View("A") 03)----TableScan: data projection=[date, ticker, time] # query diff --git a/datafusion/sqllogictest/test_files/float16.slt b/datafusion/sqllogictest/test_files/float16.slt new file mode 100644 index 0000000000000..5e59c730f0787 --- /dev/null +++ b/datafusion/sqllogictest/test_files/float16.slt @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Basic tests Tests for Float16 Type + +statement ok +create table floats as values (1.0), (2.0), (3.0), (NULL), ('Nan'); + +statement ok +create table float16s as select arrow_cast(column1, 'Float16') as column1 from floats; + +query RT +select column1, arrow_typeof(column1) as type from float16s; +---- +1 Float16 +2 Float16 +3 Float16 +NULL Float16 +NaN Float16 + +# Test coercions with arithmetic + +query RRRRRR +SELECT + column1 + 1::tinyint as column1_plus_int8, + column1 + 1::smallint as column1_plus_int16, + column1 + 1::int as column1_plus_int32, + column1 + 1::bigint as column1_plus_int64, + column1 + 1.0::float as column1_plus_float32, + column1 + 1.0 as column1_plus_float64 +FROM float16s; +---- +2 2 2 2 2 2 +3 3 3 3 3 3 +4 4 4 4 4 4 +NULL NULL NULL NULL NULL NULL +NaN NaN NaN NaN NaN NaN + +# Try coercing with literal NULL +query error +select column1 + NULL from float16s; +---- +DataFusion error: type_coercion +caused by +Error during planning: Cannot automatically convert Null to Float16 + + +# Test coercions with equality +query BBBBBB +SELECT + column1 = 1::tinyint as column1_equals_int8, + column1 = 1::smallint as column1_equals_int16, + column1 = 1::int as column1_equals_int32, + column1 = 1::bigint as column1_equals_int64, + column1 = 1.0::float as column1_equals_float32, + column1 = 1.0 as column1_equals_float64 +FROM float16s; +---- +true true true true true true +false false false false false false +false false false false false false +NULL NULL NULL NULL NULL NULL +false false false false false false + + +# Try coercing with literal NULL +query error +select column1 = NULL from float16s; +---- +DataFusion error: Error during planning: Cannot infer common argument type for comparison operation Float16 = Null + + +# Cleanup +statement ok +drop table floats; + +statement ok +drop table float16s; diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 4c4999a364d12..b72f73d44698f 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -2232,7 +2232,7 @@ physical_plan 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III -SELECT a, b, LAST_VALUE(c) as last_c +SELECT a, b, LAST_VALUE(c order by c) as last_c FROM annotated_data_infinite2 GROUP BY a, b ---- @@ -2506,12 +2506,16 @@ TUR [100.0, 75.0] 175 # test_ordering_sensitive_aggregation3 # When different aggregators have conflicting requirements, we cannot satisfy all of them in current implementation. # test below should raise Plan Error. -statement error DataFusion error: This feature is not implemented: Conflicting ordering requirements in aggregate functions is not supported +query ??? rowsort SELECT ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, ARRAY_AGG(s.amount ORDER BY s.amount ASC) AS amounts2, ARRAY_AGG(s.amount ORDER BY s.sn ASC) AS amounts3 FROM sales_global AS s GROUP BY s.country +---- +[100.0, 75.0] [75.0, 100.0] [75.0, 100.0] +[200.0, 50.0] [50.0, 200.0] [50.0, 200.0] +[80.0, 30.0] [30.0, 80.0] [30.0, 80.0] # test_ordering_sensitive_aggregation4 # If aggregators can work with bounded memory (Sorted or PartiallySorted mode), we should append requirement to @@ -2706,6 +2710,29 @@ select k, first_value(val order by o) respect NULLS from first_null group by k; 1 1 +statement ok +CREATE TABLE last_null ( + k INT, + val INT, + o int + ) as VALUES + (0, NULL, 9), + (0, 1, 1), + (1, 1, 1); + +query II rowsort +select k, last_value(val order by o) IGNORE NULLS from last_null group by k; +---- +0 1 +1 1 + +query II rowsort +select k, last_value(val order by o) respect NULLS from last_null group by k; +---- +0 NULL +1 1 + + query TT EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, @@ -3402,7 +3429,7 @@ physical_plan 06)----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 07)------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as amount], aggr=[] 08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -09)----------------DataSourceExec: partitions=1, partition_sizes=[1] +09)----------------DataSourceExec: partitions=1, partition_sizes=[2] query IRI SELECT s.sn, s.amount, 2*s.sn @@ -3471,9 +3498,9 @@ physical_plan 06)----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 07)------------AggregateExec: mode=Partial, gby=[sn@1 as sn, amount@2 as amount], aggr=[sum(l.amount)] 08)--------------NestedLoopJoinExec: join_type=Inner, filter=sn@0 >= sn@1, projection=[amount@1, sn@2, amount@3] -09)----------------DataSourceExec: partitions=1, partition_sizes=[1] +09)----------------DataSourceExec: partitions=1, partition_sizes=[2] 10)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -11)------------------DataSourceExec: partitions=1, partition_sizes=[1] +11)------------------DataSourceExec: partitions=1, partition_sizes=[2] query IRR SELECT r.sn, SUM(l.amount), r.amount @@ -3619,8 +3646,8 @@ physical_plan 07)------------AggregateExec: mode=Partial, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] 08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 09)----------------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@6 as sum_amount] -10)------------------BoundedWindowAggExec: wdw=[sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] +10)------------------BoundedWindowAggExec: wdw=[sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] +11)--------------------DataSourceExec: partitions=1, partition_sizes=[2] query ITIPTRR @@ -3775,7 +3802,7 @@ ORDER BY x; 2 2 query II -SELECT y, LAST_VALUE(x) +SELECT y, LAST_VALUE(x order by x desc) FROM FOO GROUP BY y ORDER BY y; @@ -3916,7 +3943,7 @@ physical_plan 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10, projection=[a@0, d@1, row_n@4] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], file_type=csv, has_header=true 06)--------ProjectionExec: expr=[a@0 as a, d@1 as d, row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] -07)----------BoundedWindowAggExec: wdw=[row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +07)----------BoundedWindowAggExec: wdw=[row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 08)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], file_type=csv, has_header=true # reset partition number to 8. @@ -4448,10 +4475,6 @@ physical_plan 12)----------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 13)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4], file_type=csv, has_header=true -# Use PostgreSQL dialect -statement ok -set datafusion.sql_parser.dialect = 'Postgres'; - query II SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a') FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; ---- @@ -4470,10 +4493,6 @@ SELECT c2, count(distinct c3) FILTER (WHERE c1 != 'a'), count(c5) FILTER (WHERE 4 19 18 5 11 9 -# Restore the default dialect -statement ok -set datafusion.sql_parser.dialect = 'Generic'; - statement ok drop table aggregate_test_100; @@ -4512,19 +4531,20 @@ LIMIT 5 query ITIPTR rowsort SELECT r.* FROM sales_global_with_pk as l, sales_global_with_pk as r +ORDER BY 1, 2, 3, 4, 5, 6 LIMIT 5 ---- 0 GRC 0 2022-01-01T06:00:00 EUR 30 -1 FRA 1 2022-01-01T08:00:00 EUR 50 -1 FRA 3 2022-01-02T12:00:00 EUR 200 -1 TUR 2 2022-01-01T11:30:00 TRY 75 -1 TUR 4 2022-01-03T10:00:00 TRY 100 +0 GRC 0 2022-01-01T06:00:00 EUR 30 +0 GRC 0 2022-01-01T06:00:00 EUR 30 +0 GRC 0 2022-01-01T06:00:00 EUR 30 +0 GRC 0 2022-01-01T06:00:00 EUR 30 # Create a table with timestamp data statement ok CREATE TABLE src_table ( - t1 TIMESTAMP, - c2 INT + t1 TIMESTAMP, + c2 INT ) AS VALUES ('2020-12-10T00:00:00.00Z', 0), ('2020-12-11T00:00:00.00Z', 1), @@ -4569,8 +4589,8 @@ STORED AS CSV; # Create a table from the generated CSV files: statement ok CREATE EXTERNAL TABLE timestamp_table ( - t1 TIMESTAMP, - c2 INT, + t1 TIMESTAMP, + c2 INT, ) STORED AS CSV LOCATION 'test_files/scratch/group_by/timestamp_table' @@ -5153,8 +5173,8 @@ physical_plan 02)--AggregateExec: mode=Single, gby=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }, ts@0, 946684800000000000) as date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))], aggr=[count(keywords_stream.keyword)] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(keyword@0, keyword@1)] -05)--------DataSourceExec: partitions=1, partition_sizes=[1] -06)--------DataSourceExec: partitions=1, partition_sizes=[1] +05)--------DataSourceExec: partitions=1, partition_sizes=[3] +06)--------DataSourceExec: partitions=1, partition_sizes=[3] query PI SELECT @@ -5196,17 +5216,17 @@ statement ok create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, null), (null, 0xb), (null, 0xb); query I?I -select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); +select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)) order by a, b; ---- 1 0a 2 -2 NULL 1 -NULL 0b 2 1 NULL 2 2 NULL 1 -NULL NULL 2 +2 NULL 1 NULL 0a 2 -NULL NULL 1 NULL 0b 2 +NULL 0b 2 +NULL NULL 2 +NULL NULL 1 statement ok drop table t; @@ -5216,13 +5236,13 @@ statement ok create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb); query I?I -select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); +select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)) order by a, b; ---- 1 0a 2 -2 0b 1 -3 0b 2 1 NULL 2 +2 0b 1 2 NULL 1 +3 0b 2 3 NULL 2 NULL 0a 2 NULL 0b 3 diff --git a/datafusion/sqllogictest/test_files/imdb.slt b/datafusion/sqllogictest/test_files/imdb.slt new file mode 100644 index 0000000000000..c17f9c47c745a --- /dev/null +++ b/datafusion/sqllogictest/test_files/imdb.slt @@ -0,0 +1,4040 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file contains IMDB test queries against a small sample dataset. +# The test creates tables with sample data and runs all the IMDB benchmark queries. + +# company_type table +statement ok +CREATE TABLE company_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO company_type VALUES + (1, 'production companies'), + (2, 'distributors'), + (3, 'special effects companies'), + (4, 'other companies'), + (5, 'miscellaneous companies'), + (6, 'film distributors'), + (7, 'theaters'), + (8, 'sales companies'), + (9, 'producers'), + (10, 'publishers'), + (11, 'visual effects companies'), + (12, 'makeup departments'), + (13, 'costume designers'), + (14, 'movie studios'), + (15, 'sound departments'), + (16, 'talent agencies'), + (17, 'casting companies'), + (18, 'film commissions'), + (19, 'production services'), + (20, 'digital effects studios'); + +# info_type table +statement ok +CREATE TABLE info_type ( + id INT NOT NULL, + info VARCHAR NOT NULL +); + +statement ok +INSERT INTO info_type VALUES + (1, 'runtimes'), + (2, 'color info'), + (3, 'genres'), + (4, 'languages'), + (5, 'certificates'), + (6, 'sound mix'), + (7, 'countries'), + (8, 'top 250 rank'), + (9, 'bottom 10 rank'), + (10, 'release dates'), + (11, 'filming locations'), + (12, 'production companies'), + (13, 'technical info'), + (14, 'trivia'), + (15, 'goofs'), + (16, 'martial-arts'), + (17, 'quotes'), + (18, 'movie connections'), + (19, 'plot description'), + (20, 'biography'), + (21, 'plot summary'), + (22, 'box office'), + (23, 'ratings'), + (24, 'taglines'), + (25, 'keywords'), + (26, 'soundtrack'), + (27, 'votes'), + (28, 'height'), + (30, 'mini biography'), + (31, 'budget'), + (32, 'rating'); + +# title table +statement ok +CREATE TABLE title ( + id INT NOT NULL, + title VARCHAR NOT NULL, + imdb_index VARCHAR, + kind_id INT NOT NULL, + production_year INT, + imdb_id INT, + phonetic_code VARCHAR, + episode_of_id INT, + season_nr INT, + episode_nr INT, + series_years VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO title VALUES + (1, 'The Shawshank Redemption', NULL, 1, 1994, 111161, NULL, NULL, NULL, NULL, NULL, NULL), + (2, 'The Godfather', NULL, 1, 1985, 68646, NULL, NULL, NULL, NULL, NULL, NULL), + (3, 'The Dark Knight', NULL, 1, 2008, 468569, NULL, NULL, NULL, NULL, NULL, NULL), + (4, 'The Godfather Part II', NULL, 1, 2012, 71562, NULL, NULL, NULL, NULL, NULL, NULL), + (5, 'Pulp Fiction', NULL, 1, 1994, 110912, NULL, NULL, NULL, NULL, NULL, NULL), + (6, 'Schindler''s List', NULL, 1, 1993, 108052, NULL, NULL, NULL, NULL, NULL, NULL), + (7, 'The Lord of the Rings: The Return of the King', NULL, 1, 2003, 167260, NULL, NULL, NULL, NULL, NULL, NULL), + (8, '12 Angry Men', NULL, 1, 1957, 50083, NULL, NULL, NULL, NULL, NULL, NULL), + (9, 'Inception', NULL, 1, 2010, 1375666, NULL, NULL, NULL, NULL, NULL, NULL), + (10, 'Fight Club', NULL, 1, 1999, 137523, NULL, NULL, NULL, NULL, NULL, NULL), + (11, 'The Matrix', NULL, 1, 2014, 133093, NULL, NULL, NULL, NULL, NULL, NULL), + (12, 'Goodfellas', NULL, 1, 1990, 99685, NULL, NULL, NULL, NULL, NULL, NULL), + (13, 'Avengers: Endgame', NULL, 1, 2019, 4154796, NULL, NULL, NULL, NULL, NULL, NULL), + (14, 'Interstellar', NULL, 1, 2014, 816692, NULL, NULL, NULL, NULL, NULL, NULL), + (15, 'The Silence of the Lambs', NULL, 1, 1991, 102926, NULL, NULL, NULL, NULL, NULL, NULL), + (16, 'Saving Private Ryan', NULL, 1, 1998, 120815, NULL, NULL, NULL, NULL, NULL, NULL), + (17, 'The Green Mile', NULL, 1, 1999, 120689, NULL, NULL, NULL, NULL, NULL, NULL), + (18, 'Forrest Gump', NULL, 1, 1994, 109830, NULL, NULL, NULL, NULL, NULL, NULL), + (19, 'Joker', NULL, 1, 2019, 7286456, NULL, NULL, NULL, NULL, NULL, NULL), + (20, 'Parasite', NULL, 1, 2019, 6751668, NULL, NULL, NULL, NULL, NULL, NULL), + (21, 'The Iron Giant', NULL, 1, 1999, 129167, NULL, NULL, NULL, NULL, NULL, NULL), + (22, 'Spider-Man: Into the Spider-Verse', NULL, 1, 2018, 4633694, NULL, NULL, NULL, NULL, NULL, NULL), + (23, 'Iron Man', NULL, 1, 2008, 371746, NULL, NULL, NULL, NULL, NULL, NULL), + (24, 'Black Panther', NULL, 1, 2018, 1825683, NULL, NULL, NULL, NULL, NULL, NULL), + (25, 'Titanic', NULL, 1, 1997, 120338, NULL, NULL, NULL, NULL, NULL, NULL), + (26, 'Kung Fu Panda 2', NULL, 1, 2011, 0441773, NULL, NULL, NULL, NULL, NULL, NULL), + (27, 'Halloween', NULL, 1, 2008, 1311067, NULL, NULL, NULL, NULL, NULL, NULL), + (28, 'Breaking Bad', NULL, 2, 2003, 903254, NULL, NULL, NULL, NULL, NULL, NULL), + (29, 'Breaking Bad: The Final Season', NULL, 2, 2007, 903255, NULL, NULL, NULL, NULL, NULL, NULL), + (30, 'Amsterdam Detective', NULL, 2, 2005, 905001, NULL, NULL, NULL, NULL, NULL, NULL), + (31, 'Amsterdam Detective: Cold Case', NULL, 2, 2007, 905002, NULL, NULL, NULL, NULL, NULL, NULL), + (32, 'Saw IV', NULL, 1, 2007, 905003, NULL, NULL, NULL, NULL, NULL, NULL), + (33, 'Shrek 2', NULL, 1, 2004, 906001, NULL, NULL, NULL, NULL, NULL, NULL), + (35, 'Dark Blood', NULL, 1, 2005, 907001, NULL, NULL, NULL, NULL, NULL, NULL), + (36, 'The Nordic Murders', NULL, 1, 2008, 908002, NULL, NULL, NULL, NULL, NULL, NULL), + (37, 'Scandinavian Crime', NULL, 1, 2009, 909001, NULL, NULL, NULL, NULL, NULL, NULL), + (38, 'The Western Sequel', NULL, 1, 1998, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (39, 'Marvel Superhero Epic', NULL, 1, 2010, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (40, 'The Champion', NULL, 1, 2016, 999555, NULL, NULL, NULL, NULL, NULL, NULL), + (41, 'Champion Boxer', NULL, 1, 2018, 999556, NULL, NULL, NULL, NULL, NULL, NULL), + (42, 'Avatar', NULL, 5, 2010, 499549, NULL, NULL, NULL, NULL, NULL, NULL), + (43, 'The Godfather Connection', NULL, 1, 1985, 68647, NULL, NULL, NULL, NULL, NULL, NULL), + (44, 'Digital Connection', NULL, 1, 2005, 888999, NULL, NULL, NULL, NULL, NULL, NULL), + (45, 'Berlin Noir', NULL, 1, 2010, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (46, 'YouTube Documentary', NULL, 1, 2008, 777999, NULL, NULL, NULL, NULL, NULL, NULL), + (47, 'The Swedish Murder Case', NULL, 1, 2012, 666777, NULL, NULL, NULL, NULL, NULL, NULL), + (48, 'Nordic Noir', NULL, 1, 2015, 555666, NULL, NULL, NULL, NULL, NULL, NULL), + (49, 'Derek Jacobi Story', NULL, 1, 1982, 444555, NULL, NULL, NULL, NULL, NULL, NULL), + (50, 'Woman in Black', NULL, 1, 2010, 987654, NULL, NULL, NULL, NULL, NULL, NULL), + (51, 'Kung Fu Panda', NULL, 1, 2008, 441772, NULL, NULL, NULL, NULL, NULL, NULL), + (52, 'Bruno', NULL, 1, 2009, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (53, 'Character Series', NULL, 2, 2020, 999888, NULL, NULL, NULL, 55, NULL, NULL), + (54, 'Vampire Chronicles', NULL, 1, 2015, 999999, NULL, NULL, NULL, NULL, NULL, NULL), + (55, 'Alien Invasion', NULL, 1, 2020, 888888, NULL, NULL, NULL, NULL, NULL, NULL), + (56, 'Dragon Warriors', NULL, 1, 2015, 888889, NULL, NULL, NULL, NULL, NULL, NULL), + (57, 'One Piece: Grand Adventure', NULL, 1, 2007, 777777, NULL, NULL, NULL, NULL, NULL, NULL), + (58, 'Moscow Nights', NULL, 1, 2010, 777778, NULL, NULL, NULL, NULL, NULL, NULL), + (59, 'Money Talks', NULL, 1, 1998, 888888, NULL, NULL, NULL, NULL, NULL, NULL), + (60, 'Fox Novel Movie', NULL, 1, 2005, 777888, NULL, NULL, NULL, NULL, NULL, NULL), + (61, 'Bad Movie Sequel', NULL, 1, 2010, 888777, NULL, NULL, NULL, NULL, NULL, NULL); + +# movie_companies table +statement ok +CREATE TABLE movie_companies ( + id INT NOT NULL, + movie_id INT NOT NULL, + company_id INT NOT NULL, + company_type_id INT NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_companies VALUES + (1, 1, 4, 1, '(presents) (co-production)'), + (2, 2, 5, 1, '(presents)'), + (3, 3, 6, 1, '(co-production)'), + (4, 4, 7, 1, '(as Metro-Goldwyn-Mayer Pictures)'), + (5, 5, 8, 1, '(presents) (co-production)'), + (6, 6, 9, 1, '(presents)'), + (7, 7, 10, 1, '(co-production)'), + (8, 8, 11, 2, '(distributor)'), + (9, 9, 12, 1, '(presents) (co-production)'), + (10, 10, 13, 1, '(presents)'), + (11, 11, 14, 1, '(presents) (co-production)'), + (12, 12, 15, 1, '(presents)'), + (13, 13, 16, 1, '(co-production)'), + (14, 14, 17, 1, '(presents)'), + (15, 15, 18, 1, '(co-production)'), + (16, 16, 19, 1, '(presents)'), + (17, 17, 20, 1, '(co-production)'), + (18, 18, 21, 1, '(presents)'), + (19, 19, 22, 1, '(co-production)'), + (20, 20, 23, 1, '(presents)'), + (21, 21, 24, 1, '(presents) (co-production)'), + (22, 22, 25, 1, '(presents)'), + (23, 23, 26, 1, '(co-production)'), + (24, 24, 27, 1, '(presents)'), + (25, 25, 28, 1, '(presents) (co-production)'), + (26, 3, 35, 1, '(as Warner Bros. Pictures)'), + (27, 9, 35, 1, '(as Warner Bros. Pictures)'), + (28, 23, 14, 1, '(as Marvel Studios)'), + (29, 24, 14, 1, '(as Marvel Studios)'), + (30, 13, 14, 1, '(as Marvel Studios)'), + (31, 26, 23, 1, '(as DreamWorks Animation)'), + (32, 3, 6, 2, '(distributor)'), + (33, 2, 8, 2, '(distributor)'), + (34, 3, 6, 1, '(as Warner Bros.) (2008) (USA) (worldwide)'), + (35, 44, 36, 1, NULL), + (36, 40, 9, 1, '(production) (USA) (2016)'), + (37, 56, 18, 1, '(production)'), + (38, 2, 6, 1, NULL), + (39, 13, 14, 2, '(as Marvel Studios)'), + (40, 19, 25, 1, '(co-production)'), + (41, 23, 26, 1, '(co-production)'), + (42, 19, 27, 1, '(co-production)'), + (43, 11, 18, 1, '(theatrical) (France)'), + (44, 11, 8, 1, '(VHS) (USA) (1994)'), + (45, 11, 4, 1, '(USA)'), + (46, 9, 28, 1, '(co-production)'), + (47, 28, 5, 1, '(production)'), + (48, 29, 5, 1, '(production)'), + (49, 30, 29, 1, '(production)'), + (50, 31, 30, 1, '(production)'), + (51, 27, 22, 1, '(production)'), + (52, 32, 22, 1, '(distribution) (Blu-ray)'), + (53, 33, 31, 1, '(production)'), + (54, 33, 31, 2, '(distribution)'), + (55, 35, 32, 1, NULL), + (56, 36, 33, 1, '(production) (2008)'), + (57, 37, 34, 1, '(production) (2009) (Norway)'), + (58, 38, 35, 1, NULL), + (59, 25, 9, 1, '(production)'), + (60, 52, 19, 1, NULL), + (61, 26, 37, 1, '(voice: English version)'), + (62, 21, 3, 1, '(production) (Japan) (anime)'), + (63, 57, 2, 1, '(production) (Japan) (2007) (anime)'), + (64, 58, 1, 1, '(production) (Russia) (2010)'), + (65, 59, 35, 1, NULL), + (66, 60, 13, 2, '(distribution) (DVD) (US)'), + (67, 61, 14, 1, '(production)'), + (68, 41, 9, 1, '(production) (USA) (2018)'), + (69, 46, 16, 1, '(production) (2008) (worldwide)'), + (70, 51, 31, 1, '(production) (2008) (USA) (worldwide)'), + (71, 45, 32, 1, 'Studio (2000) Berlin'), + (72, 53, 6, 1, '(production) (2020) (USA)'), + (73, 62, 9, 1, '(production) (USA) (2010) (worldwide)'); + +# movie_info_idx table +statement ok +CREATE TABLE movie_info_idx ( + id INT NOT NULL, + movie_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_info_idx VALUES + (1, 1, 8, '1', NULL), + (2, 2, 8, '2', NULL), + (3, 3, 8, '3', NULL), + (4, 4, 8, '4', NULL), + (5, 5, 8, '5', NULL), + (6, 6, 8, '6', NULL), + (7, 7, 8, '7', NULL), + (8, 8, 8, '8', NULL), + (9, 9, 8, '9', NULL), + (10, 10, 8, '10', NULL), + (11, 11, 8, '11', NULL), + (12, 12, 8, '12', NULL), + (13, 13, 8, '13', NULL), + (14, 14, 8, '14', NULL), + (15, 15, 8, '15', NULL), + (16, 16, 8, '16', NULL), + (17, 17, 8, '17', NULL), + (18, 18, 8, '18', NULL), + (19, 19, 8, '19', NULL), + (20, 20, 8, '20', NULL), + (21, 21, 8, '21', NULL), + (22, 22, 8, '22', NULL), + (23, 23, 8, '23', NULL), + (24, 24, 8, '24', NULL), + (25, 25, 8, '25', NULL), + (26, 40, 32, '8.6', NULL), + (27, 41, 32, '7.5', NULL), + (28, 45, 32, '6.8', NULL), + (29, 45, 22, '$10,000,000', NULL), + (30, 1, 22, '9.3', NULL), + (31, 2, 22, '9.2', NULL), + (32, 1, 27, '2,345,678', NULL), + (33, 3, 22, '9.0', NULL), + (34, 9, 22, '8.8', NULL), + (35, 23, 22, '8.5', NULL), + (36, 20, 9, '1', NULL), + (37, 25, 9, '2', NULL), + (38, 3, 9, '10', NULL), + (39, 28, 32, '8.2', NULL), + (40, 29, 32, '2.8', NULL), + (41, 30, 32, '8.5', NULL), + (42, 31, 32, '2.5', NULL), + (43, 27, 27, '45000', NULL), + (44, 32, 27, '52000', NULL), + (45, 33, 27, '120000', NULL), + (46, 35, 32, '7.2', NULL), + (47, 36, 32, '7.8', NULL), + (48, 37, 32, '7.5', NULL), + (49, 37, 27, '100000', NULL), + (50, 39, 32, '8.5', NULL), + (51, 54, 27, '1000', NULL), + (52, 3, 3002, '500', NULL), + (53, 3, 999, '9.5', NULL), + (54, 4, 999, '9.1', NULL), + (55, 13, 999, '8.9', NULL), + (56, 3, 32, '9.5', NULL), + (57, 4, 32, '9.1', NULL), + (58, 13, 32, '8.9', NULL), + (59, 4, 32, '9.3', NULL), + (60, 61, 9, '3', NULL), + (61, 35, 22, '8.4', NULL), + (62, 50, 32, '8.5', NULL), + (63, 48, 32, '7.5', NULL), + (64, 48, 27, '85000', NULL), + (65, 47, 32, '7.8', NULL), + (66, 46, 3, 'Documentary', NULL), + (67, 46, 10, 'USA: 2008-05-15', 'internet release'); + +# movie_info table +statement ok +CREATE TABLE movie_info ( + id INT NOT NULL, + movie_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_info VALUES + (1, 1, 1, '113', NULL), + (2, 4, 7, 'Germany', NULL), + (3, 3, 7, 'Bulgaria', NULL), + (4, 2, 1, '175', NULL), + (5, 3, 1, '152', NULL), + (6, 4, 1, '202', NULL), + (7, 5, 1, '154', NULL), + (8, 6, 1, '195', NULL), + (9, 7, 1, '201', NULL), + (10, 8, 1, '139', NULL), + (11, 9, 1, '148', NULL), + (12, 10, 1, '139', NULL), + (13, 11, 1, '136', NULL), + (14, 12, 1, '146', NULL), + (15, 13, 1, '181', NULL), + (16, 14, 1, '141', NULL), + (17, 15, 1, '159', NULL), + (18, 16, 1, '150', NULL), + (19, 17, 1, '156', NULL), + (20, 18, 1, '164', NULL), + (21, 19, 1, '122', NULL), + (22, 20, 1, '140', NULL), + (23, 40, 1, '125', NULL), + (24, 21, 1, '86', NULL), + (25, 22, 1, '117', NULL), + (26, 23, 1, '126', NULL), + (27, 24, 1, '134', NULL), + (28, 25, 1, '194', NULL), + (29, 1, 10, '1994-10-14', 'internet release'), + (30, 2, 10, '1972-03-24', 'internet release'), + (31, 3, 10, '2008-07-18', 'internet release'), + (32, 9, 10, '2010-07-16', 'internet release'), + (33, 13, 10, '2019-04-26', 'internet release'), + (34, 23, 10, '2008-05-02', 'internet release'), + (35, 24, 10, '2018-02-16', 'internet release'), + (36, 1, 2, 'Color', NULL), + (37, 3, 2, 'Color', NULL), + (38, 8, 2, 'Black and White', NULL), + (39, 9, 2, 'Color', NULL), + (40, 1, 19, 'Story about hope and redemption', NULL), + (41, 3, 19, 'Batman faces his greatest challenge', NULL), + (42, 19, 19, 'Origin story of the Batman villain', NULL), + (43, 1, 3, 'Drama', NULL), + (44, 3, 3, 'Action', NULL), + (45, 3, 3, 'Crime', NULL), + (46, 3, 3, 'Drama', NULL), + (47, 9, 3, 'Action', NULL), + (48, 9, 3, 'Adventure', NULL), + (49, 9, 3, 'Sci-Fi', NULL), + (50, 23, 3, 'Action', NULL), + (51, 23, 3, 'Adventure', NULL), + (52, 23, 3, 'Sci-Fi', NULL), + (53, 24, 3, 'Action', NULL), + (54, 24, 3, 'Adventure', NULL), + (55, 9, 7, 'Germany', NULL), + (56, 19, 7, 'German', NULL), + (57, 24, 7, 'Germany', NULL), + (58, 13, 7, 'USA', NULL), + (59, 3, 7, 'USA', NULL), + (60, 3, 22, '2343110', NULL), + (61, 3, 27, '2343110', NULL), + (62, 26, 10, 'USA:2011-05-26', NULL), + (63, 19, 20, 'Batman faces his greatest challenge', NULL), + (64, 3, 3, 'Drama', NULL), + (65, 13, 3, 'Action', NULL), + (66, 13, 19, 'Epic conclusion to the Infinity Saga', NULL), + (67, 2, 8, '1972-03-24', 'Released via internet in 2001'), + (68, 13, 4, 'English', NULL), + (69, 13, 3, 'Animation', NULL), + (70, 26, 3, 'Animation', NULL), + (71, 27, 3, '$15 million', NULL), + (72, 27, 3, 'Horror', NULL), + (73, 32, 3, 'Horror', NULL), + (74, 33, 10, 'USA: 2004', NULL), + (75, 33, 3, 'Animation', NULL), + (76, 35, 7, 'Germany', NULL), + (77, 35, 10, '2005-09-15', NULL), + (78, 44, 10, 'USA: 15 May 2005', 'This movie explores internet culture and digital connections that emerged in the early 2000s.'), + (79, 40, 10, '2016-08-12', 'internet release'), + (80, 1, 31, '$25,000,000', NULL), + (81, 45, 7, 'Germany', NULL), + (82, 45, 32, 'Germany', NULL), + (83, 13, 32, '8.5', NULL), + (84, 3, 32, '9.2', NULL), + (85, 3, 102, '9.2', NULL), + (86, 3, 25, 'sequel', NULL), + (87, 3, 102, '9.2', NULL), + (88, 3, 102, '9.2', NULL), + (89, 4, 102, '9.5', NULL), + (90, 33, 102, '8.7', NULL), + (91, 4, 32, '9.5', NULL), + (92, 11, 32, '8.7', NULL), + (93, 3, 32, '9.2', NULL), + (94, 3, 102, '9.2', NULL), + (95, 3, 32, '9.0', NULL), + (96, 26, 32, '8.2', NULL), + (97, 26, 32, '8.5', NULL), + (98, 27, 27, '8231', NULL), + (99, 27, 10, '2008-10-31', NULL), + (100, 13, 1, '182', NULL), + (101, 11, 2, 'Germany', NULL), + (102, 11, 1, '120', NULL), + (103, 3, 3, 'Drama', NULL), + (104, 11, 7, 'USA', NULL), + (105, 11, 7, 'Bulgaria', NULL), + (106, 50, 3, 'Horror', NULL), + (107, 36, 7, 'Sweden', NULL), + (108, 37, 7, 'Norway', NULL), + (109, 38, 7, 'Sweden', NULL), + (110, 54, 3, 'Horror', NULL), + (111, 55, 3, 'Sci-Fi', NULL), + (112, 56, 30, 'Japan:2015-06-15', NULL), + (113, 56, 30, 'USA:2015-07-20', NULL), + (114, 26, 10, 'Japan:2011-05-29', NULL), + (115, 26, 10, 'USA:2011-05-26', NULL), + (116, 61, 31, '$500,000', NULL), + (117, 41, 10, '2018-05-25', 'USA theatrical release'), + (118, 41, 7, 'Germany', 'Filmed on location'), + (119, 48, 7, 'Sweden', 'Filmed on location'), + (120, 48, 10, '2015-06-15', 'theatrical release'), + (121, 48, 3, 'Thriller', NULL), + (122, 47, 7, 'Sweden', 'Principal filming location'), + (123, 47, 10, '2012-09-21', 'theatrical release'), + (124, 47, 3, 'Crime', NULL), + (125, 47, 3, 'Thriller', NULL), + (126, 47, 7, 'Sweden', NULL), + (127, 3, 10, 'USA: 2008-07-14', 'internet release'), + (128, 46, 10, 'USA: 2008-05-15', 'internet release'), + (129, 40, 10, 'USA:\ 2006', 'internet release'), + (130, 51, 10, 'USA: 2008-06-06', 'theatrical release'), + (131, 51, 10, 'Japan: 2007-12-20', 'preview screening'); + +# kind_type table +statement ok +CREATE TABLE kind_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO kind_type VALUES + (1, 'movie'), + (2, 'tv series'), + (3, 'video movie'), + (4, 'tv movie'), + (5, 'video game'), + (6, 'episode'), + (7, 'documentary'), + (8, 'short movie'), + (9, 'tv mini series'), + (10, 'reality-tv'); + +# cast_info table +statement ok +CREATE TABLE cast_info ( + id INT NOT NULL, + person_id INT NOT NULL, + movie_id INT NOT NULL, + person_role_id INT, + note VARCHAR, + nr_order INT, + role_id INT NOT NULL +); + +statement ok +INSERT INTO cast_info VALUES + (1, 29, 53, NULL, NULL, 1, 1), + (2, 3, 1, 54, NULL, 1, 1), + (3, 3, 1, NULL, '(producer)', 1, 3), + (4, 4, 2, 2, NULL, 1, 1), + (5, 5, 3, 3, NULL, 1, 1), + (6, 6, 4, 4, NULL, 1, 1), + (7, 2, 50, NULL, '(writer)', 1, 4), + (8, 18, 51, 15, '(voice)', 1, 2), + (9, 1, 19, NULL, NULL, 1, 1), + (10, 6, 100, 1985, '(as Special Actor)', 1, 1), + (11, 15, 19, NULL, NULL, 1, 1), + (12, 8, 5, 5, NULL, 1, 1), + (13, 9, 6, 6, NULL, 1, 1), + (14, 10, 7, 7, NULL, 1, 1), + (15, 11, 8, 8, NULL, 1, 1), + (16, 12, 9, 9, NULL, 1, 1), + (17, 13, 10, 10, NULL, 1, 1), + (18, 14, 9, 55, NULL, 1, 1), + (19, 14, 14, 29, NULL, 1, 1), + (20, 27, 58, 28, '(producer)', 1, 1), + (21, 16, 3, 23, '(producer)', 2, 1), + (22, 20, 49, NULL, NULL, 1, 1), + (23, 13, 23, 14, NULL, 1, 1), + (24, 28, 13, NULL, '(costume design)', 1, 7), + (25, 25, 58, 31, '(voice) (uncredited)', 1, 1), + (26, 18, 3, 24, '(voice)', 1, 2), + (27, 29, 26, 24, '(voice)', 1, 2), + (28, 13, 13, 47, '(writer)', 1, 1), + (29, 17, 3, 25, '(producer)', 3, 8), + (30, 18, 3, 11, '(voice)', 1, 2), + (31, 18, 26, 11, '(voice)', 1, 2), + (32, 18, 26, 12, '(voice: original film)', 1, 2), + (33, 22, 27, 12, '(writer)', 4, 8), + (34, 23, 32, 12, '(writer)', 4, 8), + (35, 21, 33, 13, '(voice)', 2, 2), + (36, 21, 33, 13, '(voice: English version)', 2, 2), + (37, 21, 33, 13, '(voice) (uncredited)', 2, 2), + (38, 22, 39, 25, 'Superman', 1, 1), + (39, 22, 39, 26, 'Ironman', 1, 1), + (40, 22, 39, 27, 'Spiderman', 1, 1), + (41, 19, 52, NULL, NULL, 2, 1), + (42, 14, 19, NULL, NULL, 3, 1), + (43, 6, 2, 2, NULL, 1, 1), + (44, 16, 54, NULL, '(writer)', 1, 4), + (45, 24, 55, NULL, '(director)', 1, 8), + (46, 25, 56, 29, '(voice: English version)', 1, 2), + (47, 18, 26, 30, '(voice: English version)', 1, 2), + (48, 26, 21, 24, '(voice: English version)', 1, 2), + (49, 26, 57, 25, '(voice: English version)', 1, 2), + (50, 27, 25, NULL, NULL, 1, 4), + (51, 18, 62, 32, '(voice)', 1, 2); + +# char_name table +statement ok +CREATE TABLE char_name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + imdb_id INT, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO char_name VALUES + (1, 'Andy Dufresne', NULL, NULL, NULL, NULL, NULL), + (2, 'Don Vito Corleone', NULL, NULL, NULL, NULL, NULL), + (3, 'Joker', NULL, NULL, NULL, NULL, NULL), + (4, 'Michael Corleone', NULL, NULL, NULL, NULL, NULL), + (5, 'Vincent Vega', NULL, NULL, NULL, NULL, NULL), + (6, 'Oskar Schindler', NULL, NULL, NULL, NULL, NULL), + (7, 'Gandalf', NULL, NULL, NULL, NULL, NULL), + (8, 'Juror 8', NULL, NULL, NULL, NULL, NULL), + (9, 'Cobb', NULL, NULL, NULL, NULL, NULL), + (10, 'Tyler Durden', NULL, NULL, NULL, NULL, NULL), + (11, 'Batman''s Assistant', NULL, NULL, NULL, NULL, NULL), + (12, 'Tiger', NULL, NULL, NULL, NULL, NULL), + (13, 'Queen', NULL, NULL, NULL, NULL, NULL), + (14, 'Iron Man', NULL, NULL, NULL, NULL, NULL), + (15, 'Master Tigress', NULL, NULL, NULL, NULL, NULL), + (16, 'Dom Cobb', NULL, NULL, NULL, NULL, NULL), + (17, 'Rachel Dawes', NULL, NULL, NULL, NULL, NULL), + (18, 'Arthur Fleck', NULL, NULL, NULL, NULL, NULL), + (19, 'Pepper Potts', NULL, NULL, NULL, NULL, NULL), + (20, 'T''Challa', NULL, NULL, NULL, NULL, NULL), + (21, 'Steve Rogers', NULL, NULL, NULL, NULL, NULL), + (22, 'Ellis Boyd Redding', NULL, NULL, NULL, NULL, NULL), + (23, 'Bruce Wayne', NULL, NULL, NULL, NULL, NULL), + (24, 'Tigress', NULL, NULL, NULL, NULL, NULL), + (25, 'Superman', NULL, NULL, NULL, NULL, NULL), + (26, 'Ironman', NULL, NULL, NULL, NULL, NULL), + (27, 'Spiderman', NULL, NULL, NULL, NULL, NULL), + (28, 'Director', NULL, NULL, NULL, NULL, NULL), + (29, 'Tiger Warrior', NULL, NULL, NULL, NULL, NULL), + (30, 'Tigress', NULL, NULL, NULL, NULL, NULL), + (31, 'Nikolai', NULL, NULL, NULL, NULL, NULL), + (32, 'Princess Dragon', NULL, NULL, NULL, NULL, NULL); + +# keyword table +statement ok +CREATE TABLE keyword ( + id INT NOT NULL, + keyword VARCHAR NOT NULL, + phonetic_code VARCHAR +); + +statement ok +INSERT INTO keyword VALUES + (1, 'prison', NULL), + (2, 'mafia', NULL), + (3, 'superhero', NULL), + (4, 'sequel', NULL), + (5, 'crime', NULL), + (6, 'holocaust', NULL), + (7, 'fantasy', NULL), + (8, 'jury', NULL), + (9, 'dream', NULL), + (10, 'fight', NULL), + (11, 'marvel-cinematic-universe', NULL), + (12, 'character-name-in-title', NULL), + (13, 'female-name-in-title', NULL), + (14, 'murder', NULL), + (15, 'noir', NULL), + (16, 'space', NULL), + (17, 'time-travel', NULL), + (18, 'artificial-intelligence', NULL), + (19, 'robot', NULL), + (20, 'alien', NULL), + (21, '10,000-mile-club', NULL), + (22, 'martial-arts', NULL), + (23, 'computer-animation', NULL), + (24, 'violence', NULL), + (25, 'based-on-novel', NULL), + (26, 'nerd', NULL), + (27, 'marvel-comics', NULL), + (28, 'based-on-comic', NULL), + (29, 'superhero-movie', NULL); + +# movie_keyword table +statement ok +CREATE TABLE movie_keyword ( + id INT NOT NULL, + movie_id INT NOT NULL, + keyword_id INT NOT NULL +); + +statement ok +INSERT INTO movie_keyword VALUES + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + (4, 4, 4), + (5, 5, 5), + (6, 6, 6), + (7, 7, 7), + (8, 8, 8), + (9, 9, 9), + (10, 10, 10), + (11, 3, 5), + (12, 19, 3), + (13, 19, 12), + (14, 23, 11), + (15, 13, 11), + (16, 24, 11), + (17, 11, 1), + (18, 11, 20), + (19, 11, 20), + (20, 14, 16), + (21, 9, 3), + (22, 3, 14), + (23, 25, 13), + (24, 23, 12), + (25, 2, 4), + (26, 23, 19), + (27, 19, 5), + (28, 23, 3), + (29, 23, 28), + (30, 3, 4), + (31, 3, 4), + (32, 2, 4), + (33, 4, 4), + (34, 11, 4), + (35, 3, 3), + (36, 26, 16), + (37, 13, 11), + (38, 13, 3), + (39, 13, 4), + (40, 9, 17), + (41, 9, 18), + (42, 3, 12), + (43, 13, 13), + (44, 26, 21), + (45, 24, 3), + (46, 9, 14), + (47, 2, 4), + (48, 14, 21), + (49, 27, 14), + (50, 32, 14), + (51, 33, 23), + (52, 33, 23), + (55, 35, 24), + (56, 36, 14), + (57, 36, 25), + (58, 35, 4), + (59, 37, 14), + (60, 37, 25), + (61, 45, 24), + (62, 2, 4), + (63, 14, 21), + (64, 27, 14), + (65, 32, 14), + (66, 33, 23), + (67, 33, 23), + (68, 35, 24), + (69, 38, 4), + (70, 39, 3), + (71, 39, 27), + (72, 39, 28), + (73, 39, 29), + (74, 44, 26), + (75, 52, 12), + (76, 54, 14), + (77, 55, 20), + (78, 55, 16), + (79, 56, 22), + (80, 26, 22), + (81, 3, 4), + (82, 4, 4), + (83, 13, 4), + (84, 3, 4), + (85, 40, 29), + (86, 4, 4), + (87, 13, 4), + (88, 59, 4), + (89, 60, 25), + (90, 48, 14), + (91, 47, 14), + (92, 45, 24), + (93, 46, 3), + (94, 53, 12); + +# company_name table +statement ok +CREATE TABLE company_name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + country_code VARCHAR, + imdb_id INT, + name_pcode_nf VARCHAR, + name_pcode_sf VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO company_name VALUES + (1, 'Mosfilm', '[ru]', NULL, NULL, NULL, NULL), + (2, 'Toei Animation', '[jp]', NULL, NULL, NULL, NULL), + (3, 'Tokyo Animation Studio', '[jp]', NULL, NULL, NULL, NULL), + (4, 'Castle Rock Entertainment', '[us]', NULL, NULL, NULL, NULL), + (5, 'Paramount Pictures', '[us]', NULL, NULL, NULL, NULL), + (6, 'Warner Bros.', '[us]', NULL, NULL, NULL, NULL), + (7, 'Metro-Goldwyn-Mayer', '[us]', NULL, NULL, NULL, NULL), + (8, 'Miramax Films', '[us]', NULL, NULL, NULL, NULL), + (9, 'Universal Pictures', '[us]', NULL, NULL, NULL, NULL), + (10, 'New Line Cinema', '[us]', NULL, NULL, NULL, NULL), + (11, 'United Artists', '[us]', NULL, NULL, NULL, NULL), + (12, 'Columbia Pictures', '[us]', NULL, NULL, NULL, NULL), + (13, 'Twentieth Century Fox', '[us]', NULL, NULL, NULL, NULL), + (14, 'Marvel Studios', '[us]', NULL, NULL, NULL, NULL), + (15, 'DC Films', '[us]', NULL, NULL, NULL, NULL), + (16, 'YouTube', '[us]', NULL, NULL, NULL, NULL), + (17, 'DreamWorks Pictures', '[us]', NULL, NULL, NULL, NULL), + (18, 'Walt Disney Pictures', '[us]', NULL, NULL, NULL, NULL), + (19, 'Netflix', '[us]', NULL, NULL, NULL, NULL), + (20, 'Amazon Studios', '[us]', NULL, NULL, NULL, NULL), + (21, 'A24', '[us]', NULL, NULL, NULL, NULL), + (22, 'Lionsgate Films', '[us]', NULL, NULL, NULL, NULL), + (23, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL), + (24, 'Sony Pictures', '[us]', NULL, NULL, NULL, NULL), + (25, 'Bavaria Film', '[de]', NULL, NULL, NULL, NULL), + (26, 'Dutch FilmWorks', '[nl]', NULL, NULL, NULL, NULL), + (27, 'San Marino Films', '[sm]', NULL, NULL, NULL, NULL), + (28, 'Legendary Pictures', '[us]', NULL, NULL, NULL, NULL), + (29, 'Dutch Entertainment Group', '[nl]', NULL, NULL, NULL, NULL), + (30, 'Amsterdam Studios', '[nl]', NULL, NULL, NULL, NULL), + (31, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL), + (32, 'Berlin Film Studio', '[de]', NULL, NULL, NULL, NULL), + (33, 'Stockholm Productions', '[se]', NULL, NULL, NULL, NULL), + (34, 'Oslo Films', '[no]', NULL, NULL, NULL, NULL), + (35, 'Warner Bros. Pictures', '[us]', NULL, NULL, NULL, NULL), + (36, 'Silicon Entertainment', '[us]', NULL, NULL, NULL, NULL), + (37, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL); + +# name table for actors/directors information +statement ok +CREATE TABLE name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + imdb_id INT, + gender VARCHAR, + name_pcode_cf VARCHAR, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO name VALUES + (1, 'Xavier Thompson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (2, 'Susan Hill', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (3, 'Tim Robbins', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (4, 'Marlon Brando', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (5, 'Heath Ledger', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (6, 'Al Pacino', NULL, NULL, 'm', 'A', NULL, NULL, NULL), + (7, 'Downey Pacino', NULL, NULL, 'm', 'D', NULL, NULL, NULL), + (8, 'John Travolta', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (9, 'Liam Neeson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (10, 'Ian McKellen', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (11, 'Henry Fonda', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (12, 'Leonardo DiCaprio', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (13, 'Downey Robert Jr.', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (14, 'Zach Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (15, 'Bert Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (29, 'Alex Morgan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (16, 'Christian Bale', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (17, 'Christopher Nolan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (18, 'Angelina Jolie', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (19, 'Brad Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (20, 'Derek Jacobi', NULL, NULL, 'm', 'D624', NULL, NULL, NULL), + (21, 'Anne Hathaway', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (22, 'John Carpenter', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (23, 'James Wan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (24, 'Ridley Scott', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (25, 'Angelina Jolie', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (26, 'Yoko Tanaka', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (27, 'James Cameron', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (28, 'Edith Head', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (29, 'Anne Hathaway', NULL, NULL, 'f', NULL, NULL, NULL, NULL); + +# aka_name table +statement ok +CREATE TABLE aka_name ( + id INT NOT NULL, + person_id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + name_pcode_cf VARCHAR, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO aka_name VALUES + (1, 2, 'Marlon Brando Jr.', NULL, NULL, NULL, NULL, NULL), + (2, 2, 'Marlon Brando', NULL, NULL, NULL, NULL, NULL), + (3, 3, 'Heath Andrew Ledger', NULL, NULL, NULL, NULL, NULL), + (4, 6, 'Alfredo James Pacino', NULL, NULL, NULL, NULL, NULL), + (5, 5, 'John Joseph Travolta', NULL, NULL, NULL, NULL, NULL), + (6, 6, 'Liam John Neeson', NULL, NULL, NULL, NULL, NULL), + (7, 7, 'Ian Murray McKellen', NULL, NULL, NULL, NULL, NULL), + (8, 8, 'Henry Jaynes Fonda', NULL, NULL, NULL, NULL, NULL), + (9, 9, 'Leonardo Wilhelm DiCaprio', NULL, NULL, NULL, NULL, NULL), + (10, 10, 'Robert John Downey Jr.', NULL, NULL, NULL, NULL, NULL), + (11, 16, 'Christian Charles Philip Bale', NULL, NULL, NULL, NULL, NULL), + (12, 29, 'Christopher Jonathan James Nolan', NULL, NULL, NULL, NULL, NULL), + (13, 47, 'Joaquin Rafael Bottom', NULL, NULL, NULL, NULL, NULL), + (14, 26, 'Yoko Shimizu', NULL, NULL, NULL, NULL, NULL), + (15, 48, 'Chadwick Aaron Boseman', NULL, NULL, NULL, NULL, NULL), + (16, 29, 'Scarlett Ingrid Johansson', NULL, NULL, NULL, NULL, NULL), + (17, 31, 'Christopher Robert Evans', NULL, NULL, NULL, NULL, NULL), + (18, 32, 'Christopher Hemsworth', NULL, NULL, NULL, NULL, NULL), + (19, 33, 'Mark Alan Ruffalo', NULL, NULL, NULL, NULL, NULL), + (20, 20, 'Sir Derek Jacobi', NULL, NULL, NULL, NULL, NULL), + (21, 34, 'Samuel Leroy Jackson', NULL, NULL, NULL, NULL, NULL), + (22, 35, 'Gwyneth Kate Paltrow', NULL, NULL, NULL, NULL, NULL), + (23, 36, 'Thomas William Hiddleston', NULL, NULL, NULL, NULL, NULL), + (24, 37, 'Morgan Porterfield Freeman', NULL, NULL, NULL, NULL, NULL), + (25, 38, 'William Bradley Pitt', NULL, NULL, NULL, NULL, NULL), + (26, 39, 'Edward John Norton Jr.', NULL, NULL, NULL, NULL, NULL), + (27, 40, 'Marion Cotillard', NULL, NULL, NULL, NULL, NULL), + (28, 41, 'Joseph Leonard Gordon-Levitt', NULL, NULL, NULL, NULL, NULL), + (29, 42, 'Matthew David McConaughey', NULL, NULL, NULL, NULL, NULL), + (30, 43, 'Anne Jacqueline Hathaway', NULL, NULL, NULL, NULL, NULL), + (31, 44, 'Kevin Feige', NULL, NULL, NULL, NULL, NULL), + (32, 45, 'Margaret Ruth Gyllenhaal', NULL, NULL, NULL, NULL, NULL), + (33, 46, 'Kate Elizabeth Winslet', NULL, NULL, NULL, NULL, NULL), + (34, 28, 'E. Head', NULL, NULL, NULL, NULL, NULL), + (35, 29, 'Anne Jacqueline Hathaway', NULL, NULL, NULL, NULL, NULL), + (36, 29, 'Alexander Morgan', NULL, NULL, NULL, NULL, NULL), + (37, 2, 'Brando, M.', NULL, NULL, NULL, NULL, NULL), + (38, 21, 'Annie Hathaway', NULL, NULL, NULL, NULL, NULL), + (39, 21, 'Annie H', NULL, NULL, NULL, NULL, NULL), + (40, 25, 'Angie Jolie', NULL, NULL, NULL, NULL, NULL), + (41, 27, 'Jim Cameron', NULL, NULL, NULL, NULL, NULL), + (42, 18, 'Angelina Jolie', NULL, NULL, NULL, NULL, NULL); + +# role_type table +statement ok +CREATE TABLE role_type ( + id INT NOT NULL, + role VARCHAR NOT NULL +); + +statement ok +INSERT INTO role_type VALUES + (1, 'actor'), + (2, 'actress'), + (3, 'producer'), + (4, 'writer'), + (5, 'cinematographer'), + (6, 'composer'), + (7, 'costume designer'), + (8, 'director'), + (9, 'editor'), + (10, 'miscellaneous crew'); + +# link_type table +statement ok +CREATE TABLE link_type ( + id INT NOT NULL, + link VARCHAR NOT NULL +); + +statement ok +INSERT INTO link_type VALUES + (1, 'sequel'), + (2, 'follows'), + (3, 'remake of'), + (4, 'version of'), + (5, 'spin off from'), + (6, 'reference to'), + (7, 'featured in'), + (8, 'spoofed in'), + (9, 'edited into'), + (10, 'alternate language version of'), + (11, 'features'); + +# movie_link table +statement ok +CREATE TABLE movie_link ( + id INT NOT NULL, + movie_id INT NOT NULL, + linked_movie_id INT NOT NULL, + link_type_id INT NOT NULL +); + +statement ok +INSERT INTO movie_link VALUES + (1, 2, 4, 1), + (2, 3, 5, 6), + (3, 6, 7, 4), + (4, 8, 9, 8), + (5, 10, 1, 3), + (6, 28, 29, 1), + (7, 30, 31, 2), + (8, 1, 3, 6), + (9, 23, 13, 1), + (10, 13, 24, 2), + (11, 20, 3, 1), + (12, 3, 22, 1), + (13, 2, 4, 2), + (14, 19, 19, 6), + (15, 14, 16, 6), + (16, 13, 23, 2), + (17, 25, 9, 4), + (18, 17, 1, 8), + (19, 24, 23, 2), + (20, 21, 22, 1), + (21, 15, 9, 6), + (22, 11, 13, 1), + (23, 13, 11, 2), + (24, 100, 100, 7), + (25, 1, 2, 7), + (26, 23, 2, 7), + (27, 14, 25, 9), + (28, 4, 6, 4), + (29, 5, 8, 6), + (30, 7, 10, 6), + (31, 9, 2, 8), + (32, 38, 39, 2), + (33, 59, 5, 2), + (34, 60, 9, 2), + (35, 49, 49, 11), + (36, 35, 36, 2); + +# complete_cast table +statement ok +CREATE TABLE complete_cast ( + id INT NOT NULL, + movie_id INT NOT NULL, + subject_id INT NOT NULL, + status_id INT NOT NULL +); + +statement ok +INSERT INTO complete_cast VALUES + (1, 1, 1, 1), + (2, 2, 1, 1), + (3, 3, 1, 1), + (4, 4, 1, 1), + (5, 5, 1, 1), + (6, 6, 1, 1), + (7, 7, 1, 1), + (8, 8, 1, 1), + (9, 9, 1, 1), + (10, 10, 1, 1), + (11, 11, 1, 1), + (12, 12, 1, 1), + (13, 13, 1, 1), + (14, 14, 1, 1), + (15, 15, 1, 1), + (16, 16, 1, 1), + (17, 17, 1, 1), + (18, 18, 1, 1), + (19, 19, 1, 2), + (20, 20, 2, 1), + (21, 21, 1, 1), + (22, 22, 1, 1), + (23, 23, 1, 3), + (24, 24, 1, 1), + (25, 25, 1, 1), + (26, 26, 1, 1), + (27, 13, 2, 4), + (28, 44, 1, 4), + (29, 33, 1, 4), + (30, 31, 1, 1), + (31, 32, 1, 4), + (32, 33, 1, 4), + (33, 35, 2, 3), + (34, 36, 2, 3), + (35, 37, 1, 4), + (36, 37, 1, 3), + (37, 38, 1, 3), + (38, 39, 1, 3), + (39, 39, 1, 11), + (40, 40, 1, 4); + +# comp_cast_type table +statement ok +CREATE TABLE comp_cast_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO comp_cast_type VALUES + (1, 'cast'), + (2, 'crew'), + (3, 'complete'), + (4, 'complete+verified'), + (5, 'pending'), + (6, 'unverified'), + (7, 'uncredited cast'), + (8, 'uncredited crew'), + (9, 'unverified cast'), + (10, 'unverified crew'), + (11, 'complete cast'); + +# person_info table +statement ok +CREATE TABLE person_info ( + id INT NOT NULL, + person_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO person_info VALUES + (1, 1, 3, 'actor,producer', NULL), + (2, 2, 3, 'actor,director', NULL), + (3, 3, 3, 'actor', NULL), + (4, 6, 3, 'actor,producer', NULL), + (5, 5, 3, 'actor', NULL), + (6, 6, 3, 'actor', NULL), + (7, 7, 3, 'actor', NULL), + (8, 8, 3, 'actor', NULL), + (9, 20, 30, 'Renowned Shakespearean actor and stage performer', 'Volker Boehm'), + (10, 10, 3, 'actor,producer', 'marvel-cinematic-universe'), + (11, 3, 1, 'Won Academy Award for portrayal of Joker', NULL), + (12, 10, 1, 'Played Iron Man in the Marvel Cinematic Universe', NULL), + (13, 16, 3, 'actor', NULL), + (14, 16, 1, 'Played Batman in The Dark Knight trilogy', NULL), + (15, 29, 3, 'director,producer,writer', NULL), + (16, 29, 1, 'Directed The Dark Knight trilogy', NULL), + (17, 47, 3, 'actor', NULL), + (18, 47, 1, 'Won Academy Award for portrayal of Joker', NULL), + (19, 48, 3, 'actor', NULL), + (20, 48, 1, 'Played Black Panther in the Marvel Cinematic Universe', NULL), + (21, 29, 3, 'actress', NULL), + (22, 29, 1, 'Played Black Widow in the Marvel Cinematic Universe', NULL), + (23, 31, 3, 'actor', NULL), + (24, 31, 1, 'Played Captain America in the Marvel Cinematic Universe', NULL), + (25, 32, 3, 'actor', NULL), + (26, 32, 1, 'Played Thor in the Marvel Cinematic Universe', NULL), + (27, 9, 1, 'Won Academy Award for The Revenant', NULL), + (28, 9, 7, '1974-11-11', NULL), + (29, 10, 7, '1965-04-04', NULL), + (30, 16, 7, '1974-01-30', NULL), + (31, 47, 7, '1974-10-28', NULL), + (32, 48, 7, '1976-11-29', NULL), + (33, 29, 7, '1984-11-22', NULL), + (34, 31, 7, '1981-06-13', NULL), + (35, 32, 7, '1983-08-11', NULL), + (36, 21, 14, 'Won an Oscar for Les Miserables.', 'IMDB staff'), + (37, 21, 14, 'Voiced Queen in Shrek 2.', 'IMDB staff'), + (38, 21, 28, '5 ft 8 in (1.73 m)', 'IMDB staff'), + (39, 6, 30, 'Famous for his role in The Godfather', 'Volker Boehm'); + +# aka_title table +statement ok +CREATE TABLE aka_title ( + id INT NOT NULL, + movie_id INT NOT NULL, + title VARCHAR, + imdb_index VARCHAR, + kind_id INT NOT NULL, + production_year INT, + phonetic_code VARCHAR, + episode_of_id INT, + season_nr INT, + episode_nr INT, + note VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO aka_title VALUES + (1, 1, 'Shawshank', NULL, 1, 1994, NULL, NULL, NULL, NULL, NULL, NULL), + (2, 2, 'Der Pate', NULL, 1, 1972, NULL, NULL, NULL, NULL, 'German title', NULL), + (3, 3, 'The Dark Knight', NULL, 1, 2008, NULL, NULL, NULL, NULL, NULL, NULL), + (4, 4, 'Der Pate II', NULL, 1, 1974, NULL, NULL, NULL, NULL, 'German title', NULL), + (5, 5, 'Pulp Fiction', NULL, 1, 1994, NULL, NULL, NULL, NULL, NULL, NULL), + (6, 6, 'La lista di Schindler', NULL, 1, 1993, NULL, NULL, NULL, NULL, 'Italian title', NULL), + (7, 7, 'LOTR: ROTK', NULL, 1, 2003, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (8, 8, '12 Angry Men', NULL, 1, 1957, NULL, NULL, NULL, NULL, NULL, NULL), + (9, 9, 'Dream Heist', NULL, 1, 2010, NULL, NULL, NULL, NULL, 'Working title', NULL), + (10, 10, 'Fight Club', NULL, 1, 1999, NULL, NULL, NULL, NULL, NULL, NULL), + (11, 3, 'Batman: The Dark Knight', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Full title', NULL), + (12, 13, 'Avengers 4', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (13, 19, 'The Joker', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Working title', NULL), + (14, 23, 'Iron Man: Birth of a Hero', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (15, 24, 'Black Panther: Wakanda Forever', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Alternate title', NULL), + (16, 11, 'Avengers 3', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (17, 3, 'Batman 2', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Sequel numbering', NULL), + (18, 20, 'Batman: Year One', NULL, 1, 2005, NULL, NULL, NULL, NULL, 'Working title', NULL), + (19, 14, 'Journey to the Stars', NULL, 1, 2014, NULL, NULL, NULL, NULL, 'Working title', NULL), + (20, 25, 'Rose and Jack', NULL, 1, 1997, NULL, NULL, NULL, NULL, 'Character-based title', NULL), + (21, 19, 'Joker: A Descent Into Madness', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (22, 22, 'Batman 3', NULL, 1, 2012, NULL, NULL, NULL, NULL, 'Sequel numbering', NULL), + (23, 1, 'The Shawshank Redemption', NULL, 1, 1994, NULL, NULL, NULL, NULL, 'Full title', NULL), + (24, 19, 'El Joker', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Spanish title', NULL), + (25, 13, 'Los Vengadores: Endgame', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Spanish title', NULL), + (26, 19, 'The Batman', NULL, 1, 2022, NULL, NULL, NULL, NULL, 'Working title', NULL), + (27, 41, 'Champion Boxer: The Rise of a Legend', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (28, 47, 'The Swedish Murder Case', NULL, 1, 2012, NULL, NULL, NULL, NULL, 'Full title', NULL), + (29, 46, 'Viral Documentary', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Alternate title', NULL), + (30, 45, 'Berlin Noir', NULL, 1, 2010, 989898, NULL, NULL, NULL, NULL, NULL), + (31, 44, 'Digital Connection', NULL, 1, 2005, NULL, NULL, NULL, NULL, NULL, NULL), + (32, 62, 'Animated Feature', NULL, 1, 2010, 123456, NULL, NULL, NULL, NULL, NULL); + +# 1a - Query with production companies and top 250 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'top 250 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%' or mc.note like '%(presents)%') + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(co-production) Avengers: Endgame 1985 + +# 1b - Query with production companies and bottom 10 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'bottom 10 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' + AND t.production_year between 2005 and 2010 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(as Warner Bros. Pictures) Bad Movie Sequel 2008 + +# 1c - Query with distributors and top 250 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'top 250 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%') + AND t.production_year >2010 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(co-production) Avengers: Endgame 2014 + +# 1d - Query with production companies and top 250 rank (different production year) +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'bottom 10 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' + AND t.production_year >2000 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(as Warner Bros. Pictures) Bad Movie Sequel 2008 + +# 2a - Query with German companies and character-name-in-title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[de]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Joker + +# 2b - Query with Dutch companies and character-name-in-title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[nl]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Iron Man + +# 2c - Query with Slovenian companies and female name in title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[sm]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Joker + +# 2d - Query with US companies and murder movies +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Bruno + +# 3a - Query with runtimes > 100 +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year > 2005 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +The Godfather Part II + +# 3b - Query with Bulgarian movies +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Bulgaria') + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +The Dark Knight + +# 3c - Query with biographies +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND t.production_year > 1990 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +Avengers: Endgame + +# 4a - Query with certain actor names +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '5.0' + AND t.production_year > 2005 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +8.9 Avengers: Endgame + +# 4b - Query with certain actor names (revised) +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '9.0' + AND t.production_year > 2000 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +9.1 The Dark Knight + +# 4c - Query with actors from certain period +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '2.0' + AND t.production_year > 1990 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +7.2 Avengers: Endgame + +# 5a - Query with keyword and movie links +query T +SELECT MIN(t.title) AS typical_european_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note like '%(theatrical)%' and mc.note like '%(France)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year > 2005 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +The Matrix + +# 5b - Query with keyword and directors +query T +SELECT MIN(t.title) AS american_vhs_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note like '%(VHS)%' and mc.note like '%(USA)%' and mc.note like '%(1994)%' + AND mi.info IN ('USA', 'America') + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +The Matrix + +# 5c - Query with female leading roles +query T +SELECT MIN(t.title) AS american_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note not like '%(TV)%' and mc.note like '%(USA)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND t.production_year > 1990 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +Champion Boxer + +# 6a - Query for Marvel movies with Robert Downey +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2010 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6b - Query for male actors in movies after 2009 +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2014 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +sequel Downey Robert Jr. Avengers: Endgame + +# 6c - Query for superhero movies from specific year +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2014 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6d - Query for specific director +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +based-on-comic Downey Robert Jr. Avengers: Endgame + +# 6e - Query for advanced superhero movies +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6f - Query for complex superhero movies +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +based-on-comic Al Pacino Avengers: Endgame + +# 7a - Query about character names +query TT +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name LIKE '%a%' + AND it.info ='mini biography' + AND lt.link ='features' + AND n.name_pcode_cf BETWEEN 'A' + AND 'F' + AND (n.gender='m' OR (n.gender = 'f' + AND n.name LIKE 'B%')) + AND pi.note ='Volker Boehm' + AND t.production_year BETWEEN 1980 + AND 1995 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id -- #Al Pacino The Godfather +---- +Derek Jacobi Derek Jacobi Story + +# 7b - Query for person with biography +query TT +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name LIKE '%a%' + AND it.info ='mini biography' + AND lt.link ='features' + AND n.name_pcode_cf LIKE 'D%' + AND n.gender='m' + AND pi.note ='Volker Boehm' + AND t.production_year BETWEEN 1980 + AND 1984 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id +---- +Derek Jacobi Derek Jacobi Story + +# 7c - Query for extended character names and biographies +query TT +SELECT MIN(n.name) AS cast_member_name, MIN(pi.info) AS cast_member_info +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name is not NULL and (an.name LIKE '%a%' or an.name LIKE 'A%') + AND it.info ='mini biography' + AND lt.link in ('references', 'referenced in', 'features', 'featured in') + AND n.name_pcode_cf BETWEEN 'A' + AND 'F' + AND (n.gender='m' OR (n.gender = 'f' + AND n.name LIKE 'A%')) + AND pi.note is not NULL + AND t.production_year BETWEEN 1980 + AND 2010 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id +---- +Al Pacino Famous for his role in The Godfather + +# 8a - Find movies by keyword +query TT +SELECT MIN(an1.name) AS actress_pseudonym, MIN(t.title) AS japanese_movie_dubbed +FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE ci.note ='(voice: English version)' + AND cn.country_code ='[jp]' + AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' + AND n1.name like '%Yo%' and n1.name not like '%Yu%' + AND rt.role ='actress' + AND an1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Yoko Shimizu One Piece: Grand Adventure + +# 8b - Query for anime voice actors +query TT +SELECT MIN(an.name) AS acress_pseudonym, MIN(t.title) AS japanese_anime_movie +FROM aka_name AS an, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note ='(voice: English version)' + AND cn.country_code ='[jp]' + AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' and (mc.note like '%(2006)%' or mc.note like '%(2007)%') + AND n.name like '%Yo%' and n.name not like '%Yu%' + AND rt.role ='actress' + AND t.production_year between 2006 and 2007 and (t.title like 'One Piece%' or t.title like 'Dragon Ball Z%') + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Yoko Shimizu One Piece: Grand Adventure + +# 8c - Query for extended movies by keyword and voice actors +query TT +SELECT MIN(a1.name) AS writer_pseudo_name, MIN(t.title) AS movie_title +FROM aka_name AS a1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE cn.country_code ='[us]' + AND rt.role ='writer' + AND a1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND a1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Jim Cameron Titanic + +# 8d - Query for specialized movies by keyword and voice actors +query TT +SELECT MIN(an1.name) AS costume_designer_pseudo, MIN(t.title) AS movie_with_costumes +FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE cn.country_code ='[us]' + AND rt.role ='costume designer' + AND an1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +E. Head Avengers: Endgame + +# 9a - Query for movie sequels +query TTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS character_name, MIN(t.title) AS movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND n.gender ='f' and n.name like '%Ang%' + AND rt.role ='actress' + AND t.production_year between 2005 and 2015 + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Angelina Jolie Batman's Assistant Kung Fu Panda + +# 9b - Query for voice actors in American movies +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note = '(voice)' + AND cn.country_code ='[us]' + AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND n.gender ='f' and n.name like '%Angel%' + AND rt.role ='actress' + AND t.production_year between 2007 and 2010 + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Angelina Jolie Batman's Assistant Angelina Jolie Kung Fu Panda + +# 9c - Query for extended movie sequels and voice actors +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Alexander Morgan Batman's Assistant Angelina Jolie Dragon Warriors + +# 9d - Query for specialized movie sequels and voice actors +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND n.gender ='f' + AND rt.role ='actress' + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Alexander Morgan Batman's Assistant Angelina Jolie Dragon Warriors + +# 10a - Query for cast combinations +query TT +SELECT MIN(chn.name) AS uncredited_voiced_character, MIN(t.title) AS russian_movie +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(voice)%' and ci.note like '%(uncredited)%' + AND cn.country_code = '[ru]' + AND rt.role = 'actor' + AND t.production_year > 2005 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Nikolai Moscow Nights + +# 10b - Query for Russian movie producers who are also actors +query TT +SELECT MIN(chn.name) AS character, MIN(t.title) AS russian_mov_with_actor_producer +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(producer)%' + AND cn.country_code = '[ru]' + AND rt.role = 'actor' + AND t.production_year > 2000 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Director Moscow Nights + +# 10c - Query for American producers in movies +query TT +SELECT MIN(chn.name) AS character, MIN(t.title) AS movie_with_american_producer +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(producer)%' + AND cn.country_code = '[us]' + AND t.production_year > 1990 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Bruce Wayne The Dark Knight + +# 11a - Query for non-Polish companies with sequels +query TTT +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS non_polish_sequel_movie +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Warner Bros. follows Money Talks + +# 11b - Query for non-Polish companies with Money sequels from 1998 +query TTT +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS sequel_movie +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follows%' + AND mc.note IS NULL + AND t.production_year = 1998 and t.title like '%Money%' + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Warner Bros. Pictures follows Money Talks + +# 11c - Query for Fox movies based on novels +query TTT +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' and (cn.name like '20th Century Fox%' or cn.name like 'Twentieth Century Fox%') + AND ct.kind != 'production companies' and ct.kind is not NULL + AND k.keyword in ('sequel', 'revenge', 'based-on-novel') + AND mc.note is not NULL + AND t.production_year > 1950 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Twentieth Century Fox (distribution) (DVD) (US) Fox Novel Movie + +# 11d - Query for movies based on novels from non-Polish companies +query TTT +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND ct.kind != 'production companies' and ct.kind is not NULL + AND k.keyword in ('sequel', 'revenge', 'based-on-novel') + AND mc.note is not NULL + AND t.production_year > 1950 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Marvel Studios (as Marvel Studios) Avengers: Endgame + +# 12a - Query for cast in movies with specific genres +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS drama_horror_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code = '[us]' + AND ct.kind = 'production companies' + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Drama', 'Horror') + AND mi_idx.info > '8.0' + AND t.production_year between 2005 and 2008 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +Warner Bros. 9.5 The Dark Knight + +# 12b - Query for unsuccessful movies with specific budget criteria +query TT +SELECT MIN(mi.info) AS budget, MIN(t.title) AS unsuccsessful_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind is not NULL and (ct.kind ='production companies' or ct.kind = 'distributors') + AND it1.info ='budget' + AND it2.info ='bottom 10 rank' + AND t.production_year >2000 + AND (t.title LIKE 'Birdemic%' OR t.title LIKE '%Movie%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +$500,000 Bad Movie Sequel + +# 12c - Query for highly rated mainstream movies +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS mainstream_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code = '[us]' + AND ct.kind = 'production companies' + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Drama', 'Horror', 'Western', 'Family') + AND mi_idx.info > '7.0' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +Warner Bros. 9.5 The Dark Knight + +# 13a - Query for movies with specific genre combinations +query TTT +SELECT MIN(mi.info) AS release_date, MIN(miidx.info) AS rating, MIN(t.title) AS german_movie +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[de]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +2005-09-15 7.2 Dark Blood + +# 13b - Query for movies about winning with specific criteria +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND t.title != '' + AND (t.title LIKE '%Champion%' OR t.title LIKE '%Loser%') + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Universal Pictures 7.5 Champion Boxer + +# 13c - Query for movies with Champion in the title +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND t.title != '' + AND (t.title LIKE 'Champion%' OR t.title LIKE 'Loser%') + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Universal Pictures 7.5 Champion Boxer + +# 13d - Query for all US movies +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Marvel Studios 7.5 Avengers: Endgame + +# 14a - Query for actors in specific movie types +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS northern_dark_movie +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind = 'movie' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2010 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +7.5 Nordic Noir + +# 14b - Query for dark western productions with specific criteria +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS western_dark_production +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title') + AND kt.kind = 'movie' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info > '6.0' + AND t.production_year > 2010 and (t.title like '%murder%' or t.title like '%Murder%' or t.title like '%Mord%') + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +7.8 The Swedish Murder Case + +# 14c - Query for extended movie types and dark themes +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS north_european_dark_production +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword is not null and k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +6.8 Berlin Noir + +# 15a - Query for US movies with internet releases +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS internet_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year > 2000 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 2008-05-15 The Dark Knight + +# 15b - Query for YouTube movies with specific release criteria +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS youtube_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' and cn.name = 'YouTube' + AND it1.info = 'release dates' + AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year between 2005 and 2010 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 2008-05-15 YouTube Documentary + +# 15c - Query for extended internet releases +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS modern_american_internet_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 1990 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 15 May 2005 Digital Connection + +# 15d - Query for specialized internet releases +query TT +SELECT MIN(at.title) AS aka_title, MIN(t.title) AS internet_movie_title +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mi.note like '%internet%' + AND t.production_year > 1990 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Avengers 4 Avengers: Endgame + +# 16a - Query for movies in specific languages +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr >= 50 + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16b - Query for series named after characters +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16c - Query for extended languages and character-named series +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16d - Query for specialized languages and character-named series +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr >= 5 + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 17a - Query for actor/actress combinations +query TT +SELECT MIN(n.name) AS member_in_charnamed_american_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND n.name LIKE 'B%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson Bert Wilson + +# 17b - Query for actors with names starting with Z in character-named movies +query TT +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE 'Z%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Zach Wilson Zach Wilson + +# 17c - Query for extended actor/actress combinations +query TT +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE 'X%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Xavier Thompson Xavier Thompson + +# 17d - Query for specialized actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE '%Bert%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson + +# 17e - Query for advanced actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alex Morgan + +# 17f - Query for complex actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE '%B%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson + +# 18a - Query with complex genre filtering +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(producer)', '(executive producer)') + AND it1.info = 'budget' + AND it2.info = 'votes' + AND n.gender = 'm' and n.name like '%Tim%' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +$25,000,000 2,345,678 The Shawshank Redemption + +# 18b - Query for horror movies by female writers with high ratings +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Horror', 'Thriller') and mi.note is NULL + AND mi_idx.info > '8.0' + AND n.gender is not null and n.gender = 'f' + AND t.production_year between 2008 and 2014 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +Horror 8.5 Woman in Black + +# 18c - Query for extended genre filtering +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +Horror 1000 Halloween + +# 19a - Query for character name patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%Ang%' + AND rt.role ='actress' + AND t.production_year between 2005 and 2009 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19b - Query for Angelina Jolie as voice actress in Kung Fu Panda series +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS kung_fu_panda +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note = '(voice)' + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND mi.info is not null and (mi.info like 'Japan:%2007%' or mi.info like 'USA:%2008%') + AND n.gender ='f' and n.name like '%Angel%' + AND rt.role ='actress' + AND t.production_year between 2007 and 2008 and t.title like '%Kung%Fu%Panda%' + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19c - Query for extended character patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19d - Query for specialized character patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND n.gender ='f' + AND rt.role ='actress' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 20a - Query for movies with specific actor roles +query T +SELECT MIN(t.title) AS complete_downey_ironman_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') + AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND kt.kind = 'movie' + AND t.production_year > 1950 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Iron Man + +# 20b - Query for complete Downey Iron Man movies +query T +SELECT MIN(t.title) AS complete_downey_ironman_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') + AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND kt.kind = 'movie' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Iron Man + +# 20c - Query for extended specific actor roles +query TT +SELECT MIN(n.name) AS cast_member, MIN(t.title) AS complete_dynamic_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Downey Robert Jr. Iron Man + +# 21a - Query for movies with specific production years +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 21b - Query for German follow-up movies +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS german_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Germany', 'German') + AND t.production_year BETWEEN 2000 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Berlin Film Studio follows Dark Blood + +# 21c - Query for extended specific production years +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') + AND t.production_year BETWEEN 1950 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Berlin Film Studio follows Dark Blood + +# 22a - Query for movies with specific actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Germany', 'German', 'USA', 'American') + AND mi_idx.info < '7.0' + AND t.production_year > 2008 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22b - Query for western violent movies by non-US companies +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Germany', 'German', 'USA', 'American') + AND mi_idx.info < '7.0' + AND t.production_year > 2009 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22c - Query for extended actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22d - Query for specialized actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 23a - Query for sequels with specific character names +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND kt.kind in ('movie') + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 23b - Query for complete nerdy internet movies +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_nerdy_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND k.keyword in ('nerd', 'loner', 'alienation', 'dignity') + AND kt.kind in ('movie') + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 23c - Query for extended sequels with specific attributes +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND kt.kind in ('movie', 'tv movie', 'video movie', 'video game') + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 1990 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 24a - Query for movies with specific budgets +query TTT +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS voiced_action_movie_jap_eng +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat') + AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND ci.movie_id = mk.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND k.id = mk.keyword_id +---- +Batman's Assistant Angelina Jolie Kung Fu Panda 2 + +# 24b - Query for voiced characters in Kung Fu Panda +query TTT +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS kung_fu_panda +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND cn.name = 'DreamWorks Animation' + AND it.info = 'release dates' + AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat', 'computer-animated-movie') + AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2010 + AND t.title like 'Kung Fu Panda%' + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND ci.movie_id = mk.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND k.id = mk.keyword_id +---- +Batman's Assistant Angelina Jolie Kung Fu Panda 2 + +# 25a - Query for cast combinations in specific movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') + AND mi.info = 'Horror' + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Halloween + +# 25b - Query for violent horror films with male writers +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') + AND mi.info = 'Horror' + AND n.gender = 'm' + AND t.production_year > 2010 + AND t.title like 'Vampire%' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Vampire Chronicles + +# 25c - Query for extended cast combinations +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Halloween + +# 26a - Query for specific movie genres with ratings +query TTTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(n.name) AS playing_actor, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND mi_idx.info > '7.0' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 John Carpenter Marvel Superhero Epic + +# 26b - Query for complete hero movies with Man in character name +query TTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'fight') + AND kt.kind = 'movie' + AND mi_idx.info > '8.0' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 Marvel Superhero Epic + +# 26c - Query for extended movie genres and ratings +query TTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 Marvel Superhero Epic + +# 27a - Query for movies with specific person roles +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind = 'complete' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 27b - Query for complete western sequel films by non-Polish companies +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind = 'complete' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') + AND t.production_year = 1998 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 27c - Query for extended person roles +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like 'complete%' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') + AND t.production_year BETWEEN 1950 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 28a - Query for movies with specific production years +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'crew' + AND cct2.kind != 'complete+verified' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Stockholm Productions 7.8 The Nordic Murders + +# 28b - Query for Euro dark movies with complete crew +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'crew' + AND cct2.kind != 'complete+verified' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Germany', 'Swedish', 'German') + AND mi_idx.info > '6.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Stockholm Productions 7.8 The Nordic Murders + +# 28c - Query for extended movies with specific criteria +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind = 'complete' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Oslo Films 7.5 Scandinavian Crime + +# 29a - Query for movies with specific combinations +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND chn.name = 'Queen' + AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'trivia' + AND k.keyword = 'computer-animation' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.title = 'Shrek 2' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 29b - Query for specific Queen character voice actress +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND chn.name = 'Queen' + AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'height' + AND k.keyword = 'computer-animation' + AND mi.info like 'USA:%200%' + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.title = 'Shrek 2' + AND t.production_year between 2000 and 2005 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 29c - Query for extended specific combinations +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'trivia' + AND k.keyword = 'computer-animation' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 30a - Query for top-rated action movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 30b - Query for ratings of female-cast-only movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_gore_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 30c - Query for extended action movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 31a - Query for movies with specific language and production values +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 45000 James Wan Halloween + +# 31b - Query for sci-fi female-focused movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mc.note like '%(Blu-ray)%' + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 52000 James Wan Saw IV + +# 31c - Query for extended language and production values +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 45000 James Wan Halloween + +# 32a - Query for action movies with specific actor roles +query TTT +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 +WHERE k.keyword ='10,000-mile-club' + AND mk.keyword_id = k.id + AND t1.id = mk.movie_id + AND ml.movie_id = t1.id + AND ml.linked_movie_id = t2.id + AND lt.id = ml.link_type_id + AND mk.movie_id = t1.id +---- +edited into Interstellar Saving Private Ryan + +# 32b - Query for character-name-in-title movies and their connections +query TTT +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 +WHERE k.keyword ='character-name-in-title' + AND mk.keyword_id = k.id + AND t1.id = mk.movie_id + AND ml.movie_id = t1.id + AND ml.linked_movie_id = t2.id + AND lt.id = ml.link_type_id + AND mk.movie_id = t1.id +---- +featured in Iron Man Avengers: Endgame + +# 33a - Query for directors of sequels with specific ratings +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code = '[us]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series') + AND kt2.kind in ('tv series') + AND lt.link in ('sequel', 'follows', 'followed by') + AND mi_idx2.info < '3.0' + AND t2.production_year between 2005 and 2008 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Paramount Pictures Paramount Pictures 8.2 2.8 Breaking Bad Breaking Bad: The Final Season + +# 33b - Query for linked TV series by country code +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code = '[nl]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series') + AND kt2.kind in ('tv series') + AND lt.link LIKE '%follow%' + AND mi_idx2.info < '3.0' + AND t2.production_year = 2007 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Dutch Entertainment Group Amsterdam Studios 8.5 2.5 Amsterdam Detective Amsterdam Detective: Cold Case + +# 33c - Query for linked TV series and episodes with specific ratings +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code != '[us]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series', 'episode') + AND kt2.kind in ('tv series', 'episode') + AND lt.link in ('sequel', 'follows', 'followed by') + AND mi_idx2.info < '3.5' + AND t2.production_year between 2000 and 2010 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Dutch Entertainment Group Amsterdam Studios 8.5 2.5 Amsterdam Detective Amsterdam Detective: Cold Case + +# Clean up all tables +statement ok +DROP TABLE company_type; + +statement ok +DROP TABLE info_type; + +statement ok +DROP TABLE title; + +statement ok +DROP TABLE movie_companies; + +statement ok +DROP TABLE movie_info_idx; + +statement ok +DROP TABLE movie_info; + +statement ok +DROP TABLE kind_type; + +statement ok +DROP TABLE cast_info; + +statement ok +DROP TABLE char_name; + +statement ok +DROP TABLE keyword; + +statement ok +DROP TABLE movie_keyword; + +statement ok +DROP TABLE company_name; + +statement ok +DROP TABLE name; + +statement ok +DROP TABLE role_type; + +statement ok +DROP TABLE link_type; + +statement ok +DROP TABLE movie_link; + +statement ok +DROP TABLE complete_cast; + +statement ok +DROP TABLE comp_cast_type; + +statement ok +DROP TABLE person_info; + +statement ok +DROP TABLE aka_title; + +statement ok +DROP TABLE aka_name; diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 496f24abf6ed7..670992633bb85 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -149,6 +149,39 @@ drop table t statement ok drop table t2 + +############ +## 0 to represent the default value (target_partitions and planning_concurrency) +########### + +statement ok +SET datafusion.execution.target_partitions = 3; + +statement ok +SET datafusion.execution.planning_concurrency = 3; + +# when setting target_partitions and planning_concurrency to 3, their values will be 3 +query TB rowsort +SELECT name, value = 3 FROM information_schema.df_settings WHERE name IN ('datafusion.execution.target_partitions', 'datafusion.execution.planning_concurrency'); +---- +datafusion.execution.planning_concurrency true +datafusion.execution.target_partitions true + +statement ok +SET datafusion.execution.target_partitions = 0; + +statement ok +SET datafusion.execution.planning_concurrency = 0; + +# when setting target_partitions and planning_concurrency to 0, their values will be equal to the +# default values, which are different from 0 (which is invalid) +query TB rowsort +SELECT name, value = 0 FROM information_schema.df_settings WHERE name IN ('datafusion.execution.target_partitions', 'datafusion.execution.planning_concurrency'); +---- +datafusion.execution.planning_concurrency false +datafusion.execution.target_partitions false + + ############ ## SHOW VARIABLES should work ########### @@ -183,20 +216,23 @@ datafusion.catalog.location NULL datafusion.catalog.newlines_in_values false datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true -datafusion.execution.collect_statistics false +datafusion.execution.collect_statistics true datafusion.execution.enable_recursive_ctes true datafusion.execution.enforce_batch_size_in_joins false datafusion.execution.keep_partition_by_columns false +datafusion.execution.listing_table_factory_infer_partitions true datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 +datafusion.execution.objectstore_writer_buffer_size 10485760 datafusion.execution.parquet.allow_single_file_parallelism true datafusion.execution.parquet.binary_as_string false datafusion.execution.parquet.bloom_filter_fpp NULL datafusion.execution.parquet.bloom_filter_ndv NULL datafusion.execution.parquet.bloom_filter_on_read true datafusion.execution.parquet.bloom_filter_on_write false +datafusion.execution.parquet.coerce_int96 NULL datafusion.execution.parquet.column_index_truncate_length 64 datafusion.execution.parquet.compression zstd(3) datafusion.execution.parquet.created_by datafusion @@ -206,8 +242,8 @@ datafusion.execution.parquet.dictionary_enabled true datafusion.execution.parquet.dictionary_page_size_limit 1048576 datafusion.execution.parquet.enable_page_index true datafusion.execution.parquet.encoding NULL +datafusion.execution.parquet.max_predicate_cache_size NULL datafusion.execution.parquet.max_row_group_size 1048576 -datafusion.execution.parquet.max_statistics_size 4096 datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 datafusion.execution.parquet.maximum_parallel_row_group_writers 1 datafusion.execution.parquet.metadata_size_hint NULL @@ -218,7 +254,7 @@ datafusion.execution.parquet.schema_force_view_types true datafusion.execution.parquet.skip_arrow_metadata false datafusion.execution.parquet.skip_metadata true datafusion.execution.parquet.statistics_enabled page -datafusion.execution.parquet.statistics_truncate_length NULL +datafusion.execution.parquet.statistics_truncate_length 64 datafusion.execution.parquet.write_batch_size 1024 datafusion.execution.parquet.writer_version 1.0 datafusion.execution.planning_concurrency 13 @@ -228,6 +264,7 @@ datafusion.execution.skip_physical_aggregate_schema_check false datafusion.execution.soft_max_rows_per_output_file 50000000 datafusion.execution.sort_in_place_threshold_bytes 1048576 datafusion.execution.sort_spill_reservation_bytes 10485760 +datafusion.execution.spill_compression uncompressed datafusion.execution.split_file_groups_by_statistics false datafusion.execution.target_partitions 7 datafusion.execution.time_zone +00:00 @@ -238,11 +275,23 @@ datafusion.explain.physical_plan_only false datafusion.explain.show_schema false datafusion.explain.show_sizes true datafusion.explain.show_statistics false +datafusion.explain.tree_maximum_render_width 240 +datafusion.format.date_format %Y-%m-%d +datafusion.format.datetime_format %Y-%m-%dT%H:%M:%S%.f +datafusion.format.duration_format pretty +datafusion.format.null (empty) +datafusion.format.safe true +datafusion.format.time_format %H:%M:%S%.f +datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f +datafusion.format.timestamp_tz_format NULL +datafusion.format.types_info false datafusion.optimizer.allow_symmetric_joins_without_pruning true datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true +datafusion.optimizer.enable_dynamic_filter_pushdown true datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true +datafusion.optimizer.enable_window_limits true datafusion.optimizer.expand_views_at_output false datafusion.optimizer.filter_null_join_keys false datafusion.optimizer.hash_join_single_partition_threshold 1048576 @@ -260,10 +309,11 @@ datafusion.optimizer.repartition_windows true datafusion.optimizer.skip_failed_rules false datafusion.optimizer.top_down_join_key_reordering true datafusion.sql_parser.collect_spans false +datafusion.sql_parser.default_null_ordering nulls_max datafusion.sql_parser.dialect generic datafusion.sql_parser.enable_ident_normalization true datafusion.sql_parser.enable_options_value_normalization false -datafusion.sql_parser.map_varchar_to_utf8view false +datafusion.sql_parser.map_string_types_to_utf8view true datafusion.sql_parser.parse_float_as_decimal false datafusion.sql_parser.recursion_limit 50 datafusion.sql_parser.support_varchar_with_length true @@ -282,20 +332,23 @@ datafusion.catalog.location NULL Location scanned to load tables for `default` s datafusion.catalog.newlines_in_values false Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting -datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.collect_statistics true Should DataFusion collect statistics when first creating a table. Has no effect after the table is created. Applies to the default `ListingTableProvider` in DataFusion. Defaults to true. datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches +datafusion.execution.listing_table_factory_infer_partitions true Should a `ListingTable` created through the `ListingTableFactory` infer table partitions from Hive compliant directories. Defaults to true (partition columns are inferred and will be represented in the table schema). datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. +datafusion.execution.objectstore_writer_buffer_size 10485760 Size (bytes) of data buffer DataFusion uses when writing output files. This affects the size of the data chunks that are uploaded to remote object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being written, it may be necessary to increase this size to avoid errors from the remote end point. datafusion.execution.parquet.allow_single_file_parallelism true (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. datafusion.execution.parquet.binary_as_string false (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. datafusion.execution.parquet.bloom_filter_fpp NULL (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_ndv NULL (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting -datafusion.execution.parquet.bloom_filter_on_read true (writing) Use any available bloom filters when reading parquet files +datafusion.execution.parquet.bloom_filter_on_read true (reading) Use any available bloom filters when reading parquet files datafusion.execution.parquet.bloom_filter_on_write false (writing) Write bloom filters for all columns when creating parquet files +datafusion.execution.parquet.coerce_int96 NULL (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. datafusion.execution.parquet.column_index_truncate_length 64 (writing) Sets column index truncate length datafusion.execution.parquet.compression zstd(3) (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. datafusion.execution.parquet.created_by datafusion (writing) Sets "created by" property @@ -305,8 +358,8 @@ datafusion.execution.parquet.dictionary_enabled true (writing) Sets if dictionar datafusion.execution.parquet.dictionary_page_size_limit 1048576 (writing) Sets best effort maximum dictionary page size, in bytes datafusion.execution.parquet.enable_page_index true (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. datafusion.execution.parquet.encoding NULL (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.max_predicate_cache_size NULL (reading) The maximum predicate cache size, in bytes. When `pushdown_filters` is enabled, sets the maximum memory used to cache the results of predicate evaluation between filter evaluation and output generation. Decreasing this value will reduce memory usage, but may increase IO and CPU usage. None means use the default parquet reader setting. 0 means no caching. datafusion.execution.parquet.max_row_group_size 1048576 (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. -datafusion.execution.parquet.max_statistics_size 4096 (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting max_statistics_size is deprecated, currently it is not being used datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. datafusion.execution.parquet.maximum_parallel_row_group_writers 1 (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. datafusion.execution.parquet.metadata_size_hint NULL (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer @@ -317,7 +370,7 @@ datafusion.execution.parquet.schema_force_view_types true (reading) If true, par datafusion.execution.parquet.skip_arrow_metadata false (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to datafusion.execution.parquet.skip_metadata true (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata datafusion.execution.parquet.statistics_enabled page (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.statistics_truncate_length NULL (writing) Sets statictics truncate length. If NULL, uses default parquet writer setting +datafusion.execution.parquet.statistics_truncate_length 64 (writing) Sets statistics truncate length. If NULL, uses default parquet writer setting datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in bytes datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer version valid values are "1.0" and "2.0" datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system @@ -327,6 +380,7 @@ datafusion.execution.skip_physical_aggregate_schema_check false When set to true datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). +datafusion.execution.spill_compression uncompressed Sets the compression codec used when spilling data to disk. Since datafusion writes spill files using the Arrow IPC Stream format, only codecs supported by the Arrow IPC Stream Writer are allowed. Valid values are: uncompressed, lz4_frame, zstd. Note: lz4_frame offers faster (de)compression, but typically results in larger spill files. In contrast, zstd achieves higher compression ratios at the cost of slower (de)compression speed. datafusion.execution.split_file_groups_by_statistics false Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental datafusion.execution.target_partitions 7 Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour @@ -337,11 +391,23 @@ datafusion.explain.physical_plan_only false When set to true, the explain statem datafusion.explain.show_schema false When set to true, the explain statement will print schema information datafusion.explain.show_sizes true When set to true, the explain statement will print the partition sizes datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans +datafusion.explain.tree_maximum_render_width 240 (format=tree only) Maximum total width of the rendered tree. When set to 0, the tree will have no width limit. +datafusion.format.date_format %Y-%m-%d Date format for date arrays +datafusion.format.datetime_format %Y-%m-%dT%H:%M:%S%.f Format for DateTime arrays +datafusion.format.duration_format pretty Duration format. Can be either `"pretty"` or `"ISO8601"` +datafusion.format.null (empty) Format string for nulls +datafusion.format.safe true If set to `true` any formatting errors will be written to the output instead of being converted into a [`std::fmt::Error`] +datafusion.format.time_format %H:%M:%S%.f Time format for time arrays +datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f Timestamp format for timestamp arrays +datafusion.format.timestamp_tz_format NULL Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. +datafusion.format.types_info false Show types in visual representation batches datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. +datafusion.optimizer.enable_dynamic_filter_pushdown true When set to true attempts to push down dynamic filters generated by operators into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible +datafusion.optimizer.enable_window_limits true When set to true, the optimizer will attempt to push limit operations past window functions, if possible datafusion.optimizer.expand_views_at_output false When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. datafusion.optimizer.filter_null_join_keys false When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. datafusion.optimizer.hash_join_single_partition_threshold 1048576 The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition @@ -352,17 +418,18 @@ datafusion.optimizer.prefer_existing_union false When set to true, the optimizer datafusion.optimizer.prefer_hash_join true When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory datafusion.optimizer.repartition_aggregations true Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level datafusion.optimizer.repartition_file_min_size 10485760 Minimum total files size in bytes to perform file scan repartitioning. -datafusion.optimizer.repartition_file_scans true When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. +datafusion.optimizer.repartition_file_scans true When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). For FileSources, only Parquet and CSV formats are currently supported. If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't happen within a single file. If set to `true` for an in-memory source, all memtable's partitions will have their batches repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change the total number of partitions and batches per partition, but does not slice the initial record tables provided to the MemTable on creation. datafusion.optimizer.repartition_joins true Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level datafusion.optimizer.repartition_sorts true Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below ```text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` would turn into the plan below which performs better in multithreaded environments ```text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` datafusion.optimizer.repartition_windows true Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level datafusion.optimizer.skip_failed_rules false When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail datafusion.optimizer.top_down_join_key_reordering true When set to true, the physical plan optimizer will run a top down process to reorder the join keys datafusion.sql_parser.collect_spans false When set to true, the source locations relative to the original SQL query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected and recorded in the logical plan nodes. +datafusion.sql_parser.default_null_ordering nulls_max Specifies the default null ordering for query results. There are 4 options: - `nulls_max`: Nulls appear last in ascending order. - `nulls_min`: Nulls appear first in ascending order. - `nulls_first`: Nulls always be first in any order. - `nulls_last`: Nulls always be last in any order. By default, `nulls_max` is used to follow Postgres's behavior. postgres rule: datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. datafusion.sql_parser.enable_ident_normalization true When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) datafusion.sql_parser.enable_options_value_normalization false When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. -datafusion.sql_parser.map_varchar_to_utf8view false If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. If false, `VARCHAR` is mapped to `Utf8` during SQL planning. Default is false. +datafusion.sql_parser.map_string_types_to_utf8view true If true, string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. If false, they are mapped to `Utf8`. Default is true. datafusion.sql_parser.parse_float_as_decimal false When set to true, SQL parser will parse float as decimal type datafusion.sql_parser.recursion_limit 50 Specifies the recursion depth limit when parsing complex SQL Queries datafusion.sql_parser.support_varchar_with_length true If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. @@ -634,7 +701,7 @@ datafusion public abc CREATE EXTERNAL TABLE abc STORED AS CSV LOCATION ../../tes query TTT select routine_name, data_type, function_type from information_schema.routines where routine_name = 'string_agg'; ---- -string_agg LargeUtf8 AGGREGATE +string_agg String AGGREGATE # test every function type are included in the result query TTTTTTTBTTTT rowsort @@ -649,7 +716,7 @@ datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestam datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, None) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, Some("+TZ")) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) datafusion public rank datafusion public rank FUNCTION true NULL WINDOW Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. rank() -datafusion public string_agg datafusion public string_agg FUNCTION true LargeUtf8 AGGREGATE Concatenates the values of string expressions and places separator values between them. string_agg(expression, delimiter) +datafusion public string_agg datafusion public string_agg FUNCTION true String AGGREGATE Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) query B select is_deterministic from information_schema.routines where routine_name = 'now'; @@ -658,110 +725,65 @@ false # test every function type are included in the result query TTTITTTTBI -select * from information_schema.parameters where specific_name = 'date_trunc' OR specific_name = 'string_agg' OR specific_name = 'rank' ORDER BY specific_name, rid; ----- -datafusion public date_trunc 1 IN precision Utf8 NULL false 0 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 0 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 0 -datafusion public date_trunc 1 IN precision Utf8View NULL false 1 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 1 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 1 -datafusion public date_trunc 1 IN precision Utf8 NULL false 2 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 2 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 2 -datafusion public date_trunc 1 IN precision Utf8View NULL false 3 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 3 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 3 -datafusion public date_trunc 1 IN precision Utf8 NULL false 4 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 4 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 4 -datafusion public date_trunc 1 IN precision Utf8View NULL false 5 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 5 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 5 -datafusion public date_trunc 1 IN precision Utf8 NULL false 6 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 6 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 6 -datafusion public date_trunc 1 IN precision Utf8View NULL false 7 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 7 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 7 -datafusion public date_trunc 1 IN precision Utf8 NULL false 8 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 8 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 8 -datafusion public date_trunc 1 IN precision Utf8View NULL false 9 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 9 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 9 -datafusion public date_trunc 1 IN precision Utf8 NULL false 10 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 10 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 10 -datafusion public date_trunc 1 IN precision Utf8View NULL false 11 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 11 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 11 -datafusion public date_trunc 1 IN precision Utf8 NULL false 12 -datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 12 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 12 -datafusion public date_trunc 1 IN precision Utf8View NULL false 13 -datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 13 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 13 -datafusion public date_trunc 1 IN precision Utf8 NULL false 14 -datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 14 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 14 -datafusion public date_trunc 1 IN precision Utf8View NULL false 15 -datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 15 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 15 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 0 -datafusion public string_agg 2 IN delimiter Utf8 NULL false 0 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 0 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 1 -datafusion public string_agg 2 IN delimiter LargeUtf8 NULL false 1 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 1 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 2 -datafusion public string_agg 2 IN delimiter Null NULL false 2 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 2 +select * from information_schema.parameters where specific_name = 'date_trunc' OR specific_name = 'string_agg' OR specific_name = 'rank' ORDER BY specific_name, rid, data_type; +---- +datafusion public date_trunc 1 IN precision String NULL false 0 +datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 0 +datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 0 +datafusion public date_trunc 1 IN precision String NULL false 1 +datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 1 +datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 1 +datafusion public date_trunc 1 IN precision String NULL false 2 +datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 2 +datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 2 +datafusion public date_trunc 1 IN precision String NULL false 3 +datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 3 +datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 3 +datafusion public date_trunc 1 IN precision String NULL false 4 +datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 4 +datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 4 +datafusion public date_trunc 1 IN precision String NULL false 5 +datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 5 +datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 5 +datafusion public date_trunc 1 IN precision String NULL false 6 +datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 6 +datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 6 +datafusion public date_trunc 1 IN precision String NULL false 7 +datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 7 +datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 7 +datafusion public string_agg 2 IN delimiter Null NULL false 0 +datafusion public string_agg 1 IN expression String NULL false 0 +datafusion public string_agg 1 OUT NULL String NULL false 0 +datafusion public string_agg 1 IN expression String NULL false 1 +datafusion public string_agg 2 IN delimiter String NULL false 1 +datafusion public string_agg 1 OUT NULL String NULL false 1 # test variable length arguments query TTTBI rowsort select specific_name, data_type, parameter_mode, is_variadic, rid from information_schema.parameters where specific_name = 'concat'; ---- -concat LargeUtf8 IN true 2 -concat LargeUtf8 OUT false 2 -concat Utf8 IN true 1 -concat Utf8 OUT false 1 -concat Utf8View IN true 0 -concat Utf8View OUT false 0 +concat String IN true 0 +concat String OUT false 0 # test ceorcion signature query TTITI rowsort select specific_name, data_type, ordinal_position, parameter_mode, rid from information_schema.parameters where specific_name = 'repeat'; ---- repeat Int64 2 IN 0 -repeat Int64 2 IN 1 -repeat Int64 2 IN 2 -repeat LargeUtf8 1 IN 1 -repeat LargeUtf8 1 OUT 1 -repeat Utf8 1 IN 0 -repeat Utf8 1 OUT 0 -repeat Utf8 1 OUT 2 -repeat Utf8View 1 IN 2 +repeat String 1 IN 0 +repeat String 1 OUT 0 query TT??TTT rowsort show functions like 'date_trunc'; ---- -date_trunc Timestamp(Microsecond, None) [precision, expression] [Utf8, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, None) [precision, expression] [Utf8View, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, None) [precision, expression] [Utf8, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, None) [precision, expression] [Utf8View, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, None) [precision, expression] [Utf8, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, None) [precision, expression] [Utf8View, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, None) [precision, expression] [Utf8, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, None) [precision, expression] [Utf8View, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Microsecond, None) [precision, expression] [String, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [String, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Millisecond, None) [precision, expression] [String, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [String, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Nanosecond, None) [precision, expression] [String, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [String, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Second, None) [precision, expression] [String, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [String, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) statement ok show functions diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 8a9c01d36308d..9a3c959884aa0 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -68,7 +68,7 @@ physical_plan 02)--ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] 04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 @@ -128,7 +128,7 @@ physical_plan 01)DataSinkExec: sink=MemoryTable (partitions=1) 02)--CoalescePartitionsExec 03)----ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 05)--------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 06)----------CoalesceBatchesExec: target_batch_size=8192 07)------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 @@ -179,7 +179,7 @@ physical_plan 02)--ProjectionExec: expr=[a1@0 as a1, a2@1 as a2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] 04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 24982dfc28a75..075256ae4b92d 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -175,6 +175,34 @@ select * from partitioned_insert_test order by a,b,c 1 20 200 2 20 200 +statement count 0 +CREATE EXTERNAL TABLE +partitioned_insert_test_readback +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/'; + +query TTT +describe partitioned_insert_test_readback; +---- +c Int64 YES +a Dictionary(UInt16, Utf8) NO +b Dictionary(UInt16, Utf8) NO + +query ITT +select * from partitioned_insert_test_readback order by a,b,c; +---- +1 10 100 +1 10 200 +1 20 100 +2 20 100 +1 20 200 +2 20 200 + +query I +select count(*) from partitioned_insert_test_readback where b=100; +---- +3 + statement ok CREATE EXTERNAL TABLE partitioned_insert_test_verify(c bigint) @@ -333,6 +361,41 @@ select * from directory_test; 1 2 3 4 +statement count 0 +CREATE EXTERNAL TABLE +directory_with_dots_test(a bigint, b bigint) +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_versioned_parquet_table.v0/'; + +query I +INSERT INTO directory_with_dots_test values (1, 2), (3, 4); +---- +2 + +query II +select * from directory_with_dots_test; +---- +1 2 +3 4 + +statement count 0 +CREATE EXTERNAL TABLE +directory_with_dots_readback +STORED AS parquet +LOCATION 'test_files/scratch/insert_to_external/external_versioned_parquet_table.v0/'; + +query TTT +describe directory_with_dots_readback +---- +a Int64 YES +b Int64 YES + +query II +select * from directory_with_dots_readback +---- +1 2 +3 4 + statement ok CREATE EXTERNAL TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL) @@ -359,7 +422,7 @@ physical_plan 02)--ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] 04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 @@ -420,7 +483,7 @@ physical_plan 01)DataSinkExec: sink=ParquetSink(file_groups=[]) 02)--CoalescePartitionsExec 03)----ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 05)--------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 06)----------CoalesceBatchesExec: target_batch_size=8192 07)------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 diff --git a/datafusion/sqllogictest/test_files/issue_17138.slt b/datafusion/sqllogictest/test_files/issue_17138.slt new file mode 100644 index 0000000000000..de9cb4bcf77bb --- /dev/null +++ b/datafusion/sqllogictest/test_files/issue_17138.slt @@ -0,0 +1,36 @@ +statement ok +CREATE TABLE tab1(col0 INTEGER, col1 INTEGER, col2 INTEGER) + +statement ok +INSERT INTO tab1 VALUES(51,14,96) + +query R +SELECT NULL * AVG(DISTINCT 4) + SUM(col1) AS col0 FROM tab1 +---- +NULL + +query TT +EXPLAIN SELECT NULL * AVG(DISTINCT 4) + SUM(col1) AS col0 FROM tab1 +---- +logical_plan +01)Projection: Float64(NULL) AS col0 +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[NULL as col0] +02)--PlaceholderRowExec + +# Similar, with a few more arithmetic operations +query R +SELECT + CAST ( NULL AS INTEGER ) * + + AVG ( DISTINCT 4 ) + - SUM ( ALL + col1 ) AS col0 FROM tab1 +---- +NULL + +query TT +EXPLAIN SELECT + CAST ( NULL AS INTEGER ) * + + AVG ( DISTINCT 4 ) + - SUM ( ALL + col1 ) AS col0 FROM tab1 +---- +logical_plan +01)Projection: Float64(NULL) AS col0 +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[NULL as col0] +02)--PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index 972dd2265343d..2abe654a96c8c 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -681,7 +681,7 @@ select col2, col4 from t1 full outer join t2 on col1 = col3 query TT explain select * from t1 join t2 on false; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 physical_plan EmptyExec # Make batch size smaller than table row number. to introduce parallelism to the plan. @@ -842,7 +842,7 @@ LEFT JOIN department AS d ON (e.name = 'Alice' OR e.name = 'Bob'); ---- logical_plan -01)Left Join: Filter: e.name = Utf8("Alice") OR e.name = Utf8("Bob") +01)Left Join: Filter: e.name = Utf8View("Alice") OR e.name = Utf8View("Bob") 02)--SubqueryAlias: e 03)----TableScan: employees projection=[emp_id, name] 04)--SubqueryAlias: d @@ -853,47 +853,47 @@ physical_plan 03)----DataSourceExec: partitions=1, partition_sizes=[1] 04)----DataSourceExec: partitions=1, partition_sizes=[1] -query ITT +query ITT rowsort SELECT e.emp_id, e.name, d.dept_name FROM employees AS e LEFT JOIN department AS d ON (e.name = 'Alice' OR e.name = 'Bob'); ---- -1 Alice HR 1 Alice Engineering +1 Alice HR 1 Alice Sales -2 Bob HR 2 Bob Engineering +2 Bob HR 2 Bob Sales 3 Carol NULL # neither RIGHT OUTER JOIN -query ITT +query ITT rowsort SELECT e.emp_id, e.name, d.dept_name FROM department AS d RIGHT JOIN employees AS e ON (e.name = 'Alice' OR e.name = 'Bob'); ---- -1 Alice HR 1 Alice Engineering +1 Alice HR 1 Alice Sales -2 Bob HR 2 Bob Engineering +2 Bob HR 2 Bob Sales 3 Carol NULL # neither FULL OUTER JOIN -query ITT +query ITT rowsort SELECT e.emp_id, e.name, d.dept_name FROM department AS d FULL JOIN employees AS e ON (e.name = 'Alice' OR e.name = 'Bob'); ---- -1 Alice HR 1 Alice Engineering +1 Alice HR 1 Alice Sales -2 Bob HR 2 Bob Engineering +2 Bob HR 2 Bob Sales 3 Carol NULL @@ -929,7 +929,7 @@ ON (e.name = 'Alice' OR e.name = 'Bob'); logical_plan 01)Cross Join: 02)--SubqueryAlias: e -03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") +03)----Filter: employees.name = Utf8View("Alice") OR employees.name = Utf8View("Bob") 04)------TableScan: employees projection=[emp_id, name] 05)--SubqueryAlias: d 06)----TableScan: department projection=[dept_name] @@ -974,11 +974,11 @@ ON e.emp_id = d.emp_id WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); ---- logical_plan -01)Filter: d.dept_name != Utf8("Engineering") AND e.name = Utf8("Alice") OR e.name != Utf8("Alice") AND e.name = Utf8("Carol") +01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name != Utf8View("Alice") AND e.name = Utf8View("Carol") 02)--Projection: e.emp_id, e.name, d.dept_name 03)----Left Join: e.emp_id = d.emp_id 04)------SubqueryAlias: e -05)--------Filter: employees.name = Utf8("Alice") OR employees.name != Utf8("Alice") AND employees.name = Utf8("Carol") +05)--------Filter: employees.name = Utf8View("Alice") OR employees.name != Utf8View("Alice") AND employees.name = Utf8View("Carol") 06)----------TableScan: employees projection=[emp_id, name] 07)------SubqueryAlias: d 08)--------TableScan: department projection=[emp_id, dept_name] @@ -1404,3 +1404,112 @@ set datafusion.execution.target_partitions = 4; statement ok set datafusion.optimizer.repartition_joins = false; + +statement ok +CREATE TABLE t1(v0 BIGINT, v1 BIGINT); + +statement ok +CREATE TABLE t0(v0 BIGINT, v1 BIGINT); + +statement ok +INSERT INTO t0(v0, v1) VALUES (1, 1), (1, 2), (3, 3), (4, 4); + +statement ok +INSERT INTO t1(v0, v1) VALUES (1, 1), (3, 2), (3, 5); + +query TT +explain SELECT * +FROM t0, +LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); +---- +logical_plan +01)Projection: t0.v0, t0.v1, sum(t1.v1) +02)--Left Join: t0.v0 = t1.v0 +03)----TableScan: t0 projection=[v0, v1] +04)----Projection: sum(t1.v1), t1.v0 +05)------Aggregate: groupBy=[[t1.v0]], aggr=[[sum(t1.v1)]] +06)--------TableScan: t1 projection=[v0, v1] +physical_plan +01)ProjectionExec: expr=[v0@1 as v0, v1@2 as v1, sum(t1.v1)@0 as sum(t1.v1)] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----HashJoinExec: mode=CollectLeft, join_type=Right, on=[(v0@1, v0@0)], projection=[sum(t1.v1)@0, v0@2, v1@3] +04)------CoalescePartitionsExec +05)--------ProjectionExec: expr=[sum(t1.v1)@1 as sum(t1.v1), v0@0 as v0] +06)----------AggregateExec: mode=FinalPartitioned, gby=[v0@0 as v0], aggr=[sum(t1.v1)] +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------RepartitionExec: partitioning=Hash([v0@0], 4), input_partitions=4 +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------AggregateExec: mode=Partial, gby=[v0@0 as v0], aggr=[sum(t1.v1)] +11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] +12)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT * +FROM t0, +LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); +---- +1 1 1 +1 2 1 +3 3 7 +4 4 NULL + +query TT +explain SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); +---- +logical_plan +01)Inner Join: t0.v0 = t1.v0 +02)--TableScan: t0 projection=[v0, v1] +03)--TableScan: t1 projection=[v0, v1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(v0@0, v0@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); +---- +1 1 1 1 +1 2 1 1 +3 3 3 2 +3 3 3 5 + +query III +SELECT * FROM t0, LATERAL (SELECT 1); +---- +1 1 1 +1 2 1 +3 3 1 +4 4 1 + +query IIII +SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t1.v0 = 1); +---- +1 1 1 1 +1 2 1 1 +3 3 1 1 +4 4 1 1 + +query IIII +SELECT * FROM t0 JOIN LATERAL (SELECT * FROM t1 WHERE t1.v0 = 1) on true; +---- +1 1 1 1 +1 2 1 1 +3 3 1 1 +4 4 1 1 + +statement ok +drop table t1; + +statement ok +drop table t0; + +# SQLancer fuzzed query (https://github.com/apache/datafusion/issues/14015) +statement ok +create table t1(v1 int, v2 int); + +query error DataFusion error: Schema error: No field named tt1.v2. Valid fields are tt1.v1. +select v1 from t1 as tt1 natural join t1 as tt2 group by v1 order by v2; + +statement ok +drop table t1; diff --git a/datafusion/sqllogictest/test_files/join_is_not_distinct_from.slt b/datafusion/sqllogictest/test_files/join_is_not_distinct_from.slt new file mode 100644 index 0000000000000..0336cfc2d3314 --- /dev/null +++ b/datafusion/sqllogictest/test_files/join_is_not_distinct_from.slt @@ -0,0 +1,321 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test IS NOT DISTINCT FROM join functionality +# This tests the optimizer's ability to convert IS NOT DISTINCT FROM joins +# to equijoins with proper null equality handling + +statement ok +CREATE TABLE t0 ( + id INT, + val INT +) + +statement ok +CREATE TABLE t1 ( + id INT, + val INT +) + +statement ok +CREATE TABLE t2 ( + id INT, + val INT +) + +statement ok +INSERT INTO t0 VALUES +(1, 10), +(2, NULL), +(5, 50) + +statement ok +INSERT INTO t1 VALUES +(1, 10), +(2, NULL), +(3, 30), +(6, NULL) + +statement ok +INSERT INTO t2 VALUES +(1, 10), +(2, NULL), +(4, 40), +(6, 6) + +# Test basic IS NOT DISTINCT FROM join functionality +query IIII rowsort +SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON t1.val IS NOT DISTINCT FROM t2.val +---- +1 1 10 10 +2 2 NULL NULL +6 2 NULL NULL + +# Test that IS NOT DISTINCT FROM join produces HashJoin when used alone +query TT +EXPLAIN SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON t1.val IS NOT DISTINCT FROM t2.val +---- +logical_plan +01)Projection: t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +02)--Inner Join: t1.val = t2.val +03)----TableScan: t1 projection=[id, val] +04)----TableScan: t2 projection=[id, val] +physical_plan +01)ProjectionExec: expr=[id@0 as t1_id, id@2 as t2_id, val@1 as val, val@3 as val] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(val@1, val@1)], NullsEqual: true +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.explain.format = "tree"; + +# Tree explain should highlight null equality semantics +query TT +EXPLAIN SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON t1.val IS NOT DISTINCT FROM t2.val +---- +physical_plan +01)┌───────────────────────────┐ +02)│ ProjectionExec │ +03)│ -------------------- │ +04)│ t1_id: id │ +05)│ t2_id: id │ +06)│ val: val │ +07)└─────────────┬─────────────┘ +08)┌─────────────┴─────────────┐ +09)│ CoalesceBatchesExec │ +10)│ -------------------- │ +11)│ target_batch_size: │ +12)│ 8192 │ +13)└─────────────┬─────────────┘ +14)┌─────────────┴─────────────┐ +15)│ HashJoinExec │ +16)│ -------------------- │ +17)│ NullsEqual: true ├──────────────┐ +18)│ │ │ +19)│ on: (val = val) │ │ +20)└─────────────┬─────────────┘ │ +21)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +22)│ DataSourceExec ││ DataSourceExec │ +23)│ -------------------- ││ -------------------- │ +24)│ bytes: 288 ││ bytes: 288 │ +25)│ format: memory ││ format: memory │ +26)│ rows: 1 ││ rows: 1 │ +27)└───────────────────────────┘└───────────────────────────┘ + +statement ok +set datafusion.explain.format = "indent"; + +# For nested expression comparision, it should still able to be converted to Hash Join +query IIII rowsort +SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON ((t1.val+1) IS NOT DISTINCT FROM (t2.val+1)) AND ((t1.val + 1) IS NOT DISTINCT FROM 11); +---- +1 1 10 10 + +# The plan should include HashJoin +query TT +EXPLAIN SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON ((t1.val+1) IS NOT DISTINCT FROM (t2.val+1)) AND ((t1.val + 1) IS NOT DISTINCT FROM 11); +---- +logical_plan +01)Projection: t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +02)--Inner Join: CAST(t1.val AS Int64) + Int64(1) = CAST(t2.val AS Int64) + Int64(1) +03)----Filter: CAST(t1.val AS Int64) + Int64(1) IS NOT DISTINCT FROM Int64(11) +04)------TableScan: t1 projection=[id, val] +05)----TableScan: t2 projection=[id, val] +physical_plan +01)ProjectionExec: expr=[id@0 as t1_id, id@2 as t2_id, val@1 as val, val@3 as val] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1.val + Int64(1)@2, t2.val + Int64(1)@2)], projection=[id@0, val@1, id@3, val@4], NullsEqual: true +04)------CoalescePartitionsExec +05)--------ProjectionExec: expr=[id@0 as id, val@1 as val, CAST(val@1 AS Int64) + 1 as t1.val + Int64(1)] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------FilterExec: CAST(val@1 AS Int64) + 1 IS NOT DISTINCT FROM 11 +09)----------------DataSourceExec: partitions=1, partition_sizes=[1] +10)------ProjectionExec: expr=[id@0 as id, val@1 as val, CAST(val@1 AS Int64) + 1 as t2.val + Int64(1)] +11)--------DataSourceExec: partitions=1, partition_sizes=[1] + +# Mixed join predicate with `IS DISTINCT FROM` and `IS NOT DISTINCT FROM` +query IIII rowsort +SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON ((t1.val+1) IS NOT DISTINCT FROM (t2.val+1)) AND ((t1.val % 3) IS DISTINCT FROM (t2.val % 3)); +---- + +# The plan should include HashJoin +query TT +EXPLAIN SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON ((t1.val+1) IS NOT DISTINCT FROM (t2.val+1)) AND ((t1.val % 3) IS DISTINCT FROM (t2.val % 3)); +---- +logical_plan +01)Projection: t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +02)--Inner Join: CAST(t1.val AS Int64) + Int64(1) = CAST(t2.val AS Int64) + Int64(1) Filter: CAST(t1.val AS Int64) % Int64(3) IS DISTINCT FROM CAST(t2.val AS Int64) % Int64(3) +03)----TableScan: t1 projection=[id, val] +04)----TableScan: t2 projection=[id, val] +physical_plan +01)ProjectionExec: expr=[id@0 as t1_id, id@2 as t2_id, val@1 as val, val@3 as val] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1.val + Int64(1)@2, t2.val + Int64(1)@2)], filter=CAST(val@0 AS Int64) % 3 IS DISTINCT FROM CAST(val@1 AS Int64) % 3, projection=[id@0, val@1, id@3, val@4], NullsEqual: true +04)------ProjectionExec: expr=[id@0 as id, val@1 as val, CAST(val@1 AS Int64) + 1 as t1.val + Int64(1)] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] +06)------ProjectionExec: expr=[id@0 as id, val@1 as val, CAST(val@1 AS Int64) + 1 as t2.val + Int64(1)] +07)--------DataSourceExec: partitions=1, partition_sizes=[1] + +# Test mixed equal and IS NOT DISTINCT FROM conditions +# The `IS NOT DISTINCT FROM` expr should NOT in HashJoin's `on` predicate +query TT +EXPLAIN SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON t1.id = t2.id AND t1.val IS NOT DISTINCT FROM t2.val +---- +logical_plan +01)Projection: t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +02)--Inner Join: t1.id = t2.id Filter: t1.val IS NOT DISTINCT FROM t2.val +03)----TableScan: t1 projection=[id, val] +04)----TableScan: t2 projection=[id, val] +physical_plan +01)ProjectionExec: expr=[id@0 as t1_id, id@2 as t2_id, val@1 as val, val@3 as val] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(id@0, id@0)], filter=val@0 IS NOT DISTINCT FROM val@1 +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[1] + +# Test the mixed condition join result +query IIII rowsort +SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON t1.id = t2.id AND t1.val IS NOT DISTINCT FROM t2.val +---- +1 1 10 10 +2 2 NULL NULL + +# Test 3 table join +query IIII rowsort +SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON t1.val IS NOT DISTINCT FROM t2.val +JOIN t0 ON t1.val IS NOT DISTINCT FROM t0.val +---- +1 1 10 10 +2 2 NULL NULL +6 2 NULL NULL + +# Ensure there is HashJoin in the plan +query TT +EXPLAIN SELECT t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +FROM t1 +JOIN t2 ON t1.val IS NOT DISTINCT FROM t2.val +JOIN t0 ON t1.val IS NOT DISTINCT FROM t0.val +---- +logical_plan +01)Projection: t1.id AS t1_id, t2.id AS t2_id, t1.val, t2.val +02)--Inner Join: t1.val = t0.val +03)----Inner Join: t1.val = t2.val +04)------TableScan: t1 projection=[id, val] +05)------TableScan: t2 projection=[id, val] +06)----TableScan: t0 projection=[val] +physical_plan +01)ProjectionExec: expr=[id@0 as t1_id, id@2 as t2_id, val@1 as val, val@3 as val] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(val@0, val@1)], projection=[id@1, val@2, id@3, val@4], NullsEqual: true +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)--------CoalesceBatchesExec: target_batch_size=8192 +07)----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(val@1, val@1)], NullsEqual: true +08)------------DataSourceExec: partitions=1, partition_sizes=[1] +09)------------DataSourceExec: partitions=1, partition_sizes=[1] + +# Test IS NOT DISTINCT FROM with multiple columns +statement ok +CREATE TABLE t3 ( + id INT, + val1 INT, + val2 INT +) + +statement ok +CREATE TABLE t4 ( + id INT, + val1 INT, + val2 INT +) + +statement ok +INSERT INTO t3 VALUES +(1, 10, 100), +(2, NULL, 200), +(3, 30, NULL) + +statement ok +INSERT INTO t4 VALUES +(1, 10, 100), +(2, NULL, 200), +(3, 30, NULL) + +# Test multiple IS NOT DISTINCT FROM conditions - should produce HashJoin +query TT rowsort +EXPLAIN SELECT t3.id AS t3_id, t4.id AS t4_id, t3.val1, t4.val1, t3.val2, t4.val2 +FROM t3 +JOIN t4 ON (t3.val1 IS NOT DISTINCT FROM t4.val1) AND (t3.val2 IS NOT DISTINCT FROM t4.val2) +---- +01)Projection: t3.id AS t3_id, t4.id AS t4_id, t3.val1, t4.val1, t3.val2, t4.val2 +01)ProjectionExec: expr=[id@0 as t3_id, id@3 as t4_id, val1@1 as val1, val1@4 as val1, val2@2 as val2, val2@5 as val2] +02)--CoalesceBatchesExec: target_batch_size=8192 +02)--Inner Join: t3.val1 = t4.val1, t3.val2 = t4.val2 +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(val1@1, val1@1), (val2@2, val2@2)], NullsEqual: true +03)----TableScan: t3 projection=[id, val1, val2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +04)----TableScan: t4 projection=[id, val1, val2] +05)------DataSourceExec: partitions=1, partition_sizes=[1] +logical_plan +physical_plan + +# Test the multiple IS NOT DISTINCT FROM join result +query IIIIII +SELECT t3.id AS t3_id, t4.id AS t4_id, t3.val1, t4.val1, t3.val2, t4.val2 +FROM t3 +JOIN t4 ON (t3.val1 IS NOT DISTINCT FROM t4.val1) AND (t3.val2 IS NOT DISTINCT FROM t4.val2) +---- +1 1 10 10 100 100 +2 2 NULL NULL 200 200 +3 3 30 30 NULL NULL + +statement ok +drop table t0; + +statement ok +drop table t1; + +statement ok +drop table t2; + +statement ok +drop table t3; + +statement ok +drop table t4; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ca86dbfcc3c16..96d2bad086e66 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -24,7 +24,7 @@ statement ok set datafusion.execution.target_partitions = 2; statement ok -set datafusion.execution.batch_size = 2; +set datafusion.execution.batch_size = 8192; statement ok set datafusion.explain.logical_plan_only = true; @@ -549,64 +549,64 @@ statement ok set datafusion.optimizer.repartition_joins = true query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id +SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id, t1_name, t2_name ---- -11 a z -11 a y -11 a x 11 a w -22 b z -22 b y -22 b x +11 a x +11 a y +11 a z 22 b w -33 c z -33 c y -33 c x +22 b x +22 b y +22 b z 33 c w -44 d z -44 d y -44 d x +33 c x +33 c y +33 c z 44 d w +44 d x +44 d y +44 d z query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id +SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id, t1_name, t2_name ---- -11 a z -11 a y -11 a x 11 a w -22 b z -22 b y -22 b x +11 a x +11 a y +11 a z 22 b w -33 c z -33 c y -33 c x +22 b x +22 b y +22 b z 33 c w -44 d z -44 d y -44 d x +33 c x +33 c y +33 c z 44 d w +44 d x +44 d y +44 d z query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id +SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name, t2_name ---- -11 a z -11 a y -11 a x 11 a w -22 b z -22 b y -22 b x +11 a x +11 a y +11 a z 22 b w -33 c z -33 c y -33 c x +22 b x +22 b y +22 b z 33 c w -44 d z -44 d y -44 d x +33 c x +33 c y +33 c z 44 d w +44 d x +44 d y +44 d z query ITITI rowsort SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2 @@ -685,64 +685,64 @@ statement ok set datafusion.optimizer.repartition_joins = false query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id +SELECT t1_id, t1_name, t2_name FROM t1, t2 ORDER BY t1_id, t1_name, t2_name ---- -11 a z -11 a y -11 a x 11 a w -22 b z -22 b y -22 b x +11 a x +11 a y +11 a z 22 b w -33 c z -33 c y -33 c x +22 b x +22 b y +22 b z 33 c w -44 d z -44 d y -44 d x +33 c x +33 c y +33 c z 44 d w +44 d x +44 d y +44 d z query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id +SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE 1=1 ORDER BY t1_id, t1_name, t2_name ---- -11 a z -11 a y -11 a x 11 a w -22 b z -22 b y -22 b x +11 a x +11 a y +11 a z 22 b w -33 c z -33 c y -33 c x +22 b x +22 b y +22 b z 33 c w -44 d z -44 d y -44 d x +33 c x +33 c y +33 c z 44 d w +44 d x +44 d y +44 d z query ITT nosort -SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id +SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name, t2_name ---- -11 a z -11 a y -11 a x 11 a w -22 b z -22 b y -22 b x +11 a x +11 a y +11 a z 22 b w -33 c z -33 c y -33 c x +22 b x +22 b y +22 b z 33 c w -44 d z -44 d y -44 d x +33 c x +33 c y +33 c z 44 d w +44 d x +44 d y +44 d z query ITITI rowsort SELECT * FROM (SELECT t1_id, t1_name FROM t1 UNION ALL SELECT t1_id, t1_name FROM t1) AS t1 CROSS JOIN t2 @@ -1067,9 +1067,9 @@ LEFT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id WHERE join_t2.t2_int < 10 or (join_t1.t1_int > 2 and join_t2.t2_name != 'w') ---- logical_plan -01)Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int < UInt32(10) OR join_t1.t1_int > UInt32(2) AND join_t2.t2_name != Utf8("w") +01)Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int < UInt32(10) OR join_t1.t1_int > UInt32(2) AND join_t2.t2_name != Utf8View("w") 02)--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] -03)--Filter: join_t2.t2_int < UInt32(10) OR join_t2.t2_name != Utf8("w") +03)--Filter: join_t2.t2_int < UInt32(10) OR join_t2.t2_name != Utf8View("w") 04)----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] # Reduce left join 3 (to inner join) @@ -1153,7 +1153,7 @@ WHERE join_t1.t1_name != 'b' ---- logical_plan 01)Left Join: join_t1.t1_id = join_t2.t2_id -02)--Filter: join_t1.t1_name != Utf8("b") +02)--Filter: join_t1.t1_name != Utf8View("b") 03)----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] 04)--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] @@ -1168,9 +1168,9 @@ WHERE join_t1.t1_name != 'b' and join_t2.t2_name = 'x' ---- logical_plan 01)Inner Join: join_t1.t1_id = join_t2.t2_id -02)--Filter: join_t1.t1_name != Utf8("b") +02)--Filter: join_t1.t1_name != Utf8View("b") 03)----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] -04)--Filter: join_t2.t2_name = Utf8("x") +04)--Filter: join_t2.t2_name = Utf8View("x") 05)----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] ### @@ -1373,7 +1373,7 @@ inner join join_t4 on join_t3.s3 = join_t4.s4 {id: 2} {id: 2} # join with struct key and nulls -# Note that intersect or except applies `null_equals_null` as true for Join. +# Note that intersect or except applies `null_equality` as `NullEquality::NullEqualsNull` for Join. query ? SELECT * FROM join_t3 EXCEPT @@ -2066,6 +2066,7 @@ SELECT join_t1.t1_id, join_t2.t2_id FROM join_t1 INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +ORDER BY 1 ---- 22 11 33 11 @@ -2105,6 +2106,7 @@ SELECT join_t1.t1_id, join_t2.t2_id FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1 RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2 ON join_t1.t1_id < join_t2.t2_id +ORDER BY 1, 2 ---- 33 44 33 55 @@ -2151,6 +2153,7 @@ WHERE EXISTS ( FROM join_t2 WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 ) +ORDER BY 1 ---- 22 b 2 33 c 3 @@ -2167,6 +2170,7 @@ WHERE EXISTS ( FROM join_t2 WHERE join_t1.t1_id + 1 > join_t2.t2_id * 2 ) +ORDER BY 1 ---- 22 b 2 33 c 3 @@ -3195,7 +3199,7 @@ physical_plan 04)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, rn1@5 ASC NULLS LAST 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 06)----------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -07)------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +07)------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true 09)----CoalesceBatchesExec: target_batch_size=2 10)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST @@ -3233,7 +3237,7 @@ physical_plan 08)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, rn1@5 ASC NULLS LAST 09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 10)----------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -11)------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +11)------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 12)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true statement ok @@ -3272,14 +3276,14 @@ physical_plan 06)----------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 08)--------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -09)----------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +09)----------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 10)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true 11)------SortExec: expr=[a@1 ASC], preserve_partitioning=[true] 12)--------CoalesceBatchesExec: target_batch_size=2 13)----------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 14)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 15)--------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -16)----------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +16)----------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 17)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true statement ok @@ -3314,7 +3318,7 @@ physical_plan 02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@1, a@1)] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true 04)----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -05)------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +05)------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 06)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true # hash join should propagate ordering equivalence of the right side for RIGHT ANTI join. @@ -3341,9 +3345,30 @@ physical_plan 02)--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(a@0, a@1)] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC], file_type=csv, has_header=true 04)----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -05)------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +05)------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 06)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true +# Test ordering preservation for RIGHT join +query TT +EXPLAIN SELECT * +FROM annotated_data as l_table +RIGHT JOIN (SELECT * FROM annotated_data) as r_table +ON l_table.b = r_table.b +ORDER BY r_table.a ASC NULLS FIRST, r_table.b, r_table.c, l_table.a ASC NULLS FIRST; +---- +logical_plan +01)Sort: r_table.a ASC NULLS FIRST, r_table.b ASC NULLS LAST, r_table.c ASC NULLS LAST, l_table.a ASC NULLS FIRST +02)--Right Join: l_table.b = r_table.b +03)----SubqueryAlias: l_table +04)------TableScan: annotated_data projection=[a0, a, b, c, d] +05)----SubqueryAlias: r_table +06)------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@2, b@2)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true + query TT EXPLAIN SELECT l.a, LAST_VALUE(r.b ORDER BY r.a ASC NULLS FIRST) as last_col1 FROM annotated_data as l @@ -3416,7 +3441,7 @@ physical_plan 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10, projection=[a@0, d@1, row_n@4] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], file_type=csv, has_header=true 06)--------ProjectionExec: expr=[a@0 as a, d@1 as d, row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] -07)----------BoundedWindowAggExec: wdw=[row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +07)----------BoundedWindowAggExec: wdw=[row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 08)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], file_type=csv, has_header=true # run query above in multiple partitions @@ -3604,12 +3629,12 @@ logical_plan 02)--SubqueryAlias: a 03)----Union 04)------Projection: Int64(1) AS c, Int64(2) AS d -05)--------EmptyRelation +05)--------EmptyRelation: rows=1 06)------Projection: Int64(1) AS c, Int64(3) AS d -07)--------EmptyRelation +07)--------EmptyRelation: rows=1 08)--SubqueryAlias: rhs 09)----Projection: Int64(1) AS e, Int64(3) AS f -10)------EmptyRelation +10)------EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[c@2 as c, d@3 as d, e@0 as e, f@1 as f] 02)--CoalesceBatchesExec: target_batch_size=2 @@ -3647,12 +3672,12 @@ logical_plan 02)--SubqueryAlias: a 03)----Union 04)------Projection: Int64(1) AS c, Int64(2) AS d -05)--------EmptyRelation +05)--------EmptyRelation: rows=1 06)------Projection: Int64(1) AS c, Int64(3) AS d -07)--------EmptyRelation +07)--------EmptyRelation: rows=1 08)--SubqueryAlias: rhs 09)----Projection: Int64(1) AS e, Int64(3) AS f -10)------EmptyRelation +10)------EmptyRelation: rows=1 physical_plan 01)ProjectionExec: expr=[c@2 as c, d@3 as d, e@0 as e, f@1 as f] 02)--CoalesceBatchesExec: target_batch_size=2 @@ -3687,7 +3712,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 as a WHERE 1=0 ) AS a INNER JOIN (SELECT 1 as a) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Inner join with empty right table query TT @@ -3695,7 +3720,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 AS a ) AS a INNER JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Left join with empty left table query TT @@ -3703,7 +3728,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 as a WHERE 1=0 ) AS a LEFT JOIN (SELECT 1 as a) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Left join with empty left and empty right table query TT @@ -3711,7 +3736,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 as a WHERE 1=0 ) AS a LEFT JOIN (SELECT 1 as a WHERE 1=0) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Right join with empty right table query TT @@ -3719,7 +3744,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 AS a ) AS a RIGHT JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Right join with empty right and empty left table query TT @@ -3727,7 +3752,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 as a WHERE 1=0 ) AS a RIGHT JOIN (SELECT 1 as a WHERE 1=0) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Left SEMI join with empty left table query TT @@ -3735,7 +3760,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 AS a ) AS a LEFT SEMI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Left SEMI join with empty right table query TT @@ -3743,7 +3768,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 AS a WHERE 1=0 ) AS a LEFT SEMI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Right SEMI join with empty left table query TT @@ -3751,7 +3776,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 AS a WHERE 1=0 ) AS a RIGHT SEMI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Right SEMI join with empty right table query TT @@ -3759,7 +3784,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 AS a ) AS a RIGHT SEMI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Left ANTI join with empty left table query TT @@ -3767,7 +3792,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 AS a WHERE 1=0 ) AS a LEFT ANTI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Right ANTI join with empty right table query TT @@ -3775,7 +3800,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 AS a ) AS a RIGHT ANTI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # FULL OUTER join with empty left and empty right table query TT @@ -3783,7 +3808,7 @@ EXPLAIN SELECT * FROM ( SELECT 1 as a WHERE 1=0 ) AS a FULL JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 # Left ANTI join with empty right table query TT @@ -3794,7 +3819,7 @@ EXPLAIN SELECT * FROM ( logical_plan 01)SubqueryAlias: a 02)--Projection: Int64(1) AS a -03)----EmptyRelation +03)----EmptyRelation: rows=1 # Right ANTI join with empty left table query TT @@ -3805,7 +3830,7 @@ EXPLAIN SELECT * FROM ( logical_plan 01)SubqueryAlias: b 02)--Projection: Int64(1) AS a -03)----EmptyRelation +03)----EmptyRelation: rows=1 statement ok @@ -3871,8 +3896,8 @@ physical_plan 02)--CoalesceBatchesExec: target_batch_size=3 03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] 04)------SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] -05)--------DataSourceExec: partitions=1, partition_sizes=[1] -06)------DataSourceExec: partitions=1, partition_sizes=[1] +05)--------DataSourceExec: partitions=1, partition_sizes=[2] +06)------DataSourceExec: partitions=1, partition_sizes=[2] @@ -3928,8 +3953,8 @@ physical_plan 01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] 02)--CoalesceBatchesExec: target_batch_size=3 03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] -04)------DataSourceExec: partitions=1, partition_sizes=[1] -05)------DataSourceExec: partitions=1, partition_sizes=[1] +04)------DataSourceExec: partitions=1, partition_sizes=[2] +05)------DataSourceExec: partitions=1, partition_sizes=[2] # Null build indices: @@ -3989,8 +4014,8 @@ physical_plan 02)--CoalesceBatchesExec: target_batch_size=3 03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] 04)------SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] -05)--------DataSourceExec: partitions=1, partition_sizes=[1] -06)------DataSourceExec: partitions=1, partition_sizes=[1] +05)--------DataSourceExec: partitions=1, partition_sizes=[2] +06)------DataSourceExec: partitions=1, partition_sizes=[2] # Test CROSS JOIN LATERAL syntax (planning) @@ -4008,13 +4033,13 @@ logical_plan 08)----------Projection: __unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)),depth=1) AS UNNEST(generate_series(Int64(1),outer_ref(t1.t1_int))) 09)------------Unnest: lists[__unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] 10)--------------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS __unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) -11)----------------EmptyRelation -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }) +11)----------------EmptyRelation: rows=1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "t1_int", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }) # Test CROSS JOIN LATERAL syntax (execution) # TODO: https://github.com/apache/datafusion/issues/10048 -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t1" \}\), name: "t1_int" \}\) +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(Field \{ name: "t1_int", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}, Column \{ relation: Some\(Bare \{ table: "t1" \}\), name: "t1_int" \}\) select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); @@ -4033,13 +4058,13 @@ logical_plan 08)----------Projection: __unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)),depth=1) AS UNNEST(generate_series(Int64(1),outer_ref(t2.t1_int))) 09)------------Unnest: lists[__unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] 10)--------------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS __unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) -11)----------------EmptyRelation -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }) +11)----------------EmptyRelation: rows=1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "t1_int", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }) # Test INNER JOIN LATERAL syntax (execution) # TODO: https://github.com/apache/datafusion/issues/10048 -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t2" \}\), name: "t1_int" \}\) +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(Field \{ name: "t1_int", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}, Column \{ relation: Some\(Bare \{ table: "t2" \}\), name: "t1_int" \}\) select t1_id, t1_name, i from join_t1 t2 inner join lateral (select * from unnest(generate_series(1, t1_int))) as series(i) on(t1_id > i); # Test RIGHT JOIN LATERAL syntax (unsupported) @@ -4087,7 +4112,7 @@ logical_plan 07)------------TableScan: sales_global projection=[ts, sn, amount, currency] 08)----------SubqueryAlias: e 09)------------Projection: exchange_rates.ts, exchange_rates.currency_from, exchange_rates.rate -10)--------------Filter: exchange_rates.currency_to = Utf8("USD") +10)--------------Filter: exchange_rates.currency_to = Utf8View("USD") 11)----------------TableScan: exchange_rates projection=[ts, currency_from, currency_to, rate] physical_plan 01)SortExec: expr=[sn@1 ASC NULLS LAST], preserve_partitioning=[false] @@ -4123,9 +4148,9 @@ logical_plan 03)----TableScan: left_table projection=[a, b, c] 04)----TableScan: right_table projection=[x, y, z] physical_plan -01)NestedLoopJoinExec: join_type=Inner, filter=a@0 < x@1 -02)--DataSourceExec: partitions=1, partition_sizes=[0] -03)--SortExec: expr=[x@0 ASC NULLS LAST], preserve_partitioning=[false] +01)SortExec: expr=[x@3 ASC NULLS LAST], preserve_partitioning=[false] +02)--NestedLoopJoinExec: join_type=Inner, filter=a@0 < x@1 +03)----DataSourceExec: partitions=1, partition_sizes=[0] 04)----DataSourceExec: partitions=1, partition_sizes=[0] query TT @@ -4160,23 +4185,43 @@ AS VALUES (3, 3, true), (3, 3, false); -query IIIIB -SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 LIMIT 2; +query IIIIB rowsort +-- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism +SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 LIMIT 20; ---- -2 2 2 2 true +1 1 NULL NULL NULL 2 2 2 2 false - -query IIIIB -SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 2; ----- 2 2 2 2 true -3 3 2 2 true +3 3 3 3 false +3 3 3 3 true +4 4 NULL NULL NULL -query IIIIB -SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 2; +query IIIIB rowsort +-- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism +SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; ---- +1 1 NULL NULL NULL +2 2 2 2 false 2 2 2 2 true +3 3 2 2 false +3 3 2 2 true +3 3 3 3 false +3 3 3 3 true +4 4 2 2 false +4 4 2 2 true +4 4 3 3 false +4 4 3 3 true + +query IIIIB rowsort +-- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism +SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 20; +---- +1 1 NULL NULL NULL 2 2 2 2 false +2 2 2 2 true +3 3 3 3 false +3 3 3 3 true +4 4 NULL NULL NULL ## Test !join.on.is_empty() && join.filter.is_none() query TT @@ -4190,8 +4235,8 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=3, fetch=2 02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)] -03)----DataSourceExec: partitions=1, partition_sizes=[1] -04)----DataSourceExec: partitions=1, partition_sizes=[1] +03)----DataSourceExec: partitions=1, partition_sizes=[2] +04)----DataSourceExec: partitions=1, partition_sizes=[2] ## Test join.on.is_empty() && join.filter.is_some() query TT @@ -4205,8 +4250,8 @@ logical_plan physical_plan 01)GlobalLimitExec: skip=0, fetch=2 02)--NestedLoopJoinExec: join_type=Full, filter=c2@0 >= c2@1 -03)----DataSourceExec: partitions=1, partition_sizes=[1] -04)----DataSourceExec: partitions=1, partition_sizes=[1] +03)----DataSourceExec: partitions=1, partition_sizes=[2] +04)----DataSourceExec: partitions=1, partition_sizes=[2] ## Test !join.on.is_empty() && join.filter.is_some() query TT @@ -4220,8 +4265,8 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=3, fetch=2 02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], filter=c2@0 >= c2@1 -03)----DataSourceExec: partitions=1, partition_sizes=[1] -04)----DataSourceExec: partitions=1, partition_sizes=[1] +03)----DataSourceExec: partitions=1, partition_sizes=[2] +04)----DataSourceExec: partitions=1, partition_sizes=[2] ## Add more test cases for join limit pushdown statement ok @@ -4236,23 +4281,23 @@ set datafusion.execution.target_partitions = 1; # Note we use csv as MemoryExec does not support limit push down (so doesn't manifest # bugs if limits are improperly pushed down) query I -COPY (values (1), (2), (3), (4), (5)) TO 'test_files/scratch/limit/t1.csv' +COPY (values (1), (2), (3), (4), (5)) TO 'test_files/scratch/joins/t1.csv' STORED AS CSV ---- 5 # store t2 in different order so the top N rows are not the same as the top N rows of t1 query I -COPY (values (5), (4), (3), (2), (1)) TO 'test_files/scratch/limit/t2.csv' +COPY (values (5), (4), (3), (2), (1)) TO 'test_files/scratch/joins/t2.csv' STORED AS CSV ---- 5 statement ok -create external table t1(a int) stored as CSV location 'test_files/scratch/limit/t1.csv'; +create external table t1(a int) stored as CSV location 'test_files/scratch/joins/t1.csv'; statement ok -create external table t2(b int) stored as CSV location 'test_files/scratch/limit/t2.csv'; +create external table t2(b int) stored as CSV location 'test_files/scratch/joins/t2.csv'; ###### ## LEFT JOIN w/ LIMIT @@ -4284,8 +4329,8 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=3, fetch=2 02)--HashJoinExec: mode=CollectLeft, join_type=Left, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/t1.csv]]}, projection=[a], limit=2, file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/t2.csv]]}, projection=[b], file_type=csv, has_header=true +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], limit=2, file_type=csv, has_header=true +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true ###### ## RIGHT JOIN w/ LIMIT @@ -4318,8 +4363,8 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=3, fetch=2 02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/t1.csv]]}, projection=[a], file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/t2.csv]]}, projection=[b], limit=2, file_type=csv, has_header=true +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], limit=2, file_type=csv, has_header=true ###### ## FULL JOIN w/ LIMIT @@ -4355,8 +4400,8 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=3, fetch=2 02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(a@0, b@0)] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/t1.csv]]}, projection=[a], file_type=csv, has_header=true -04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/t2.csv]]}, projection=[b], file_type=csv, has_header=true +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t1.csv]]}, projection=[a], file_type=csv, has_header=true +04)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/joins/t2.csv]]}, projection=[b], file_type=csv, has_header=true statement ok drop table t1; @@ -4385,7 +4430,7 @@ JOIN my_catalog.my_schema.table_with_many_types AS r ON l.binary_col = r.binary_ logical_plan 01)Projection: count(Int64(1)) AS count(*) 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -03)----Projection: +03)----Projection: 04)------Inner Join: l.binary_col = r.binary_col 05)--------SubqueryAlias: l 06)----------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col] @@ -4429,11 +4474,9 @@ physical_plan 04)------CoalescePartitionsExec 05)--------CoalesceBatchesExec: target_batch_size=3 06)----------FilterExec: b@1 > 3, projection=[a@0] -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------DataSourceExec: partitions=1, partition_sizes=[1] -09)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true] -10)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -11)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)------------DataSourceExec: partitions=2, partition_sizes=[1, 1] +08)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true] +09)--------DataSourceExec: partitions=2, partition_sizes=[1, 1] query TT explain select * from test where a in (select a from test where b > 3) order by c desc nulls last; @@ -4453,11 +4496,9 @@ physical_plan 04)------CoalescePartitionsExec 05)--------CoalesceBatchesExec: target_batch_size=3 06)----------FilterExec: b@1 > 3, projection=[a@0] -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------DataSourceExec: partitions=1, partition_sizes=[1] -09)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true] -10)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -11)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)------------DataSourceExec: partitions=2, partition_sizes=[1, 1] +08)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true] +09)--------DataSourceExec: partitions=2, partition_sizes=[1, 1] query III select * from test where a in (select a from test where b > 3) order by c desc nulls first; @@ -4628,7 +4669,7 @@ logical_plan 05)------Subquery: 06)--------Filter: outer_ref(j1.j1_id) < j2.j2_id 07)----------TableScan: j2 projection=[j2_string, j2_id] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Int32, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "j1_id", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) query TT explain SELECT * FROM j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id), LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4 @@ -4644,7 +4685,7 @@ logical_plan 08)----Subquery: 09)------Filter: j3.j3_string = outer_ref(j2.j2_string) 10)--------TableScan: j3 projection=[j3_string, j3_id] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Utf8, Column { relation: Some(Bare { table: "j2" }), name: "j2_string" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "j2_string", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "j2" }), name: "j2_string" }) query TT explain SELECT * FROM j1, LATERAL (SELECT * FROM j1, LATERAL (SELECT * FROM j2 WHERE j1_id = j2_id) as j2) as j2; @@ -4660,7 +4701,7 @@ logical_plan 08)----------Subquery: 09)------------Filter: outer_ref(j1.j1_id) = j2.j2_id 10)--------------TableScan: j2 projection=[j2_string, j2_id] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Int32, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "j1_id", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) query TT explain SELECT j1_string, j2_string FROM j1 LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true); @@ -4673,7 +4714,7 @@ logical_plan 05)------Subquery: 06)--------Filter: outer_ref(j1.j1_id) < j2.j2_id 07)----------TableScan: j2 projection=[j2_string, j2_id] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Int32, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "j1_id", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) query TT explain SELECT * FROM j1, (j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true)); @@ -4687,7 +4728,7 @@ logical_plan 06)------Subquery: 07)--------Filter: outer_ref(j1.j1_id) + outer_ref(j2.j2_id) = j3.j3_id 08)----------TableScan: j3 projection=[j3_string, j3_id] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Int32, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "j1_id", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) query TT explain SELECT * FROM j1, LATERAL (SELECT 1) AS j2; @@ -4697,7 +4738,7 @@ logical_plan 02)--TableScan: j1 projection=[j1_string, j1_id] 03)--SubqueryAlias: j2 04)----Projection: Int64(1) -05)------EmptyRelation +05)------EmptyRelation: rows=1 physical_plan 01)CrossJoinExec 02)--DataSourceExec: partitions=1, partition_sizes=[0] @@ -4742,3 +4783,419 @@ drop table person; statement count 0 drop table orders; + +# Create tables for testing compound field access in JOIN conditions +statement ok +CREATE TABLE compound_field_table_t +AS VALUES +({r: 'a', c: 1}), +({r: 'b', c: 2.3}); + +statement ok +CREATE TABLE compound_field_table_u +AS VALUES +({r: 'a', c: 1}), +({r: 'b', c: 2.3}); + +# Test compound field access in JOIN condition with table aliases +query ?? +SELECT * FROM compound_field_table_t tee JOIN compound_field_table_u you ON tee.column1['r'] = you.column1['r'] +---- +{r: a, c: 1.0} {r: a, c: 1.0} +{r: b, c: 2.3} {r: b, c: 2.3} + +# Test compound field access in JOIN condition without table aliases +query ?? +SELECT * FROM compound_field_table_t JOIN compound_field_table_u ON compound_field_table_t.column1['r'] = compound_field_table_u.column1['r'] +---- +{r: a, c: 1.0} {r: a, c: 1.0} +{r: b, c: 2.3} {r: b, c: 2.3} + +# Test compound field access with numeric field access +query ?? +SELECT * FROM compound_field_table_t tee JOIN compound_field_table_u you ON tee.column1['c'] = you.column1['c'] +---- +{r: a, c: 1.0} {r: a, c: 1.0} +{r: b, c: 2.3} {r: b, c: 2.3} + +# Test compound field access with mixed field types +query ?? +SELECT * FROM compound_field_table_t tee JOIN compound_field_table_u you ON tee.column1['r'] = you.column1['r'] AND tee.column1['c'] = you.column1['c'] +---- +{r: a, c: 1.0} {r: a, c: 1.0} +{r: b, c: 2.3} {r: b, c: 2.3} + +# Clean up compound field tables +statement ok +DROP TABLE compound_field_table_t; + +statement ok +DROP TABLE compound_field_table_u; + + +statement ok +CREATE TABLE t1 (k INT, v INT); + +statement ok +CREATE TABLE t2 (k INT, v INT); + +statement ok +INSERT INTO t1 + SELECT value AS k, value AS v + FROM range(1, 10001) AS t(value); + +statement ok +INSERT INTO t2 VALUES (1, 1); + +## The TopK(Sort with fetch) should not be pushed down to the hash join +query TT +explain +SELECT * +FROM t1 +LEFT ANTI JOIN t2 ON t1.k = t2.k +ORDER BY t1.k +LIMIT 2; +---- +logical_plan +01)Sort: t1.k ASC NULLS LAST, fetch=2 +02)--LeftAnti Join: t1.k = t2.k +03)----TableScan: t1 projection=[k, v] +04)----TableScan: t2 projection=[k] +physical_plan +01)SortExec: TopK(fetch=2), expr=[k@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(k@0, k@0)] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[3334] + + +query II +SELECT * +FROM t1 +LEFT ANTI JOIN t2 ON t1.k = t2.k +ORDER BY t1.k +LIMIT 2; +---- +2 2 +3 3 + + +## Test left anti join without limit, we should support push down sort to the left side +query TT +explain +SELECT * +FROM t1 +LEFT ANTI JOIN t2 ON t1.k = t2.k +ORDER BY t1.k; +---- +logical_plan +01)Sort: t1.k ASC NULLS LAST +02)--LeftAnti Join: t1.k = t2.k +03)----TableScan: t1 projection=[k, v] +04)----TableScan: t2 projection=[k] +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(k@0, k@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----SortExec: expr=[k@0 ASC NULLS LAST], preserve_partitioning=[false] +05)------DataSourceExec: partitions=1, partition_sizes=[3334] + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + + +# Test hash joins with an empty build relation (empty build relation optimization) + +statement ok +CREATE TABLE t1 (k1 int, v1 int); + +statement ok +CREATE TABLE t2 (k2 int, v2 int); + +statement ok +INSERT INTO t1 SELECT i AS k, 1 FROM generate_series(1, 30000) t(i); + +statement ok +set datafusion.explain.physical_plan_only = true; + +# INNER JOIN +query TT +EXPLAIN +SELECT * +FROM t1 +JOIN t2 ON k1 = k2 +---- +physical_plan +01)ProjectionExec: expr=[k1@2 as k1, v1@3 as v1, k2@0 as k2, v2@1 as v2] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(k2@0, k1@0)] +04)------DataSourceExec: partitions=1, partition_sizes=[0] +05)------DataSourceExec: partitions=1, partition_sizes=[10000] + +query IIII +SELECT sum(k1), sum(v1), sum(k2), sum(v2) +FROM t1 +JOIN t2 ON k1 = k2 +---- +NULL NULL NULL NULL + +# LEFT JOIN +query TT +EXPLAIN +SELECT * +FROM t1 +LEFT JOIN t2 ON k1 = k2 +---- +physical_plan +01)ProjectionExec: expr=[k1@2 as k1, v1@3 as v1, k2@0 as k2, v2@1 as v2] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Right, on=[(k2@0, k1@0)] +04)------DataSourceExec: partitions=1, partition_sizes=[0] +05)------DataSourceExec: partitions=1, partition_sizes=[10000] + +query IIII +SELECT sum(k1), sum(v1), sum(k2), sum(v2) +FROM t1 +LEFT JOIN t2 ON k1 = k2 +---- +450015000 30000 NULL NULL + +# RIGHT JOIN +query TT +EXPLAIN +SELECT * +FROM t1 +RIGHT JOIN t2 ON k1 = k2 +---- +physical_plan +01)ProjectionExec: expr=[k1@2 as k1, v1@3 as v1, k2@0 as k2, v2@1 as v2] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(k2@0, k1@0)] +04)------DataSourceExec: partitions=1, partition_sizes=[0] +05)------DataSourceExec: partitions=1, partition_sizes=[10000] + +query IIII +SELECT sum(k1), sum(v1), sum(k2), sum(v2) +FROM t1 +RIGHT JOIN t2 ON k1 = k2 +---- +NULL NULL NULL NULL + +# LEFT SEMI JOIN +query TT +EXPLAIN +SELECT * +FROM t1 +LEFT SEMI JOIN t2 ON k1 = k2 +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(k2@0, k1@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[0] +04)----DataSourceExec: partitions=1, partition_sizes=[10000] + +query II +SELECT sum(k1), sum(v1) +FROM t1 +LEFT SEMI JOIN t2 ON k1 = k2 +---- +NULL NULL + +# RIGHT SEMI JOIN +query TT +EXPLAIN +SELECT * +FROM t1 +RIGHT SEMI JOIN t2 ON k1 = k2 +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(k2@0, k1@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[0] +04)----DataSourceExec: partitions=1, partition_sizes=[10000] + +query II +SELECT sum(k2), sum(v2) +FROM t1 +RIGHT SEMI JOIN t2 ON k1 = k2 +---- +NULL NULL + +# LEFT ANTI JOIN +query TT +EXPLAIN +SELECT * +FROM t1 +LEFT ANTI JOIN t2 ON k1 = k2 +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(k2@0, k1@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[0] +04)----DataSourceExec: partitions=1, partition_sizes=[10000] + +query II +SELECT sum(k1), sum(v1) +FROM t1 +LEFT ANTI JOIN t2 ON k1 = k2 +---- +450015000 30000 + +# RIGHT ANTI JOIN +query TT +EXPLAIN +SELECT * +FROM t1 +RIGHT ANTI JOIN t2 ON k1 = k2 +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(k2@0, k1@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[0] +04)----DataSourceExec: partitions=1, partition_sizes=[10000] + +query II +SELECT sum(k2), sum(v2) +FROM t1 +RIGHT ANTI JOIN t2 ON k1 = k2 +---- +NULL NULL + +# FULL JOIN +query TT +EXPLAIN +SELECT * +FROM t1 +FULL JOIN t2 ON k1 = k2 +---- +physical_plan +01)ProjectionExec: expr=[k1@2 as k1, v1@3 as v1, k2@0 as k2, v2@1 as v2] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Full, on=[(k2@0, k1@0)] +04)------DataSourceExec: partitions=1, partition_sizes=[0] +05)------DataSourceExec: partitions=1, partition_sizes=[10000] + +query IIII +SELECT sum(k1), sum(v1), sum(k2), sum(v2) +FROM t1 +FULL JOIN t2 ON k1 = k2 +---- +450015000 30000 NULL NULL + +# LEFT MARK JOIN +query TT +EXPLAIN +SELECT * +FROM t2 +WHERE k2 > 0 + OR EXISTS ( + SELECT * + FROM t1 + WHERE k2 = k1 + ) +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--FilterExec: k2@0 > 0 OR mark@2, projection=[k2@0, v2@1] +03)----CoalesceBatchesExec: target_batch_size=3 +04)------HashJoinExec: mode=CollectLeft, join_type=LeftMark, on=[(k2@0, k1@0)] +05)--------DataSourceExec: partitions=1, partition_sizes=[0] +06)--------DataSourceExec: partitions=1, partition_sizes=[10000] + +query II +SELECT * +FROM t2 +WHERE k2 > 0 + OR EXISTS ( + SELECT * + FROM t1 + WHERE k2 = k1 + ) +---- + +# Projection inside the join (changes the output schema) +query TT +EXPLAIN +SELECT distinct(v1) +FROM t1 +LEFT ANTI JOIN t2 ON k1 = k2 +---- +physical_plan +01)AggregateExec: mode=Single, gby=[v1@0 as v1], aggr=[] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(k2@0, k1@0)], projection=[v1@1] +04)------DataSourceExec: partitions=1, partition_sizes=[0] +05)------DataSourceExec: partitions=1, partition_sizes=[10000] + +query I +SELECT distinct(v1) +FROM t1 +LEFT ANTI JOIN t2 ON k1 = k2 +---- +1 + +# Both sides empty +query TT +EXPLAIN +SELECT * +FROM t1 +LEFT ANTI JOIN t2 ON k1 = k2 +WHERE k1 < 0 +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(k2@0, k1@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[0] +04)----CoalesceBatchesExec: target_batch_size=3 +05)------FilterExec: k1@0 < 0 +06)--------DataSourceExec: partitions=1, partition_sizes=[10000] + +query II +SELECT * +FROM t1 +LEFT ANTI JOIN t2 ON k1 = k2 +WHERE k1 < 0 +---- + +# Mark testing +statement ok +CREATE OR REPLACE TABLE t1(b INT, c INT, d INT); + +statement ok +INSERT INTO t1 VALUES + (10, 5, 3), + ( 1, 7, 8), + ( 2, 9, 7), + ( 3, 8,10), + ( 5, 6, 6), + ( 0, 4, 9), + ( 4, 8, 7), + (100,6, 5); + +query I rowsort +SELECT c + FROM t1 + WHERE c > d + OR EXISTS(SELECT 1 FROM t1 AS x WHERE x.b= d+2) +---- +4 +5 +6 +6 +7 +8 +8 +9 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +set datafusion.explain.physical_plan_only = false; diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 93ffa313b8f70..ae82aee5e1559 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -365,7 +365,7 @@ EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); logical_plan 01)Projection: count(Int64(1)) AS count(*) 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -03)----Projection: +03)----Projection: 04)------Limit: skip=6, fetch=3 05)--------Filter: t1.a > Int32(3) 06)----------TableScan: t1 projection=[a] @@ -663,15 +663,14 @@ logical_plan physical_plan 01)GlobalLimitExec: skip=4, fetch=10 02)--SortPreservingMergeExec: [c@0 DESC], fetch=14 -03)----UnionExec -04)------SortExec: TopK(fetch=14), expr=[c@0 DESC], preserve_partitioning=[true] +03)----SortExec: TopK(fetch=14), expr=[c@0 DESC], preserve_partitioning=[true] +04)------UnionExec 05)--------ProjectionExec: expr=[CAST(c@0 AS Int64) as c] 06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], file_type=csv, has_header=true -08)------SortExec: TopK(fetch=14), expr=[c@0 DESC], preserve_partitioning=[true] -09)--------ProjectionExec: expr=[CAST(d@0 AS Int64) as c] -10)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -11)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[d], file_type=csv, has_header=true +08)--------ProjectionExec: expr=[CAST(d@0 AS Int64) as c] +09)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[d], file_type=csv, has_header=true # Applying LIMIT & OFFSET to subquery. query III @@ -800,7 +799,7 @@ CREATE TABLE src_table ( # File 1: query I COPY (SELECT * FROM src_table where part_key = 1) -TO 'test_files/scratch/parquet/test_limit_with_partitions/part-0.parquet' +TO 'test_files/scratch/limit/test_limit_with_partitions/part-0.parquet' STORED AS PARQUET; ---- 3 @@ -808,7 +807,7 @@ STORED AS PARQUET; # File 2: query I COPY (SELECT * FROM src_table where part_key = 2) -TO 'test_files/scratch/parquet/test_limit_with_partitions/part-1.parquet' +TO 'test_files/scratch/limit/test_limit_with_partitions/part-1.parquet' STORED AS PARQUET; ---- 4 @@ -816,7 +815,7 @@ STORED AS PARQUET; # File 3: query I COPY (SELECT * FROM src_table where part_key = 3) -TO 'test_files/scratch/parquet/test_limit_with_partitions/part-2.parquet' +TO 'test_files/scratch/limit/test_limit_with_partitions/part-2.parquet' STORED AS PARQUET; ---- 3 @@ -828,13 +827,14 @@ CREATE EXTERNAL TABLE test_limit_with_partitions value INT ) STORED AS PARQUET -LOCATION 'test_files/scratch/parquet/test_limit_with_partitions/'; +LOCATION 'test_files/scratch/limit/test_limit_with_partitions/'; query TT explain with selection as ( select * from test_limit_with_partitions + order by part_key limit 1 ) select 1 as foo @@ -847,19 +847,19 @@ logical_plan 02)--Sort: selection.part_key ASC NULLS LAST, fetch=1000 03)----Projection: Int64(1) AS foo, selection.part_key 04)------SubqueryAlias: selection -05)--------Limit: skip=0, fetch=1 -06)----------TableScan: test_limit_with_partitions projection=[part_key], fetch=1 +05)--------Sort: test_limit_with_partitions.part_key ASC NULLS LAST, fetch=1 +06)----------TableScan: test_limit_with_partitions projection=[part_key] physical_plan -01)ProjectionExec: expr=[foo@0 as foo] -02)--SortExec: TopK(fetch=1000), expr=[part_key@1 ASC NULLS LAST], preserve_partitioning=[false] -03)----ProjectionExec: expr=[1 as foo, part_key@0 as part_key] -04)------CoalescePartitionsExec: fetch=1 -05)--------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-0.parquet:0..794], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-1.parquet:0..794], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-2.parquet:0..794]]}, projection=[part_key], limit=1, file_type=parquet +01)ProjectionExec: expr=[1 as foo] +02)--SortPreservingMergeExec: [part_key@0 ASC NULLS LAST], fetch=1 +03)----SortExec: TopK(fetch=1), expr=[part_key@0 ASC NULLS LAST], preserve_partitioning=[true] +04)------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/test_limit_with_partitions/part-0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/test_limit_with_partitions/part-1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/limit/test_limit_with_partitions/part-2.parquet]]}, projection=[part_key], file_type=parquet, predicate=DynamicFilter [ empty ] query I with selection as ( select * from test_limit_with_partitions + order by part_key limit 1 ) select 1 as foo diff --git a/datafusion/sqllogictest/test_files/listing_table_partitions.slt b/datafusion/sqllogictest/test_files/listing_table_partitions.slt new file mode 100644 index 0000000000000..52433429cfe80 --- /dev/null +++ b/datafusion/sqllogictest/test_files/listing_table_partitions.slt @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query I +copy (values('foo'), ('bar')) +to 'test_files/scratch/listing_table_partitions/single_part/a=1/file1.parquet'; +---- +2 + +query I +copy (values('baz')) +to 'test_files/scratch/listing_table_partitions/single_part/a=1/file2.parquet'; +---- +1 + +statement count 0 +create external table single_part +stored as parquet location 'test_files/scratch/listing_table_partitions/single_part/'; + +query TT +select * from single_part order by (column1); +---- +bar 1 +baz 1 +foo 1 + +query I +copy (values('foo'), ('bar')) to 'test_files/scratch/listing_table_partitions/multi_part/a=1/b=100/file1.parquet'; +---- +2 + +query I +copy (values('baz')) to 'test_files/scratch/listing_table_partitions/multi_part/a=1/b=200/file1.parquet'; +---- +1 + +statement count 0 +create external table multi_part +stored as parquet location 'test_files/scratch/listing_table_partitions/multi_part/'; + +query TTT +select * from multi_part where b=200; +---- +baz 1 200 + +statement count 0 +set datafusion.execution.listing_table_factory_infer_partitions = false; + +statement count 0 +create external table infer_disabled +stored as parquet location 'test_files/scratch/listing_table_partitions/multi_part/'; + +query T +select * from infer_disabled order by (column1); +---- +bar +baz +foo + +statement count 0 +set datafusion.execution.listing_table_factory_infer_partitions = true; diff --git a/datafusion/sqllogictest/test_files/listing_table_statistics.slt b/datafusion/sqllogictest/test_files/listing_table_statistics.slt new file mode 100644 index 0000000000000..37daf551c2c39 --- /dev/null +++ b/datafusion/sqllogictest/test_files/listing_table_statistics.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test file with different schema order but generating correct statistics for table +statement ok +COPY (SELECT * FROM values (1, 'a'), (2, 'b') t(int_col, str_col)) to 'test_files/scratch/listing_table_statistics/1.parquet'; + +statement ok +COPY (SELECT * FROM values ('c', 3), ('d', -1) t(str_col, int_col)) to 'test_files/scratch/listing_table_statistics/2.parquet'; + +statement ok +set datafusion.execution.collect_statistics = true; + +statement ok +set datafusion.explain.show_statistics = true; + +statement ok +create external table t stored as parquet location 'test_files/scratch/listing_table_statistics'; + +query TT +explain format indent select * from t; +---- +logical_plan TableScan: t projection=[int_col, str_col] +physical_plan DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/listing_table_statistics/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/listing_table_statistics/2.parquet]]}, projection=[int_col, str_col], file_type=parquet, statistics=[Rows=Exact(4), Bytes=Exact(212), [(Col[0]: Min=Exact(Int64(-1)) Max=Exact(Int64(3)) Null=Exact(0)),(Col[1]: Min=Exact(Utf8View("a")) Max=Exact(Utf8View("d")) Null=Exact(0))]] + +statement ok +drop table t; + +statement ok +set datafusion.execution.collect_statistics = false; + +statement ok +set datafusion.explain.show_statistics = false; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 42a4ba6218016..4f1e5ef39a00d 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -651,6 +651,57 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) [NULL] [[4, NULL, 6]] [NULL] [NULL] [NULL] [[1, NULL, 3]] +# Tests for map_entries + +query ? +SELECT map_entries(MAP { 'a': 1, 'b': 3 }); +---- +[{key: a, value: 1}, {key: b, value: 3}] + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT map_entries(MAP { 'a': 1, 2: 3 }); + +query ? +SELECT map_entries(MAP {'a':1, 'b':2, 'c':3 }) FROM t; +---- +[{key: a, value: 1}, {key: b, value: 2}, {key: c, value: 3}] +[{key: a, value: 1}, {key: b, value: 2}, {key: c, value: 3}] +[{key: a, value: 1}, {key: b, value: 2}, {key: c, value: 3}] + +query ? +SELECT map_entries(Map{column1: column2, column3: column4}) FROM t; +---- +[{key: a, value: 1}, {key: k1, value: 10}] +[{key: b, value: 2}, {key: k3, value: 30}] +[{key: d, value: 4}, {key: k5, value: 50}] + +query ? +SELECT map_entries(map(column5, column6)) FROM t; +---- +[{key: k1, value: 1}, {key: k2, value: 2}] +[{key: k3, value: 3}] +[{key: k5, value: 5}] + +query ? +SELECT map_entries(map(column8, column9)) FROM t; +---- +[{key: [1, 2, 3], value: a}] +[{key: [4], value: b}] +[{key: [1, 2], value: c}] + +query ? +SELECT map_entries(Map{}); +---- +[] + +query ? +SELECT map_entries(column1) from map_array_table_1; +---- +[{key: 1, value: [1, NULL, 3]}, {key: 2, value: [4, NULL, 6]}, {key: 3, value: [7, 8, 9]}] +[{key: 4, value: [1, NULL, 3]}, {key: 5, value: [4, NULL, 6]}, {key: 6, value: [7, 8, 9]}] +[{key: 7, value: [1, NULL, 3]}, {key: 8, value: [9, NULL, 6]}, {key: 9, value: [7, 8, 9]}] +NULL + # Tests for map_keys query ? @@ -782,5 +833,12 @@ select column3[true] from tt; ---- 3 +# https://github.com/apache/datafusion/issues/16187 +# should be NULL in case of out of bounds for Null Type +query ? +select map_values(map([named_struct('a', 1, 'b', null)], [named_struct('a', 1, 'b', null)]))[0] as a; +---- +NULL + statement ok drop table tt; diff --git a/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt b/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt new file mode 100644 index 0000000000000..aa623b63cdc72 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt @@ -0,0 +1,133 @@ +# Min/Max with FixedSizeList over integers +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')); +---- +[1, 2] [1, 2, 3, 4] + +# Min/Max with FixedSizeList over strings +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array('a', 'b', 'c'), 'FixedSizeList(3, Utf8)')), +(arrow_cast(make_array('a', 'b'), 'LargeList(Utf8)')); +---- +[a, b] [a, b, c] + +# Min/Max with FixedSizeList over booleans +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(true, false, true), 'FixedSizeList(3, Boolean)')), +(arrow_cast(make_array(true, false), 'FixedSizeList(2, Boolean)')); +---- +[true, false] [true, false, true] + +# Min/Max with FixedSizeList over nullable integers +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(NULL, 1, 2), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')); +---- +[1, 2] [NULL, 1, 2] + +# Min/Max FixedSizeList with different lengths and nulls +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(1, NULL, 3), 'FixedSizeList(3, Int64)')); +---- +[1, 2] [1, NULL, 3] + +# Min/Max FixedSizeList with only NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')); +---- +[NULL] [NULL, NULL] + + +# Min/Max FixedSizeList of varying types (integers and NULLs) +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(NULL, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2, NULL), 'FixedSizeList(3, Int64)')); +---- +[1, 2, 3] [NULL, 2, 3] + +# Min/Max FixedSizeList grouped by key with NULLs and differing lengths +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(1, NULL, 3), 'FixedSizeList(3, Int64)')), +(0, arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(1, arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')), +(1, arrow_cast(make_array(NULL, 5), 'FixedSizeList(2, Int64)')) +GROUP BY column1; +---- +0 [1, 2, 3, 4] [1, NULL, 3] +1 [1, 2] [NULL, 5] + +# Min/Max FixedSizeList grouped by key with NULLs and differing lengths +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(0, arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), +(1, arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')) +GROUP BY column1; +---- +0 [NULL] [NULL, NULL] +1 [NULL] [NULL] + +# Min/Max grouped FixedSizeList with empty and non-empty +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')), +(1, arrow_cast(make_array(5, 6), 'FixedSizeList(2, Int64)')) +GROUP BY column1; +---- +0 [1] [1] +1 [5, 6] [5, 6] + +# Min/Max over FixedSizeList with a window function +query ? +SELECT min(column1) OVER (ORDER BY column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[1, 2, 3] +[1, 2, 3] +[1, 2, 3] + +# Min/Max over FixedSizeList with a window function and nulls +query ? +SELECT min(column1) OVER (ORDER BY column1) FROM VALUES +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[2, 3] +[2, 3] +[2, 3] + +# Min/Max over FixedSizeList with a window function, nulls and ROWS BETWEEN statement +query ? +SELECT min(column1) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM VALUES +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[2, 3] +[2, 3] +[4, 5] + +# Min/Max over FixedSizeList with a window function using a different column +query ? +SELECT max(column2) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[4, 5] +[4, 5] diff --git a/datafusion/sqllogictest/test_files/min_max/init_data.slt.part b/datafusion/sqllogictest/test_files/min_max/init_data.slt.part new file mode 100644 index 0000000000000..57e14f6993d46 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/init_data.slt.part @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# -------------------------------------- +# 1. Min/Max over integers +# -------------------------------------- +statement ok +create table min_max_base_int as values + (make_array(1, 2, 3, 4)), + (make_array(1, 2)) +; + +# -------------------------------------- +# 2. Min/Max over strings +# -------------------------------------- +statement ok +create table min_max_base_string as values + (make_array('a', 'b', 'c')), + (make_array('a', 'b')) +; + +# -------------------------------------- +# 3. Min/Max over booleans +# -------------------------------------- +statement ok +create table min_max_base_bool as values + (make_array(true, false, true)), + (make_array(true, false)) +; + +# -------------------------------------- +# 4. Min/Max over nullable integers +# -------------------------------------- +statement ok +create table min_max_base_nullable_int as values + (make_array(NULL, 1, 2)), + (make_array(1, 2)) +; + +# -------------------------------------- +# 5. Min/Max with mixed lengths and nulls +# -------------------------------------- +statement ok +create table min_max_base_mixed_lengths_nulls as values + (make_array(1, 2, 3, 4)), + (make_array(1, 2)), + (make_array(1, NULL, 3)) +; + +# -------------------------------------- +# 6. Min/Max with only NULLs +# -------------------------------------- +statement ok +create table min_max_base_all_nulls as values + (make_array(NULL, NULL)), + (make_array(NULL)) +; + +# -------------------------------------- +# 7. Min/Max with partial NULLs +# -------------------------------------- +statement ok +create table min_max_base_null_variants as values + (make_array(1, 2, 3)), + (make_array(NULL, 2, 3)), + (make_array(1, 2, NULL)) +; + +# -------------------------------------- +# 8. Min/Max grouped by key with NULLs and differing lengths +# -------------------------------------- +statement ok +create table min_max_base_grouped_nulls as values + (0, make_array(1, NULL, 3)), + (0, make_array(1, 2, 3, 4)), + (1, make_array(1, 2)), + (1, make_array(NULL, 5)), + (1, make_array()) +; + +# -------------------------------------- +# 9. Min/Max grouped by key with only NULLs +# -------------------------------------- +statement ok +create table min_max_base_grouped_all_null as values + (0, make_array(NULL)), + (0, make_array(NULL, NULL)), + (1, make_array(NULL)) +; + +# -------------------------------------- +# 10. Min/Max grouped with empty and non-empty lists +# -------------------------------------- +statement ok +create table min_max_base_grouped_simple as values + (0, make_array()), + (0, make_array(1)), + (0, make_array()), + (1, make_array()), + (1, make_array(5, 6)) +; + +# -------------------------------------- +# 11. Min over with window function +# -------------------------------------- +statement ok +create table min_base_window_simple as values + (make_array(1, 2, 3)), + (make_array(1, 2, 3)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 12. Min over with window + NULLs +# -------------------------------------- +statement ok +create table min_base_window_with_null as values + (make_array(NULL)), + (make_array(4, 5)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 13. Min over with ROWS BETWEEN clause +# -------------------------------------- +statement ok +create table min_base_window_rows_between as values + (make_array(NULL)), + (make_array(4, 5)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 14. Max over using different order column +# -------------------------------------- +statement ok +create table max_base_window_different_column as values + (make_array(1, 2, 3), make_array(4, 5)), + (make_array(2, 3), make_array(2, 3)), + (make_array(2, 3), NULL) +; diff --git a/datafusion/sqllogictest/test_files/min_max/large_list.slt b/datafusion/sqllogictest/test_files/min_max/large_list.slt new file mode 100644 index 0000000000000..44789e9dd786c --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/large_list.slt @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +## -------------------------------------- +## 1. Min/Max over integers +## -------------------------------------- +statement ok +create table min_max_int as ( + select + arrow_cast(column1, 'LargeList(Int64)') as column1 + from min_max_base_int + ); + +## -------------------------------------- +## 2. Min/Max over strings +## -------------------------------------- +statement ok +create table min_max_string as ( + select + arrow_cast(column1, 'LargeList(Utf8)') as column1 +from min_max_base_string); + +## -------------------------------------- +## 3. Min/Max over booleans +## -------------------------------------- +statement ok +create table min_max_bool as +( + select + arrow_cast(column1, 'LargeList(Boolean)') as column1 +from min_max_base_bool); + +## -------------------------------------- +## 4. Min/Max over nullable integers +## -------------------------------------- +statement ok +create table min_max_nullable_int as ( + select + arrow_cast(column1, 'LargeList(Int64)') as column1 + from min_max_base_nullable_int +); + +## -------------------------------------- +## 5. Min/Max with mixed lengths and nulls +## -------------------------------------- +statement ok +create table min_max_mixed_lengths_nulls as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_mixed_lengths_nulls); + +## -------------------------------------- +## 6. Min/Max with only NULLs +## -------------------------------------- +statement ok +create table min_max_all_nulls as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_all_nulls); + +## -------------------------------------- +## 7. Min/Max with partial NULLs +## -------------------------------------- +statement ok +create table min_max_null_variants as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_null_variants); + +## -------------------------------------- +## 8. Min/Max grouped by key with NULLs and differing lengths +## -------------------------------------- +statement ok +create table min_max_grouped_nulls as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_nulls); + +## -------------------------------------- +## 9. Min/Max grouped by key with only NULLs +## -------------------------------------- +statement ok +create table min_max_grouped_all_null as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_all_null); + +## -------------------------------------- +## 10. Min/Max grouped with simple sizes +## -------------------------------------- +statement ok +create table min_max_grouped_simple as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_simple); + +## -------------------------------------- +## 11. Min over with window function +## -------------------------------------- +statement ok +create table min_window_simple as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_simple); + +## -------------------------------------- +## 12. Min over with window + NULLs +## -------------------------------------- +statement ok +create table min_window_with_null as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_with_null); + +## -------------------------------------- +## 13. Min over with ROWS BETWEEN clause +## -------------------------------------- +statement ok +create table min_window_rows_between as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_rows_between); + +## -------------------------------------- +## 14. Max over using different order column +## -------------------------------------- +statement ok +create table max_window_different_column as (select + arrow_cast(column1, 'LargeList(Int64)') as column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from max_base_window_different_column); + +include ./queries.slt.part diff --git a/datafusion/sqllogictest/test_files/min_max/list.slt b/datafusion/sqllogictest/test_files/min_max/list.slt new file mode 100644 index 0000000000000..e63e8303c7d5f --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/list.slt @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +# -------------------------------------- +# 1. Min/Max over integers +# -------------------------------------- +statement ok +create table min_max_int as ( + select * from min_max_base_int ) +; + +# -------------------------------------- +# 2. Min/Max over strings +# -------------------------------------- +statement ok +create table min_max_string as ( + select * from min_max_base_string ) +; + +# -------------------------------------- +# 3. Min/Max over booleans +# -------------------------------------- +statement ok +create table min_max_bool as ( + select * from min_max_base_bool ) +; + +# -------------------------------------- +# 4. Min/Max over nullable integers +# -------------------------------------- +statement ok +create table min_max_nullable_int as ( + select * from min_max_base_nullable_int ) +; + +# -------------------------------------- +# 5. Min/Max with mixed lengths and nulls +# -------------------------------------- +statement ok +create table min_max_mixed_lengths_nulls as ( + select * from min_max_base_mixed_lengths_nulls ) +; + +# -------------------------------------- +# 6. Min/Max with only NULLs +# -------------------------------------- +statement ok +create table min_max_all_nulls as ( + select * from min_max_base_all_nulls ) +; + +# -------------------------------------- +# 7. Min/Max with partial NULLs +# -------------------------------------- +statement ok +create table min_max_null_variants as ( + select * from min_max_base_null_variants ) +; + +# -------------------------------------- +# 8. Min/Max grouped by key with NULLs and differing lengths +# -------------------------------------- +statement ok +create table min_max_grouped_nulls as ( + select * from min_max_base_grouped_nulls ) +; + +# -------------------------------------- +# 9. Min/Max grouped by key with only NULLs +# -------------------------------------- +statement ok +create table min_max_grouped_all_null as ( + select * from min_max_base_grouped_all_null ) +; + +# -------------------------------------- +# 10. Min/Max grouped with simple sizes +# -------------------------------------- +statement ok +create table min_max_grouped_simple as ( + select * from min_max_base_grouped_simple ) +; + +# -------------------------------------- +# 11. Min over with window function +# -------------------------------------- +statement ok +create table min_window_simple as ( + select * from min_base_window_simple ) +; + +# -------------------------------------- +# 12. Min over with window + NULLs +# -------------------------------------- +statement ok +create table min_window_with_null as ( + select * from min_base_window_with_null ) +; + +# -------------------------------------- +# 13. Min over with ROWS BETWEEN clause +# -------------------------------------- +statement ok +create table min_window_rows_between as ( + select * from min_base_window_rows_between ) +; + +# -------------------------------------- +# 14. Max over using different order column +# -------------------------------------- +statement ok +create table max_window_different_column as ( + select * from max_base_window_different_column ) +; + +include ./queries.slt.part diff --git a/datafusion/sqllogictest/test_files/min_max/queries.slt.part b/datafusion/sqllogictest/test_files/min_max/queries.slt.part new file mode 100644 index 0000000000000..bc7fb840bf977 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/queries.slt.part @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +## 1. Min/Max List over integers +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_int; +---- +[1, 2] [1, 2, 3, 4] + +## 2. Min/Max List over strings +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_string; +---- +[a, b] [a, b, c] + +## 3. Min/Max List over booleans +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_bool; +---- +[true, false] [true, false, true] + +## 4. Min/Max List over nullable integers +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_nullable_int; +---- +[1, 2] [NULL, 1, 2] + +## 5. Min/Max List with mixed lengths and nulls +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_mixed_lengths_nulls; +---- +[1, 2] [1, NULL, 3] + +## 6. Min/Max List with only NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_all_nulls; +---- +[NULL] [NULL, NULL] + +## 7. Min/Max List with partial NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_null_variants; +---- +[1, 2, 3] [NULL, 2, 3] + +## 8. Min/Max List grouped by key with NULLs and differing lengths +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_nulls GROUP BY column1 ORDER BY column1; +---- +0 [1, 2, 3, 4] [1, NULL, 3] +1 [] [NULL, 5] + +## 9. Min/Max List grouped by key with only NULLs +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_all_null GROUP BY column1 ORDER BY column1; +---- +0 [NULL] [NULL, NULL] +1 [NULL] [NULL] + +## 10. Min/Max grouped List with simple sizes +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_simple GROUP BY column1 ORDER BY column1; +---- +0 [] [1] +1 [] [5, 6] + +## 11. Min over List with window function +query ? +SELECT MIN(column1) OVER (ORDER BY column1) FROM min_window_simple; +---- +[1, 2, 3] +[1, 2, 3] +[1, 2, 3] + +## 12. Min over List with window + NULLs +query ? +SELECT MIN(column1) OVER (ORDER BY column1) FROM min_window_with_null; +---- +[2, 3] +[2, 3] +[2, 3] + +## 13. Min over List with ROWS BETWEEN clause +query ? +SELECT MIN(column1) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM min_window_rows_between; +---- +[2, 3] +[2, 3] +[4, 5] + +## 14. Max over List using different order column +query ? +SELECT MAX(column2) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM max_window_different_column; +---- +[4, 5] +[4, 5] +[2, 3] diff --git a/datafusion/sqllogictest/test_files/monotonic_projection_test.slt b/datafusion/sqllogictest/test_files/monotonic_projection_test.slt index e8700b1fea275..9c806cfa0d8aa 100644 --- a/datafusion/sqllogictest/test_files/monotonic_projection_test.slt +++ b/datafusion/sqllogictest/test_files/monotonic_projection_test.slt @@ -129,12 +129,12 @@ ORDER BY a_str ASC, b ASC; ---- logical_plan 01)Sort: a_str ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST -02)--Projection: CAST(multiple_ordered_table.a AS Utf8) AS a_str, multiple_ordered_table.b +02)--Projection: CAST(multiple_ordered_table.a AS Utf8View) AS a_str, multiple_ordered_table.b 03)----TableScan: multiple_ordered_table projection=[a, b] physical_plan 01)SortPreservingMergeExec: [a_str@0 ASC NULLS LAST, b@1 ASC NULLS LAST] 02)--SortExec: expr=[a_str@0 ASC NULLS LAST, b@1 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[CAST(a@0 AS Utf8) as a_str, b@1 as b] +03)----ProjectionExec: expr=[CAST(a@0 AS Utf8View) as a_str, b@1 as b] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], file_type=csv, has_header=true diff --git a/datafusion/sqllogictest/test_files/operator.slt b/datafusion/sqllogictest/test_files/operator.slt index a651eda99684b..6f3c40188172d 100644 --- a/datafusion/sqllogictest/test_files/operator.slt +++ b/datafusion/sqllogictest/test_files/operator.slt @@ -262,6 +262,15 @@ from numeric_types; ---- Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 +############### NULL arithmetic ############### + +# select both nulls with basic arithmetic operations +query IIIII +select null + null, null - null, null * null, null / null, null % null; +---- +NULL NULL NULL NULL NULL + + ############### # Test for comparison with constants uses efficient types # Expect the physical plans to compare with constants of the same type diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 4e8be56f3377d..04a7615c764b8 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -94,6 +94,98 @@ NULL three 1 one 2 two +statement ok +set datafusion.sql_parser.default_null_ordering = 'nulls_min'; + +# test asc with `nulls_min` null ordering + +query IT +SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num +---- +NULL three +1 one +2 two + +# test desc with `nulls_min` null ordering + +query IT +SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC +---- +2 two +1 one +NULL three + +statement ok +set datafusion.sql_parser.default_null_ordering = 'nulls_first'; + +# test asc with `nulls_first` null ordering + +query IT +SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num +---- +NULL three +1 one +2 two + +# test desc with `nulls_first` null ordering + +query IT +SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC +---- +NULL three +2 two +1 one + + +statement ok +set datafusion.sql_parser.default_null_ordering = 'nulls_last'; + +# test asc with `nulls_last` null ordering + +query IT +SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num +---- +1 one +2 two +NULL three + +# test desc with `nulls_last` null ordering + +query IT +SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC +---- +2 two +1 one +NULL three + +statement ok +set datafusion.sql_parser.default_null_ordering = ''; + +# test asc with an empty `default_null_ordering`. Expected to use the default null ordering which is `nulls_max` + +query IT +SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num +---- +1 one +2 two +NULL three + +# test desc with an empty `default_null_ordering`. Expected to use the default null ordering which is `nulls_max` + +query IT +SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC +---- +NULL three +2 two +1 one + +statement error DataFusion error: Error during planning: Unsupported value Null +set datafusion.sql_parser.default_null_ordering = null; + +# reset to default null ordering +statement ok +set datafusion.sql_parser.default_null_ordering = 'nulls_max'; + # sort statement ok @@ -674,6 +766,13 @@ physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/ query error DataFusion error: Error during planning: Column a is not in schema CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS CSV WITH ORDER (a ASC) LOCATION 'file://path/to/table'; + +# Create external table with order column expression that can't be planned +# This is currently expected to fail, but should not panic +query error DataFusion error: Schema error: No field named a\. +CREATE EXTERNAL TABLE dt STORED AS CSV WITH ORDER (a || b) LOCATION 'file://path/to/table'; + + # Sort with duplicate sort expressions # Table is sorted multiple times on the same column name and should not fail statement ok @@ -1040,12 +1139,12 @@ limit 5; ---- logical_plan 01)Sort: c_str ASC NULLS LAST, fetch=5 -02)--Projection: CAST(ordered_table.c AS Utf8) AS c_str +02)--Projection: CAST(ordered_table.c AS Utf8View) AS c_str 03)----TableScan: ordered_table projection=[c] physical_plan 01)SortPreservingMergeExec: [c_str@0 ASC NULLS LAST], fetch=5 02)--SortExec: TopK(fetch=5), expr=[c_str@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[CAST(c@0 AS Utf8) as c_str] +03)----ProjectionExec: expr=[CAST(c@0 AS Utf8View) as c_str] 04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], file_type=csv, has_header=true @@ -1258,13 +1357,12 @@ logical_plan 08)--------TableScan: ordered_table projection=[a0, b, c, d] physical_plan 01)SortPreservingMergeExec: [d@4 ASC NULLS LAST, c@1 ASC NULLS LAST, a@2 ASC NULLS LAST, a0@3 ASC NULLS LAST, b@0 ASC NULLS LAST], fetch=2 -02)--UnionExec -03)----SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST, c@1 ASC NULLS LAST, a@2 ASC NULLS LAST, b@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST, c@1 ASC NULLS LAST, a@2 ASC NULLS LAST, a0@3 ASC NULLS LAST, b@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----UnionExec 04)------ProjectionExec: expr=[b@1 as b, c@2 as c, a@0 as a, NULL as a0, d@3 as d] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[c@2 ASC NULLS LAST], file_type=csv, has_header=true -06)----SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST, c@1 ASC NULLS LAST, a0@3 ASC NULLS LAST, b@0 ASC NULLS LAST], preserve_partitioning=[false] -07)------ProjectionExec: expr=[b@1 as b, c@2 as c, NULL as a, a0@0 as a0, d@3 as d] -08)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, b, c, d], output_ordering=[c@2 ASC NULLS LAST], file_type=csv, has_header=true +06)------ProjectionExec: expr=[b@1 as b, c@2 as c, NULL as a, a0@0 as a0, d@3 as d] +07)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, b, c, d], output_ordering=[c@2 ASC NULLS LAST], file_type=csv, has_header=true # Test: run the query from above query IIIII @@ -1299,9 +1397,9 @@ logical_plan 02)--Projection: CASE WHEN name = Utf8("name1") THEN Float64(0) WHEN name = Utf8("name2") THEN Float64(0.5) END AS a 03)----Union 04)------Projection: Utf8("name1") AS name -05)--------EmptyRelation +05)--------EmptyRelation: rows=1 06)------Projection: Utf8("name2") AS name -07)--------EmptyRelation +07)--------EmptyRelation: rows=1 physical_plan 01)SortPreservingMergeExec: [a@0 DESC] 02)--ProjectionExec: expr=[CASE WHEN name@0 = name1 THEN 0 WHEN name@0 = name2 THEN 0.5 END as a] @@ -1380,3 +1478,79 @@ physical_plan statement ok drop table table_with_ordered_not_null; + +# ORDER BY ALL +statement ok +set datafusion.sql_parser.dialect = 'DuckDB'; + +statement ok +CREATE OR REPLACE TABLE addresses AS + SELECT '123 Quack Blvd' AS address, 'DuckTown' AS city, '11111' AS zip + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'DuckTown', '11111' + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'Duck Town', '11111' + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'Duck Town', '11111-0001'; + + +query TTT +SELECT * FROM addresses ORDER BY ALL; +---- +111 Duck Duck Goose Ln Duck Town 11111 +111 Duck Duck Goose Ln Duck Town 11111-0001 +111 Duck Duck Goose Ln DuckTown 11111 +123 Quack Blvd DuckTown 11111 + +query TTT +SELECT * FROM addresses ORDER BY ALL DESC; +---- +123 Quack Blvd DuckTown 11111 +111 Duck Duck Goose Ln DuckTown 11111 +111 Duck Duck Goose Ln Duck Town 11111-0001 +111 Duck Duck Goose Ln Duck Town 11111 + +query TT +SELECT address, zip FROM addresses ORDER BY ALL; +---- +111 Duck Duck Goose Ln 11111 +111 Duck Duck Goose Ln 11111 +111 Duck Duck Goose Ln 11111-0001 +123 Quack Blvd 11111 + +# Create a table with an order clause that's not a simple column reference +statement ok +CREATE EXTERNAL TABLE ordered ( + a BIGINT NOT NULL, + b BIGINT NOT NULL +) +STORED AS CSV +LOCATION 'data/composite_order.csv' +OPTIONS ('format.has_header' 'true') +WITH ORDER (a + b); + +# Simple query should be just a table scan +query TT +EXPLAIN SELECT * from ordered; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/data/composite_order.csv]]}, projection=[a, b], output_ordering=[a@0 + b@1 ASC NULLS LAST], file_type=csv, has_header=true + +# Query ordered by the declared order should be just a table scan +query TT +EXPLAIN SELECT * from ordered ORDER BY (a + b); +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/data/composite_order.csv]]}, projection=[a, b], output_ordering=[a@0 + b@1 ASC NULLS LAST], file_type=csv, has_header=true + +# Order equivalence handling should make this query a simple table scan +query TT +EXPLAIN SELECT * from ordered ORDER BY -(a + b) desc nulls last; +---- +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/data/composite_order.csv]]}, projection=[a, b], output_ordering=[a@0 + b@1 ASC NULLS LAST], file_type=csv, has_header=true + +# Ordering by another column requires a sort +query TT +EXPLAIN SELECT * from ordered ORDER BY a; +---- +physical_plan +01)SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/data/composite_order.csv]]}, projection=[a, b], output_ordering=[a@0 + b@1 ASC NULLS LAST], file_type=csv, has_header=true diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index 2970b2effb3e9..e722005bf0f0d 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -130,8 +130,7 @@ STORED AS PARQUET; ---- 3 -# Check output plan again, expect no "output_ordering" clause in the physical_plan -> ParquetExec, -# due to there being more files than partitions: +# Check output plan again query TT EXPLAIN SELECT int_col, string_col FROM test_table @@ -142,8 +141,7 @@ logical_plan 02)--TableScan: test_table projection=[int_col, string_col] physical_plan 01)SortPreservingMergeExec: [string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] -02)--SortExec: expr=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/2.parquet]]}, projection=[int_col, string_col], file_type=parquet +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/2.parquet]]}, projection=[int_col, string_col], output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST], file_type=parquet # Perform queries using MIN and MAX @@ -304,6 +302,54 @@ select count(*) from listing_table; ---- 12 +# Test table pointing to the folder with parquet files(ends with /) +statement ok +CREATE EXTERNAL TABLE listing_table_folder_0 +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet/test_table/'; + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = true; + +# scan file: 0.parquet 1.parquet 2.parquet +query I +select count(*) from listing_table_folder_0; +---- +9 + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = false; + +# scan file: 0.parquet 1.parquet 2.parquet 3.parquet +query I +select count(*) from listing_table_folder_0; +---- +12 + +# Test table pointing to the folder with parquet files(doesn't end with /) +statement ok +CREATE EXTERNAL TABLE listing_table_folder_1 +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet/test_table'; + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = true; + +# scan file: 0.parquet 1.parquet 2.parquet +query I +select count(*) from listing_table_folder_1; +---- +9 + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = false; + +# scan file: 0.parquet 1.parquet 2.parquet 3.parquet +query I +select count(*) from listing_table_folder_1; +---- +12 + # Clean up statement ok DROP TABLE timestamp_with_tz; @@ -629,3 +675,190 @@ physical_plan statement ok drop table foo + + +# Tests for int96 timestamps written by spark +# See https://github.com/apache/datafusion/issues/9981 + +statement ok +CREATE EXTERNAL TABLE int96_from_spark +STORED AS PARQUET +LOCATION '../../parquet-testing/data/int96_from_spark.parquet'; + +# by default the value is read as nanosecond precision +query TTT +describe int96_from_spark +---- +a Timestamp(Nanosecond, None) YES + +# Note that the values are read as nanosecond precision +query P +select * from int96_from_spark +---- +2024-01-01T20:34:56.123456 +2024-01-01T01:00:00 +1816-03-29T08:56:08.066277376 +2024-12-30T23:00:00 +NULL +1815-11-08T16:01:01.191053312 + +statement ok +drop table int96_from_spark; + +# Enable coercion of int96 to microseconds +statement ok +set datafusion.execution.parquet.coerce_int96 = ms; + +statement ok +CREATE EXTERNAL TABLE int96_from_spark +STORED AS PARQUET +LOCATION '../../parquet-testing/data/int96_from_spark.parquet'; + +# Print schema +query TTT +describe int96_from_spark; +---- +a Timestamp(Millisecond, None) YES + +# Per https://github.com/apache/parquet-testing/blob/6e851ddd768d6af741c7b15dc594874399fc3cff/data/int96_from_spark.md?plain=1#L37 +# these values should be +# +# Some("2024-01-01T12:34:56.123456"), +# Some("2024-01-01T01:00:00Z"), +# Some("9999-12-31T01:00:00-02:00"), +# Some("2024-12-31T01:00:00+02:00"), +# None, +# Some("290000-12-31T01:00:00+02:00")) +# +# However, printing the large dates (9999-12-31 and 290000-12-31) is not supported by +# arrow yet +# +# See https://github.com/apache/arrow-rs/issues/7287 +query P +select * from int96_from_spark +---- +2024-01-01T20:34:56.123 +2024-01-01T01:00:00 +9999-12-31T03:00:00 +2024-12-30T23:00:00 +NULL +ERROR: Cast error: Failed to convert -9357363680509551 to datetime for Timestamp(Millisecond, None) + +# Cleanup / reset default setting +statement ok +drop table int96_from_spark; + +statement ok +set datafusion.execution.parquet.coerce_int96 = ns; + + +### Tests for metadata caching + +# Create temporary data +query I +COPY ( + SELECT 'k-' || i as k, i as v + FROM generate_series(1, 20000) t(i) + ORDER BY k +) +TO 'test_files/scratch/parquet/cache_metadata.parquet' +OPTIONS (MAX_ROW_GROUP_SIZE 4096, DATA_PAGE_ROW_COUNT_LIMIT 2048); +---- +20000 + +statement ok +CREATE EXTERNAL TABLE t +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet/cache_metadata.parquet'; + +query TI +select * from t where k = 'k-1000' or k = 'k-9999' order by k +---- +k-1000 1000 +k-9999 9999 + +query IT +select v, k from t where (v between 1 and 2) or (v between 9999 and 10000) order by v +---- +1 k-1 +2 k-2 +9999 k-9999 +10000 k-10000 + +# Updating the file should invalidate the cache. Otherwise, the following queries would fail +# (e.g., with "Arrow: Parquet argument error: External: incomplete frame"). +query I +COPY ( + SELECT 'k-' || i as k, 20000 - i as v + FROM generate_series(1, 20000) t(i) + ORDER BY k +) +TO 'test_files/scratch/parquet/cache_metadata.parquet' +OPTIONS (MAX_ROW_GROUP_SIZE 4096, DATA_PAGE_ROW_COUNT_LIMIT 2048); +---- +20000 + +query TI +select * from t where k = 'k-1000' or k = 'k-9999' order by k +---- +k-1000 19000 +k-9999 10001 + +query IT +select v, k from t where (v between 1 and 2) or (v between 9999 and 10000) order by v +---- +1 k-19999 +2 k-19998 +9999 k-10001 +10000 k-10000 + +statement ok +DROP TABLE t; + +# Partitioned files should be independently cached. Otherwise, the following queries might fail. +statement ok +COPY ( + SELECT i % 10 as part, 'k-' || i as k, i as v + FROM generate_series(0, 9) t(i) + ORDER BY k +) +TO 'test_files/scratch/parquet/cache_metadata_partitioned.parquet' +PARTITIONED BY (part); + +statement ok +CREATE EXTERNAL TABLE t +STORED AS PARQUET +PARTITIONED BY (part) +LOCATION 'test_files/scratch/parquet/cache_metadata_partitioned.parquet'; + +query TTI +select part, k, v from t where k = 'k-0' +---- +0 k-0 0 + +query TTI +select part, k, v from t where k = 'k-5' +---- +5 k-5 5 + +query TTI +select part, k, v from t where k = 'k-9' +---- +9 k-9 9 + +query TTI +select part, k, v from t order by k +---- +0 k-0 0 +1 k-1 1 +2 k-2 2 +3 k-3 3 +4 k-4 4 +5 k-5 5 +6 k-6 6 +7 k-7 7 +8 k-8 8 +9 k-9 9 + +statement ok +DROP TABLE t; diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt index 758113b708355..6dc2c264aeb85 100644 --- a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -54,7 +54,6 @@ LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/'; statement ok set datafusion.execution.parquet.pushdown_filters = true; -## Create table without pushdown statement ok CREATE EXTERNAL TABLE t_pushdown(a varchar, b int, c float) STORED AS PARQUET LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/'; @@ -76,17 +75,252 @@ NULL NULL NULL +query T +select a from t_pushdown where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query TT +EXPLAIN select a from t where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t.a ASC NULLS LAST +02)--Projection: t.a +03)----Filter: t.b > Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------FilterExec: b@1 > 2, projection=[a@0] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 +06)----------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +query TT +EXPLAIN select a from t_pushdown where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t_pushdown.a ASC NULLS LAST +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +query T +select a from t where b = 2 ORDER BY b; +---- +bar + +query T +select a from t_pushdown where b = 2 ORDER BY b; +---- +bar + +query TT +EXPLAIN select a from t where b = 2 ORDER BY b; +---- +logical_plan +01)Projection: t.a +02)--Sort: t.b ASC NULLS LAST +03)----Filter: t.b = Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b = Int32(2)] +physical_plan +01)CoalescePartitionsExec +02)--ProjectionExec: expr=[a@0 as a] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------FilterExec: b@1 = 2 +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 +06)----------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], file_type=parquet, predicate=b@1 = 2, pruning_predicate=b_null_count@2 != row_count@3 AND b_min@0 <= 2 AND 2 <= b_max@1, required_guarantees=[b in (2)] + +query TT +EXPLAIN select a from t_pushdown where b = 2 ORDER BY b; +---- +logical_plan +01)Projection: t_pushdown.a +02)--Sort: t_pushdown.b ASC NULLS LAST +03)----Filter: t_pushdown.b = Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b = Int32(2)] +physical_plan +01)CoalescePartitionsExec +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 = 2, pruning_predicate=b_null_count@2 != row_count@3 AND b_min@0 <= 2 AND 2 <= b_max@1, required_guarantees=[b in (2)] + +# If we set the setting to `true` it override's the table's setting +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +query T +select a from t where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query T +select a from t_pushdown where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query TT +EXPLAIN select a from t where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t.a ASC NULLS LAST +02)--Projection: t.a +03)----Filter: t.b > Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +query TT +EXPLAIN select a from t_pushdown where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t_pushdown.a ASC NULLS LAST +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + +query T +select a from t where b = 2 ORDER BY b; +---- +bar + +query T +select a from t_pushdown where b = 2 ORDER BY b; +---- +bar + +query TT +EXPLAIN select a from t where b = 2 ORDER BY b; +---- +logical_plan +01)Projection: t.a +02)--Sort: t.b ASC NULLS LAST +03)----Filter: t.b = Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b = Int32(2)] +physical_plan +01)CoalescePartitionsExec +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 = 2, pruning_predicate=b_null_count@2 != row_count@3 AND b_min@0 <= 2 AND 2 <= b_max@1, required_guarantees=[b in (2)] + +query TT +EXPLAIN select a from t_pushdown where b = 2 ORDER BY b; +---- +logical_plan +01)Projection: t_pushdown.a +02)--Sort: t_pushdown.b ASC NULLS LAST +03)----Filter: t_pushdown.b = Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b = Int32(2)] +physical_plan +01)CoalescePartitionsExec +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 = 2, pruning_predicate=b_null_count@2 != row_count@3 AND b_min@0 <= 2 AND 2 <= b_max@1, required_guarantees=[b in (2)] + +# If we reset the default the table created without pushdown goes back to disabling it +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +query T +select a from t where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query T +select a from t_pushdown where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query TT +EXPLAIN select a from t where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t.a ASC NULLS LAST +02)--Projection: t.a +03)----Filter: t.b > Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------FilterExec: b@1 > 2, projection=[a@0] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 +06)----------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] + query TT EXPLAIN select a from t_pushdown where b > 2 ORDER BY a; ---- logical_plan 01)Sort: t_pushdown.a ASC NULLS LAST -02)--TableScan: t_pushdown projection=[a], full_filters=[t_pushdown.b > Int32(2)] +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2)] physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] 02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] +query T +select a from t where b = 2 ORDER BY b; +---- +bar + +query T +select a from t_pushdown where b = 2 ORDER BY b; +---- +bar + +query TT +EXPLAIN select a from t where b = 2 ORDER BY b; +---- +logical_plan +01)Projection: t.a +02)--Sort: t.b ASC NULLS LAST +03)----Filter: t.b = Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b = Int32(2)] +physical_plan +01)CoalescePartitionsExec +02)--ProjectionExec: expr=[a@0 as a] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------FilterExec: b@1 = 2 +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 +06)----------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], file_type=parquet, predicate=b@1 = 2, pruning_predicate=b_null_count@2 != row_count@3 AND b_min@0 <= 2 AND 2 <= b_max@1, required_guarantees=[b in (2)] + +query TT +EXPLAIN select a from t_pushdown where b = 2 ORDER BY b; +---- +logical_plan +01)Projection: t_pushdown.a +02)--Sort: t_pushdown.b ASC NULLS LAST +03)----Filter: t_pushdown.b = Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b = Int32(2)] +physical_plan +01)CoalescePartitionsExec +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 = 2, pruning_predicate=b_null_count@2 != row_count@3 AND b_min@0 <= 2 AND 2 <= b_max@1, required_guarantees=[b in (2)] # When filter pushdown *is* enabled, ParquetExec can filter exactly, # not just metadata, so we expect to see no FilterExec @@ -115,6 +349,23 @@ physical_plan 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 06)----------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], file_type=parquet, predicate=b@1 > 2, pruning_predicate=b_null_count@1 != row_count@2 AND b_max@0 > 2, required_guarantees=[] +query T +select a from t_pushdown where b = 2 ORDER BY b; +---- +bar + +query TT +EXPLAIN select a from t_pushdown where b = 2 ORDER BY b; +---- +logical_plan +01)Projection: t_pushdown.a +02)--Sort: t_pushdown.b ASC NULLS LAST +03)----Filter: t_pushdown.b = Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b = Int32(2)] +physical_plan +01)CoalescePartitionsExec +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], file_type=parquet, predicate=b@1 = 2, pruning_predicate=b_null_count@2 != row_count@3 AND b_min@0 <= 2 AND 2 <= b_max@1, required_guarantees=[b in (2)] + # also test querying on columns that are not in all the files query T select a from t_pushdown where b > 2 AND a IS NOT NULL order by a; @@ -127,7 +378,9 @@ EXPLAIN select a from t_pushdown where b > 2 AND a IS NOT NULL order by a; ---- logical_plan 01)Sort: t_pushdown.a ASC NULLS LAST -02)--TableScan: t_pushdown projection=[a], full_filters=[t_pushdown.b > Int32(2), t_pushdown.a IS NOT NULL] +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) AND t_pushdown.a IS NOT NULL +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2), t_pushdown.a IS NOT NULL] physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] 02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -144,15 +397,190 @@ EXPLAIN select b from t_pushdown where a = 'bar' order by b; ---- logical_plan 01)Sort: t_pushdown.b ASC NULLS LAST -02)--TableScan: t_pushdown projection=[b], full_filters=[t_pushdown.a = Utf8("bar")] +02)--Projection: t_pushdown.b +03)----Filter: t_pushdown.a = Utf8View("bar") +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.a = Utf8View("bar")] physical_plan 01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] 02)--SortExec: expr=[b@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[b], file_type=parquet, predicate=a@0 = bar, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= bar AND bar <= a_max@1, required_guarantees=[a in (bar)] + +# should not push down volatile predicates such as RANDOM +# expect that the random predicate is evaluated after the scan +query TT +EXPLAIN select a from t_pushdown where b > random(); +---- +logical_plan +01)Projection: t_pushdown.a +02)--Filter: CAST(t_pushdown.b AS Float64) > random() +03)----TableScan: t_pushdown projection=[a, b] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: CAST(b@1 AS Float64) > random(), projection=[a@0] +03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], file_type=parquet + ## cleanup statement ok DROP TABLE t; statement ok DROP TABLE t_pushdown; + +## Test filter pushdown with a predicate that references both a partition column and a file column +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +## Create table +statement ok +CREATE EXTERNAL TABLE t_pushdown(part text, val text) +STORED AS PARQUET +PARTITIONED BY (part) +LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/'; + +statement ok +COPY ( + SELECT arrow_cast('a', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet' +STORED AS PARQUET; + +statement ok +COPY ( + SELECT arrow_cast('b', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet' +STORED AS PARQUET; + +statement ok +COPY ( + SELECT arrow_cast('xyz', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet' +STORED AS PARQUET; + +query TT +select * from t_pushdown where part == val order by part, val; +---- +a a +b b + +query TT +select * from t_pushdown where part != val order by part, val; +---- +xyz c + +# If we reference both a file and partition column the predicate cannot be pushed down +query TT +EXPLAIN select * from t_pushdown where part != val +---- +logical_plan +01)Filter: t_pushdown.val != t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], partial_filters=[t_pushdown.val != t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 != part@1 +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet + +# If we reference only a partition column it gets evaluated during the listing phase +query TT +EXPLAIN select * from t_pushdown where part != 'a'; +---- +logical_plan TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part != Utf8View("a")] +physical_plan DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet + +# And if we reference only a file column it gets pushed down +query TT +EXPLAIN select * from t_pushdown where val != 'c'; +---- +logical_plan +01)Filter: t_pushdown.val != Utf8View("c") +02)--TableScan: t_pushdown projection=[val, part], partial_filters=[t_pushdown.val != Utf8View("c")] +physical_plan DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet, predicate=val@0 != c, pruning_predicate=val_null_count@2 != row_count@3 AND (val_min@0 != c OR c != val_max@1), required_guarantees=[val not in (c)] + +# If we have a mix of filters: +# - The partition filters get evaluated during planning +# - The mixed filters end up in a FilterExec +# - The file filters get pushed down into the scan +query TT +EXPLAIN select * from t_pushdown where val != 'd' AND val != 'c' AND part = 'a' AND part != val; +---- +logical_plan +01)Filter: t_pushdown.val != Utf8View("d") AND t_pushdown.val != Utf8View("c") AND t_pushdown.val != t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part = Utf8View("a")], partial_filters=[t_pushdown.val != Utf8View("d"), t_pushdown.val != Utf8View("c"), t_pushdown.val != t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 != part@1 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet]]}, projection=[val, part], file_type=parquet, predicate=val@0 != d AND val@0 != c, pruning_predicate=val_null_count@2 != row_count@3 AND (val_min@0 != d OR d != val_max@1) AND val_null_count@2 != row_count@3 AND (val_min@0 != c OR c != val_max@1), required_guarantees=[val not in (c, d)] + +# The order of filters should not matter +query TT +EXPLAIN select val, part from t_pushdown where part = 'a' AND part = val; +---- +logical_plan +01)Filter: t_pushdown.val = t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part = Utf8View("a")], partial_filters=[t_pushdown.val = t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 = part@1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet]]}, projection=[val, part], file_type=parquet + +query TT +select val, part from t_pushdown where part = 'a' AND part = val; +---- +a a + +query TT +EXPLAIN select val, part from t_pushdown where part = val AND part = 'a'; +---- +logical_plan +01)Filter: t_pushdown.val = t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part = Utf8View("a")], partial_filters=[t_pushdown.val = t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 = part@1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet]]}, projection=[val, part], file_type=parquet + +query TT +select val, part from t_pushdown where part = val AND part = 'a'; +---- +a a + +statement ok +COPY ( + SELECT + '00000000000000000000000000000001' AS trace_id, + '2023-10-01 00:00:00'::timestamptz AS start_timestamp, + 'prod' as deployment_environment +) +TO 'test_files/scratch/parquet_filter_pushdown/data/1.parquet'; + +statement ok +COPY ( + SELECT + '00000000000000000000000000000002' AS trace_id, + '2024-10-01 00:00:00'::timestamptz AS start_timestamp, + 'staging' as deployment_environment +) +TO 'test_files/scratch/parquet_filter_pushdown/data/2.parquet'; + +statement ok +CREATE EXTERNAL TABLE t1 STORED AS PARQUET LOCATION 'test_files/scratch/parquet_filter_pushdown/data/'; + +statement ok +SET datafusion.execution.parquet.pushdown_filters = true; + +query T +SELECT deployment_environment +FROM t1 +WHERE trace_id = '00000000000000000000000000000002' +ORDER BY start_timestamp, trace_id; +---- +staging + +query P +SELECT start_timestamp +FROM t1 +WHERE trace_id = '00000000000000000000000000000002' AND deployment_environment = 'staging' +ORDER BY start_timestamp, trace_id +LIMIT 1; +---- +2024-10-01T00:00:00Z diff --git a/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt b/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt index d325ca423daca..fe909e70ffb00 100644 --- a/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt +++ b/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt @@ -38,20 +38,22 @@ CREATE TABLE src_table ( bigint_col BIGINT, date_col DATE, overlapping_col INT, - constant_col INT + constant_col INT, + nulls_first_col INT, + nulls_last_col INT ) AS VALUES -- first file -(1, 3, 'aaa', 100, 1, 0, 0), -(2, 2, 'bbb', 200, 2, 1, 0), -(3, 1, 'ccc', 300, 3, 2, 0), +(1, 3, 'aaa', 100, 1, 0, 0, NULL, 1), +(2, 2, 'bbb', 200, 2, 1, 0, NULL, 2), +(3, 1, 'ccc', 300, 3, 2, 0, 1, 3), -- second file -(4, 6, 'ddd', 400, 4, 0, 0), -(5, 5, 'eee', 500, 5, 1, 0), -(6, 4, 'fff', 600, 6, 2, 0), +(4, 6, 'ddd', 400, 4, 0, 0, 2, 4), +(5, 5, 'eee', 500, 5, 1, 0, 3, 5), +(6, 4, 'fff', 600, 6, 2, 0, 4, 6), -- third file -(7, 9, 'ggg', 700, 7, 3, 0), -(8, 8, 'hhh', 800, 8, 4, 0), -(9, 7, 'iii', 900, 9, 5, 0); +(7, 9, 'ggg', 700, 7, 3, 0, 5, 7), +(8, 8, 'hhh', 800, 8, 4, 0, 6, NULL), +(9, 7, 'iii', 900, 9, 5, 0, 7, NULL); # Setup 3 files, in particular more files than there are partitions @@ -90,11 +92,18 @@ CREATE EXTERNAL TABLE test_table ( bigint_col BIGINT NOT NULL, date_col DATE NOT NULL, overlapping_col INT NOT NULL, - constant_col INT NOT NULL + constant_col INT NOT NULL, + nulls_first_col INT, + nulls_last_col INT ) STORED AS PARQUET PARTITIONED BY (partition_col) -WITH ORDER (int_col ASC NULLS LAST, bigint_col ASC NULLS LAST) +WITH ORDER ( + int_col ASC NULLS LAST, + bigint_col ASC NULLS LAST, + nulls_first_col ASC NULLS FIRST, + nulls_last_col ASC NULLS LAST +) LOCATION 'test_files/scratch/parquet_sorted_statistics/test_table'; # Order by numeric columns @@ -102,28 +111,33 @@ LOCATION 'test_files/scratch/parquet_sorted_statistics/test_table'; # DataFusion doesn't currently support string column statistics # This should not require a sort. query TT -EXPLAIN SELECT int_col, bigint_col +EXPLAIN SELECT int_col, bigint_col, nulls_first_col, nulls_last_col FROM test_table -ORDER BY int_col, bigint_col; +ORDER BY int_col, bigint_col, nulls_first_col NULLS FIRST, nulls_last_col NULLS LAST; ---- logical_plan -01)Sort: test_table.int_col ASC NULLS LAST, test_table.bigint_col ASC NULLS LAST -02)--TableScan: test_table projection=[int_col, bigint_col] -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[int_col, bigint_col], output_ordering=[int_col@0 ASC NULLS LAST, bigint_col@1 ASC NULLS LAST], file_type=parquet +01)Sort: test_table.int_col ASC NULLS LAST, test_table.bigint_col ASC NULLS LAST, test_table.nulls_first_col ASC NULLS FIRST, test_table.nulls_last_col ASC NULLS LAST +02)--TableScan: test_table projection=[int_col, bigint_col, nulls_first_col, nulls_last_col] +physical_plan +01)SortPreservingMergeExec: [int_col@0 ASC NULLS LAST, bigint_col@1 ASC NULLS LAST, nulls_first_col@2 ASC, nulls_last_col@3 ASC NULLS LAST] +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet]]}, projection=[int_col, bigint_col, nulls_first_col, nulls_last_col], output_ordering=[int_col@0 ASC NULLS LAST, bigint_col@1 ASC NULLS LAST, nulls_first_col@2 ASC, nulls_last_col@3 ASC NULLS LAST], file_type=parquet # Another planning test, but project on a column with unsupported statistics # We should be able to ignore this and look at only the relevant statistics query TT EXPLAIN SELECT string_col FROM test_table -ORDER BY int_col, bigint_col; +ORDER BY int_col, bigint_col, nulls_first_col NULLS FIRST, nulls_last_col NULLS LAST; ---- logical_plan 01)Projection: test_table.string_col -02)--Sort: test_table.int_col ASC NULLS LAST, test_table.bigint_col ASC NULLS LAST -03)----Projection: test_table.string_col, test_table.int_col, test_table.bigint_col -04)------TableScan: test_table projection=[int_col, string_col, bigint_col] -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[string_col], file_type=parquet +02)--Sort: test_table.int_col ASC NULLS LAST, test_table.bigint_col ASC NULLS LAST, test_table.nulls_first_col ASC NULLS FIRST, test_table.nulls_last_col ASC NULLS LAST +03)----Projection: test_table.string_col, test_table.int_col, test_table.bigint_col, test_table.nulls_first_col, test_table.nulls_last_col +04)------TableScan: test_table projection=[int_col, string_col, bigint_col, nulls_first_col, nulls_last_col] +physical_plan +01)ProjectionExec: expr=[string_col@0 as string_col] +02)--SortPreservingMergeExec: [int_col@1 ASC NULLS LAST, bigint_col@2 ASC NULLS LAST, nulls_first_col@3 ASC, nulls_last_col@4 ASC NULLS LAST] +03)----DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet]]}, projection=[string_col, int_col, bigint_col, nulls_first_col, nulls_last_col], output_ordering=[int_col@1 ASC NULLS LAST, bigint_col@2 ASC NULLS LAST, nulls_first_col@3 ASC, nulls_last_col@4 ASC NULLS LAST], file_type=parquet # Clean up & recreate but sort on descending column statement ok @@ -155,7 +169,9 @@ ORDER BY descending_col DESC NULLS LAST, bigint_col ASC NULLS LAST; logical_plan 01)Sort: test_table.descending_col DESC NULLS LAST, test_table.bigint_col ASC NULLS LAST 02)--TableScan: test_table projection=[descending_col, bigint_col] -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet]]}, projection=[descending_col, bigint_col], output_ordering=[descending_col@0 DESC NULLS LAST, bigint_col@1 ASC NULLS LAST], file_type=parquet +physical_plan +01)SortPreservingMergeExec: [descending_col@0 DESC NULLS LAST, bigint_col@1 ASC NULLS LAST] +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet]]}, projection=[descending_col, bigint_col], output_ordering=[descending_col@0 DESC NULLS LAST, bigint_col@1 ASC NULLS LAST], file_type=parquet # Clean up & re-create with partition columns in sort order statement ok @@ -189,7 +205,9 @@ ORDER BY partition_col, int_col, bigint_col; logical_plan 01)Sort: test_table.partition_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST, test_table.bigint_col ASC NULLS LAST 02)--TableScan: test_table projection=[int_col, bigint_col, partition_col] -physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[int_col, bigint_col, partition_col], output_ordering=[partition_col@2 ASC NULLS LAST, int_col@0 ASC NULLS LAST, bigint_col@1 ASC NULLS LAST], file_type=parquet +physical_plan +01)SortPreservingMergeExec: [partition_col@2 ASC NULLS LAST, int_col@0 ASC NULLS LAST, bigint_col@1 ASC NULLS LAST] +02)--DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet]]}, projection=[int_col, bigint_col, partition_col], output_ordering=[partition_col@2 ASC NULLS LAST, int_col@0 ASC NULLS LAST, bigint_col@1 ASC NULLS LAST], file_type=parquet # Clean up & re-create with overlapping column in sort order # This will test the ability to sort files with overlapping statistics diff --git a/datafusion/sqllogictest/test_files/parquet_statistics.slt b/datafusion/sqllogictest/test_files/parquet_statistics.slt new file mode 100644 index 0000000000000..c04235ef4ee6f --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet_statistics.slt @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for statistics in parquet files. +# Writes data into two files: +# * test_table/0.parquet +# * test_table/1.parquet +# +# And verifies statistics are correctly calculated for the table +# +# NOTE that statistics are ONLY gathered when the table is first created +# so the table must be recreated to see the effects of the setting + +query I +COPY (values (1), (2), (3)) +TO 'test_files/scratch/parquet_statistics/test_table/0.parquet' +STORED AS PARQUET; +---- +3 + +query I +COPY (values (3), (4)) +TO 'test_files/scratch/parquet_statistics/test_table/1.parquet' +STORED AS PARQUET; +---- +2 + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +set datafusion.explain.show_statistics = true; + +###### +# By default, the statistics are gathered +###### + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Inexact(2), Bytes=Inexact(31), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Inexact(2), Bytes=Inexact(31), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Inexact(5), Bytes=Inexact(121), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)], statistics=[Rows=Inexact(5), Bytes=Inexact(121), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] + +# cleanup +statement ok +DROP TABLE test_table; + +###### +# When the setting is true, statistics are gathered +###### + +statement ok +set datafusion.execution.collect_statistics = true; + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Inexact(2), Bytes=Inexact(31), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Inexact(2), Bytes=Inexact(31), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Inexact(5), Bytes=Inexact(121), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)], statistics=[Rows=Inexact(5), Bytes=Inexact(121), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] + +# cleanup +statement ok +DROP TABLE test_table; + + +###### +# When the setting is false, the statistics are NOT gathered +###### + +statement ok +set datafusion.execution.collect_statistics = false; + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(1)))]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)], statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]] + +# cleanup +statement ok +DROP TABLE test_table; diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt index d14b6ca81f67e..fcc12226e47c5 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt @@ -48,7 +48,7 @@ COPY aggregate_test_100_by_sql ### ## Setup test for datafusion ### -onlyif DataFusion +skipif postgres statement ok CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c1 VARCHAR NOT NULL, diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt index 25b4924715caa..4453aa1489a1b 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt @@ -49,7 +49,7 @@ COPY aggregate_test_100_by_sql ### ## Setup test for datafusion ### -onlyif DataFusion +skipif postgres statement ok CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c1 VARCHAR NOT NULL, diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt index e02c19016790d..f8e0770271309 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt @@ -46,7 +46,7 @@ COPY aggregate_test_100_by_sql ### ## Setup test for datafusion ### -onlyif DataFusion +skipif postgres statement ok CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c1 VARCHAR NOT NULL, diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt index edad3747a2030..f967d79a6d952 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt @@ -46,7 +46,7 @@ COPY aggregate_test_100_by_sql ### ## Setup test for datafusion ### -onlyif DataFusion +skipif postgres statement ok CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c1 VARCHAR NOT NULL, diff --git a/datafusion/sqllogictest/test_files/pipe_operator.slt b/datafusion/sqllogictest/test_files/pipe_operator.slt new file mode 100644 index 0000000000000..5908b3d6b2a4d --- /dev/null +++ b/datafusion/sqllogictest/test_files/pipe_operator.slt @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# BigQuery supports the pipe operator syntax +# TODO: Make the Generic dialect support the pipe operator syntax +statement ok +set datafusion.sql_parser.dialect = 'BigQuery'; + +statement ok +CREATE TABLE test( + a INT, + b FLOAT, + c VARCHAR, + n VARCHAR +) AS VALUES + (1, 1.1, 'a', NULL), + (2, 2.2, 'b', NULL), + (3, 3.3, 'c', NULL) +; + +# WHERE pipe +query IRTT +SELECT * +FROM test +|> WHERE a > 1 +---- +2 2.2 b NULL +3 3.3 c NULL + +# ORDER BY pipe +query IRTT +SELECT * +FROM test +|> ORDER BY a DESC +---- +3 3.3 c NULL +2 2.2 b NULL +1 1.1 a NULL + +# ORDER BY pipe, limit +query IRTT +SELECT * +FROM test +|> ORDER BY a DESC +|> LIMIT 1 +---- +3 3.3 c NULL + +# SELECT pipe +query I +SELECT * +FROM test +|> SELECT a +---- +1 +2 +3 + +# EXTEND pipe +query IRR +SELECT * +FROM test +|> SELECT a, b +|> EXTEND a + b AS a_plus_b +---- +1 1.1 2.1 +2 2.2 4.2 +3 3.3 6.3 + +query IRR +SELECT * +FROM test +|> SELECT a, b +|> where a = 1 +|> EXTEND a + b AS a_plus_b +---- +1 1.1 2.1 + +# AS pipe +query I +SELECT * +FROM test +|> as test_pipe +|> select test_pipe.a +---- +1 +2 +3 + +# UNION pipe +query I +SELECT * +FROM test +|> select a +|> UNION ALL ( + SELECT a FROM test +); +---- +1 +2 +3 +1 +2 +3 + +# INTERSECT pipe +query I rowsort +SELECT * FROM range(0,3) +|> INTERSECT DISTINCT + (SELECT * FROM range(1,3)); +---- +1 +2 + +# EXCEPT pipe +query I rowsort +select * from range(0,10) +|> EXCEPT DISTINCT (select * from range(5,10)); +---- +0 +1 +2 +3 +4 + +# AGGREGATE pipe +query II +( + SELECT 'apples' AS item, 2 AS sales + UNION ALL + SELECT 'bananas' AS item, 5 AS sales + UNION ALL + SELECT 'apples' AS item, 7 AS sales +) +|> AGGREGATE COUNT(*) AS num_items, SUM(sales) AS total_sales; +---- +3 14 + +query TII rowsort +( + SELECT 'apples' AS item, 2 AS sales + UNION ALL + SELECT 'bananas' AS item, 5 AS sales + UNION ALL + SELECT 'apples' AS item, 7 AS sales +) +|> AGGREGATE COUNT(*) AS num_items, SUM(sales) AS total_sales + GROUP BY item; +---- +apples 2 9 +bananas 1 5 + +query TII rowsort +( + SELECT 'apples' AS item, 2 AS sales + UNION ALL + SELECT 'bananas' AS item, 5 AS sales + UNION ALL + SELECT 'apples' AS item, 7 AS sales +) +|> AGGREGATE COUNT(*) AS num_items, SUM(sales) AS total_sales + GROUP BY item +|> WHERE num_items > 1; +---- +apples 2 9 + +# JOIN pipe +query TII +( + SELECT 'apples' AS item, 2 AS sales + UNION ALL + SELECT 'bananas' AS item, 5 AS sales +) +|> AS produce_sales +|> LEFT JOIN + ( + SELECT "apples" AS item, 123 AS id + ) AS produce_data + ON produce_sales.item = produce_data.item +|> SELECT produce_sales.item, sales, id; +---- +apples 2 123 +bananas 5 NULL diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index b263e39f3b11b..77ee3e4f05a0d 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -662,11 +662,11 @@ OR ---- logical_plan 01)Projection: lineitem.l_partkey -02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) +02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) 03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) 04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] -05)----Filter: (part.p_brand = Utf8("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) -06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_size <= Int32(15)] +05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) +06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_partkey@0] @@ -755,8 +755,8 @@ logical_plan 05)--------Inner Join: lineitem.l_partkey = part.p_partkey 06)----------TableScan: lineitem projection=[l_partkey, l_extendedprice, l_discount] 07)----------Projection: part.p_partkey -08)------------Filter: part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23") -09)--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23")] +08)------------Filter: part.p_brand = Utf8View("Brand#12") OR part.p_brand = Utf8View("Brand#23") +09)--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8View("Brand#12") OR part.p_brand = Utf8View("Brand#23")] 10)------TableScan: partsupp projection=[ps_partkey, ps_suppkey] physical_plan 01)AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[sum(lineitem.l_extendedprice), avg(lineitem.l_discount), count(DISTINCT partsupp.ps_suppkey)] @@ -777,6 +777,52 @@ physical_plan 16)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 17)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand], file_type=csv, has_header=true +# Simplification of a binary operator with a NULL value + +statement ok +create table t(x int) as values (1), (2), (3); + +query TT +EXPLAIN FORMAT INDENT SELECT x > NULL FROM t; +---- +logical_plan +01)Projection: Boolean(NULL) AS t.x > NULL +02)--TableScan: t projection=[] +physical_plan +01)ProjectionExec: expr=[NULL as t.x > NULL] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN FORMAT INDENT SELECT * FROM t WHERE x > NULL; +---- +logical_plan EmptyRelation: rows=0 +physical_plan EmptyExec + +query TT +EXPLAIN FORMAT INDENT SELECT * FROM t WHERE x < 5 AND (10 * NULL < x); +---- +logical_plan +01)Filter: t.x < Int32(5) AND Boolean(NULL) +02)--TableScan: t projection=[x] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: x@0 < 5 AND NULL +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN FORMAT INDENT SELECT * FROM t WHERE x < 5 OR (10 * NULL < x); +---- +logical_plan +01)Filter: t.x < Int32(5) OR Boolean(NULL) +02)--TableScan: t projection=[x] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: x@0 < 5 OR NULL +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +drop table t; + # Inlist simplification statement ok @@ -785,7 +831,7 @@ create table t(x int) as values (1), (2), (3); query TT explain select x from t where x IN (1,2,3) AND x IN (4,5); ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 physical_plan EmptyExec query TT @@ -808,8 +854,29 @@ physical_plan query TT explain select x from t where x NOT IN (1,2,3,4,5) AND x IN (1,2,3); ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 physical_plan EmptyExec +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression InSubquery\(InSubquery \{ expr: Literal\(Int64\(NULL\), None\), subquery: , negated: false \}\) +WITH empty AS (SELECT 10 WHERE false) +SELECT + NULL IN (SELECT * FROM empty), -- should be false, as the right side is empty relation + NULL NOT IN (SELECT * FROM empty) -- should be true, as the right side is empty relation +FROM (SELECT 1) t; + +query I +WITH empty AS (SELECT 10 WHERE false) +SELECT * FROM (SELECT 1) t +WHERE NOT (NULL IN (SELECT * FROM empty)); -- all rows should be returned +---- +1 + +query I +WITH empty AS (SELECT 10 WHERE false) +SELECT * FROM (SELECT 1) t +WHERE NULL NOT IN (SELECT * FROM empty); -- all rows should be returned +---- +1 + statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index 33df0d26f3610..d61603ae65588 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -92,7 +92,7 @@ DEALLOCATE my_plan statement ok PREPARE my_plan AS SELECT * FROM person WHERE id < $1; -statement error No value found for placeholder with id \$1 +statement error Prepared statement 'my_plan' expects 1 parameters, but 0 provided EXECUTE my_plan statement ok diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt index 0f0cbac1fa323..9f840e7bdc2f0 100644 --- a/datafusion/sqllogictest/test_files/projection.slt +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -252,3 +252,31 @@ physical_plan statement ok drop table t; + +# Regression test for +# https://github.com/apache/datafusion/issues/17513 + +query I +COPY (select 1 as a, 2 as b) +TO 'test_files/scratch/projection/17513.parquet' +STORED AS PARQUET; +---- +1 + +statement ok +create external table t1 stored as parquet location 'test_files/scratch/projection/17513.parquet'; + +query TT +explain format indent +select from t1 where t1.a > 1; +---- +logical_plan +01)Projection: +02)--Filter: t1.a > Int64(1) +03)----TableScan: t1 projection=[a], partial_filters=[t1.a > Int64(1)] +physical_plan +01)ProjectionExec: expr=[] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: a@0 > 1 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection/17513.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 > 1, pruning_predicate=a_null_count@1 != row_count@2 AND a_max@0 > 1, required_guarantees=[] diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 67965146e76b3..47095d92d9376 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -18,7 +18,7 @@ # Test push down filter statement ok -set datafusion.explain.logical_plan_only = true; +set datafusion.explain.physical_plan_only = true; statement ok CREATE TABLE IF NOT EXISTS v AS VALUES(1,[1,2,3]),(2,[3,4,5]); @@ -35,12 +35,14 @@ select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -03)----Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -04)------Filter: v.column1 = Int64(2) -05)--------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] +02)--UnnestExec +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[column2@0 as __unnest_placeholder(v.column2)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: column1@0 = 2, projection=[column2@1] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] query I select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; @@ -52,13 +54,14 @@ select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Projection: __unnest_placeholder(v.column2,depth=1) -04)------Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -05)--------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -06)----------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------UnnestExec +06)----------ProjectionExec: expr=[column2@0 as __unnest_placeholder(v.column2)] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] query II select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; @@ -70,13 +73,16 @@ select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -05)--------Filter: v.column1 = Int64(2) -06)----------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 +04)------UnnestExec +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------FilterExec: column1@0 = 2 +09)----------------DataSourceExec: partitions=1, partition_sizes=[1] query II select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; @@ -89,12 +95,14 @@ select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) -03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -05)--------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 OR column1@1 = 2 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------UnnestExec +06)----------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table v; @@ -111,19 +119,40 @@ select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; query TT explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ---- -logical_plan -01)Projection: d.column1, __unnest_placeholder(d.column2,depth=1) AS o -02)--Filter: get_field(__unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) -03)----Unnest: lists[__unnest_placeholder(d.column2)|depth=1] structs[] -04)------Projection: d.column1, d.column2 AS __unnest_placeholder(d.column2) -05)--------TableScan: d projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as o] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: get_field(__unnest_placeholder(d.column2,depth=1)@1, a) = 1 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------UnnestExec +06)----------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +statement ok +drop table d; + +statement ok +CREATE TABLE d AS VALUES (named_struct('a', 1, 'b', 2)), (named_struct('a', 3, 'b', 4)), (named_struct('a', 5, 'b', 6)); + +query II +select * from (select unnest(column1) from d) where "__unnest_placeholder(d.column1).b" > 5; +---- +5 6 +query TT +explain select * from (select unnest(column1) from d) where "__unnest_placeholder(d.column1).b" > 5; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: __unnest_placeholder(d.column1).b@1 > 5 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------UnnestExec +05)--------ProjectionExec: expr=[column1@0 as __unnest_placeholder(d.column1)] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table d; - # Test push down filter with limit for parquet statement ok set datafusion.execution.parquet.pushdown_filters = true; @@ -146,7 +175,7 @@ CREATE TABLE src_table ( # File 1: query I COPY (SELECT * FROM src_table where part_key = 1) -TO 'test_files/scratch/parquet/test_filter_with_limit/part-0.parquet' +TO 'test_files/scratch/push_down_filter/test_filter_with_limit/part-0.parquet' STORED AS PARQUET; ---- 3 @@ -154,7 +183,7 @@ STORED AS PARQUET; # File 2: query I COPY (SELECT * FROM src_table where part_key = 2) -TO 'test_files/scratch/parquet/test_filter_with_limit/part-1.parquet' +TO 'test_files/scratch/push_down_filter/test_filter_with_limit/part-1.parquet' STORED AS PARQUET; ---- 4 @@ -162,7 +191,7 @@ STORED AS PARQUET; # File 3: query I COPY (SELECT * FROM src_table where part_key = 3) -TO 'test_files/scratch/parquet/test_filter_with_limit/part-2.parquet' +TO 'test_files/scratch/push_down_filter/test_filter_with_limit/part-2.parquet' STORED AS PARQUET; ---- 3 @@ -174,14 +203,14 @@ CREATE EXTERNAL TABLE test_filter_with_limit value INT ) STORED AS PARQUET -LOCATION 'test_files/scratch/parquet/test_filter_with_limit/'; +LOCATION 'test_files/scratch/push_down_filter/test_filter_with_limit/'; query TT explain select * from test_filter_with_limit where value = 2 limit 1; ---- -logical_plan -01)Limit: skip=0, fetch=1 -02)--TableScan: test_filter_with_limit projection=[part_key, value], full_filters=[test_filter_with_limit.value = Int32(2)], fetch=1 +physical_plan +01)CoalescePartitionsExec: fetch=1 +02)--DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/test_filter_with_limit/part-0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/test_filter_with_limit/part-1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/test_filter_with_limit/part-2.parquet]]}, projection=[part_key, value], limit=1, file_type=parquet, predicate=value@1 = 2, pruning_predicate=value_null_count@2 != row_count@3 AND value_min@0 <= 2 AND 2 <= value_max@1, required_guarantees=[value in (2)] query II select * from test_filter_with_limit where value = 2 limit 1; @@ -218,44 +247,176 @@ LOCATION 'test_files/scratch/push_down_filter/t.parquet'; query TT explain select a from t where a = '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] # The predicate should not have a column cast when the value is a valid i32 query TT explain select a from t where a != '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 != 100, pruning_predicate=a_null_count@2 != row_count@3 AND (a_min@0 != 100 OR 100 != a_max@1), required_guarantees=[a not in (100)] # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = '99999999999'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99999999999")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99999999999 # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = '99.99'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99.99")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99.99 # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = ''; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = # The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information. query TT explain select a from t where cast(a as string) = '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] # The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost). query TT explain select a from t where CAST(a AS string) = '0123'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("0123")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8View) = 0123 + +# Test dynamic filter pushdown with swapped join inputs (issue #17196) +# Create tables with different sizes to force join input swapping +statement ok +copy (select i as k from generate_series(1, 100) t(i)) to 'test_files/scratch/push_down_filter/small_table.parquet'; + +statement ok +copy (select i as k, i as v from generate_series(1, 1000) t(i)) to 'test_files/scratch/push_down_filter/large_table.parquet'; + +statement ok +create external table small_table stored as parquet location 'test_files/scratch/push_down_filter/small_table.parquet'; + +statement ok +create external table large_table stored as parquet location 'test_files/scratch/push_down_filter/large_table.parquet'; + +# Test that dynamic filter is applied to the correct table after join input swapping +# The small_table should be the build side, large_table should be the probe side with dynamic filter +query TT +explain select * from small_table join large_table on small_table.k = large_table.k where large_table.v >= 50; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(k@0, k@0)] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/small_table.parquet]]}, projection=[k], file_type=parquet +04)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/large_table.parquet]]}, projection=[k, v], file_type=parquet, predicate=v@1 >= 50 AND DynamicFilter [ empty ], pruning_predicate=v_null_count@1 != row_count@2 AND v_max@0 >= 50, required_guarantees=[] + +statement ok +drop table small_table; + +statement ok +drop table large_table; statement ok drop table t; + +# Regression test for https://github.com/apache/datafusion/issues/17188 +query I +COPY (select i as k from generate_series(1, 10000000) as t(i)) +TO 'test_files/scratch/push_down_filter/t1.parquet' +STORED AS PARQUET; +---- +10000000 + +query I +COPY (select i as k, i as v from generate_series(1, 10000000) as t(i)) +TO 'test_files/scratch/push_down_filter/t2.parquet' +STORED AS PARQUET; +---- +10000000 + +statement ok +create external table t1 stored as parquet location 'test_files/scratch/push_down_filter/t1.parquet'; + +statement ok +create external table t2 stored as parquet location 'test_files/scratch/push_down_filter/t2.parquet'; + +# The failure before https://github.com/apache/datafusion/pull/17197 was non-deterministic and random +# So we'll run the same query a couple of times just to have more certainty it's fixed +# Sorry about the spam in this slt test... + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +query III rowsort +select * +from t1 +join t2 on t1.k = t2.k +where v = 1 or v = 10000000 +order by t1.k, t2.v; +---- +1 1 1 +10000000 10000000 10000000 + +# Regression test for https://github.com/apache/datafusion/issues/17512 + +query I +COPY ( + SELECT arrow_cast('2025-01-01T00:00:00Z'::timestamptz, 'Timestamp(Microsecond, Some("UTC"))') AS start_timestamp +) +TO 'test_files/scratch/push_down_filter/17512.parquet' +STORED AS PARQUET; +---- +1 + +statement ok +CREATE EXTERNAL TABLE records STORED AS PARQUET LOCATION 'test_files/scratch/push_down_filter/17512.parquet'; + +query I +SELECT 1 +FROM ( + SELECT start_timestamp + FROM records + WHERE start_timestamp <= '2025-01-01T00:00:00Z'::timestamptz +) AS t +WHERE t.start_timestamp::time < '00:00:01'::time; +---- +1 diff --git a/datafusion/sqllogictest/test_files/qualify.slt b/datafusion/sqllogictest/test_files/qualify.slt new file mode 100644 index 0000000000000..d53b56ce58de1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/qualify.slt @@ -0,0 +1,373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## QUALIFY Clause Tests +########## + +# Create test data +statement ok +CREATE TABLE users ( + id INT, + name VARCHAR, + age INT, + salary DECIMAL(10,2), + dept VARCHAR +) AS VALUES +(1, 'Alice', 25, 50000.00, 'Engineering'), +(2, 'Bob', 30, 60000.00, 'Engineering'), +(3, 'Charlie', 25, 55000.00, 'Engineering'), +(4, 'Diana', 35, 70000.00, 'Marketing'), +(5, 'Eve', 30, 65000.00, 'Marketing'), +(6, 'Frank', 25, 52000.00, 'Engineering'), +(7, 'Grace', 35, 75000.00, 'Marketing'), +(8, 'Henry', 30, 62000.00, 'Engineering'); + +# Basic QUALIFY with ROW_NUMBER +query ITI +SELECT id, name, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn +FROM users +QUALIFY rn = 1 +ORDER BY dept, id; +---- +8 Henry 1 +7 Grace 1 + +# QUALIFY with RANK +query ITI +SELECT id, name, RANK() OVER (ORDER BY salary DESC) as rank +FROM users +QUALIFY rank <= 3 +ORDER BY rank, id; +---- +7 Grace 1 +4 Diana 2 +5 Eve 3 + +# QUALIFY with DENSE_RANK +query ITI +SELECT id, name, DENSE_RANK() OVER (PARTITION BY dept ORDER BY age) as dense_rank +FROM users +QUALIFY dense_rank <= 2 +ORDER BY dept, dense_rank, id; +---- +1 Alice 1 +3 Charlie 1 +6 Frank 1 +2 Bob 2 +8 Henry 2 +5 Eve 1 +4 Diana 2 +7 Grace 2 + +# QUALIFY with complex condition +query ITII +SELECT id, name, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn, + RANK() OVER (ORDER BY age) as age_rank +FROM users +QUALIFY rn <= 2 AND age_rank <= 5 +ORDER BY dept, rn, id; +---- +8 Henry 1 4 +2 Bob 2 4 + +# QUALIFY with LAG function +query ITRR +SELECT id, name, salary, LAG(salary) OVER (PARTITION BY dept ORDER BY id) as prev_salary +FROM users +QUALIFY prev_salary IS NOT NULL AND salary > prev_salary +ORDER BY dept, id; +---- +2 Bob 60000 50000 +8 Henry 62000 52000 +7 Grace 75000 65000 + +# QUALIFY with LEAD function +query ITRR +SELECT id, name, salary, LEAD(salary) OVER (PARTITION BY dept ORDER BY id) as next_salary +FROM users +QUALIFY next_salary IS NOT NULL AND salary < next_salary +ORDER BY dept, id; +---- +1 Alice 50000 60000 +6 Frank 52000 62000 +5 Eve 65000 75000 + +# QUALIFY with NTILE +query ITI +SELECT id, name, NTILE(3) OVER (PARTITION BY dept ORDER BY salary DESC) as tile +FROM users +QUALIFY tile = 1 +ORDER BY dept, id; +---- +2 Bob 1 +8 Henry 1 +7 Grace 1 + +# QUALIFY with PERCENT_RANK +query ITR +SELECT id, name, PERCENT_RANK() OVER (PARTITION BY dept ORDER BY salary) as pct_rank +FROM users +QUALIFY pct_rank >= 0.5 +ORDER BY dept, pct_rank, id; +---- +3 Charlie 0.5 +2 Bob 0.75 +8 Henry 1 +4 Diana 0.5 +7 Grace 1 + +# QUALIFY with CUME_DIST +query ITR +SELECT id, name, CUME_DIST() OVER (PARTITION BY dept ORDER BY age) as cume_dist +FROM users +QUALIFY cume_dist >= 0.75 +ORDER BY dept, cume_dist, id; +---- +2 Bob 1 +8 Henry 1 +4 Diana 1 +7 Grace 1 + +# QUALIFY with multiple window functions +query ITIII +SELECT id, name, + ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn, + RANK() OVER (ORDER BY age) as age_rank, + DENSE_RANK() OVER (PARTITION BY dept ORDER BY age) as dept_age_rank +FROM users +QUALIFY rn <= 2 AND age_rank <= 4 AND dept_age_rank <= 2 +ORDER BY dept, rn, id; +---- +8 Henry 1 4 2 +2 Bob 2 4 2 + +# QUALIFY with arithmetic expressions +query ITRI +SELECT id, name, salary, + ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) as rn +FROM users +QUALIFY rn = 1 AND salary > 60000 +ORDER BY dept, id; +---- +8 Henry 62000 1 +7 Grace 75000 1 + +# QUALIFY with string functions +query ITI +SELECT id, name, + ROW_NUMBER() OVER (PARTITION BY dept ORDER BY name) as rn +FROM users +QUALIFY rn = 1 +ORDER BY dept, id; +---- +1 Alice 1 +4 Diana 1 + +# window function with aggregate function +query ITI +SELECT id, name, COUNT(*) OVER (PARTITION BY dept) as cnt +FROM users +QUALIFY cnt > 4 +ORDER BY dept, id; +---- +1 Alice 5 +2 Bob 5 +3 Charlie 5 +6 Frank 5 +8 Henry 5 + +# QUALIFY with HAVING +query TR +SELECT dept, AVG(salary) OVER (PARTITION BY dept) as r +FROM users +WHERE salary > 5000 +GROUP BY dept, salary +HAVING SUM(salary) > 20000 +QUALIFY r > 60000 +---- +Marketing 70000 +Marketing 70000 +Marketing 70000 + +# QUALIFY with aggregate function reference from projection +query TR +SELECT dept, SUM(salary) AS s +FROM users +GROUP BY dept +QUALIFY RANK() OVER (ORDER BY dept DESC) = 1 AND s > 1000 +ORDER BY dept; +---- +Marketing 210000 + +# QUALIFY with aggregate function +query T +SELECT dept +FROM users +GROUP BY dept +QUALIFY RANK() OVER (ORDER BY dept DESC) = 1 AND SUM(salary) > 1000 +ORDER BY dept; +---- +Marketing + +# QUALIFY with aggregate function within window function +query TR +SELECT dept, SUM(salary) AS s +FROM users +GROUP BY dept +QUALIFY RANK() OVER (ORDER BY SUM(salary) DESC) = 1 +ORDER BY dept; +---- +Engineering 279000 + +# QUALIFY with aggregate function reference from projection within window function +query TR +SELECT dept, SUM(salary) AS s +FROM users +GROUP BY dept +QUALIFY RANK() OVER (ORDER BY s DESC) = 1 +ORDER BY dept; +---- +Engineering 279000 + +# Error: QUALIFY without window functions +query error +SELECT id, name FROM users QUALIFY id > 1; + +# Window function in QUALIFY +query IT +SELECT id, name FROM users QUALIFY COUNT(*) OVER () > 1 ORDER BY id; +---- +1 Alice +2 Bob +3 Charlie +4 Diana +5 Eve +6 Frank +7 Grace +8 Henry + +# verify the logical plan and physical plan +query TT +EXPLAIN SELECT id, name FROM users QUALIFY COUNT(*) OVER () > 1 ORDER BY id; +---- +logical_plan +01)Sort: users.id ASC NULLS LAST +02)--Projection: users.id, users.name +03)----Filter: count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING > Int64(1) +04)------WindowAggr: windowExpr=[[count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +05)--------TableScan: users projection=[id, name] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 > 1, projection=[id@0, name@1] +04)------WindowAggExec: wdw=[count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] + +# plan row_number() +query TT +explain select row_number() over (PARTITION BY dept) as rk from users qualify rk > 1; +---- +logical_plan +01)Projection: row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rk +02)--Filter: row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING > UInt64(1) +03)----Projection: row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING +04)------WindowAggr: windowExpr=[[row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +05)--------TableScan: users projection=[dept] +physical_plan +01)ProjectionExec: expr=[row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@0 as rk] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@0 > 1 +04)------ProjectionExec: expr=[row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] +05)--------BoundedWindowAggExec: wdw=[row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] +06)----------SortExec: expr=[dept@0 ASC NULLS LAST], preserve_partitioning=[false] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + +# plan with window function and group by +query TT +EXPLAIN SELECT dept, AVG(salary) OVER (PARTITION BY dept) as r +FROM users +WHERE salary > 5000 +GROUP BY dept, salary +HAVING SUM(salary) > 20000 +QUALIFY r > 60000 +---- +logical_plan +01)Projection: users.dept, avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS r +02)--Filter: avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING > Decimal128(Some(60000000000),14,6) +03)----Projection: users.dept, avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING +04)------WindowAggr: windowExpr=[[avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +05)--------Projection: users.dept, users.salary +06)----------Filter: sum(users.salary) > Decimal128(Some(2000000),20,2) +07)------------Aggregate: groupBy=[[users.dept, users.salary]], aggr=[[sum(users.salary)]] +08)--------------Filter: users.salary > Decimal128(Some(500000),10,2) +09)----------------TableScan: users projection=[salary, dept] +physical_plan +01)ProjectionExec: expr=[dept@0 as dept, avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as r] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 > Some(60000000000),14,6 +04)------ProjectionExec: expr=[dept@0 as dept, avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] +05)--------WindowAggExec: wdw=[avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Decimal128(14, 6), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +06)----------SortExec: expr=[dept@0 ASC NULLS LAST], preserve_partitioning=[true] +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------RepartitionExec: partitioning=Hash([dept@0], 4), input_partitions=4 +09)----------------CoalesceBatchesExec: target_batch_size=8192 +10)------------------FilterExec: sum(users.salary)@2 > Some(2000000),20,2, projection=[dept@0, salary@1] +11)--------------------AggregateExec: mode=FinalPartitioned, gby=[dept@0 as dept, salary@1 as salary], aggr=[sum(users.salary)] +12)----------------------CoalesceBatchesExec: target_batch_size=8192 +13)------------------------RepartitionExec: partitioning=Hash([dept@0, salary@1], 4), input_partitions=4 +14)--------------------------AggregateExec: mode=Partial, gby=[dept@1 as dept, salary@0 as salary], aggr=[sum(users.salary)] +15)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +16)------------------------------CoalesceBatchesExec: target_batch_size=8192 +17)--------------------------------FilterExec: salary@0 > Some(500000),10,2 +18)----------------------------------DataSourceExec: partitions=1, partition_sizes=[1] + +# plan with aggregate function +query TT +EXPLAIN SELECT dept, SUM(salary) AS s +FROM users +GROUP BY dept +QUALIFY RANK() OVER (ORDER BY s DESC) = 1 +ORDER BY dept; +---- +logical_plan +01)Sort: users.dept ASC NULLS LAST +02)--Projection: users.dept, sum(users.salary) AS s +03)----Filter: rank() ORDER BY [sum(users.salary) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = UInt64(1) +04)------WindowAggr: windowExpr=[[rank() ORDER BY [sum(users.salary) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------Aggregate: groupBy=[[users.dept]], aggr=[[sum(users.salary)]] +06)----------TableScan: users projection=[salary, dept] +physical_plan +01)SortPreservingMergeExec: [dept@0 ASC NULLS LAST] +02)--SortExec: expr=[dept@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[dept@0 as dept, sum(users.salary)@1 as s] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: rank() ORDER BY [sum(users.salary) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 = 1, projection=[dept@0, sum(users.salary)@1] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------BoundedWindowAggExec: wdw=[rank() ORDER BY [sum(users.salary) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() ORDER BY [sum(users.salary) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +08)--------------SortPreservingMergeExec: [sum(users.salary)@1 DESC] +09)----------------SortExec: expr=[sum(users.salary)@1 DESC], preserve_partitioning=[true] +10)------------------AggregateExec: mode=FinalPartitioned, gby=[dept@0 as dept], aggr=[sum(users.salary)] +11)--------------------CoalesceBatchesExec: target_batch_size=8192 +12)----------------------RepartitionExec: partitioning=Hash([dept@0], 4), input_partitions=4 +13)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------------------AggregateExec: mode=Partial, gby=[dept@1 as dept], aggr=[sum(users.salary)] +15)----------------------------DataSourceExec: partitions=1, partition_sizes=[1] + +# Clean up +statement ok +DROP TABLE users; diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt deleted file mode 100644 index 44ba61e877d97..0000000000000 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ /dev/null @@ -1,898 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -statement ok -CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES - ('abc', '^(a)', 1, 'i'), - ('ABC', '^(A).*', 1, 'i'), - ('aBc', '(b|d)', 1, 'i'), - ('AbC', '(B|D)', 2, null), - ('aBC', '^(b|c)', 3, null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), - ('Düsseldorf','[\p{Letter}-]+', 3, null), - ('Москва', '[\p{L}-]+', 4, null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), - ('إسرائيل', '^\p{Arabic}+$', 2, null); - -# -# regexp_like tests -# - -query B -SELECT regexp_like(str, pattern, flags) FROM t; ----- -true -true -true -false -false -false -true -true -true -true -true - -query B -SELECT str ~ NULL FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -select str ~ right('foo', NULL) FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -select right('foo', NULL) !~ str FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -SELECT regexp_like('foobarbequebaz', ''); ----- -true - -query B -SELECT regexp_like('', ''); ----- -true - -query B -SELECT regexp_like('foobarbequebaz', '(bar)(beque)'); ----- -true - -query B -SELECT regexp_like('fooBarb -eQuebaz', '(bar).*(que)', 'is'); ----- -true - -query B -SELECT regexp_like('foobarbequebaz', '(ba3r)(bequ34e)'); ----- -false - -query B -SELECT regexp_like('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); ----- -true - -query B -SELECT regexp_like('aaa-0', '.*-(\d)'); ----- -true - -query B -SELECT regexp_like('bb-1', '.*-(\d)'); ----- -true - -query B -SELECT regexp_like('aa', '.*-(\d)'); ----- -false - -query B -SELECT regexp_like(NULL, '.*-(\d)'); ----- -NULL - -query B -SELECT regexp_like('aaa-0', NULL); ----- -NULL - -query B -SELECT regexp_like(null, '.*-(\d)'); ----- -NULL - -query error Error during planning: regexp_like\(\) does not support the "global" option -SELECT regexp_like('bb-1', '.*-(\d)', 'g'); - -query error Error during planning: regexp_like\(\) does not support the "global" option -SELECT regexp_like('bb-1', '.*-(\d)', 'g'); - -query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) -SELECT regexp_like('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); - -# look-around is not supported and will just return false -query B -SELECT regexp_like('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); ----- -false - -query B -select regexp_like('aaa-555', '.*-(\d*)'); ----- -true - -# -# regexp_match tests -# - -query ? -SELECT regexp_match(str, pattern, flags) FROM t; ----- -[a] -[A] -[B] -NULL -NULL -NULL -[010] -[Düsseldorf] -[Москва] -[Köln] -[إسرائيل] - -# test string view -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query ? -SELECT regexp_match(str, pattern, flags) FROM t_stringview; ----- -[a] -[A] -[B] -NULL -NULL -NULL -[010] -[Düsseldorf] -[Москва] -[Köln] -[إسرائيل] - -statement ok -DROP TABLE t_stringview; - -query ? -SELECT regexp_match('foobarbequebaz', ''); ----- -[] - -query ? -SELECT regexp_match('', ''); ----- -[] - -query ? -SELECT regexp_match('foobarbequebaz', '(bar)(beque)'); ----- -[bar, beque] - -query ? -SELECT regexp_match('fooBarb -eQuebaz', '(bar).*(que)', 'is'); ----- -[Bar, Que] - -query ? -SELECT regexp_match('foobarbequebaz', '(ba3r)(bequ34e)'); ----- -NULL - -query ? -SELECT regexp_match('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); ----- -[barbeque] - -query ? -SELECT regexp_match('aaa-0', '.*-(\d)'); ----- -[0] - -query ? -SELECT regexp_match('bb-1', '.*-(\d)'); ----- -[1] - -query ? -SELECT regexp_match('aa', '.*-(\d)'); ----- -NULL - -query ? -SELECT regexp_match(NULL, '.*-(\d)'); ----- -NULL - -query ? -SELECT regexp_match('aaa-0', NULL); ----- -NULL - -query ? -SELECT regexp_match(null, '.*-(\d)'); ----- -NULL - -query error Error during planning: regexp_match\(\) does not support the "global" option -SELECT regexp_match('bb-1', '.*-(\d)', 'g'); - -query error Error during planning: regexp_match\(\) does not support the "global" option -SELECT regexp_match('bb-1', '.*-(\d)', 'g'); - -query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) -SELECT regexp_match('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); - -# look-around is not supported and will just return null -query ? -SELECT regexp_match('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); ----- -NULL - -# ported test -query ? -SELECT regexp_match('aaa-555', '.*-(\d*)'); ----- -[555] - -query B -select 'abc' ~ null; ----- -NULL - -query B -select null ~ null; ----- -NULL - -query B -select null ~ 'abc'; ----- -NULL - -query B -select 'abc' ~* null; ----- -NULL - -query B -select null ~* null; ----- -NULL - -query B -select null ~* 'abc'; ----- -NULL - -query B -select 'abc' !~ null; ----- -NULL - -query B -select null !~ null; ----- -NULL - -query B -select null !~ 'abc'; ----- -NULL - -query B -select 'abc' !~* null; ----- -NULL - -query B -select null !~* null; ----- -NULL - -query B -select null !~* 'abc'; ----- -NULL - -# -# regexp_replace tests -# - -query T -SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t; ----- -Xbc -X -aXc -AbC -aBC -4000 -X -X -X -X -X - -# test string view -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query T -SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview; ----- -Xbc -X -aXc -AbC -aBC -4000 -X -X -X -X -X - -statement ok -DROP TABLE t_stringview; - -query T -SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi'); ----- -XXX - -query T -SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'i'); ----- -XabcABC - -query T -SELECT regexp_replace('foobarbaz', 'b..', 'X', 'g'); ----- -fooXX - -query T -SELECT regexp_replace('foobarbaz', 'b..', 'X'); ----- -fooXbaz - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ----- -fooXarYXazY - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL); ----- -NULL - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', NULL, 'g'); ----- -NULL - -query T -SELECT regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g'); ----- -NULL - -query T -SELECT regexp_replace('Thomas', '.[mN]a.', 'M'); ----- -ThM - -query T -SELECT regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g'); ----- -NULL - -query T -SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') ----- -fooxx - -query TTT -select - regexp_replace(col, NULL, 'c'), - regexp_replace(col, 'a', NULL), - regexp_replace(col, 'a', 'c', NULL) -from (values ('a'), ('b')) as tbl(col); ----- -NULL NULL NULL -NULL NULL NULL - -# multiline string -query B -SELECT 'foo\nbar\nbaz' ~ 'bar'; ----- -true - -statement error -Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, dict_is_ordered: false, metadata -: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, dict_is_ordered: false, metadata: {} }) -select [1,2] ~ [3]; - -query B -SELECT 'foo\nbar\nbaz' LIKE '%bar%'; ----- -true - -query B -SELECT NULL LIKE NULL; ----- -NULL - -query B -SELECT NULL iLIKE NULL; ----- -NULL - -query B -SELECT NULL not LIKE NULL; ----- -NULL - -query B -SELECT NULL not iLIKE NULL; ----- -NULL - -# regexp_count tests - -# regexp_count tests from postgresql -# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 - -query I -SELECT regexp_count('123123123123123', '(12)3'); ----- -5 - -query I -SELECT regexp_count('123123123123', '123', 1); ----- -4 - -query I -SELECT regexp_count('123123123123', '123', 3); ----- -3 - -query I -SELECT regexp_count('123123123123', '123', 33); ----- -0 - -query I -SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); ----- -0 - -query I -SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); ----- -4 - -statement error -External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based -SELECT regexp_count('123123123123', '123', 0); - -statement error -External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based -SELECT regexp_count('123123123123', '123', -3); - -statement error -External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag -SELECT regexp_count('123123123123', '123', 1, 'g'); - -query I -SELECT regexp_count(str, '\w') from t; ----- -3 -3 -3 -3 -3 -4 -4 -10 -6 -4 -7 - -query I -SELECT regexp_count(str, '\w{2}', start) from t; ----- -1 -1 -1 -1 -0 -2 -1 -4 -1 -2 -3 - -query I -SELECT regexp_count(str, 'ab', 1, 'i') from t; ----- -1 -1 -1 -1 -1 -0 -0 -0 -0 -0 -0 - - -query I -SELECT regexp_count(str, pattern) from t; ----- -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start) from t; ----- -1 -1 -0 -0 -0 -0 -0 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start, flags) from t; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test type coercion -query I -SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test string views - -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query I -SELECT regexp_count(str, '\w') from t_stringview; ----- -3 -3 -3 -3 -3 -4 -4 -10 -6 -4 -7 - -query I -SELECT regexp_count(str, '\w{2}', start) from t_stringview; ----- -1 -1 -1 -1 -0 -2 -1 -4 -1 -2 -3 - -query I -SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview; ----- -1 -1 -1 -1 -1 -0 -0 -0 -0 -0 -0 - - -query I -SELECT regexp_count(str, pattern) from t_stringview; ----- -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start) from t_stringview; ----- -1 -1 -0 -0 -0 -0 -0 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start, flags) from t_stringview; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test type coercion -query I -SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# NULL tests - -query I -SELECT regexp_count(NULL, NULL); ----- -0 - -query I -SELECT regexp_count(NULL, 'a'); ----- -0 - -query I -SELECT regexp_count('a', NULL); ----- -0 - -query I -SELECT regexp_count(NULL, NULL, NULL, NULL); ----- -0 - -statement ok -CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); - -query I -SELECT regexp_count(str, pattern, start, flags) from empty_table; ----- - -statement ok -INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); - -query I -SELECT regexp_count(str, pattern, start, flags) from empty_table; ----- -0 -0 -0 -0 - -statement ok -drop table t; - -statement ok -create or replace table strings as values - ('FooBar'), - ('Foo'), - ('Foo'), - ('Bar'), - ('FooBar'), - ('Bar'), - ('Baz'); - -statement ok -create or replace table dict_table as -select arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1 -from strings; - -query T -select column1 from dict_table where column1 LIKE '%oo%'; ----- -FooBar -Foo -Foo -FooBar - -query T -select column1 from dict_table where column1 NOT LIKE '%oo%'; ----- -Bar -Bar -Baz - -query T -select column1 from dict_table where column1 ILIKE '%oO%'; ----- -FooBar -Foo -Foo -FooBar - -query T -select column1 from dict_table where column1 NOT ILIKE '%oO%'; ----- -Bar -Bar -Baz - - -# plan should not cast the column, instead it should use the dictionary directly -query TT -explain select column1 from dict_table where column1 LIKE '%oo%'; ----- -logical_plan -01)Filter: dict_table.column1 LIKE Utf8("%oo%") -02)--TableScan: dict_table projection=[column1] -physical_plan -01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column1@0 LIKE %oo% -03)----DataSourceExec: partitions=1, partition_sizes=[1] - -# Ensure casting / coercion works for all operators -# (there should be no casts to Utf8) -query TT -explain select - column1 LIKE '%oo%', - column1 NOT LIKE '%oo%', - column1 ILIKE '%oo%', - column1 NOT ILIKE '%oo%' -from dict_table; ----- -logical_plan -01)Projection: dict_table.column1 LIKE Utf8("%oo%"), dict_table.column1 NOT LIKE Utf8("%oo%"), dict_table.column1 ILIKE Utf8("%oo%"), dict_table.column1 NOT ILIKE Utf8("%oo%") -02)--TableScan: dict_table projection=[column1] -physical_plan -01)ProjectionExec: expr=[column1@0 LIKE %oo% as dict_table.column1 LIKE Utf8("%oo%"), column1@0 NOT LIKE %oo% as dict_table.column1 NOT LIKE Utf8("%oo%"), column1@0 ILIKE %oo% as dict_table.column1 ILIKE Utf8("%oo%"), column1@0 NOT ILIKE %oo% as dict_table.column1 NOT ILIKE Utf8("%oo%")] -02)--DataSourceExec: partitions=1, partition_sizes=[1] - -statement ok -drop table strings - -statement ok -drop table dict_table diff --git a/datafusion/sqllogictest/test_files/regexp/README.md b/datafusion/sqllogictest/test_files/regexp/README.md new file mode 100644 index 0000000000000..7e5efc5b5ddf2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/README.md @@ -0,0 +1,59 @@ + + +# Regexp Test Files + +This directory contains test files for regular expression (regexp) functions in DataFusion. + +## Directory Structure + +``` +regexp/ + - init_data.slt.part // Shared test data for regexp functions + - regexp_like.slt // Tests for regexp_like function + - regexp_count.slt // Tests for regexp_count function + - regexp_match.slt // Tests for regexp_match function + - regexp_replace.slt // Tests for regexp_replace function +``` + +## Tested Functions + +1. `regexp_like`: Check if a string matches a regular expression +2. `regexp_count`: Count occurrences of a pattern in a string +3. `regexp_match`: Extract matching substrings +4. `regexp_replace`: Replace matched substrings + +## Test Data + +Test data is centralized in the `init_data.slt.part` file and imported into each test file using the `include` directive. This approach ensures: + +Consistent test data across different regexp function tests +Easy maintenance of test data +Reduced duplication + +## Test Coverage + +Each test file covers: + +Basic functionality +Case-insensitive matching +Null handling +Start position tests +Capture group handling +Different string types (UTF-8, Unicode) diff --git a/datafusion/sqllogictest/test_files/regexp/init_data.slt.part b/datafusion/sqllogictest/test_files/regexp/init_data.slt.part new file mode 100644 index 0000000000000..ed6fb0e872df9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/init_data.slt.part @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +create table regexp_test_data (str varchar, pattern varchar, start int, flags varchar) as values + (NULL, '^(a)', 1, 'i'), + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_count.slt b/datafusion/sqllogictest/test_files/regexp/regexp_count.slt new file mode 100644 index 0000000000000..d842a1ee81dfb --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_count.slt @@ -0,0 +1,344 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +statement error +External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag +SELECT regexp_count('123123123123', '123', 1, 'g'); + +query I +SELECT regexp_count(str, '\w') from regexp_test_data; +---- +0 +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from regexp_test_data; +---- +0 +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from regexp_test_data; +---- +0 +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from regexp_test_data; +---- +0 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from regexp_test_data; +---- +0 +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from regexp_test_data; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from regexp_test_data; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query I +SELECT regexp_count(str, '\w') from t_stringview; +---- +0 +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t_stringview; +---- +0 +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview; +---- +0 +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t_stringview; +---- +0 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t_stringview; +---- +0 +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t_stringview; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + +statement ok +drop table t_stringview; + +statement ok +drop table empty_table; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt b/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt new file mode 100644 index 0000000000000..d4e98e6431678 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_instr.slt @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query I +SELECT regexp_instr('123123123123123', '(12)3'); +---- +1 + +query I +SELECT regexp_instr('123123123123', '123', 1); +---- +1 + +query I +SELECT regexp_instr('123123123123', '123', 3); +---- +4 + +query I +SELECT regexp_instr('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, ''); +---- +0 + +query I +SELECT regexp_instr('ABCABCABCABC', 'Abc', 1, 2, 'i'); +---- +4 + +query I +SELECT + regexp_instr( + 'The quick brown fox jumps over the lazy dog.', + ' (quick) (brown) (fox)', + 1, + 1, + 'i', + 2 -- subexpression_number (2 for second group) + ); +---- +11 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_instr() requires start to be 1 based +SELECT regexp_instr('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_instr() requires start to be 1 based +SELECT regexp_instr('123123123123', '123', -3); + +query I +SELECT regexp_instr(str, pattern) FROM regexp_test_data; +---- +NULL +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_instr(str, pattern, start) FROM regexp_test_data; +---- +NULL +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + + +statement ok +CREATE TABLE t_stringview AS +SELECT + arrow_cast(str, 'Utf8View') AS str, + arrow_cast(pattern, 'Utf8View') AS pattern, + arrow_cast(start, 'Int64') AS start +FROM regexp_test_data; + +query I +SELECT regexp_instr(str, pattern, start) FROM t_stringview; +---- +NULL +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + +query I +SELECT regexp_instr( + arrow_cast(str, 'Utf8'), + arrow_cast(pattern, 'LargeUtf8'), + arrow_cast(start, 'Int32') +) FROM t_stringview; +---- +NULL +1 +1 +0 +0 +0 +0 +0 +3 +4 +1 +2 + +query I +SELECT regexp_instr(NULL, NULL); +---- +NULL + +query I +SELECT regexp_instr(NULL, 'a'); +---- +NULL + +query I +SELECT regexp_instr('a', NULL); +---- +NULL + +query I +SELECT regexp_instr('😀abcdef', 'abc'); +---- +2 + + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int); + +query I +SELECT regexp_instr(str, pattern, start) FROM empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES + ('a', NULL, 1), + (NULL, 'a', 1), + (NULL, NULL, 1), + (NULL, NULL, NULL); + +query I +SELECT regexp_instr(str, pattern, start) FROM empty_table; +---- +NULL +NULL +NULL +NULL + +statement ok +DROP TABLE t_stringview; + +statement ok +DROP TABLE empty_table; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_like.slt b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt new file mode 100644 index 0000000000000..dd42511eade93 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt @@ -0,0 +1,339 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query B +SELECT regexp_like(str, pattern, flags) FROM regexp_test_data; +---- +NULL +true +true +true +false +false +false +true +true +true +true +true + +query B +SELECT str ~ NULL FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +select str ~ right('foo', NULL) FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +select right('foo', NULL) !~ str FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +SELECT regexp_like('foobarbequebaz', ''); +---- +true + +query B +SELECT regexp_like('', ''); +---- +true + +query B +SELECT regexp_like('foobarbequebaz', '(bar)(beque)'); +---- +true + +query B +SELECT regexp_like('fooBarbeQuebaz', '(bar).*(que)', 'is'); +---- +true + +query B +SELECT regexp_like('foobarbequebaz', '(ba3r)(bequ34e)'); +---- +false + +query B +SELECT regexp_like('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); +---- +true + +query B +SELECT regexp_like('aaa-0', '.*-(\d)'); +---- +true + +query B +SELECT regexp_like('bb-1', '.*-(\d)'); +---- +true + +query B +SELECT regexp_like('aa', '.*-(\d)'); +---- +false + +query B +SELECT regexp_like(NULL, '.*-(\d)'); +---- +NULL + +query B +SELECT regexp_like('aaa-0', NULL); +---- +NULL + +query B +SELECT regexp_like(null, '.*-(\d)'); +---- +NULL + +query error Error during planning: regexp_like\(\) does not support the "global" option +SELECT regexp_like('bb-1', '.*-(\d)', 'g'); + +query error Error during planning: regexp_like\(\) does not support the "global" option +SELECT regexp_like('bb-1', '.*-(\d)', 'g'); + +query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) +SELECT regexp_like('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); + +# look-around is not supported and will just return false +query B +SELECT regexp_like('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); +---- +false + +query B +select regexp_like('aaa-555', '.*-(\d*)'); +---- +true + +# multiline string +query B +SELECT 'foo\nbar\nbaz' ~ 'bar'; +---- +true + +statement error +Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, metadata: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, metadata: {} }) +select [1,2] ~ [3]; + +query B +SELECT 'foo\nbar\nbaz' LIKE '%bar%'; +---- +true + +query B +SELECT NULL LIKE NULL; +---- +NULL + +query B +SELECT NULL iLIKE NULL; +---- +NULL + +query B +SELECT NULL not LIKE NULL; +---- +NULL + +query B +SELECT NULL not iLIKE NULL; +---- +NULL + +statement ok +create or replace table strings as values + ('FooBar'), + ('Foo'), + ('Foo'), + ('Bar'), + ('FooBar'), + ('Bar'), + ('Baz'); + +statement ok +create or replace table dict_table as +select arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1 +from strings; + +query T +select column1 from dict_table where column1 LIKE '%oo%'; +---- +FooBar +Foo +Foo +FooBar + +query T +select column1 from dict_table where column1 NOT LIKE '%oo%'; +---- +Bar +Bar +Baz + +query T +select column1 from dict_table where column1 ILIKE '%oO%'; +---- +FooBar +Foo +Foo +FooBar + +query T +select column1 from dict_table where column1 NOT ILIKE '%oO%'; +---- +Bar +Bar +Baz + + +# plan should not cast the column, instead it should use the dictionary directly +query TT +explain select column1 from dict_table where column1 LIKE '%oo%'; +---- +logical_plan +01)Filter: dict_table.column1 LIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column1@0 LIKE %oo% +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +# Ensure casting / coercion works for all operators +# (there should be no casts to Utf8) +query TT +explain select + column1 LIKE '%oo%', + column1 NOT LIKE '%oo%', + column1 ILIKE '%oo%', + column1 NOT ILIKE '%oo%' +from dict_table; +---- +logical_plan +01)Projection: dict_table.column1 LIKE Utf8("%oo%"), dict_table.column1 NOT LIKE Utf8("%oo%"), dict_table.column1 ILIKE Utf8("%oo%"), dict_table.column1 NOT ILIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 LIKE %oo% as dict_table.column1 LIKE Utf8("%oo%"), column1@0 NOT LIKE %oo% as dict_table.column1 NOT LIKE Utf8("%oo%"), column1@0 ILIKE %oo% as dict_table.column1 ILIKE Utf8("%oo%"), column1@0 NOT ILIKE %oo% as dict_table.column1 NOT ILIKE Utf8("%oo%")] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +drop table strings + +statement ok +drop table dict_table + +# Ensure that regexp_like is rewritten to use the (more optimized) regex operators +statement ok +create table regexp_test as values + ('foobar', 'i'), + ('Foo', 'i'), + ('bar', 'mi') ; + +# Expressions that can be rewritten to use the ~ operator (which is more optimized) +# (expect the plans to use the ~ / ~* operators, not the REGEXP_LIKE function) +query TT +explain select + regexp_like(column1, 'fo.*'), + regexp_like(column1, 'fo.*', 'i'), +from regexp_test; +---- +logical_plan +01)Projection: regexp_test.column1 ~ Utf8("fo.*") AS regexp_like(regexp_test.column1,Utf8("fo.*")), regexp_test.column1 ~* Utf8("fo.*") AS regexp_like(regexp_test.column1,Utf8("fo.*"),Utf8("i")) +02)--TableScan: regexp_test projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 ~ fo.* as regexp_like(regexp_test.column1,Utf8("fo.*")), column1@0 ~* fo.* as regexp_like(regexp_test.column1,Utf8("fo.*"),Utf8("i"))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query BB +select + regexp_like(column1, 'fo.*'), + regexp_like(column1, 'fo.*', 'i'), +from regexp_test; +---- +true true +false true +false false + +# Expressions that can not be rewritten to use the ~ / ~* operators +# (expect the plans to use the REGEXP_LIKE function) +query TT +explain select + regexp_like(column1, 'f.*r', 'mi'), -- args + regexp_like(column1, 'f.*r', column2) -- non scalar flags +from regexp_test; +---- +logical_plan +01)Projection: regexp_like(regexp_test.column1, Utf8("f.*r"), Utf8("mi")), regexp_like(regexp_test.column1, Utf8("f.*r"), regexp_test.column2) +02)--TableScan: regexp_test projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[regexp_like(column1@0, f.*r, mi) as regexp_like(regexp_test.column1,Utf8("f.*r"),Utf8("mi")), regexp_like(column1@0, f.*r, column2@1) as regexp_like(regexp_test.column1,Utf8("f.*r"),regexp_test.column2)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query BB +select + regexp_like(column1, 'f.*r', 'mi'), -- args + regexp_like(column1, 'f.*r', column2) -- non scalar flags +from regexp_test; +---- +true true +false false +false false + +statement ok +drop table if exists dict_table; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_match.slt b/datafusion/sqllogictest/test_files/regexp/regexp_match.slt new file mode 100644 index 0000000000000..e79af4774aa21 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_match.slt @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query ? +SELECT regexp_match(str, pattern, flags) FROM regexp_test_data; +---- +NULL +[a] +[A] +[B] +NULL +NULL +NULL +[010] +[Düsseldorf] +[Москва] +[Köln] +[إسرائيل] + +# test string view +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query ? +SELECT regexp_match(str, pattern, flags) FROM t_stringview; +---- +NULL +[a] +[A] +[B] +NULL +NULL +NULL +[010] +[Düsseldorf] +[Москва] +[Köln] +[إسرائيل] + +statement ok +DROP TABLE t_stringview; + +query ? +SELECT regexp_match('foobarbequebaz', ''); +---- +[] + +query ? +SELECT regexp_match('', ''); +---- +[] + +query ? +SELECT regexp_match('foobarbequebaz', '(bar)(beque)'); +---- +[bar, beque] + +query ? +SELECT regexp_match('fooBarb +eQuebaz', '(bar).*(que)', 'is'); +---- +[Bar, Que] + +query ? +SELECT regexp_match('foobarbequebaz', '(ba3r)(bequ34e)'); +---- +NULL + +query ? +SELECT regexp_match('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); +---- +[barbeque] + +query ? +SELECT regexp_match('aaa-0', '.*-(\d)'); +---- +[0] + +query ? +SELECT regexp_match('bb-1', '.*-(\d)'); +---- +[1] + +query ? +SELECT regexp_match('aa', '.*-(\d)'); +---- +NULL + +query ? +SELECT regexp_match(NULL, '.*-(\d)'); +---- +NULL + +query ? +SELECT regexp_match('aaa-0', NULL); +---- +NULL + +query ? +SELECT regexp_match(null, '.*-(\d)'); +---- +NULL + +query error Error during planning: regexp_match\(\) does not support the "global" option +SELECT regexp_match('bb-1', '.*-(\d)', 'g'); + +query error Error during planning: regexp_match\(\) does not support the "global" option +SELECT regexp_match('bb-1', '.*-(\d)', 'g'); + +query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) +SELECT regexp_match('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); + +# look-around is not supported and will just return null +query ? +SELECT regexp_match('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); +---- +NULL + +# ported test +query ? +SELECT regexp_match('aaa-555', '.*-(\d*)'); +---- +[555] + +query B +select 'abc' ~ null; +---- +NULL + +query B +select null ~ null; +---- +NULL + +query B +select null ~ 'abc'; +---- +NULL + +query B +select 'abc' ~* null; +---- +NULL + +query B +select null ~* null; +---- +NULL + +query B +select null ~* 'abc'; +---- +NULL + +query B +select 'abc' !~ null; +---- +NULL + +query B +select null !~ null; +---- +NULL + +query B +select null !~ 'abc'; +---- +NULL + +query B +select 'abc' !~* null; +---- +NULL + +query B +select null !~* null; +---- +NULL + +query B +select null !~* 'abc'; +---- +NULL diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt new file mode 100644 index 0000000000000..a16801adcef78 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query T +SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM regexp_test_data; +---- +NULL +Xbc +X +aXc +AbC +aBC +4000 +X +X +X +X +X + +# test string view +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query T +SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview; +---- +NULL +Xbc +X +aXc +AbC +aBC +4000 +X +X +X +X +X + +statement ok +DROP TABLE t_stringview; + +query T +SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi'); +---- +XXX + +query T +SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'i'); +---- +XabcABC + +query T +SELECT regexp_replace('foobarbaz', 'b..', 'X', 'g'); +---- +fooXX + +query T +SELECT regexp_replace('foobarbaz', 'b..', 'X'); +---- +fooXbaz + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); +---- +fooXarYXazY + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL); +---- +NULL + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', NULL, 'g'); +---- +NULL + +query T +SELECT regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g'); +---- +NULL + +query T +SELECT regexp_replace('Thomas', '.[mN]a.', 'M'); +---- +ThM + +query T +SELECT regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g'); +---- +NULL + +query T +SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') +---- +fooxx + +query TTT +select + regexp_replace(col, NULL, 'c'), + regexp_replace(col, 'a', NULL), + regexp_replace(col, 'a', 'c', NULL) +from (values ('a'), ('b')) as tbl(col); +---- +NULL NULL NULL +NULL NULL NULL diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt index 70666346e2cab..29d20d10b6715 100644 --- a/datafusion/sqllogictest/test_files/repartition.slt +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -46,8 +46,8 @@ physical_plan 01)AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----RepartitionExec: partitioning=Hash([column1@0], 4), input_partitions=4 -04)------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition/parquet_table/2.parquet]]}, projection=[column1, column2], file_type=parquet # disable round robin repartitioning diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 2b30de572c8cc..c536c8165c5a3 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..137], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:137..274], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:274..411], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:411..547]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] +03)----DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..135], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:135..270], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:270..405], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:405..537]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..137], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:137..274], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:274..411], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:411..547]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] +03)----DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..135], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:135..270], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:270..405], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:405..537]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ physical_plan 02)--SortExec: expr=[column1@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------FilterExec: column1@0 != 42 -05)--------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..272], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:272..538, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..278], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:278..547]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] +05)--------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..266], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:266..526, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..272], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:272..537]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan 01)SortPreservingMergeExec: [column1@0 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: column1@0 != 42 -04)------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..269], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..273], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:273..547], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:269..538]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] +04)------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..263], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..268], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:268..537], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:263..526]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index f583d659fd4f5..b0e200015dfd8 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -523,7 +523,7 @@ query RRR rowsort select log(a, 64) a, log(b), log(10, b) from unsigned_integers; ---- 3 NULL NULL -3.7855785 4 4 +3.785578521429 4 4 6 3 3 Infinity 2 2 @@ -1832,7 +1832,7 @@ query TT EXPLAIN SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string; ---- logical_plan -01)Projection: simple_string.letter, simple_string.letter = Utf8("A") AS simple_string.letter = left(Utf8("APACHE"),Int64(1)) +01)Projection: simple_string.letter, simple_string.letter = Utf8View("A") AS simple_string.letter = left(Utf8("APACHE"),Int64(1)) 02)--TableScan: simple_string projection=[letter] physical_plan 01)ProjectionExec: expr=[letter@0 as letter, letter@0 = A as simple_string.letter = left(Utf8("APACHE"),Int64(1))] @@ -1851,10 +1851,10 @@ query TT EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string; ---- logical_plan -01)Projection: simple_string.letter, simple_string.letter = left(simple_string.letter2, Int64(1)) +01)Projection: simple_string.letter, simple_string.letter = CAST(left(simple_string.letter2, Int64(1)) AS Utf8View) 02)--TableScan: simple_string projection=[letter, letter2] physical_plan -01)ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as simple_string.letter = left(simple_string.letter2,Int64(1))] +01)ProjectionExec: expr=[letter@0 as letter, letter@0 = CAST(left(letter2@1, 1) AS Utf8View) as simple_string.letter = left(simple_string.letter2,Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query TB diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index aa14faf984e40..cd1f90c42efd3 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -408,7 +408,7 @@ VALUES (1,2,3,4,5,6,7,8,9,10,11,12,13,NULL,'F',3.5) # Test non-literal expressions in VALUES query II -VALUES (1, CASE WHEN RANDOM() > 0.5 THEN 1 ELSE 1 END), +VALUES (1, CASE WHEN RANDOM() > 0.5 THEN 1 ELSE 1 END), (2, CASE WHEN RANDOM() > 0.5 THEN 2 ELSE 2 END); ---- 1 1 @@ -558,7 +558,7 @@ EXPLAIN SELECT * FROM ((SELECT column1 FROM foo) "T1" CROSS JOIN (SELECT column2 ---- logical_plan 01)SubqueryAlias: F -02)--Cross Join: +02)--Cross Join: 03)----SubqueryAlias: T1 04)------TableScan: foo projection=[column1] 05)----SubqueryAlias: T2 @@ -1641,7 +1641,7 @@ query II SELECT CASE WHEN B.x > 0 THEN A.x / B.x ELSE 0 END AS value1, CASE WHEN B.x > 0 AND B.y > 0 THEN A.x / B.x ELSE 0 END AS value3 -FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B; +FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B; ---- 0 0 0 0 @@ -1656,10 +1656,10 @@ query TT explain select coalesce(1, y/x), coalesce(2, y/x) from t; ---- logical_plan -01)Projection: coalesce(Int64(1), CAST(t.y / t.x AS Int64)), coalesce(Int64(2), CAST(t.y / t.x AS Int64)) -02)--TableScan: t projection=[x, y] +01)Projection: Int64(1) AS coalesce(Int64(1),t.y / t.x), Int64(2) AS coalesce(Int64(2),t.y / t.x) +02)--TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[coalesce(1, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(1),t.y / t.x), coalesce(2, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(2),t.y / t.x)] +01)ProjectionExec: expr=[1 as coalesce(Int64(1),t.y / t.x), 2 as coalesce(Int64(2),t.y / t.x)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query TT @@ -1686,11 +1686,17 @@ physical_plan 02)--ProjectionExec: expr=[y@1 = 0 as __common_expr_1, x@0 as x, y@1 as y] 03)----DataSourceExec: partitions=1, partition_sizes=[1] -# due to the reason describe in https://github.com/apache/datafusion/issues/8927, -# the following queries will fail -query error +query II select coalesce(1, y/x), coalesce(2, y/x) from t; +---- +1 2 +1 2 +1 2 +1 2 +1 2 +# due to the reason describe in https://github.com/apache/datafusion/issues/8927, +# the following queries will fail query error SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; @@ -1871,3 +1877,61 @@ select *, count(*) over() as ta from t; statement count 0 drop table t; + +# test "user" column +# See https://github.com/apache/datafusion/issues/14141 +statement count 0 +create table t_with_user(a int, user text) as values (1,'test'), (2,null), (3,'foo'); + +query T +select t_with_user.user from t_with_user; +---- +test +NULL +foo + +query IT +select * from t_with_user where t_with_user.user = 'foo'; +---- +3 foo + +query T +select user from t_with_user; +---- +test +NULL +foo + +query IT +select * from t_with_user where user = 'foo'; +---- +3 foo + +# test "current_time" column +# See https://github.com/apache/datafusion/issues/14141 +statement count 0 +create table t_with_current_time(a int, current_time text) as values (1,'now'), (2,null), (3,'later'); + +# here it's clear the the column was meant +query B +select t_with_current_time.current_time is not null from t_with_current_time; +---- +true +false +true + +# here it's the function +query B +select current_time is not null from t_with_current_time; +---- +true +true +true + +# and here it's the column again +query B +select "current_time" is not null from t_with_current_time; +---- +true +false +true diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index 43193fb41cfad..c77163dc996dc 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -35,22 +35,22 @@ query TT explain select b from t where b ~ '.*' ---- logical_plan -01)Filter: t.b IS NOT NULL +01)Filter: t.b ~ Utf8View(".*") 02)--TableScan: t projection=[b] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: b@0 IS NOT NULL +02)--FilterExec: b@0 ~ .* 03)----DataSourceExec: partitions=1, partition_sizes=[1] query TT explain select b from t where b !~ '.*' ---- logical_plan -01)Filter: t.b = Utf8("") +01)Filter: t.b !~ Utf8View(".*") 02)--TableScan: t projection=[b] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: b@0 = +02)--FilterExec: b@0 !~ .* 03)----DataSourceExec: partitions=1, partition_sizes=[1] query T @@ -63,5 +63,47 @@ query T select b from t where b !~ '.*' ---- +query TT +explain select * from t where a = a; +---- +logical_plan +01)Filter: t.a IS NOT NULL OR Boolean(NULL) +02)--TableScan: t projection=[a, b] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: a@0 IS NOT NULL OR NULL +03)----DataSourceExec: partitions=1, partition_sizes=[1] + statement ok drop table t; + +# test decimal precision +query B +SELECT a * 1.000::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a); +---- +false + +query B +SELECT 1.000::DECIMAL(4,3) * a > 1.2::decimal(2,1) FROM VALUES (1) AS t(a); +---- +false + +query B +SELECT NULL::DECIMAL(4,3) * a > 1.2::decimal(2,1) FROM VALUES (1) AS t(a); +---- +NULL + +query B +SELECT a * NULL::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a); +---- +NULL + +query B +SELECT a / 1.000::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a); +---- +false + +query B +SELECT a / NULL::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/simplify_predicates.slt b/datafusion/sqllogictest/test_files/simplify_predicates.slt new file mode 100644 index 0000000000000..31ce1efd21c72 --- /dev/null +++ b/datafusion/sqllogictest/test_files/simplify_predicates.slt @@ -0,0 +1,234 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test cases for predicate simplification feature +# Basic redundant comparison simplification + +statement ok +set datafusion.explain.logical_plan_only=true; + +statement ok +CREATE TABLE test_data ( + int_col INT, + float_col FLOAT, + str_col VARCHAR, + date_col DATE, + bool_col BOOLEAN +); + +# x > 5 AND x > 6 should simplify to x > 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col > 6; +---- +logical_plan +01)Filter: test_data.int_col > Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x >= 6 should simplify to x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col >= 6; +---- +logical_plan +01)Filter: test_data.int_col >= Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x < 10 AND x <= 8 should simplify to x <= 8 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col < 10 AND int_col <= 8; +---- +logical_plan +01)Filter: test_data.int_col <= Int32(8) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x > 6 AND x > 7 should simplify to x > 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col > 6 AND int_col > 7; +---- +logical_plan +01)Filter: test_data.int_col > Int32(7) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND y < 10 AND x > 6 AND y < 8 should simplify to x > 6 AND y < 8 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND float_col < 10 AND int_col > 6 AND float_col < 8; +---- +logical_plan +01)Filter: test_data.float_col < Float32(8) AND test_data.int_col > Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 7 AND x = 7 should simplify to x = 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col = 7; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 7 AND x = 6 should simplify to false +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col = 6; +---- +logical_plan EmptyRelation: rows=0 + +# TODO: x = 7 AND x < 2 should simplify to false +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col < 2; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) AND test_data.int_col < Int32(2) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + + +# TODO: x = 7 AND x > 5 should simplify to x = 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col > 5; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) AND test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# str_col > 'apple' AND str_col > 'banana' should simplify to str_col > 'banana' +query TT +EXPLAIN SELECT * FROM test_data WHERE str_col > 'apple' AND str_col > 'banana'; +---- +logical_plan +01)Filter: test_data.str_col > Utf8View("banana") +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# date_col > '2023-01-01' AND date_col > '2023-02-01' should simplify to date_col > '2023-02-01' +query TT +EXPLAIN SELECT * FROM test_data WHERE date_col > '2023-01-01' AND date_col > '2023-02-01'; +---- +logical_plan +01)Filter: test_data.date_col > Date32("2023-02-01") +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +query TT +EXPLAIN SELECT * FROM test_data WHERE bool_col = true AND bool_col = false; +---- +logical_plan +01)Filter: test_data.bool_col AND NOT test_data.bool_col +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + + +# This shouldn't be simplified since they're different relationships +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > float_col AND int_col > 5; +---- +logical_plan +01)Filter: CAST(test_data.int_col AS Float32) > test_data.float_col AND test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# Should simplify the int_col predicates but preserve the others +query TT +EXPLAIN SELECT * FROM test_data +WHERE int_col > 5 + AND int_col > 10 + AND str_col LIKE 'A%' + AND float_col BETWEEN 1 AND 100; +---- +logical_plan +01)Filter: test_data.str_col LIKE Utf8View("A%") AND test_data.float_col >= Float32(1) AND test_data.float_col <= Float32(100) AND test_data.int_col > Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +statement ok +CREATE TABLE test_data2 ( + id INT, + value INT +); + +query TT +EXPLAIN SELECT t1.int_col, t2.value +FROM test_data t1 +JOIN test_data2 t2 ON t1.int_col = t2.id +WHERE t1.int_col > 5 + AND t1.int_col > 10 + AND t2.value < 100 + AND t2.value < 50; +---- +logical_plan +01)Projection: t1.int_col, t2.value +02)--Inner Join: t1.int_col = t2.id +03)----SubqueryAlias: t1 +04)------Filter: test_data.int_col > Int32(10) +05)--------TableScan: test_data projection=[int_col] +06)----SubqueryAlias: t2 +07)------Filter: test_data2.value < Int32(50) AND test_data2.id > Int32(10) +08)--------TableScan: test_data2 projection=[id, value] + +# Handling negated predicates +# NOT (x < 10) AND NOT (x < 5) should simplify to NOT (x < 10) +query TT +EXPLAIN SELECT * FROM test_data WHERE NOT (int_col < 10) AND NOT (int_col < 5); +---- +logical_plan +01)Filter: test_data.int_col >= Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x < 10 should be preserved (can't be simplified) +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col < 10; +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) AND test_data.int_col < Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# 5 < x AND 3 < x should simplify to 5 < x +query TT +EXPLAIN SELECT * FROM test_data WHERE 5 < int_col AND 3 < int_col; +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# CAST(x AS FLOAT) > 5.0 AND CAST(x AS FLOAT) > 6.0 should simplify +query TT +EXPLAIN SELECT * FROM test_data WHERE CAST(int_col AS FLOAT) > 5.0 AND CAST(int_col AS FLOAT) > 6.0; +---- +logical_plan +01)Filter: CAST(CAST(test_data.int_col AS Float32) AS Float64) > Float64(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 5 AND x = 6 (logically impossible) +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 5 AND int_col = 6; +---- +logical_plan EmptyRelation: rows=0 + +# (x > 5 OR y < 10) AND (x > 6 OR y < 8) +# This is more complex but could still benefit from some simplification +query TT +EXPLAIN SELECT * FROM test_data +WHERE (int_col > 5 OR float_col < 10) + AND (int_col > 6 OR float_col < 8); +---- +logical_plan +01)Filter: (test_data.int_col > Int32(5) OR test_data.float_col < Float32(10)) AND (test_data.int_col > Int32(6) OR test_data.float_col < Float32(8)) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# Combination of AND and OR with simplifiable predicates +query TT +EXPLAIN SELECT * FROM test_data +WHERE (int_col > 5 AND int_col > 6) + OR (float_col < 10 AND float_col < 8); +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) AND test_data.int_col > Int32(6) OR test_data.float_col < Float32(10) AND test_data.float_col < Float32(8) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +statement ok +set datafusion.explain.logical_plan_only=false; diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 162c9a17b61f3..ed463333217af 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -695,9 +695,199 @@ select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b ---- 51 54 -# return sql params back to default values -statement ok -set datafusion.optimizer.prefer_hash_join = true; +# RIGHTSEMI join tests + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 14 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +# Test RIGHTSEMI with cross batch data distribution + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b union all + select 12 a, 14 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b union all + select 12 a, 15 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 14 +12 15 statement ok set datafusion.execution.batch_size = 8192; + + +###### +## Tests for Binary, LargeBinary, BinaryView, FixedSizeBinary join keys +###### +statement ok +create table t1(x varchar, id1 int) as values ('aa', 1), ('bb', 2), ('aa', 3), (null, 4), ('ee', 5); + +statement ok +create table t2(y varchar, id2 int) as values ('ee', 10), ('bb', 20), ('cc', 30), ('cc', 40), (null, 50); + +# Binary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'Binary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'Binary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# LargeBinary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'LargeBinary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'LargeBinary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# BinaryView join keys +query ?I?I +with t1 as (select arrow_cast(x, 'BinaryView') as x, id1 from t1), + t2 as (select arrow_cast(y, 'BinaryView') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# FixedSizeBinary join keys +query ?I?I +with t1 as (select arrow_cast(arrow_cast(x, 'Binary'), 'FixedSizeBinary(2)') as x, id1 from t1), + t2 as (select arrow_cast(arrow_cast(y, 'Binary'), 'FixedSizeBinary(2)') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +statement ok +drop table t1; + +statement ok +drop table t2; + +# return sql params back to default values +statement ok +set datafusion.optimizer.prefer_hash_join = true; diff --git a/datafusion/sqllogictest/test_files/spark/README.md b/datafusion/sqllogictest/test_files/spark/README.md new file mode 100644 index 0000000000000..cffd28009889d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/README.md @@ -0,0 +1,67 @@ + + +# Spark Test Files + +This directory contains test files for the `spark` test suite. + +## RoadMap + +Implementing the `datafusion-spark` compatible functions project is still a work in progress. +Many of the tests in this directory are commented out and are waiting for help with implementation. + +For more information please see: + +- [The `datafusion-spark` Epic](https://github.com/apache/datafusion/issues/15914) +- [Spark Test Generation Script] (https://github.com/apache/datafusion/pull/16409#issuecomment-2972618052) + +## Testing Guide + +When testing Spark functions: + +- Functions must be tested on both `Scalar` and `Array` inputs +- Test cases should only contain `SELECT` statements with the function being tested +- Add explicit casts to input values to ensure the correct data type is used (e.g., `0::INT`) + - Explicit casting is necessary because DataFusion and Spark do not infer data types in the same way + +### Finding Test Cases + +To verify and compare function behavior at a minimum, you can refer to the following documentation sources: + +1. Databricks SQL Function Reference: + https://docs.databricks.com/aws/en/sql/language-manual/functions/NAME +2. Apache Spark SQL Function Reference: + https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.NAME.html +3. PySpark SQL Function Reference: + https://spark.apache.org/docs/latest/api/sql/#NAME + +**Note:** Replace `NAME` in each URL with the actual function name (e.g., for the `ASCII` function, use `ascii` instead +of `NAME`). + +### Scalar Example: + +```sql +SELECT expm1(0::INT); +``` + +### Array Example: + +```sql +SELECT expm1(a) FROM (VALUES (0::INT), (1::INT)) AS t(a); +``` diff --git a/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt b/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt new file mode 100644 index 0000000000000..a5bed6ea324a7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/aggregate/avg.slt @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query R +SELECT avg(a) FROM (VALUES (10::INT), (20::INT), (30::INT), (40::INT), (50::INT)) AS t(a); +---- +30 + +query R +SELECT avg(a) FROM (VALUES (40::INT), (23::INT), (17::INT), (40::INT), (NULL)) AS t(a); +---- +30 + +query R +SELECT avg(a) FROM (VALUES (0::INT), (0::INT)) AS t(a); +---- +0 + +query IR +SELECT a % 2 AS g, avg(a) +FROM (VALUES (40), (23), (17), (40), (30)) AS t(a) +GROUP BY g +ORDER BY g; +---- +0 36.666666666666664 +1 20 + +query IR +SELECT a % 2 AS g, avg(a) +FROM (VALUES (10::INT), (20::INT), (30::INT), (40::INT), (50::INT)) AS t(a) +GROUP BY g +ORDER BY g; +---- +0 30 + +query IR +SELECT a, avg(a) +FROM (VALUES (0::INT), (0::INT)) AS t(a) +GROUP BY a +ORDER BY a; +---- +0 0 \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/spark/array/array.slt b/datafusion/sqllogictest/test_files/spark/array/array.slt new file mode 100644 index 0000000000000..09821e6d582d2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/array.slt @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query ? +SELECT array(1, 2, 3); +---- +[1, 2, 3] + + +query ? +SELECT array('a', 'b'); +---- +[a, b] + + +query ? +SELECT array(); +---- +[] + +query ?? +SELECT array(), array(array()); +---- +[] [[]] + + +query ? +SELECT array(null); +---- +[NULL] + + +query ? +SELECT array(1, NULL, 3); +---- +[1, NULL, 3] + + +query ? +SELECT array['hello', '', null, 'nULl', 'nULlx', 'aa"bb', 'mm\nn', 'uu,vv', 'yy zz']; +---- +[hello, , NULL, nULl, nULlx, aa"bb, mm\nn, uu,vv, yy zz] + +query ? +SELECT array(array(1,2),array(3,4)); +---- +[[1, 2], [3, 4]] + + +query ? +SELECT array(array(1), array(2,3,4)); +---- +[[1], [2, 3, 4]] + +query ? +SELECT array(array(1,2)); +---- +[[1, 2]] diff --git a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt new file mode 100644 index 0000000000000..544c39608f33b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT array_repeat('123', 2); +## PySpark 3.5.5 Result: {'array_repeat(123, 2)': ['123', '123'], 'typeof(array_repeat(123, 2))': 'array', 'typeof(123)': 'string', 'typeof(2)': 'int'} +#query +#SELECT array_repeat('123'::string, 2::int); diff --git a/datafusion/sqllogictest/test_files/spark/array/sequence.slt b/datafusion/sqllogictest/test_files/spark/array/sequence.slt new file mode 100644 index 0000000000000..bb4aa06bfd257 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/sequence.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sequence(1, 5); +## PySpark 3.5.5 Result: {'sequence(1, 5)': [1, 2, 3, 4, 5], 'typeof(sequence(1, 5))': 'array', 'typeof(1)': 'int', 'typeof(5)': 'int'} +#query +#SELECT sequence(1::int, 5::int); + +## Original Query: SELECT sequence(5, 1); +## PySpark 3.5.5 Result: {'sequence(5, 1)': [5, 4, 3, 2, 1], 'typeof(sequence(5, 1))': 'array', 'typeof(5)': 'int', 'typeof(1)': 'int'} +#query +#SELECT sequence(5::int, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_count.slt b/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_count.slt new file mode 100644 index 0000000000000..2789efef7bf36 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitmap/bitmap_count.slt @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query I +SELECT bitmap_count(X'1010'); +---- +2 + +query I +SELECT bitmap_count(X'FFFF'); +---- +16 + +query I +SELECT bitmap_count(X'0'); +---- +0 + +query I +SELECT bitmap_count(a) FROM (VALUES (X'0AB0'), (X'0AB0CD'), (NULL)) AS t(a); +---- +5 +10 +NULL + +# Tests with different binary types +query I +SELECT bitmap_count(arrow_cast(a, 'LargeBinary')) FROM (VALUES (X'0AB0'), (X'0AB0CD'), (NULL)) AS t(a); +---- +5 +10 +NULL + +query I +SELECT bitmap_count(arrow_cast(a, 'BinaryView')) FROM (VALUES (X'0AB0'), (X'0AB0CD'), (NULL)) AS t(a); +---- +5 +10 +NULL + +query I +SELECT bitmap_count(arrow_cast(a, 'FixedSizeBinary(2)')) FROM (VALUES (X'1010'), (X'0AB0'), (X'FFFF'), (NULL)) AS t(a); +---- +2 +5 +16 +NULL diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt b/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt new file mode 100644 index 0000000000000..2a75c7648d409 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt @@ -0,0 +1,227 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bit_count(0); +## PySpark 3.5.5 Result: {'bit_count(0)': 0, 'typeof(bit_count(0))': 'int', 'typeof(0)': 'int'} + +# Basic tests with different integer types +query I +SELECT bit_count(0::int); +---- +0 + +query I +SELECT bit_count(1::int); +---- +1 + +query I +SELECT bit_count(7::int); +---- +3 + +query I +SELECT bit_count(15::int); +---- +4 + +query I +SELECT bit_count(255::int); +---- +8 + +query I +SELECT bit_count(1023::int); +---- +10 + +# Tests with negative numbers (two's complement) +query I +SELECT bit_count(-1::int); +---- +32 + +query I +SELECT bit_count(-2::int); +---- +31 + +query I +SELECT bit_count(-3::int); +---- +31 + +# Tests with different integer types +query I +SELECT bit_count(arrow_cast(0, 'Int8')); +---- +0 + +query I +SELECT bit_count(arrow_cast(15, 'Int8')); +---- +4 + +query I +SELECT bit_count(arrow_cast(-1, 'Int8')); +---- +8 + +query I +SELECT bit_count(arrow_cast(0, 'Int16')); +---- +0 + +query I +SELECT bit_count(arrow_cast(255, 'Int16')); +---- +8 + +query I +SELECT bit_count(arrow_cast(-1, 'Int16')); +---- +16 + +query I +SELECT bit_count(arrow_cast(0, 'Int64')); +---- +0 + +query I +SELECT bit_count(arrow_cast(255, 'Int64')); +---- +8 + +query I +SELECT bit_count(arrow_cast(-1, 'Int64')); +---- +64 + +# Tests with unsigned integer types +query I +SELECT bit_count(arrow_cast(0, 'UInt8')); +---- +0 + +query I +SELECT bit_count(arrow_cast(255, 'UInt8')); +---- +8 + +query I +SELECT bit_count(arrow_cast(0, 'UInt16')); +---- +0 + +query I +SELECT bit_count(arrow_cast(65535, 'UInt16')); +---- +16 + +query I +SELECT bit_count(arrow_cast(0, 'UInt32')); +---- +0 + +query I +SELECT bit_count(arrow_cast(4294967295, 'UInt32')); +---- +32 + +query I +SELECT bit_count(arrow_cast(0, 'UInt64')); +---- +0 + +query I +SELECT bit_count(arrow_cast(18446744073709551615, 'UInt64')); +---- +64 + +# Tests with NULL values +query I +SELECT bit_count(arrow_cast(NULL, 'Int32')); +---- +NULL + +query I +SELECT bit_count(arrow_cast(NULL, 'Int8')); +---- +NULL + +query I +SELECT bit_count(arrow_cast(NULL, 'UInt64')); +---- +NULL + +# Tests with edge cases +query I +SELECT bit_count(arrow_cast(0, 'Int32')) as zero_count; +---- +0 + +query I +SELECT bit_count(arrow_cast(1, 'Int32')) as one_count; +---- +1 + +query I +SELECT bit_count(arrow_cast(2, 'Int32')) as two_count; +---- +1 + +query I +SELECT bit_count(arrow_cast(3, 'Int32')) as three_count; +---- +2 + +query I +SELECT bit_count(arrow_cast(4, 'Int32')) as four_count; +---- +1 + +query I +SELECT bit_count(arrow_cast(5, 'Int32')) as five_count; +---- +2 + +# Tests with large numbers +query I +SELECT bit_count(arrow_cast(2147483647, 'Int32')); +---- +31 + +query I +SELECT bit_count(arrow_cast(-2147483648, 'Int32')); +---- +1 + +query I +SELECT bit_count(arrow_cast(9223372036854775807, 'Int64')); +---- +63 + +query I +SELECT bit_count(arrow_cast(-9223372036854775808, 'Int64')); +---- +1 diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/bit_get.slt b/datafusion/sqllogictest/test_files/spark/bitwise/bit_get.slt new file mode 100644 index 0000000000000..6a2b244d58e69 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/bit_get.slt @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bit_get(11, 0); +## PySpark 3.5.5 Result: {'bit_get(11, 0)': 1, 'typeof(bit_get(11, 0))': 'tinyint', 'typeof(11)': 'int', 'typeof(0)': 'int'} +query I +SELECT bit_get(11, 0); +---- +1 + +## Original Query: SELECT bit_get(11, 2); +## PySpark 3.5.5 Result: {'bit_get(11, 2)': 0, 'typeof(bit_get(11, 2))': 'tinyint', 'typeof(11)': 'int', 'typeof(2)': 'int'} +query I +SELECT bit_get(11, 2); +---- +0 + +## Test additional cases +query I +SELECT bit_get(11, 3); +---- +1 + +query I +SELECT bit_get(255, 7); +---- +1 + +query I +SELECT bit_get(255, 8); +---- +0 + +query I +SELECT bit_get(0, 0); +---- +0 + +## Test edge cases +statement error DataFusion error: Arrow error: Compute error: bit_get: position -1 is out of bounds. Expected pos < 64 and pos >= 0 +SELECT bit_get(11, -1); + +statement error DataFusion error: Arrow error: Compute error: bit_get: position 64 is out of bounds. Expected pos < 64 and pos >= 0 +SELECT bit_get(11, 64); + +## Test null inputs +query I +SELECT bit_get(NULL, 0); +---- +NULL + +query I +SELECT bit_get(11, NULL); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/bitwise_not.slt b/datafusion/sqllogictest/test_files/spark/bitwise/bitwise_not.slt new file mode 100644 index 0000000000000..5f51cd68ef94f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/bitwise_not.slt @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bitwise_not(0); +## PySpark 3.5.5 Result: {'bitwise_not(0)': -1, 'typeof(bitwise_not(0))': 'int', 'typeof(0)': 'int'} + +# Basic tests with different integer types +query I +SELECT bitwise_not(0::int); +---- +-1 + +query I +SELECT bitwise_not(1::int); +---- +-2 + +query I +SELECT bitwise_not(7::int); +---- +-8 + +query I +SELECT bitwise_not(15::int); +---- +-16 + +query I +SELECT bitwise_not(255::int); +---- +-256 + +query I +SELECT bitwise_not(1023::int); +---- +-1024 + +# Tests with negative numbers (two's complement) +query I +SELECT bitwise_not(-1::int); +---- +0 + +query I +SELECT bitwise_not(-2::int); +---- +1 + +query I +SELECT bitwise_not(-3::int); +---- +2 + +# Tests with different integer types +query I +SELECT bitwise_not(arrow_cast(0, 'Int8')); +---- +-1 + +query I +SELECT bitwise_not(arrow_cast(15, 'Int8')); +---- +-16 + +query I +SELECT bitwise_not(arrow_cast(-1, 'Int8')); +---- +0 + +query I +SELECT bitwise_not(arrow_cast(0, 'Int16')); +---- +-1 + +query I +SELECT bitwise_not(arrow_cast(255, 'Int16')); +---- +-256 + +query I +SELECT bitwise_not(arrow_cast(-1, 'Int16')); +---- +0 + +query I +SELECT bitwise_not(arrow_cast(0, 'Int32')); +---- +-1 + +query I +SELECT bitwise_not(arrow_cast(255, 'Int32')); +---- +-256 + +query I +SELECT bitwise_not(arrow_cast(-1, 'Int32')); +---- +0 + +query I +SELECT bitwise_not(arrow_cast(0, 'Int64')); +---- +-1 + +query I +SELECT bitwise_not(arrow_cast(255, 'Int64')); +---- +-256 + +query I +SELECT bitwise_not(arrow_cast(-1, 'Int64')); +---- +0 + +# Tests with NULL values +query I +SELECT bitwise_not(arrow_cast(NULL, 'Int32')); +---- +NULL + +query I +SELECT bitwise_not(arrow_cast(NULL, 'Int8')); +---- +NULL + +query I +SELECT bitwise_not(arrow_cast(NULL, 'Int64')); +---- +NULL + +# Tests with edge cases +query I +SELECT bitwise_not(arrow_cast(0, 'Int32')) as zero_not; +---- +-1 + +query I +SELECT bitwise_not(arrow_cast(1, 'Int32')) as one_not; +---- +-2 + +query I +SELECT bitwise_not(arrow_cast(2, 'Int32')) as two_not; +---- +-3 + +query I +SELECT bitwise_not(arrow_cast(3, 'Int32')) as three_not; +---- +-4 + +query I +SELECT bitwise_not(arrow_cast(4, 'Int32')) as four_not; +---- +-5 + +query I +SELECT bitwise_not(arrow_cast(5, 'Int32')) as five_not; +---- +-6 + +# Tests with large numbers +query I +SELECT bitwise_not(arrow_cast(2147483647, 'Int32')); +---- +-2147483648 + +query I +SELECT bitwise_not(arrow_cast(-2147483648, 'Int32')); +---- +2147483647 + +query I +SELECT bitwise_not(arrow_cast(9223372036854775807, 'Int64')); +---- +-9223372036854775808 + +query I +SELECT bitwise_not(arrow_cast(-9223372036854775808, 'Int64')); +---- +9223372036854775807 diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/getbit.slt b/datafusion/sqllogictest/test_files/spark/bitwise/getbit.slt new file mode 100644 index 0000000000000..7cfdfe8257277 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/getbit.slt @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT getbit(11, 0); +## PySpark 3.5.5 Result: {'getbit(11, 0)': 1, 'typeof(getbit(11, 0))': 'tinyint', 'typeof(11)': 'int', 'typeof(0)': 'int'} +query I +SELECT getbit(11, 0); +---- +1 + +## Original Query: SELECT getbit(11, 2); +## PySpark 3.5.5 Result: {'getbit(11, 2)': 0, 'typeof(getbit(11, 2))': 'tinyint', 'typeof(11)': 'int', 'typeof(2)': 'int'} +query I +SELECT getbit(11, 2); +---- +0 + +## Test additional cases +query I +SELECT getbit(11, 3); +---- +1 + +query I +SELECT getbit(255, 7); +---- +1 + +query I +SELECT getbit(255, 8); +---- +0 + +query I +SELECT getbit(0, 0); +---- +0 + +## Test edge cases +statement error DataFusion error: Arrow error: Compute error: bit_get: position -1 is out of bounds. Expected pos < 64 and pos >= 0 +SELECT getbit(11, -1); + +statement error DataFusion error: Arrow error: Compute error: bit_get: position 64 is out of bounds. Expected pos < 64 and pos >= 0 +SELECT getbit(11, 64); + +## Test null inputs +query I +SELECT getbit(NULL, 0); +---- +NULL + +query I +SELECT getbit(11, NULL); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt b/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt new file mode 100644 index 0000000000000..3587bcc7ca52b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT shiftright(4, 1); +## PySpark 3.5.5 Result: {'shiftright(4, 1)': 2, 'typeof(shiftright(4, 1))': 'int', 'typeof(4)': 'int', 'typeof(1)': 'int'} + +# Basic shiftright tests +query I +SELECT shiftright(4::int, 1::int); +---- +2 + +query I +SELECT shiftright(8::int, 2::int); +---- +2 + +query I +SELECT shiftright(16::int, 3::int); +---- +2 + +# Different data types +query I +SELECT shiftright(4::bigint, 1::int); +---- +2 + +query I +SELECT shiftright(8::bigint, 2::int); +---- +2 + +query I +SELECT shiftright(4::int, 1::bigint); +---- +2 + +# Large shifts (should handle modulo correctly) +query I +SELECT shiftright(1::int, 32::int); +---- +1 + +query I +SELECT shiftright(2::int, 33::int); +---- +1 + +query I +SELECT shiftright(3::int, 64::int); +---- +3 + +# Negative shifts +query I +SELECT shiftright(4::int, -1::int); +---- +0 + +query I +SELECT shiftright(8::int, -2::int); +---- +0 + +query I +SELECT shiftright(16::int, -3::int); +---- +0 + +# Zero shifts +query I +SELECT shiftright(5::int, 0::int); +---- +5 + +query I +SELECT shiftright(0::int, 5::int); +---- +0 + +# Edge cases - signed right shift preserves sign +query I +SELECT shiftright(-4::int, 1::int); +---- +-2 + +query I +SELECT shiftright(-8::int, 2::int); +---- +-2 + +query I +SELECT shiftright(-16::int, 3::int); +---- +-2 + +query I +SELECT shiftright(2147483647::int, 1::int); +---- +1073741823 + +# Null handling +query I +SELECT shiftright(NULL::int, 1::int); +---- +NULL + +query I +SELECT shiftright(1::int, NULL::int); +---- +NULL + +query I +SELECT shiftright(NULL::int, NULL::int); +---- +NULL + +query I +select shiftright(3::int,-31); +---- +1 + +query I +select shiftright(3::int,-32); +---- +3 diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt b/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt new file mode 100644 index 0000000000000..b0d4cfaec7021 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT shiftrightunsigned(4, 1); +## PySpark 3.5.5 Result: {'shiftrightunsigned(4, 1)': 2, 'typeof(shiftrightunsigned(4, 1))': 'int', 'typeof(4)': 'int', 'typeof(1)': 'int'} + +# Basic shiftrightunsigned tests +query I +SELECT shiftrightunsigned(4::int, 1::int); +---- +2 + +query I +SELECT shiftrightunsigned(8::int, 2::int); +---- +2 + +query I +SELECT shiftrightunsigned(16::int, 3::int); +---- +2 + +# Different data types +query I +SELECT shiftrightunsigned(4::bigint, 1::int); +---- +2 + +query I +SELECT shiftrightunsigned(8::bigint, 2::int); +---- +2 + +query I +SELECT shiftrightunsigned(4::int, 1::bigint); +---- +2 + +# Large shifts (should handle modulo correctly) +query I +SELECT shiftrightunsigned(1::int, 32::int); +---- +1 + +query I +SELECT shiftrightunsigned(2::int, 33::int); +---- +1 + +query I +SELECT shiftrightunsigned(3::int, 64::int); +---- +3 + +# Negative shifts +query I +SELECT shiftrightunsigned(4::int, -1::int); +---- +0 + +query I +SELECT shiftrightunsigned(8::int, -2::int); +---- +0 + +query I +SELECT shiftrightunsigned(16::int, -3::int); +---- +0 + +# Zero shifts +query I +SELECT shiftrightunsigned(5::int, 0::int); +---- +5 + +query I +SELECT shiftrightunsigned(0::int, 5::int); +---- +0 + +# Edge cases - unsigned right shift treats negative values as large positive +query I +SELECT shiftrightunsigned(-4::int, 1::int); +---- +2147483646 + +query I +SELECT shiftrightunsigned(-8::int, 2::int); +---- +1073741822 + +query I +SELECT shiftrightunsigned(-16::int, 3::int); +---- +536870910 + +query I +SELECT shiftrightunsigned(2147483647::int, 1::int); +---- +1073741823 + + +# Null handling +query I +SELECT shiftrightunsigned(NULL::int, 1::int); +---- +NULL + +query I +SELECT shiftrightunsigned(1::int, NULL::int); +---- +NULL + +query I +SELECT shiftrightunsigned(NULL::int, NULL::int); +---- +NULL + +query I +select shiftrightunsigned(3::int,-31); +---- +1 + +query I +select shiftrightunsigned(3::int,-32); +---- +3 diff --git a/datafusion/sqllogictest/test_files/spark/collection/concat.slt b/datafusion/sqllogictest/test_files/spark/collection/concat.slt new file mode 100644 index 0000000000000..911975d9c72d9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/collection/concat.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT concat('Spark', 'SQL'); +## PySpark 3.5.5 Result: {'concat(Spark, SQL)': 'SparkSQL', 'typeof(concat(Spark, SQL))': 'string', 'typeof(Spark)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT concat('Spark'::string, 'SQL'::string); diff --git a/datafusion/sqllogictest/test_files/spark/collection/reverse.slt b/datafusion/sqllogictest/test_files/spark/collection/reverse.slt new file mode 100644 index 0000000000000..f49c7c2a8c2b0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/collection/reverse.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT reverse('Spark SQL'); +## PySpark 3.5.5 Result: {'reverse(Spark SQL)': 'LQS krapS', 'typeof(reverse(Spark SQL))': 'string', 'typeof(Spark SQL)': 'string'} +#query +#SELECT reverse('Spark SQL'::string); diff --git a/datafusion/sqllogictest/test_files/spark/conditional/coalesce.slt b/datafusion/sqllogictest/test_files/spark/conditional/coalesce.slt new file mode 100644 index 0000000000000..3af8110ad6f38 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conditional/coalesce.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT coalesce(NULL, 1, NULL); +## PySpark 3.5.5 Result: {'coalesce(NULL, 1, NULL)': 1, 'typeof(coalesce(NULL, 1, NULL))': 'int', 'typeof(NULL)': 'void', 'typeof(1)': 'int'} +#query +#SELECT coalesce(NULL::void, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt b/datafusion/sqllogictest/test_files/spark/conditional/if.slt new file mode 100644 index 0000000000000..b4380e065b987 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Basic IF function tests + +# Test basic true condition +query T +SELECT if(true, 'yes', 'no'); +---- +yes + +# Test basic false condition +query T +SELECT if(false, 'yes', 'no'); +---- +no + +# Test with comparison operators +query T +SELECT if(1 < 2, 'a', 'b'); +---- +a + +query T +SELECT if(1 > 2, 'a', 'b'); +---- +b + + +## Numeric type tests + +# Test with integers +query I +SELECT if(true, 10, 20); +---- +10 + +query I +SELECT if(false, 10, 20); +---- +20 + +# Test with different integer types +query I +SELECT if(true, 100, 200); +---- +100 + +## Float type tests + +# Test with floating point numbers +query R +SELECT if(true, 1.5, 2.5); +---- +1.5 + +query R +SELECT if(false, 1.5, 2.5); +---- +2.5 + +## String type tests + +# Test with different string values +query T +SELECT if(true, 'hello', 'world'); +---- +hello + +query T +SELECT if(false, 'hello', 'world'); +---- +world + +## NULL handling tests + +# Test with NULL condition +query T +SELECT if(NULL, 'yes', 'no'); +---- +no + +query T +SELECT if(NOT NULL, 'yes', 'no'); +---- +no + +# Test with NULL true value +query T +SELECT if(true, NULL, 'no'); +---- +NULL + +# Test with NULL false value +query T +SELECT if(false, 'yes', NULL); +---- +NULL + +# Test with all NULL +query ? +SELECT if(true, NULL, NULL); +---- +NULL + +## Type coercion tests + +# Test integer to float coercion +query R +SELECT if(true, 10, 20.5); +---- +10 + +query R +SELECT if(false, 10, 20.5); +---- +20.5 + +# Test float to integer coercion +query R +SELECT if(true, 10.5, 20); +---- +10.5 + +query R +SELECT if(false, 10.5, 20); +---- +20 + +statement error Int64 is not a boolean or null +SELECT if(1, 10.5, 20); + + +statement error Utf8 is not a boolean or null +SELECT if('x', 10.5, 20); + +query II +SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v) +---- +1 1 +2 1 + +query I +SELECT IF(true, 1 / 1, 1 / 0); +---- +1 diff --git a/datafusion/sqllogictest/test_files/spark/conditional/nullif.slt b/datafusion/sqllogictest/test_files/spark/conditional/nullif.slt new file mode 100644 index 0000000000000..1a4c80e3baaeb --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conditional/nullif.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT nullif(2, 2); +## PySpark 3.5.5 Result: {'nullif(2, 2)': None, 'typeof(nullif(2, 2))': 'int', 'typeof(2)': 'int'} +#query +#SELECT nullif(2::int); diff --git a/datafusion/sqllogictest/test_files/spark/conditional/nvl2.slt b/datafusion/sqllogictest/test_files/spark/conditional/nvl2.slt new file mode 100644 index 0000000000000..c5ea2f8f1f360 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conditional/nvl2.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT nvl2(NULL, 2, 1); +## PySpark 3.5.5 Result: {'nvl2(NULL, 2, 1)': 1, 'typeof(nvl2(NULL, 2, 1))': 'int', 'typeof(NULL)': 'void', 'typeof(2)': 'int', 'typeof(1)': 'int'} +#query +#SELECT nvl2(NULL::void, 2::int, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/csv/schema_of_csv.slt b/datafusion/sqllogictest/test_files/spark/csv/schema_of_csv.slt new file mode 100644 index 0000000000000..eaa31c9d5c9cb --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/csv/schema_of_csv.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT schema_of_csv('1,abc'); +## PySpark 3.5.5 Result: {'schema_of_csv(1,abc)': 'STRUCT<_c0: INT, _c1: STRING>', 'typeof(schema_of_csv(1,abc))': 'string', 'typeof(1,abc)': 'string'} +#query +#SELECT schema_of_csv('1,abc'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt new file mode 100644 index 0000000000000..cae9b21dd4766 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT add_months('2016-08-31', 1); +## PySpark 3.5.5 Result: {'add_months(2016-08-31, 1)': datetime.date(2016, 9, 30), 'typeof(add_months(2016-08-31, 1))': 'date', 'typeof(2016-08-31)': 'string', 'typeof(1)': 'int'} +#query +#SELECT add_months('2016-08-31'::string, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/convert_timezone.slt b/datafusion/sqllogictest/test_files/spark/datetime/convert_timezone.slt new file mode 100644 index 0000000000000..54c9e616cf05e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/convert_timezone.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT convert_timezone('Europe/Brussels', 'America/Los_Angeles', timestamp_ntz'2021-12-06 00:00:00'); +## PySpark 3.5.5 Result: {"convert_timezone(Europe/Brussels, America/Los_Angeles, TIMESTAMP_NTZ '2021-12-06 00:00:00')": datetime.datetime(2021, 12, 5, 15, 0), "typeof(convert_timezone(Europe/Brussels, America/Los_Angeles, TIMESTAMP_NTZ '2021-12-06 00:00:00'))": 'timestamp_ntz', 'typeof(Europe/Brussels)': 'string', 'typeof(America/Los_Angeles)': 'string', "typeof(TIMESTAMP_NTZ '2021-12-06 00:00:00')": 'timestamp_ntz'} +#query +#SELECT convert_timezone('Europe/Brussels'::string, 'America/Los_Angeles'::string, TIMESTAMP_NTZ '2021-12-06 00:00:00'::timestamp_ntz); + +## Original Query: SELECT convert_timezone('Europe/Brussels', timestamp_ntz'2021-12-05 15:00:00'); +## PySpark 3.5.5 Result: {"convert_timezone(current_timezone(), Europe/Brussels, TIMESTAMP_NTZ '2021-12-05 15:00:00')": datetime.datetime(2021, 12, 6, 0, 0), "typeof(convert_timezone(current_timezone(), Europe/Brussels, TIMESTAMP_NTZ '2021-12-05 15:00:00'))": 'timestamp_ntz', 'typeof(Europe/Brussels)': 'string', "typeof(TIMESTAMP_NTZ '2021-12-05 15:00:00')": 'timestamp_ntz'} +#query +#SELECT convert_timezone('Europe/Brussels'::string, TIMESTAMP_NTZ '2021-12-05 15:00:00'::timestamp_ntz); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/curdate.slt b/datafusion/sqllogictest/test_files/spark/datetime/curdate.slt new file mode 100644 index 0000000000000..21ec4c0305aa0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/curdate.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT curdate(); +## PySpark 3.5.5 Result: {'current_date()': datetime.date(2025, 6, 14), 'typeof(current_date())': 'date'} +#query +#SELECT curdate(); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/current_date.slt b/datafusion/sqllogictest/test_files/spark/datetime/current_date.slt new file mode 100644 index 0000000000000..cd187901777f4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/current_date.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_date(); +## PySpark 3.5.5 Result: {'current_date()': datetime.date(2025, 6, 14), 'typeof(current_date())': 'date'} +#query +#SELECT current_date(); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/current_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/current_timestamp.slt new file mode 100644 index 0000000000000..f3e4f5856aca6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/current_timestamp.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_timestamp(); +## PySpark 3.5.5 Result: {'current_timestamp()': datetime.datetime(2025, 6, 14, 23, 57, 38, 948981), 'typeof(current_timestamp())': 'timestamp'} +#query +#SELECT current_timestamp(); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/current_timezone.slt b/datafusion/sqllogictest/test_files/spark/datetime/current_timezone.slt new file mode 100644 index 0000000000000..db3d8d40742d7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/current_timezone.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_timezone(); +## PySpark 3.5.5 Result: {'current_timezone()': 'America/Los_Angeles', 'typeof(current_timezone())': 'string'} +#query +#SELECT current_timezone(); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt new file mode 100644 index 0000000000000..2e9851ca1e595 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_add('2016-07-30', 1); +## PySpark 3.5.5 Result: {'date_add(2016-07-30, 1)': datetime.date(2016, 7, 31), 'typeof(date_add(2016-07-30, 1))': 'date', 'typeof(2016-07-30)': 'string', 'typeof(1)': 'int'} + +# Basic date_add tests +query D +SELECT date_add('2016-07-30'::date, 1::int); +---- +2016-07-31 + +query D +SELECT date_add('2016-07-30'::date, arrow_cast(1, 'Int8')); +---- +2016-07-31 + +query D +SELECT date_add('2016-07-30'::date, arrow_cast(1, 'Int8')); +---- +2016-07-31 + +query D +SELECT date_sub('2016-07-30'::date, 0::int); +---- +2016-07-30 + +query error DataFusion error: Arrow error: Arithmetic overflow: date_add +SELECT date_add('2016-07-30'::date, 2147483647::int); + +query error DataFusion error: Arrow error: Arithmetic overflow: date_sub +SELECT date_sub('1969-01-01'::date, 2147483647::int); + +query D +SELECT date_add('2016-07-30'::date, 100000::int); +---- +2290-05-15 + +query D +SELECT date_sub('2016-07-30'::date, 100000::int); +---- +1742-10-15 + +# Test with negative day values (should subtract days) +query D +SELECT date_add('2016-07-30'::date, -5::int); +---- +2016-07-25 + +# Test with NULL values +query D +SELECT date_add(NULL::date, 1::int); +---- +NULL + +query D +SELECT date_add('2016-07-30'::date, NULL::int); +---- +NULL + +query D +SELECT date_add(NULL::date, NULL::int); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt new file mode 100644 index 0000000000000..c5871ab41e183 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_diff('2009-07-30', '2009-07-31'); +## PySpark 3.5.5 Result: {'date_diff(2009-07-30, 2009-07-31)': -1, 'typeof(date_diff(2009-07-30, 2009-07-31))': 'int', 'typeof(2009-07-30)': 'string', 'typeof(2009-07-31)': 'string'} +#query +#SELECT date_diff('2009-07-30'::string, '2009-07-31'::string); + +## Original Query: SELECT date_diff('2009-07-31', '2009-07-30'); +## PySpark 3.5.5 Result: {'date_diff(2009-07-31, 2009-07-30)': 1, 'typeof(date_diff(2009-07-31, 2009-07-30))': 'int', 'typeof(2009-07-31)': 'string', 'typeof(2009-07-30)': 'string'} +#query +#SELECT date_diff('2009-07-31'::string, '2009-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_format.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_format.slt new file mode 100644 index 0000000000000..1242518dee3f5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_format.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_format('2016-04-08', 'y'); +## PySpark 3.5.5 Result: {'date_format(2016-04-08, y)': '2016', 'typeof(date_format(2016-04-08, y))': 'string', 'typeof(2016-04-08)': 'string', 'typeof(y)': 'string'} +#query +#SELECT date_format('2016-04-08'::string, 'y'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt new file mode 100644 index 0000000000000..cd3271cdc7df8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_part('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); +## PySpark 3.5.5 Result: {"date_part(MINUTE, INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 55, "typeof(date_part(MINUTE, INTERVAL '123 23:55:59.002001' DAY TO SECOND))": 'tinyint', 'typeof(MINUTE)': 'string', "typeof(INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 'interval day to second'} +#query +#SELECT date_part('MINUTE'::string, INTERVAL '123 23:55:59.002001' DAY TO SECOND::interval day to second); + +## Original Query: SELECT date_part('MONTH', INTERVAL '2021-11' YEAR TO MONTH); +## PySpark 3.5.5 Result: {"date_part(MONTH, INTERVAL '2021-11' YEAR TO MONTH)": 11, "typeof(date_part(MONTH, INTERVAL '2021-11' YEAR TO MONTH))": 'tinyint', 'typeof(MONTH)': 'string', "typeof(INTERVAL '2021-11' YEAR TO MONTH)": 'interval year to month'} +#query +#SELECT date_part('MONTH'::string, INTERVAL '2021-11' YEAR TO MONTH::interval year to month); + +## Original Query: SELECT date_part('SECONDS', timestamp'2019-10-01 00:00:01.000001'); +## PySpark 3.5.5 Result: {"date_part(SECONDS, TIMESTAMP '2019-10-01 00:00:01.000001')": Decimal('1.000001'), "typeof(date_part(SECONDS, TIMESTAMP '2019-10-01 00:00:01.000001'))": 'decimal(8,6)', 'typeof(SECONDS)': 'string', "typeof(TIMESTAMP '2019-10-01 00:00:01.000001')": 'timestamp'} +#query +#SELECT date_part('SECONDS'::string, TIMESTAMP '2019-10-01 00:00:01.000001'::timestamp); + +## Original Query: SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'); +## PySpark 3.5.5 Result: {"date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456')": 2019, "typeof(date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(YEAR)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} +#query +#SELECT date_part('YEAR'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); + +## Original Query: SELECT date_part('days', interval 5 days 3 hours 7 minutes); +## PySpark 3.5.5 Result: {"date_part(days, INTERVAL '5 03:07' DAY TO MINUTE)": 5, "typeof(date_part(days, INTERVAL '5 03:07' DAY TO MINUTE))": 'int', 'typeof(days)': 'string', "typeof(INTERVAL '5 03:07' DAY TO MINUTE)": 'interval day to minute'} +#query +#SELECT date_part('days'::string, INTERVAL '5 03:07' DAY TO MINUTE::interval day to minute); + +## Original Query: SELECT date_part('doy', DATE'2019-08-12'); +## PySpark 3.5.5 Result: {"date_part(doy, DATE '2019-08-12')": 224, "typeof(date_part(doy, DATE '2019-08-12'))": 'int', 'typeof(doy)': 'string', "typeof(DATE '2019-08-12')": 'date'} +#query +#SELECT date_part('doy'::string, DATE '2019-08-12'::date); + +## Original Query: SELECT date_part('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); +## PySpark 3.5.5 Result: {"date_part(seconds, INTERVAL '05:00:30.001001' HOUR TO SECOND)": Decimal('30.001001'), "typeof(date_part(seconds, INTERVAL '05:00:30.001001' HOUR TO SECOND))": 'decimal(8,6)', 'typeof(seconds)': 'string', "typeof(INTERVAL '05:00:30.001001' HOUR TO SECOND)": 'interval hour to second'} +#query +#SELECT date_part('seconds'::string, INTERVAL '05:00:30.001001' HOUR TO SECOND::interval hour to second); + +## Original Query: SELECT date_part('week', timestamp'2019-08-12 01:00:00.123456'); +## PySpark 3.5.5 Result: {"date_part(week, TIMESTAMP '2019-08-12 01:00:00.123456')": 33, "typeof(date_part(week, TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(week)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} +#query +#SELECT date_part('week'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt new file mode 100644 index 0000000000000..cb5e77c3b4f1e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_sub('2016-07-30', 1); +## PySpark 3.5.5 Result: {'date_sub(2016-07-30, 1)': datetime.date(2016, 7, 29), 'typeof(date_sub(2016-07-30, 1))': 'date', 'typeof(2016-07-30)': 'string', 'typeof(1)': 'int'} + +# Basic date_sub tests +query D +SELECT date_sub('2016-07-30'::date, 1::int); +---- +2016-07-29 + +query D +SELECT date_sub('2016-07-30'::date, arrow_cast(1, 'Int8')); +---- +2016-07-29 + +query D +SELECT date_sub('2016-07-30'::date, arrow_cast(1, 'Int16')); +---- +2016-07-29 + +query D +SELECT date_sub('2016-07-30'::date, 0::int); +---- +2016-07-30 + +# Test with negative day values (should add days) +query D +SELECT date_sub('2016-07-30'::date, -1::int); +---- +2016-07-31 + +query D +SELECT date_sub('2016-07-30'::date, -5::int); +---- +2016-08-04 + +# Test with NULL values +query D +SELECT date_sub(NULL::date, 1::int); +---- +NULL + +query D +SELECT date_sub('2016-07-30'::date, NULL::int); +---- +NULL + +query D +SELECT date_sub(NULL::date, NULL::int); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt new file mode 100644 index 0000000000000..8a15254e6795e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_trunc('DD', '2015-03-05T09:32:05.359'); +## PySpark 3.5.5 Result: {'date_trunc(DD, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 5, 0, 0), 'typeof(date_trunc(DD, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(DD)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} +#query +#SELECT date_trunc('DD'::string, '2015-03-05T09:32:05.359'::string); + +## Original Query: SELECT date_trunc('HOUR', '2015-03-05T09:32:05.359'); +## PySpark 3.5.5 Result: {'date_trunc(HOUR, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 5, 9, 0), 'typeof(date_trunc(HOUR, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(HOUR)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} +#query +#SELECT date_trunc('HOUR'::string, '2015-03-05T09:32:05.359'::string); + +## Original Query: SELECT date_trunc('MILLISECOND', '2015-03-05T09:32:05.123456'); +## PySpark 3.5.5 Result: {'date_trunc(MILLISECOND, 2015-03-05T09:32:05.123456)': datetime.datetime(2015, 3, 5, 9, 32, 5, 123000), 'typeof(date_trunc(MILLISECOND, 2015-03-05T09:32:05.123456))': 'timestamp', 'typeof(MILLISECOND)': 'string', 'typeof(2015-03-05T09:32:05.123456)': 'string'} +#query +#SELECT date_trunc('MILLISECOND'::string, '2015-03-05T09:32:05.123456'::string); + +## Original Query: SELECT date_trunc('MM', '2015-03-05T09:32:05.359'); +## PySpark 3.5.5 Result: {'date_trunc(MM, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 1, 0, 0), 'typeof(date_trunc(MM, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(MM)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} +#query +#SELECT date_trunc('MM'::string, '2015-03-05T09:32:05.359'::string); + +## Original Query: SELECT date_trunc('YEAR', '2015-03-05T09:32:05.359'); +## PySpark 3.5.5 Result: {'date_trunc(YEAR, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 1, 1, 0, 0), 'typeof(date_trunc(YEAR, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(YEAR)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} +#query +#SELECT date_trunc('YEAR'::string, '2015-03-05T09:32:05.359'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/dateadd.slt b/datafusion/sqllogictest/test_files/spark/datetime/dateadd.slt new file mode 100644 index 0000000000000..c369989616f6c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/dateadd.slt @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT dateadd('2016-07-30', 1); +## PySpark 3.5.5 Result: {'date_add(2016-07-30, 1)': datetime.date(2016, 7, 31), 'typeof(date_add(2016-07-30, 1))': 'date', 'typeof(2016-07-30)': 'string', 'typeof(1)': 'int'} + +# Basic dateadd tests (alias for date_add) +query D +SELECT dateadd('2016-07-30'::date, 1::int); +---- +2016-07-31 + +query D +SELECT dateadd('2016-07-30'::date, 0::int); +---- +2016-07-30 + +# Test with negative day values (should subtract days) + +query D +SELECT dateadd('2016-07-30'::date, -5::int); +---- +2016-07-25 + +# Test with NULL values +query D +SELECT dateadd(NULL::date, 1::int); +---- +NULL + +query D +SELECT dateadd('2016-07-30'::date, NULL::int); +---- +NULL + +query D +SELECT dateadd(NULL::date, NULL::int); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt b/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt new file mode 100644 index 0000000000000..223e2c313ae86 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT datediff('2009-07-30', '2009-07-31'); +## PySpark 3.5.5 Result: {'datediff(2009-07-30, 2009-07-31)': -1, 'typeof(datediff(2009-07-30, 2009-07-31))': 'int', 'typeof(2009-07-30)': 'string', 'typeof(2009-07-31)': 'string'} +#query +#SELECT datediff('2009-07-30'::string, '2009-07-31'::string); + +## Original Query: SELECT datediff('2009-07-31', '2009-07-30'); +## PySpark 3.5.5 Result: {'datediff(2009-07-31, 2009-07-30)': 1, 'typeof(datediff(2009-07-31, 2009-07-30))': 'int', 'typeof(2009-07-31)': 'string', 'typeof(2009-07-30)': 'string'} +#query +#SELECT datediff('2009-07-31'::string, '2009-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt b/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt new file mode 100644 index 0000000000000..b2dd0089c2823 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT datepart('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); +## PySpark 3.5.5 Result: {"datepart(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 55, "typeof(datepart(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND))": 'tinyint', 'typeof(MINUTE)': 'string', "typeof(INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 'interval day to second'} +#query +#SELECT datepart('MINUTE'::string, INTERVAL '123 23:55:59.002001' DAY TO SECOND::interval day to second); + +## Original Query: SELECT datepart('MONTH', INTERVAL '2021-11' YEAR TO MONTH); +## PySpark 3.5.5 Result: {"datepart(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH)": 11, "typeof(datepart(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH))": 'tinyint', 'typeof(MONTH)': 'string', "typeof(INTERVAL '2021-11' YEAR TO MONTH)": 'interval year to month'} +#query +#SELECT datepart('MONTH'::string, INTERVAL '2021-11' YEAR TO MONTH::interval year to month); + +## Original Query: SELECT datepart('SECONDS', timestamp'2019-10-01 00:00:01.000001'); +## PySpark 3.5.5 Result: {"datepart(SECONDS FROM TIMESTAMP '2019-10-01 00:00:01.000001')": Decimal('1.000001'), "typeof(datepart(SECONDS FROM TIMESTAMP '2019-10-01 00:00:01.000001'))": 'decimal(8,6)', 'typeof(SECONDS)': 'string', "typeof(TIMESTAMP '2019-10-01 00:00:01.000001')": 'timestamp'} +#query +#SELECT datepart('SECONDS'::string, TIMESTAMP '2019-10-01 00:00:01.000001'::timestamp); + +## Original Query: SELECT datepart('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'); +## PySpark 3.5.5 Result: {"datepart(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456')": 2019, "typeof(datepart(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(YEAR)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} +#query +#SELECT datepart('YEAR'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); + +## Original Query: SELECT datepart('days', interval 5 days 3 hours 7 minutes); +## PySpark 3.5.5 Result: {"datepart(days FROM INTERVAL '5 03:07' DAY TO MINUTE)": 5, "typeof(datepart(days FROM INTERVAL '5 03:07' DAY TO MINUTE))": 'int', 'typeof(days)': 'string', "typeof(INTERVAL '5 03:07' DAY TO MINUTE)": 'interval day to minute'} +#query +#SELECT datepart('days'::string, INTERVAL '5 03:07' DAY TO MINUTE::interval day to minute); + +## Original Query: SELECT datepart('doy', DATE'2019-08-12'); +## PySpark 3.5.5 Result: {"datepart(doy FROM DATE '2019-08-12')": 224, "typeof(datepart(doy FROM DATE '2019-08-12'))": 'int', 'typeof(doy)': 'string', "typeof(DATE '2019-08-12')": 'date'} +#query +#SELECT datepart('doy'::string, DATE '2019-08-12'::date); + +## Original Query: SELECT datepart('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); +## PySpark 3.5.5 Result: {"datepart(seconds FROM INTERVAL '05:00:30.001001' HOUR TO SECOND)": Decimal('30.001001'), "typeof(datepart(seconds FROM INTERVAL '05:00:30.001001' HOUR TO SECOND))": 'decimal(8,6)', 'typeof(seconds)': 'string', "typeof(INTERVAL '05:00:30.001001' HOUR TO SECOND)": 'interval hour to second'} +#query +#SELECT datepart('seconds'::string, INTERVAL '05:00:30.001001' HOUR TO SECOND::interval hour to second); + +## Original Query: SELECT datepart('week', timestamp'2019-08-12 01:00:00.123456'); +## PySpark 3.5.5 Result: {"datepart(week FROM TIMESTAMP '2019-08-12 01:00:00.123456')": 33, "typeof(datepart(week FROM TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(week)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} +#query +#SELECT datepart('week'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/day.slt b/datafusion/sqllogictest/test_files/spark/datetime/day.slt new file mode 100644 index 0000000000000..35b73d67f5fd1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/day.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT day('2009-07-30'); +## PySpark 3.5.5 Result: {'day(2009-07-30)': 30, 'typeof(day(2009-07-30))': 'int', 'typeof(2009-07-30)': 'string'} +#query +#SELECT day('2009-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/dayofmonth.slt b/datafusion/sqllogictest/test_files/spark/datetime/dayofmonth.slt new file mode 100644 index 0000000000000..4e4e9ff4a23b3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/dayofmonth.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT dayofmonth('2009-07-30'); +## PySpark 3.5.5 Result: {'dayofmonth(2009-07-30)': 30, 'typeof(dayofmonth(2009-07-30))': 'int', 'typeof(2009-07-30)': 'string'} +#query +#SELECT dayofmonth('2009-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/dayofweek.slt b/datafusion/sqllogictest/test_files/spark/datetime/dayofweek.slt new file mode 100644 index 0000000000000..cc885818f62ff --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/dayofweek.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT dayofweek('2009-07-30'); +## PySpark 3.5.5 Result: {'dayofweek(2009-07-30)': 5, 'typeof(dayofweek(2009-07-30))': 'int', 'typeof(2009-07-30)': 'string'} +#query +#SELECT dayofweek('2009-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/dayofyear.slt b/datafusion/sqllogictest/test_files/spark/datetime/dayofyear.slt new file mode 100644 index 0000000000000..7ffab98dac84a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/dayofyear.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT dayofyear('2016-04-09'); +## PySpark 3.5.5 Result: {'dayofyear(2016-04-09)': 100, 'typeof(dayofyear(2016-04-09))': 'int', 'typeof(2016-04-09)': 'string'} +#query +#SELECT dayofyear('2016-04-09'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/hour.slt b/datafusion/sqllogictest/test_files/spark/datetime/hour.slt new file mode 100644 index 0000000000000..e129b271658ab --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/hour.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT hour('2009-07-30 12:58:59'); +## PySpark 3.5.5 Result: {'hour(2009-07-30 12:58:59)': 12, 'typeof(hour(2009-07-30 12:58:59))': 'int', 'typeof(2009-07-30 12:58:59)': 'string'} +#query +#SELECT hour('2009-07-30 12:58:59'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/last_day.slt b/datafusion/sqllogictest/test_files/spark/datetime/last_day.slt new file mode 100644 index 0000000000000..6dee48de9555d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/last_day.slt @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +query D +SELECT last_day('2009-01-12'::DATE); +---- +2009-01-31 + + +query D +SELECT last_day('2015-02-28'::DATE); +---- +2015-02-28 + +query D +SELECT last_day('2015-03-27'::DATE); +---- +2015-03-31 + +query D +SELECT last_day('2015-04-26'::DATE); +---- +2015-04-30 + +query D +SELECT last_day('2015-05-25'::DATE); +---- +2015-05-31 + +query D +SELECT last_day('2015-06-24'::DATE); +---- +2015-06-30 + +query D +SELECT last_day('2015-07-23'::DATE); +---- +2015-07-31 + +query D +SELECT last_day('2015-08-01'::DATE); +---- +2015-08-31 + +query D +SELECT last_day('2015-09-02'::DATE); +---- +2015-09-30 + +query D +SELECT last_day('2015-10-03'::DATE); +---- +2015-10-31 + +query D +SELECT last_day('2015-11-04'::DATE); +---- +2015-11-30 + +query D +SELECT last_day('2015-12-05'::DATE); +---- +2015-12-31 + + +query D +SELECT last_day('2016-01-06'::DATE); +---- +2016-01-31 + +query D +SELECT last_day('2016-02-07'::DATE); +---- +2016-02-29 + + +query D +SELECT last_day(null::DATE); +---- +NULL + + +statement error Failed to coerce arguments to satisfy a call to 'last_day' function +select last_day('foo'); + + +statement error Failed to coerce arguments to satisfy a call to 'last_day' function +select last_day(123); + + +statement error 'last_day' does not support zero arguments +select last_day(); + +statement error Failed to coerce arguments to satisfy a call to 'last_day' function +select last_day(last_day('2016-02-07'::string, 'foo')); + +statement error Failed to coerce arguments to satisfy a call to 'last_day' function +select last_day(last_day('2016-02-31'::string)); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/localtimestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/localtimestamp.slt new file mode 100644 index 0000000000000..36fd451382d04 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/localtimestamp.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT localtimestamp(); +## PySpark 3.5.5 Result: {'localtimestamp()': datetime.datetime(2025, 6, 14, 23, 57, 39, 529742), 'typeof(localtimestamp())': 'timestamp_ntz'} +#query +#SELECT localtimestamp(); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_date.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_date.slt new file mode 100644 index 0000000000000..b95347f976e95 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_date.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_date(2013, 7, 15); +## PySpark 3.5.5 Result: {'make_date(2013, 7, 15)': datetime.date(2013, 7, 15), 'typeof(make_date(2013, 7, 15))': 'date', 'typeof(2013)': 'int', 'typeof(7)': 'int', 'typeof(15)': 'int'} +#query +#SELECT make_date(2013::int, 7::int, 15::int); + +## Original Query: SELECT make_date(2019, 7, NULL); +## PySpark 3.5.5 Result: {'make_date(2019, 7, NULL)': None, 'typeof(make_date(2019, 7, NULL))': 'date', 'typeof(2019)': 'int', 'typeof(7)': 'int', 'typeof(NULL)': 'void'} +#query +#SELECT make_date(2019::int, 7::int, NULL::void); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_dt_interval.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_dt_interval.slt new file mode 100644 index 0000000000000..dc6c33caa9b4c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_dt_interval.slt @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_dt_interval(1, 12, 30, 01.001001); +## PySpark 3.5.5 Result: {'make_dt_interval(1, 12, 30, 1.001001)': datetime.timedelta(days=1, seconds=45001, microseconds=1001), 'typeof(make_dt_interval(1, 12, 30, 1.001001))': 'interval day to second', 'typeof(1)': 'int', 'typeof(12)': 'int', 'typeof(30)': 'int', 'typeof(1.001001)': 'decimal(7,6)'} +query ? +SELECT make_dt_interval(1::int, 12::int, 30::int, 1.001001::decimal(7,6)); +---- +1 days 12 hours 30 mins 1.001001 secs + +## Original Query: SELECT make_dt_interval(100, null, 3); +## PySpark 3.5.5 Result: {'make_dt_interval(100, NULL, 3, 0.000000)': None, 'typeof(make_dt_interval(100, NULL, 3, 0.000000))': 'interval day to second', 'typeof(100)': 'int', 'typeof(NULL)': 'void', 'typeof(3)': 'int'} +query ? +SELECT make_dt_interval(100::int, NULL, 3::int); +---- +NULL + +## Original Query: SELECT make_dt_interval(2); +## PySpark 3.5.5 Result: {'make_dt_interval(2, 0, 0, 0.000000)': datetime.timedelta(days=2), 'typeof(make_dt_interval(2, 0, 0, 0.000000))': 'interval day to second', 'typeof(2)': 'int'} +query ? +SELECT make_dt_interval(2::int); +---- +2 days 0 hours 0 mins 0.000000 secs + +# null +query ? +SELECT (make_dt_interval(null, 0, 0, 0)) +---- +NULL + +query ? +SELECT (make_dt_interval(0, null, 0, 0)) +---- +NULL + +query ? +SELECT (make_dt_interval(0, 0, null, 0)) +---- +NULL + +query ? +SELECT (make_dt_interval(0, 0, 0, null)) +---- +NULL + +# missing params +query ? +SELECT (make_dt_interval()) AS make_dt_interval +---- +0 days 0 hours 0 mins 0.000000 secs + +query ? +SELECT (make_dt_interval(1)) AS make_dt_interval +---- +1 days 0 hours 0 mins 0.000000 secs + +query ? +SELECT (make_dt_interval(1, 1)) AS make_dt_interval +---- +1 days 1 hours 0 mins 0.000000 secs + +query ? +SELECT (make_dt_interval(1, 1, 1)) AS make_dt_interval +---- +1 days 1 hours 1 mins 0.000000 secs + +query ? +SELECT (make_dt_interval(1, 1, 1, 1)) AS make_dt_interval +---- +1 days 1 hours 1 mins 1.000000 secs + + +# all 0 values +query ? +SELECT (make_dt_interval(0, 0, 0, 0)) +---- +0 days 0 hours 0 mins 0.000000 secs + +query ? +SELECT (make_dt_interval(-1, 24, 0, 0)) df +---- +0 days 0 hours 0 mins 0.000000 secs + +query ? +SELECT (make_dt_interval(1, -24, 0, 0)) dt +---- +0 days 0 hours 0 mins 0.000000 secs + +query ? +SELECT (make_dt_interval(0, 0, 0, 0.1)) +---- +0 days 0 hours 0 mins 0.100000 secs + + +# doctest https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.make_dt_interval.html +# extract only the value make_dt_interval + +query ? +SELECT MAKE_DT_INTERVAL(day) AS interval_val +FROM VALUES (1, 12, 30, 1.001001) AS t(day, hour, min, sec); +---- +1 days 0 hours 0 mins 0.000000 secs + +query ? +SELECT MAKE_DT_INTERVAL(day, hour) AS interval_val +FROM VALUES (1, 12, 30, 1.001001) AS t(day, hour, min, sec); +---- +1 days 12 hours 0 mins 0.000000 secs + +query ? +SELECT MAKE_DT_INTERVAL(day, hour, min) AS interval_val +FROM VALUES (1, 12, 30, 1.001001) AS t(day, hour, min, sec); +---- +1 days 12 hours 30 mins 0.000000 secs + +query ? +SELECT MAKE_DT_INTERVAL(day, hour, min, sec) AS interval_val +FROM VALUES (1, 12, 30, 1.001001) AS t(day, hour, min, sec); +---- +1 days 12 hours 30 mins 1.001001 secs + +query ? +SELECT MAKE_DT_INTERVAL(1, 12, 30, 1.001001) +---- +1 days 12 hours 30 mins 1.001001 secs + +query ? +SELECT MAKE_DT_INTERVAL(1, 12, 30, 1.001001); +---- +1 days 12 hours 30 mins 1.001001 secs diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_interval.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_interval.slt new file mode 100644 index 0000000000000..d6c5199b87b75 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_interval.slt @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +query IIIIIIR? +SELECT + y, m, w, d, h, mi, s, + make_interval(y, m, w, d, h, mi, s) AS interval +FROM VALUES + (NULL,2, 3, 4, 5, 6, 7.5), + (1, NULL,3, 4, 5, 6, 7.5), + (1, 2, NULL,4, 5, 6, 7.5), + (1, 2, 3, NULL,5, 6, 7.5), + (1, 2, 3, 4, NULL,6, 7.5), + (1, 2, 3, 4, 5, NULL,7.5), + (1, 2, 3, 4, 5, 6, CAST(NULL AS DOUBLE)), + (1, 1, 1, 1, 1, 1, 1.0) +AS v(y, m, w, d, h, mi, s); +---- +NULL 2 3 4 5 6 7.5 NULL +1 NULL 3 4 5 6 7.5 NULL +1 2 NULL 4 5 6 7.5 NULL +1 2 3 NULL 5 6 7.5 NULL +1 2 3 4 NULL 6 7.5 NULL +1 2 3 4 5 NULL 7.5 NULL +1 2 3 4 5 6 NULL NULL +1 1 1 1 1 1 1 13 mons 8 days 1 hours 1 mins 1.000000000 secs + +query IIIIIIR? +SELECT + y, m, w, d, h, mi, s, + make_interval(y, m, w, d, h, mi, s) AS interval +FROM VALUES + (0, 0, 0, 0, 0, 0, arrow_cast('NaN','Float64')) +AS v(y, m, w, d, h, mi, s); +---- +0 0 0 0 0 0 NaN NULL + +query IIIIIIR? +SELECT + y, m, w, d, h, mi, s, + make_interval(y, m, w, d, h, mi, s) AS interval +FROM VALUES + (0, 0, 0, 0, 0, 0, CAST('Infinity' AS DOUBLE)) +AS v(y, m, w, d, h, mi, s); +---- +0 0 0 0 0 0 Infinity NULL + +query IIIIIIR? +SELECT + y, m, w, d, h, mi, s, + make_interval(y, m, w, d, h, mi, s) AS interval +FROM VALUES + (0, 0, 0, 0, 0, 0, CAST('-Infinity' AS DOUBLE)) +AS v(y, m, w, d, h, mi, s); +---- +0 0 0 0 0 0 -Infinity NULL + +query ? +SELECT make_interval(2147483647, 1, 0, 0, 0, 0, 0.0); +---- +NULL + +query ? +SELECT make_interval(0, 0, 2147483647, 1, 0, 0, 0.0); +---- +NULL + +query ? +SELECT make_interval(0, 0, 0, 0, 2147483647, 1, 0.0); +---- +NULL + +# Intervals being rendered as empty string, see issue: +# https://github.com/apache/datafusion/issues/17455 +# We expect something like 0.00 secs with query ? +query T +SELECT make_interval(0, 0, 0, 0, 0, 0, 0.0) || ''; +---- +(empty) + +# Intervals being rendered as empty string, see issue: +# https://github.com/apache/datafusion/issues/17455 +# We expect something like 0.00 secs with query ? +query T +SELECT make_interval() || ''; +---- +(empty) + +query ? +SELECT INTERVAL '1' SECOND AS iv; +---- +1.000000000 secs diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp.slt new file mode 100644 index 0000000000000..262154186c8e0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_timestamp(2014, 12, 28, 6, 30, 45.887); +## PySpark 3.5.5 Result: {'make_timestamp(2014, 12, 28, 6, 30, 45.887)': datetime.datetime(2014, 12, 28, 6, 30, 45, 887000), 'typeof(make_timestamp(2014, 12, 28, 6, 30, 45.887))': 'timestamp', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)'} +#query +#SELECT make_timestamp(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3)); + +## Original Query: SELECT make_timestamp(2014, 12, 28, 6, 30, 45.887, 'CET'); +## PySpark 3.5.5 Result: {'make_timestamp(2014, 12, 28, 6, 30, 45.887, CET)': datetime.datetime(2014, 12, 27, 21, 30, 45, 887000), 'typeof(make_timestamp(2014, 12, 28, 6, 30, 45.887, CET))': 'timestamp', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)', 'typeof(CET)': 'string'} +#query +#SELECT make_timestamp(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3), 'CET'::string); + +## Original Query: SELECT make_timestamp(2019, 6, 30, 23, 59, 1); +## PySpark 3.5.5 Result: {'make_timestamp(2019, 6, 30, 23, 59, 1)': datetime.datetime(2019, 6, 30, 23, 59, 1), 'typeof(make_timestamp(2019, 6, 30, 23, 59, 1))': 'timestamp', 'typeof(2019)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(23)': 'int', 'typeof(59)': 'int', 'typeof(1)': 'int'} +#query +#SELECT make_timestamp(2019::int, 6::int, 30::int, 23::int, 59::int, 1::int); + +## Original Query: SELECT make_timestamp(2019, 6, 30, 23, 59, 60); +## PySpark 3.5.5 Result: {'make_timestamp(2019, 6, 30, 23, 59, 60)': datetime.datetime(2019, 7, 1, 0, 0), 'typeof(make_timestamp(2019, 6, 30, 23, 59, 60))': 'timestamp', 'typeof(2019)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(23)': 'int', 'typeof(59)': 'int', 'typeof(60)': 'int'} +#query +#SELECT make_timestamp(2019::int, 6::int, 30::int, 23::int, 59::int, 60::int); + +## Original Query: SELECT make_timestamp(null, 7, 22, 15, 30, 0); +## PySpark 3.5.5 Result: {'make_timestamp(NULL, 7, 22, 15, 30, 0)': None, 'typeof(make_timestamp(NULL, 7, 22, 15, 30, 0))': 'timestamp', 'typeof(NULL)': 'void', 'typeof(7)': 'int', 'typeof(22)': 'int', 'typeof(15)': 'int', 'typeof(30)': 'int', 'typeof(0)': 'int'} +#query +#SELECT make_timestamp(NULL::void, 7::int, 22::int, 15::int, 30::int, 0::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ltz.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ltz.slt new file mode 100644 index 0000000000000..ce5e07f663c4d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ltz.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887); +## PySpark 3.5.5 Result: {'make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887)': datetime.datetime(2014, 12, 28, 6, 30, 45, 887000), 'typeof(make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887))': 'timestamp', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)'} +#query +#SELECT make_timestamp_ltz(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3)); + +## Original Query: SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887, 'CET'); +## PySpark 3.5.5 Result: {'make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887, CET)': datetime.datetime(2014, 12, 27, 21, 30, 45, 887000), 'typeof(make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887, CET))': 'timestamp', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)', 'typeof(CET)': 'string'} +#query +#SELECT make_timestamp_ltz(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3), 'CET'::string); + +## Original Query: SELECT make_timestamp_ltz(2019, 6, 30, 23, 59, 60); +## PySpark 3.5.5 Result: {'make_timestamp_ltz(2019, 6, 30, 23, 59, 60)': datetime.datetime(2019, 7, 1, 0, 0), 'typeof(make_timestamp_ltz(2019, 6, 30, 23, 59, 60))': 'timestamp', 'typeof(2019)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(23)': 'int', 'typeof(59)': 'int', 'typeof(60)': 'int'} +#query +#SELECT make_timestamp_ltz(2019::int, 6::int, 30::int, 23::int, 59::int, 60::int); + +## Original Query: SELECT make_timestamp_ltz(null, 7, 22, 15, 30, 0); +## PySpark 3.5.5 Result: {'make_timestamp_ltz(NULL, 7, 22, 15, 30, 0)': None, 'typeof(make_timestamp_ltz(NULL, 7, 22, 15, 30, 0))': 'timestamp', 'typeof(NULL)': 'void', 'typeof(7)': 'int', 'typeof(22)': 'int', 'typeof(15)': 'int', 'typeof(30)': 'int', 'typeof(0)': 'int'} +#query +#SELECT make_timestamp_ltz(NULL::void, 7::int, 22::int, 15::int, 30::int, 0::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ntz.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ntz.slt new file mode 100644 index 0000000000000..fbbe37655eb7a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ntz.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887); +## PySpark 3.5.5 Result: {'make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887)': datetime.datetime(2014, 12, 28, 6, 30, 45, 887000), 'typeof(make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887))': 'timestamp_ntz', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)'} +#query +#SELECT make_timestamp_ntz(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3)); + +## Original Query: SELECT make_timestamp_ntz(2019, 6, 30, 23, 59, 60); +## PySpark 3.5.5 Result: {'make_timestamp_ntz(2019, 6, 30, 23, 59, 60)': datetime.datetime(2019, 7, 1, 0, 0), 'typeof(make_timestamp_ntz(2019, 6, 30, 23, 59, 60))': 'timestamp_ntz', 'typeof(2019)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(23)': 'int', 'typeof(59)': 'int', 'typeof(60)': 'int'} +#query +#SELECT make_timestamp_ntz(2019::int, 6::int, 30::int, 23::int, 59::int, 60::int); + +## Original Query: SELECT make_timestamp_ntz(null, 7, 22, 15, 30, 0); +## PySpark 3.5.5 Result: {'make_timestamp_ntz(NULL, 7, 22, 15, 30, 0)': None, 'typeof(make_timestamp_ntz(NULL, 7, 22, 15, 30, 0))': 'timestamp_ntz', 'typeof(NULL)': 'void', 'typeof(7)': 'int', 'typeof(22)': 'int', 'typeof(15)': 'int', 'typeof(30)': 'int', 'typeof(0)': 'int'} +#query +#SELECT make_timestamp_ntz(NULL::void, 7::int, 22::int, 15::int, 30::int, 0::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_ym_interval.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_ym_interval.slt new file mode 100644 index 0000000000000..9429a3a5306ed --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_ym_interval.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_ym_interval(-1, 1); +## PySpark 3.5.5 Result: {'make_ym_interval(-1, 1)': -11, 'typeof(make_ym_interval(-1, 1))': 'interval year to month', 'typeof(-1)': 'int', 'typeof(1)': 'int'} +#query +#SELECT make_ym_interval(-1::int, 1::int); + +## Original Query: SELECT make_ym_interval(1, 0); +## PySpark 3.5.5 Result: {'make_ym_interval(1, 0)': 12, 'typeof(make_ym_interval(1, 0))': 'interval year to month', 'typeof(1)': 'int', 'typeof(0)': 'int'} +#query +#SELECT make_ym_interval(1::int, 0::int); + +## Original Query: SELECT make_ym_interval(1, 2); +## PySpark 3.5.5 Result: {'make_ym_interval(1, 2)': 14, 'typeof(make_ym_interval(1, 2))': 'interval year to month', 'typeof(1)': 'int', 'typeof(2)': 'int'} +#query +#SELECT make_ym_interval(1::int, 2::int); + +## Original Query: SELECT make_ym_interval(2); +## PySpark 3.5.5 Result: {'make_ym_interval(2, 0)': 24, 'typeof(make_ym_interval(2, 0))': 'interval year to month', 'typeof(2)': 'int'} +#query +#SELECT make_ym_interval(2::int); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/minute.slt b/datafusion/sqllogictest/test_files/spark/datetime/minute.slt new file mode 100644 index 0000000000000..dbe1e64be8377 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/minute.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT minute('2009-07-30 12:58:59'); +## PySpark 3.5.5 Result: {'minute(2009-07-30 12:58:59)': 58, 'typeof(minute(2009-07-30 12:58:59))': 'int', 'typeof(2009-07-30 12:58:59)': 'string'} +#query +#SELECT minute('2009-07-30 12:58:59'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/month.slt b/datafusion/sqllogictest/test_files/spark/datetime/month.slt new file mode 100644 index 0000000000000..17a34352d16f3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/month.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT month('2016-07-30'); +## PySpark 3.5.5 Result: {'month(2016-07-30)': 7, 'typeof(month(2016-07-30))': 'int', 'typeof(2016-07-30)': 'string'} +#query +#SELECT month('2016-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/months_between.slt b/datafusion/sqllogictest/test_files/spark/datetime/months_between.slt new file mode 100644 index 0000000000000..c2526761655db --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/months_between.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT months_between('1997-02-28 10:30:00', '1996-10-30'); +## PySpark 3.5.5 Result: {'months_between(1997-02-28 10:30:00, 1996-10-30, true)': 3.94959677, 'typeof(months_between(1997-02-28 10:30:00, 1996-10-30, true))': 'double', 'typeof(1997-02-28 10:30:00)': 'string', 'typeof(1996-10-30)': 'string'} +#query +#SELECT months_between('1997-02-28 10:30:00'::string, '1996-10-30'::string); + +## Original Query: SELECT months_between('1997-02-28 10:30:00', '1996-10-30', false); +## PySpark 3.5.5 Result: {'months_between(1997-02-28 10:30:00, 1996-10-30, false)': 3.9495967741935485, 'typeof(months_between(1997-02-28 10:30:00, 1996-10-30, false))': 'double', 'typeof(1997-02-28 10:30:00)': 'string', 'typeof(1996-10-30)': 'string', 'typeof(false)': 'boolean'} +#query +#SELECT months_between('1997-02-28 10:30:00'::string, '1996-10-30'::string, false::boolean); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/next_day.slt b/datafusion/sqllogictest/test_files/spark/datetime/next_day.slt new file mode 100644 index 0000000000000..872d1f2b58eb6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/next_day.slt @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +query D +SELECT next_day('2015-01-14'::DATE, 'TU'::string); +---- +2015-01-20 + +query D +SELECT next_day('2015-07-27'::DATE, 'Sun'::string); +---- +2015-08-02 + +query D +SELECT next_day('2015-07-27'::DATE, 'Sat'::string); +---- +2015-08-01 + +query error Failed to coerce arguments to satisfy a call to 'next_day' function +SELECT next_day('2015-07-27'::DATE); + +query error Failed to coerce arguments to satisfy a call to 'next_day' function +SELECT next_day('Sun'::string); + +query error 'next_day' does not support zero arguments +SELECT next_day(); + +query error Failed to coerce arguments to satisfy a call to 'next_day' function +SELECT next_day(1::int, 'Sun'::string); + +query error Failed to coerce arguments to satisfy a call to 'next_day' function +SELECT next_day('2015-07-27'::DATE, 'Sat'::string, 'Sun'::string); + +query error Failed to coerce arguments to satisfy a call to 'next_day' function +SELECT next_day('invalid_date'::string, 'Mon'::string); + +query D +SELECT next_day('2000-01-01'::DATE, 2.0::float); +---- +NULL + +query D +SELECT next_day('2020-01-01'::DATE, 'invalid_day'::string); +---- +NULL + +query error Cast error: Cannot cast string '2015-13-32' to value of Date32 type +SELECT next_day('2015-13-32'::DATE, 'Sun'::string); + +query D +SELECT next_day(a, b) +FROM VALUES + ('2000-01-01'::DATE, 'Mon'::string), + (NULL::DATE, NULL::string), + (NULL::DATE, 'Mon'::string), + ('2015-01-14'::DATE, NULL::string) as t(a, b); +---- +2000-01-03 +NULL +NULL +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/now.slt b/datafusion/sqllogictest/test_files/spark/datetime/now.slt new file mode 100644 index 0000000000000..985140c1ac442 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/now.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT now(); +## PySpark 3.5.5 Result: {'now()': datetime.datetime(2025, 6, 14, 23, 57, 39, 982956), 'typeof(now())': 'timestamp'} +#query +#SELECT now(); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/quarter.slt b/datafusion/sqllogictest/test_files/spark/datetime/quarter.slt new file mode 100644 index 0000000000000..27b6728b0b7bb --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/quarter.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT quarter('2016-08-31'); +## PySpark 3.5.5 Result: {'quarter(2016-08-31)': 3, 'typeof(quarter(2016-08-31))': 'int', 'typeof(2016-08-31)': 'string'} +#query +#SELECT quarter('2016-08-31'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/second.slt b/datafusion/sqllogictest/test_files/spark/datetime/second.slt new file mode 100644 index 0000000000000..f69c9af4a62d9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/second.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT second('2009-07-30 12:58:59'); +## PySpark 3.5.5 Result: {'second(2009-07-30 12:58:59)': 59, 'typeof(second(2009-07-30 12:58:59))': 'int', 'typeof(2009-07-30 12:58:59)': 'string'} +#query +#SELECT second('2009-07-30 12:58:59'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/timestamp_micros.slt b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_micros.slt new file mode 100644 index 0000000000000..19a52c981075f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_micros.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT timestamp_micros(1230219000123123); +## PySpark 3.5.5 Result: {'timestamp_micros(1230219000123123)': datetime.datetime(2008, 12, 25, 7, 30, 0, 123123), 'typeof(timestamp_micros(1230219000123123))': 'timestamp', 'typeof(1230219000123123)': 'bigint'} +#query +#SELECT timestamp_micros(1230219000123123::bigint); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/timestamp_millis.slt b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_millis.slt new file mode 100644 index 0000000000000..7dc092549fffa --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_millis.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT timestamp_millis(1230219000123); +## PySpark 3.5.5 Result: {'timestamp_millis(1230219000123)': datetime.datetime(2008, 12, 25, 7, 30, 0, 123000), 'typeof(timestamp_millis(1230219000123))': 'timestamp', 'typeof(1230219000123)': 'bigint'} +#query +#SELECT timestamp_millis(1230219000123::bigint); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/timestamp_seconds.slt b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_seconds.slt new file mode 100644 index 0000000000000..8e14c1dfe1f2b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_seconds.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT timestamp_seconds(1230219000); +## PySpark 3.5.5 Result: {'timestamp_seconds(1230219000)': datetime.datetime(2008, 12, 25, 7, 30), 'typeof(timestamp_seconds(1230219000))': 'timestamp', 'typeof(1230219000)': 'int'} +#query +#SELECT timestamp_seconds(1230219000::int); + +## Original Query: SELECT timestamp_seconds(1230219000.123); +## PySpark 3.5.5 Result: {'timestamp_seconds(1230219000.123)': datetime.datetime(2008, 12, 25, 7, 30, 0, 123000), 'typeof(timestamp_seconds(1230219000.123))': 'timestamp', 'typeof(1230219000.123)': 'decimal(13,3)'} +#query +#SELECT timestamp_seconds(1230219000.123::decimal(13,3)); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_date.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_date.slt new file mode 100644 index 0000000000000..3863cfb2baae7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_date.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_date('2009-07-30 04:17:52'); +## PySpark 3.5.5 Result: {'to_date(2009-07-30 04:17:52)': datetime.date(2009, 7, 30), 'typeof(to_date(2009-07-30 04:17:52))': 'date', 'typeof(2009-07-30 04:17:52)': 'string'} +#query +#SELECT to_date('2009-07-30 04:17:52'::string); + +## Original Query: SELECT to_date('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_date(2016-12-31, yyyy-MM-dd)': datetime.date(2016, 12, 31), 'typeof(to_date(2016-12-31, yyyy-MM-dd))': 'date', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_date('2016-12-31'::string, 'yyyy-MM-dd'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp.slt new file mode 100644 index 0000000000000..39f77620fa771 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_timestamp('2016-12-31 00:12:00'); +## PySpark 3.5.5 Result: {'to_timestamp(2016-12-31 00:12:00)': datetime.datetime(2016, 12, 31, 0, 12), 'typeof(to_timestamp(2016-12-31 00:12:00))': 'timestamp', 'typeof(2016-12-31 00:12:00)': 'string'} +#query +#SELECT to_timestamp('2016-12-31 00:12:00'::string); + +## Original Query: SELECT to_timestamp('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_timestamp(2016-12-31, yyyy-MM-dd)': datetime.datetime(2016, 12, 31, 0, 0), 'typeof(to_timestamp(2016-12-31, yyyy-MM-dd))': 'timestamp', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_timestamp('2016-12-31'::string, 'yyyy-MM-dd'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ltz.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ltz.slt new file mode 100644 index 0000000000000..c7c43a2bcc56d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ltz.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_timestamp_ltz('2016-12-31 00:12:00'); +## PySpark 3.5.5 Result: {'to_timestamp_ltz(2016-12-31 00:12:00)': datetime.datetime(2016, 12, 31, 0, 12), 'typeof(to_timestamp_ltz(2016-12-31 00:12:00))': 'timestamp', 'typeof(2016-12-31 00:12:00)': 'string'} +#query +#SELECT to_timestamp_ltz('2016-12-31 00:12:00'::string); + +## Original Query: SELECT to_timestamp_ltz('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_timestamp_ltz(2016-12-31, yyyy-MM-dd)': datetime.datetime(2016, 12, 31, 0, 0), 'typeof(to_timestamp_ltz(2016-12-31, yyyy-MM-dd))': 'timestamp', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_timestamp_ltz('2016-12-31'::string, 'yyyy-MM-dd'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ntz.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ntz.slt new file mode 100644 index 0000000000000..11c4e4cbe257f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ntz.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_timestamp_ntz('2016-12-31 00:12:00'); +## PySpark 3.5.5 Result: {'to_timestamp_ntz(2016-12-31 00:12:00)': datetime.datetime(2016, 12, 31, 0, 12), 'typeof(to_timestamp_ntz(2016-12-31 00:12:00))': 'timestamp_ntz', 'typeof(2016-12-31 00:12:00)': 'string'} +#query +#SELECT to_timestamp_ntz('2016-12-31 00:12:00'::string); + +## Original Query: SELECT to_timestamp_ntz('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_timestamp_ntz(2016-12-31, yyyy-MM-dd)': datetime.datetime(2016, 12, 31, 0, 0), 'typeof(to_timestamp_ntz(2016-12-31, yyyy-MM-dd))': 'timestamp_ntz', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_timestamp_ntz('2016-12-31'::string, 'yyyy-MM-dd'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_unix_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_unix_timestamp.slt new file mode 100644 index 0000000000000..53c1902094a50 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_unix_timestamp.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_unix_timestamp('2016-04-08', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_unix_timestamp(2016-04-08, yyyy-MM-dd)': 1460098800, 'typeof(to_unix_timestamp(2016-04-08, yyyy-MM-dd))': 'bigint', 'typeof(2016-04-08)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_unix_timestamp('2016-04-08'::string, 'yyyy-MM-dd'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt new file mode 100644 index 0000000000000..24693016be1a7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_utc_timestamp('2016-08-31', 'Asia/Seoul'); +## PySpark 3.5.5 Result: {'to_utc_timestamp(2016-08-31, Asia/Seoul)': datetime.datetime(2016, 8, 30, 15, 0), 'typeof(to_utc_timestamp(2016-08-31, Asia/Seoul))': 'timestamp', 'typeof(2016-08-31)': 'string', 'typeof(Asia/Seoul)': 'string'} +#query +#SELECT to_utc_timestamp('2016-08-31'::string, 'Asia/Seoul'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt new file mode 100644 index 0000000000000..a502e2f7f7b00 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT trunc('2009-02-12', 'MM'); +## PySpark 3.5.5 Result: {'trunc(2009-02-12, MM)': datetime.date(2009, 2, 1), 'typeof(trunc(2009-02-12, MM))': 'date', 'typeof(2009-02-12)': 'string', 'typeof(MM)': 'string'} +#query +#SELECT trunc('2009-02-12'::string, 'MM'::string); + +## Original Query: SELECT trunc('2015-10-27', 'YEAR'); +## PySpark 3.5.5 Result: {'trunc(2015-10-27, YEAR)': datetime.date(2015, 1, 1), 'typeof(trunc(2015-10-27, YEAR))': 'date', 'typeof(2015-10-27)': 'string', 'typeof(YEAR)': 'string'} +#query +#SELECT trunc('2015-10-27'::string, 'YEAR'::string); + +## Original Query: SELECT trunc('2019-08-04', 'quarter'); +## PySpark 3.5.5 Result: {'trunc(2019-08-04, quarter)': datetime.date(2019, 7, 1), 'typeof(trunc(2019-08-04, quarter))': 'date', 'typeof(2019-08-04)': 'string', 'typeof(quarter)': 'string'} +#query +#SELECT trunc('2019-08-04'::string, 'quarter'::string); + +## Original Query: SELECT trunc('2019-08-04', 'week'); +## PySpark 3.5.5 Result: {'trunc(2019-08-04, week)': datetime.date(2019, 7, 29), 'typeof(trunc(2019-08-04, week))': 'date', 'typeof(2019-08-04)': 'string', 'typeof(week)': 'string'} +#query +#SELECT trunc('2019-08-04'::string, 'week'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/try_to_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/try_to_timestamp.slt new file mode 100644 index 0000000000000..23b788125ed0e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/try_to_timestamp.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_to_timestamp('2016-12-31 00:12:00'); +## PySpark 3.5.5 Result: {'try_to_timestamp(2016-12-31 00:12:00)': datetime.datetime(2016, 12, 31, 0, 12), 'typeof(try_to_timestamp(2016-12-31 00:12:00))': 'timestamp', 'typeof(2016-12-31 00:12:00)': 'string'} +#query +#SELECT try_to_timestamp('2016-12-31 00:12:00'::string); + +## Original Query: SELECT try_to_timestamp('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'try_to_timestamp(2016-12-31, yyyy-MM-dd)': datetime.datetime(2016, 12, 31, 0, 0), 'typeof(try_to_timestamp(2016-12-31, yyyy-MM-dd))': 'timestamp', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT try_to_timestamp('2016-12-31'::string, 'yyyy-MM-dd'::string); + +## Original Query: SELECT try_to_timestamp('foo', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'try_to_timestamp(foo, yyyy-MM-dd)': None, 'typeof(try_to_timestamp(foo, yyyy-MM-dd))': 'timestamp', 'typeof(foo)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT try_to_timestamp('foo'::string, 'yyyy-MM-dd'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/unix_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/unix_timestamp.slt new file mode 100644 index 0000000000000..bc597912bc85b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/unix_timestamp.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT unix_timestamp('2016-04-08', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'unix_timestamp(2016-04-08, yyyy-MM-dd)': 1460098800, 'typeof(unix_timestamp(2016-04-08, yyyy-MM-dd))': 'bigint', 'typeof(2016-04-08)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT unix_timestamp('2016-04-08'::string, 'yyyy-MM-dd'::string); + +## Original Query: SELECT unix_timestamp(); +## PySpark 3.5.5 Result: {'unix_timestamp(current_timestamp(), yyyy-MM-dd HH:mm:ss)': 1749970660, 'typeof(unix_timestamp(current_timestamp(), yyyy-MM-dd HH:mm:ss))': 'bigint'} +#query +#SELECT unix_timestamp(); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/weekday.slt b/datafusion/sqllogictest/test_files/spark/datetime/weekday.slt new file mode 100644 index 0000000000000..b4f5444e8a2da --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/weekday.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT weekday('2009-07-30'); +## PySpark 3.5.5 Result: {'weekday(2009-07-30)': 3, 'typeof(weekday(2009-07-30))': 'int', 'typeof(2009-07-30)': 'string'} +#query +#SELECT weekday('2009-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/weekofyear.slt b/datafusion/sqllogictest/test_files/spark/datetime/weekofyear.slt new file mode 100644 index 0000000000000..30e69341d97d1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/weekofyear.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT weekofyear('2008-02-20'); +## PySpark 3.5.5 Result: {'weekofyear(2008-02-20)': 8, 'typeof(weekofyear(2008-02-20))': 'int', 'typeof(2008-02-20)': 'string'} +#query +#SELECT weekofyear('2008-02-20'::string); diff --git a/datafusion/sqllogictest/test_files/spark/datetime/year.slt b/datafusion/sqllogictest/test_files/spark/datetime/year.slt new file mode 100644 index 0000000000000..6577522736c07 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/year.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT year('2016-07-30'); +## PySpark 3.5.5 Result: {'year(2016-07-30)': 2016, 'typeof(year(2016-07-30))': 'int', 'typeof(2016-07-30)': 'string'} +#query +#SELECT year('2016-07-30'::string); diff --git a/datafusion/sqllogictest/test_files/spark/hash/crc32.slt b/datafusion/sqllogictest/test_files/spark/hash/crc32.slt new file mode 100644 index 0000000000000..87b69d8d404ea --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/crc32.slt @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT crc32('Spark'); +## PySpark 3.5.5 Result: {'crc32(Spark)': 1557323817, 'typeof(crc32(Spark))': 'bigint', 'typeof(Spark)': 'string'} + +# Basic crc32 tests +query I +SELECT crc32('Spark'); +---- +1557323817 + +query I +SELECT crc32(NULL); +---- +NULL + +query I +SELECT crc32(''); +---- +0 + +query I +SELECT crc32(arrow_cast('', 'Binary')); +---- +0 + +# Test with LargeUtf8 (using CAST to ensure type) +query I +SELECT crc32(arrow_cast('Spark', 'LargeUtf8')); +---- +1557323817 + +# Test with Utf8View (using CAST to ensure type) +query I +SELECT crc32(arrow_cast('Spark', 'Utf8View')); +---- +1557323817 + +# Test with different binary types +query I +SELECT crc32(arrow_cast('Spark', 'Binary')); +---- +1557323817 + +# Test with LargeBinary +query I +SELECT crc32(arrow_cast('Spark', 'LargeBinary')); +---- +1557323817 + +# Test with BinaryView +query I +SELECT crc32(arrow_cast('Spark', 'BinaryView')); +---- +1557323817 diff --git a/datafusion/sqllogictest/test_files/spark/hash/md5.slt b/datafusion/sqllogictest/test_files/spark/hash/md5.slt new file mode 100644 index 0000000000000..f1a4b82e291ad --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/md5.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT md5('Spark'); +## PySpark 3.5.5 Result: {'md5(Spark)': '8cde774d6f7333752ed72cacddb05126', 'typeof(md5(Spark))': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT md5('Spark'::string); diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha.slt b/datafusion/sqllogictest/test_files/spark/hash/sha.slt new file mode 100644 index 0000000000000..c7710aa6a763f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/sha.slt @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sha('Spark'); +## PySpark 3.5.5 Result: {'sha(Spark)': '85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c', 'typeof(sha(Spark))': 'string', 'typeof(Spark)': 'string'} + +# Basic sha tests +query T +SELECT sha('Spark'); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +query T +SELECT sha(NULL); +---- +NULL + +query T +SELECT sha(''); +---- +da39a3ee5e6b4b0d3255bfef95601890afd80709 + +# Test with LargeUtf8 (using CAST to ensure type) +query T +SELECT sha(arrow_cast('Spark', 'LargeUtf8')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +# Test with Utf8View (using CAST to ensure type) +query T +SELECT sha(arrow_cast('Spark', 'Utf8View')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +# Test with Binary +query T +SELECT sha(arrow_cast('Spark', 'Binary')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +# Test with LargeBinary +query T +SELECT sha(arrow_cast('Spark', 'LargeBinary')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +# Test with BinaryView +query T +SELECT sha(arrow_cast('Spark', 'BinaryView')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha1.slt b/datafusion/sqllogictest/test_files/spark/hash/sha1.slt new file mode 100644 index 0000000000000..1ce7346160726 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/sha1.slt @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sha1('Spark'); +## PySpark 3.5.5 Result: {'sha1(Spark)': '85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c', 'typeof(sha1(Spark))': 'string', 'typeof(Spark)': 'string'} + +# Basic sha1 tests +query T +SELECT sha1('Spark'); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +query T +SELECT sha1(NULL); +---- +NULL + +query T +SELECT sha1(''); +---- +da39a3ee5e6b4b0d3255bfef95601890afd80709 + +# Test with LargeUtf8 (using CAST to ensure type) +query T +SELECT sha1(arrow_cast('Spark', 'LargeUtf8')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +# Test with Utf8View (using CAST to ensure type) +query T +SELECT sha1(arrow_cast('Spark', 'Utf8View')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +# Test with Binary +query T +SELECT sha1(arrow_cast('Spark', 'Binary')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +# Test with LargeBinary +query T +SELECT sha1(arrow_cast('Spark', 'LargeBinary')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c + +# Test with BinaryView +query T +SELECT sha1(arrow_cast('Spark', 'BinaryView')); +---- +85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt new file mode 100644 index 0000000000000..7690a38773b04 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT sha2('Spark', 0::INT); +---- +529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b + +query T +SELECT sha2('Spark', 256::INT); +---- +529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b + +query T +SELECT sha2('Spark', 224::INT); +---- +dbeab94971678d36af2195851c0f7485775a2a7c60073d62fc04549c + +query T +SELECT sha2('Spark', 384::INT); +---- +1e40b8d06c248a1cc32428c22582b6219d072283078fa140d9ad297ecadf2cabefc341b857ad36226aa8d6d79f2ab67d + +query T +SELECT sha2('Spark', 512::INT); +---- +44844a586c54c9a212da1dbfe05c5f1705de1af5fda1f0d36297623249b279fd8f0ccec03f888f4fb13bf7cd83fdad58591c797f81121a23cfdd5e0897795238 + +query T +SELECT sha2('Spark', 128::INT); +---- +NULL + +query T +SELECT sha2(expr, 256::INT) FROM VALUES ('foo'), ('bar') AS t(expr); +---- +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9 + +query T +SELECT sha2(expr, 128::INT) FROM VALUES ('foo'), ('bar') AS t(expr); +---- +NULL +NULL + +query T +SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), (384::INT), (512::INT), (128::INT) AS t(bit_length); +---- +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +98c11ffdfdd540676b1a137cb1a22b2a70350c9a44171d6b1180c6be5cbb2ee3f79d532c8a1dd9ef2e8e08e752a3babb +f7fbba6e0636f890e56fbbf3283e524c6fa3204ae298382d624741d0dc6638326e282c41be5e4254d8820772c5518a2c5a8c0c7f7eda19594a7eb539453e1ed7 +NULL + +query T +SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('baz',384::INT), ('qux',512::INT), ('qux',128::INT) AS t(expr, bit_length); +---- +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +07daf010de7f7f0d8d76a76eb8d1eb40182c8d1e7a3877a6686c9bf0 +967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91 +8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52 +NULL diff --git a/datafusion/sqllogictest/test_files/spark/json/get_json_object.slt b/datafusion/sqllogictest/test_files/spark/json/get_json_object.slt new file mode 100644 index 0000000000000..7917ee1168766 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/json/get_json_object.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT get_json_object('{"a":"b"}', '$.a'); +## PySpark 3.5.5 Result: {'get_json_object({"a":"b"}, $.a)': 'b', 'typeof(get_json_object({"a":"b"}, $.a))': 'string', 'typeof({"a":"b"})': 'string', 'typeof($.a)': 'string'} +#query +#SELECT get_json_object('{"a":"b"}'::string, '$.a'::string); diff --git a/datafusion/sqllogictest/test_files/spark/json/json_object_keys.slt b/datafusion/sqllogictest/test_files/spark/json/json_object_keys.slt new file mode 100644 index 0000000000000..ce399c5820a27 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/json/json_object_keys.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT json_object_keys('{"f1":"abc","f2":{"f3":"a", "f4":"b"}}'); +## PySpark 3.5.5 Result: {'json_object_keys({"f1":"abc","f2":{"f3":"a", "f4":"b"}})': ['f1', 'f2'], 'typeof(json_object_keys({"f1":"abc","f2":{"f3":"a", "f4":"b"}}))': 'array', 'typeof({"f1":"abc","f2":{"f3":"a", "f4":"b"}})': 'string'} +#query +#SELECT json_object_keys('{"f1":"abc","f2":{"f3":"a", "f4":"b"}}'::string); + +## Original Query: SELECT json_object_keys('{"key": "value"}'); +## PySpark 3.5.5 Result: {'json_object_keys({"key": "value"})': ['key'], 'typeof(json_object_keys({"key": "value"}))': 'array', 'typeof({"key": "value"})': 'string'} +#query +#SELECT json_object_keys('{"key": "value"}'::string); + +## Original Query: SELECT json_object_keys('{}'); +## PySpark 3.5.5 Result: {'json_object_keys({})': [], 'typeof(json_object_keys({}))': 'array', 'typeof({})': 'string'} +#query +#SELECT json_object_keys('{}'::string); diff --git a/datafusion/sqllogictest/test_files/spark/map/map_from_arrays.slt b/datafusion/sqllogictest/test_files/spark/map/map_from_arrays.slt new file mode 100644 index 0000000000000..a26b0435c9291 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/map/map_from_arrays.slt @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Spark doctests +query ? +SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')); +---- +{1.0: 2, 3.0: 4} + +query ? +SELECT map_from_arrays(array(2, 5), array('a', 'b')); +---- +{2: a, 5: b} + +query ? +SELECT map_from_arrays(array(1, 2), array('a', NULL)); +---- +{1: a, 2: NULL} + +query ? +SELECT map_from_arrays(cast(array() as array), cast(array() as array)); +---- +{} + +# Tests with DataType:Null input arrays +query ? +SELECT map_from_arrays(NULL, NULL); +---- +NULL + +query ? +SELECT map_from_arrays(array(1), NULL); +---- +NULL + +query ? +SELECT map_from_arrays(NULL, array(1)); +---- +NULL + +# Tests with different inner lists lengths +query error DataFusion error: Execution error: map_deduplicate_keys: keys and values lists in the same row must have equal lengths +SELECT map_from_arrays(array(1, 2, 3), array('a', 'b')); + +query error DataFusion error: Execution error: map_deduplicate_keys: keys and values lists in the same row must have equal lengths +SELECT map_from_arrays(array(), array('a', 'b')); + +query error DataFusion error: Execution error: map_deduplicate_keys: keys and values lists in the same row must have equal lengths +SELECT map_from_arrays(array(1, 2, 3), array()); + +query error DataFusion error: Execution error: map_deduplicate_keys: keys and values lists in the same row must have equal lengths +select map_from_arrays(a, b) +from values + (array[1], array[1]), + (array[2, 3, 4], array[2, 3]), + (array[5], array[4]) +as tab(a, b); + +#Test with multiple rows: good, empty and nullable +query ? +select map_from_arrays(a, b) +from values + (array[1], array['a']), + (NULL, NULL), + (array[1,2,3], NULL), + (NULL, array['b', 'c']), + (array[4, 5], array['d', 'e']), + (array[], array[]), + (array[6, 7, 8], array['f', 'g', 'h']) +as tab(a, b); +---- +{1: a} +NULL +NULL +NULL +{4: d, 5: e} +{} +{6: f, 7: g, 8: h} + +# Test with complex types +query ? +SELECT map_from_arrays(array(array('a', 'b'), array('c', 'd')), array(struct(1, 2, 3), struct(4, 5, 6))); +---- +{[a, b]: {c0: 1, c1: 2, c2: 3}, [c, d]: {c0: 4, c1: 5, c2: 6}} + +# Test with nested function calls +query ? +SELECT + map_from_arrays( + array['outer_key1', 'outer_key2'], + array[ + -- value for outer_key1: a map itself + map_from_arrays( + array['inner_a', 'inner_b'], + array[1, 2] + ), + -- value for outer_key2: another map + map_from_arrays( + array['inner_x', 'inner_y', 'inner_z'], + array[10, 20, 30] + ) + ] + ) AS nested_map; +---- +{outer_key1: {inner_a: 1, inner_b: 2}, outer_key2: {inner_x: 10, inner_y: 20, inner_z: 30}} + +# Test with duplicate keys +query ? +SELECT map_from_arrays(array(true, false, true), array('a', NULL, 'b')); +---- +{false: NULL, true: b} + +# Tests with different list types +query ? +SELECT map_from_arrays(arrow_cast(array(2, 5), 'LargeList(Int32)'), arrow_cast(array('a', 'b'), 'FixedSizeList(2, Utf8)')); +---- +{2: a, 5: b} + +query ? +SELECT map_from_arrays(arrow_cast(array('a', 'b', 'c'), 'FixedSizeList(3, Utf8)'), arrow_cast(array(1, 2, 3), 'LargeList(Int32)')); +---- +{a: 1, b: 2, c: 3} diff --git a/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt b/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt new file mode 100644 index 0000000000000..19b46886a027e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/map/map_from_entries.slt @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Spark doctests +query ? +SELECT map_from_entries(array[struct(1, 'a'), struct(2, 'b')]); +---- +{1: a, 2: b} + +query ? +SELECT map_from_entries(array[struct(1, cast(null as string)), struct(2, 'b')]); +---- +{1: NULL, 2: b} + +query ? +SELECT map_from_entries(data) +from values + (array[struct(1, 'a'), struct(2, 'b')]), + (array[struct(3, 'c')]) +as tab(data); +---- +{1: a, 2: b} +{3: c} + +# Tests with NULL and empty input structarrays +query ? +SELECT map_from_entries(data) +from values + (cast(array[] as array>)), + (cast(NULL as array>)) +as tab(data); +---- +{} +NULL + +# Test with NULL key, should fail +query error DataFusion error: Arrow error: Invalid argument error: Found unmasked nulls for non-nullable StructArray field "key" +SELECT map_from_entries(array[struct(NULL, 1)]); + +# Tests with NULL and array of Null type, should fail +query error DataFusion error: Execution error: map_from_entries: expected array>, got Null +SELECT map_from_entries(NULL); + +query error DataFusion error: Execution error: map_from_entries: expected array>, got Null +SELECT map_from_entries(array[NULL]); + +# Test with NULL array and NULL entries in arrays +# output is NULL if any entry is NULL +query ? +SELECT map_from_entries(data) +from values + ( + array[ + struct(1 as a, 'a' as b), + cast(NULL as struct
), + cast(NULL as struct) + ] + ), + (NULL), + ( + array[ + struct(2 as a, 'b' as b), + struct(3 as a, 'c' as b) + ] + ), + ( + array[ + struct(4 as a, 'd' as b), + cast(NULL as struct), + struct(5 as a, 'e' as b), + struct(6 as a, 'f' as b) + ] + ) +as tab(data); +---- +NULL +NULL +{2: b, 3: c} +NULL + +#Test with multiple rows: good, empty and nullable +query ? +SELECT map_from_entries(data) +from values + (NULL), + (array[ + struct(1 as a, 'b' as b), + struct(2 as a, cast(NULL as string) as b), + struct(3 as a, 'd' as b) + ]), + (array[]), + (NULL) +as tab(data); +---- +NULL +{1: b, 2: NULL, 3: d} +{} +NULL + +# Test with complex types +query ? +SELECT map_from_entries(array[ + struct(array('a', 'b'), struct(1, 2, 3)), + struct(array('c', 'd'), struct(4, 5, 6)) +]); +---- +{[a, b]: {c0: 1, c1: 2, c2: 3}, [c, d]: {c0: 4, c1: 5, c2: 6}} + +# Test with nested function calls +query ? +SELECT + map_from_entries( + array[ + struct( + 'outer_key1', + -- value for outer_key1: a map itself + map_from_entries( + array[ + struct('inner_a', 1), + struct('inner_b', 2) + ] + ) + ), + struct( + 'outer_key2', + -- value for outer_key2: another map + map_from_entries( + array[ + struct('inner_x', 10), + struct('inner_y', 20), + struct('inner_z', 30) + ] + ) + ) + ] + ) AS nested_map; +---- +{outer_key1: {inner_a: 1, inner_b: 2}, outer_key2: {inner_x: 10, inner_y: 20, inner_z: 30}} + +# Test with duplicate keys +query ? +SELECT map_from_entries(array( + struct(true, 'a'), + struct(false, 'b'), + struct(true, 'c'), + struct(false, cast(NULL as string)), + struct(true, 'd') +)); +---- +{false: NULL, true: d} diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt new file mode 100644 index 0000000000000..4b9edf7e29f27 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT abs(-1); +## PySpark 3.5.5 Result: {'abs(-1)': 1, 'typeof(abs(-1))': 'int', 'typeof(-1)': 'int'} +#query +#SELECT abs(-1::int); + +## Original Query: SELECT abs(INTERVAL -'1-1' YEAR TO MONTH); +## PySpark 3.5.5 Result: {"abs(INTERVAL '-1-1' YEAR TO MONTH)": 13, "typeof(abs(INTERVAL '-1-1' YEAR TO MONTH))": 'interval year to month', "typeof(INTERVAL '-1-1' YEAR TO MONTH)": 'interval year to month'} +#query +#SELECT abs(INTERVAL '-1-1' YEAR TO MONTH::interval year to month); diff --git a/datafusion/sqllogictest/test_files/spark/math/acos.slt b/datafusion/sqllogictest/test_files/spark/math/acos.slt new file mode 100644 index 0000000000000..76ee5694254b8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/acos.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT acos(1); +## PySpark 3.5.5 Result: {'ACOS(1)': 0.0, 'typeof(ACOS(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT acos(1::int); + +## Original Query: SELECT acos(2); +## PySpark 3.5.5 Result: {'ACOS(2)': nan, 'typeof(ACOS(2))': 'double', 'typeof(2)': 'int'} +#query +#SELECT acos(2::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/acosh.slt b/datafusion/sqllogictest/test_files/spark/math/acosh.slt new file mode 100644 index 0000000000000..45b4537419ea6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/acosh.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT acosh(0); +## PySpark 3.5.5 Result: {'ACOSH(0)': nan, 'typeof(ACOSH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT acosh(0::int); + +## Original Query: SELECT acosh(1); +## PySpark 3.5.5 Result: {'ACOSH(1)': 0.0, 'typeof(ACOSH(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT acosh(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/asin.slt b/datafusion/sqllogictest/test_files/spark/math/asin.slt new file mode 100644 index 0000000000000..5c6d265ff036e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/asin.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT asin(0); +## PySpark 3.5.5 Result: {'ASIN(0)': 0.0, 'typeof(ASIN(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT asin(0::int); + +## Original Query: SELECT asin(2); +## PySpark 3.5.5 Result: {'ASIN(2)': nan, 'typeof(ASIN(2))': 'double', 'typeof(2)': 'int'} +#query +#SELECT asin(2::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/asinh.slt b/datafusion/sqllogictest/test_files/spark/math/asinh.slt new file mode 100644 index 0000000000000..7d965dea2bd77 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/asinh.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT asinh(0); +## PySpark 3.5.5 Result: {'ASINH(0)': 0.0, 'typeof(ASINH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT asinh(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/atan.slt b/datafusion/sqllogictest/test_files/spark/math/atan.slt new file mode 100644 index 0000000000000..b5817b08049ce --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/atan.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT atan(0); +## PySpark 3.5.5 Result: {'ATAN(0)': 0.0, 'typeof(ATAN(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT atan(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/atan2.slt b/datafusion/sqllogictest/test_files/spark/math/atan2.slt new file mode 100644 index 0000000000000..eb644854c402d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/atan2.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT atan2(0, 0); +## PySpark 3.5.5 Result: {'ATAN2(0, 0)': 0.0, 'typeof(ATAN2(0, 0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT atan2(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/atanh.slt b/datafusion/sqllogictest/test_files/spark/math/atanh.slt new file mode 100644 index 0000000000000..7e79f8c7bee58 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/atanh.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT atanh(0); +## PySpark 3.5.5 Result: {'ATANH(0)': 0.0, 'typeof(ATANH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT atanh(0::int); + +## Original Query: SELECT atanh(2); +## PySpark 3.5.5 Result: {'ATANH(2)': nan, 'typeof(ATANH(2))': 'double', 'typeof(2)': 'int'} +#query +#SELECT atanh(2::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/bin.slt b/datafusion/sqllogictest/test_files/spark/math/bin.slt new file mode 100644 index 0000000000000..1fa24e6cda6b0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/bin.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bin(-13); +## PySpark 3.5.5 Result: {'bin(-13)': '1111111111111111111111111111111111111111111111111111111111110011', 'typeof(bin(-13))': 'string', 'typeof(-13)': 'int'} +#query +#SELECT bin(-13::int); + +## Original Query: SELECT bin(13); +## PySpark 3.5.5 Result: {'bin(13)': '1101', 'typeof(bin(13))': 'string', 'typeof(13)': 'int'} +#query +#SELECT bin(13::int); + +## Original Query: SELECT bin(13.3); +## PySpark 3.5.5 Result: {'bin(13.3)': '1101', 'typeof(bin(13.3))': 'string', 'typeof(13.3)': 'decimal(3,1)'} +#query +#SELECT bin(13.3::decimal(3,1)); diff --git a/datafusion/sqllogictest/test_files/spark/math/bround.slt b/datafusion/sqllogictest/test_files/spark/math/bround.slt new file mode 100644 index 0000000000000..afdc9c635c9a7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/bround.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bround(2.5, 0); +## PySpark 3.5.5 Result: {'bround(2.5, 0)': Decimal('2'), 'typeof(bround(2.5, 0))': 'decimal(2,0)', 'typeof(2.5)': 'decimal(2,1)', 'typeof(0)': 'int'} +#query +#SELECT bround(2.5::decimal(2,1), 0::int); + +## Original Query: SELECT bround(25, -1); +## PySpark 3.5.5 Result: {'bround(25, -1)': 20, 'typeof(bround(25, -1))': 'int', 'typeof(25)': 'int', 'typeof(-1)': 'int'} +#query +#SELECT bround(25::int, -1::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/cbrt.slt b/datafusion/sqllogictest/test_files/spark/math/cbrt.slt new file mode 100644 index 0000000000000..f0aea17ff0b9b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/cbrt.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT cbrt(27.0); +## PySpark 3.5.5 Result: {'CBRT(27.0)': 3.0, 'typeof(CBRT(27.0))': 'double', 'typeof(27.0)': 'decimal(3,1)'} +#query +#SELECT cbrt(27.0::decimal(3,1)); diff --git a/datafusion/sqllogictest/test_files/spark/math/ceil.slt b/datafusion/sqllogictest/test_files/spark/math/ceil.slt new file mode 100644 index 0000000000000..c87a29b61fd49 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/ceil.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ceil(-0.1); +## PySpark 3.5.5 Result: {'CEIL(-0.1)': Decimal('0'), 'typeof(CEIL(-0.1))': 'decimal(1,0)', 'typeof(-0.1)': 'decimal(1,1)'} +#query +#SELECT ceil(-0.1::decimal(1,1)); + +## Original Query: SELECT ceil(3.1411, -3); +## PySpark 3.5.5 Result: {'ceil(3.1411, -3)': Decimal('1000'), 'typeof(ceil(3.1411, -3))': 'decimal(4,0)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(-3)': 'int'} +#query +#SELECT ceil(3.1411::decimal(5,4), -3::int); + +## Original Query: SELECT ceil(3.1411, 3); +## PySpark 3.5.5 Result: {'ceil(3.1411, 3)': Decimal('3.142'), 'typeof(ceil(3.1411, 3))': 'decimal(5,3)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(3)': 'int'} +#query +#SELECT ceil(3.1411::decimal(5,4), 3::int); + +## Original Query: SELECT ceil(5); +## PySpark 3.5.5 Result: {'CEIL(5)': 5, 'typeof(CEIL(5))': 'bigint', 'typeof(5)': 'int'} +#query +#SELECT ceil(5::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/ceiling.slt b/datafusion/sqllogictest/test_files/spark/math/ceiling.slt new file mode 100644 index 0000000000000..2b761faef47df --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/ceiling.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ceiling(-0.1); +## PySpark 3.5.5 Result: {'ceiling(-0.1)': Decimal('0'), 'typeof(ceiling(-0.1))': 'decimal(1,0)', 'typeof(-0.1)': 'decimal(1,1)'} +#query +#SELECT ceiling(-0.1::decimal(1,1)); + +## Original Query: SELECT ceiling(3.1411, -3); +## PySpark 3.5.5 Result: {'ceiling(3.1411, -3)': Decimal('1000'), 'typeof(ceiling(3.1411, -3))': 'decimal(4,0)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(-3)': 'int'} +#query +#SELECT ceiling(3.1411::decimal(5,4), -3::int); + +## Original Query: SELECT ceiling(3.1411, 3); +## PySpark 3.5.5 Result: {'ceiling(3.1411, 3)': Decimal('3.142'), 'typeof(ceiling(3.1411, 3))': 'decimal(5,3)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(3)': 'int'} +#query +#SELECT ceiling(3.1411::decimal(5,4), 3::int); + +## Original Query: SELECT ceiling(5); +## PySpark 3.5.5 Result: {'ceiling(5)': 5, 'typeof(ceiling(5))': 'bigint', 'typeof(5)': 'int'} +#query +#SELECT ceiling(5::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/conv.slt b/datafusion/sqllogictest/test_files/spark/math/conv.slt new file mode 100644 index 0000000000000..371fd3e746bd3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/conv.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT conv('100', 2, 10); +## PySpark 3.5.5 Result: {'conv(100, 2, 10)': '4', 'typeof(conv(100, 2, 10))': 'string', 'typeof(100)': 'string', 'typeof(2)': 'int', 'typeof(10)': 'int'} +#query +#SELECT conv('100'::string, 2::int, 10::int); + +## Original Query: SELECT conv(-10, 16, -10); +## PySpark 3.5.5 Result: {'conv(-10, 16, -10)': '-16', 'typeof(conv(-10, 16, -10))': 'string', 'typeof(-10)': 'int', 'typeof(16)': 'int'} +#query +#SELECT conv(-10::int, 16::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/cos.slt b/datafusion/sqllogictest/test_files/spark/math/cos.slt new file mode 100644 index 0000000000000..a473c257553b1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/cos.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT cos(0); +## PySpark 3.5.5 Result: {'COS(0)': 1.0, 'typeof(COS(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT cos(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/cosh.slt b/datafusion/sqllogictest/test_files/spark/math/cosh.slt new file mode 100644 index 0000000000000..97b3a2eb01cb8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/cosh.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT cosh(0); +## PySpark 3.5.5 Result: {'COSH(0)': 1.0, 'typeof(COSH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT cosh(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/cot.slt b/datafusion/sqllogictest/test_files/spark/math/cot.slt new file mode 100644 index 0000000000000..5bb010337addf --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/cot.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT cot(1); +## PySpark 3.5.5 Result: {'COT(1)': 0.6420926159343306, 'typeof(COT(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT cot(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/csc.slt b/datafusion/sqllogictest/test_files/spark/math/csc.slt new file mode 100644 index 0000000000000..b11986c3e1b9f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/csc.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT csc(1); +## PySpark 3.5.5 Result: {'CSC(1)': 1.1883951057781212, 'typeof(CSC(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT csc(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/degrees.slt b/datafusion/sqllogictest/test_files/spark/math/degrees.slt new file mode 100644 index 0000000000000..5ca7bacb8a6a6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/degrees.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT degrees(3.141592653589793); +## PySpark 3.5.5 Result: {'DEGREES(3.141592653589793)': 180.0, 'typeof(DEGREES(3.141592653589793))': 'double', 'typeof(3.141592653589793)': 'decimal(16,15)'} +#query +#SELECT degrees(3.141592653589793::decimal(16,15)); diff --git a/datafusion/sqllogictest/test_files/spark/math/e.slt b/datafusion/sqllogictest/test_files/spark/math/e.slt new file mode 100644 index 0000000000000..c8e23d3b0900b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/e.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT e(); +## PySpark 3.5.5 Result: {'E()': 2.718281828459045, 'typeof(E())': 'double'} +#query +#SELECT e(); diff --git a/datafusion/sqllogictest/test_files/spark/math/exp.slt b/datafusion/sqllogictest/test_files/spark/math/exp.slt new file mode 100644 index 0000000000000..671684f9855da --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/exp.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT exp(0); +## PySpark 3.5.5 Result: {'EXP(0)': 1.0, 'typeof(EXP(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT exp(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/expm1.slt b/datafusion/sqllogictest/test_files/spark/math/expm1.slt new file mode 100644 index 0000000000000..96d4abb0414b3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/expm1.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query R +SELECT expm1(0::INT); +---- +0 + +query R +SELECT expm1(1::INT); +---- +1.718281828459045 + +query R +SELECT expm1(a) FROM (VALUES (0::INT), (1::INT)) AS t(a); +---- +0 +1.718281828459045 diff --git a/datafusion/sqllogictest/test_files/spark/math/factorial.slt b/datafusion/sqllogictest/test_files/spark/math/factorial.slt new file mode 100644 index 0000000000000..f8eae5d95ab85 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/factorial.slt @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT factorial(5); +## PySpark 3.5.5 Result: {'factorial(5)': 120, 'typeof(factorial(5))': 'bigint', 'typeof(5)': 'int'} +query I +SELECT factorial(5::INT); +---- +120 + +query I +SELECT factorial(a) +FROM VALUES + (-1::INT), + (0::INT), (1::INT), (2::INT), (3::INT), (4::INT), (5::INT), (6::INT), (7::INT), (8::INT), (9::INT), (10::INT), + (11::INT), (12::INT), (13::INT), (14::INT), (15::INT), (16::INT), (17::INT), (18::INT), (19::INT), (20::INT), + (21::INT), + (NULL) AS t(a); +---- +NULL +1 +1 +2 +6 +24 +120 +720 +5040 +40320 +362880 +3628800 +39916800 +479001600 +6227020800 +87178291200 +1307674368000 +20922789888000 +355687428096000 +6402373705728000 +121645100408832000 +2432902008176640000 +NULL +NULL + +query error Error during planning: Failed to coerce arguments to satisfy a call to 'factorial' function +SELECT factorial(5::BIGINT); diff --git a/datafusion/sqllogictest/test_files/spark/math/floor.slt b/datafusion/sqllogictest/test_files/spark/math/floor.slt new file mode 100644 index 0000000000000..d39d47ab1fee8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/floor.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT floor(-0.1); +## PySpark 3.5.5 Result: {'FLOOR(-0.1)': Decimal('-1'), 'typeof(FLOOR(-0.1))': 'decimal(1,0)', 'typeof(-0.1)': 'decimal(1,1)'} +#query +#SELECT floor(-0.1::decimal(1,1)); + +## Original Query: SELECT floor(3.1411, -3); +## PySpark 3.5.5 Result: {'floor(3.1411, -3)': Decimal('0'), 'typeof(floor(3.1411, -3))': 'decimal(4,0)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(-3)': 'int'} +#query +#SELECT floor(3.1411::decimal(5,4), -3::int); + +## Original Query: SELECT floor(3.1411, 3); +## PySpark 3.5.5 Result: {'floor(3.1411, 3)': Decimal('3.141'), 'typeof(floor(3.1411, 3))': 'decimal(5,3)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(3)': 'int'} +#query +#SELECT floor(3.1411::decimal(5,4), 3::int); + +## Original Query: SELECT floor(5); +## PySpark 3.5.5 Result: {'FLOOR(5)': 5, 'typeof(FLOOR(5))': 'bigint', 'typeof(5)': 'int'} +#query +#SELECT floor(5::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/greatest.slt b/datafusion/sqllogictest/test_files/spark/math/greatest.slt new file mode 100644 index 0000000000000..ff1143d5fcafa --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/greatest.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT greatest(10, 9, 2, 4, 3); +## PySpark 3.5.5 Result: {'greatest(10, 9, 2, 4, 3)': 10, 'typeof(greatest(10, 9, 2, 4, 3))': 'int', 'typeof(10)': 'int', 'typeof(9)': 'int', 'typeof(2)': 'int', 'typeof(4)': 'int', 'typeof(3)': 'int'} +#query +#SELECT greatest(10::int, 9::int, 2::int, 4::int, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt new file mode 100644 index 0000000000000..0fb8b92de02d4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT hex('Spark SQL'); +---- +537061726B2053514C + +query T +SELECT hex(1234::INT); +---- +4D2 + +query T +SELECT hex(a) from VALUES (1234::INT), (NULL), (456::INT) AS t(a); +---- +4D2 +NULL +1C8 + +query T +SELECT hex(a) from VALUES ('foo'), (NULL), ('foobarbaz') AS t(a); +---- +666F6F +NULL +666F6F62617262617A + +statement ok +CREATE TABLE t_utf8view as VALUES (arrow_cast('foo', 'Utf8View')), (NULL), (arrow_cast('foobarbaz', 'Utf8View')); + +query T +SELECT hex(column1) FROM t_utf8view; +---- +666F6F +NULL +666F6F62617262617A diff --git a/datafusion/sqllogictest/test_files/spark/math/hypot.slt b/datafusion/sqllogictest/test_files/spark/math/hypot.slt new file mode 100644 index 0000000000000..1349be0a95ee7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/hypot.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT hypot(3, 4); +## PySpark 3.5.5 Result: {'HYPOT(3, 4)': 5.0, 'typeof(HYPOT(3, 4))': 'double', 'typeof(3)': 'int', 'typeof(4)': 'int'} +#query +#SELECT hypot(3::int, 4::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/least.slt b/datafusion/sqllogictest/test_files/spark/math/least.slt new file mode 100644 index 0000000000000..f17bc2aed9885 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/least.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT least(10, 9, 2, 4, 3); +## PySpark 3.5.5 Result: {'least(10, 9, 2, 4, 3)': 2, 'typeof(least(10, 9, 2, 4, 3))': 'int', 'typeof(10)': 'int', 'typeof(9)': 'int', 'typeof(2)': 'int', 'typeof(4)': 'int', 'typeof(3)': 'int'} +#query +#SELECT least(10::int, 9::int, 2::int, 4::int, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/ln.slt b/datafusion/sqllogictest/test_files/spark/math/ln.slt new file mode 100644 index 0000000000000..d3245f76736e7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/ln.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ln(1); +## PySpark 3.5.5 Result: {'ln(1)': 0.0, 'typeof(ln(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT ln(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/log.slt b/datafusion/sqllogictest/test_files/spark/math/log.slt new file mode 100644 index 0000000000000..0ea3de7f1bf0d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/log.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT log(10, 100); +## PySpark 3.5.5 Result: {'LOG(10, 100)': 2.0, 'typeof(LOG(10, 100))': 'double', 'typeof(10)': 'int', 'typeof(100)': 'int'} +#query +#SELECT log(10::int, 100::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/log10.slt b/datafusion/sqllogictest/test_files/spark/math/log10.slt new file mode 100644 index 0000000000000..95e518f2eb804 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/log10.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT log10(10); +## PySpark 3.5.5 Result: {'LOG10(10)': 1.0, 'typeof(LOG10(10))': 'double', 'typeof(10)': 'int'} +#query +#SELECT log10(10::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/log1p.slt b/datafusion/sqllogictest/test_files/spark/math/log1p.slt new file mode 100644 index 0000000000000..359051c62120e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/log1p.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT log1p(0); +## PySpark 3.5.5 Result: {'LOG1P(0)': 0.0, 'typeof(LOG1P(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT log1p(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/log2.slt b/datafusion/sqllogictest/test_files/spark/math/log2.slt new file mode 100644 index 0000000000000..2706c0fad4bdd --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/log2.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT log2(2); +## PySpark 3.5.5 Result: {'LOG2(2)': 1.0, 'typeof(LOG2(2))': 'double', 'typeof(2)': 'int'} +#query +#SELECT log2(2::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/mod.slt b/datafusion/sqllogictest/test_files/spark/math/mod.slt new file mode 100644 index 0000000000000..2780b3e1053df --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/mod.slt @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT MOD(2, 1.8); +## PySpark 3.5.5 Result: {'mod(2, 1.8)': Decimal('0.2'), 'typeof(mod(2, 1.8))': 'decimal(2,1)', 'typeof(2)': 'int', 'typeof(1.8)': 'decimal(2,1)'} +query R +SELECT MOD(2::int, 1.8::decimal(2,1)); +---- +0.2 + +# Basic integer modulo operations +query I +SELECT MOD(10::int, 3::int) as mod_1; +---- +1 + +query I +SELECT MOD(7::int, 2::int) as mod_2; +---- +1 + +query I +SELECT MOD(15::int, 4::int) as mod_3; +---- +3 + +query I +SELECT MOD(100::int, 30::int) as mod_4; +---- +10 + +query I +SELECT MOD(50::int, 25::int) as mod_5; +---- +0 + +query I +SELECT MOD(200::int, 60::int) as mod_6; +---- +20 + +# Float modulo operations +query R +SELECT MOD(10.5::float8, 3.0::float8) as mod_float_1; +---- +1.5 + +query R +SELECT MOD(7.2::float8, 2.5::float8) as mod_float_2; +---- +2.2 + +query R +SELECT MOD(15.8::float8, 4.2::float8) as mod_float_3; +---- +3.2 + +# Mixed type operations +query R +SELECT MOD(10::int, 3.0::float8) as mod_mixed_1; +---- +1 + +query R +SELECT MOD(10.5::float8, 3::int) as mod_mixed_2; +---- +1.5 + +# NULL value handling +query I +SELECT MOD(NULL::int, 3::int) as mod_null_1; +---- +NULL + +query I +SELECT MOD(10::int, NULL::int) as mod_null_2; +---- +NULL + +query I +SELECT MOD(NULL::int, NULL::int) as mod_null_3; +---- +NULL + +# Special values: NaN and Infinity +query R +SELECT MOD(5.0::float8, 'NaN'::float8) as mod_nan_1; +---- +NaN + +query R +SELECT MOD('NaN'::float8, 2.0::float8) as mod_nan_2; +---- +NaN + +query R +SELECT MOD('NaN'::float8, 'Infinity'::float8) as mod_nan_3; +---- +NaN + +query R +SELECT MOD('Infinity'::float8, 'NaN'::float8) as mod_nan_4; +---- +NaN + +query R +SELECT MOD(5.0::float8, 'Infinity'::float8) as mod_inf_1; +---- +5 + +query R +SELECT MOD('Infinity'::float8, 2.0::float8) as mod_inf_2; +---- +NaN + +# Decimal operations +query R +SELECT MOD(2.5::decimal(3,1), 1.2::decimal(2,1)) as mod_decimal_1; +---- +0.1 + +query R +SELECT MOD(10.0::decimal(3,1), 3.0::decimal(2,1)) as mod_decimal_2; +---- +1 + +# Edge cases +query I +SELECT MOD(0::int, 5::int) as mod_zero_1; +---- +0 + +query I +SELECT MOD(5::int, 1::int) as mod_zero_2; +---- +0 + +query I +SELECT MOD(-10::int, 3::int) as mod_negative_1; +---- +-1 + +query I +SELECT MOD(10::int, -3::int) as mod_negative_2; +---- +1 + +query I +SELECT MOD(-10::int, -3::int) as mod_negative_3; +---- +-1 + +# Multiple MOD operations +query I +SELECT MOD(MOD(100::int, 30::int), 5::int) as mod_nested_1; +---- +0 + +query I +SELECT MOD(10::int, MOD(7::int, 3::int)) as mod_nested_2; +---- +0 + +# MOD with different data types +query I +SELECT MOD(10::int8, 3::int8) as mod_int8; +---- +1 + +query I +SELECT MOD(arrow_cast(10, 'Int16'), arrow_cast(3, 'Int16')) as mod_int16; +---- +1 + +query I +SELECT MOD(arrow_cast(10, 'Int32'), arrow_cast(3, 'Int32')) as mod_int32; +---- +1 + +query I +SELECT MOD(arrow_cast(10, 'Int64'), arrow_cast(3, 'Int64')) as mod_int64; +---- +1 + +query I +SELECT MOD(arrow_cast(10, 'UInt16'), arrow_cast(3, 'UInt16')) as mod_int16; +---- +1 + +query I +SELECT MOD(arrow_cast(10, 'UInt32'), arrow_cast(3, 'UInt32')) as mod_int32; +---- +1 + +query I +SELECT MOD(arrow_cast(10, 'UInt64'), arrow_cast(3, 'UInt64')) as mod_int64; +---- +1 + +query R +SELECT MOD(10::float4, 3::float4) as mod_float4; +---- +1 + +query R +SELECT MOD(10::float8, 3::float8) as mod_float8; +---- +1 + +# MOD in expressions +query I +SELECT MOD(10::int + 5::int, 3::int) as mod_expr_1; +---- +0 + +query I +SELECT MOD(10::int, 2::int + 1::int) as mod_expr_2; +---- +1 + +query I +SELECT MOD(10::int * 2::int, 5::int) as mod_expr_3; +---- +0 diff --git a/datafusion/sqllogictest/test_files/spark/math/negative.slt b/datafusion/sqllogictest/test_files/spark/math/negative.slt new file mode 100644 index 0000000000000..aa8e558e9895e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/negative.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT negative(1); +## PySpark 3.5.5 Result: {'negative(1)': -1, 'typeof(negative(1))': 'int', 'typeof(1)': 'int'} +#query +#SELECT negative(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/pi.slt b/datafusion/sqllogictest/test_files/spark/math/pi.slt new file mode 100644 index 0000000000000..4b94e09bc9383 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/pi.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT pi(); +## PySpark 3.5.5 Result: {'PI()': 3.141592653589793, 'typeof(PI())': 'double'} +#query +#SELECT pi(); diff --git a/datafusion/sqllogictest/test_files/spark/math/pmod.slt b/datafusion/sqllogictest/test_files/spark/math/pmod.slt new file mode 100644 index 0000000000000..cf273c2d78f53 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/pmod.slt @@ -0,0 +1,334 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +# Basic PMOD tests with positive integers +query I +SELECT pmod(10::int, 3::int) as pmod_1; +---- +1 + +query I +SELECT pmod(7::int, 2::int) as pmod_2; +---- +1 + +query I +SELECT pmod(15::int, 4::int) as pmod_3; +---- +3 + +# PMOD tests with negative integers (should return positive remainder) +query I +SELECT pmod(-10::int, 3::int) as pmod_negative_1; +---- +2 + +query I +SELECT pmod(-7::int, 3::int) as pmod_negative_2; +---- +2 + +query I +SELECT pmod(-15::int, 4::int) as pmod_negative_3; +---- +1 + +query I +SELECT pmod(-5::int, 5::int) as pmod_negative_4; +---- +0 + +# PMOD tests with zero +query I +SELECT pmod(0::int, 5::int) as pmod_zero_1; +---- +0 + +statement error DataFusion error: Arrow error: Divide by zero error +SELECT pmod(10::int, 0::int) as pmod_zero_2; + +# PMOD tests with NULL values +query I +SELECT pmod(NULL::int, 3::int) as pmod_null_1; +---- +NULL + +query I +SELECT pmod(10::int, NULL::int) as pmod_null_2; +---- +NULL + +query I +SELECT pmod(NULL::int, NULL::int) as pmod_null_3; +---- +NULL + +# PMOD tests with large integers +query I +SELECT pmod(100::int, 30::int) as pmod_large_1; +---- +10 + +query I +SELECT pmod(-100::int, 30::int) as pmod_large_2; +---- +20 + +query I +SELECT pmod(200::int, 60::int) as pmod_large_3; +---- +20 + +query I +SELECT pmod(-200::int, 60::int) as pmod_large_4; +---- +40 + +# PMOD tests with edge cases +query I +SELECT pmod(-1::int, 5::int) as pmod_edge_1; +---- +4 + +query I +SELECT pmod(1::int, 5::int) as pmod_edge_2; +---- +1 + +query I +SELECT pmod(-5::int, 5::int) as pmod_edge_3; +---- +0 + +query I +SELECT pmod(5::int, 5::int) as pmod_edge_4; +---- +0 + +query I +SELECT pmod(-6::int, 5::int) as pmod_edge_5; +---- +4 + +query I +SELECT pmod(6::int, 5::int) as pmod_edge_6; +---- +1 + +# PMOD tests with negative divisors +query I +SELECT pmod(10::int, -3::int) as pmod_neg_div_1; +---- +1 + +query I +SELECT pmod(-7::int, -3::int) as pmod_neg_div_2; +---- +-1 + +query I +SELECT pmod(15::int, -4::int) as pmod_neg_div_3; +---- +3 + +# PMOD tests with floating point numbers +query R +SELECT pmod(10.5::float8, 3.0::float8) as pmod_float_1; +---- +1.5 + +query R +SELECT pmod(-7.2::float8, 3.0::float8) as pmod_float_2; +---- +1.8 + +query R +SELECT pmod(15.8::float8, 4.2::float8) as pmod_float_3; +---- +3.2 + +query R +SELECT pmod(-15.8::float8, 4.2::float8) as pmod_float_4; +---- +1 + +query R +SELECT pmod(5.0::float8, 2.5::float8) as pmod_float_5; +---- +0 + +query R +SELECT pmod(-5.0::float8, 2.5::float8) as pmod_float_6; +---- +0 + +# PMOD tests with float32 +query R +SELECT pmod(10.5::float4, 3.0::float4) as pmod_float32_1; +---- +1.5 + +query R +SELECT CAST(pmod(CAST(-7.2 AS float4), CAST(3.0 AS float4)) AS DECIMAL(3,1)) as pmod_float32_2; +---- +1.8 + +query R +SELECT CAST(pmod(15.8::float4, 4.2::float4) AS DECIMAL(3,1)) as pmod_float32_3; +---- +3.2 + +query R +SELECT CAST(pmod(-15.8::float4, 4.2::float4) AS DECIMAL(3,1)) as pmod_float32_4; +---- +1 + +# PMOD tests with special float values +query R +SELECT pmod('NaN'::float8, 2.0::float8) as pmod_nan_1; +---- +NaN + +query R +SELECT pmod(5.0::float8, 'NaN'::float8) as pmod_nan_2; +---- +NaN + +query R +SELECT pmod('Infinity'::float8, 2.0::float8) as pmod_inf_1; +---- +NaN + +query R +SELECT pmod(5.0::float8, 'Infinity'::float8) as pmod_inf_2; +---- +5 + +query R +SELECT pmod(-5.0::float8, 'Infinity'::float8) as pmod_inf_3; +---- +NaN + +query R +SELECT pmod('NaN'::float8, 'Infinity'::float8) as pmod_nan_inf_1; +---- +NaN + +query R +SELECT pmod('Infinity'::float8, 'NaN'::float8) as pmod_inf_nan_1; +---- +NaN + +# PMOD tests with decimal types +query R +SELECT pmod(2.5::decimal(3,1), 1.2::decimal(2,1)) as pmod_decimal_1; +---- +0.1 + +query R +SELECT pmod(-2.5::decimal(3,1), 1.2::decimal(2,1)) as pmod_decimal_2; +---- +1.1 + +query R +SELECT pmod(10.0::decimal(3,1), 3.0::decimal(2,1)) as pmod_decimal_3; +---- +1 + +query R +SELECT pmod(-10.0::decimal(3,1), 3.0::decimal(2,1)) as pmod_decimal_4; +---- +2 + +# PMOD tests with different integer types +query I +SELECT pmod(10::int8, 3::int8) as pmod_int8_1; +---- +1 + +query I +SELECT pmod(-10::int8, 3::int8) as pmod_int8_2; +---- +2 + +query I +SELECT pmod(arrow_cast(10, 'Int16'), arrow_cast(3, 'Int16')) as pmod_int16_1; +---- +1 + +query I +SELECT pmod(arrow_cast(-10, 'Int16'), arrow_cast(3, 'Int16')) as pmod_int16_2; +---- +2 + +query I +SELECT pmod(arrow_cast(10, 'Int64'), arrow_cast(3, 'Int64')) as pmod_int64_1; +---- +1 + +query I +SELECT pmod(arrow_cast(-10, 'Int64'), arrow_cast(3, 'Int64')) as pmod_int64_2; +---- +2 + +# PMOD tests with unsigned integers +query I +SELECT pmod(arrow_cast(10, 'UInt8'), arrow_cast(3, 'UInt8')) as pmod_uint8_1; +---- +1 + +query I +SELECT pmod(arrow_cast(10, 'UInt16'), arrow_cast(3, 'UInt16')) as pmod_uint16_1; +---- +1 + +query I +SELECT pmod(arrow_cast(10, 'UInt32'), arrow_cast(3, 'UInt32')) as pmod_uint32_1; +---- +1 + +query I +SELECT pmod(arrow_cast(10, 'UInt64'), arrow_cast(3, 'UInt64')) as pmod_uint64_1; +---- +1 + +# PMOD tests with scalar values +query I +SELECT pmod(10, 3) as pmod_scalar_1; +---- +1 + +query I +SELECT pmod(-10, 3) as pmod_scalar_2; +---- +2 + +query R +SELECT pmod(10.5, 3.0) as pmod_scalar_3; +---- +1.5 + +query R +SELECT pmod(-7.2, 3.0) as pmod_scalar_4; +---- +1.8 diff --git a/datafusion/sqllogictest/test_files/spark/math/positive.slt b/datafusion/sqllogictest/test_files/spark/math/positive.slt new file mode 100644 index 0000000000000..5e1be0f4b4678 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/positive.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT positive(1); +## PySpark 3.5.5 Result: {'(+ 1)': 1, 'typeof((+ 1))': 'int', 'typeof(1)': 'int'} +#query +#SELECT positive(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/pow.slt b/datafusion/sqllogictest/test_files/spark/math/pow.slt new file mode 100644 index 0000000000000..55b6f65b81235 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/pow.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT pow(2, 3); +## PySpark 3.5.5 Result: {'pow(2, 3)': 8.0, 'typeof(pow(2, 3))': 'double', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT pow(2::int, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/power.slt b/datafusion/sqllogictest/test_files/spark/math/power.slt new file mode 100644 index 0000000000000..f82056c6d941b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/power.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT power(2, 3); +## PySpark 3.5.5 Result: {'POWER(2, 3)': 8.0, 'typeof(POWER(2, 3))': 'double', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT power(2::int, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/radians.slt b/datafusion/sqllogictest/test_files/spark/math/radians.slt new file mode 100644 index 0000000000000..bccda62c542ff --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/radians.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT radians(180); +## PySpark 3.5.5 Result: {'RADIANS(180)': 3.141592653589793, 'typeof(RADIANS(180))': 'double', 'typeof(180)': 'int'} +#query +#SELECT radians(180::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/rand.slt b/datafusion/sqllogictest/test_files/spark/math/rand.slt new file mode 100644 index 0000000000000..53b4c6f822218 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/rand.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT rand(); +## PySpark 3.5.5 Result: {'rand()': 0.949892358232337, 'typeof(rand())': 'double'} +#query +#SELECT rand(); + +## Original Query: SELECT rand(0); +## PySpark 3.5.5 Result: {'rand(0)': 0.7604953758285915, 'typeof(rand(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT rand(0::int); + +## Original Query: SELECT rand(null); +## PySpark 3.5.5 Result: {'rand(NULL)': 0.7604953758285915, 'typeof(rand(NULL))': 'double', 'typeof(NULL)': 'void'} +#query +#SELECT rand(NULL::void); diff --git a/datafusion/sqllogictest/test_files/spark/math/randn.slt b/datafusion/sqllogictest/test_files/spark/math/randn.slt new file mode 100644 index 0000000000000..daf81babd02c4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/randn.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT randn(); +## PySpark 3.5.5 Result: {'randn()': 1.498983714060803, 'typeof(randn())': 'double'} +#query +#SELECT randn(); + +## Original Query: SELECT randn(0); +## PySpark 3.5.5 Result: {'randn(0)': 1.6034991609278433, 'typeof(randn(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT randn(0::int); + +## Original Query: SELECT randn(null); +## PySpark 3.5.5 Result: {'randn(NULL)': 1.6034991609278433, 'typeof(randn(NULL))': 'double', 'typeof(NULL)': 'void'} +#query +#SELECT randn(NULL::void); diff --git a/datafusion/sqllogictest/test_files/spark/math/random.slt b/datafusion/sqllogictest/test_files/spark/math/random.slt new file mode 100644 index 0000000000000..280a81b8888c0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/random.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT random(); +## PySpark 3.5.5 Result: {'rand()': 0.7460731389309176, 'typeof(rand())': 'double'} +#query +#SELECT random(); + +## Original Query: SELECT random(0); +## PySpark 3.5.5 Result: {'rand(0)': 0.7604953758285915, 'typeof(rand(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT random(0::int); + +## Original Query: SELECT random(null); +## PySpark 3.5.5 Result: {'rand(NULL)': 0.7604953758285915, 'typeof(rand(NULL))': 'double', 'typeof(NULL)': 'void'} +#query +#SELECT random(NULL::void); diff --git a/datafusion/sqllogictest/test_files/spark/math/rint.slt b/datafusion/sqllogictest/test_files/spark/math/rint.slt new file mode 100644 index 0000000000000..2cae3cbf58fd3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/rint.slt @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT rint(12.3456); +## PySpark 3.5.5 Result: {'rint(12.3456)': 12.0, 'typeof(rint(12.3456))': 'double', 'typeof(12.3456)': 'decimal(6,4)'} +query R +SELECT rint(12.3456); +---- +12 + +## Test additional cases +query R +SELECT rint(-12.3456); +---- +-12 + +query R +SELECT rint(arrow_cast(-12.3456, 'Float32')); +---- +-12 + +## Test int +query R +SELECT rint(arrow_cast(12, 'UInt8')); +---- +12 + +query R +SELECT rint(arrow_cast(-12, 'Int8')); +---- +-12 + +query R +SELECT rint(arrow_cast(12, 'UInt16')); +---- +12 + +query R +SELECT rint(arrow_cast(-12, 'Int16')); +---- +-12 + +query R +SELECT rint(arrow_cast(12, 'UInt32')); +---- +12 + +query R +SELECT rint(arrow_cast(-12, 'Int32')); +---- +-12 + +query R +SELECT rint(arrow_cast(12, 'UInt64')); +---- +12 + +query R +SELECT rint(arrow_cast(-12, 'Int64')); +---- +-12 + +query R +SELECT rint(2.5); +---- +2 + +query R +SELECT rint(3.5); +---- +4 + +query R +SELECT rint(-2.5); +---- +-2 + +query R +SELECT rint(-3.5); +---- +-4 + +query R +SELECT rint(0.0); +---- +0 + +query R +SELECT rint(42); +---- +42 + +## Test with null +query R +SELECT rint(NULL); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/math/round.slt b/datafusion/sqllogictest/test_files/spark/math/round.slt new file mode 100644 index 0000000000000..bc1f6b72247a0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/round.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT round(2.5, 0); +## PySpark 3.5.5 Result: {'round(2.5, 0)': Decimal('3'), 'typeof(round(2.5, 0))': 'decimal(2,0)', 'typeof(2.5)': 'decimal(2,1)', 'typeof(0)': 'int'} +#query +#SELECT round(2.5::decimal(2,1), 0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/sec.slt b/datafusion/sqllogictest/test_files/spark/math/sec.slt new file mode 100644 index 0000000000000..6c49a34549f0f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sec.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sec(0); +## PySpark 3.5.5 Result: {'SEC(0)': 1.0, 'typeof(SEC(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT sec(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt b/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt new file mode 100644 index 0000000000000..3676e4c18153c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT shiftleft(2, 1); +## PySpark 3.5.5 Result: {'shiftleft(2, 1)': 4, 'typeof(shiftleft(2, 1))': 'int', 'typeof(2)': 'int', 'typeof(1)': 'int'} + +# Basic shiftleft tests +query I +SELECT shiftleft(2::int, 1::int); +---- +4 + +query I +SELECT shiftleft(1::int, 2::int); +---- +4 + +query I +SELECT shiftleft(3::int, 3::int); +---- +24 + +# Different data types +query I +SELECT shiftleft(2::bigint, 1::int); +---- +4 + +query I +SELECT shiftleft(1::bigint, 2::int); +---- +4 + +query I +SELECT shiftleft(2::int, 1::bigint); +---- +4 + +# Large shifts (should handle modulo correctly) +query I +SELECT shiftleft(1::int, 32::int); +---- +1 + +query I +SELECT shiftleft(2::int, 33::int); +---- +4 + +query I +SELECT shiftleft(3::int, 64::int); +---- +3 + +# Negative shifts +query I +SELECT shiftleft(4::int, -1::int); +---- +0 + +query I +SELECT shiftleft(8::int, -2::int); +---- +0 + +query I +SELECT shiftleft(16::int, -3::int); +---- +0 + +# Zero shifts +query I +SELECT shiftleft(5::int, 0::int); +---- +5 + +query I +SELECT shiftleft(0::int, 5::int); +---- +0 + +# Edge cases +query I +SELECT shiftleft(2147483647::int, 1::int); +---- +-2 + +query I +SELECT shiftleft(-1::int, 1::int); +---- +-2 + +# Multiple values in a table +query I +SELECT shiftleft(value, shift) FROM (VALUES (1, 1), (2, 2), (3, 3), (4, 4)) AS t(value, shift); +---- +2 +8 +24 +64 + +# Null handling +query I +SELECT shiftleft(NULL::int, 1::int); +---- +NULL + +query I +SELECT shiftleft(1::int, NULL::int); +---- +NULL + +query I +SELECT shiftleft(NULL::int, NULL::int); +---- +NULL + +query I +select shiftleft(3::int,-31); +---- +6 + +query I +select shiftleft(3::int,-32); +---- +3 diff --git a/datafusion/sqllogictest/test_files/spark/math/sign.slt b/datafusion/sqllogictest/test_files/spark/math/sign.slt new file mode 100644 index 0000000000000..e135f4b13d063 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sign.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sign(40); +## PySpark 3.5.5 Result: {'sign(40)': 1.0, 'typeof(sign(40))': 'double', 'typeof(40)': 'int'} +#query +#SELECT sign(40::int); + +## Original Query: SELECT sign(INTERVAL -'100' YEAR); +## PySpark 3.5.5 Result: {"sign(INTERVAL '-100' YEAR)": -1.0, "typeof(sign(INTERVAL '-100' YEAR))": 'double', "typeof(INTERVAL '-100' YEAR)": 'interval year'} +#query +#SELECT sign(INTERVAL '-100' YEAR::interval year); diff --git a/datafusion/sqllogictest/test_files/spark/math/signum.slt b/datafusion/sqllogictest/test_files/spark/math/signum.slt new file mode 100644 index 0000000000000..5557f5fe32721 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/signum.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT signum(40); +## PySpark 3.5.5 Result: {'SIGNUM(40)': 1.0, 'typeof(SIGNUM(40))': 'double', 'typeof(40)': 'int'} +#query +#SELECT signum(40::int); + +## Original Query: SELECT signum(INTERVAL -'100' YEAR); +## PySpark 3.5.5 Result: {"SIGNUM(INTERVAL '-100' YEAR)": -1.0, "typeof(SIGNUM(INTERVAL '-100' YEAR))": 'double', "typeof(INTERVAL '-100' YEAR)": 'interval year'} +#query +#SELECT signum(INTERVAL '-100' YEAR::interval year); diff --git a/datafusion/sqllogictest/test_files/spark/math/sin.slt b/datafusion/sqllogictest/test_files/spark/math/sin.slt new file mode 100644 index 0000000000000..418a6fafdff8d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sin.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sin(0); +## PySpark 3.5.5 Result: {'SIN(0)': 0.0, 'typeof(SIN(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT sin(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/sinh.slt b/datafusion/sqllogictest/test_files/spark/math/sinh.slt new file mode 100644 index 0000000000000..6d24d387e210c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sinh.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sinh(0); +## PySpark 3.5.5 Result: {'SINH(0)': 0.0, 'typeof(SINH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT sinh(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/sqrt.slt b/datafusion/sqllogictest/test_files/spark/math/sqrt.slt new file mode 100644 index 0000000000000..10b896eec9651 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sqrt.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sqrt(4); +## PySpark 3.5.5 Result: {'SQRT(4)': 2.0, 'typeof(SQRT(4))': 'double', 'typeof(4)': 'int'} +#query +#SELECT sqrt(4::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/tan.slt b/datafusion/sqllogictest/test_files/spark/math/tan.slt new file mode 100644 index 0000000000000..4699893d2bd59 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/tan.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT tan(0); +## PySpark 3.5.5 Result: {'TAN(0)': 0.0, 'typeof(TAN(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT tan(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/tanh.slt b/datafusion/sqllogictest/test_files/spark/math/tanh.slt new file mode 100644 index 0000000000000..1511adb5b3724 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/tanh.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT tanh(0); +## PySpark 3.5.5 Result: {'TANH(0)': 0.0, 'typeof(TANH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT tanh(0::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/try_add.slt b/datafusion/sqllogictest/test_files/spark/math/try_add.slt new file mode 100644 index 0000000000000..f3f83158289fa --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/try_add.slt @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_add(1, 2); +## PySpark 3.5.5 Result: {'try_add(1, 2)': 3, 'typeof(try_add(1, 2))': 'int', 'typeof(1)': 'int', 'typeof(2)': 'int'} +#query +#SELECT try_add(1::int, 2::int); + +## Original Query: SELECT try_add(2147483647, 1); +## PySpark 3.5.5 Result: {'try_add(2147483647, 1)': None, 'typeof(try_add(2147483647, 1))': 'int', 'typeof(2147483647)': 'int', 'typeof(1)': 'int'} +#query +#SELECT try_add(2147483647::int, 1::int); + +## Original Query: SELECT try_add(date'2021-01-01', 1); +## PySpark 3.5.5 Result: {"try_add(DATE '2021-01-01', 1)": datetime.date(2021, 1, 2), "typeof(try_add(DATE '2021-01-01', 1))": 'date', "typeof(DATE '2021-01-01')": 'date', 'typeof(1)': 'int'} +#query +#SELECT try_add(DATE '2021-01-01'::date, 1::int); + +## Original Query: SELECT try_add(date'2021-01-01', interval 1 year); +## PySpark 3.5.5 Result: {"try_add(DATE '2021-01-01', INTERVAL '1' YEAR)": datetime.date(2022, 1, 1), "typeof(try_add(DATE '2021-01-01', INTERVAL '1' YEAR))": 'date', "typeof(DATE '2021-01-01')": 'date', "typeof(INTERVAL '1' YEAR)": 'interval year'} +#query +#SELECT try_add(DATE '2021-01-01'::date, INTERVAL '1' YEAR::interval year); + +## Original Query: SELECT try_add(interval 1 year, interval 2 year); +## PySpark 3.5.5 Result: {"try_add(INTERVAL '1' YEAR, INTERVAL '2' YEAR)": 36, "typeof(try_add(INTERVAL '1' YEAR, INTERVAL '2' YEAR))": 'interval year', "typeof(INTERVAL '1' YEAR)": 'interval year', "typeof(INTERVAL '2' YEAR)": 'interval year'} +#query +#SELECT try_add(INTERVAL '1' YEAR::interval year, INTERVAL '2' YEAR::interval year); + +## Original Query: SELECT try_add(timestamp'2021-01-01 00:00:00', interval 1 day); +## PySpark 3.5.5 Result: {"try_add(TIMESTAMP '2021-01-01 00:00:00', INTERVAL '1' DAY)": datetime.datetime(2021, 1, 2, 0, 0), "typeof(try_add(TIMESTAMP '2021-01-01 00:00:00', INTERVAL '1' DAY))": 'timestamp', "typeof(TIMESTAMP '2021-01-01 00:00:00')": 'timestamp', "typeof(INTERVAL '1' DAY)": 'interval day'} +#query +#SELECT try_add(TIMESTAMP '2021-01-01 00:00:00'::timestamp, INTERVAL '1' DAY::interval day); diff --git a/datafusion/sqllogictest/test_files/spark/math/try_divide.slt b/datafusion/sqllogictest/test_files/spark/math/try_divide.slt new file mode 100644 index 0000000000000..405872f9ca0f8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/try_divide.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_divide(1, 0); +## PySpark 3.5.5 Result: {'try_divide(1, 0)': None, 'typeof(try_divide(1, 0))': 'double', 'typeof(1)': 'int', 'typeof(0)': 'int'} +#query +#SELECT try_divide(1::int, 0::int); + +## Original Query: SELECT try_divide(2L, 2L); +## PySpark 3.5.5 Result: {'try_divide(2, 2)': 1.0, 'typeof(try_divide(2, 2))': 'double', 'typeof(2)': 'bigint'} +#query +#SELECT try_divide(2::bigint); + +## Original Query: SELECT try_divide(3, 2); +## PySpark 3.5.5 Result: {'try_divide(3, 2)': 1.5, 'typeof(try_divide(3, 2))': 'double', 'typeof(3)': 'int', 'typeof(2)': 'int'} +#query +#SELECT try_divide(3::int, 2::int); + +## Original Query: SELECT try_divide(interval 2 month, 0); +## PySpark 3.5.5 Result: {"try_divide(INTERVAL '2' MONTH, 0)": None, "typeof(try_divide(INTERVAL '2' MONTH, 0))": 'interval year to month', "typeof(INTERVAL '2' MONTH)": 'interval month', 'typeof(0)': 'int'} +#query +#SELECT try_divide(INTERVAL '2' MONTH::interval month, 0::int); + +## Original Query: SELECT try_divide(interval 2 month, 2); +## PySpark 3.5.5 Result: {"try_divide(INTERVAL '2' MONTH, 2)": 1, "typeof(try_divide(INTERVAL '2' MONTH, 2))": 'interval year to month', "typeof(INTERVAL '2' MONTH)": 'interval month', 'typeof(2)': 'int'} +#query +#SELECT try_divide(INTERVAL '2' MONTH::interval month, 2::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/try_multiply.slt b/datafusion/sqllogictest/test_files/spark/math/try_multiply.slt new file mode 100644 index 0000000000000..c495a758e2346 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/try_multiply.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_multiply(-2147483648, 10); +## PySpark 3.5.5 Result: {'try_multiply(-2147483648, 10)': None, 'typeof(try_multiply(-2147483648, 10))': 'int', 'typeof(-2147483648)': 'int', 'typeof(10)': 'int'} +#query +#SELECT try_multiply(-2147483648::int, 10::int); + +## Original Query: SELECT try_multiply(2, 3); +## PySpark 3.5.5 Result: {'try_multiply(2, 3)': 6, 'typeof(try_multiply(2, 3))': 'int', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT try_multiply(2::int, 3::int); + +## Original Query: SELECT try_multiply(interval 2 year, 3); +## PySpark 3.5.5 Result: {"try_multiply(INTERVAL '2' YEAR, 3)": 72, "typeof(try_multiply(INTERVAL '2' YEAR, 3))": 'interval year to month', "typeof(INTERVAL '2' YEAR)": 'interval year', 'typeof(3)': 'int'} +#query +#SELECT try_multiply(INTERVAL '2' YEAR::interval year, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/math/try_subtract.slt b/datafusion/sqllogictest/test_files/spark/math/try_subtract.slt new file mode 100644 index 0000000000000..4ce4c480b91c2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/try_subtract.slt @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_subtract(-2147483648, 1); +## PySpark 3.5.5 Result: {'try_subtract(-2147483648, 1)': None, 'typeof(try_subtract(-2147483648, 1))': 'int', 'typeof(-2147483648)': 'int', 'typeof(1)': 'int'} +#query +#SELECT try_subtract(-2147483648::int, 1::int); + +## Original Query: SELECT try_subtract(2, 1); +## PySpark 3.5.5 Result: {'try_subtract(2, 1)': 1, 'typeof(try_subtract(2, 1))': 'int', 'typeof(2)': 'int', 'typeof(1)': 'int'} +#query +#SELECT try_subtract(2::int, 1::int); + +## Original Query: SELECT try_subtract(date'2021-01-01', interval 1 year); +## PySpark 3.5.5 Result: {"try_subtract(DATE '2021-01-01', INTERVAL '1' YEAR)": datetime.date(2020, 1, 1), "typeof(try_subtract(DATE '2021-01-01', INTERVAL '1' YEAR))": 'date', "typeof(DATE '2021-01-01')": 'date', "typeof(INTERVAL '1' YEAR)": 'interval year'} +#query +#SELECT try_subtract(DATE '2021-01-01'::date, INTERVAL '1' YEAR::interval year); + +## Original Query: SELECT try_subtract(date'2021-01-02', 1); +## PySpark 3.5.5 Result: {"try_subtract(DATE '2021-01-02', 1)": datetime.date(2021, 1, 1), "typeof(try_subtract(DATE '2021-01-02', 1))": 'date', "typeof(DATE '2021-01-02')": 'date', 'typeof(1)': 'int'} +#query +#SELECT try_subtract(DATE '2021-01-02'::date, 1::int); + +## Original Query: SELECT try_subtract(interval 2 year, interval 1 year); +## PySpark 3.5.5 Result: {"try_subtract(INTERVAL '2' YEAR, INTERVAL '1' YEAR)": 12, "typeof(try_subtract(INTERVAL '2' YEAR, INTERVAL '1' YEAR))": 'interval year', "typeof(INTERVAL '2' YEAR)": 'interval year', "typeof(INTERVAL '1' YEAR)": 'interval year'} +#query +#SELECT try_subtract(INTERVAL '2' YEAR::interval year, INTERVAL '1' YEAR::interval year); + +## Original Query: SELECT try_subtract(timestamp'2021-01-02 00:00:00', interval 1 day); +## PySpark 3.5.5 Result: {"try_subtract(TIMESTAMP '2021-01-02 00:00:00', INTERVAL '1' DAY)": datetime.datetime(2021, 1, 1, 0, 0), "typeof(try_subtract(TIMESTAMP '2021-01-02 00:00:00', INTERVAL '1' DAY))": 'timestamp', "typeof(TIMESTAMP '2021-01-02 00:00:00')": 'timestamp', "typeof(INTERVAL '1' DAY)": 'interval day'} +#query +#SELECT try_subtract(TIMESTAMP '2021-01-02 00:00:00'::timestamp, INTERVAL '1' DAY::interval day); diff --git a/datafusion/sqllogictest/test_files/spark/math/width_bucket.slt b/datafusion/sqllogictest/test_files/spark/math/width_bucket.slt new file mode 100644 index 0000000000000..d2661ceb9d3bb --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/width_bucket.slt @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +query I +SELECT width_bucket(-0.9, 5.2, 0.5, 2) +---- +3 + +query I +SELECT width_bucket(-2.1, 1.3, 3.4, 3) +---- +0 + +query I +SELECT width_bucket(5.3, 0.2, 10.6, 5) +---- +3 + +query I +SELECT width_bucket(8.1, 0.0, 5.7, 4) +---- +5 + +query I +SELECT width_bucket(INTERVAL '0' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10) +---- +1 + +query I +SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10) +---- +1 + +query I +SELECT width_bucket(INTERVAL '1' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10) +---- +2 + +query I +SELECT width_bucket(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10) +---- +2 + +# test of sail +query I +SELECT width_bucket(0.0, 10.0, 0.0, 5) +---- +6 + +query I +SELECT width_bucket(10.0, 0.0, 10.0, 5) +---- +6 + +query I +SELECT width_bucket(10.0, 0.0, 0.0, 5) +---- +NULL + +# lo == hi +query I +SELECT width_bucket(10.0, 0.0, 0.0, 5); +---- +NULL + +# n <= 0 +query I +SELECT width_bucket(5.0, 0.0, 10.0, 0); +---- +NULL + +query I +SELECT width_bucket(arrow_cast('NaN','Float64'),5.0, 0.0, 5) +---- +NULL + +query I +SELECT width_bucket(5.0, arrow_cast('NaN','Float64'), 0.0, 5) +---- +NULL + +query I +SELECT width_bucket(5.0, 0.0, arrow_cast('NaN','Float64'), 5) +---- +NULL + +query I +SELECT width_bucket(INTERVAL '1' YEAR, INTERVAL '5' YEAR, INTERVAL '5' YEAR, 10) +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/misc/assert_true.slt b/datafusion/sqllogictest/test_files/spark/misc/assert_true.slt new file mode 100644 index 0000000000000..99330233aabdd --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/assert_true.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT assert_true(0 < 1); +## PySpark 3.5.5 Result: {"assert_true((0 < 1), '(0 < 1)' is not true!)": None, "typeof(assert_true((0 < 1), '(0 < 1)' is not true!))": 'void', 'typeof((0 < 1))': 'boolean'} +#query +#SELECT assert_true((0 < 1)::boolean); diff --git a/datafusion/sqllogictest/test_files/spark/misc/current_catalog.slt b/datafusion/sqllogictest/test_files/spark/misc/current_catalog.slt new file mode 100644 index 0000000000000..b0cb488233c93 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/current_catalog.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_catalog(); +## PySpark 3.5.5 Result: {'current_catalog()': 'spark_catalog', 'typeof(current_catalog())': 'string'} +#query +#SELECT current_catalog(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/current_database.slt b/datafusion/sqllogictest/test_files/spark/misc/current_database.slt new file mode 100644 index 0000000000000..0883db29a0a64 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/current_database.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_database(); +## PySpark 3.5.5 Result: {'current_database()': 'default', 'typeof(current_database())': 'string'} +#query +#SELECT current_database(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/current_schema.slt b/datafusion/sqllogictest/test_files/spark/misc/current_schema.slt new file mode 100644 index 0000000000000..630734431df35 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/current_schema.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_schema(); +## PySpark 3.5.5 Result: {'current_database()': 'default', 'typeof(current_database())': 'string'} +#query +#SELECT current_schema(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/current_user.slt b/datafusion/sqllogictest/test_files/spark/misc/current_user.slt new file mode 100644 index 0000000000000..17cfbd292e1db --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/current_user.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_user(); +## PySpark 3.5.5 Result: {'current_user()': 'r', 'typeof(current_user())': 'string'} +#query +#SELECT current_user(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/equal_null.slt b/datafusion/sqllogictest/test_files/spark/misc/equal_null.slt new file mode 100644 index 0000000000000..88999d997d2db --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/equal_null.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT equal_null(1, '11'); +## PySpark 3.5.5 Result: {'equal_null(1, 11)': False, 'typeof(equal_null(1, 11))': 'boolean', 'typeof(1)': 'int', 'typeof(11)': 'string'} +#query +#SELECT equal_null(1::int, '11'::string); + +## Original Query: SELECT equal_null(3, 3); +## PySpark 3.5.5 Result: {'equal_null(3, 3)': True, 'typeof(equal_null(3, 3))': 'boolean', 'typeof(3)': 'int'} +#query +#SELECT equal_null(3::int); + +## Original Query: SELECT equal_null(NULL, 'abc'); +## PySpark 3.5.5 Result: {'equal_null(NULL, abc)': False, 'typeof(equal_null(NULL, abc))': 'boolean', 'typeof(NULL)': 'void', 'typeof(abc)': 'string'} +#query +#SELECT equal_null(NULL::void, 'abc'::string); + +## Original Query: SELECT equal_null(NULL, NULL); +## PySpark 3.5.5 Result: {'equal_null(NULL, NULL)': True, 'typeof(equal_null(NULL, NULL))': 'boolean', 'typeof(NULL)': 'void'} +#query +#SELECT equal_null(NULL::void); + +## Original Query: SELECT equal_null(true, NULL); +## PySpark 3.5.5 Result: {'equal_null(true, NULL)': False, 'typeof(equal_null(true, NULL))': 'boolean', 'typeof(true)': 'boolean', 'typeof(NULL)': 'void'} +#query +#SELECT equal_null(true::boolean, NULL::void); diff --git a/datafusion/sqllogictest/test_files/spark/misc/input_file_block_length.slt b/datafusion/sqllogictest/test_files/spark/misc/input_file_block_length.slt new file mode 100644 index 0000000000000..4f227d7c4d779 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/input_file_block_length.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT input_file_block_length(); +## PySpark 3.5.5 Result: {'input_file_block_length()': -1, 'typeof(input_file_block_length())': 'bigint'} +#query +#SELECT input_file_block_length(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/input_file_block_start.slt b/datafusion/sqllogictest/test_files/spark/misc/input_file_block_start.slt new file mode 100644 index 0000000000000..c60c616328b57 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/input_file_block_start.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT input_file_block_start(); +## PySpark 3.5.5 Result: {'input_file_block_start()': -1, 'typeof(input_file_block_start())': 'bigint'} +#query +#SELECT input_file_block_start(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/input_file_name.slt b/datafusion/sqllogictest/test_files/spark/misc/input_file_name.slt new file mode 100644 index 0000000000000..0379d6d0f5db8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/input_file_name.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT input_file_name(); +## PySpark 3.5.5 Result: {'input_file_name()': '', 'typeof(input_file_name())': 'string'} +#query +#SELECT input_file_name(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/java_method.slt b/datafusion/sqllogictest/test_files/spark/misc/java_method.slt new file mode 100644 index 0000000000000..bb6db98de7e9b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/java_method.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT java_method('java.util.UUID', 'randomUUID'); +## PySpark 3.5.5 Result: {'java_method(java.util.UUID, randomUUID)': 'e0d43859-1003-4f43-bfff-f2e3c34981e2', 'typeof(java_method(java.util.UUID, randomUUID))': 'string', 'typeof(java.util.UUID)': 'string', 'typeof(randomUUID)': 'string'} +#query +#SELECT java_method('java.util.UUID'::string, 'randomUUID'::string); diff --git a/datafusion/sqllogictest/test_files/spark/misc/monotonically_increasing_id.slt b/datafusion/sqllogictest/test_files/spark/misc/monotonically_increasing_id.slt new file mode 100644 index 0000000000000..00f6b4a1192ad --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/monotonically_increasing_id.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT monotonically_increasing_id(); +## PySpark 3.5.5 Result: {'monotonically_increasing_id()': 0, 'typeof(monotonically_increasing_id())': 'bigint'} +#query +#SELECT monotonically_increasing_id(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/reflect.slt b/datafusion/sqllogictest/test_files/spark/misc/reflect.slt new file mode 100644 index 0000000000000..223f692f7abda --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/reflect.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT reflect('java.util.UUID', 'randomUUID'); +## PySpark 3.5.5 Result: {'reflect(java.util.UUID, randomUUID)': 'bcf8f6e4-0d46-41a1-bc3c-9f793c8f8aa8', 'typeof(reflect(java.util.UUID, randomUUID))': 'string', 'typeof(java.util.UUID)': 'string', 'typeof(randomUUID)': 'string'} +#query +#SELECT reflect('java.util.UUID'::string, 'randomUUID'::string); diff --git a/datafusion/sqllogictest/test_files/spark/misc/spark_partition_id.slt b/datafusion/sqllogictest/test_files/spark/misc/spark_partition_id.slt new file mode 100644 index 0000000000000..57993103f8c4b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/spark_partition_id.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT spark_partition_id(); +## PySpark 3.5.5 Result: {'SPARK_PARTITION_ID()': 0, 'typeof(SPARK_PARTITION_ID())': 'int'} +#query +#SELECT spark_partition_id(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/typeof.slt b/datafusion/sqllogictest/test_files/spark/misc/typeof.slt new file mode 100644 index 0000000000000..e930b65baa052 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/typeof.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT typeof(1); +## PySpark 3.5.5 Result: {'typeof(1)': 'int', 'typeof(typeof(1))': 'string'} +#query +#SELECT typeof(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/misc/user.slt b/datafusion/sqllogictest/test_files/spark/misc/user.slt new file mode 100644 index 0000000000000..fc63c6108536a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/user.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT user(); +## PySpark 3.5.5 Result: {'current_user()': 'r', 'typeof(current_user())': 'string'} +#query +#SELECT user(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/uuid.slt b/datafusion/sqllogictest/test_files/spark/misc/uuid.slt new file mode 100644 index 0000000000000..223bd71447ca0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/uuid.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT uuid(); +## PySpark 3.5.5 Result: {'uuid()': '96981e67-62f6-49bc-a6f4-2f9bc676edda', 'typeof(uuid())': 'string'} +#query +#SELECT uuid(); diff --git a/datafusion/sqllogictest/test_files/spark/misc/version.slt b/datafusion/sqllogictest/test_files/spark/misc/version.slt new file mode 100644 index 0000000000000..d01e0c9d962d6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/version.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT version(); +## PySpark 3.5.5 Result: {'version()': '3.5.5 7c29c664cdc9321205a98a14858aaf8daaa19db2', 'typeof(version())': 'string'} +#query +#SELECT version(); diff --git a/datafusion/sqllogictest/test_files/spark/predicate/ilike.slt b/datafusion/sqllogictest/test_files/spark/predicate/ilike.slt new file mode 100644 index 0000000000000..68e8b1c59aeb6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/predicate/ilike.slt @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ilike('Spark', '_Park'); +## PySpark 3.5.5 Result: {'ilike(Spark, _Park)': True, 'typeof(ilike(Spark, _Park))': 'boolean', 'typeof(Spark)': 'string', 'typeof(_Park)': 'string'} +query B +SELECT ilike('Spark'::string, '_Park'::string); +---- +true + +query B +SELECT ilike('Spark', arrow_cast('_Park', 'LargeUtf8')); +---- +true + +query B +SELECT ilike(arrow_cast('Spark', 'Utf8View'), arrow_cast('_Park', 'LargeUtf8')); +---- +true + +query B +SELECT ilike('Spark'::string, '_park'::string); +---- +true + +query B +SELECT ilike('SPARK'::string, '_park'::string); +---- +true + +query B +SELECT ilike('Spark'::string, 'SP%'::string); +---- +true + +query B +SELECT ilike('Spark'::string, '%ARK'::string); +---- +true + +query B +SELECT ilike('Spark'::string, 'xyz'::string); +---- +false + +query B +SELECT ilike(NULL::string, '_park'::string); +---- +NULL + +query B +SELECT ilike('Spark'::string, NULL::string); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/predicate/isnotnull.slt b/datafusion/sqllogictest/test_files/spark/predicate/isnotnull.slt new file mode 100644 index 0000000000000..3fd5d6cea0719 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/predicate/isnotnull.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT isnotnull(1); +## PySpark 3.5.5 Result: {'(1 IS NOT NULL)': True, 'typeof((1 IS NOT NULL))': 'boolean', 'typeof(1)': 'int'} +#query +#SELECT isnotnull(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/predicate/isnull.slt b/datafusion/sqllogictest/test_files/spark/predicate/isnull.slt new file mode 100644 index 0000000000000..7c2290fa3d026 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/predicate/isnull.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT isnull(1); +## PySpark 3.5.5 Result: {'(1 IS NULL)': False, 'typeof((1 IS NULL))': 'boolean', 'typeof(1)': 'int'} +#query +#SELECT isnull(1::int); diff --git a/datafusion/sqllogictest/test_files/spark/predicate/like.slt b/datafusion/sqllogictest/test_files/spark/predicate/like.slt new file mode 100644 index 0000000000000..35cd8a4eaf3ed --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/predicate/like.slt @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT like('Spark', '_park'); +## PySpark 3.5.5 Result: {'Spark LIKE _park': True, 'typeof(Spark LIKE _park)': 'boolean', 'typeof(Spark)': 'string', 'typeof(_park)': 'string'} +query B +SELECT like('Spark'::string, '_park'::string); +---- +true + +query B +SELECT ilike('Spark', arrow_cast('_park', 'LargeUtf8')); +---- +true + +query B +SELECT ilike(arrow_cast('Spark', 'Utf8View'), arrow_cast('_park', 'LargeUtf8')); +---- +true + +query B +SELECT like('Spark'::string, '_Park'::string); +---- +false + +query B +SELECT like('SPARK'::string, '_park'::string); +---- +false + +query B +SELECT like('Spark'::string, 'Sp%'::string); +---- +true + +query B +SELECT like('Spark'::string, 'SP%'::string); +---- +false + +query B +SELECT like('Spark'::string, '%ark'::string); +---- +true + +query B +SELECT like('Spark'::string, '%ARK'::string); +---- +false + +query B +SELECT like('Spark'::string, 'xyz'::string); +---- +false + +query B +SELECT like(NULL::string, '_park'::string); +---- +NULL + +query B +SELECT like('Spark'::string, NULL::string); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/string/ascii.slt b/datafusion/sqllogictest/test_files/spark/string/ascii.slt new file mode 100644 index 0000000000000..623154ffaa7bf --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/ascii.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query I +SELECT ascii('234'); +---- +50 + +query I +SELECT ascii(''); +---- +0 + +query I +SELECT ascii('222'); +---- +50 + +query I +SELECT ascii('😀'); +---- +128512 + +query I +SELECT ascii(2::INT); +---- +50 + +query I +SELECT ascii(a) FROM (VALUES ('Spark'), ('PySpark'), ('Pandas API')) AS t(a); +---- +83 +80 +80 diff --git a/datafusion/sqllogictest/test_files/spark/string/base64.slt b/datafusion/sqllogictest/test_files/spark/string/base64.slt new file mode 100644 index 0000000000000..66edbe8442158 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/base64.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT base64('Spark SQL'); +## PySpark 3.5.5 Result: {'base64(Spark SQL)': 'U3BhcmsgU1FM', 'typeof(base64(Spark SQL))': 'string', 'typeof(Spark SQL)': 'string'} +#query +#SELECT base64('Spark SQL'::string); + +## Original Query: SELECT base64(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"base64(X'537061726B2053514C')": 'U3BhcmsgU1FM', "typeof(base64(X'537061726B2053514C'))": 'string', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT base64(X'537061726B2053514C'::binary); diff --git a/datafusion/sqllogictest/test_files/spark/string/bit_length.slt b/datafusion/sqllogictest/test_files/spark/string/bit_length.slt new file mode 100644 index 0000000000000..457d8cf034719 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/bit_length.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bit_length('Spark SQL'); +## PySpark 3.5.5 Result: {'bit_length(Spark SQL)': 72, 'typeof(bit_length(Spark SQL))': 'int', 'typeof(Spark SQL)': 'string'} +#query +#SELECT bit_length('Spark SQL'::string); + +## Original Query: SELECT bit_length(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"bit_length(X'537061726B2053514C')": 72, "typeof(bit_length(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT bit_length(X'537061726B2053514C'::binary); diff --git a/datafusion/sqllogictest/test_files/spark/string/btrim.slt b/datafusion/sqllogictest/test_files/spark/string/btrim.slt new file mode 100644 index 0000000000000..bf25bd652c81e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/btrim.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT btrim(' SparkSQL '); +## PySpark 3.5.5 Result: {'btrim( SparkSQL )': 'SparkSQL', 'typeof(btrim( SparkSQL ))': 'string', 'typeof( SparkSQL )': 'string'} +#query +#SELECT btrim(' SparkSQL '::string); + +## Original Query: SELECT btrim('SSparkSQLS', 'SL'); +## PySpark 3.5.5 Result: {'btrim(SSparkSQLS, SL)': 'parkSQ', 'typeof(btrim(SSparkSQLS, SL))': 'string', 'typeof(SSparkSQLS)': 'string', 'typeof(SL)': 'string'} +#query +#SELECT btrim('SSparkSQLS'::string, 'SL'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/char.slt b/datafusion/sqllogictest/test_files/spark/string/char.slt new file mode 100644 index 0000000000000..299e2a04136ed --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/char.slt @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT char(65::INT); +---- +A + +query T +SELECT char(321::INT); +---- +A + +query T +SELECT char(a) FROM (VALUES (-1::INT), (0::INT), (65::INT), (321::INT)) AS t(a); +---- +(empty) +\0 +A +A diff --git a/datafusion/sqllogictest/test_files/spark/string/char_length.slt b/datafusion/sqllogictest/test_files/spark/string/char_length.slt new file mode 100644 index 0000000000000..d9f86d45d291d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/char_length.slt @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query I +SELECT CHAR_LENGTH('Spark SQL '); +---- +10 + +query I +SELECT char_length('Spark SQL '); +---- +10 + +query I +SELECT char_length(x'537061726b2053514c'); +---- +9 diff --git a/datafusion/sqllogictest/test_files/spark/string/character_length.slt b/datafusion/sqllogictest/test_files/spark/string/character_length.slt new file mode 100644 index 0000000000000..644741416e53c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/character_length.slt @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query I +SELECT CHARACTER_LENGTH('Spark SQL '); +---- +10 + +query I +SELECT character_length('Spark SQL '); +---- +10 + +query I +SELECT character_length(x'537061726b2053514c'); +---- +9 diff --git a/datafusion/sqllogictest/test_files/spark/string/chr.slt b/datafusion/sqllogictest/test_files/spark/string/chr.slt new file mode 100644 index 0000000000000..69ec4fca394b2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/chr.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT chr(65); +## PySpark 3.5.5 Result: {'chr(65)': 'A', 'typeof(chr(65))': 'string', 'typeof(65)': 'int'} +#query +#SELECT chr(65::int); diff --git a/datafusion/sqllogictest/test_files/spark/string/concat_ws.slt b/datafusion/sqllogictest/test_files/spark/string/concat_ws.slt new file mode 100644 index 0000000000000..62df636bba9ce --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/concat_ws.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT concat_ws(' ', 'Spark', 'SQL'); +## PySpark 3.5.5 Result: {'concat_ws( , Spark, SQL)': 'Spark SQL', 'typeof(concat_ws( , Spark, SQL))': 'string', 'typeof( )': 'string', 'typeof(Spark)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT concat_ws(' '::string, 'Spark'::string, 'SQL'::string); + +## Original Query: SELECT concat_ws('/', 'foo', null, 'bar'); +## PySpark 3.5.5 Result: {'concat_ws(/, foo, NULL, bar)': 'foo/bar', 'typeof(concat_ws(/, foo, NULL, bar))': 'string', 'typeof(/)': 'string', 'typeof(foo)': 'string', 'typeof(NULL)': 'void', 'typeof(bar)': 'string'} +#query +#SELECT concat_ws('/'::string, 'foo'::string, NULL::void, 'bar'::string); + +## Original Query: SELECT concat_ws('s'); +## PySpark 3.5.5 Result: {'concat_ws(s)': '', 'typeof(concat_ws(s))': 'string', 'typeof(s)': 'string'} +#query +#SELECT concat_ws('s'::string); + +## Original Query: SELECT concat_ws(null, 'Spark', 'SQL'); +## PySpark 3.5.5 Result: {'concat_ws(NULL, Spark, SQL)': None, 'typeof(concat_ws(NULL, Spark, SQL))': 'string', 'typeof(NULL)': 'void', 'typeof(Spark)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT concat_ws(NULL::void, 'Spark'::string, 'SQL'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/contains.slt b/datafusion/sqllogictest/test_files/spark/string/contains.slt new file mode 100644 index 0000000000000..1bfb61fc00e37 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/contains.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT contains('Spark SQL', 'SPARK'); +## PySpark 3.5.5 Result: {'contains(Spark SQL, SPARK)': False, 'typeof(contains(Spark SQL, SPARK))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(SPARK)': 'string'} +#query +#SELECT contains('Spark SQL'::string, 'SPARK'::string); + +## Original Query: SELECT contains('Spark SQL', 'Spark'); +## PySpark 3.5.5 Result: {'contains(Spark SQL, Spark)': True, 'typeof(contains(Spark SQL, Spark))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT contains('Spark SQL'::string, 'Spark'::string); + +## Original Query: SELECT contains('Spark SQL', null); +## PySpark 3.5.5 Result: {'contains(Spark SQL, NULL)': None, 'typeof(contains(Spark SQL, NULL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(NULL)': 'void'} +#query +#SELECT contains('Spark SQL'::string, NULL::void); + +## Original Query: SELECT contains(x'537061726b2053514c', x'537061726b'); +## PySpark 3.5.5 Result: {"contains(X'537061726B2053514C', X'537061726B')": True, "typeof(contains(X'537061726B2053514C', X'537061726B'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'537061726B')": 'binary'} +#query +#SELECT contains(X'537061726B2053514C'::binary, X'537061726B'::binary); diff --git a/datafusion/sqllogictest/test_files/spark/string/decode.slt b/datafusion/sqllogictest/test_files/spark/string/decode.slt new file mode 100644 index 0000000000000..a427fe40389e8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/decode.slt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT decode(2, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); +## PySpark 3.5.5 Result: {'decode(2, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle, Non domestic)': 'San Francisco', 'typeof(decode(2, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle, Non domestic))': 'string', 'typeof(2)': 'int', 'typeof(1)': 'int', 'typeof(Southlake)': 'string', 'typeof(San Francisco)': 'string', 'typeof(3)': 'int', 'typeof(New Jersey)': 'string', 'typeof(4)': 'int', 'typeof(Seattle)': 'string', 'typeof(Non domestic)': 'string'} +#query +#SELECT decode(2::int, 1::int, 'Southlake'::string, 'San Francisco'::string, 3::int, 'New Jersey'::string, 4::int, 'Seattle'::string, 'Non domestic'::string); + +## Original Query: SELECT decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle'); +## PySpark 3.5.5 Result: {'decode(6, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle)': None, 'typeof(decode(6, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle))': 'string', 'typeof(6)': 'int', 'typeof(1)': 'int', 'typeof(Southlake)': 'string', 'typeof(2)': 'int', 'typeof(San Francisco)': 'string', 'typeof(3)': 'int', 'typeof(New Jersey)': 'string', 'typeof(4)': 'int', 'typeof(Seattle)': 'string'} +#query +#SELECT decode(6::int, 1::int, 'Southlake'::string, 2::int, 'San Francisco'::string, 3::int, 'New Jersey'::string, 4::int, 'Seattle'::string); + +## Original Query: SELECT decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); +## PySpark 3.5.5 Result: {'decode(6, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle, Non domestic)': 'Non domestic', 'typeof(decode(6, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle, Non domestic))': 'string', 'typeof(6)': 'int', 'typeof(1)': 'int', 'typeof(Southlake)': 'string', 'typeof(2)': 'int', 'typeof(San Francisco)': 'string', 'typeof(3)': 'int', 'typeof(New Jersey)': 'string', 'typeof(4)': 'int', 'typeof(Seattle)': 'string', 'typeof(Non domestic)': 'string'} +#query +#SELECT decode(6::int, 1::int, 'Southlake'::string, 2::int, 'San Francisco'::string, 3::int, 'New Jersey'::string, 4::int, 'Seattle'::string, 'Non domestic'::string); + +## Original Query: SELECT decode(null, 6, 'Spark', NULL, 'SQL', 4, 'rocks'); +## PySpark 3.5.5 Result: {'decode(NULL, 6, Spark, NULL, SQL, 4, rocks)': 'SQL', 'typeof(decode(NULL, 6, Spark, NULL, SQL, 4, rocks))': 'string', 'typeof(NULL)': 'void', 'typeof(6)': 'int', 'typeof(Spark)': 'string', 'typeof(SQL)': 'string', 'typeof(4)': 'int', 'typeof(rocks)': 'string'} +#query +#SELECT decode(NULL::void, 6::int, 'Spark'::string, 'SQL'::string, 4::int, 'rocks'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/elt.slt b/datafusion/sqllogictest/test_files/spark/string/elt.slt new file mode 100644 index 0000000000000..12917d17e1e47 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/elt.slt @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT elt(1, 'scala', 'java'); +## PySpark 3.5.5 Result: {'elt(1, scala, java)': 'scala', 'typeof(elt(1, scala, java))': 'string', 'typeof(1)': 'int', 'typeof(scala)': 'string', 'typeof(java)': 'string'} +query T +SELECT elt(1::int, 'scala'::string, 'java'::string); +---- +scala + +## Original Query: SELECT elt(2, 'a', 1); +## PySpark 3.5.5 Result: {'elt(2, a, 1)': '1', 'typeof(elt(2, a, 1))': 'string', 'typeof(2)': 'int', 'typeof(a)': 'string', 'typeof(1)': 'int'} +query T +SELECT elt(2::int, 'a'::string, 1::int); +---- +1 + +query T +SELECT elt(11::int, 10, 20) +---- +NULL + +query T +SELECT elt(1::int, 10, 20) +---- +10 + +query T +SELECT elt(1::int, null, 20) +---- +NULL + +query T +SELECT elt(1::int, 10, null) +---- +10 + +query T +SELECT elt(1, 10, null) +---- +10 diff --git a/datafusion/sqllogictest/test_files/spark/string/encode.slt b/datafusion/sqllogictest/test_files/spark/string/encode.slt new file mode 100644 index 0000000000000..4ad02316f4f3f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/encode.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT encode('abc', 'utf-8'); +## PySpark 3.5.5 Result: {'encode(abc, utf-8)': bytearray(b'abc'), 'typeof(encode(abc, utf-8))': 'binary', 'typeof(abc)': 'string', 'typeof(utf-8)': 'string'} +#query +#SELECT encode('abc'::string, 'utf-8'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/endswith.slt b/datafusion/sqllogictest/test_files/spark/string/endswith.slt new file mode 100644 index 0000000000000..35ada546f8bf4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/endswith.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT endswith('Spark SQL', 'SQL'); +## PySpark 3.5.5 Result: {'endswith(Spark SQL, SQL)': True, 'typeof(endswith(Spark SQL, SQL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT endswith('Spark SQL'::string, 'SQL'::string); + +## Original Query: SELECT endswith('Spark SQL', 'Spark'); +## PySpark 3.5.5 Result: {'endswith(Spark SQL, Spark)': False, 'typeof(endswith(Spark SQL, Spark))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT endswith('Spark SQL'::string, 'Spark'::string); + +## Original Query: SELECT endswith('Spark SQL', null); +## PySpark 3.5.5 Result: {'endswith(Spark SQL, NULL)': None, 'typeof(endswith(Spark SQL, NULL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(NULL)': 'void'} +#query +#SELECT endswith('Spark SQL'::string, NULL::void); + +## Original Query: SELECT endswith(x'537061726b2053514c', x'53514c'); +## PySpark 3.5.5 Result: {"endswith(X'537061726B2053514C', X'53514C')": True, "typeof(endswith(X'537061726B2053514C', X'53514C'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'53514C')": 'binary'} +#query +#SELECT endswith(X'537061726B2053514C'::binary, X'53514C'::binary); + +## Original Query: SELECT endswith(x'537061726b2053514c', x'537061726b'); +## PySpark 3.5.5 Result: {"endswith(X'537061726B2053514C', X'537061726B')": False, "typeof(endswith(X'537061726B2053514C', X'537061726B'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'537061726B')": 'binary'} +#query +#SELECT endswith(X'537061726B2053514C'::binary, X'537061726B'::binary); diff --git a/datafusion/sqllogictest/test_files/spark/string/find_in_set.slt b/datafusion/sqllogictest/test_files/spark/string/find_in_set.slt new file mode 100644 index 0000000000000..690d03ffa475f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/find_in_set.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT find_in_set('ab','abc,b,ab,c,def'); +## PySpark 3.5.5 Result: {'find_in_set(ab, abc,b,ab,c,def)': 3, 'typeof(find_in_set(ab, abc,b,ab,c,def))': 'int', 'typeof(ab)': 'string', 'typeof(abc,b,ab,c,def)': 'string'} +#query +#SELECT find_in_set('ab'::string, 'abc,b,ab,c,def'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/format_number.slt b/datafusion/sqllogictest/test_files/spark/string/format_number.slt new file mode 100644 index 0000000000000..a56b8d004c912 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/format_number.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT format_number(12332.123456, '##################.###'); +## PySpark 3.5.5 Result: {'format_number(12332.123456, ##################.###)': '12332.123', 'typeof(format_number(12332.123456, ##################.###))': 'string', 'typeof(12332.123456)': 'decimal(11,6)', 'typeof(##################.###)': 'string'} +#query +#SELECT format_number(12332.123456::decimal(11,6), '##################.###'::string); + +## Original Query: SELECT format_number(12332.123456, 4); +## PySpark 3.5.5 Result: {'format_number(12332.123456, 4)': '12,332.1235', 'typeof(format_number(12332.123456, 4))': 'string', 'typeof(12332.123456)': 'decimal(11,6)', 'typeof(4)': 'int'} +#query +#SELECT format_number(12332.123456::decimal(11,6), 4::int); diff --git a/datafusion/sqllogictest/test_files/spark/string/format_string.slt b/datafusion/sqllogictest/test_files/spark/string/format_string.slt new file mode 100644 index 0000000000000..07c8cd10d1a96 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/format_string.slt @@ -0,0 +1,2315 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +# ================================ +# Basic format_string tests +# ================================ + +## Basic string formatting +query T +SELECT format_string('Hello World %s', 'DataFusion'); +---- +Hello World DataFusion + +query T +SELECT format_string(arrow_cast('Hello World %s', 'LargeUtf8'), 'DataFusion'); +---- +Hello World DataFusion + +query T +SELECT format_string(arrow_cast('Hello World %s', 'Utf8View'), 'DataFusion'); +---- +Hello World DataFusion + +## Basic integer formatting +query T +SELECT format_string('Hello World %d %s', 100, 'days'); +---- +Hello World 100 days + +## Multiple string arguments +query T +SELECT format_string('%s %s %s', 'Hello', 'World', 'Test'); +---- +Hello World Test + +## Format without arguments (just return the format string) +query T +SELECT format_string('Hello World'); +---- +Hello World + +# ================================ +# Integer formatting tests +# ================================ + +## Decimal integer formatting +query T +SELECT format_string('Value: %d', 42); +---- +Value: 42 + +## Hexadecimal integer formatting (lowercase) +query T +SELECT format_string('Hex: %x', 255); +---- +Hex: ff + +## Hexadecimal integer formatting (uppercase) +query T +SELECT format_string('Hex: %X', 255); +---- +Hex: FF + +## Octal integer formatting +query T +SELECT format_string('Octal: %o', 64); +---- +Octal: 100 + +## Integer with width padding +query T +SELECT format_string('Padded: %5d', 42); +---- +Padded: 42 + +## Integer with zero padding +query T +SELECT format_string('Zero padded: %05d', 42); +---- +Zero padded: 00042 + +## Left-aligned integer +query T +SELECT format_string('Left: %-5d|', 42); +---- +Left: 42 | + +## Integer with force sign +query T +SELECT format_string('Signed: %+d', 42); +---- +Signed: +42 + +## Negative integer +query T +SELECT format_string('Negative: %d', -42); +---- +Negative: -42 + +# ================================ +# Float formatting tests +# ================================ + +## Basic float formatting +query T +SELECT format_string('Float: %f', 3.14159); +---- +Float: 3.141590 + +query T +SELECT format_string('Float: %f', 30.0); +---- +Float: 30.000000 + +## Float with precision +query T +SELECT format_string('Precision: %.2f', 3.14159); +---- +Precision: 3.14 + +## Scientific notation (lowercase) +query T +SELECT format_string('Scientific: %e', 1234.5); +---- +Scientific: 1.234500e+03 + +## Scientific notation (uppercase) +query T +SELECT format_string('Scientific: %E', 1234.5); +---- +Scientific: 1.234500E+03 + +## Compact float (lowercase) +query T +SELECT format_string('Compact: %g', 1234.5); +---- +Compact: 1234.5 + +query T +SELECT format_string('Compact: %g', CAST(123456789.1 AS DOUBLE)); +---- +Compact: 1.23457e+08 + +## Compact float (uppercase) +query T +SELECT format_string('Compact: %G', 1234.5); +---- +Compact: 1234.5 + +query T +SELECT format_string('Compact: %G', CAST(123456789.1 AS DOUBLE)); +---- +Compact: 1.23457E+08 + +## Float with width and precision +query T +SELECT format_string('Formatted: %10.2f', 3.14159); +---- +Formatted: 3.14 + +## Float zero padding +query T +SELECT format_string('Zero: %08.2f', 3.14); +---- +Zero: 00003.14 + +## Float with left alignment +query T +SELECT format_string('Left: %-10.2f|', 3.14); +---- +Left: 3.14 | + +## Float with space sign (positive) +query T +SELECT format_string('Space: % .2f', 3.14); +---- +Space: 3.14 + +## Float with space sign (negative) +query T +SELECT format_string('Space: % .2f', -3.14); +---- +Space: -3.14 + +## Float with force sign (positive) +query T +SELECT format_string('Force: %+.2f', 3.14); +---- +Force: +3.14 + +## Float with force sign (negative) +query T +SELECT format_string('Force: %+.2f', -3.14); +---- +Force: -3.14 + +## Float with precision 0 +query T +SELECT format_string('Precision 0: %.0f', 3.14); +---- +Precision 0: 3 + +## Float with precision 0 (rounds up) +query T +SELECT format_string('Precision 0: %.0f', 3.6); +---- +Precision 0: 4 + +## Float with precision 0 and alternate form +query T +SELECT format_string('Alt form: %#.0f', 3.14); +---- +Alt form: 3. + +## Scientific notation with precision 0 +query T +SELECT format_string('Sci: %.0e', 1234.5); +---- +Sci: 1e+03 + +## Compact format with precision 0 +query T +SELECT format_string('Compact: %.0g', 1234.5); +---- +Compact: 1e+03 + +# ================================ +# Boolean formatting tests +# ================================ + +## Boolean lowercase +query T +SELECT format_string('Bool: %b', true); +---- +Bool: true + +## Boolean uppercase +query T +SELECT format_string('Bool: %B', false); +---- +Bool: FALSE + +## Boolean with width +query T +SELECT format_string('Bool: %6b', true); +---- +Bool: true + +## Boolean with invalid ARGUMENT +statement error +SELECT format_string('Bool: %6b', 1) + +# ================================ +# String formatting tests +# ================================ + +## String formatting +query T +SELECT format_string('String: %s', 'DataFusion'); +---- +String: DataFusion + +## String with width +query T +SELECT format_string('Padded: %10s|', 'test'); +---- +Padded: test| + +## String left-aligned +query T +SELECT format_string('Left: %-10s|', 'test'); +---- +Left: test | + +## String with precision (truncation) +query T +SELECT format_string('Truncated: %.3s', 'DataFusion'); +---- +Truncated: Dat + +## String uppercase conversion +query T +SELECT format_string('Upper: %S', 'datafusion'); +---- +Upper: DATAFUSION + +# ================================ +# Character formatting tests +# ================================ + +## Character formatting from integer +query T +SELECT format_string('Char: %c', 97); +---- +Char: a + +## Character uppercase +query T +SELECT format_string('Char: %C', 97); +---- +Char: A + +## Character with width padding +query T +SELECT format_string('Char: %5c', 65); +---- +Char: A + +## Character with left alignment +query T +SELECT format_string('Char: %-5c|', 65); +---- +Char: A | + +## Character uppercase with width +query T +SELECT format_string('Char: %5C', 97); +---- +Char: A + +## Character uppercase with left alignment +query T +SELECT format_string('Char: %-5C|', 97); +---- +Char: A | + +## Character with invalid ARGUMENT +statement error +SELECT format_string('Char: %5c', true); + +# ================================ +# Time formatting tests +# ================================ + +## Hour formatting (24-hour) +query T +SELECT format_string('Hour: %tH', TIMESTAMP '2023-12-25 14:30:45'); +---- +Hour: 14 + +## Hour formatting (12-hour) +query T +SELECT format_string('Hour: %tI', TIMESTAMP '2023-12-25 14:30:45'); +---- +Hour: 02 + +## Minute formatting +query T +SELECT format_string('Minute: %tM', TIMESTAMP '2023-12-25 14:30:45'); +---- +Minute: 30 + +## Second formatting +query T +SELECT format_string('Second: %tS', TIMESTAMP '2023-12-25 14:30:45'); +---- +Second: 45 + +## AM/PM marker +query T +SELECT format_string('AM/PM: %tp', TIMESTAMP '2023-12-25 14:30:45'); +---- +AM/PM: pm + +## AM/PM marker uppercase +query T +SELECT format_string('AM/PM: %Tp', TIMESTAMP '2023-12-25 14:30:45'); +---- +AM/PM: PM + +## AM/PM marker uppercase (morning) +query T +SELECT format_string('AM/PM: %Tp', TIMESTAMP '2023-12-25 09:30:45'); +---- +AM/PM: AM + +## Year formatting +query T +SELECT format_string('Year: %tY', TIMESTAMP '2023-12-25 14:30:45'); +---- +Year: 2023 + +## Year formatting uppercase +query T +SELECT format_string('Year: %TY', TIMESTAMP '2023-12-25 14:30:45'); +---- +Year: 2023 + +## Month formatting +query T +SELECT format_string('Month: %tm', TIMESTAMP '2023-12-25 14:30:45'); +---- +Month: 12 + +## Day formatting +query T +SELECT format_string('Day: %td', TIMESTAMP '2023-12-25 14:30:45'); +---- +Day: 25 + +## Time formatting (HH:MM) +query T +SELECT format_string('Time: %tR', TIMESTAMP '2023-12-25 14:30:45'); +---- +Time: 14:30 + +## Time formatting (HH:MM:SS) +query T +SELECT format_string('Time: %tT', TIMESTAMP '2023-12-25 14:30:45'); +---- +Time: 14:30:45 + +## Date formatting (MM/DD/YY) +query T +SELECT format_string('Date: %tD', TIMESTAMP '2023-12-25 14:30:45'); +---- +Date: 12/25/23 + +## ISO date formatting (YYYY-MM-DD) +query T +SELECT format_string('ISO Date: %tF', TIMESTAMP '2023-12-25 14:30:45'); +---- +ISO Date: 2023-12-25 + +## Complex date formatting (Sun Jul 20 16:17:00 EDT 1969) +query B +SELECT format_string('Date: %tc', TIMESTAMP '1969-07-20 16:17:00') LIKE 'Date: Sun Jul 20 16:17:00 % 1969'; +---- +true + + +## Hour formatting (24-hour no padding) +query T +SELECT format_string('Hour: %tk', TIMESTAMP '2023-12-25 04:30:45'); +---- +Hour: 4 + +## Hour formatting (12-hour no padding) +query T +SELECT format_string('Hour: %tl', TIMESTAMP '2023-12-25 14:30:45'); +---- +Hour: 2 + +## Milliseconds formatting +query T +SELECT format_string('Milliseconds: %tL', TIMESTAMP '2023-12-25 14:30:45.123'); +---- +Milliseconds: 123 + +## Nanoseconds formatting +query T +SELECT format_string('Nanoseconds: %tN', TIMESTAMP '2023-12-25 14:30:45.123456789'); +---- +Nanoseconds: 123456789 + +## Timezone offset (RFC 822) +query T +SELECT format_string('Timezone: %tz', TIMESTAMP '2023-12-25 14:30:45'); +---- +Timezone: +0000 + +## Timezone abbreviation +query T +SELECT format_string('Timezone: %tZ', from_unixtime(1599572549, 'America/New_York')); +---- +Timezone: UTC + +## Seconds since epoch +query T +SELECT format_string('Epoch seconds: %ts', TIMESTAMP '1970-01-01 00:00:01'); +---- +Epoch seconds: 1 + +## Milliseconds since epoch +query T +SELECT format_string('Epoch millis: %tQ', TIMESTAMP '1970-01-01 00:00:01'); +---- +Epoch millis: 1000 + +## Full month name +query T +SELECT format_string('Month: %tB', TIMESTAMP '2023-12-25 14:30:45'); +---- +Month: December + +## Full month name uppercase +query T +SELECT format_string('Month: %TB', TIMESTAMP '2023-12-25 14:30:45'); +---- +Month: DECEMBER + +## Abbreviated month name +query T +SELECT format_string('Month: %tb', TIMESTAMP '2023-12-25 14:30:45'); +---- +Month: Dec + +## Abbreviated month name uppercase +query T +SELECT format_string('Month: %Tb', TIMESTAMP '2023-12-25 14:30:45'); +---- +Month: DEC + +## Same as %tb +query T +SELECT format_string('Month: %th', TIMESTAMP '2023-12-25 14:30:45'); +---- +Month: Dec + +## Full day of week +query T +SELECT format_string('Day: %tA', TIMESTAMP '2023-12-25 14:30:45'); +---- +Day: Monday + +## Full day of week uppercase +query T +SELECT format_string('Day: %TA', TIMESTAMP '2023-12-25 14:30:45'); +---- +Day: MONDAY + +## Abbreviated day of week +query T +SELECT format_string('Day: %ta', TIMESTAMP '2023-12-25 14:30:45'); +---- +Day: Mon + +## Abbreviated day of week uppercase +query T +SELECT format_string('Day: %Ta', TIMESTAMP '2023-12-25 14:30:45'); +---- +Day: MON + +## Century (year/100) +query T +SELECT format_string('Century: %tC', TIMESTAMP '2023-12-25 14:30:45'); +---- +Century: 20 + +## Two-digit year +query T +SELECT format_string('Year: %ty', TIMESTAMP '2023-12-25 14:30:45'); +---- +Year: 23 + +## Day of year +query T +SELECT format_string('Day of year: %tj', TIMESTAMP '2023-12-25 14:30:45'); +---- +Day of year: 359 + +## Day of month (no padding) +query T +SELECT format_string('Day: %te', TIMESTAMP '2023-12-05 14:30:45'); +---- +Day: 5 + +## 12-hour time with AM/PM +query T +SELECT format_string('Time: %tr', TIMESTAMP '2023-12-25 14:30:45'); +---- +Time: 02:30:45 PM + +statement error +SELECT format_string('Time: %t', TIMESTAMP '2023-12-25 14:30:45'); + +statement error +SELECT format_string('Time: %T', TIMESTAMP '2023-12-25 14:30:45'); + + +statement error +SELECT format_string('Time: %tx', TIMESTAMP '2023-12-25 14:30:45'); + +statement error +SELECT format_string('Time: %Tx', TIMESTAMP '2023-12-25 14:30:45'); + + + +# ================================ +# Decimal formatting tests +# ================================ + +## Decimal formatting +query T +SELECT format_string('Decimal: %f', CAST(123.456 AS DECIMAL(10,3))); +---- +Decimal: 123.456000 + +## Decimal with precision +query T +SELECT format_string('Decimal: %.2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Decimal: 123.46 + +## Decimal scientific notation +query T +SELECT format_string('Scientific: %e', CAST(1234.5 AS DECIMAL(10,2))); +---- +Scientific: 1.234500e+03 + +## Decimal with width padding +query T +SELECT format_string('Padded: %10.2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Padded: 123.46 + +## Decimal with zero padding +query T +SELECT format_string('Zero padded: %010.2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Zero padded: 0000123.46 + +## Decimal with left adjustment +query T +SELECT format_string('Left: %-10.2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Left: 123.46 + +## Decimal with plus sign +query T +SELECT format_string('Plus: %+.2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Plus: +123.46 + +## Decimal with space sign +query T +SELECT format_string('Space: % .2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Space: 123.46 + +## Negative decimal with plus sign +query T +SELECT format_string('Negative: %+.2f', CAST(-123.456 AS DECIMAL(10,3))); +---- +Negative: -123.46 + +## Negative decimal with space sign +query T +SELECT format_string('Negative: % .2f', CAST(-123.456 AS DECIMAL(10,3))); +---- +Negative: -123.46 + +## Decimal with width and plus sign +query T +SELECT format_string('Width+Plus: %+10.2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Width+Plus: +123.46 + +## Decimal with zero padding and plus sign +query T +SELECT format_string('Zero+Plus: %+010.2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Zero+Plus: +000123.46 + +## Decimal with left adjustment and plus sign +query T +SELECT format_string('Left+Plus: %-+10.2f', CAST(123.456 AS DECIMAL(10,3))); +---- +Left+Plus: +123.46 + +## Decimal scientific notation with width +query T +SELECT format_string('Sci Width: %15.2e', CAST(1234.5 AS DECIMAL(10,2))); +---- +Sci Width: 1.23e+03 + +## Decimal scientific notation with zero padding +query T +SELECT format_string('Sci Zero: %015.2e', CAST(1234.5 AS DECIMAL(10,2))); +---- +Sci Zero: 00000001.23e+03 + +## Decimal scientific notation with plus sign +query T +SELECT format_string('Sci Plus: %+.2e', CAST(1234.5 AS DECIMAL(10,2))); +---- +Sci Plus: +1.23e+03 + +## Decimal compact format with width +query T +SELECT format_string('Compact: %10.2g', CAST(123.456 AS DECIMAL(10,3))); +---- +Compact: 1.2e+02 + +## Decimal compact format with plus sign +query T +SELECT format_string('Compact+: %+.2g', CAST(123.456 AS DECIMAL(10,3))); +---- +Compact+: +1.2e+02 + +statement error +SELECT format_string('Compact+: %+.2g', 1); + +# ================================ +# Special cases and edge cases +# ================================ + +## Literal percent sign +query T +SELECT format_string('Percent: %%'); +---- +Percent: % + +## Newline character +query T +SELECT format_string('Line1%nLine2'); +---- + +01)Line1 +02)Line2 + +## Multiple format specifiers +query T +SELECT format_string('String: %s, Integer: %d, Float: %.2f', 'test', 42, 3.14159); +---- +String: test, Integer: 42, Float: 3.14 + +## Mixed width and precision +query T +SELECT format_string('Mixed: %10s %5d %.2f', 'hello', 123, 45.678); +---- +Mixed: hello 123 45.68 + +# ================================ +# NULL handling tests +# ================================ + +## NULL format string +query T +SELECT format_string(NULL, 'test'); +---- +NULL + +query T +SELECT format_string(arrow_cast(NULL, 'Utf8'), 'test'); +---- +NULL + +query T +SELECT format_string(arrow_cast(NULL, 'LargeUtf8'), 'test'); +---- +NULL + +query T +SELECT format_string(arrow_cast(NULL, 'Utf8View'), 'test'); +---- +NULL + +## NULL argument with string format +query T +SELECT format_string('Value: %s', NULL); +---- +Value: null + +## NULL with string format (uppercase) +query T +SELECT format_string('Upper: %S', NULL); +---- +Upper: NULL + + +## NULL argument with string format +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Utf8')); +---- +Value: null + +## NULL with string format (uppercase) +query T +SELECT format_string('Upper: %S', arrow_cast(NULL, 'Utf8')); +---- +Upper: NULL + +## NULL argument with string format +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'LargeUtf8')); +---- +Value: null + +## NULL with string format (uppercase) +query T +SELECT format_string('Upper: %S', arrow_cast(NULL, 'LargeUtf8')); +---- +Upper: NULL + +## NULL argument with string format +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Utf8View')); +---- +Value: null + +## NULL with string format (uppercase) +query T +SELECT format_string('Upper: %S', arrow_cast(NULL, 'Utf8View')); +---- +Upper: NULL + +## NULL with integer format using arrow_cast +query T +SELECT format_string('Integer: %d', arrow_cast(NULL, 'Int32')); +---- +Integer: null + +## NULL with hex format (lowercase) using arrow_cast +query T +SELECT format_string('Hex: %x', arrow_cast(NULL, 'Int32')); +---- +Hex: null + +## NULL with hex format (uppercase) using arrow_cast +query T +SELECT format_string('Hex: %X', arrow_cast(NULL, 'Int32')); +---- +Hex: NULL + +## NULL with octal format using arrow_cast +query T +SELECT format_string('Octal: %o', arrow_cast(NULL, 'Int32')); +---- +Octal: null + +## NULL with float format using arrow_cast +query T +SELECT format_string('Float: %f', arrow_cast(NULL, 'Float64')); +---- +Float: null + +## NULL with float and precision using arrow_cast +query T +SELECT format_string('Float: %.2f', arrow_cast(NULL, 'Float64')); +---- +Float: nu + +## NULL with scientific notation (lowercase) using arrow_cast +query T +SELECT format_string('Scientific: %e', arrow_cast(NULL, 'Float64')); +---- +Scientific: null + +## NULL with scientific notation (uppercase) using arrow_cast +query T +SELECT format_string('Scientific: %E', arrow_cast(NULL, 'Float64')); +---- +Scientific: NULL + +## NULL with compact float (lowercase) using arrow_cast +query T +SELECT format_string('Compact: %g', arrow_cast(NULL, 'Float64')); +---- +Compact: null + +## NULL with compact float and precision (lowercase) using arrow_cast +query T +SELECT format_string('Float: %.3g', arrow_cast(NULL, 'Float64')); +---- +Float: nul + +## NULL with compact float (uppercase) using arrow_cast +query T +SELECT format_string('Compact: %G', arrow_cast(NULL, 'Float64')); +---- +Compact: NULL + +## NULL with compact float and precision (uppercase) using arrow_cast +query T +SELECT format_string('Float: %.3G', arrow_cast(NULL, 'Float64')); +---- +Float: NUL + +## NULL with hex float (lowercase) using arrow_cast +query T +SELECT format_string('Hex float: %a', arrow_cast(NULL, 'Float64')); +---- +Hex float: null + +## NULL with hex float (uppercase) using arrow_cast +query T +SELECT format_string('Hex float: %A', arrow_cast(NULL, 'Float64')); +---- +Hex float: NULL + +# ## NULL with float and precision using arrow_cast +# query T +# SELECT format_string('Float: %.2f', arrow_cast(NULL, 'Float16')); +# ---- +# Float: nu + +## NULL with boolean format (lowercase) using arrow_cast +query T +SELECT format_string('Bool: %b', arrow_cast(NULL, 'Boolean')); +---- +Bool: false + +## NULL with boolean format (uppercase) using arrow_cast +query T +SELECT format_string('Bool: %B', arrow_cast(NULL, 'Boolean')); +---- +Bool: FALSE + +## NULL with character format (lowercase) using arrow_cast +query T +SELECT format_string('Char: %c', arrow_cast(NULL, 'Int32')); +---- +Char: null + +## NULL with character format (uppercase) using arrow_cast +query T +SELECT format_string('Char: %C', arrow_cast(NULL, 'Int32')); +---- +Char: NULL + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Hour: %tH', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +---- +Hour: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +---- +Month: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Time32(Second)')); +---- +Month: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Time32(Millisecond)')); +---- +Month: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Time64(Microsecond)')); +---- +Month: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Time64(Nanosecond)')); +---- +Month: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Second, None)')); +---- +Month: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Millisecond, None)')); +---- +Month: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Microsecond, None)')); +---- +Month: null + +## NULL with timestamp format using arrow_cast +query T +SELECT format_string('Month: %tB', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +---- +Month: null + +## NULL with decimal format using arrow_cast +query T +SELECT format_string('Decimal: %f', arrow_cast(NULL, 'Decimal128(10, 2)')); +---- +Decimal: null + +## NULL Int8 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Int8')); +---- +Value: null + +## NULL Int16 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Int16')); +---- +Value: null + +## NULL Int64 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Int64')); +---- +Value: null + +## NULL UInt8 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'UInt8')); +---- +Value: null + +## NULL UInt16 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'UInt16')); +---- +Value: null + +## NULL UInt32 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'UInt32')); +---- +Value: null + +## NULL UInt64 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'UInt64')); +---- +Value: null + +## NULL Float32 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Float32')); +---- +Value: null + +## NULL Float64 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Float64')); +---- +Value: null + +## NULL Timestamp with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Timestamp(Nanosecond, None)')); +---- +Value: null + +## NULL Date32 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Date32')); +---- +Value: null + +## NULL Date64 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Date64')); +---- +Value: null + +## NULL Decimal128 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Decimal128(10, 2)')); +---- +Value: null + +## NULL Decimal256 with string format using arrow_cast +query T +SELECT format_string('Value: %s', arrow_cast(NULL, 'Decimal256(20, 3)')); +---- +Value: null + + +# ================================ +# Error cases (should fail) +# ================================ + +## Format string expects arguments but none provided +statement error +SELECT format_string('Value: %d'); + +statement error +SELECT format_string(1); + +## Too few arguments for format specifiers +statement error +SELECT format_string('Values: %d %s', 42); + +## Invalid conversion for data type +statement error +SELECT format_string('Value: %d', 'not_a_number'); + +statement error +SELECT format_string('Value: %k', 'string'); + +# ================================ +# Positional argument tests +# ================================ + +## Positional arguments +query T +SELECT format_string('%2$s %1$d', 42, 'test'); +---- +test 42 + +## Reuse positional arguments +query T +SELECT format_string('%1$s %1$s', 'repeat'); +---- +repeat repeat + +## Mixed positional and sequential +query T +SELECT format_string('%2$s %s %1$d', 42, 'middle', 'end'); +---- +middle 42 42 + +statement error +SELECT format_string('%$s', 'test'); + +# ================================ +# Flag combination tests +# ================================ + +## Alternate form with hex +query T +SELECT format_string('Hex: %#x', 255); +---- +Hex: 0xff + +## Alternate form with octal +query T +SELECT format_string('Octal: %#o', 64); +---- +Octal: 0100 + +## Space sign with positive number +query T +SELECT format_string('Space: % d', 42); +---- +Space: 42 + +## Grouping separator (if supported) +query T +SELECT format_string('Grouped: %,d', 1234567); +---- +Grouped: 1,234,567 + +## Parentheses for negative numbers +query T +SELECT format_string('Negative: %(d', -42); +---- +Negative: (42) + +# ================================ +# Array/Column tests +# ================================ + +## Test with array values +statement ok +CREATE TABLE test_format(fmt STRING, val1 STRING, val2 INT) AS VALUES + ('Hello %s %d', 'World', 1), + ('Float: %2$d %1$s', '3.14159', 2), + (NULL, '3.14159', 3); + +query T +SELECT format_string(arrow_cast(fmt, 'Utf8'), val1, val2) FROM test_format; +---- +Hello World 1 +Float: 2 3.14159 +NULL + +query T +SELECT format_string(arrow_cast(fmt, 'LargeUtf8'), val1, val2) FROM test_format; +---- +Hello World 1 +Float: 2 3.14159 +NULL + +query T +SELECT format_string(arrow_cast(fmt, 'Utf8View'), val1, val2) FROM test_format; +---- +Hello World 1 +Float: 2 3.14159 +NULL + +query T +SELECT format_string(fmt, arrow_cast(val1, 'LargeUtf8'), val2) FROM test_format; +---- +Hello World 1 +Float: 2 3.14159 +NULL + +query T +SELECT format_string(fmt, arrow_cast(val1, 'Utf8'), val2) FROM test_format; +---- +Hello World 1 +Float: 2 3.14159 +NULL + +query T +SELECT format_string(fmt, arrow_cast(val1, 'Utf8View'), val2) FROM test_format; +---- +Hello World 1 +Float: 2 3.14159 +NULL + +query T +SELECT format_string(arrow_cast('Hello %s %d', 'Utf8'), val1, val2) FROM test_format; +---- +Hello World 1 +Hello 3.14159 2 +Hello 3.14159 3 + +query T +SELECT format_string(arrow_cast('Hello %s %d', 'LargeUtf8'), val1, val2) FROM test_format; +---- +Hello World 1 +Hello 3.14159 2 +Hello 3.14159 3 + +query T +SELECT format_string(arrow_cast('Hello %s %d', 'Utf8View'), val1, val2) FROM test_format; +---- +Hello World 1 +Hello 3.14159 2 +Hello 3.14159 3 + +statement ok +DROP TABLE test_format; + +# ================================ +# Type-specific conversion tests +# ================================ + +## Boolean with string formats +query T +SELECT format_string('Value: %s', arrow_cast(true, 'Boolean')); +---- +Value: true + +query T +SELECT format_string('Value: %S', arrow_cast(false, 'Boolean')); +---- +Value: FALSE + +## Int8 with various formats +query T +SELECT format_string('Decimal: %d', arrow_cast(127, 'Int8')); +---- +Decimal: 127 + +query T +SELECT format_string('Hex: %x', arrow_cast(127, 'Int8')); +---- +Hex: 7f + +query T +SELECT format_string('Hex: %X', arrow_cast(127, 'Int8')); +---- +Hex: 7F + +query T +SELECT format_string('Octal: %o', arrow_cast(127, 'Int8')); +---- +Octal: 177 + +query T +SELECT format_string('Char: %c', arrow_cast(65, 'Int8')); +---- +Char: A + +query T +SELECT format_string('Char: %C', arrow_cast(97, 'Int8')); +---- +Char: A + +query T +SELECT format_string('Char: %c', arrow_cast(65, 'UInt32')); +---- +Char: A + +query T +SELECT format_string('Char: %C', arrow_cast(97, 'UInt32')); +---- +Char: A + +query T +SELECT format_string('Char: %c', arrow_cast(65, 'UInt64')); +---- +Char: A + +query T +SELECT format_string('Char: %C', arrow_cast(97, 'UInt64')); +---- +Char: A + +query T +SELECT format_string('String: %s', arrow_cast(127, 'Int8')); +---- +String: 127 + +query T +SELECT format_string('String: %S', arrow_cast(127, 'Int8')); +---- +String: 127 + +query T +SELECT format_string('String: %s', arrow_cast(127, 'UInt8')); +---- +String: 127 + +query T +SELECT format_string('String: %S', arrow_cast(127, 'UInt8')); +---- +String: 127 + +query T +SELECT format_string('String: %s', arrow_cast(127, 'UInt16')); +---- +String: 127 + +query T +SELECT format_string('String: %S', arrow_cast(127, 'UInt16')); +---- +String: 127 + +query T +SELECT format_string('String: %s', arrow_cast(127, 'Int32')); +---- +String: 127 + +query T +SELECT format_string('String: %S', arrow_cast(127, 'Int32')); +---- +String: 127 + +query T +SELECT format_string('String: %s', arrow_cast(127, 'UInt64')); +---- +String: 127 + +query T +SELECT format_string('String: %S', arrow_cast(127, 'UInt64')); +---- +String: 127 + +## Int16 with various formats +query T +SELECT format_string('Decimal: %d', arrow_cast(32767, 'Int16')); +---- +Decimal: 32767 + +query T +SELECT format_string('Hex: %x', arrow_cast(32767, 'Int16')); +---- +Hex: 7fff + +query T +SELECT format_string('Hex: %X', arrow_cast(32767, 'Int16')); +---- +Hex: 7FFF + +query T +SELECT format_string('Octal: %o', arrow_cast(32767, 'Int16')); +---- +Octal: 77777 + +query T +SELECT format_string('Char: %c', arrow_cast(8364, 'Int16')); +---- +Char: € + +query T +SELECT format_string('String: %s', arrow_cast(32767, 'Int16')); +---- +String: 32767 + +query T +SELECT format_string('NaN: %s', CAST('NaN' AS DOUBLE)); +---- +NaN: NaN + +query T +SELECT format_string('Infinity: %s', CAST('+Inf' AS DOUBLE)); +---- +Infinity: Infinity + +query T +SELECT format_string('Infinity: %s', CAST('-Inf' AS DOUBLE)); +---- +Infinity: -Infinity + +query T +SELECT format_string('NaN: %S', CAST('NaN' AS DOUBLE)); +---- +NaN: NAN + +query T +SELECT format_string('Infinity: %S', CAST('+Inf' AS DOUBLE)); +---- +Infinity: INFINITY + +query T +SELECT format_string('Infinity: %S', CAST('-Inf' AS DOUBLE)); +---- +Infinity: -INFINITY + +## Int32 with various formats +query T +SELECT format_string('Decimal: %d', arrow_cast(2147483647, 'Int32')); +---- +Decimal: 2147483647 + +query T +SELECT format_string('Hex: %x', arrow_cast(255, 'Int32')); +---- +Hex: ff + +query T +SELECT format_string('Octal: %o', arrow_cast(511, 'Int32')); +---- +Octal: 777 + +query T +SELECT format_string('Char: %c', arrow_cast(128512, 'Int32')); +---- +Char: 😀 + +## UInt8 with various formats +query T +SELECT format_string('Decimal: %d', arrow_cast(255, 'UInt8')); +---- +Decimal: 255 + +query T +SELECT format_string('Hex: %x', arrow_cast(255, 'UInt8')); +---- +Hex: ff + +query T +SELECT format_string('Octal: %o', arrow_cast(255, 'UInt8')); +---- +Octal: 377 + +query T +SELECT format_string('Char: %c', arrow_cast(65, 'UInt8')); +---- +Char: A + +## UInt16 with various formats +query T +SELECT format_string('Decimal: %d', arrow_cast(65535, 'UInt16')); +---- +Decimal: 65535 + +query T +SELECT format_string('Hex: %X', arrow_cast(65535, 'UInt16')); +---- +Hex: FFFF + +query T +SELECT format_string('Char: %c', arrow_cast(9733, 'UInt16')); +---- +Char: ★ + +## UInt32 with various formats +query T +SELECT format_string('Decimal: %d', arrow_cast(4294967295, 'UInt32')); +---- +Decimal: 4294967295 + +query T +SELECT format_string('Hex: %x', arrow_cast(4294967295, 'UInt32')); +---- +Hex: ffffffff + +query T +SELECT format_string('String: %s', arrow_cast(4294967295, 'UInt32')); +---- +String: 4294967295 + +## UInt64 with various formats +query T +SELECT format_string('Decimal: %d', arrow_cast(18446744073709551615, 'UInt64')); +---- +Decimal: 18446744073709551615 + +query T +SELECT format_string('Hex: %X', arrow_cast(18446744073709551615, 'UInt64')); +---- +Hex: FFFFFFFFFFFFFFFF + +## Float16 with various formats +query T +SELECT format_string('Float: %f', arrow_cast(3.14, 'Float16')); +---- +Float: 3.140625 + +query T +SELECT format_string('Scientific: %e', arrow_cast(3.14, 'Float16')); +---- +Scientific: 3.140625e+00 + +query T +SELECT format_string('Scientific: %E', arrow_cast(3.14, 'Float16')); +---- +Scientific: 3.140625E+00 + +query T +SELECT format_string('Compact: %g', arrow_cast(3.14, 'Float16')); +---- +Compact: 3.14063 + +query T +SELECT format_string('Compact: %G', arrow_cast(3.14, 'Float16')); +---- +Compact: 3.14063 + +query T +SELECT format_string('String: %s', arrow_cast(3.14, 'Float16')); +---- +String: 3.140625 + +query T +SELECT format_string('String: %S', arrow_cast(3.14, 'Float16')); +---- +String: 3.140625 + +query T +SELECT format_string('Hex float: %a', arrow_cast(3.14, 'Float16')); +---- +Hex float: 0x1.92p1 + +query T +SELECT format_string('Hex float: %A', arrow_cast(3.14, 'Float16')); +---- +Hex float: 0X1.92P1 + +## Float32 with various formats +query T +SELECT format_string('Float: %f', arrow_cast(3.14159, 'Float32')); +---- +Float: 3.141590 + +query T +SELECT format_string('Scientific: %e', arrow_cast(1234.5, 'Float32')); +---- +Scientific: 1.234500e+03 + +query T +SELECT format_string('Compact: %g', arrow_cast(1234.5, 'Float32')); +---- +Compact: 1234.5 + +query T +SELECT format_string('String: %s', arrow_cast(3.14159, 'Float32')); +---- +String: 3.14159 + +query T +SELECT format_string('Hex float: %a', arrow_cast(3.14, 'Float32')); +---- +Hex float: 0x1.91eb86p1 + +query T +SELECT format_string('Hex float: %A', arrow_cast(3.14, 'Float32')); +---- +Hex float: 0X1.91EB86P1 + +## Float64 with various formats + +query T +SELECT format_string('String: %s', arrow_cast(3.14159, 'Float64')); +---- +String: 3.14159 + +query T +SELECT format_string('String: %S', arrow_cast(3.14159, 'Float64')); +---- +String: 3.14159 + +## Decimal128 with various formats +query T +SELECT format_string('Float: %f', arrow_cast(123.456, 'Decimal128(10, 3)')); +---- +Float: 123.456000 + +query T +SELECT format_string('Scientific: %e', arrow_cast(1234.5, 'Decimal128(10, 2)')); +---- +Scientific: 1.234500e+03 + +query T +SELECT format_string('Scientific: %E', arrow_cast(1234.5, 'Decimal128(10, 2)')); +---- +Scientific: 1.234500E+03 + +query T +SELECT format_string('Compact: %g', arrow_cast(1234.5, 'Decimal128(10, 2)')); +---- +Compact: 1234.5 + +query T +SELECT format_string('Compact: %G', arrow_cast(1234.5, 'Decimal128(10, 2)')); +---- +Compact: 1234.5 + +query T +SELECT format_string('String: %s', arrow_cast(123.456, 'Decimal128(10, 3)')); +---- +String: 123456 + +query T +SELECT format_string('String: %S', arrow_cast(123.456, 'Decimal128(10, 3)')); +---- +String: 123456 + +## Decimal256 with various formats +query T +SELECT format_string('Float: %f', arrow_cast(123.456, 'Decimal256(20, 3)')); +---- +Float: 123.456000 + +query T +SELECT format_string('Scientific: %e', arrow_cast(1234.5, 'Decimal256(20, 2)')); +---- +Scientific: 1.234500e+03 + +query T +SELECT format_string('Compact: %g', arrow_cast(1234.5, 'Decimal256(20, 2)')); +---- +Compact: 1234.5 + +query T +SELECT format_string('String: %s', arrow_cast(123.456, 'Decimal256(20, 3)')); +---- +String: 123456 + +## Time32Second with time formats +query T +SELECT format_string('Hour: %tH', arrow_cast(52245::int, 'Time32(Second)')); +---- +Hour: 14 + +query T +SELECT format_string('Minute: %tM', arrow_cast(52245::int, 'Time32(Second)')); +---- +Minute: 30 + +query T +SELECT format_string('String: %s', arrow_cast(52245::int, 'Time32(Second)')); +---- +String: 52245 + +query T +SELECT format_string('String: %S', arrow_cast(52245::int, 'Time32(Second)')); +---- +String: 52245 + +## Time32Millisecond with time formats +query T +SELECT format_string('Hour: %tH', arrow_cast(52245000::int, 'Time32(Millisecond)')); +---- +Hour: 14 + +query T +SELECT format_string('Second: %tS', arrow_cast(52245000::int, 'Time32(Millisecond)')); +---- +Second: 45 + +query T +SELECT format_string('String: %s', arrow_cast(52245000::int, 'Time32(Millisecond)')); +---- +String: 52245000 + +## Time64Microsecond with time formats +query T +SELECT format_string('Hour: %tH', arrow_cast(52245000000, 'Time64(Microsecond)')); +---- +Hour: 14 + +query T +SELECT format_string('Time: %tT', arrow_cast(52245000000, 'Time64(Microsecond)')); +---- +Time: 14:30:45 + +query T +SELECT format_string('String: %s', arrow_cast(52245000000, 'Time64(Microsecond)')); +---- +String: 52245000000 + +## Time64Nanosecond with time formats +query T +SELECT format_string('Hour: %tH', arrow_cast(52245000000000, 'Time64(Nanosecond)')); +---- +Hour: 14 + +query T +SELECT format_string('AM/PM: %tp', arrow_cast(52245000000000, 'Time64(Nanosecond)')); +---- +AM/PM: pm + +query T +SELECT format_string('String: %s', arrow_cast(52245000000000, 'Time64(Nanosecond)')); +---- +String: 52245000000000 + +## TimestampSecond with time formats +query T +SELECT format_string('Year: %tY', arrow_cast(1703512245, 'Timestamp(Second, None)')); +---- +Year: 2023 + +query T +SELECT format_string('Month: %tm', arrow_cast(1703512245, 'Timestamp(Second, None)')); +---- +Month: 12 + +query T +SELECT format_string('String: %s', arrow_cast(1703512245, 'Timestamp(Second, None)')); +---- +String: 1703512245 + +query T +SELECT format_string('String: %S', arrow_cast(1703512245, 'Timestamp(Second, None)')); +---- +String: 1703512245 + +## TimestampMillisecond with time formats +query T +SELECT format_string('ISO Date: %tF', arrow_cast(1703512245000, 'Timestamp(Millisecond, None)')); +---- +ISO Date: 2023-12-25 + +query T +SELECT format_string('String: %s', arrow_cast(1703512245000, 'Timestamp(Millisecond, None)')); +---- +String: 1703512245000 + +## TimestampMicrosecond with time formats +query T +SELECT format_string('Date: %tD', arrow_cast(1703512245000000, 'Timestamp(Microsecond, None)')); +---- +Date: 12/25/23 + +query T +SELECT format_string('String: %s', arrow_cast(1703512245000000, 'Timestamp(Microsecond, None)')); +---- +String: 1703512245000000 + +query T +SELECT format_string('String: %s', arrow_cast('2020-01-02 01:01:11.1234567890Z', 'Timestamp(Nanosecond, None)')); +---- +String: 1577926871123456789 + +## Date32 with time formats +query T +SELECT format_string('Year: %tY', arrow_cast(19716, 'Date32')); +---- +Year: 2023 + +query T +SELECT format_string('Month: %tB', arrow_cast(19716, 'Date32')); +---- +Month: December + +query T +SELECT format_string('String: %s', arrow_cast(19716, 'Date32')); +---- +String: 19716 + +query T +SELECT format_string('String: %S', arrow_cast(19716, 'Date32')); +---- +String: 19716 + +## Date64 with time formats +query T +SELECT format_string('Year: %tY', arrow_cast(19716, 'Date64')); +---- +Year: 2023 + +query T +SELECT format_string('Month: %tB', arrow_cast(19716, 'Date64')); +---- +Month: December + +query T +SELECT format_string('String: %s', arrow_cast(19716, 'Date64')); +---- +String: 19716 + +query T +SELECT format_string('String: %S', arrow_cast(19716, 'Date64')); +---- +String: 19716 + +## Date64 with invalid ARGUMENT +statement error +SELECT format_string('String: %tY', true); + +# ================================ +# General formatting tests (%h, %H) +# ================================ + +# Not implemented yet. Can be implemented after https://github.com/apache/datafusion/pull/17093 is merged +## Hash value formatting (lowercase) +statement error +SELECT format_string('Hash: %h', 'test'); +# ---- +# Hash: ec06e15a + +## Hash value formatting (uppercase) +statement error +SELECT format_string('Hash: %H', 'test'); +# ---- +# Hash: EC06E15A + +## Hash with width +statement error +SELECT format_string('Hash: %10h', 'test'); +# ---- +# Hash: ec06e15a + +# ================================ +# Hexadecimal floating point tests +# ================================ + +## Hexadecimal float (lowercase) +query T +SELECT format_string('Hex float: %a', 15.9375); +---- +Hex float: 0x1.fep3 + +## Hexadecimal float (uppercase) +query T +SELECT format_string('Hex float: %A', 15.9375); +---- +Hex float: 0X1.FEP3 + +## Hexadecimal float with precision +query T +SELECT format_string('Hex float: %.10a', 15.9375); +---- +Hex float: 0x1.fe00000000p3 + +query T +SELECT format_string('%a', 12.3456); +---- +0x1.8b0f27bb2fec5p3 + +## Hexadecimal float with zero +query T +SELECT format_string('Hex float: %a', 0.0); +---- +Hex float: 0x0.0p0 + +## Hexadecimal float with negative value +query T +SELECT format_string('Hex float: %a', -15.9375); +---- +Hex float: -0x1.fep3 + +## Hexadecimal float with very small value +query T +SELECT format_string('Hex float: %a', 0.0000152587890625); +---- +Hex float: 0x1.0p-16 + +## Hexadecimal float with force sign +query T +SELECT format_string('Hex float: %+a', 15.9375); +---- +Hex float: +0x1.fep3 + +## Hexadecimal float with space sign (positive) +query T +SELECT format_string('Hex float: % a', 15.9375); +---- +Hex float: 0x1.fep3 + +## Hexadecimal float with space sign (negative) +query T +SELECT format_string('Hex float: % a', -15.9375); +---- +Hex float: -0x1.fep3 + +## Hexadecimal float uppercase with space sign +query T +SELECT format_string('Hex float: % A', 15.9375); +---- +Hex float: 0X1.FEP3 + +## Hexadecimal float with width +query T +SELECT format_string('Hex float: %20a', 15.9375); +---- +Hex float: 0x1.fep3 + +## Hexadecimal float with zero padding +query T +SELECT format_string('Hex float: %020a', 15.9375); +---- +Hex float: 0x0000000000001.fep3 + +## Hexadecimal float with alternate form and precision +query T +SELECT format_string('Hex float: %#.5a', 1.0); +---- +Hex float: 0x1.00000p0 + +## Hexadecimal float uppercase with force sign +query T +SELECT format_string('Hex float: %+A', -15.9375); +---- +Hex float: -0X1.FEP3 + +## Hexadecimal float with left alignment +query T +SELECT format_string('Hex float: %-20a', 15.9375); +---- +Hex float: 0x1.fep3 + +## Hexadecimal float with subnormal number (Float64) +query T +SELECT format_string('Hex float: %a', 2.2250738585072014e-308); +---- +Hex float: 0x1.0p-1022 + +## Hexadecimal float with smallest subnormal (Float64) +query T +SELECT format_string('Hex float: %a', 5.0e-324); +---- +Hex float: 0x0.0000000000001p-1022 + +## Hexadecimal float uppercase with subnormal +query T +SELECT format_string('Hex float: %A', 5.0e-324); +---- +Hex float: 0X0.0000000000001P-1022 + +## Hexadecimal float with subnormal and precision +query T +SELECT format_string('Hex float: %.20a', 2.2250738585072014e-308); +---- +Hex float: 0x1.00000000000000000000p-1022 + +## Hexadecimal float with negative subnormal +query T +SELECT format_string('Hex float: %a', -5.0e-324); +---- +Hex float: -0x0.0000000000001p-1022 + +## Hexadecimal float with subnormal and precision 5 +query T +SELECT format_string('Hex float: %.5a', 5.0e-324); +---- +Hex float: 0x1.00000p-1074 + +## Hexadecimal float with subnormal and precision 10 +query T +SELECT format_string('Hex float: %.10a', 5.0e-324); +---- +Hex float: 0x1.0000000000p-1074 + +## Hexadecimal float with subnormal and precision 13 (full) +query T +SELECT format_string('Hex float: %.13a', 5.0e-324); +---- +Hex float: 0x0.0000000000001p-1022 + +## Hexadecimal float with larger subnormal and precision +query T +SELECT format_string('Hex float: %.5a', 2.225e-308); +---- +Hex float: 0x1.fffbap-1023 + +## Hexadecimal float with subnormal and precision 0 +query T +SELECT format_string('Hex float: %.0a', 5.0e-324); +---- +Hex float: 0x1.0p-1074 + +query T +SELECT format_string('Hex float: %.2a', 5.0e-324); +---- +Hex float: 0x1.00p-1074 + + +query T +SELECT format_string('Hex float: %.2a', 5.0e-323); +---- +Hex float: 0x1.40p-1071 + +query T +SELECT format_string('Hex float: %.0a', 5.0e-323); +---- +Hex float: 0x1.4p-1071 + +# ================================ +# Relative indexing tests +# ================================ + +## Relative indexing with < +query T +SELECT format_string('%s %>', 'typeof(Hi there! Good morning.)': 'string'} +#query +#SELECT sentences('Hi there! Good morning.'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/soundex.slt b/datafusion/sqllogictest/test_files/spark/string/soundex.slt new file mode 100644 index 0000000000000..f0c46e10fd1de --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/soundex.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT soundex('Miller'); +## PySpark 3.5.5 Result: {'soundex(Miller)': 'M460', 'typeof(soundex(Miller))': 'string', 'typeof(Miller)': 'string'} +#query +#SELECT soundex('Miller'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/split_part.slt b/datafusion/sqllogictest/test_files/spark/string/split_part.slt new file mode 100644 index 0000000000000..0561a03ecf75d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/split_part.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT split_part('11.12.13', '.', 3); +## PySpark 3.5.5 Result: {'split_part(11.12.13, ., 3)': '13', 'typeof(split_part(11.12.13, ., 3))': 'string', 'typeof(11.12.13)': 'string', 'typeof(.)': 'string', 'typeof(3)': 'int'} +#query +#SELECT split_part('11.12.13'::string, '.'::string, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/string/startswith.slt b/datafusion/sqllogictest/test_files/spark/string/startswith.slt new file mode 100644 index 0000000000000..f75f9d080dfac --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/startswith.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT startswith('Spark SQL', 'SQL'); +## PySpark 3.5.5 Result: {'startswith(Spark SQL, SQL)': False, 'typeof(startswith(Spark SQL, SQL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT startswith('Spark SQL'::string, 'SQL'::string); + +## Original Query: SELECT startswith('Spark SQL', 'Spark'); +## PySpark 3.5.5 Result: {'startswith(Spark SQL, Spark)': True, 'typeof(startswith(Spark SQL, Spark))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT startswith('Spark SQL'::string, 'Spark'::string); + +## Original Query: SELECT startswith('Spark SQL', null); +## PySpark 3.5.5 Result: {'startswith(Spark SQL, NULL)': None, 'typeof(startswith(Spark SQL, NULL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(NULL)': 'void'} +#query +#SELECT startswith('Spark SQL'::string, NULL::void); + +## Original Query: SELECT startswith(x'537061726b2053514c', x'53514c'); +## PySpark 3.5.5 Result: {"startswith(X'537061726B2053514C', X'53514C')": False, "typeof(startswith(X'537061726B2053514C', X'53514C'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'53514C')": 'binary'} +#query +#SELECT startswith(X'537061726B2053514C'::binary, X'53514C'::binary); + +## Original Query: SELECT startswith(x'537061726b2053514c', x'537061726b'); +## PySpark 3.5.5 Result: {"startswith(X'537061726B2053514C', X'537061726B')": True, "typeof(startswith(X'537061726B2053514C', X'537061726B'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'537061726B')": 'binary'} +#query +#SELECT startswith(X'537061726B2053514C'::binary, X'537061726B'::binary); diff --git a/datafusion/sqllogictest/test_files/spark/string/substr.slt b/datafusion/sqllogictest/test_files/spark/string/substr.slt new file mode 100644 index 0000000000000..0942bdd86a4ef --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/substr.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT substr('Spark SQL', -3); +## PySpark 3.5.5 Result: {'substr(Spark SQL, -3, 2147483647)': 'SQL', 'typeof(substr(Spark SQL, -3, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(-3)': 'int'} +#query +#SELECT substr('Spark SQL'::string, -3::int); + +## Original Query: SELECT substr('Spark SQL', 5); +## PySpark 3.5.5 Result: {'substr(Spark SQL, 5, 2147483647)': 'k SQL', 'typeof(substr(Spark SQL, 5, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int'} +#query +#SELECT substr('Spark SQL'::string, 5::int); + +## Original Query: SELECT substr('Spark SQL', 5, 1); +## PySpark 3.5.5 Result: {'substr(Spark SQL, 5, 1)': 'k', 'typeof(substr(Spark SQL, 5, 1))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int', 'typeof(1)': 'int'} +#query +#SELECT substr('Spark SQL'::string, 5::int, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/string/substring.slt b/datafusion/sqllogictest/test_files/spark/string/substring.slt new file mode 100644 index 0000000000000..847ce4b6d4739 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/substring.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT substring('Spark SQL', -3); +## PySpark 3.5.5 Result: {'substring(Spark SQL, -3, 2147483647)': 'SQL', 'typeof(substring(Spark SQL, -3, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(-3)': 'int'} +#query +#SELECT substring('Spark SQL'::string, -3::int); + +## Original Query: SELECT substring('Spark SQL', 5); +## PySpark 3.5.5 Result: {'substring(Spark SQL, 5, 2147483647)': 'k SQL', 'typeof(substring(Spark SQL, 5, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int'} +#query +#SELECT substring('Spark SQL'::string, 5::int); + +## Original Query: SELECT substring('Spark SQL', 5, 1); +## PySpark 3.5.5 Result: {'substring(Spark SQL, 5, 1)': 'k', 'typeof(substring(Spark SQL, 5, 1))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int', 'typeof(1)': 'int'} +#query +#SELECT substring('Spark SQL'::string, 5::int, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/string/substring_index.slt b/datafusion/sqllogictest/test_files/spark/string/substring_index.slt new file mode 100644 index 0000000000000..b434d9fa5edc4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/substring_index.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT substring_index('www.apache.org', '.', 2); +## PySpark 3.5.5 Result: {'substring_index(www.apache.org, ., 2)': 'www.apache', 'typeof(substring_index(www.apache.org, ., 2))': 'string', 'typeof(www.apache.org)': 'string', 'typeof(.)': 'string', 'typeof(2)': 'int'} +#query +#SELECT substring_index('www.apache.org'::string, '.'::string, 2::int); diff --git a/datafusion/sqllogictest/test_files/spark/string/to_binary.slt b/datafusion/sqllogictest/test_files/spark/string/to_binary.slt new file mode 100644 index 0000000000000..d8efa323f2c52 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/to_binary.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_binary('abc', 'utf-8'); +## PySpark 3.5.5 Result: {'to_binary(abc, utf-8)': bytearray(b'abc'), 'typeof(to_binary(abc, utf-8))': 'binary', 'typeof(abc)': 'string', 'typeof(utf-8)': 'string'} +#query +#SELECT to_binary('abc'::string, 'utf-8'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/to_char.slt b/datafusion/sqllogictest/test_files/spark/string/to_char.slt new file mode 100644 index 0000000000000..88d88bbb8ad9f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/to_char.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_char(-12454.8, '99G999D9S'); +## PySpark 3.5.5 Result: {'to_char(-12454.8, 99G999D9S)': '12,454.8-', 'typeof(to_char(-12454.8, 99G999D9S))': 'string', 'typeof(-12454.8)': 'decimal(6,1)', 'typeof(99G999D9S)': 'string'} +#query +#SELECT to_char(-12454.8::decimal(6,1), '99G999D9S'::string); + +## Original Query: SELECT to_char(12454, '99G999'); +## PySpark 3.5.5 Result: {'to_char(12454, 99G999)': '12,454', 'typeof(to_char(12454, 99G999))': 'string', 'typeof(12454)': 'int', 'typeof(99G999)': 'string'} +#query +#SELECT to_char(12454::int, '99G999'::string); + +## Original Query: SELECT to_char(454, '999'); +## PySpark 3.5.5 Result: {'to_char(454, 999)': '454', 'typeof(to_char(454, 999))': 'string', 'typeof(454)': 'int', 'typeof(999)': 'string'} +#query +#SELECT to_char(454::int, '999'::string); + +## Original Query: SELECT to_char(454.00, '000D00'); +## PySpark 3.5.5 Result: {'to_char(454.00, 000D00)': '454.00', 'typeof(to_char(454.00, 000D00))': 'string', 'typeof(454.00)': 'decimal(5,2)', 'typeof(000D00)': 'string'} +#query +#SELECT to_char(454.00::decimal(5,2), '000D00'::string); + +## Original Query: SELECT to_char(78.12, '$99.99'); +## PySpark 3.5.5 Result: {'to_char(78.12, $99.99)': '$78.12', 'typeof(to_char(78.12, $99.99))': 'string', 'typeof(78.12)': 'decimal(4,2)', 'typeof($99.99)': 'string'} +#query +#SELECT to_char(78.12::decimal(4,2), '$99.99'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/to_number.slt b/datafusion/sqllogictest/test_files/spark/string/to_number.slt new file mode 100644 index 0000000000000..ffbee15aca4d2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/to_number.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_number('$78.12', '$99.99'); +## PySpark 3.5.5 Result: {'to_number($78.12, $99.99)': Decimal('78.12'), 'typeof(to_number($78.12, $99.99))': 'decimal(4,2)', 'typeof($78.12)': 'string', 'typeof($99.99)': 'string'} +#query +#SELECT to_number('$78.12'::string, '$99.99'::string); + +## Original Query: SELECT to_number('12,454', '99,999'); +## PySpark 3.5.5 Result: {'to_number(12,454, 99,999)': Decimal('12454'), 'typeof(to_number(12,454, 99,999))': 'decimal(5,0)', 'typeof(12,454)': 'string', 'typeof(99,999)': 'string'} +#query +#SELECT to_number('12,454'::string, '99,999'::string); + +## Original Query: SELECT to_number('12,454.8-', '99,999.9S'); +## PySpark 3.5.5 Result: {'to_number(12,454.8-, 99,999.9S)': Decimal('-12454.8'), 'typeof(to_number(12,454.8-, 99,999.9S))': 'decimal(6,1)', 'typeof(12,454.8-)': 'string', 'typeof(99,999.9S)': 'string'} +#query +#SELECT to_number('12,454.8-'::string, '99,999.9S'::string); + +## Original Query: SELECT to_number('454', '999'); +## PySpark 3.5.5 Result: {'to_number(454, 999)': Decimal('454'), 'typeof(to_number(454, 999))': 'decimal(3,0)', 'typeof(454)': 'string', 'typeof(999)': 'string'} +#query +#SELECT to_number('454'::string, '999'::string); + +## Original Query: SELECT to_number('454.00', '000.00'); +## PySpark 3.5.5 Result: {'to_number(454.00, 000.00)': Decimal('454.00'), 'typeof(to_number(454.00, 000.00))': 'decimal(5,2)', 'typeof(454.00)': 'string', 'typeof(000.00)': 'string'} +#query +#SELECT to_number('454.00'::string, '000.00'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/to_varchar.slt b/datafusion/sqllogictest/test_files/spark/string/to_varchar.slt new file mode 100644 index 0000000000000..51662b89e5580 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/to_varchar.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_varchar(-12454.8, '99G999D9S'); +## PySpark 3.5.5 Result: {'to_char(-12454.8, 99G999D9S)': '12,454.8-', 'typeof(to_char(-12454.8, 99G999D9S))': 'string', 'typeof(-12454.8)': 'decimal(6,1)', 'typeof(99G999D9S)': 'string'} +#query +#SELECT to_varchar(-12454.8::decimal(6,1), '99G999D9S'::string); + +## Original Query: SELECT to_varchar(12454, '99G999'); +## PySpark 3.5.5 Result: {'to_char(12454, 99G999)': '12,454', 'typeof(to_char(12454, 99G999))': 'string', 'typeof(12454)': 'int', 'typeof(99G999)': 'string'} +#query +#SELECT to_varchar(12454::int, '99G999'::string); + +## Original Query: SELECT to_varchar(454, '999'); +## PySpark 3.5.5 Result: {'to_char(454, 999)': '454', 'typeof(to_char(454, 999))': 'string', 'typeof(454)': 'int', 'typeof(999)': 'string'} +#query +#SELECT to_varchar(454::int, '999'::string); + +## Original Query: SELECT to_varchar(454.00, '000D00'); +## PySpark 3.5.5 Result: {'to_char(454.00, 000D00)': '454.00', 'typeof(to_char(454.00, 000D00))': 'string', 'typeof(454.00)': 'decimal(5,2)', 'typeof(000D00)': 'string'} +#query +#SELECT to_varchar(454.00::decimal(5,2), '000D00'::string); + +## Original Query: SELECT to_varchar(78.12, '$99.99'); +## PySpark 3.5.5 Result: {'to_char(78.12, $99.99)': '$78.12', 'typeof(to_char(78.12, $99.99))': 'string', 'typeof(78.12)': 'decimal(4,2)', 'typeof($99.99)': 'string'} +#query +#SELECT to_varchar(78.12::decimal(4,2), '$99.99'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/translate.slt b/datafusion/sqllogictest/test_files/spark/string/translate.slt new file mode 100644 index 0000000000000..53ea41a7ac31e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/translate.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT translate('AaBbCc', 'abc', '123'); +## PySpark 3.5.5 Result: {'translate(AaBbCc, abc, 123)': 'A1B2C3', 'typeof(translate(AaBbCc, abc, 123))': 'string', 'typeof(AaBbCc)': 'string', 'typeof(abc)': 'string', 'typeof(123)': 'string'} +#query +#SELECT translate('AaBbCc'::string, 'abc'::string, '123'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/trim.slt b/datafusion/sqllogictest/test_files/spark/string/trim.slt new file mode 100644 index 0000000000000..725bab5e69623 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/trim.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT trim(' SparkSQL '); +## PySpark 3.5.5 Result: {'trim( SparkSQL )': 'SparkSQL', 'typeof(trim( SparkSQL ))': 'string', 'typeof( SparkSQL )': 'string'} +#query +#SELECT trim(' SparkSQL '::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/try_to_binary.slt b/datafusion/sqllogictest/test_files/spark/string/try_to_binary.slt new file mode 100644 index 0000000000000..211520be1e48b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/try_to_binary.slt @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_to_binary('abc', 'utf-8'); +## PySpark 3.5.5 Result: {'try_to_binary(abc, utf-8)': bytearray(b'abc'), 'typeof(try_to_binary(abc, utf-8))': 'binary', 'typeof(abc)': 'string', 'typeof(utf-8)': 'string'} +#query +#SELECT try_to_binary('abc'::string, 'utf-8'::string); + +## Original Query: select try_to_binary('a!', 'base64'); +## PySpark 3.5.5 Result: {'try_to_binary(a!, base64)': None, 'typeof(try_to_binary(a!, base64))': 'binary', 'typeof(a!)': 'string', 'typeof(base64)': 'string'} +#query +#SELECT try_to_binary('a!'::string, 'base64'::string); + +## Original Query: select try_to_binary('abc', 'invalidFormat'); +## PySpark 3.5.5 Result: {'try_to_binary(abc, invalidFormat)': None, 'typeof(try_to_binary(abc, invalidFormat))': 'binary', 'typeof(abc)': 'string', 'typeof(invalidFormat)': 'string'} +#query +#SELECT try_to_binary('abc'::string, 'invalidFormat'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/try_to_number.slt b/datafusion/sqllogictest/test_files/spark/string/try_to_number.slt new file mode 100644 index 0000000000000..10be9e2180be8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/try_to_number.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_to_number('$78.12', '$99.99'); +## PySpark 3.5.5 Result: {'try_to_number($78.12, $99.99)': Decimal('78.12'), 'typeof(try_to_number($78.12, $99.99))': 'decimal(4,2)', 'typeof($78.12)': 'string', 'typeof($99.99)': 'string'} +#query +#SELECT try_to_number('$78.12'::string, '$99.99'::string); + +## Original Query: SELECT try_to_number('12,454', '99,999'); +## PySpark 3.5.5 Result: {'try_to_number(12,454, 99,999)': Decimal('12454'), 'typeof(try_to_number(12,454, 99,999))': 'decimal(5,0)', 'typeof(12,454)': 'string', 'typeof(99,999)': 'string'} +#query +#SELECT try_to_number('12,454'::string, '99,999'::string); + +## Original Query: SELECT try_to_number('12,454.8-', '99,999.9S'); +## PySpark 3.5.5 Result: {'try_to_number(12,454.8-, 99,999.9S)': Decimal('-12454.8'), 'typeof(try_to_number(12,454.8-, 99,999.9S))': 'decimal(6,1)', 'typeof(12,454.8-)': 'string', 'typeof(99,999.9S)': 'string'} +#query +#SELECT try_to_number('12,454.8-'::string, '99,999.9S'::string); + +## Original Query: SELECT try_to_number('454', '999'); +## PySpark 3.5.5 Result: {'try_to_number(454, 999)': Decimal('454'), 'typeof(try_to_number(454, 999))': 'decimal(3,0)', 'typeof(454)': 'string', 'typeof(999)': 'string'} +#query +#SELECT try_to_number('454'::string, '999'::string); + +## Original Query: SELECT try_to_number('454.00', '000.00'); +## PySpark 3.5.5 Result: {'try_to_number(454.00, 000.00)': Decimal('454.00'), 'typeof(try_to_number(454.00, 000.00))': 'decimal(5,2)', 'typeof(454.00)': 'string', 'typeof(000.00)': 'string'} +#query +#SELECT try_to_number('454.00'::string, '000.00'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/ucase.slt b/datafusion/sqllogictest/test_files/spark/string/ucase.slt new file mode 100644 index 0000000000000..00860c697399e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/ucase.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ucase('SparkSql'); +## PySpark 3.5.5 Result: {'ucase(SparkSql)': 'SPARKSQL', 'typeof(ucase(SparkSql))': 'string', 'typeof(SparkSql)': 'string'} +#query +#SELECT ucase('SparkSql'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/unbase64.slt b/datafusion/sqllogictest/test_files/spark/string/unbase64.slt new file mode 100644 index 0000000000000..5cf3fbee0455d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/unbase64.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT unbase64('U3BhcmsgU1FM'); +## PySpark 3.5.5 Result: {'unbase64(U3BhcmsgU1FM)': bytearray(b'Spark SQL'), 'typeof(unbase64(U3BhcmsgU1FM))': 'binary', 'typeof(U3BhcmsgU1FM)': 'string'} +#query +#SELECT unbase64('U3BhcmsgU1FM'::string); diff --git a/datafusion/sqllogictest/test_files/spark/string/upper.slt b/datafusion/sqllogictest/test_files/spark/string/upper.slt new file mode 100644 index 0000000000000..91c92940332a7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/upper.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT upper('SparkSql'); +## PySpark 3.5.5 Result: {'upper(SparkSql)': 'SPARKSQL', 'typeof(upper(SparkSql))': 'string', 'typeof(SparkSql)': 'string'} +#query +#SELECT upper('SparkSql'::string); diff --git a/datafusion/sqllogictest/test_files/spark/struct/named_struct.slt b/datafusion/sqllogictest/test_files/spark/struct/named_struct.slt new file mode 100644 index 0000000000000..83b24f6d041f2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/struct/named_struct.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT named_struct("a", 1, "b", 2, "c", 3); +## PySpark 3.5.5 Result: {'named_struct(a, 1, b, 2, c, 3)': Row(a=1, b=2, c=3), 'typeof(named_struct(a, 1, b, 2, c, 3))': 'struct', 'typeof(a)': 'string', 'typeof(1)': 'int', 'typeof(b)': 'string', 'typeof(2)': 'int', 'typeof(c)': 'string', 'typeof(3)': 'int'} +#query +#SELECT named_struct('a'::string, 1::int, 'b'::string, 2::int, 'c'::string, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/struct/struct.slt b/datafusion/sqllogictest/test_files/spark/struct/struct.slt new file mode 100644 index 0000000000000..fe23e249701f5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/struct/struct.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT struct(1, 2, 3); +## PySpark 3.5.5 Result: {'struct(1, 2, 3)': Row(col1=1, col2=2, col3=3), 'typeof(struct(1, 2, 3))': 'struct', 'typeof(1)': 'int', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT struct(1::int, 2::int, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/url/parse_url.slt b/datafusion/sqllogictest/test_files/spark/url/parse_url.slt new file mode 100644 index 0000000000000..f2dc55f75598a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/parse_url.slt @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT parse_url('http://spark.apache.org/path?query=1'::string, 'HOST'::string); +---- +spark.apache.org + +query T +SELECT parse_url('http://spark.apache.org/path?query=1'::string, 'QUERY'::string); +---- +query=1 + +query T +SELECT parse_url('http://spark.apache.org/path?query=1'::string, 'QUERY'::string, 'query'::string); +---- +1 + +query T +SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'HOST'::string); +---- +spark.apache.org + +query T +SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'PATH'::string); +---- +/path + +query T +SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'QUERY'::string); +---- +query=1 + +query T +SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'REF'::string); +---- +Ref + +query T +SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'PROTOCOL'::string); +---- +http + +query T +SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'FILE'::string); +---- +/path?query=1 + +query T +SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'AUTHORITY'::string); +---- +userinfo@spark.apache.org + +query T +SELECT parse_url('http://userinfo@spark.apache.org/path?query=1#Ref'::string, 'USERINFO'::string); +---- +userinfo + +query T +SELECT parse_url('https://example.com/a?x=1', 'QUERY', 'x'); +---- +1 + +query T +SELECT parse_url('https://example.com/a?x=1', 'query', 'x'); +---- +NULL + +query T +SELECT parse_url('www.example.com/path?x=1', 'HOST'); +---- +NULL + +query T +SELECT parse_url('www.example.com/path?x=1', 'host'); +---- +NULL + +query T +SELECT parse_url('https://example.com/?a=1', 'QUERY', 'b'); +---- +NULL + +query T +SELECT parse_url('https://example.com/?a=1', 'query', 'b'); +---- +NULL + +query T +SELECT parse_url('https://example.com/path#frag', 'REF'); +---- +frag + +query T +SELECT parse_url('https://example.com/path#frag', 'ref'); +---- +NULL + +query T +SELECT parse_url('ftp://user:pwd@ftp.example.com:21/files', 'USERINFO'); +---- +user:pwd + +query T +SELECT parse_url('ftp://user:pwd@ftp.example.com:21/files', 'userinfo'); +---- +NULL + +query T +SELECT parse_url('http://[2001:db8::2]:8080/index.html?ok=1', 'HOST'); +---- +[2001:db8::2] + +query T +SELECT parse_url('http://[2001:db8::2]:8080/index.html?ok=1', 'host'); +---- +NULL + +query T +SELECT parse_url('notaurl', 'HOST'); +---- +NULL + +query T +SELECT parse_url('notaurl', 'host'); +---- +NULL + +query T +SELECT parse_url('https://example.com', 'PATH'); +---- +(empty) + +query T +SELECT parse_url('https://example.com', 'path'); +---- +NULL + +query T +SELECT parse_url('https://example.com/a/b?x=1&y=2#frag', 'PROTOCOL'); +---- +https + +query T +SELECT parse_url('https://example.com/a/b?x=1&y=2#frag', 'protocol'); +---- +NULL + +query T +SELECT parse_url('https://ex.com/?Tag=ok', 'QUERY', 'tag'); +---- +NULL + +query T +SELECT parse_url('https://ex.com/?Tag=ok', 'query', 'tag'); +---- +NULL + +statement error 'parse_url' does not support zero arguments +SELECT parse_url(); + +query error DataFusion error: Execution error: The url is invalid: inva lid://spark\.apache\.org/path\?query=1\. Use `try_parse_url` to tolerate invalid URL and return NULL instead\. SQLSTATE: 22P02 +SELECT parse_url('inva lid://spark.apache.org/path?query=1', 'QUERY'); diff --git a/datafusion/sqllogictest/test_files/spark/url/try_parse_url.slt b/datafusion/sqllogictest/test_files/spark/url/try_parse_url.slt new file mode 100644 index 0000000000000..403747c63c77c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/try_parse_url.slt @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/blob/b6095cc7fccaf016b47f009ba93b2357dc781a7d/python/pysail/tests/spark/function/test_try_parse_url.txt +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +query T +SELECT try_parse_url('https://example.com/a?x=1', 'QUERY', 'x'); +---- +1 + +query T +SELECT try_parse_url('https://example.com/a?x=1', 'query', 'x'); +---- +NULL + +query T +SELECT try_parse_url('www.example.com/path?x=1', 'HOST'); +---- +NULL + +query T +SELECT try_parse_url('www.example.com/path?x=1', 'host'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com/?a=1', 'QUERY', 'b'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com/?a=1', 'query', 'b'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com/path#frag', 'REF'); +---- +frag + +query T +SELECT try_parse_url('https://example.com/path#frag', 'ref'); +---- +NULL + +query T +SELECT try_parse_url('ftp://user:pwd@ftp.example.com:21/files', 'USERINFO'); +---- +user:pwd + +query T +SELECT try_parse_url('ftp://user:pwd@ftp.example.com:21/files', 'userinfo'); +---- +NULL + +query T +SELECT try_parse_url('http://[2001:db8::2]:8080/index.html?ok=1', 'HOST'); +---- +[2001:db8::2] + +query T +SELECT try_parse_url('http://[2001:db8::2]:8080/index.html?ok=1', 'host'); +---- +NULL + +query T +SELECT try_parse_url('notaurl', 'HOST'); +---- +NULL + +query T +SELECT try_parse_url('notaurl', 'host'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com', 'PATH'); +---- +(empty) + +query T +SELECT try_parse_url('https://example.com', 'path'); +---- +NULL + +query T +SELECT try_parse_url('https://example.com/a/b?x=1&y=2#frag', 'PROTOCOL'); +---- +https + +query T +SELECT try_parse_url('https://example.com/a/b?x=1&y=2#frag', 'protocol'); +---- +NULL + +query T +SELECT try_parse_url('https://ex.com/?Tag=ok', 'QUERY', 'tag'); +---- +NULL + +query T +SELECT try_parse_url('https://ex.com/?Tag=ok', 'query', 'tag'); +---- +NULL + +query T +SELECT try_parse_url('inva lid://spark.apache.org/path?query=1', 'QUERY'); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/url/url_decode.slt b/datafusion/sqllogictest/test_files/spark/url/url_decode.slt new file mode 100644 index 0000000000000..fa5028b647dc3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/url_decode.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT url_decode('https%3A%2F%2Fspark.apache.org'); +## PySpark 3.5.5 Result: {'url_decode(https%3A%2F%2Fspark.apache.org)': 'https://spark.apache.org', 'typeof(url_decode(https%3A%2F%2Fspark.apache.org))': 'string', 'typeof(https%3A%2F%2Fspark.apache.org)': 'string'} +#query +#SELECT url_decode('https%3A%2F%2Fspark.apache.org'::string); diff --git a/datafusion/sqllogictest/test_files/spark/url/url_encode.slt b/datafusion/sqllogictest/test_files/spark/url/url_encode.slt new file mode 100644 index 0000000000000..6aef87dcb4c0f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/url_encode.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT url_encode('https://spark.apache.org'); +## PySpark 3.5.5 Result: {'url_encode(https://spark.apache.org)': 'https%3A%2F%2Fspark.apache.org', 'typeof(url_encode(https://spark.apache.org))': 'string', 'typeof(https://spark.apache.org)': 'string'} +#query +#SELECT url_encode('https://spark.apache.org'::string); diff --git a/datafusion/sqllogictest/test_files/spark/xml/xpath.slt b/datafusion/sqllogictest/test_files/spark/xml/xpath.slt new file mode 100644 index 0000000000000..d1ff9239216c9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/xml/xpath.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT xpath('b1b2b3c1c2','a/b'); +## PySpark 3.5.5 Result: {'xpath(b1b2b3c1c2, a/b)': [None, None, None], 'typeof(xpath(b1b2b3c1c2, a/b))': 'array', 'typeof(b1b2b3c1c2)': 'string', 'typeof(a/b)': 'string'} +#query +#SELECT xpath('b1b2b3c1c2'::string, 'a/b'::string); diff --git a/datafusion/sqllogictest/test_files/spark/xml/xpath_boolean.slt b/datafusion/sqllogictest/test_files/spark/xml/xpath_boolean.slt new file mode 100644 index 0000000000000..8a5dc693eb893 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/xml/xpath_boolean.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT xpath_boolean('1','a/b'); +## PySpark 3.5.5 Result: {'xpath_boolean(1, a/b)': True, 'typeof(xpath_boolean(1, a/b))': 'boolean', 'typeof(1)': 'string', 'typeof(a/b)': 'string'} +#query +#SELECT xpath_boolean('1'::string, 'a/b'::string); diff --git a/datafusion/sqllogictest/test_files/spark/xml/xpath_string.slt b/datafusion/sqllogictest/test_files/spark/xml/xpath_string.slt new file mode 100644 index 0000000000000..cfabf467edfaa --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/xml/xpath_string.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT xpath_string('bcc','a/c'); +## PySpark 3.5.5 Result: {'xpath_string(bcc, a/c)': 'cc', 'typeof(xpath_string(bcc, a/c))': 'string', 'typeof(bcc)': 'string', 'typeof(a/c)': 'string'} +#query +#SELECT xpath_string('bcc'::string, 'a/c'::string); diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 79b783f89a614..f602dbb54b081 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -303,6 +303,26 @@ SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'x ---- fooxx +query T +SELECT regexp_replace(arrow_cast('foobar', 'LargeUtf8'), 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Utf8View'), 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace('foobar', arrow_cast('bar', 'LargeUtf8'), 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace('foobar', arrow_cast('bar', 'Utf8View'), 'xx', 'gi') +---- +fooxx + query T SELECT repeat('foo', 3) ---- diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index a72c8f5744849..fb67daa0b8405 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -784,7 +784,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: regexp_like(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k +01)Projection: test.column1_utf8view ~ Utf8View("^https?://(?:www\.)?([^/]+)/.*$") AS k 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REGEXP_MATCH @@ -804,7 +804,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: regexp_replace(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$"), Utf8("\1")) AS k +01)Projection: regexp_replace(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$"), Utf8View("\1")) AS k 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REPEAT diff --git a/datafusion/sqllogictest/test_files/strings.slt b/datafusion/sqllogictest/test_files/strings.slt index 81b8f4b2da9a1..9fa453fa02523 100644 --- a/datafusion/sqllogictest/test_files/strings.slt +++ b/datafusion/sqllogictest/test_files/strings.slt @@ -115,6 +115,12 @@ p1 p1e1 p1m1e1 +query T rowsort +SELECT s FROM test WHERE s ILIKE 'p1'; +---- +P1 +p1 + # NOT ILIKE query T rowsort SELECT s FROM test WHERE s NOT ILIKE 'p1%'; diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index bdba738761034..95eeffc31903f 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -53,9 +53,9 @@ select * from struct_values; query TT select arrow_typeof(s1), arrow_typeof(s2) from struct_values; ---- -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(c0 Int32) Struct(a Int32, b Utf8View) +Struct(c0 Int32) Struct(a Int32, b Utf8View) +Struct(c0 Int32) Struct(a Int32, b Utf8View) # struct[i] @@ -229,12 +229,12 @@ select named_struct('field_a', 1, 'field_b', 2); query T select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3)); ---- -Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(first Int64, second Int64, third Int64) query T select arrow_typeof({'first': 1, 'second': 2, 'third': 3}); ---- -Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(first Int64, second Int64, third Int64) # test nested struct literal query ? @@ -271,12 +271,33 @@ select a from values where (a, c) = (1, 'a'); ---- 1 +query I +select a from values as v where (v.a, v.c) = (1, 'a'); +---- +1 + +query I +select a from values as v where (v.a, v.c) != (1, 'a'); +---- +2 +3 + +query I +select a from values as v where (v.a, v.c) = (1, 'b'); +---- + query I select a from values where (a, c) IN ((1, 'a'), (2, 'b')); ---- 1 2 +query I +select a from values as v where (v.a, v.c) IN ((1, 'a'), (2, 'b')); +---- +1 +2 + statement ok drop table values; @@ -392,7 +413,7 @@ create table t(a struct, b struct) as valu query T select arrow_typeof([a, b]) from t; ---- -List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) query ? select [a, b] from t; @@ -443,12 +464,12 @@ select * from t; query T select arrow_typeof(c1) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, b Int32) query T select arrow_typeof(c2) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, b Float32) statement ok drop table t; @@ -465,8 +486,8 @@ select * from t; query T select arrow_typeof(column1) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8, c Float64) +Struct(r Utf8, c Float64) statement ok drop table t; @@ -498,9 +519,9 @@ select coalesce(s1) from t; query T select arrow_typeof(coalesce(s1, s2)) from t; ---- -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) statement ok drop table t; @@ -525,9 +546,9 @@ select coalesce(s1, s2) from t; query T select arrow_typeof(coalesce(s1, s2)) from t; ---- -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) statement ok drop table t; @@ -562,7 +583,7 @@ create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as valu query T select arrow_typeof([a, b]) from t; ---- -List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) statement ok drop table t; @@ -585,13 +606,13 @@ create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, query T select arrow_typeof(a) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, c Int32, g Float32) # type of each column should not coerced but perserve as it is query T select arrow_typeof(b) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, c Float32, g Int32) statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index aaccaaa43ce49..dec9357495356 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -210,9 +210,9 @@ physical_plan 08)--------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 09)----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] 10)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] +11)--------------------DataSourceExec: partitions=1, partition_sizes=[2] 12)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -13)--------DataSourceExec: partitions=1, partition_sizes=[1] +13)--------DataSourceExec: partitions=1, partition_sizes=[2] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -245,9 +245,9 @@ physical_plan 08)--------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 09)----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int * Float64(1))] 10)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] +11)--------------------DataSourceExec: partitions=1, partition_sizes=[2] 12)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -13)--------DataSourceExec: partitions=1, partition_sizes=[1] +13)--------DataSourceExec: partitions=1, partition_sizes=[2] query IR rowsort SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -280,9 +280,9 @@ physical_plan 08)--------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 09)----------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] 10)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] +11)--------------------DataSourceExec: partitions=1, partition_sizes=[2] 12)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -13)--------DataSourceExec: partitions=1, partition_sizes=[1] +13)--------DataSourceExec: partitions=1, partition_sizes=[2] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1 @@ -318,9 +318,9 @@ physical_plan 10)------------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 11)--------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] 12)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -13)------------------------DataSourceExec: partitions=1, partition_sizes=[1] +13)------------------------DataSourceExec: partitions=1, partition_sizes=[2] 14)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -15)--------DataSourceExec: partitions=1, partition_sizes=[1] +15)--------DataSourceExec: partitions=1, partition_sizes=[2] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 @@ -400,7 +400,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: +04)----Projection: 05)------Filter: t1.t1_int < t1.t1_id 06)--------TableScan: t1 projection=[t1_id, t1_int] @@ -499,7 +499,7 @@ logical_plan 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 05)------Projection: Int64(1) AS a -06)--------EmptyRelation +06)--------EmptyRelation: rows=1 query II rowsort SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from t1 @@ -619,7 +619,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[t1_id, t1_name] 03)--SubqueryAlias: __correlated_sq_1 -04)----EmptyRelation +04)----EmptyRelation: rows=1 #exists_subquery_with_limit #de-correlated, limit is removed @@ -644,7 +644,7 @@ SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) ---- -logical_plan EmptyRelation +logical_plan EmptyRelation: rows=0 query IT rowsort SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) @@ -701,7 +701,7 @@ logical_plan 01)Projection: t1.t1_id, __scalar_sq_1.t2_id AS t2_id 02)--Left Join: 03)----TableScan: t1 projection=[t1_id] -04)----EmptyRelation +04)----EmptyRelation: rows=0 query II rowsort SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 @@ -921,7 +921,7 @@ query TT explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN Int64(NULL) ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 @@ -995,7 +995,7 @@ select t1.t1_int from t1 where ( ---- logical_plan 01)Projection: t1.t1_int -02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_two END = Int64(2) +02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN Int64(NULL) ELSE __scalar_sq_1.cnt_plus_two END = Int64(2) 03)----Projection: t1.t1_int, __scalar_sq_1.cnt_plus_two, __scalar_sq_1.count(Int64(1)), __scalar_sq_1.__always_true 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------TableScan: t1 projection=[t1_int] @@ -1049,6 +1049,46 @@ false true true +query IT rowsort +SELECT t1_id, (SELECT case when max(t2.t2_id) > 1 then 'a' else 'b' end FROM t2 WHERE t2.t2_int = t1.t1_int) x from t1 +---- +11 a +22 b +33 a +44 b + +query IB rowsort +SELECT t1_id, (SELECT max(t2.t2_id) is null FROM t2 WHERE t2.t2_int = t1.t1_int) x from t1 +---- +11 false +22 true +33 false +44 true + +query TT +explain SELECT t1_id, (SELECT max(t2.t2_id) is null FROM t2 WHERE t2.t2_int = t1.t1_int) x from t1 +---- +logical_plan +01)Projection: t1.t1_id, __scalar_sq_1.__always_true IS NULL OR __scalar_sq_1.__always_true IS NOT NULL AND __scalar_sq_1.max(t2.t2_id) IS NULL AS x +02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int +03)----TableScan: t1 projection=[t1_id, t1_int] +04)----SubqueryAlias: __scalar_sq_1 +05)------Projection: max(t2.t2_id) IS NULL, t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[max(t2.t2_id)]] +07)----------TableScan: t2 projection=[t2_id, t2_int] + +query TT +explain SELECT t1_id, (SELECT max(t2.t2_id) FROM t2 WHERE t2.t2_int = t1.t1_int) x from t1 +---- +logical_plan +01)Projection: t1.t1_id, __scalar_sq_1.max(t2.t2_id) AS x +02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int +03)----TableScan: t1 projection=[t1_id, t1_int] +04)----SubqueryAlias: __scalar_sq_1 +05)------Projection: max(t2.t2_id), t2.t2_int +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[max(t2.t2_id)]] +07)----------TableScan: t2 projection=[t2_id, t2_int] + # in_subquery_to_join_with_correlated_outer_filter_disjunction query TT explain select t1.t1_id, @@ -1152,10 +1192,10 @@ physical_plan 01)CoalesceBatchesExec: target_batch_size=2 02)--FilterExec: t1_id@0 > 40 OR NOT mark@3, projection=[t1_id@0, t1_name@1, t1_int@2] 03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=LeftMark, on=[(t1_id@0, t2_id@0)] -05)--------DataSourceExec: partitions=1, partition_sizes=[1] +04)------HashJoinExec: mode=CollectLeft, join_type=RightMark, on=[(t2_id@0, t1_id@0)] +05)--------DataSourceExec: partitions=1, partition_sizes=[2] 06)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----------DataSourceExec: partitions=1, partition_sizes=[2] statement ok set datafusion.explain.logical_plan_only = true; @@ -1413,9 +1453,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[a] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: -05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -06)--------TableScan: t2 projection=[] +04)----EmptyRelation: rows=1 statement count 0 drop table t1; diff --git a/datafusion/sqllogictest/test_files/subquery_sort.slt b/datafusion/sqllogictest/test_files/subquery_sort.slt index 5d22bf92e7e65..1e5a3c8f526ac 100644 --- a/datafusion/sqllogictest/test_files/subquery_sort.slt +++ b/datafusion/sqllogictest/test_files/subquery_sort.slt @@ -100,7 +100,7 @@ physical_plan 01)ProjectionExec: expr=[c1@0 as c1, r@1 as r] 02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST, c3@2 ASC NULLS LAST, c9@3 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] -04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3, c9], file_type=csv, has_header=true @@ -126,10 +126,9 @@ physical_plan 01)ProjectionExec: expr=[c1@0 as c1, r@1 as r] 02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST, c3@2 ASC NULLS LAST, c9@3 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] -04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8View(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortPreservingMergeExec: [c1@0 DESC] -06)----------SortExec: expr=[c1@0 DESC], preserve_partitioning=[true] -07)------------DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0] +04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] statement ok DROP TABLE sink_table_with_utf8view; diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index 7d318c50bacf4..0159abe8d06b7 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -153,23 +153,23 @@ SELECT * FROM generate_series(1, 5, NULL) query TT EXPLAIN SELECT * FROM generate_series(1, 5) ---- -logical_plan TableScan: tmp_table projection=[value] +logical_plan TableScan: generate_series() projection=[value] physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=5, batch_size=8192] # # Test generate_series with invalid arguments # -query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series SELECT * FROM generate_series(5, 1) -query error DataFusion error: Error during planning: start is smaller than end, but increment is negative: cannot generate infinite series +query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series SELECT * FROM generate_series(-6, 6, -1) -query error DataFusion error: Error during planning: step cannot be zero +query error DataFusion error: Error during planning: Step cannot be zero SELECT * FROM generate_series(-6, 6, 0) -query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series SELECT * FROM generate_series(6, -6, 1) @@ -177,7 +177,7 @@ statement error DataFusion error: Error during planning: generate_series functio SELECT * FROM generate_series(1, 2, 3, 4) -statement error DataFusion error: Error during planning: First argument must be an integer literal +statement error DataFusion error: Error during planning: Argument \#1 must be an INTEGER, TIMESTAMP, DATE or NULL, got Utf8 SELECT * FROM generate_series('foo', 'bar') # UDF and UDTF `generate_series` can be used simultaneously @@ -220,6 +220,12 @@ SELECT * FROM range(3, 6) 4 5 +query I rowsort +SELECT * FROM range(1, 1+2) +---- +1 +2 + # #generated_data > batch_size query I SELECT count(v1) FROM range(-66666,66666) t1(v1) @@ -270,23 +276,23 @@ SELECT * FROM range(1, 5, NULL) query TT EXPLAIN SELECT * FROM range(1, 5) ---- -logical_plan TableScan: tmp_table projection=[value] +logical_plan TableScan: range() projection=[value] physical_plan LazyMemoryExec: partitions=1, batch_generators=[range: start=1, end=5, batch_size=8192] # # Test range with invalid arguments # -query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series SELECT * FROM range(5, 1) -query error DataFusion error: Error during planning: start is smaller than end, but increment is negative: cannot generate infinite series +query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series SELECT * FROM range(-6, 6, -1) -query error DataFusion error: Error during planning: step cannot be zero +query error DataFusion error: Error during planning: Step cannot be zero SELECT * FROM range(-6, 6, 0) -query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series SELECT * FROM range(6, -6, 1) @@ -294,12 +300,197 @@ statement error DataFusion error: Error during planning: range function requires SELECT * FROM range(1, 2, 3, 4) -statement error DataFusion error: Error during planning: First argument must be an integer literal +statement error DataFusion error: Error during planning: Argument \#1 must be an INTEGER, TIMESTAMP, DATE or NULL, got Utf8 SELECT * FROM range('foo', 'bar') +statement error DataFusion error: Error during planning: Argument #2 must be an INTEGER or NULL, got Literal\(Utf8\("bar"\), None\) +SELECT * FROM range(1, 'bar') + # UDF and UDTF `range` can be used simultaneously query ? rowsort SELECT range(1, t1.end) FROM range(3, 5) as t1(end) ---- [1, 2, 3] [1, 2] + +# +# Test timestamp ranges +# + +# Basic timestamp range with 1 day interval +query P rowsort +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-04T00:00:00', INTERVAL '1' DAY) +---- +2023-01-01T00:00:00 +2023-01-02T00:00:00 +2023-01-03T00:00:00 + +# Timestamp range with hour interval +query P rowsort +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-01T03:00:00', INTERVAL '1' HOUR) +---- +2023-01-01T00:00:00 +2023-01-01T01:00:00 +2023-01-01T02:00:00 + +# Timestamp range with month interval +query P rowsort +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-04-01T00:00:00', INTERVAL '1' MONTH) +---- +2023-01-01T00:00:00 +2023-02-01T00:00:00 +2023-03-01T00:00:00 + +# Timestamp generate_series (includes end) +query P rowsort +SELECT * FROM generate_series(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-03T00:00:00', INTERVAL '1' DAY) +---- +2023-01-01T00:00:00 +2023-01-02T00:00:00 +2023-01-03T00:00:00 + +# Timestamp range with timezone +query P +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00+00:00', TIMESTAMP '2023-01-03T00:00:00+00:00', INTERVAL '1' DAY) +---- +2023-01-01T00:00:00 +2023-01-02T00:00:00 + +# Negative timestamp range (going backwards) +query P +SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '-1' DAY) +---- +2023-01-03T00:00:00 +2023-01-02T00:00:00 + +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '1' DAY) + +query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-02T00:00:00', INTERVAL '-1' DAY) + +query error DataFusion error: Error during planning: range function with timestamps requires exactly 3 arguments +SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00') + +# Single timestamp (start == end) +query P +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '1' DAY) +---- + +# Timestamp range with NULL values +query P +SELECT * FROM range(NULL::TIMESTAMP, TIMESTAMP '2023-01-03T00:00:00', INTERVAL '1' DAY) +---- + +query P +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', NULL::TIMESTAMP, INTERVAL '1' DAY) +---- + +# No interval gives no rows +query P +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-03T00:00:00', NULL::INTERVAL) +---- + +# Zero-length interval gives error +query error DataFusion error: Error during planning: Step interval cannot be zero +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-03T00:00:00', INTERVAL '0' DAY) + +# Timezone-aware +query P +SELECT * FROM range(TIMESTAMPTZ '2023-02-01T00:00:00-07:00', TIMESTAMPTZ '2023-02-01T09:00:00+01:00', INTERVAL '1' HOUR); +---- +2023-02-01T07:00:00Z + +# Basic date range with hour interval +query P +SELECT * FROM range(DATE '1992-01-01', DATE '1992-01-03', INTERVAL '6' HOUR); +---- +1992-01-01T00:00:00 +1992-01-01T06:00:00 +1992-01-01T12:00:00 +1992-01-01T18:00:00 +1992-01-02T00:00:00 +1992-01-02T06:00:00 +1992-01-02T12:00:00 +1992-01-02T18:00:00 + +# Date range with day interval +query P +SELECT * FROM range(DATE '1992-09-01', DATE '1992-09-05', INTERVAL '1' DAY); +---- +1992-09-01T00:00:00 +1992-09-02T00:00:00 +1992-09-03T00:00:00 +1992-09-04T00:00:00 + +# Date range with month interval +query P +SELECT * FROM range(DATE '1992-09-01', DATE '1993-01-01', INTERVAL '1' MONTH); +---- +1992-09-01T00:00:00 +1992-10-01T00:00:00 +1992-11-01T00:00:00 +1992-12-01T00:00:00 + +# Date range generate_series includes end +query P +SELECT * FROM generate_series(DATE '1992-09-01', DATE '1992-09-03', INTERVAL '1' DAY); +---- +1992-09-01T00:00:00 +1992-09-02T00:00:00 +1992-09-03T00:00:00 + +query TT +EXPLAIN SELECT * FROM generate_series(DATE '1992-09-01', DATE '1992-09-03', INTERVAL '1' DAY); +---- +logical_plan TableScan: generate_series() projection=[value] +physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=715305600000000000, end=715478400000000000, batch_size=8192] + +# Backwards date range +query P +SELECT * FROM range(DATE '1992-09-05', DATE '1992-09-01', INTERVAL '-1' DAY); +---- +1992-09-05T00:00:00 +1992-09-04T00:00:00 +1992-09-03T00:00:00 +1992-09-02T00:00:00 + +# NULL handling for dates +query P +SELECT * FROM range(DATE '1992-09-01', NULL::DATE, INTERVAL '1' MONTH) +---- + +query TT +EXPLAIN SELECT * FROM range(DATE '1992-09-01', NULL::DATE, INTERVAL '1' MONTH) +---- +logical_plan TableScan: range() projection=[value] +physical_plan LazyMemoryExec: partitions=1, batch_generators=[range: empty] + +query P +SELECT * FROM range(NULL::DATE, DATE '1992-09-01', INTERVAL '1' MONTH) +---- + +query P +SELECT * FROM range(DATE '1992-09-01', DATE '1992-10-01', NULL::INTERVAL) +---- + +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +SELECT * FROM range(DATE '2023-01-03', DATE '2023-01-01', INTERVAL '1' DAY) + +query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +SELECT * FROM range(DATE '2023-01-01', DATE '2023-01-02', INTERVAL '-1' DAY) + +query error DataFusion error: Error during planning: range function with dates requires exactly 3 arguments +SELECT * FROM range(DATE '2023-01-01', DATE '2023-01-03') + +# Table function as relation +statement ok +CREATE OR REPLACE TABLE json_table (c INT) AS VALUES (1), (2); + +query II +SELECT c, f.* FROM json_table, LATERAL generate_series(1,2) f; +---- +1 1 +1 2 +2 1 +2 2 diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index e042e3863b69f..1a7ff41d64a66 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -176,6 +176,115 @@ SELECT TIMESTAMPTZ '2000-01-01T01:01:01' 2000-01-01T01:01:01Z +########## +## cast tests +########## + +query BPPPPPP +SELECT t1 = t2 AND t1 = t3 AND t1 = t4 AND t1 = t5 AND t1 = t6, * +FROM (SELECT + (SELECT CAST(CAST(1 AS float) AS timestamp(0))) AS t1, + (SELECT CAST(CAST(one AS float) AS timestamp(0)) FROM (SELECT 1 AS one)) AS t2, + (SELECT CAST(CAST(one AS float) AS timestamp(0)) FROM (VALUES (1)) t(one)) AS t3, + (SELECT CAST(CAST(1 AS double) AS timestamp(0))) AS t4, + (SELECT CAST(CAST(one AS double) AS timestamp(0)) FROM (SELECT 1 AS one)) AS t5, + (SELECT CAST(CAST(one AS double) AS timestamp(0)) FROM (VALUES (1)) t(one)) AS t6 +) +---- +true 1970-01-01T00:00:01 1970-01-01T00:00:01 1970-01-01T00:00:01 1970-01-01T00:00:01 1970-01-01T00:00:01 1970-01-01T00:00:01 + +query BPPPPPP +SELECT t1 = t2 AND t1 = t3 AND t1 = t4 AND t1 = t5 AND t1 = t6, * +FROM (SELECT + (SELECT CAST(CAST(1 AS float) AS timestamp(3))) AS t1, + (SELECT CAST(CAST(one AS float) AS timestamp(3)) FROM (SELECT 1 AS one)) AS t2, + (SELECT CAST(CAST(one AS float) AS timestamp(3)) FROM (VALUES (1)) t(one)) AS t3, + (SELECT CAST(CAST(1 AS double) AS timestamp(3))) AS t4, + (SELECT CAST(CAST(one AS double) AS timestamp(3)) FROM (SELECT 1 AS one)) AS t5, + (SELECT CAST(CAST(one AS double) AS timestamp(3)) FROM (VALUES (1)) t(one)) AS t6 +) +---- +true 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 + +query BPPPPPP +SELECT t1 = t2 AND t1 = t3 AND t1 = t4 AND t1 = t5 AND t1 = t6, * +FROM (SELECT + (SELECT CAST(CAST(1 AS float) AS timestamp(6))) AS t1, + (SELECT CAST(CAST(one AS float) AS timestamp(6)) FROM (SELECT 1 AS one)) AS t2, + (SELECT CAST(CAST(one AS float) AS timestamp(6)) FROM (VALUES (1)) t(one)) AS t3, + (SELECT CAST(CAST(1 AS double) AS timestamp(6))) AS t4, + (SELECT CAST(CAST(one AS double) AS timestamp(6)) FROM (SELECT 1 AS one)) AS t5, + (SELECT CAST(CAST(one AS double) AS timestamp(6)) FROM (VALUES (1)) t(one)) AS t6 +) +---- +true 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 + +query BPPPPPP +SELECT t1 = t2 AND t1 = t3 AND t1 = t4 AND t1 = t5 AND t1 = t6, * +FROM (SELECT + (SELECT CAST(CAST(1 AS float) AS timestamp(9))) AS t1, + (SELECT CAST(CAST(one AS float) AS timestamp(9)) FROM (SELECT 1 AS one)) AS t2, + (SELECT CAST(CAST(one AS float) AS timestamp(9)) FROM (VALUES (1)) t(one)) AS t3, + (SELECT CAST(CAST(1 AS double) AS timestamp(9))) AS t4, + (SELECT CAST(CAST(one AS double) AS timestamp(9)) FROM (SELECT 1 AS one)) AS t5, + (SELECT CAST(CAST(one AS double) AS timestamp(9)) FROM (VALUES (1)) t(one)) AS t6 +) +---- +true 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 + +query BPPPPPP +SELECT t1 = t2 AND t1 = t3 AND t1 = t4 AND t1 = t5 AND t1 = t6, * +FROM (SELECT + (SELECT CAST(CAST(1.125 AS float) AS timestamp(0))) AS t1, + (SELECT CAST(CAST(one_and_a_bit AS float) AS timestamp(0)) FROM (SELECT 1.125 AS one_and_a_bit)) AS t2, + (SELECT CAST(CAST(one_and_a_bit AS float) AS timestamp(0)) FROM (VALUES (1.125)) t(one_and_a_bit)) AS t3, + (SELECT CAST(CAST(1.125 AS double) AS timestamp(0))) AS t4, + (SELECT CAST(CAST(one_and_a_bit AS double) AS timestamp(0)) FROM (SELECT 1.125 AS one_and_a_bit)) AS t5, + (SELECT CAST(CAST(one_and_a_bit AS double) AS timestamp(0)) FROM (VALUES (1.125)) t(one_and_a_bit)) AS t6 +) +---- +true 1970-01-01T00:00:01 1970-01-01T00:00:01 1970-01-01T00:00:01 1970-01-01T00:00:01 1970-01-01T00:00:01 1970-01-01T00:00:01 + +query BPPPPPP +SELECT t1 = t2 AND t1 = t3 AND t1 = t4 AND t1 = t5 AND t1 = t6, * +FROM (SELECT + (SELECT CAST(CAST(1.125 AS float) AS timestamp(3))) AS t1, + (SELECT CAST(CAST(one_and_a_bit AS float) AS timestamp(3)) FROM (SELECT 1.125 AS one_and_a_bit)) AS t2, + (SELECT CAST(CAST(one_and_a_bit AS float) AS timestamp(3)) FROM (VALUES (1.125)) t(one_and_a_bit)) AS t3, + (SELECT CAST(CAST(1.125 AS double) AS timestamp(3))) AS t4, + (SELECT CAST(CAST(one_and_a_bit AS double) AS timestamp(3)) FROM (SELECT 1.125 AS one_and_a_bit)) AS t5, + (SELECT CAST(CAST(one_and_a_bit AS double) AS timestamp(3)) FROM (VALUES (1.125)) t(one_and_a_bit)) AS t6 +) +---- +true 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 1970-01-01T00:00:00.001 + +query BPPPPPP +SELECT t1 = t2 AND t1 = t3 AND t1 = t4 AND t1 = t5 AND t1 = t6, * +FROM (SELECT + (SELECT CAST(CAST(1.125 AS float) AS timestamp(6))) AS t1, + (SELECT CAST(CAST(one_and_a_bit AS float) AS timestamp(6)) FROM (SELECT 1.125 AS one_and_a_bit)) AS t2, + (SELECT CAST(CAST(one_and_a_bit AS float) AS timestamp(6)) FROM (VALUES (1.125)) t(one_and_a_bit)) AS t3, + (SELECT CAST(CAST(1.125 AS double) AS timestamp(6))) AS t4, + (SELECT CAST(CAST(one_and_a_bit AS double) AS timestamp(6)) FROM (SELECT 1.125 AS one_and_a_bit)) AS t5, + (SELECT CAST(CAST(one_and_a_bit AS double) AS timestamp(6)) FROM (VALUES (1.125)) t(one_and_a_bit)) AS t6 +) +---- +true 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000001 + +query BPPPPPP +SELECT t1 = t2 AND t1 = t3 AND t1 = t4 AND t1 = t5 AND t1 = t6, * +FROM (SELECT + (SELECT CAST(CAST(1.125 AS float) AS timestamp(9))) AS t1, + (SELECT CAST(CAST(one_and_a_bit AS float) AS timestamp(9)) FROM (SELECT 1.125 AS one_and_a_bit)) AS t2, + (SELECT CAST(CAST(one_and_a_bit AS float) AS timestamp(9)) FROM (VALUES (1.125)) t(one_and_a_bit)) AS t3, + (SELECT CAST(CAST(1.125 AS double) AS timestamp(9))) AS t4, + (SELECT CAST(CAST(one_and_a_bit AS double) AS timestamp(9)) FROM (SELECT 1.125 AS one_and_a_bit)) AS t5, + (SELECT CAST(CAST(one_and_a_bit AS double) AS timestamp(9)) FROM (VALUES (1.125)) t(one_and_a_bit)) AS t6 +) +---- +true 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 + + ########## ## to_timestamp tests ########## @@ -394,12 +503,12 @@ SELECT COUNT(*) FROM ts_data_secs where ts > to_timestamp_seconds('2020-09-08 12 query PPP SELECT to_timestamp(1.1) as c1, cast(1.1 as timestamp) as c2, 1.1::timestamp as c3; ---- -1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 +1970-01-01T00:00:01.100 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 query PPP SELECT to_timestamp(-1.1) as c1, cast(-1.1 as timestamp) as c2, (-1.1)::timestamp as c3; ---- -1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 +1969-12-31T23:59:58.900 1969-12-31T23:59:59.999999999 1969-12-31T23:59:59.999999999 query PPP SELECT to_timestamp(0.0) as c1, cast(0.0 as timestamp) as c2, 0.0::timestamp as c3; @@ -409,24 +518,24 @@ SELECT to_timestamp(0.0) as c1, cast(0.0 as timestamp) as c2, 0.0::timestamp as query PPP SELECT to_timestamp(1.23456789) as c1, cast(1.23456789 as timestamp) as c2, 1.23456789::timestamp as c3; ---- -1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 +1970-01-01T00:00:01.234567890 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 query PPP SELECT to_timestamp(123456789.123456789) as c1, cast(123456789.123456789 as timestamp) as c2, 123456789.123456789::timestamp as c3; ---- -1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 +1973-11-29T21:33:09.123456784 1970-01-01T00:00:00.123456789 1970-01-01T00:00:00.123456789 # to_timestamp Decimal128 inputs query PPP SELECT to_timestamp(arrow_cast(1.1, 'Decimal128(2,1)')) as c1, cast(arrow_cast(1.1, 'Decimal128(2,1)') as timestamp) as c2, arrow_cast(1.1, 'Decimal128(2,1)')::timestamp as c3; ---- -1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 1970-01-01T00:00:01.100 +1970-01-01T00:00:01.100 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 query PPP SELECT to_timestamp(arrow_cast(-1.1, 'Decimal128(2,1)')) as c1, cast(arrow_cast(-1.1, 'Decimal128(2,1)') as timestamp) as c2, arrow_cast(-1.1, 'Decimal128(2,1)')::timestamp as c3; ---- -1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 1969-12-31T23:59:58.900 +1969-12-31T23:59:58.900 1969-12-31T23:59:59.999999999 1969-12-31T23:59:59.999999999 query PPP SELECT to_timestamp(arrow_cast(0.0, 'Decimal128(2,1)')) as c1, cast(arrow_cast(0.0, 'Decimal128(2,1)') as timestamp) as c2, arrow_cast(0.0, 'Decimal128(2,1)')::timestamp as c3; @@ -436,12 +545,12 @@ SELECT to_timestamp(arrow_cast(0.0, 'Decimal128(2,1)')) as c1, cast(arrow_cast(0 query PPP SELECT to_timestamp(arrow_cast(1.23456789, 'Decimal128(9,8)')) as c1, cast(arrow_cast(1.23456789, 'Decimal128(9,8)') as timestamp) as c2, arrow_cast(1.23456789, 'Decimal128(9,8)')::timestamp as c3; ---- -1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 1970-01-01T00:00:01.234567890 +1970-01-01T00:00:01.234567890 1970-01-01T00:00:00.000000001 1970-01-01T00:00:00.000000001 query PPP SELECT to_timestamp(arrow_cast(123456789.123456789, 'Decimal128(18,9)')) as c1, cast(arrow_cast(123456789.123456789, 'Decimal128(18,9)') as timestamp) as c2, arrow_cast(123456789.123456789, 'Decimal128(18,9)')::timestamp as c3; ---- -1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 1973-11-29T21:33:09.123456784 +1973-11-29T21:33:09.123456784 1970-01-01T00:00:00.123456789 1970-01-01T00:00:00.123456789 # from_unixtime @@ -2268,23 +2377,23 @@ query error input contains invalid characters SELECT to_timestamp_seconds('2020-09-08 12/00/00+00:00', '%c', '%+') # to_timestamp with broken formatting -query error bad or unsupported format string +query error DataFusion error: Execution error: Error parsing timestamp from '2020\-09\-08 12/00/00\+00:00' using format '%q': trailing input SELECT to_timestamp('2020-09-08 12/00/00+00:00', '%q') # to_timestamp_nanos with broken formatting -query error bad or unsupported format string +query error DataFusion error: Execution error: Error parsing timestamp from '2020\-09\-08 12/00/00\+00:00' using format '%q': trailing input SELECT to_timestamp_nanos('2020-09-08 12/00/00+00:00', '%q') # to_timestamp_millis with broken formatting -query error bad or unsupported format string +query error DataFusion error: Execution error: Error parsing timestamp from '2020\-09\-08 12/00/00\+00:00' using format '%q': trailing input SELECT to_timestamp_millis('2020-09-08 12/00/00+00:00', '%q') # to_timestamp_micros with broken formatting -query error bad or unsupported format string +query error DataFusion error: Execution error: Error parsing timestamp from '2020\-09\-08 12/00/00\+00:00' using format '%q': trailing input SELECT to_timestamp_micros('2020-09-08 12/00/00+00:00', '%q') # to_timestamp_seconds with broken formatting -query error bad or unsupported format string +query error DataFusion error: Execution error: Error parsing timestamp from '2020\-09\-08 12/00/00\+00:00' using format '%q': trailing input SELECT to_timestamp_seconds('2020-09-08 12/00/00+00:00', '%q') # Create string timestamp table with different formats @@ -2794,6 +2903,18 @@ select date_format(dates, date_format) from formats; 01:01:2000 05:04:2003 +query T +select date_format(dates, time_format) from formats; +---- +00-00-00 +00::00::00 + +query T +select date_format(dates, timestamp_format) from formats; +---- +01:01:2000 00-00-00 +05:04:2003 00-00-00 + query T select to_char(times, time_format) from formats; ---- @@ -2842,6 +2963,11 @@ select to_char(arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Second, N ---- 03-08-2023 14-38-50 +query T +select to_char(arrow_cast('2023-09-04'::date, 'Timestamp(Second, Some("UTC"))'), '%Y-%m-%dT%H:%M:%S%.3f'); +---- +2023-09-04T00:00:00.000 + query T select to_char(arrow_cast(123456, 'Duration(Second)'), 'pretty'); ---- @@ -3110,6 +3236,11 @@ select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brus ---- 2024-04-01T00:00:20 +query P +select to_local_time(NULL); +---- +NULL + query PTPT select time, @@ -3137,6 +3268,7 @@ select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestam statement ok create table t AS VALUES + (NULL), ('2024-01-01T00:00:01Z'), ('2024-02-01T00:00:01Z'), ('2024-03-01T00:00:01Z'), @@ -3164,6 +3296,7 @@ from t; query PPT select column1, to_local_time(column1::timestamp), arrow_typeof(to_local_time(column1::timestamp)) from t_utc; ---- +NULL NULL Timestamp(Nanosecond, None) 2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) 2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) 2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) @@ -3180,6 +3313,7 @@ select column1, to_local_time(column1::timestamp), arrow_typeof(to_local_time(co query PPT select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_utc; ---- +NULL NULL Timestamp(Nanosecond, None) 2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) 2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) 2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) @@ -3196,6 +3330,7 @@ select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) fro query PPT select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_timezone; ---- +NULL NULL Timestamp(Nanosecond, None) 2024-01-01T00:00:01+01:00 2024-01-01T00:00:01 Timestamp(Nanosecond, None) 2024-02-01T00:00:01+01:00 2024-02-01T00:00:01 Timestamp(Nanosecond, None) 2024-03-01T00:00:01+01:00 2024-03-01T00:00:01 Timestamp(Nanosecond, None) @@ -3213,6 +3348,7 @@ select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) fro query P select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_utc; ---- +NULL 2024-01-01T00:00:00+01:00 2024-02-01T00:00:00+01:00 2024-03-01T00:00:00+01:00 @@ -3229,6 +3365,7 @@ select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/B query P select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_timezone; ---- +NULL 2024-01-01T00:00:00+01:00 2024-02-01T00:00:00+01:00 2024-03-01T00:00:00+01:00 @@ -3415,3 +3552,100 @@ select to_timestamp('-1'); query error DataFusion error: Arrow error: Parser error: Error parsing timestamp from '\-1': timestamp must contain at least 10 characters select to_timestamp(arrow_cast('-1', 'Utf8')); + +query P +SELECT CAST(CAST(1 AS decimal(17,2)) AS timestamp(3)) AS a UNION ALL +SELECT CAST(CAST(one AS decimal(17,2)) AS timestamp(3)) AS a FROM (VALUES (1)) t(one); +---- +1970-01-01T00:00:00.001 +1970-01-01T00:00:00.001 + +query P +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Nanosecond, None)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Nanosecond, None)') AS a FROM (VALUES (1)) t(one); +---- +1970-01-01T00:00:00.000000001 +1970-01-01T00:00:00.000000001 + +query P +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Microsecond, None)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Microsecond, None)') AS a FROM (VALUES (1)) t(one); +---- +1970-01-01T00:00:00.000001 +1970-01-01T00:00:00.000001 + +query P +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Millisecond, None)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Millisecond, None)') AS a FROM (VALUES (1)) t(one); +---- +1970-01-01T00:00:00.001 +1970-01-01T00:00:00.001 + +query P +SELECT arrow_cast(CAST(1 AS decimal(17,2)), 'Timestamp(Second, None)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,2)), 'Timestamp(Second, None)') AS a FROM (VALUES (1)) t(one); +---- +1970-01-01T00:00:01 +1970-01-01T00:00:01 + + +query P +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Nanosecond, None)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Nanosecond, None)') AS a FROM (VALUES (1.123)) t(one); +---- +1970-01-01T00:00:00.000000001 +1970-01-01T00:00:00.000000001 + +query P +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Microsecond, None)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Microsecond, None)') AS a FROM (VALUES (1.123)) t(one); +---- +1970-01-01T00:00:00.000001 +1970-01-01T00:00:00.000001 + +query P +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Millisecond, None)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Millisecond, None)') AS a FROM (VALUES (1.123)) t(one); +---- +1970-01-01T00:00:00.001 +1970-01-01T00:00:00.001 + +query P +SELECT arrow_cast(CAST(1.123 AS decimal(17,3)), 'Timestamp(Second, None)') AS a UNION ALL +SELECT arrow_cast(CAST(one AS decimal(17,3)), 'Timestamp(Second, None)') AS a FROM (VALUES (1.123)) t(one); +---- +1970-01-01T00:00:01 +1970-01-01T00:00:01 + +query TTTTT +SELECT + arrow_typeof(a), + CAST(a AS varchar), + arrow_cast(a, 'Utf8'), + arrow_cast(a, 'Utf8View'), + arrow_cast(a, 'LargeUtf8') +FROM (SELECT DATE '2005-09-10' AS a) +---- +Date32 2005-09-10 2005-09-10 2005-09-10 2005-09-10 + +query TTTTT +SELECT + arrow_typeof(a), + CAST(a AS varchar), + arrow_cast(a, 'Utf8'), + arrow_cast(a, 'Utf8View'), + arrow_cast(a, 'LargeUtf8') +FROM (SELECT TIMESTAMP '2005-09-10 13:31:00' AS a) +---- +Timestamp(Nanosecond, None) 2005-09-10T13:31:00 2005-09-10T13:31:00 2005-09-10T13:31:00 2005-09-10T13:31:00 + +query TTTTT +SELECT + arrow_typeof(a), + CAST(a AS varchar), + arrow_cast(a, 'Utf8'), + arrow_cast(a, 'Utf8View'), + arrow_cast(a, 'LargeUtf8') +FROM (SELECT CAST('2005-09-10 13:31:00 +02:00' AS timestamp with time zone) AS a) +---- +Timestamp(Nanosecond, Some("+00")) 2005-09-10T11:31:00Z 2005-09-10T11:31:00Z 2005-09-10T11:31:00Z 2005-09-10T11:31:00Z diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt index b5ff95c358d8e..8a08cc17d4172 100644 --- a/datafusion/sqllogictest/test_files/topk.slt +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -53,7 +53,7 @@ query I select * from (select * from topk limit 8) order by x limit 3; ---- 0 -1 +2 2 @@ -233,3 +233,165 @@ d 1 -98 y7C453hRWd4E7ImjNDWlpexB8nUqjh y7C453hRWd4E7ImjNDWlpexB8nUqjh e 2 52 xipQ93429ksjNcXPX5326VSg1xJZcW xipQ93429ksjNcXPX5326VSg1xJZcW d 1 -72 wwXqSGKLyBQyPkonlzBNYUJTCo4LRS wwXqSGKLyBQyPkonlzBNYUJTCo4LRS a 1 -5 waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs waIGbOGl1PM6gnzZ4uuZt4E2yDWRHs + +##################################### +## Test TopK with Partially Sorted Inputs +##################################### + + +# Create an external table where data is pre-sorted by (number DESC, letter ASC) only. +statement ok +CREATE EXTERNAL TABLE partial_sorted ( + number INT, + letter VARCHAR, + age INT +) +STORED AS parquet +LOCATION 'test_files/scratch/topk/partial_sorted/1.parquet' +WITH ORDER (number DESC, letter ASC); + +# Insert test data into the external table. +query I +COPY ( + SELECT * + FROM ( + VALUES + (1, 'F', 100), + (1, 'B', 50), + (2, 'C', 70), + (2, 'D', 80), + (3, 'A', 60), + (3, 'E', 90) + ) AS t(number, letter, age) + ORDER BY number DESC, letter ASC +) +TO 'test_files/scratch/topk/partial_sorted/1.parquet'; +---- +6 + +## explain physical_plan only +statement ok +set datafusion.explain.physical_plan_only = true + +## batch size smaller than number of rows in the table and result +statement ok +set datafusion.execution.batch_size = 2 + +# Run a TopK query that orders by all columns. +# Although the table is only guaranteed to be sorted by (number DESC, letter ASC), +# DataFusion should use the common prefix optimization +# and return the correct top 3 rows when ordering by all columns. +query ITI +select number, letter, age from partial_sorted order by number desc, letter asc, age desc limit 3; +---- +3 A 60 +3 E 90 +2 C 70 + +# A more complex example with a projection that includes an expression (see further down for the explained plan) +query IIITI +select + number + 1 as number_plus, + number, + number + 1 as other_number_plus, + letter, + age +from partial_sorted +order by + number_plus desc, + number desc, + other_number_plus desc, + letter asc, + age desc +limit 3; +---- +4 3 4 A 60 +4 3 4 E 90 +3 2 3 C 70 + +# Verify that the physical plan includes the sort prefix. +# The output should display a "sort_prefix" in the SortExec node. +query TT +explain select number, letter, age from partial_sorted order by number desc, letter asc, age desc limit 3; +---- +physical_plan +01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC NULLS LAST, age@2 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC, letter@1 ASC NULLS LAST] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilter [ empty ] + + +# Explain variations of the above query with different orderings, and different sort prefixes. +# The "sort_prefix" in the SortExec node should only be present if the TopK's ordering starts with either (number DESC, letter ASC) or just (number DESC). +query TT +explain select number, letter, age from partial_sorted order by age desc limit 3; +---- +physical_plan +01)SortExec: TopK(fetch=3), expr=[age@2 DESC], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilter [ empty ] + +query TT +explain select number, letter, age from partial_sorted order by number desc, letter desc limit 3; +---- +physical_plan +01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilter [ empty ] + +query TT +explain select number, letter, age from partial_sorted order by number asc limit 3; +---- +physical_plan +01)SortExec: TopK(fetch=3), expr=[number@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilter [ empty ] + +query TT +explain select number, letter, age from partial_sorted order by letter asc, number desc limit 3; +---- +physical_plan +01)SortExec: TopK(fetch=3), expr=[letter@1 ASC NULLS LAST, number@0 DESC], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Explicit NULLS ordering cases (reversing the order of the NULLS on the number and letter orderings) +query TT +explain select number, letter, age from partial_sorted order by number desc, letter asc NULLS FIRST limit 3; +---- +physical_plan +01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC], preserve_partitioning=[false], sort_prefix=[number@0 DESC] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilter [ empty ] + +query TT +explain select number, letter, age from partial_sorted order by number desc NULLS LAST, letter asc limit 3; +---- +physical_plan +01)SortExec: TopK(fetch=3), expr=[number@0 DESC NULLS LAST, letter@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilter [ empty ] + + +# Verify that the sort prefix is correctly computed on the normalized ordering (removing redundant aliased columns) +query TT +explain select number, letter, age, number as column4, letter as column5 from partial_sorted order by number desc, column4 desc, letter asc, column5 asc, age desc limit 3; +---- +physical_plan +01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC NULLS LAST, age@2 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC, letter@1 ASC NULLS LAST] +02)--ProjectionExec: expr=[number@0 as number, letter@1 as letter, age@2 as age, number@0 as column4, letter@1 as column5] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilter [ empty ] + +# Verify that the sort prefix is correctly computed over normalized, order-maintaining projections (number + 1, number, number + 1, age) +query TT +explain select number + 1 as number_plus, number, number + 1 as other_number_plus, age from partial_sorted order by number_plus desc, number desc, other_number_plus desc, age asc limit 3; +---- +physical_plan +01)SortPreservingMergeExec: [number_plus@0 DESC, number@1 DESC, other_number_plus@2 DESC, age@3 ASC NULLS LAST], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[number_plus@0 DESC, number@1 DESC, age@3 ASC NULLS LAST], preserve_partitioning=[true], sort_prefix=[number_plus@0 DESC, number@1 DESC] +03)----ProjectionExec: expr=[__common_expr_1@0 as number_plus, number@1 as number, __common_expr_1@0 as other_number_plus, age@2 as age] +04)------ProjectionExec: expr=[CAST(number@0 AS Int64) + 1 as __common_expr_1, number@0 as number, age@1 as age] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, age], output_ordering=[number@0 DESC], file_type=parquet + +# Cleanup +statement ok +DROP TABLE partial_sorted; + +statement ok +set datafusion.explain.physical_plan_only = false + +statement ok +set datafusion.execution.batch_size = 8192 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part index fee496f92055e..04de9153a0474 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part @@ -65,8 +65,8 @@ logical_plan 12)--------------------Filter: orders.o_orderdate >= Date32("1993-10-01") AND orders.o_orderdate < Date32("1994-01-01") 13)----------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("1993-10-01"), orders.o_orderdate < Date32("1994-01-01")] 14)--------------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount -15)----------------Filter: lineitem.l_returnflag = Utf8("R") -16)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[lineitem.l_returnflag = Utf8("R")] +15)----------------Filter: lineitem.l_returnflag = Utf8View("R") +16)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[lineitem.l_returnflag = Utf8View("R")] 17)----------TableScan: nation projection=[n_nationkey, n_name] physical_plan 01)SortPreservingMergeExec: [revenue@2 DESC], fetch=10 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part index 1dba8c0537209..a6225daae4362 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part @@ -58,8 +58,8 @@ logical_plan 09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], partial_filters=[Boolean(true)] 10)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] 11)------------Projection: nation.n_nationkey -12)--------------Filter: nation.n_name = Utf8("GERMANY") -13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +12)--------------Filter: nation.n_name = Utf8View("GERMANY") +13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] 14)------SubqueryAlias: __scalar_sq_1 15)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) 16)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] @@ -70,8 +70,8 @@ logical_plan 21)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] 22)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 23)----------------Projection: nation.n_nationkey -24)------------------Filter: nation.n_name = Utf8("GERMANY") -25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +24)------------------Filter: nation.n_name = Utf8View("GERMANY") +25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] physical_plan 01)SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[false] 02)--ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part index 3757fc48dba0a..f7344daed8c7a 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part @@ -51,12 +51,12 @@ order by logical_plan 01)Sort: lineitem.l_shipmode ASC NULLS LAST 02)--Projection: lineitem.l_shipmode, sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS high_line_count, sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS low_line_count -03)----Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] +03)----Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[sum(CASE WHEN orders.o_orderpriority = Utf8View("1-URGENT") OR orders.o_orderpriority = Utf8View("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8View("1-URGENT") AND orders.o_orderpriority != Utf8View("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] 04)------Projection: lineitem.l_shipmode, orders.o_orderpriority 05)--------Inner Join: lineitem.l_orderkey = orders.o_orderkey 06)----------Projection: lineitem.l_orderkey, lineitem.l_shipmode -07)------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("1994-01-01") AND lineitem.l_receiptdate < Date32("1995-01-01") -08)--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_receiptdate > lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("1994-01-01"), lineitem.l_receiptdate < Date32("1995-01-01")] +07)------------Filter: (lineitem.l_shipmode = Utf8View("MAIL") OR lineitem.l_shipmode = Utf8View("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("1994-01-01") AND lineitem.l_receiptdate < Date32("1995-01-01") +08)--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("MAIL") OR lineitem.l_shipmode = Utf8View("SHIP"), lineitem.l_receiptdate > lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("1994-01-01"), lineitem.l_receiptdate < Date32("1995-01-01")] 09)----------TableScan: orders projection=[o_orderkey, o_orderpriority] physical_plan 01)SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part index e9d9cf141d103..96f3bd6edf324 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part @@ -50,8 +50,8 @@ logical_plan 08)--------------Left Join: customer.c_custkey = orders.o_custkey 09)----------------TableScan: customer projection=[c_custkey] 10)----------------Projection: orders.o_orderkey, orders.o_custkey -11)------------------Filter: orders.o_comment NOT LIKE Utf8("%special%requests%") -12)--------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")] +11)------------------Filter: orders.o_comment NOT LIKE Utf8View("%special%requests%") +12)--------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8View("%special%requests%")] physical_plan 01)SortPreservingMergeExec: [custdist@1 DESC, c_count@0 DESC], fetch=10 02)--SortExec: TopK(fetch=10), expr=[custdist@1 DESC, c_count@0 DESC], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part index 1104af2bdc643..8d8dd68c3d7bd 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part @@ -33,7 +33,7 @@ where ---- logical_plan 01)Projection: Float64(100) * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue -02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN __common_expr_1 ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8View("PROMO%") THEN __common_expr_1 ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 03)----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS __common_expr_1, part.p_type 04)------Inner Join: lineitem.l_partkey = part.p_partkey 05)--------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part index c648f164c8094..53d637ea3f510 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part @@ -58,12 +58,12 @@ logical_plan 06)----------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size 07)------------Inner Join: partsupp.ps_partkey = part.p_partkey 08)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] -09)--------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) -10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] +09)--------------Filter: part.p_brand != Utf8View("Brand#45") AND part.p_type NOT LIKE Utf8View("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) +10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8View("Brand#45"), part.p_type NOT LIKE Utf8View("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] 11)----------SubqueryAlias: __correlated_sq_1 12)------------Projection: supplier.s_suppkey -13)--------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") -14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] +13)--------------Filter: supplier.s_comment LIKE Utf8View("%Customer%Complaints%") +14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8View("%Customer%Complaints%")] physical_plan 01)SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], fetch=10 02)--SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], preserve_partitioning=[true] @@ -88,7 +88,7 @@ physical_plan 21)----------------------------------CoalesceBatchesExec: target_batch_size=8192 22)------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 23)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) +24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND p_size@3 IN (SET) ([49, 14, 23, 45, 19, 3, 36, 9]) 25)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 26)--------------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false 27)--------------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part index 02553890bcf5a..51a0d096428c0 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part @@ -44,8 +44,8 @@ logical_plan 06)----------Inner Join: lineitem.l_partkey = part.p_partkey 07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] 08)------------Projection: part.p_partkey -09)--------------Filter: part.p_brand = Utf8("Brand#23") AND part.p_container = Utf8("MED BOX") -10)----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8("Brand#23"), part.p_container = Utf8("MED BOX")] +09)--------------Filter: part.p_brand = Utf8View("Brand#23") AND part.p_container = Utf8View("MED BOX") +10)----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8View("Brand#23"), part.p_container = Utf8View("MED BOX")] 11)--------SubqueryAlias: __scalar_sq_1 12)----------Projection: CAST(Float64(0.2) * CAST(avg(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)), lineitem.l_partkey 13)------------Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[avg(lineitem.l_quantity)]] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part index b0e5b2e904d00..4960ad1f4a914 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part @@ -57,19 +57,19 @@ logical_plan 01)Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue 02)--Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 03)----Projection: lineitem.l_extendedprice, lineitem.l_discount -04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) +04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) 05)--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount -06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") -07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] -08)--------Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) -09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)] +06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON") +07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG"), lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] +08)--------Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) +09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)] physical_plan 01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue] 02)--AggregateExec: mode=Final, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] +06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN ([SM CASE, SM BOX, SM PACK, SM PKG]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([MED BAG, MED BOX, MED PKG, MED PACK]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([LG CASE, LG BOX, LG PACK, LG PKG]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 09)----------------CoalesceBatchesExec: target_batch_size=8192 @@ -78,6 +78,6 @@ physical_plan 12)------------CoalesceBatchesExec: target_batch_size=8192 13)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 14)----------------CoalesceBatchesExec: target_batch_size=8192 -15)------------------FilterExec: (p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 +15)------------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15) AND p_size@2 >= 1 16)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 17)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part index 2a8ee9f229b7b..b2e0fb0cd1cc0 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part @@ -75,14 +75,14 @@ logical_plan 10)------------------Projection: part.p_partkey, part.p_mfgr, partsupp.ps_suppkey, partsupp.ps_supplycost 11)--------------------Inner Join: part.p_partkey = partsupp.ps_partkey 12)----------------------Projection: part.p_partkey, part.p_mfgr -13)------------------------Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") -14)--------------------------TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] +13)------------------------Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8View("%BRASS") +14)--------------------------TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8View("%BRASS")] 15)----------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] 16)------------------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] 17)--------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] 18)----------Projection: region.r_regionkey -19)------------Filter: region.r_name = Utf8("EUROPE") -20)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] +19)------------Filter: region.r_name = Utf8View("EUROPE") +20)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("EUROPE")] 21)------SubqueryAlias: __scalar_sq_1 22)--------Projection: min(partsupp.ps_supplycost), partsupp.ps_partkey 23)----------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[min(partsupp.ps_supplycost)]] @@ -96,8 +96,8 @@ logical_plan 31)------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 32)--------------------TableScan: nation projection=[n_nationkey, n_regionkey] 33)----------------Projection: region.r_regionkey -34)------------------Filter: region.r_name = Utf8("EUROPE") -35)--------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] +34)------------------Filter: region.r_name = Utf8View("EUROPE") +35)--------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("EUROPE")] physical_plan 01)SortPreservingMergeExec: [s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], fetch=10 02)--SortExec: TopK(fetch=10), expr=[s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part index 4844d5fae60bd..0b994de411ea3 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part @@ -63,8 +63,8 @@ logical_plan 05)--------Inner Join: supplier.s_nationkey = nation.n_nationkey 06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] 07)----------Projection: nation.n_nationkey -08)------------Filter: nation.n_name = Utf8("CANADA") -09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] +08)------------Filter: nation.n_name = Utf8View("CANADA") +09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("CANADA")] 10)------SubqueryAlias: __correlated_sq_2 11)--------Projection: partsupp.ps_suppkey 12)----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * sum(lineitem.l_quantity) @@ -72,8 +72,8 @@ logical_plan 14)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] 15)--------------SubqueryAlias: __correlated_sq_1 16)----------------Projection: part.p_partkey -17)------------------Filter: part.p_name LIKE Utf8("forest%") -18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] +17)------------------Filter: part.p_name LIKE Utf8View("forest%") +18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8View("forest%")] 19)------------SubqueryAlias: __scalar_sq_3 20)--------------Projection: Float64(0.5) * CAST(sum(lineitem.l_quantity) AS Float64), lineitem.l_partkey, lineitem.l_suppkey 21)----------------Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[sum(lineitem.l_quantity)]] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part index bb3e884e27bef..e52171524007e 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part @@ -76,11 +76,11 @@ logical_plan 16)----------------------------Filter: lineitem.l_receiptdate > lineitem.l_commitdate 17)------------------------------TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], partial_filters=[lineitem.l_receiptdate > lineitem.l_commitdate] 18)--------------------Projection: orders.o_orderkey -19)----------------------Filter: orders.o_orderstatus = Utf8("F") -20)------------------------TableScan: orders projection=[o_orderkey, o_orderstatus], partial_filters=[orders.o_orderstatus = Utf8("F")] +19)----------------------Filter: orders.o_orderstatus = Utf8View("F") +20)------------------------TableScan: orders projection=[o_orderkey, o_orderstatus], partial_filters=[orders.o_orderstatus = Utf8View("F")] 21)----------------Projection: nation.n_nationkey -22)------------------Filter: nation.n_name = Utf8("SAUDI ARABIA") -23)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("SAUDI ARABIA")] +22)------------------Filter: nation.n_name = Utf8View("SAUDI ARABIA") +23)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("SAUDI ARABIA")] 24)------------SubqueryAlias: __correlated_sq_1 25)--------------SubqueryAlias: l2 26)----------------TableScan: lineitem projection=[l_orderkey, l_suppkey] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index 828bf967d8f4a..fc9c01843cc75 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -90,7 +90,7 @@ physical_plan 14)--------------------------CoalesceBatchesExec: target_batch_size=8192 15)----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]) +17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([13, 31, 23, 29, 30, 18, 17]) 18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 19)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false 20)--------------------------CoalesceBatchesExec: target_batch_size=8192 @@ -100,6 +100,6 @@ physical_plan 24)----------------------CoalescePartitionsExec 25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] 26)--------------------------CoalesceBatchesExec: target_batch_size=8192 -27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]), projection=[c_acctbal@1] +27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] 28)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 29)--------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part index 2ad496ef26fdf..d982ec32e9547 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part @@ -50,8 +50,8 @@ logical_plan 06)----------Projection: orders.o_orderkey, orders.o_orderdate, orders.o_shippriority 07)------------Inner Join: customer.c_custkey = orders.o_custkey 08)--------------Projection: customer.c_custkey -09)----------------Filter: customer.c_mktsegment = Utf8("BUILDING") -10)------------------TableScan: customer projection=[c_custkey, c_mktsegment], partial_filters=[customer.c_mktsegment = Utf8("BUILDING")] +09)----------------Filter: customer.c_mktsegment = Utf8View("BUILDING") +10)------------------TableScan: customer projection=[c_custkey, c_mktsegment], partial_filters=[customer.c_mktsegment = Utf8View("BUILDING")] 11)--------------Filter: orders.o_orderdate < Date32("1995-03-15") 12)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], partial_filters=[orders.o_orderdate < Date32("1995-03-15")] 13)----------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part index f192f987b3ef9..15636056b8714 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part @@ -64,8 +64,8 @@ logical_plan 19)------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 20)--------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] 21)----------Projection: region.r_regionkey -22)------------Filter: region.r_name = Utf8("ASIA") -23)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("ASIA")] +22)------------Filter: region.r_name = Utf8View("ASIA") +23)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("ASIA")] physical_plan 01)SortPreservingMergeExec: [revenue@1 DESC] 02)--SortExec: expr=[revenue@1 DESC], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part index e03de9596fbef..291d56e43f2df 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part @@ -63,7 +63,7 @@ logical_plan 03)----Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[sum(shipping.volume)]] 04)------SubqueryAlias: shipping 05)--------Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, date_part(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume -06)----------Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") +06)----------Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8View("FRANCE") AND n2.n_name = Utf8View("GERMANY") OR n1.n_name = Utf8View("GERMANY") AND n2.n_name = Utf8View("FRANCE") 07)------------Projection: lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate, customer.c_nationkey, n1.n_name 08)--------------Inner Join: supplier.s_nationkey = n1.n_nationkey 09)----------------Projection: supplier.s_nationkey, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate, customer.c_nationkey @@ -78,11 +78,11 @@ logical_plan 18)------------------------TableScan: orders projection=[o_orderkey, o_custkey] 19)--------------------TableScan: customer projection=[c_custkey, c_nationkey] 20)----------------SubqueryAlias: n1 -21)------------------Filter: nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY") -22)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY")] +21)------------------Filter: nation.n_name = Utf8View("FRANCE") OR nation.n_name = Utf8View("GERMANY") +22)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("FRANCE") OR nation.n_name = Utf8View("GERMANY")] 23)------------SubqueryAlias: n2 -24)--------------Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE") -25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE")] +24)--------------Filter: nation.n_name = Utf8View("GERMANY") OR nation.n_name = Utf8View("FRANCE") +25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY") OR nation.n_name = Utf8View("FRANCE")] physical_plan 01)SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST] 02)--SortExec: expr=[supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part index 88ceffd62ad35..50171c528db6d 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part @@ -58,7 +58,7 @@ order by logical_plan 01)Sort: all_nations.o_year ASC NULLS LAST 02)--Projection: all_nations.o_year, CAST(CAST(sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END) AS Decimal128(12, 2)) / CAST(sum(all_nations.volume) AS Decimal128(12, 2)) AS Decimal128(15, 2)) AS mkt_share -03)----Aggregate: groupBy=[[all_nations.o_year]], aggr=[[sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)]] +03)----Aggregate: groupBy=[[all_nations.o_year]], aggr=[[sum(CASE WHEN all_nations.nation = Utf8View("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)]] 04)------SubqueryAlias: all_nations 05)--------Projection: date_part(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume, n2.n_name AS nation 06)----------Inner Join: n1.n_regionkey = region.r_regionkey @@ -75,8 +75,8 @@ logical_plan 17)--------------------------------Projection: lineitem.l_orderkey, lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount 18)----------------------------------Inner Join: part.p_partkey = lineitem.l_partkey 19)------------------------------------Projection: part.p_partkey -20)--------------------------------------Filter: part.p_type = Utf8("ECONOMY ANODIZED STEEL") -21)----------------------------------------TableScan: part projection=[p_partkey, p_type], partial_filters=[part.p_type = Utf8("ECONOMY ANODIZED STEEL")] +20)--------------------------------------Filter: part.p_type = Utf8View("ECONOMY ANODIZED STEEL") +21)----------------------------------------TableScan: part projection=[p_partkey, p_type], partial_filters=[part.p_type = Utf8View("ECONOMY ANODIZED STEEL")] 22)------------------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount] 23)--------------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 24)----------------------------Filter: orders.o_orderdate >= Date32("1995-01-01") AND orders.o_orderdate <= Date32("1996-12-31") @@ -87,8 +87,8 @@ logical_plan 29)----------------SubqueryAlias: n2 30)------------------TableScan: nation projection=[n_nationkey, n_name] 31)------------Projection: region.r_regionkey -32)--------------Filter: region.r_name = Utf8("AMERICA") -33)----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("AMERICA")] +32)--------------Filter: region.r_name = Utf8View("AMERICA") +33)----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("AMERICA")] physical_plan 01)SortPreservingMergeExec: [o_year@0 ASC NULLS LAST] 02)--SortExec: expr=[o_year@0 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part index 8ccf967187d7d..3b31c1bc2e8e3 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part @@ -67,8 +67,8 @@ logical_plan 13)------------------------Projection: lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount 14)--------------------------Inner Join: part.p_partkey = lineitem.l_partkey 15)----------------------------Projection: part.p_partkey -16)------------------------------Filter: part.p_name LIKE Utf8("%green%") -17)--------------------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("%green%")] +16)------------------------------Filter: part.p_name LIKE Utf8View("%green%") +17)--------------------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8View("%green%")] 18)----------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount] 19)------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 20)--------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] diff --git a/datafusion/sqllogictest/test_files/type_coercion.slt b/datafusion/sqllogictest/test_files/type_coercion.slt index 2c6079bc7039d..3175a0646b799 100644 --- a/datafusion/sqllogictest/test_files/type_coercion.slt +++ b/datafusion/sqllogictest/test_files/type_coercion.slt @@ -128,9 +128,9 @@ EXPLAIN SELECT 1, 2 UNION ALL SELECT 3, 4 logical_plan 01)Union 02)--Projection: Int64(1) AS Int64(1), Int64(2) AS Int64(2) -03)----EmptyRelation +03)----EmptyRelation: rows=1 04)--Projection: Int64(3) AS Int64(1), Int64(4) AS Int64(2) -05)----EmptyRelation +05)----EmptyRelation: rows=1 # union_with_incompatible_data_type() query error Incompatible inputs for Union: Previous inputs were of type Interval\(MonthDayNano\), but got incompatible type Int64 on column 'Int64\(1\)' @@ -143,9 +143,9 @@ EXPLAIN SELECT 1 a UNION ALL SELECT 1.1 a logical_plan 01)Union 02)--Projection: CAST(Int64(1) AS Float64) AS a -03)----EmptyRelation +03)----EmptyRelation: rows=1 04)--Projection: Float64(1.1) AS a -05)----EmptyRelation +05)----EmptyRelation: rows=1 # union_with_null() query TT @@ -154,9 +154,9 @@ EXPLAIN SELECT NULL a UNION ALL SELECT 1.1 a logical_plan 01)Union 02)--Projection: CAST(NULL AS Float64) AS a -03)----EmptyRelation +03)----EmptyRelation: rows=1 04)--Projection: Float64(1.1) AS a -05)----EmptyRelation +05)----EmptyRelation: rows=1 # union_with_float_and_string() query TT @@ -165,9 +165,9 @@ EXPLAIN SELECT 'a' a UNION ALL SELECT 1.1 a logical_plan 01)Union 02)--Projection: Utf8("a") AS a -03)----EmptyRelation +03)----EmptyRelation: rows=1 04)--Projection: CAST(Float64(1.1) AS Utf8) AS a -05)----EmptyRelation +05)----EmptyRelation: rows=1 # union_with_multiply_cols() query TT @@ -176,9 +176,9 @@ EXPLAIN SELECT 'a' a, 1 b UNION ALL SELECT 1.1 a, 1.1 b logical_plan 01)Union 02)--Projection: Utf8("a") AS a, CAST(Int64(1) AS Float64) AS b -03)----EmptyRelation +03)----EmptyRelation: rows=1 04)--Projection: CAST(Float64(1.1) AS Utf8) AS a, Float64(1.1) AS b -05)----EmptyRelation +05)----EmptyRelation: rows=1 # sorted_union_with_different_types_and_group_by() query TT @@ -193,12 +193,12 @@ logical_plan 04)------Aggregate: groupBy=[[x.a]], aggr=[[]] 05)--------SubqueryAlias: x 06)----------Projection: Int64(1) AS a -07)------------EmptyRelation +07)------------EmptyRelation: rows=1 08)----Projection: x.a 09)------Aggregate: groupBy=[[x.a]], aggr=[[]] 10)--------SubqueryAlias: x 11)----------Projection: Float64(1.1) AS a -12)------------EmptyRelation +12)------------EmptyRelation: rows=1 # union_with_binary_expr_and_cast() query TT @@ -212,12 +212,12 @@ logical_plan 03)----Aggregate: groupBy=[[CAST(Float64(0) + CAST(x.a AS Float64) AS Int32)]], aggr=[[]] 04)------SubqueryAlias: x 05)--------Projection: Int64(1) AS a -06)----------EmptyRelation +06)----------EmptyRelation: rows=1 07)--Projection: Float64(2.1) + x.a AS Float64(0) + x.a 08)----Aggregate: groupBy=[[Float64(2.1) + CAST(x.a AS Float64)]], aggr=[[]] 09)------SubqueryAlias: x 10)--------Projection: Int64(1) AS a -11)----------EmptyRelation +11)----------EmptyRelation: rows=1 # union_with_aliases() query TT @@ -231,12 +231,12 @@ logical_plan 03)----Aggregate: groupBy=[[x.a]], aggr=[[]] 04)------SubqueryAlias: x 05)--------Projection: Int64(1) AS a -06)----------EmptyRelation +06)----------EmptyRelation: rows=1 07)--Projection: x.a AS a1 08)----Aggregate: groupBy=[[x.a]], aggr=[[]] 09)------SubqueryAlias: x 10)--------Projection: Float64(1.1) AS a -11)----------EmptyRelation +11)----------EmptyRelation: rows=1 # union_with_incompatible_data_types() query error Incompatible inputs for Union: Previous inputs were of type Utf8, but got incompatible type Boolean on column 'a' diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 356f1598bc0fa..1f7605d220c5e 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -230,7 +230,7 @@ logical_plan 02)--Union 03)----TableScan: t1 projection=[name] 04)----TableScan: t2 projection=[name] -05)----Projection: t2.name || Utf8("_new") AS name +05)----Projection: t2.name || Utf8View("_new") AS name 06)------TableScan: t2 projection=[name] physical_plan 01)AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] @@ -266,7 +266,7 @@ logical_plan 01)Union 02)--TableScan: t1 projection=[name] 03)--TableScan: t2 projection=[name] -04)--Projection: t2.name || Utf8("_new") AS name +04)--Projection: t2.name || Utf8View("_new") AS name 05)----TableScan: t2 projection=[name] physical_plan 01)UnionExec @@ -308,7 +308,7 @@ logical_plan physical_plan 01)UnionExec 02)--CoalesceBatchesExec: target_batch_size=2 -03)----HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)] +03)----HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(id@0, CAST(t2.id AS Int32)@2), (name@1, name@1)], NullsEqual: true 04)------CoalescePartitionsExec 05)--------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] 06)----------CoalesceBatchesExec: target_batch_size=2 @@ -321,7 +321,7 @@ physical_plan 13)----------DataSourceExec: partitions=1, partition_sizes=[1] 14)--ProjectionExec: expr=[CAST(id@0 AS Int32) as id, name@1 as name] 15)----CoalesceBatchesExec: target_batch_size=2 -16)------HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)], projection=[id@0, name@1] +16)------HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(CAST(t2.id AS Int32)@2, id@0), (name@1, name@1)], projection=[id@0, name@1], NullsEqual: true 17)--------CoalescePartitionsExec 18)----------ProjectionExec: expr=[id@0 as id, name@1 as name, CAST(id@0 AS Int32) as CAST(t2.id AS Int32)] 19)------------AggregateExec: mode=FinalPartitioned, gby=[id@0 as id, name@1 as name], aggr=[] @@ -378,7 +378,7 @@ logical_plan physical_plan 01)UnionExec 02)--CoalesceBatchesExec: target_batch_size=2 -03)----HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(name@0, name@0)] +03)----HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(name@0, name@0)], NullsEqual: true 04)------CoalescePartitionsExec 05)--------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] 06)----------CoalesceBatchesExec: target_batch_size=2 @@ -389,7 +389,7 @@ physical_plan 11)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 12)--------DataSourceExec: partitions=1, partition_sizes=[1] 13)--CoalesceBatchesExec: target_batch_size=2 -14)----HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(name@0, name@0)] +14)----HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(name@0, name@0)], NullsEqual: true 15)------CoalescePartitionsExec 16)--------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] 17)----------CoalesceBatchesExec: target_batch_size=2 @@ -413,15 +413,14 @@ logical_plan 06)------TableScan: aggregate_test_100 projection=[c1, c3] physical_plan 01)SortPreservingMergeExec: [c9@1 DESC], fetch=5 -02)--UnionExec -03)----SortExec: TopK(fetch=5), expr=[c9@1 DESC], preserve_partitioning=[true] +02)--SortExec: TopK(fetch=5), expr=[c9@1 DESC], preserve_partitioning=[true] +03)----UnionExec 04)------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Decimal128(20, 0)) as c9] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], file_type=csv, has_header=true -07)----SortExec: TopK(fetch=5), expr=[c9@1 DESC], preserve_partitioning=[true] -08)------ProjectionExec: expr=[c1@0 as c1, CAST(c3@1 AS Decimal128(20, 0)) as c9] -09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3], file_type=csv, has_header=true +07)------ProjectionExec: expr=[c1@0 as c1, CAST(c3@1 AS Decimal128(20, 0)) as c9] +08)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3], file_type=csv, has_header=true query TR SELECT c1, c9 FROM aggregate_test_100 UNION ALL SELECT c1, c3 FROM aggregate_test_100 ORDER BY c9 DESC LIMIT 5 @@ -489,20 +488,20 @@ logical_plan 04)------Limit: skip=0, fetch=3 05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 06)----------SubqueryAlias: a -07)------------Projection: +07)------------Projection: 08)--------------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] 09)----------------Projection: aggregate_test_100.c1 -10)------------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") -11)--------------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +10)------------------Filter: aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +11)--------------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] 12)----Projection: Int64(1) AS cnt 13)------Limit: skip=0, fetch=3 -14)--------EmptyRelation +14)--------EmptyRelation: rows=1 15)----Projection: lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt 16)------Limit: skip=0, fetch=3 17)--------WindowAggr: windowExpr=[[lead(b.c1, Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 18)----------SubqueryAlias: b 19)------------Projection: Int64(1) AS c1 -20)--------------EmptyRelation +20)--------------EmptyRelation: rows=1 physical_plan 01)CoalescePartitionsExec: fetch=3 02)--UnionExec @@ -522,7 +521,7 @@ physical_plan 16)----ProjectionExec: expr=[1 as cnt] 17)------PlaceholderRowExec 18)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] -19)------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +19)------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 20)--------ProjectionExec: expr=[1 as c1] 21)----------PlaceholderRowExec @@ -621,11 +620,11 @@ logical_plan 01)Union 02)--Projection: Int64(1) AS a 03)----Aggregate: groupBy=[[Int64(1)]], aggr=[[]] -04)------EmptyRelation +04)------EmptyRelation: rows=1 05)--Projection: Int64(2) AS a -06)----EmptyRelation +06)----EmptyRelation: rows=1 07)--Projection: Int64(3) AS a -08)----EmptyRelation +08)----EmptyRelation: rows=1 physical_plan 01)UnionExec 02)--ProjectionExec: expr=[Int64(1)@0 as a] @@ -648,12 +647,12 @@ logical_plan 03)----Aggregate: groupBy=[[a.n]], aggr=[[count(Int64(1))]] 04)------SubqueryAlias: a 05)--------Projection: Int64(5) AS n -06)----------EmptyRelation +06)----------EmptyRelation: rows=1 07)--Projection: b.x AS count, b.y AS n 08)----SubqueryAlias: b 09)------Projection: Int64(1) AS x, max(Int64(10)) AS y 10)--------Aggregate: groupBy=[[]], aggr=[[max(Int64(10))]] -11)----------EmptyRelation +11)----------EmptyRelation: rows=1 physical_plan 01)UnionExec 02)--ProjectionExec: expr=[count(Int64(1))@1 as count, n@0 as n] @@ -829,10 +828,10 @@ ORDER BY c1 logical_plan 01)Sort: c1 ASC NULLS LAST 02)--Union -03)----Filter: aggregate_test_100.c1 = Utf8("a") -04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")] -05)----Filter: aggregate_test_100.c1 = Utf8("a") -06)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")] +03)----Filter: aggregate_test_100.c1 = Utf8View("a") +04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8View("a")] +05)----Filter: aggregate_test_100.c1 = Utf8View("a") +06)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8View("a")] physical_plan 01)CoalescePartitionsExec 02)--UnionExec @@ -916,19 +915,19 @@ physical_plan 03)----SortExec: expr=[y@0 ASC NULLS LAST], preserve_partitioning=[true] 04)------ProjectionExec: expr=[CAST(y@0 AS Int64) as y] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------DataSourceExec: partitions=1, partition_sizes=[1] +06)----------DataSourceExec: partitions=1, partition_sizes=[2] 07)----SortExec: expr=[y@0 ASC NULLS LAST], preserve_partitioning=[false] 08)------DataSourceExec: partitions=1, partition_sizes=[1] # optimize_subquery_sort in create_relation removes Sort so the result is not sorted. query I -SELECT * FROM v1; +SELECT * FROM v1 ORDER BY 1; ---- -20 -40 +1 3 3 -1 +20 +40 query TT explain SELECT * FROM v1; @@ -943,7 +942,7 @@ physical_plan 01)UnionExec 02)--ProjectionExec: expr=[CAST(y@0 AS Int64) as y] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------DataSourceExec: partitions=1, partition_sizes=[1] +04)------DataSourceExec: partitions=1, partition_sizes=[2] 05)--DataSourceExec: partitions=1, partition_sizes=[1] statement count 0 diff --git a/datafusion/sqllogictest/test_files/union_by_name.slt b/datafusion/sqllogictest/test_files/union_by_name.slt index 9572e6efc3e67..6a1608d5d1348 100644 --- a/datafusion/sqllogictest/test_files/union_by_name.slt +++ b/datafusion/sqllogictest/test_files/union_by_name.slt @@ -334,90 +334,8 @@ select x, y, z from t3 union all by name select z, y, x from t4 order by x; a b c a b c - -# FIXME: The following should pass without error, but currently it is failing -# due to differing record batch schemas when the SLT runner collects results. -# This is due to the following issue: https://github.com/apache/datafusion/issues/15394#issue-2943811768 -# -# More context can be found here: https://github.com/apache/datafusion/pull/15242#issuecomment-2746563234 -query error +query TTTT rowsort select x, y, z from t3 union all by name select z, y, x, 'd' as zz from t3; ---- -DataFusion error: Internal error: Schema mismatch. Previously had -Schema { - fields: [ - Field { - name: "x", - data_type: Utf8, - nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, - }, - Field { - name: "y", - data_type: Utf8, - nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, - }, - Field { - name: "z", - data_type: Utf8, - nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, - }, - Field { - name: "zz", - data_type: Utf8, - nullable: false, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, - }, - ], - metadata: {}, -} - -Got: -Schema { - fields: [ - Field { - name: "x", - data_type: Utf8, - nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, - }, - Field { - name: "y", - data_type: Utf8, - nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, - }, - Field { - name: "z", - data_type: Utf8, - nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, - }, - Field { - name: "zz", - data_type: Utf8, - nullable: true, - dict_id: 0, - dict_is_ordered: false, - metadata: {}, - }, - ], - metadata: {}, -}. -This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +a b c NULL +a b c d diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index 9c70b1011f58a..74616490ab707 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +# Note: union_table is registered via Rust code in the sqllogictest test harness +# because there is no way to create a union type in SQL today + ########## ## UNION DataType Tests ########## @@ -23,7 +26,8 @@ query ?I select union_column, union_extract(union_column, 'int') from union_table; ---- {int=1} 1 -{int=2} 2 +{string=bar} NULL +{int=3} 3 query error DataFusion error: Execution error: field bool not found on union select union_extract(union_column, 'bool') from union_table; @@ -45,3 +49,19 @@ select union_extract(union_column, 1) from union_table; query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3 select union_extract(union_column, 'a', 'b') from union_table; + +query ?T +select union_column, union_tag(union_column) from union_table; +---- +{int=1} int +{string=bar} string +{int=3} int + +query error DataFusion error: Error during planning: 'union_tag' does not support zero arguments +select union_tag() from union_table; + +query error DataFusion error: Error during planning: The function 'union_tag' expected 1 arguments but received 2 +select union_tag(union_column, 'int') from union_table; + +query error DataFusion error: Execution error: union_tag only support unions, got Utf8 +select union_tag('int') from union_table; diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index b9c13582952a6..67b3a7cf56665 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -91,12 +91,12 @@ select * from unnest(null); ## Unnest empty array in select list -query I +query ? select unnest([]); ---- ## Unnest empty array in from clause -query I +query ? select * from unnest([]); ---- @@ -243,7 +243,7 @@ query error DataFusion error: This feature is not implemented: unnest\(\) does n select unnest(null) from unnest_table; ## Multiple unnest functions in selection -query II +query ?I select unnest([]), unnest(NULL::int[]); ---- @@ -263,10 +263,10 @@ NULL 10 NULL NULL NULL 17 NULL NULL 18 -query IIIT -select - unnest(column1), unnest(column2) + 2, - column3 * 10, unnest(array_remove(column1, '4')) +query IIII +select + unnest(column1), unnest(column2) + 2, + column3 * 10, unnest(array_remove(column1, 4)) from unnest_table; ---- 1 9 10 1 @@ -316,7 +316,7 @@ select * from unnest( 2 b NULL NULL NULL c NULL NULL -query II +query ?I select * from unnest([], NULL::int[]); ---- @@ -863,11 +863,11 @@ select count(*) from (select unnest(range(0, 100000)) id) t inner join (select u # Test implicit LATERAL support for UNNEST # Issue: https://github.com/apache/datafusion/issues/13659 # TODO: https://github.com/apache/datafusion/issues/10048 -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), Column \{ relation: Some\(Bare \{ table: "u" \}\), name: "column1" \}\) +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(Field \{ name: "column1", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}, Column \{ relation: Some\(Bare \{ table: "u" \}\), name: "column1" \}\) select * from unnest_table u, unnest(u.column1); # Test implicit LATERAL support for UNNEST (INNER JOIN) -query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), Column \{ relation: Some\(Bare \{ table: "u" \}\), name: "column1" \}\) +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(Field \{ name: "column1", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}, Column \{ relation: Some\(Bare \{ table: "u" \}\), name: "column1" \}\) select * from unnest_table u INNER JOIN unnest(u.column1) AS t(column1) ON u.column3 = t.column1; # Test implicit LATERAL planning for UNNEST @@ -875,15 +875,15 @@ query TT explain select * from unnest_table u, unnest(u.column1); ---- logical_plan -01)Cross Join: +01)Cross Join: 02)--SubqueryAlias: u 03)----TableScan: unnest_table projection=[column1, column2, column3, column4, column5] 04)--Subquery: 05)----Projection: __unnest_placeholder(outer_ref(u.column1),depth=1) AS UNNEST(outer_ref(u.column1)) 06)------Unnest: lists[__unnest_placeholder(outer_ref(u.column1))|depth=1] structs[] 07)--------Projection: outer_ref(u.column1) AS __unnest_placeholder(outer_ref(u.column1)) -08)----------EmptyRelation -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), Column { relation: Some(Bare { table: "u" }), name: "column1" }) +08)----------EmptyRelation: rows=1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "column1", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "u" }), name: "column1" }) # Test implicit LATERAL planning for UNNEST (INNER JOIN) query TT @@ -898,8 +898,8 @@ logical_plan 06)------Projection: __unnest_placeholder(outer_ref(u.column1),depth=1) AS column1 07)--------Unnest: lists[__unnest_placeholder(outer_ref(u.column1))|depth=1] structs[] 08)----------Projection: outer_ref(u.column1) AS __unnest_placeholder(outer_ref(u.column1)) -09)------------EmptyRelation -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), Column { relation: Some(Bare { table: "u" }), name: "column1" }) +09)------------EmptyRelation: rows=1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Field { name: "column1", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Column { relation: Some(Bare { table: "u" }), name: "column1" }) # uncorrelated EXISTS with unnest query I diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 908d2b34aea46..9f2c16b21106f 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -31,7 +31,7 @@ explain update t1 set a=1, b=2, c=3.0, d=NULL; ---- logical_plan 01)Dml: op=[Update] table=[t1] -02)--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d +02)--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8View) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d 03)----TableScan: t1 physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Update) @@ -40,7 +40,7 @@ explain update t1 set a=c+1, b=a, c=c+1.0, d=b; ---- logical_plan 01)Dml: op=[Update] table=[t1] -02)--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d +02)--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8View) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d 03)----TableScan: t1 physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Update) @@ -69,7 +69,7 @@ explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1 logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d -03)----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) +03)----Filter: t1.a = t2.a AND t1.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) 04)------Cross Join: 05)--------TableScan: t1 06)--------TableScan: t2 @@ -89,7 +89,7 @@ explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d -03)----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) +03)----Filter: t.a = t2.a AND t.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) 04)------Cross Join: 05)--------SubqueryAlias: t 06)----------TableScan: t1 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 76e3751e4b8e4..f1a708d84dd3c 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -263,13 +263,13 @@ logical_plan 07)------------SubqueryAlias: _sample_data 08)--------------Union 09)----------------Projection: Int64(1) AS a, Utf8("aa") AS b -10)------------------EmptyRelation +10)------------------EmptyRelation: rows=1 11)----------------Projection: Int64(3) AS a, Utf8("aa") AS b -12)------------------EmptyRelation +12)------------------EmptyRelation: rows=1 13)----------------Projection: Int64(5) AS a, Utf8("bb") AS b -14)------------------EmptyRelation +14)------------------EmptyRelation: rows=1 15)----------------Projection: Int64(7) AS a, Utf8("bb") AS b -16)------------------EmptyRelation +16)------------------EmptyRelation: rows=1 physical_plan 01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] 02)--SortExec: expr=[b@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -348,19 +348,19 @@ logical_plan 09)----------------SubqueryAlias: _sample_data 10)------------------Union 11)--------------------Projection: Int64(1) AS a, Utf8("aa") AS b -12)----------------------EmptyRelation +12)----------------------EmptyRelation: rows=1 13)--------------------Projection: Int64(3) AS a, Utf8("aa") AS b -14)----------------------EmptyRelation +14)----------------------EmptyRelation: rows=1 15)--------------------Projection: Int64(5) AS a, Utf8("bb") AS b -16)----------------------EmptyRelation +16)----------------------EmptyRelation: rows=1 17)--------------------Projection: Int64(7) AS a, Utf8("bb") AS b -18)----------------------EmptyRelation +18)----------------------EmptyRelation: rows=1 physical_plan 01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] 02)--ProjectionExec: expr=[b@0 as b, max(d.a)@1 as max_a, max(d.seq)@2 as max(d.seq)] 03)----AggregateExec: mode=SinglePartitioned, gby=[b@2 as b], aggr=[max(d.a), max(d.seq)], ordering_mode=Sorted 04)------ProjectionExec: expr=[row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as seq, a@0 as a, b@1 as b] -05)--------BoundedWindowAggExec: wdw=[row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 06)----------SortExec: expr=[b@1 ASC NULLS LAST, a@0 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=4 @@ -1241,9 +1241,9 @@ logical_plan 05)--------TableScan: aggregate_test_100 projection=[c8, c9] physical_plan 01)ProjectionExec: expr=[c9@0 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum2] -02)--BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 03)----ProjectionExec: expr=[c9@1 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------SortExec: expr=[c9@1 ASC NULLS LAST, c8@0 ASC NULLS LAST], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c8, c9], file_type=csv, has_header=true @@ -1263,8 +1263,8 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c2@0 as c2, max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] 02)--WindowAggExec: wdw=[sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] -03)----BoundedWindowAggExec: wdw=[max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------SortExec: expr=[c2@0 ASC NULLS LAST, c9@1 ASC NULLS LAST], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c9], file_type=csv, has_header=true @@ -1287,9 +1287,9 @@ physical_plan 01)SortExec: expr=[c2@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--ProjectionExec: expr=[c2@0 as c2, max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] 03)----WindowAggExec: wdw=[sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] -04)------BoundedWindowAggExec: wdw=[max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------SortExec: expr=[c9@1 ASC NULLS LAST, c2@0 ASC NULLS LAST], preserve_partitioning=[false] -06)----------BoundedWindowAggExec: wdw=[min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 07)------------SortExec: expr=[c2@0 ASC NULLS LAST, c9@1 ASC NULLS LAST], preserve_partitioning=[false] 08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c9], file_type=csv, has_header=true @@ -1311,12 +1311,12 @@ logical_plan 05)--------TableScan: aggregate_test_100 projection=[c1, c2, c4] physical_plan 01)ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@2 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] -02)--BoundedWindowAggExec: wdw=[count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 03)----SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=4096 05)--------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 06)----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] -07)------------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +07)------------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 08)--------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 09)----------------CoalesceBatchesExec: target_batch_size=4096 10)------------------RepartitionExec: partitioning=Hash([c1@0, c2@1], 2), input_partitions=2 @@ -1343,9 +1343,9 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@0 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +05)--------SortExec: TopK(fetch=10), expr=[c9@0 DESC], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true query III @@ -1362,6 +1362,110 @@ SELECT 4144173353 20935849039 28472563256 4076864659 24997484146 28118515915 +# Only 1 SortExec was added, and limit 100 was turned into limit 10 +query TT +EXPLAIN SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM ( + SELECT c9, + FROM aggregate_test_100 + ORDER BY c9 DESC + LIMIT 100 + ) + LIMIT 5 +---- +logical_plan +01)Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 +02)--Limit: skip=0, fetch=5 +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +05)--------Sort: aggregate_test_100.c9 DESC NULLS FIRST, fetch=100 +06)----------TableScan: aggregate_test_100 projection=[c9] +physical_plan +01)ProjectionExec: expr=[c9@0 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +05)--------SortExec: TopK(fetch=10), expr=[c9@0 DESC], preserve_partitioning=[false] +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true + +# ensure limit pushdown can handle bigger preceding instead of following +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query III +SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5 +---- +4268716378 24997484146 8498370520 +4229654142 29012926487 12714811027 +4216440507 28743001064 16858984380 +4144173353 28472563256 20935849039 +4076864659 28118515915 24997484146 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query III +SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5 +---- +4268716378 24997484146 8498370520 +4229654142 29012926487 12714811027 +4216440507 28743001064 16858984380 +4144173353 28472563256 20935849039 +4076864659 28118515915 24997484146 + +# test_window_agg_sort_reversed_plan +# Only 1 SortExec was added, limit & skip are pushed down +query TT +EXPLAIN SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5 + OFFSET 5 +---- +logical_plan +01)Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 +02)--Limit: skip=5, fetch=5 +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +05)--------TableScan: aggregate_test_100 projection=[c9] +physical_plan +01)ProjectionExec: expr=[c9@0 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2] +02)--GlobalLimitExec: skip=5, fetch=5 +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +05)--------SortExec: TopK(fetch=15), expr=[c9@0 DESC], preserve_partitioning=[false] +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true + +query III +SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5 + OFFSET 5 +---- +4061635107 29012926487 27741341640 +4015442341 28743001064 27423817254 +3998790955 28472563256 27079733310 +3959216334 28118515915 26689577379 +3862393166 27741341640 26284746231 + # test_window_agg_sort_reversed_plan_builtin query TT EXPLAIN SELECT @@ -1384,8 +1488,8 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -1427,9 +1531,9 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as rn1, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] -05)--------BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +04)------SortExec: TopK(fetch=10), expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 06)----------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -1469,10 +1573,10 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@2 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as rn2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c9@2 ASC NULLS LAST, c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[false] -05)--------BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +04)------SortExec: TopK(fetch=10), expr=[c9@2 ASC NULLS LAST, c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 07)------------SortExec: expr=[c9@2 DESC, c1@0 DESC], preserve_partitioning=[false] 08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c9], file_type=csv, has_header=true @@ -1553,17 +1657,17 @@ physical_plan 02)--GlobalLimitExec: skip=0, fetch=5 03)----WindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] 04)------ProjectionExec: expr=[c1@0 as c1, c3@2 as c3, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -05)--------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 06)----------SortExec: expr=[c3@2 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[false] -07)------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +07)------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 08)--------------SortExec: expr=[c3@2 ASC NULLS LAST, c1@0 ASC], preserve_partitioning=[false] -09)----------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +09)----------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 10)------------------SortExec: expr=[c3@2 ASC NULLS LAST, c1@0 DESC], preserve_partitioning=[false] 11)--------------------WindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }] 12)----------------------WindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] 13)------------------------SortExec: expr=[c3@2 DESC NULLS LAST], preserve_partitioning=[false] 14)--------------------------WindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] -15)----------------------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +15)----------------------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 16)------------------------------SortExec: expr=[c3@2 DESC, c1@0 ASC NULLS LAST], preserve_partitioning=[false] 17)--------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/null_cases.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true @@ -1637,9 +1741,9 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@1 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c1@0 ASC NULLS LAST, c9@1 DESC], preserve_partitioning=[false] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +05)--------SortExec: TopK(fetch=10), expr=[c1@0 ASC NULLS LAST, c9@1 DESC], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], file_type=csv, has_header=true @@ -1681,9 +1785,9 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@1 as c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c1@0 ASC NULLS LAST, c9@1 DESC], preserve_partitioning=[false] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +05)--------SortExec: TopK(fetch=10), expr=[c1@0 ASC NULLS LAST, c9@1 DESC], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], file_type=csv, has_header=true query III @@ -1729,7 +1833,7 @@ physical_plan 02)--GlobalLimitExec: skip=0, fetch=5 03)----WindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)), is_causal: false }] 04)------ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, c3@2 as c3, c9@3 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 06)----------SortPreservingMergeExec: [__common_expr_1@0 DESC, c9@3 DESC, c2@1 ASC NULLS LAST] 07)------------SortExec: expr=[__common_expr_1@0 DESC, c9@3 DESC, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 08)--------------ProjectionExec: expr=[c3@1 + c4@2 as __common_expr_1, c2@0 as c2, c3@1 as c3, c9@3 as c9] @@ -1767,11 +1871,11 @@ logical_plan 01)Projection: count(Int64(1)) AS count(*) AS global_count 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: a -04)------Projection: +04)------Projection: 05)--------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] 06)----------Projection: aggregate_test_100.c1 -07)------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") -08)--------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +07)------------Filter: aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +08)--------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as global_count] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -1822,13 +1926,13 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5 02)--ProjectionExec: expr=[c3@0 as c3, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2] -03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c3@0 ASC NULLS LAST, c9@1 DESC], preserve_partitioning=[true] 05)--------CoalesceBatchesExec: target_batch_size=4096 06)----------RepartitionExec: partitioning=Hash([c3@0], 2), input_partitions=2 07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 08)--------------ProjectionExec: expr=[c3@1 as c3, c9@2 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -09)----------------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +09)----------------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 10)------------------SortExec: expr=[c3@1 DESC, c9@2 DESC, c2@0 ASC NULLS LAST], preserve_partitioning=[false] 11)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3, c9], file_type=csv, has_header=true @@ -1864,7 +1968,7 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST] 02)--ProjectionExec: expr=[c1@0 as c1, row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] -03)----BoundedWindowAggExec: wdw=[row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 04)------SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[true] 05)--------CoalesceBatchesExec: target_batch_size=4096 06)----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -1993,7 +2097,7 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, rn1@1 ASC NULLS LAST] 02)--ProjectionExec: expr=[c1@0 as c1, row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] -03)----BoundedWindowAggExec: wdw=[row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 04)------SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[true] 05)--------CoalesceBatchesExec: target_batch_size=4096 06)----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -2019,10 +2123,10 @@ logical_plan physical_plan 01)SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--ProjectionExec: expr=[c1@0 as c1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING@2 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2] -03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 04)------SortPreservingMergeExec: [c9@1 ASC NULLS LAST] 05)--------SortExec: expr=[c9@1 ASC NULLS LAST], preserve_partitioning=[true] -06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING], mode=[Sorted] 07)------------SortExec: expr=[c1@0 ASC NULLS LAST, c9@1 ASC NULLS LAST], preserve_partitioning=[true] 08)--------------CoalesceBatchesExec: target_batch_size=4096 09)----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -2107,10 +2211,10 @@ logical_plan physical_plan 01)SortExec: TopK(fetch=5), expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--ProjectionExec: expr=[c9@2 as c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] -03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 04)------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c9@3 as c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] 05)--------WindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] -06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 07)------------WindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] 08)--------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, c9@3 ASC NULLS LAST, c8@2 ASC NULLS LAST], preserve_partitioning=[false] 09)----------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], file_type=csv, has_header=true @@ -2162,11 +2266,11 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@1 as c9, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 04)------ProjectionExec: expr=[c2@0 as c2, c9@2 as c9, c1_alias@3 as c1_alias, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] 05)--------WindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] 06)----------ProjectionExec: expr=[c2@1 as c2, c8@2 as c8, c9@3 as c9, c1_alias@4 as c1_alias, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING] -07)------------BoundedWindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +07)------------BoundedWindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 08)--------------WindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] 09)----------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, c9@3 ASC NULLS LAST, c8@2 ASC NULLS LAST], preserve_partitioning=[false] 10)------------------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c8@2 as c8, c9@3 as c9, c1@0 as c1_alias] @@ -2208,9 +2312,9 @@ physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] 02)--SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: true }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING], mode=[Sorted] 05)--------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING] -06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 07)------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[false] 08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c9, c12], file_type=csv, has_header=true @@ -2244,7 +2348,7 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -2281,7 +2385,7 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -2318,7 +2422,7 @@ logical_plan physical_plan 01)SortExec: TopK(fetch=5), expr=[rn1@1 DESC], preserve_partitioning=[false] 02)--ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -2356,9 +2460,9 @@ logical_plan 03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST, c9@0 ASC NULLS LAST], preserve_partitioning=[false] +01)SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST, c9@0 ASC NULLS LAST], preserve_partitioning=[false], sort_prefix=[rn1@1 ASC NULLS LAST] 02)--ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -2433,7 +2537,7 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -2455,7 +2559,7 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c5@0 as c5, c9@1 as c9, row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rn1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Decimal128(None,21,0)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[CAST(c9@1 AS Decimal128(20, 0)) + CAST(c5@0 AS Decimal128(20, 0)) DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5, c9], file_type=csv, has_header=true @@ -2476,7 +2580,7 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@0 as c9, CAST(row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 AS Int64) as rn1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -2498,7 +2602,7 @@ SELECT FROM aggregate_test_100; statement ok -set datafusion.optimizer.skip_failed_rules = true +set datafusion.optimizer.skip_failed_rules = false # Error is returned from the logical plan. query error Cannot cast Utf8\("1 DAY"\) to Int8 @@ -2581,10 +2685,10 @@ physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7 as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1, sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2, minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3, cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3] 02)--SortExec: TopK(fetch=5), expr=[inc_col@24 DESC], preserve_partitioning=[false] 03)----ProjectionExec: expr=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as sum1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@14 as sum2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@15 as sum3, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as min1, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as min2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as min3, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as max1, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as max2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as max3, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@22 as cnt1, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cnt2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@2 as sumr1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@3 as sumr2, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sumr3, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as minr1, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as minr2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as minr3, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as maxr1, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as maxr2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as maxr3, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@11 as cntr1, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cntr2, sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as sum4, count(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as cnt3, inc_col@1 as inc_col] -04)------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }, count(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, count(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING], mode=[Sorted] 05)--------ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, inc_col@3 as inc_col, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@5 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@6 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@7 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@12 as max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@13 as max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@14 as count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@15 as count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@22 as max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@23 as max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@25 as count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@26 as count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING] -06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)), is_causal: false }, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)), is_causal: false }, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)), is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Field { name: "min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Field { name: "max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING: Field { name: "count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING], mode=[Sorted] +07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING: Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 4 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING: Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 8 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Field { name: "min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 5 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Field { name: "max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 5 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Field { name: "count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 2 PRECEDING AND 6 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Field { name: "count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 8 FOLLOWING], mode=[Sorted] 08)--------------ProjectionExec: expr=[CAST(desc_col@2 AS Int64) as __common_expr_1, CAST(inc_col@1 AS Int64) as __common_expr_2, ts@0 as ts, inc_col@1 as inc_col, desc_col@2 as desc_col] 09)----------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col, desc_col], output_ordering=[ts@0 ASC NULLS LAST], file_type=csv, has_header=true @@ -2667,8 +2771,8 @@ logical_plan physical_plan 01)SortExec: TopK(fetch=5), expr=[ts@0 DESC], preserve_partitioning=[false] 02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] -03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "nth_value(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING], mode=[Sorted] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], file_type=csv, has_header=true query IIIIIIIIIIIIIIIIIIIIIIIII @@ -2739,8 +2843,8 @@ physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, min1@2 as min1, min2@3 as min2, max1@4 as max1, max2@5 as max2, count1@6 as count1, count2@7 as count2, avg1@8 as avg1, avg2@9 as avg2] 02)--SortExec: TopK(fetch=5), expr=[inc_col@10 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@9 as sum1, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@10 as min1, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@5 as min2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@11 as max1, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@6 as max2, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@12 as count1, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@7 as count2, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@13 as avg1, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@8 as avg2, inc_col@3 as inc_col] -04)------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Field { name: "count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Field { name: "avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING], mode=[Sorted] 06)----------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, CAST(inc_col@1 AS Float64) as __common_expr_2, ts@0 as ts, inc_col@1 as inc_col] 07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], file_type=csv, has_header=true @@ -2791,8 +2895,8 @@ physical_plan 01)ProjectionExec: expr=[first_value1@0 as first_value1, first_value2@1 as first_value2, last_value1@2 as last_value1, last_value2@3 as last_value2, nth_value1@4 as nth_value1] 02)--SortExec: TopK(fetch=5), expr=[inc_col@5 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as first_value1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as first_value2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as last_value1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as last_value2, nth_value(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as nth_value1, inc_col@1 as inc_col] -04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, nth_value(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "nth_value(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, nth_value(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Field { name: "nth_value(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING], mode=[Sorted] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], file_type=csv, has_header=true query IIIII @@ -2835,8 +2939,8 @@ logical_plan physical_plan 01)ProjectionExec: expr=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING], mode=[Sorted] 05)--------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, ts@0 as ts, inc_col@1 as inc_col] 06)----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] @@ -2880,8 +2984,8 @@ logical_plan physical_plan 01)ProjectionExec: expr=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND 3 FOLLOWING], mode=[Sorted] 05)--------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, ts@0 as ts, inc_col@1 as inc_col] 06)----------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] @@ -2980,12 +3084,12 @@ logical_plan physical_plan 01)ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)), is_causal: true }], mode=[Linear] -04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[PartiallySorted([1, 0])] -05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[PartiallySorted([0])] -07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow, is_causal: true }], mode=[PartiallySorted([0, 1])] -08)--------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING], mode=[Linear] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING], mode=[PartiallySorted([1, 0])] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING], mode=[PartiallySorted([0])] +07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND CURRENT ROW], mode=[PartiallySorted([0, 1])] +08)--------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 09)----------------ProjectionExec: expr=[CAST(c@2 AS Int64) as __common_expr_1, a@0 as a, b@1 as b, c@2 as c, d@3 as d] 10)------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] @@ -3048,17 +3152,17 @@ logical_plan physical_plan 01)SortExec: TopK(fetch=5), expr=[c@2 ASC NULLS LAST], preserve_partitioning=[false] 02)--ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] -03)----BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)), is_causal: true }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING], mode=[Sorted] 04)------SortExec: expr=[d@4 ASC NULLS LAST, a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] -05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING], mode=[Sorted] 06)----------SortExec: expr=[b@2 ASC NULLS LAST, a@1 ASC NULLS LAST, d@4 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] -07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 08)--------------SortExec: expr=[b@2 ASC NULLS LAST, a@1 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] -09)----------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +09)----------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING], mode=[Sorted] 10)------------------SortExec: expr=[a@1 ASC NULLS LAST, d@4 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] -11)--------------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted] +11)--------------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND CURRENT ROW], mode=[Sorted] 12)----------------------SortExec: expr=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, d@4 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] -13)------------------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +13)------------------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING], mode=[Sorted] 14)--------------------------ProjectionExec: expr=[CAST(c@2 AS Int64) as __common_expr_1, a@0 as a, b@1 as b, c@2 as c, d@3 as d] 15)----------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], file_type=csv, has_header=true @@ -3122,7 +3226,7 @@ physical_plan 01)ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] 02)--CoalesceBatchesExec: target_batch_size=4096, fetch=5 03)----FilterExec: row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 < 50 -04)------BoundedWindowAggExec: wdw=[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] # Top level sort is pushed down through BoundedWindowAggExec as its SUM result does already satisfy the required @@ -3144,7 +3248,7 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c9@0 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true @@ -3229,11 +3333,11 @@ logical_plan 08)--------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] physical_plan 01)ProjectionExec: expr=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum2, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum3, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum4] -02)--BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] +02)--BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Linear] 03)----ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, a@1 as a, d@4 as d, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@7 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[PartiallySorted([0])] -06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[PartiallySorted([0])] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 07)------------ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, a@0 as a, b@1 as b, c@2 as c, d@3 as d] 08)--------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] @@ -3260,17 +3364,17 @@ logical_plan 08)--------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] physical_plan 01)ProjectionExec: expr=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum2, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum3, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum4] -02)--BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] +02)--BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Linear] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------RepartitionExec: partitioning=Hash([d@2], 2), input_partitions=2, preserve_order=true, sort_exprs=__common_expr_1@0 ASC NULLS LAST, a@1 ASC NULLS LAST 05)--------ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, a@1 as a, d@4 as d, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@7 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 07)------------CoalesceBatchesExec: target_batch_size=4096 08)--------------RepartitionExec: partitioning=Hash([b@2, a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, __common_expr_1@0 ASC NULLS LAST -09)----------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[PartiallySorted([0])] +09)----------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[PartiallySorted([0])] 10)------------------CoalesceBatchesExec: target_batch_size=4096 11)--------------------RepartitionExec: partitioning=Hash([a@1, d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, __common_expr_1@0 ASC NULLS LAST -12)----------------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +12)----------------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 13)------------------------CoalesceBatchesExec: target_batch_size=4096 14)--------------------------RepartitionExec: partitioning=Hash([a@1, b@2], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, __common_expr_1@0 ASC NULLS LAST 15)----------------------------ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, a@0 as a, b@1 as b, c@2 as c, d@3 as d] @@ -3329,7 +3433,7 @@ logical_plan physical_plan 01)SortExec: TopK(fetch=5), expr=[c3@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--ProjectionExec: expr=[c3@0 as c3, max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as max1] -03)----BoundedWindowAggExec: wdw=[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c12@1 ASC NULLS LAST], preserve_partitioning=[false] 05)--------ProjectionExec: expr=[c3@0 as c3, c12@2 as c12, min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] 06)----------WindowAggExec: wdw=[min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] @@ -3373,7 +3477,7 @@ physical_plan 01)ProjectionExec: expr=[min1@0 as min1, max1@1 as max1] 02)--SortExec: TopK(fetch=5), expr=[c3@2 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min1, min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max1, c3@0 as c3] -04)------BoundedWindowAggExec: wdw=[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }, min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------SortExec: expr=[c12@1 ASC NULLS LAST], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c12], file_type=csv, has_header=true @@ -3425,7 +3529,7 @@ logical_plan 02)--Filter: multiple_ordered_table.b = Int32(0) 03)----TableScan: multiple_ordered_table projection=[a0, a, b, c, d], partial_filters=[multiple_ordered_table.b = Int32(0)] physical_plan -01)BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 02)--CoalesceBatchesExec: target_batch_size=4096 03)----FilterExec: b@2 = 0 04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_orderings=[[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST], [c@3 ASC NULLS LAST]], file_type=csv, has_header=true @@ -3443,7 +3547,7 @@ logical_plan 02)--Filter: multiple_ordered_table.b = Int32(0) 03)----TableScan: multiple_ordered_table projection=[a0, a, b, c, d], partial_filters=[multiple_ordered_table.b = Int32(0)] physical_plan -01)BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 02)--SortExec: expr=[d@4 ASC NULLS LAST], preserve_partitioning=[false] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------FilterExec: b@2 = 0 @@ -3480,9 +3584,9 @@ logical_plan 05)--------TableScan: multiple_ordered_table projection=[a, b, c, d] physical_plan 01)ProjectionExec: expr=[min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max1] -02)--BoundedWindowAggExec: wdw=[min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 03)----ProjectionExec: expr=[c@2 as c, d@3 as d, max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -04)------BoundedWindowAggExec: wdw=[max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], file_type=csv, has_header=true query TT @@ -3499,7 +3603,7 @@ logical_plan 04)------TableScan: multiple_ordered_table projection=[c, d], partial_filters=[multiple_ordered_table.d = Int32(0)] physical_plan 01)ProjectionExec: expr=[max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max_c] -02)--BoundedWindowAggExec: wdw=[max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------FilterExec: d@1 = 0 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], file_type=csv, has_header=true @@ -3514,7 +3618,7 @@ logical_plan 03)----TableScan: multiple_ordered_table projection=[a, c, d] physical_plan 01)ProjectionExec: expr=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -02)--BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], file_type=csv, has_header=true query TT @@ -3527,7 +3631,7 @@ logical_plan 03)----TableScan: multiple_ordered_table projection=[a, b, c, d] physical_plan 01)ProjectionExec: expr=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -02)--BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], file_type=csv, has_header=true query I @@ -3620,7 +3724,7 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [c@3 ASC NULLS LAST] 02)--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] -03)----BoundedWindowAggExec: wdw=[avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] +03)----BoundedWindowAggExec: wdw=[avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Field { name: "avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN 10 PRECEDING AND CURRENT ROW], mode=[Linear] 04)------CoalesceBatchesExec: target_batch_size=4096 05)--------RepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST 06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3955,7 +4059,7 @@ logical_plan 03)----TableScan: table_with_pk projection=[sn, ts, currency, amount] physical_plan 01)ProjectionExec: expr=[sn@0 as sn, ts@1 as ts, currency@2 as currency, amount@3 as amount, sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum1] -02)--BoundedWindowAggExec: wdw=[sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 03)----SortExec: expr=[sn@0 ASC NULLS LAST], preserve_partitioning=[false] 04)------DataSourceExec: partitions=1, partition_sizes=[1] @@ -4076,7 +4180,7 @@ physical_plan 02)--GlobalLimitExec: skip=0, fetch=5 03)----WindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)), is_causal: false }] 04)------ProjectionExec: expr=[c3@0 as c3, c4@1 as c4, c9@2 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1] -05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 06)----------SortExec: expr=[c3@0 + c4@1 DESC], preserve_partitioning=[false] 07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c4, c9], file_type=csv, has_header=true @@ -4115,7 +4219,7 @@ logical_plan 04)------TableScan: a projection=[a] physical_plan 01)ProjectionExec: expr=[count(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -02)--BoundedWindowAggExec: wdw=[count(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "count(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[count(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "count(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -4138,7 +4242,7 @@ logical_plan 04)------TableScan: a projection=[a] physical_plan 01)ProjectionExec: expr=[row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] -02)--BoundedWindowAggExec: wdw=[row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +02)--BoundedWindowAggExec: wdw=[row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Field { name: "row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], mode=[Sorted] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -4341,6 +4445,9 @@ LIMIT 5; 24 31 14 94 +statement ok +set datafusion.execution.batch_size = 100; + # Tests schema and data are in sync for mixed nulls and not nulls values for builtin window function query T select lag(a) over (order by a ASC NULLS FIRST) as x1 @@ -4938,11 +5045,11 @@ FROM (SELECT c1, c2, ROW_NUMBER() OVER() as rn FROM t LIMIT 5) GROUP BY rn -ORDER BY rn; +ORDER BY 1, 2, 3 ---- 1 a 1 -2 b 2 1 a 3 +2 b 2 3 NULL 4 NULL a4 5 @@ -5181,6 +5288,10 @@ order by c1; 3 1 1 3 10 2 + +statement ok +set datafusion.execution.batch_size = 1; + # push filter since it uses a partition column query TT explain select c1, c2, rank @@ -5200,7 +5311,7 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank@2 ASC NULLS LAST] 02)--ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank] -03)----BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 04)------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 05)--------CoalesceBatchesExec: target_batch_size=1 06)----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -5244,7 +5355,7 @@ physical_plan 02)--ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank] 03)----CoalesceBatchesExec: target_batch_size=1 04)------FilterExec: c2@1 >= 10 -05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=1 08)--------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -5286,7 +5397,7 @@ physical_plan 02)--ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank] 03)----CoalesceBatchesExec: target_batch_size=1 04)------FilterExec: c2@1 = 10 -05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 06)----------SortExec: expr=[c2@1 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=1 08)--------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -5327,7 +5438,7 @@ physical_plan 02)--ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank] 03)----CoalesceBatchesExec: target_batch_size=1 04)------FilterExec: c1@0 = 1 OR c2@1 = 10 -05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=1 08)--------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -5370,11 +5481,11 @@ physical_plan 01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank1@2 ASC NULLS LAST, rank2@3 ASC NULLS LAST] 02)--SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank1@2 ASC NULLS LAST, rank2@3 ASC NULLS LAST], preserve_partitioning=[true] 03)----ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank1, rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as rank2] -04)------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------SortExec: expr=[c2@1 ASC NULLS LAST, c1@0 ASC NULLS LAST], preserve_partitioning=[true] 06)----------CoalesceBatchesExec: target_batch_size=1 07)------------RepartitionExec: partitioning=Hash([c2@1, c1@0], 2), input_partitions=2 -08)--------------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +08)--------------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 09)----------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 10)------------------CoalesceBatchesExec: target_batch_size=1 11)--------------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -5421,13 +5532,13 @@ physical_plan 01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank1@2 ASC NULLS LAST, rank2@3 ASC NULLS LAST] 02)--SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, rank1@2 ASC NULLS LAST, rank2@3 ASC NULLS LAST], preserve_partitioning=[true] 03)----ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rank1, rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as rank2] -04)------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() PARTITION BY [t1.c2, t1.c1] ORDER BY [t1.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 05)--------SortExec: expr=[c2@1 ASC NULLS LAST, c1@0 ASC NULLS LAST], preserve_partitioning=[true] 06)----------CoalesceBatchesExec: target_batch_size=1 07)------------RepartitionExec: partitioning=Hash([c2@1, c1@0], 2), input_partitions=2 08)--------------CoalesceBatchesExec: target_batch_size=1 09)----------------FilterExec: c2@1 > 1 -10)------------------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +10)------------------BoundedWindowAggExec: wdw=[rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() PARTITION BY [t1.c1] ORDER BY [t1.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] 11)--------------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 12)----------------------CoalesceBatchesExec: target_batch_size=1 13)------------------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -5507,6 +5618,7 @@ physical_plan 02)--WindowAggExec: wdw=[sum(aggregate_test_100_ordered.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100_ordered.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], file_type=csv, has_header=true + query TT EXPLAIN SELECT c1, MIN(c5) OVER(PARTITION BY c1) as min_c5 FROM aggregate_test_100_ordered ORDER BY c1, min_c5 DESC NULLS LAST; ---- @@ -5537,6 +5649,21 @@ physical_plan 02)--WindowAggExec: wdw=[max(aggregate_test_100_ordered.c5) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "max(aggregate_test_100_ordered.c5) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5], file_type=csv, has_header=true +query II rowsort +SELECT + t1.v1, + SUM(t1.v1) OVER w + 1 +FROM + generate_series(1, 5) AS t1(v1) +WINDOW + w AS (ORDER BY t1.v1); +---- +1 2 +2 4 +3 7 +4 11 +5 16 + # Testing Utf8View with window statement ok CREATE TABLE aggregate_test_100_utf8view AS SELECT @@ -5595,3 +5722,423 @@ DROP TABLE aggregate_test_100_utf8view; statement ok DROP TABLE aggregate_test_100 + +# window definitions with aliases +query II rowsort +SELECT + t1.v1, + SUM(t1.v1) OVER W + 1 +FROM + generate_series(1, 5) AS t1(v1) +WINDOW + w AS (ORDER BY t1.v1); +---- +1 2 +2 4 +3 7 +4 11 +5 16 + +# window definitions with aliases +query II rowsort +SELECT + t1.v1, + SUM(t1.v1) OVER w + 1 +FROM + generate_series(1, 5) AS t1(v1) +WINDOW + W AS (ORDER BY t1.v1); +---- +1 2 +2 4 +3 7 +4 11 +5 16 + + +# window with distinct operation +statement ok +CREATE TABLE table_test_distinct_count ( + k VARCHAR, + v Int, + time TIMESTAMP WITH TIME ZONE +); + +statement ok +INSERT INTO table_test_distinct_count (k, v, time) VALUES + ('a', 1, '1970-01-01T00:01:00.00Z'), + ('a', 1, '1970-01-01T00:02:00.00Z'), + ('a', 1, '1970-01-01T00:03:00.00Z'), + ('a', 2, '1970-01-01T00:03:00.00Z'), + ('a', 1, '1970-01-01T00:04:00.00Z'), + ('b', 3, '1970-01-01T00:01:00.00Z'), + ('b', 3, '1970-01-01T00:02:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'); + +query TPII +SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ORDER BY k, time; +---- +a 1970-01-01T00:01:00Z 1 1 +a 1970-01-01T00:02:00Z 2 1 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:04:00Z 4 2 +b 1970-01-01T00:01:00Z 1 1 +b 1970-01-01T00:02:00Z 2 1 +b 1970-01-01T00:03:00Z 4 2 +b 1970-01-01T00:03:00Z 4 2 + + +query TT +EXPLAIN SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ORDER BY k, time; +---- +logical_plan +01)Sort: table_test_distinct_count.k ASC NULLS LAST, table_test_distinct_count.time ASC NULLS LAST +02)--Projection: table_test_distinct_count.k, table_test_distinct_count.time, count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS normal_count, count(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS distinct_count +03)----WindowAggr: windowExpr=[[count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, count(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW]] +04)------TableScan: table_test_distinct_count projection=[k, v, time] +physical_plan +01)SortPreservingMergeExec: [k@0 ASC NULLS LAST, time@1 ASC NULLS LAST] +02)--ProjectionExec: expr=[k@0 as k, time@2 as time, count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as normal_count, count(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as distinct_count] +03)----BoundedWindowAggExec: wdw=[count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW, count(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: expr=[k@0 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] +05)--------CoalesceBatchesExec: target_batch_size=1 +06)----------RepartitionExec: partitioning=Hash([k@0], 2), input_partitions=2 +07)------------DataSourceExec: partitions=2, partition_sizes=[5, 4] + + +# Add testing for distinct sum +query TPII +SELECT + k, + time, + SUM(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS sum_v, + SUM(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS sum_distinct_v +FROM table_test_distinct_count +ORDER BY k, time; +---- +a 1970-01-01T00:01:00Z 1 1 +a 1970-01-01T00:02:00Z 2 1 +a 1970-01-01T00:03:00Z 5 3 +a 1970-01-01T00:03:00Z 5 3 +a 1970-01-01T00:04:00Z 5 3 +b 1970-01-01T00:01:00Z 3 3 +b 1970-01-01T00:02:00Z 6 3 +b 1970-01-01T00:03:00Z 14 7 +b 1970-01-01T00:03:00Z 14 7 + + + +query TT +EXPLAIN SELECT + k, + time, + SUM(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS sum_v, + SUM(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS sum_distinct_v +FROM table_test_distinct_count +ORDER BY k, time; +---- +logical_plan +01)Sort: table_test_distinct_count.k ASC NULLS LAST, table_test_distinct_count.time ASC NULLS LAST +02)--Projection: table_test_distinct_count.k, table_test_distinct_count.time, sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS sum_v, sum(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS sum_distinct_v +03)----WindowAggr: windowExpr=[[sum(__common_expr_1) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, sum(DISTINCT __common_expr_1) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS sum(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW]] +04)------Projection: CAST(table_test_distinct_count.v AS Int64) AS __common_expr_1, table_test_distinct_count.k, table_test_distinct_count.time +05)--------TableScan: table_test_distinct_count projection=[k, v, time] +physical_plan +01)SortPreservingMergeExec: [k@0 ASC NULLS LAST, time@1 ASC NULLS LAST] +02)--ProjectionExec: expr=[k@1 as k, time@2 as time, sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as sum_v, sum(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as sum_distinct_v] +03)----BoundedWindowAggExec: wdw=[sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "sum(table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW, sum(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "sum(DISTINCT table_test_distinct_count.v) PARTITION BY [table_test_distinct_count.k] ORDER BY [table_test_distinct_count.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: expr=[k@1 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] +05)--------CoalesceBatchesExec: target_batch_size=1 +06)----------RepartitionExec: partitioning=Hash([k@1], 2), input_partitions=2 +07)------------ProjectionExec: expr=[CAST(v@1 AS Int64) as __common_expr_1, k@0 as k, time@2 as time] +08)--------------DataSourceExec: partitions=2, partition_sizes=[5, 4] + + +# FILTER clause with window functions + +# Verify FILTER clause with non-aggregate window functions fails with a clear message +query error DataFusion error: Error during planning: FILTER clause can only be used with aggregate window functions\. Found in 'row_number\(\) FILTER \(WHERE test\.c1 > Int64\(0\)\) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING' +SELECT +c1, +ROW_NUMBER() FILTER(WHERE c1 > 0) OVER () as rn1 +FROM test +LIMIT 5 + + +query error DataFusion error: Error during planning: FILTER clause can only be used with aggregate window functions\. Found in 'first_value\(test\.c1\) FILTER \(WHERE test\.c1 > Int64\(0\)\) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING' +SELECT +c1, +FIRST_VALUE(c1) FILTER(WHERE c1 > 0) OVER () as rn1 +FROM test +LIMIT 5 + + +query error DataFusion error: Error during planning: FILTER clause can only be used with aggregate window functions\. Found in 'lag\(test\.c1\) FILTER \(WHERE test\.c1 > Int64\(0\)\) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING' +SELECT +c1, +LAG(c1) FILTER(WHERE c1 > 0) OVER () as rn1 +FROM test +LIMIT 5 + + +# Check error propagation from filter to window function +query error +SELECT +c1, +SUM(c2) FILTER (WHERE c2 >= []) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum1, +FROM test +LIMIT 5 +---- +DataFusion error: type_coercion +caused by +Error during planning: Cannot infer common argument type for comparison operation Int64 >= List(Field { name: "item", data_type: Null, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + + +# EXPLAIN should display the filters +query TT +EXPLAIN SELECT +c1, +c2, +SUM(c2) FILTER (WHERE c2 >= 2) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum1, +SUM(c2) FILTER (WHERE c2 >= 2 AND c2 < 4 AND c1 > 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum2, +COUNT(c2) FILTER (WHERE c2 >= 2) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as count1, +ARRAY_AGG(c2) FILTER (WHERE c2 >= 2) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as array_agg1, +ARRAY_AGG(c2) FILTER (WHERE c2 >= 2 AND c2 < 4 AND c1 > 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as array_agg2, +FROM test +ORDER BY c1, c2 +LIMIT 5 +---- +logical_plan +01)Sort: test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST, fetch=5 +02)--Projection: test.c1, test.c2, sum(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, sum(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, count(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS count1, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS array_agg1, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS array_agg2 +03)----WindowAggr: windowExpr=[[sum(test.c2) FILTER (WHERE __common_expr_1 AS test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(test.c2) FILTER (WHERE __common_expr_2) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, count(test.c2) FILTER (WHERE __common_expr_1 AS test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS count(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, array_agg(test.c2) FILTER (WHERE __common_expr_1 AS test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, array_agg(test.c2) FILTER (WHERE __common_expr_2) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------Projection: __common_expr_3 AS __common_expr_1, __common_expr_3 AND test.c2 < Int64(4) AND test.c1 > Int32(0) AS __common_expr_2, test.c1, test.c2 +05)--------Projection: test.c2 >= Int64(2) AS __common_expr_3, test.c1, test.c2 +06)----------TableScan: test projection=[c1, c2] +physical_plan +01)ProjectionExec: expr=[c1@2 as c1, c2@3 as c2, sum(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum1, sum(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum2, count(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as count1, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@7 as array_agg1, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as array_agg2] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[sum(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, count(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "count(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortPreservingMergeExec: [c1@2 ASC NULLS LAST, c2@3 ASC NULLS LAST], fetch=5 +05)--------SortExec: TopK(fetch=5), expr=[c1@2 ASC NULLS LAST, c2@3 ASC NULLS LAST], preserve_partitioning=[true] +06)----------ProjectionExec: expr=[__common_expr_3@0 as __common_expr_1, __common_expr_3@0 AND c2@2 < 4 AND c1@1 > 0 as __common_expr_2, c1@1 as c1, c2@2 as c2] +07)------------ProjectionExec: expr=[c2@1 >= 2 as __common_expr_3, c1@0 as c1, c2@1 as c2] +08)--------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_csv/partition-0.csv], [WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_csv/partition-1.csv], [WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_csv/partition-2.csv], [WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_csv/partition-3.csv]]}, projection=[c1, c2], file_type=csv, has_header=false + + +# FILTER filters out some rows +query IIIII?? +SELECT +c1, +c2, +SUM(c2) FILTER (WHERE c2 >= 2) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum1, +SUM(c2) FILTER (WHERE c2 >= 2 AND c2 < 4 AND c1 >= 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum2, +COUNT(c2) FILTER (WHERE c2 >= 2) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as count1, +ARRAY_AGG(c2) FILTER (WHERE c2 >= 2) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as array_agg1, +ARRAY_AGG(c2) FILTER (WHERE c2 >= 2 AND c2 < 4 AND c1 >= 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as array_agg2, +FROM test +ORDER BY c1, c2 +LIMIT 5 +---- +0 0 NULL NULL 0 NULL NULL +0 1 NULL NULL 0 NULL NULL +0 2 2 2 1 [2] [2] +0 3 5 5 2 [2, 3] [2, 3] +0 4 9 5 3 [2, 3, 4] [2, 3] + + +# FILTER filters out no rows +query IIIII?? +SELECT +c1, +c2, +SUM(c2) FILTER (WHERE c2 >= 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum1, +SUM(c2) FILTER (WHERE c2 >= 0 AND c2 < 1000 AND c1 >= 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum2, +COUNT(c2) FILTER (WHERE c2 >= 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as count1, +ARRAY_AGG(c2) FILTER (WHERE c2 >= 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as array_agg1, +ARRAY_AGG(c2) FILTER (WHERE c2 >= 0 AND c2 < 1000 AND c1 >= 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as array_agg2, +FROM test +ORDER BY c1, c2 +LIMIT 5 +---- +0 0 0 0 1 [0] [0] +0 1 1 1 2 [0, 1] [0, 1] +0 2 3 3 3 [0, 1, 2] [0, 1, 2] +0 3 6 6 4 [0, 1, 2, 3] [0, 1, 2, 3] +0 4 10 10 5 [0, 1, 2, 3, 4] [0, 1, 2, 3, 4] + + +# FILTER filters out every row +query IIIII?? +SELECT +c1, +c2, +SUM(c2) FILTER (WHERE c2 == -1) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum1, +SUM(c2) FILTER (WHERE c2 >= 0 AND c2 < 0 AND c1 >= 0) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum2, +COUNT(c2) FILTER (WHERE c2 == -1) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as count1, +ARRAY_AGG(c2) FILTER (WHERE c2 >= 1000) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as array_agg1, +ARRAY_AGG(c2) FILTER (WHERE c2 >= 0 AND c2 < 1000 AND c1 >= 1000) OVER (ORDER BY c1, c2 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as array_agg2, +FROM test +ORDER BY c1, c2 +LIMIT 5 +---- +0 0 NULL NULL 0 NULL NULL +0 1 NULL NULL 0 NULL NULL +0 2 NULL NULL 0 NULL NULL +0 3 NULL NULL 0 NULL NULL +0 4 NULL NULL 0 NULL NULL + +# regression test for https://github.com/apache/datafusion/issues/17401 +query I +WITH source AS ( + SELECT + 1 AS n, + '' AS a1, '' AS a2, '' AS a3, '' AS a4, '' AS a5, '' AS a6, '' AS a7, '' AS a8, + '' AS a9, '' AS a10, '' AS a11, '' AS a12 +) +SELECT + sum(n) OVER (PARTITION BY + a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12 + ) +FROM source; +---- +1 + +# regression test for https://github.com/apache/datafusion/issues/17401 +query I +WITH source AS ( + SELECT + 1 AS n, + '' AS a1, '' AS a2, '' AS a3, '' AS a4, '' AS a5, '' AS a6, '' AS a7, '' AS a8, + '' AS a9, '' AS a10, '' AS a11, '' AS a12, '' AS a13, '' AS a14, '' AS a15, '' AS a16, + '' AS a17, '' AS a18, '' AS a19, '' AS a20, '' AS a21, '' AS a22, '' AS a23, '' AS a24, + '' AS a25, '' AS a26, '' AS a27, '' AS a28, '' AS a29, '' AS a30, '' AS a31, '' AS a32, + '' AS a33, '' AS a34, '' AS a35, '' AS a36, '' AS a37, '' AS a38, '' AS a39, '' AS a40 +) +SELECT + sum(n) OVER (PARTITION BY + a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, + a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31, a32, a33, a34, a35, a36, a37, a38, a39, a40 + ) +FROM source; +---- +1 + +# regression test for https://github.com/apache/datafusion/issues/17401 +query I +WITH source AS ( + SELECT + 1 AS n, + '' AS a1, '' AS a2, '' AS a3, '' AS a4, '' AS a5, '' AS a6, '' AS a7, '' AS a8, + '' AS a9, '' AS a10, '' AS a11, '' AS a12, '' AS a13, '' AS a14, '' AS a15, '' AS a16, + '' AS a17, '' AS a18, '' AS a19, '' AS a20, '' AS a21, '' AS a22, '' AS a23, '' AS a24, + '' AS a25, '' AS a26, '' AS a27, '' AS a28, '' AS a29, '' AS a30, '' AS a31, '' AS a32, + '' AS a33, '' AS a34, '' AS a35, '' AS a36, '' AS a37, '' AS a38, '' AS a39, '' AS a40 +) +SELECT + sum(n) OVER (PARTITION BY + a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, + a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31, a32, a33, a34, a35, a36, a37, a38, a39, a40 + ) +FROM ( + SELECT * FROM source + ORDER BY a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, + a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31, a32, a33, a34, a35, a36, a37, a38, a39, a40 +); +---- +1 + +# regression test for https://github.com/apache/datafusion/issues/17401 +query I +WITH source AS ( + SELECT + 1 AS n, + '' AS a1, '' AS a2, '' AS a3, '' AS a4, '' AS a5, '' AS a6, '' AS a7, '' AS a8, + '' AS a9, '' AS a10, '' AS a11, '' AS a12, '' AS a13, '' AS a14, '' AS a15, '' AS a16, + '' AS a17, '' AS a18, '' AS a19, '' AS a20, '' AS a21, '' AS a22, '' AS a23, '' AS a24, + '' AS a25, '' AS a26, '' AS a27, '' AS a28, '' AS a29, '' AS a30, '' AS a31, '' AS a32, + '' AS a33, '' AS a34, '' AS a35, '' AS a36, '' AS a37, '' AS a38, '' AS a39, '' AS a40 +) +SELECT + sum(n) OVER (PARTITION BY + a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, + a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31, a32, a33, a34, a35, a36, a37, a38, a39, a40 + ) +FROM ( + SELECT * FROM source + WHERE a1 = '' AND a2 = '' AND a3 = '' AND a4 = '' AND a5 = '' AND a6 = '' AND a7 = '' AND a8 = '' + AND a9 = '' AND a10 = '' AND a11 = '' AND a12 = '' AND a13 = '' AND a14 = '' AND a15 = '' AND a16 = '' + AND a17 = '' AND a18 = '' AND a19 = '' AND a20 = '' AND a21 = '' AND a22 = '' AND a23 = '' AND a24 = '' + AND a25 = '' AND a26 = '' AND a27 = '' AND a28 = '' AND a29 = '' AND a30 = '' AND a31 = '' AND a32 = '' + AND a33 = '' AND a34 = '' AND a35 = '' AND a36 = '' AND a37 = '' AND a38 = '' AND a39 = '' AND a40 = '' + ORDER BY a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, + a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31, a32, a33, a34, a35, a36, a37, a38, a39, a40 +); +---- +1 + +# window_with_subquery_rewritten_to_join +# the optimizer `scalar_subquery_to_join` rewrites +# `WHERE acctbal > ( SELECT AVG(acctbal) FROM suppliers)` into a Join, +# breaking the input schema passed to the window function above. +# See: https://github.com/apache/datafusion/issues/17770 +query I +WITH suppliers AS ( + SELECT * + FROM (VALUES (1, 10.0), (1, 20.0)) AS t(nation, acctbal) +) +SELECT + ROW_NUMBER() OVER (PARTITION BY nation ORDER BY acctbal DESC) AS rn +FROM suppliers AS s +WHERE acctbal > ( + SELECT AVG(acctbal) FROM suppliers +); +---- +1 diff --git a/datafusion/sqllogictest/test_files/window_limits.slt b/datafusion/sqllogictest/test_files/window_limits.slt new file mode 100644 index 0000000000000..c1e680084f4b7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/window_limits.slt @@ -0,0 +1,769 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# see https://datafusion.apache.org/user-guide/sql/window_functions.html#syntax for field names & examples +statement ok +CREATE EXTERNAL TABLE employees ( + depname VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + empno INT NOT NULL, + salary BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL, + hire_date DATE NOT NULL, + c15 TIMESTAMP NOT NULL, +) +STORED AS CSV +LOCATION '../../testing/data/csv/aggregate_test_100_with_dates.csv' +OPTIONS ('format.has_header' 'true'); + +# lead defaults to 1 and should grow limit +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query I +SELECT LEAD(empno) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +299 +363 +417 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query I +SELECT LEAD(empno) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +299 +363 +417 + +query TT +EXPLAIN +SELECT LEAD(empno) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +logical_plan +01)Projection: lead(employees.empno) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--Limit: skip=0, fetch=3 +03)----WindowAggr: windowExpr=[[lead(employees.empno) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno] +physical_plan +01)ProjectionExec: expr=[lead(employees.empno) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as lead(employees.empno) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----BoundedWindowAggExec: wdw=[lead(employees.empno) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lead(employees.empno) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=4), expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno], file_type=csv, has_header=true + +# lead defaults can lookahead by any amount and should grow limit +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query I +SELECT LEAD(empno, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +363 +417 +794 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query I +SELECT LEAD(empno, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +363 +417 +794 + +query TT +EXPLAIN +SELECT LEAD(empno, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +logical_plan +01)Projection: lead(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--Limit: skip=0, fetch=3 +03)----WindowAggr: windowExpr=[[lead(employees.empno, Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno] +physical_plan +01)ProjectionExec: expr=[lead(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as lead(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----BoundedWindowAggExec: wdw=[lead(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lead(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=5), expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno], file_type=csv, has_header=true + +# Should use the max of leads +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query IIII +SELECT + empno, + LEAD(salary, 1) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead1, + LEAD(salary, 3) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead3, + LEAD(salary, 5) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead5 +FROM employees +ORDER BY empno +LIMIT 5; +---- +102 28774375 557517119 4015442341 +299 1865307672 4061635107 3542840110 +363 557517119 4015442341 1088543984 +417 4061635107 3542840110 1362369177 +794 4015442341 1088543984 145294611 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query IIII +SELECT + empno, + LEAD(salary, 1) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead1, + LEAD(salary, 3) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead3, + LEAD(salary, 5) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead5 +FROM employees +ORDER BY empno +LIMIT 5; +---- +102 28774375 557517119 4015442341 +299 1865307672 4061635107 3542840110 +363 557517119 4015442341 1088543984 +417 4061635107 3542840110 1362369177 +794 4015442341 1088543984 145294611 + +query TT +EXPLAIN +SELECT + empno, + LEAD(salary, 1) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead1, + LEAD(salary, 3) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead3, + LEAD(salary, 5) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead5 +FROM employees +ORDER BY empno +LIMIT 5; +---- +logical_plan +01)Sort: employees.empno ASC NULLS LAST, fetch=5 +02)--Projection: employees.empno, lead(employees.salary,Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, lead(employees.salary,Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead3, lead(employees.salary,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead5 +03)----WindowAggr: windowExpr=[[lead(employees.salary, Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lead(employees.salary, Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lead(employees.salary, Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno, salary] +physical_plan +01)ProjectionExec: expr=[empno@0 as empno, lead(employees.salary,Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as lead1, lead(employees.salary,Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as lead3, lead(employees.salary,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as lead5] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[lead(employees.salary,Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lead(employees.salary,Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lead(employees.salary,Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lead(employees.salary,Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lead(employees.salary,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lead(employees.salary,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=10), expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno, salary], file_type=csv, has_header=true + +# 2 < 3... nth_value should not grow the limit +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query I +SELECT NTH_VALUE(empno, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +NULL +299 +299 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query I +SELECT NTH_VALUE(empno, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +NULL +299 +299 + +query TT +EXPLAIN +SELECT NTH_VALUE(empno, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +logical_plan +01)Projection: nth_value(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--Limit: skip=0, fetch=3 +03)----WindowAggr: windowExpr=[[nth_value(employees.empno, Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno] +physical_plan +01)ProjectionExec: expr=[nth_value(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as nth_value(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----BoundedWindowAggExec: wdw=[nth_value(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "nth_value(employees.empno,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=3), expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno], file_type=csv, has_header=true + +# 5 > 3... nth_value still won't grow the limit - it's causal +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query I +SELECT NTH_VALUE(empno, 5) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +NULL +NULL +NULL + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query I +SELECT NTH_VALUE(empno, 5) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +NULL +NULL +NULL + +query TT +EXPLAIN +SELECT NTH_VALUE(empno, 5) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) FROM employees LIMIT 3 +---- +logical_plan +01)Projection: nth_value(employees.empno,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--Limit: skip=0, fetch=3 +03)----WindowAggr: windowExpr=[[nth_value(employees.empno, Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno] +physical_plan +01)ProjectionExec: expr=[nth_value(employees.empno,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as nth_value(employees.empno,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----BoundedWindowAggExec: wdw=[nth_value(employees.empno,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "nth_value(employees.empno,Int64(5)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=3), expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno], file_type=csv, has_header=true + +# aggregate functions shouldn't affect the window +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query TIIRII +SELECT + depname, + empno, + SUM(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_sum, + AVG(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_avg, + MIN(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_min, + MAX(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_max +FROM employees +LIMIT 5; +---- +a 102 3276123488 3276123488 3276123488 3276123488 +e 299 3304897863 1652448931.5 28774375 3276123488 +a 363 5170205535 1723401845 28774375 3276123488 +e 417 5727722654 1431930663.5 28774375 3276123488 +d 794 9789357761 1957871552.2 28774375 4061635107 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query TIIRII +SELECT + depname, + empno, + SUM(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_sum, + AVG(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_avg, + MIN(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_min, + MAX(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_max +FROM employees +LIMIT 5; +---- +a 102 3276123488 3276123488 3276123488 3276123488 +e 299 3304897863 1652448931.5 28774375 3276123488 +a 363 5170205535 1723401845 28774375 3276123488 +e 417 5727722654 1431930663.5 28774375 3276123488 +d 794 9789357761 1957871552.2 28774375 4061635107 + +query TT +EXPLAIN +SELECT + depname, + empno, + SUM(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_sum, + AVG(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_avg, + MIN(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_min, + MAX(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_max +FROM employees +LIMIT 5; +---- +logical_plan +01)Projection: employees.depname, employees.empno, sum(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS running_sum, avg(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS running_avg, min(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS running_min, max(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS running_max +02)--Limit: skip=0, fetch=5 +03)----WindowAggr: windowExpr=[[sum(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, avg(CAST(employees.salary AS Float64)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, max(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[depname, empno, salary] +physical_plan +01)ProjectionExec: expr=[depname@0 as depname, empno@1 as empno, sum(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as running_sum, avg(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as running_avg, min(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as running_min, max(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as running_max] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[sum(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, avg(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "avg(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "min(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, max(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "max(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=5), expr=[empno@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[depname, empno, salary], file_type=csv, has_header=true + +# ranking functions that don't affect the limit +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query IIII +SELECT + empno, + row_number() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rn, + rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rnk, + dense_rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS drnk +FROM employees +ORDER BY empno +LIMIT 5; +---- +102 1 1 1 +299 2 2 2 +363 3 3 3 +417 4 4 4 +794 5 5 5 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query IIII +SELECT + empno, + row_number() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rn, + rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rnk, + dense_rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS drnk +FROM employees +ORDER BY empno +LIMIT 5; +---- +102 1 1 1 +299 2 2 2 +363 3 3 3 +417 4 4 4 +794 5 5 5 + +query TT +EXPLAIN +SELECT + empno, + row_number() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rn, + rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rnk, + dense_rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS drnk +FROM employees +ORDER BY empno +LIMIT 5; +---- +logical_plan +01)Sort: employees.empno ASC NULLS LAST, fetch=5 +02)--Projection: employees.empno, row_number() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn, rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rnk, dense_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS drnk +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, dense_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno] +physical_plan +01)ProjectionExec: expr=[empno@0 as empno, row_number() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn, rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rnk, dense_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as drnk] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "row_number() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, dense_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "dense_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=5), expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno], file_type=csv, has_header=true + +# Unoptimizable global ranking functions +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query IRRI +SELECT + empno, + percent_rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS pr, + cume_dist() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cd, + ntile(4) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nt +FROM employees +ORDER BY empno +LIMIT 5; +---- +102 0 0.01 1 +299 0.010101010101 0.02 1 +363 0.020202020202 0.03 1 +417 0.030303030303 0.04 1 +794 0.040404040404 0.05 1 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query IRRI +SELECT + empno, + percent_rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS pr, + cume_dist() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cd, + ntile(4) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nt +FROM employees +ORDER BY empno +LIMIT 5; +---- +102 0 0.01 1 +299 0.010101010101 0.02 1 +363 0.020202020202 0.03 1 +417 0.030303030303 0.04 1 +794 0.040404040404 0.05 1 + +query TT +EXPLAIN +SELECT + empno, + percent_rank() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS pr, + cume_dist() OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cd, + ntile(4) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nt +FROM employees +ORDER BY empno +LIMIT 5; +---- +logical_plan +01)Sort: employees.empno ASC NULLS LAST, fetch=5 +02)--Projection: employees.empno, percent_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS pr, cume_dist() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS cd, ntile(Int64(4)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS nt +03)----WindowAggr: windowExpr=[[percent_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, cume_dist() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, ntile(Int64(4)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno] +physical_plan +01)ProjectionExec: expr=[empno@0 as empno, percent_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as pr, cume_dist() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as cd, ntile(Int64(4)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as nt] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----WindowAggExec: wdw=[percent_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "percent_rank() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, cume_dist() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "cume_dist() ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }, ntile(Int64(4)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ntile(Int64(4)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }] +04)------SortExec: expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno], file_type=csv, has_header=true + +# Analytical functions that don't lookahead +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query IIIII +SELECT + empno, + first_value(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS fv, + lag(salary, 1) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS l1, + last_value(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lv, + nth_value(salary, 3) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS n3 +FROM employees +ORDER BY empno +LIMIT 5; +---- +102 3276123488 NULL 3276123488 NULL +299 3276123488 3276123488 28774375 NULL +363 3276123488 28774375 1865307672 1865307672 +417 3276123488 1865307672 557517119 1865307672 +794 3276123488 557517119 4061635107 1865307672 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query IIIII +SELECT + empno, + first_value(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS fv, + lag(salary, 1) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS l1, + last_value(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lv, + nth_value(salary, 3) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS n3 +FROM employees +ORDER BY empno +LIMIT 5; +---- +102 3276123488 NULL 3276123488 NULL +299 3276123488 3276123488 28774375 NULL +363 3276123488 28774375 1865307672 1865307672 +417 3276123488 1865307672 557517119 1865307672 +794 3276123488 557517119 4061635107 1865307672 + +query TT +EXPLAIN +SELECT + empno, + first_value(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS fv, + lag(salary, 1) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS l1, + last_value(salary) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lv, + nth_value(salary, 3) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS n3 +FROM employees +ORDER BY empno +LIMIT 5; +---- +logical_plan +01)Sort: employees.empno ASC NULLS LAST, fetch=5 +02)--Projection: employees.empno, first_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS fv, lag(employees.salary,Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS l1, last_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lv, nth_value(employees.salary,Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS n3 +03)----WindowAggr: windowExpr=[[first_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lag(employees.salary, Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, last_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(employees.salary, Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno, salary] +physical_plan +01)ProjectionExec: expr=[empno@0 as empno, first_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as fv, lag(employees.salary,Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as l1, last_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as lv, nth_value(employees.salary,Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as n3] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[first_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "first_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lag(employees.salary,Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lag(employees.salary,Int64(1)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, last_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "last_value(employees.salary) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, nth_value(employees.salary,Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "nth_value(employees.salary,Int64(3)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=5), expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno, salary], file_type=csv, has_header=true + +# should handle partition by unoptimized +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query TIII +SELECT depname, empno, salary, SUM(salary) OVER ( + PARTITION BY depname + ORDER BY empno + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ) AS running_sum +FROM employees +ORDER BY depname +LIMIT 5 +---- +a 102 3276123488 3276123488 +a 363 1865307672 5141431160 +a 829 4015442341 5880750013 +a 2555 145294611 4160736952 +a 2809 754775609 900070220 + +query TT +EXPLAIN +SELECT depname, empno, salary, SUM(salary) OVER ( + PARTITION BY depname + ORDER BY empno + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ) AS running_sum +FROM employees +ORDER BY depname +LIMIT 5 +---- +logical_plan +01)Sort: employees.depname ASC NULLS LAST, fetch=5 +02)--Projection: employees.depname, employees.empno, employees.salary, sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW AS running_sum +03)----WindowAggr: windowExpr=[[sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[depname, empno, salary] +physical_plan +01)SortPreservingMergeExec: [depname@0 ASC NULLS LAST], fetch=5 +02)--ProjectionExec: expr=[depname@0 as depname, empno@1 as empno, salary@2 as salary, sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW@3 as running_sum] +03)----BoundedWindowAggExec: wdw=[sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW: Field { name: "sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: expr=[depname@0 ASC NULLS LAST, empno@1 ASC NULLS LAST], preserve_partitioning=[true] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([depname@0], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[depname, empno, salary], file_type=csv, has_header=true + +# should handle partition by optimized +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query TIII +SELECT depname, empno, salary, SUM(salary) OVER ( + PARTITION BY depname + ORDER BY empno + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ) AS running_sum +FROM employees +ORDER BY depname +LIMIT 5 +---- +a 102 3276123488 3276123488 +a 363 1865307672 5141431160 +a 829 4015442341 5880750013 +a 2555 145294611 4160736952 +a 2809 754775609 900070220 + +query TT +EXPLAIN +SELECT depname, empno, salary, SUM(salary) OVER ( + PARTITION BY depname + ORDER BY empno + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ) AS running_sum +FROM employees +ORDER BY depname +LIMIT 5 +---- +logical_plan +01)Sort: employees.depname ASC NULLS LAST, fetch=5 +02)--Projection: employees.depname, employees.empno, employees.salary, sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW AS running_sum +03)----WindowAggr: windowExpr=[[sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[depname, empno, salary] +physical_plan +01)SortPreservingMergeExec: [depname@0 ASC NULLS LAST], fetch=5 +02)--ProjectionExec: expr=[depname@0 as depname, empno@1 as empno, salary@2 as salary, sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW@3 as running_sum] +03)----BoundedWindowAggExec: wdw=[sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW: Field { name: "sum(employees.salary) PARTITION BY [employees.depname] ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN 1 PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=5), expr=[depname@0 ASC NULLS LAST, empno@1 ASC NULLS LAST], preserve_partitioning=[true] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([depname@0], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[depname, empno, salary], file_type=csv, has_header=true + +# unbounded following +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query I +SELECT LEAD(salary) OVER (ORDER BY empno ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +FROM employees +LIMIT 5; +---- +28774375 +1865307672 +557517119 +4061635107 +4015442341 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query I +SELECT LEAD(salary) OVER (ORDER BY empno ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +FROM employees +LIMIT 5; +---- +28774375 +1865307672 +557517119 +4061635107 +4015442341 + +# RANGE +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query I +SELECT LEAD(salary) OVER (ORDER BY empno RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +FROM employees +LIMIT 5; +---- +28774375 +1865307672 +557517119 +4061635107 +4015442341 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query I +SELECT LEAD(salary) OVER (ORDER BY empno RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +FROM employees +LIMIT 5; +---- +28774375 +1865307672 +557517119 +4061635107 +4015442341 + +# multiple windows +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query II +SELECT + LEAD(salary, 1) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), + LEAD(salary, 5) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +FROM employees +LIMIT 5; +---- +28774375 4015442341 +1865307672 3542840110 +557517119 1088543984 +4061635107 1362369177 +4015442341 145294611 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query II +SELECT + LEAD(salary, 1) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), + LEAD(salary, 5) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +FROM employees +LIMIT 5; +---- +28774375 4015442341 +1865307672 3542840110 +557517119 1088543984 +4061635107 1362369177 +4015442341 145294611 + +# sliding +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query III +SELECT + empno, + salary, + SUM(salary) OVER (ORDER BY empno ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS sliding_sum +FROM employees +LIMIT 3; +---- +102 3276123488 3276123488 +299 28774375 3304897863 +363 1865307672 5170205535 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query III +SELECT + empno, + salary, + SUM(salary) OVER (ORDER BY empno ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) AS sliding_sum +FROM employees +LIMIT 3; +---- +102 3276123488 3276123488 +299 28774375 3304897863 +363 1865307672 5170205535 + +# sliding lead +statement ok +set datafusion.optimizer.enable_window_limits = false; + +query III +SELECT + empno, + salary, + LEAD(salary, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead2 +FROM employees +LIMIT 3; +---- +102 3276123488 1865307672 +299 28774375 557517119 +363 1865307672 4061635107 + +statement ok +set datafusion.optimizer.enable_window_limits = true; + +query III +SELECT + empno, + salary, + LEAD(salary, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead2 +FROM employees +LIMIT 3; +---- +102 3276123488 1865307672 +299 28774375 557517119 +363 1865307672 4061635107 + +query TT +EXPLAIN +SELECT + empno, + salary, + LEAD(salary, 2) OVER (ORDER BY empno ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS lead2 +FROM employees +LIMIT 3; +---- +logical_plan +01)Projection: employees.empno, employees.salary, lead(employees.salary,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead2 +02)--Limit: skip=0, fetch=3 +03)----WindowAggr: windowExpr=[[lead(employees.salary, Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: employees projection=[empno, salary] +physical_plan +01)ProjectionExec: expr=[empno@0 as empno, salary@1 as salary, lead(employees.salary,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as lead2] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----BoundedWindowAggExec: wdw=[lead(employees.salary,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "lead(employees.salary,Int64(2)) ORDER BY [employees.empno ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +04)------SortExec: TopK(fetch=5), expr=[empno@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[empno, salary], file_type=csv, has_header=true diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index d23e986914fc4..16bb5cff4ad79 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -34,17 +34,18 @@ workspace = true async-recursion = "1.0" async-trait = { workspace = true } chrono = { workspace = true } -datafusion = { workspace = true } +datafusion = { workspace = true, features = ["sql"] } itertools = { workspace = true } object_store = { workspace = true } pbjson-types = { workspace = true } prost = { workspace = true } -substrait = { version = "0.55", features = ["serde"] } +substrait = { version = "0.58", features = ["serde"] } url = { workspace = true } tokio = { workspace = true, features = ["fs"] } +uuid = { version = "1.17.0", features = ["v4"] } [dev-dependencies] -datafusion = { workspace = true, features = ["nested_expressions"] } +datafusion = { workspace = true, features = ["nested_expressions", "unicode_expressions"] } datafusion-functions-aggregate = { workspace = true } serde_json = "1.0" tokio = { workspace = true } diff --git a/datafusion/substrait/README.md b/datafusion/substrait/README.md index 92bb9abcc6901..d18d7bda5e3b0 100644 --- a/datafusion/substrait/README.md +++ b/datafusion/substrait/README.md @@ -19,8 +19,12 @@ # Apache DataFusion Substrait -This crate contains a [Substrait] producer and consumer for Apache Arrow -[DataFusion] plans. See [API Docs] for details and examples. +[Apache DataFusion] is an extensible query execution framework, written in Rust, that uses [Apache Arrow] as its in-memory format. +This crate is a submodule of DataFusion that provides a [Substrait] producer and consumer for DataFusion +plans. See [API Docs] for details and examples. + +[apache arrow]: https://arrow.apache.org/ +[apache datafusion]: https://datafusion.apache.org/ [substrait]: https://substrait.io [api docs]: https://docs.rs/datafusion-substrait/latest diff --git a/datafusion/substrait/src/extensions.rs b/datafusion/substrait/src/extensions.rs index c74061f2c9f3c..f9a2e0fb82556 100644 --- a/datafusion/substrait/src/extensions.rs +++ b/datafusion/substrait/src/extensions.rs @@ -45,6 +45,8 @@ impl Extensions { // Rename those to match the Substrait extensions for interoperability let function_name = match function_name.as_str() { "substr" => "substring".to_string(), + "log" => "logb".to_string(), + "isnan" => "is_nan".to_string(), _ => function_name, }; diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 0f2fbf199be35..9a4f44e81df23 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs deleted file mode 100644 index 61f3379735c7d..0000000000000 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ /dev/null @@ -1,3453 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; -use arrow::buffer::OffsetBuffer; -use async_recursion::async_recursion; -use datafusion::arrow::array::MapArray; -use datafusion::arrow::datatypes::{ - DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, -}; -use datafusion::common::{ - not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, - substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, Spans, - TableReference, -}; -use datafusion::datasource::provider_as_source; -use datafusion::logical_expr::expr::{Exists, InSubquery, Sort, WindowFunctionParams}; - -use datafusion::logical_expr::{ - Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, - LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values, -}; -use substrait::proto::aggregate_rel::Grouping; -use substrait::proto::expression as substrait_expression; -use substrait::proto::expression::subquery::set_predicate::PredicateOp; -use substrait::proto::expression_reference::ExprType; -use url::Url; - -use crate::extensions::Extensions; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, - VIEW_CONTAINER_TYPE_VARIATION_REF, -}; -#[allow(deprecated)] -use crate::variation_const::{ - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, - INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, - TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, - TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, -}; -use async_trait::async_trait; -use datafusion::arrow::array::{new_empty_array, AsArray}; -use datafusion::arrow::temporal_conversions::NANOSECONDS; -use datafusion::catalog::TableProvider; -use datafusion::common::scalar::ScalarStructBuilder; -use datafusion::execution::{FunctionRegistry, SessionState}; -use datafusion::logical_expr::builder::project; -use datafusion::logical_expr::expr::InList; -use datafusion::logical_expr::{ - col, expr, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, -}; -use datafusion::prelude::{lit, JoinType}; -use datafusion::{ - arrow, error::Result, logical_expr::utils::split_conjunction, - logical_expr::utils::split_conjunction_owned, prelude::Column, scalar::ScalarValue, -}; -use std::collections::HashSet; -use std::sync::Arc; -use substrait::proto; -use substrait::proto::exchange_rel::ExchangeKind; -use substrait::proto::expression::cast::FailureBehavior::ReturnNull; -use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::{ - interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, -}; -use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{ - Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction, - SingularOrList, SwitchExpression, WindowFunction, -}; -use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; -use substrait::proto::rel_common::{Emit, EmitKind}; -use substrait::proto::set_rel::SetOp; -use substrait::proto::{ - aggregate_function::AggregationInvocation, - expression::{ - field_reference::ReferenceType::DirectReference, literal::LiteralType, - reference_segment::ReferenceType::StructField, - window_function::bound as SubstraitBound, - window_function::bound::Kind as BoundKind, window_function::Bound, - window_function::BoundsType, MaskExpression, RexType, - }, - fetch_rel, - function_argument::ArgType, - join_rel, plan_rel, r#type, - read_rel::ReadType, - rel::RelType, - rel_common, - sort_field::{SortDirection, SortKind::*}, - AggregateFunction, AggregateRel, ConsistentPartitionWindowRel, CrossRel, - DynamicParameter, ExchangeRel, Expression, ExtendedExpression, ExtensionLeafRel, - ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, - JoinRel, NamedStruct, Plan, ProjectRel, ReadRel, Rel, RelCommon, SetRel, SortField, - SortRel, Type, -}; - -#[async_trait] -/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. -/// It can be implemented by users to allow for custom handling of relations, expressions, etc. -/// -/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully -/// customizable Substrait serde. -/// -/// # Example Usage -/// -/// ``` -/// # use async_trait::async_trait; -/// # use datafusion::catalog::TableProvider; -/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference}; -/// # use datafusion::error::Result; -/// # use datafusion::execution::{FunctionRegistry, SessionState}; -/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; -/// # use std::sync::Arc; -/// # use substrait::proto; -/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel}; -/// # use datafusion::arrow::datatypes::DataType; -/// # use datafusion::logical_expr::expr::ScalarFunction; -/// # use datafusion_substrait::extensions::Extensions; -/// # use datafusion_substrait::logical_plan::consumer::{ -/// # from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer -/// # }; -/// -/// struct CustomSubstraitConsumer { -/// extensions: Arc, -/// state: Arc, -/// } -/// -/// #[async_trait] -/// impl SubstraitConsumer for CustomSubstraitConsumer { -/// async fn resolve_table_ref( -/// &self, -/// table_ref: &TableReference, -/// ) -> Result>> { -/// let table = table_ref.table().to_string(); -/// let schema = self.state.schema_for_ref(table_ref.clone())?; -/// let table_provider = schema.table(&table).await?; -/// Ok(table_provider) -/// } -/// -/// fn get_extensions(&self) -> &Extensions { -/// self.extensions.as_ref() -/// } -/// -/// fn get_function_registry(&self) -> &impl FunctionRegistry { -/// self.state.as_ref() -/// } -/// -/// // You can reuse existing consumer code to assist in handling advanced extensions -/// async fn consume_project(&self, rel: &ProjectRel) -> Result { -/// let df_plan = from_project_rel(self, rel).await?; -/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() { -/// not_impl_err!( -/// "decode and handle an advanced extension: {:?}", -/// advanced_extension -/// ) -/// } else { -/// Ok(df_plan) -/// } -/// } -/// -/// // You can implement a fully custom consumer method if you need special handling -/// async fn consume_filter(&self, rel: &FilterRel) -> Result { -/// let input = self.consume_rel(rel.input.as_ref().unwrap()).await?; -/// let expression = -/// self.consume_expression(rel.condition.as_ref().unwrap(), input.schema()) -/// .await?; -/// // though this one is quite boring -/// LogicalPlanBuilder::from(input).filter(expression)?.build() -/// } -/// -/// // You can add handlers for extension relations -/// async fn consume_extension_leaf( -/// &self, -/// rel: &ExtensionLeafRel, -/// ) -> Result { -/// not_impl_err!( -/// "handle protobuf Any {} as you need", -/// rel.detail.as_ref().unwrap().type_url -/// ) -/// } -/// -/// // and handlers for user-define types -/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result { -/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap(); -/// match type_string.as_str() { -/// "u!foo" => not_impl_err!("handle foo conversion"), -/// "u!bar" => not_impl_err!("handle bar conversion"), -/// _ => substrait_err!("unexpected type") -/// } -/// } -/// -/// // and user-defined literals -/// fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result { -/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap(); -/// match type_string.as_str() { -/// "u!foo" => not_impl_err!("handle foo conversion"), -/// "u!bar" => not_impl_err!("handle bar conversion"), -/// _ => substrait_err!("unexpected type") -/// } -/// } -/// } -/// ``` -/// -pub trait SubstraitConsumer: Send + Sync + Sized { - async fn resolve_table_ref( - &self, - table_ref: &TableReference, - ) -> Result>>; - - // TODO: Remove these two methods - // Ideally, the abstract consumer should not place any constraints on implementations. - // The functionality for which the Extensions and FunctionRegistry is needed should be abstracted - // out into methods on the trait. As an example, resolve_table_reference is such a method. - // See: https://github.com/apache/datafusion/issues/13863 - fn get_extensions(&self) -> &Extensions; - fn get_function_registry(&self) -> &impl FunctionRegistry; - - // Relation Methods - // There is one method per Substrait relation to allow for easy overriding of consumer behaviour. - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - /// All [Rel]s to be converted pass through this method. - /// You can provide your own implementation if you wish to customize the conversion behaviour. - async fn consume_rel(&self, rel: &Rel) -> Result { - from_substrait_rel(self, rel).await - } - - async fn consume_read(&self, rel: &ReadRel) -> Result { - from_read_rel(self, rel).await - } - - async fn consume_filter(&self, rel: &FilterRel) -> Result { - from_filter_rel(self, rel).await - } - - async fn consume_fetch(&self, rel: &FetchRel) -> Result { - from_fetch_rel(self, rel).await - } - - async fn consume_aggregate(&self, rel: &AggregateRel) -> Result { - from_aggregate_rel(self, rel).await - } - - async fn consume_sort(&self, rel: &SortRel) -> Result { - from_sort_rel(self, rel).await - } - - async fn consume_join(&self, rel: &JoinRel) -> Result { - from_join_rel(self, rel).await - } - - async fn consume_project(&self, rel: &ProjectRel) -> Result { - from_project_rel(self, rel).await - } - - async fn consume_set(&self, rel: &SetRel) -> Result { - from_set_rel(self, rel).await - } - - async fn consume_cross(&self, rel: &CrossRel) -> Result { - from_cross_rel(self, rel).await - } - - async fn consume_consistent_partition_window( - &self, - _rel: &ConsistentPartitionWindowRel, - ) -> Result { - not_impl_err!("Consistent Partition Window Rel not supported") - } - - async fn consume_exchange(&self, rel: &ExchangeRel) -> Result { - from_exchange_rel(self, rel).await - } - - // Expression Methods - // There is one method per Substrait expression to allow for easy overriding of consumer behaviour - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - /// All [Expression]s to be converted pass through this method. - /// You can provide your own implementation if you wish to customize the conversion behaviour. - async fn consume_expression( - &self, - expr: &Expression, - input_schema: &DFSchema, - ) -> Result { - from_substrait_rex(self, expr, input_schema).await - } - - async fn consume_literal(&self, expr: &Literal) -> Result { - from_literal(self, expr).await - } - - async fn consume_field_reference( - &self, - expr: &FieldReference, - input_schema: &DFSchema, - ) -> Result { - from_field_reference(self, expr, input_schema).await - } - - async fn consume_scalar_function( - &self, - expr: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - from_scalar_function(self, expr, input_schema).await - } - - async fn consume_window_function( - &self, - expr: &WindowFunction, - input_schema: &DFSchema, - ) -> Result { - from_window_function(self, expr, input_schema).await - } - - async fn consume_if_then( - &self, - expr: &IfThen, - input_schema: &DFSchema, - ) -> Result { - from_if_then(self, expr, input_schema).await - } - - async fn consume_switch( - &self, - _expr: &SwitchExpression, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Switch expression not supported") - } - - async fn consume_singular_or_list( - &self, - expr: &SingularOrList, - input_schema: &DFSchema, - ) -> Result { - from_singular_or_list(self, expr, input_schema).await - } - - async fn consume_multi_or_list( - &self, - _expr: &MultiOrList, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Multi Or List expression not supported") - } - - async fn consume_cast( - &self, - expr: &substrait_expression::Cast, - input_schema: &DFSchema, - ) -> Result { - from_cast(self, expr, input_schema).await - } - - async fn consume_subquery( - &self, - expr: &substrait_expression::Subquery, - input_schema: &DFSchema, - ) -> Result { - from_subquery(self, expr, input_schema).await - } - - async fn consume_nested( - &self, - _expr: &Nested, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Nested expression not supported") - } - - async fn consume_enum(&self, _expr: &Enum, _input_schema: &DFSchema) -> Result { - not_impl_err!("Enum expression not supported") - } - - async fn consume_dynamic_parameter( - &self, - _expr: &DynamicParameter, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Dynamic Parameter expression not supported") - } - - // User-Defined Functionality - - // The details of extension relations, and how to handle them, are fully up to users to specify. - // The following methods allow users to customize the consumer behaviour - - async fn consume_extension_leaf( - &self, - rel: &ExtensionLeafRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionLeafRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionLeafRel") - } - - async fn consume_extension_single( - &self, - rel: &ExtensionSingleRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionSingleRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionSingleRel") - } - - async fn consume_extension_multi( - &self, - rel: &ExtensionMultiRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionMultiRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionMultiRel") - } - - // Users can bring their own types to Substrait which require custom handling - - fn consume_user_defined_type( - &self, - user_defined_type: &r#type::UserDefined, - ) -> Result { - substrait_err!( - "Missing handler for user-defined type: {}", - user_defined_type.type_reference - ) - } - - fn consume_user_defined_literal( - &self, - user_defined_literal: &proto::expression::literal::UserDefined, - ) -> Result { - substrait_err!( - "Missing handler for user-defined literals {}", - user_defined_literal.type_reference - ) - } -} - -/// Convert Substrait Rel to DataFusion DataFrame -#[async_recursion] -pub async fn from_substrait_rel( - consumer: &impl SubstraitConsumer, - relation: &Rel, -) -> Result { - let plan: Result = match &relation.rel_type { - Some(rel_type) => match rel_type { - RelType::Read(rel) => consumer.consume_read(rel).await, - RelType::Filter(rel) => consumer.consume_filter(rel).await, - RelType::Fetch(rel) => consumer.consume_fetch(rel).await, - RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await, - RelType::Sort(rel) => consumer.consume_sort(rel).await, - RelType::Join(rel) => consumer.consume_join(rel).await, - RelType::Project(rel) => consumer.consume_project(rel).await, - RelType::Set(rel) => consumer.consume_set(rel).await, - RelType::ExtensionSingle(rel) => consumer.consume_extension_single(rel).await, - RelType::ExtensionMulti(rel) => consumer.consume_extension_multi(rel).await, - RelType::ExtensionLeaf(rel) => consumer.consume_extension_leaf(rel).await, - RelType::Cross(rel) => consumer.consume_cross(rel).await, - RelType::Window(rel) => { - consumer.consume_consistent_partition_window(rel).await - } - RelType::Exchange(rel) => consumer.consume_exchange(rel).await, - rt => not_impl_err!("{rt:?} rel not supported yet"), - }, - None => return substrait_err!("rel must set rel_type"), - }; - apply_emit_kind(retrieve_rel_common(relation), plan?) -} - -/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions. -/// -/// Used as the consumer in [from_substrait_plan] -pub struct DefaultSubstraitConsumer<'a> { - extensions: &'a Extensions, - state: &'a SessionState, -} - -impl<'a> DefaultSubstraitConsumer<'a> { - pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { - DefaultSubstraitConsumer { extensions, state } - } -} - -#[async_trait] -impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { - async fn resolve_table_ref( - &self, - table_ref: &TableReference, - ) -> Result>> { - let table = table_ref.table().to_string(); - let schema = self.state.schema_for_ref(table_ref.clone())?; - let table_provider = schema.table(&table).await?; - Ok(table_provider) - } - - fn get_extensions(&self) -> &Extensions { - self.extensions - } - - fn get_function_registry(&self) -> &impl FunctionRegistry { - self.state - } - - async fn consume_extension_leaf( - &self, - rel: &ExtensionLeafRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } - - async fn consume_extension_single( - &self, - rel: &ExtensionSingleRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let Some(input_rel) = &rel.input else { - return substrait_err!( - "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead" - ); - }; - let input_plan = self.consume_rel(input_rel).await?; - let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } - - async fn consume_extension_multi( - &self, - rel: &ExtensionMultiRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionMultiRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let mut inputs = Vec::with_capacity(rel.inputs.len()); - for input in &rel.inputs { - let input_plan = self.consume_rel(input).await?; - inputs.push(input_plan); - } - let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } -} - -// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which -// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone -// results in correct points on the timeline, and we pick UTC as a reasonable default. -// However, DF uses the timezone also for some arithmetic and display purposes (see e.g. -// https://github.com/apache/arrow-rs/blob/ee5694078c86c8201549654246900a4232d531a9/arrow-cast/src/cast/mod.rs#L1749). -const DEFAULT_TIMEZONE: &str = "UTC"; - -pub fn name_to_op(name: &str) -> Option { - match name { - "equal" => Some(Operator::Eq), - "not_equal" => Some(Operator::NotEq), - "lt" => Some(Operator::Lt), - "lte" => Some(Operator::LtEq), - "gt" => Some(Operator::Gt), - "gte" => Some(Operator::GtEq), - "add" => Some(Operator::Plus), - "subtract" => Some(Operator::Minus), - "multiply" => Some(Operator::Multiply), - "divide" => Some(Operator::Divide), - "mod" => Some(Operator::Modulo), - "modulus" => Some(Operator::Modulo), - "and" => Some(Operator::And), - "or" => Some(Operator::Or), - "is_distinct_from" => Some(Operator::IsDistinctFrom), - "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), - "regex_match" => Some(Operator::RegexMatch), - "regex_imatch" => Some(Operator::RegexIMatch), - "regex_not_match" => Some(Operator::RegexNotMatch), - "regex_not_imatch" => Some(Operator::RegexNotIMatch), - "bitwise_and" => Some(Operator::BitwiseAnd), - "bitwise_or" => Some(Operator::BitwiseOr), - "str_concat" => Some(Operator::StringConcat), - "at_arrow" => Some(Operator::AtArrow), - "arrow_at" => Some(Operator::ArrowAt), - "bitwise_xor" => Some(Operator::BitwiseXor), - "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), - "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), - _ => None, - } -} - -pub fn substrait_fun_name(name: &str) -> &str { - let name = match name.rsplit_once(':') { - // Since 0.32.0, Substrait requires the function names to be in a compound format - // https://substrait.io/extensions/#function-signature-compound-names - // for example, `add:i8_i8`. - // On the consumer side, we don't really care about the signature though, just the name. - Some((name, _)) => name, - None => name, - }; - name -} - -fn split_eq_and_noneq_join_predicate_with_nulls_equality( - filter: &Expr, -) -> (Vec<(Column, Column)>, bool, Option) { - let exprs = split_conjunction(filter); - - let mut accum_join_keys: Vec<(Column, Column)> = vec![]; - let mut accum_filters: Vec = vec![]; - let mut nulls_equal_nulls = false; - - for expr in exprs { - #[allow(clippy::collapsible_match)] - match expr { - Expr::BinaryExpr(binary_expr) => match binary_expr { - x @ (BinaryExpr { - left, - op: Operator::Eq, - right, - } - | BinaryExpr { - left, - op: Operator::IsNotDistinctFrom, - right, - }) => { - nulls_equal_nulls = match x.op { - Operator::Eq => false, - Operator::IsNotDistinctFrom => true, - _ => unreachable!(), - }; - - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - accum_join_keys.push((l.clone(), r.clone())); - } - _ => accum_filters.push(expr.clone()), - } - } - _ => accum_filters.push(expr.clone()), - }, - _ => accum_filters.push(expr.clone()), - } - } - - let join_filter = accum_filters.into_iter().reduce(Expr::and); - (accum_join_keys, nulls_equal_nulls, join_filter) -} - -async fn union_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut union_builder = Ok(LogicalPlanBuilder::from( - consumer.consume_rel(&rels[0]).await?, - )); - for input in &rels[1..] { - let rel_plan = consumer.consume_rel(input).await?; - - union_builder = if is_all { - union_builder?.union(rel_plan) - } else { - union_builder?.union_distinct(rel_plan) - }; - } - union_builder?.build() -} - -async fn intersect_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut rel = consumer.consume_rel(&rels[0]).await?; - - for input in &rels[1..] { - rel = LogicalPlanBuilder::intersect( - rel, - consumer.consume_rel(input).await?, - is_all, - )? - } - - Ok(rel) -} - -async fn except_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut rel = consumer.consume_rel(&rels[0]).await?; - - for input in &rels[1..] { - rel = LogicalPlanBuilder::except(rel, consumer.consume_rel(input).await?, is_all)? - } - - Ok(rel) -} - -/// Convert Substrait Plan to DataFusion LogicalPlan -pub async fn from_substrait_plan( - state: &SessionState, - plan: &Plan, -) -> Result { - // Register function extension - let extensions = Extensions::try_from(&plan.extensions)?; - if !extensions.type_variations.is_empty() { - return not_impl_err!("Type variation extensions are not supported"); - } - - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; - from_substrait_plan_with_consumer(&consumer, plan).await -} - -/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer -pub async fn from_substrait_plan_with_consumer( - consumer: &impl SubstraitConsumer, - plan: &Plan, -) -> Result { - match plan.relations.len() { - 1 => { - match plan.relations[0].rel_type.as_ref() { - Some(rt) => match rt { - plan_rel::RelType::Rel(rel) => Ok(consumer.consume_rel(rel).await?), - plan_rel::RelType::Root(root) => { - let plan = consumer.consume_rel(root.input.as_ref().unwrap()).await?; - if root.names.is_empty() { - // Backwards compatibility for plans missing names - return Ok(plan); - } - let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; - if renamed_schema.has_equivalent_names_and_types(plan.schema()).is_ok() { - // Nothing to do if the schema is already equivalent - return Ok(plan); - } - match plan { - // If the last node of the plan produces expressions, bake the renames into those expressions. - // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), - LogicalPlan::Aggregate(a) => { - let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); - let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; - let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) - }, - // There are probably more plans where we could bake things in, can add them later as needed. - // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) - } - } - }, - None => plan_err!("Cannot parse plan relation: None") - } - }, - _ => not_impl_err!( - "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", - plan.relations.len() - ) - } -} - -/// An ExprContainer is a container for a collection of expressions with a common input schema -/// -/// In addition, each expression is associated with a field, which defines the -/// expression's output. The data type and nullability of the field are calculated from the -/// expression and the input schema. However the names of the field (and its nested fields) are -/// derived from the Substrait message. -pub struct ExprContainer { - /// The input schema for the expressions - pub input_schema: DFSchemaRef, - /// The expressions - /// - /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output - pub exprs: Vec<(Expr, Field)>, -} - -/// Convert Substrait ExtendedExpression to ExprContainer -/// -/// A Substrait ExtendedExpression message contains one or more expressions, -/// with names for the outputs, and an input schema. These pieces are all included -/// in the ExprContainer. -/// -/// This is a top-level message and can be used to send expressions (not plans) -/// between systems. This is often useful for scenarios like pushdown where filter -/// expressions need to be sent to remote systems. -pub async fn from_substrait_extended_expr( - state: &SessionState, - extended_expr: &ExtendedExpression, -) -> Result { - // Register function extension - let extensions = Extensions::try_from(&extended_expr.extensions)?; - if !extensions.type_variations.is_empty() { - return not_impl_err!("Type variation extensions are not supported"); - } - - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; - - let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { - Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), - None => { - plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") - } - }?); - - // Parse expressions - let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); - for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { - let scalar_expr = match &substrait_expr.expr_type { - Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), - Some(ExprType::Measure(_)) => { - not_impl_err!("Measure expressions are not yet supported") - } - None => { - plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") - } - }?; - let expr = consumer - .consume_expression(scalar_expr, &input_schema) - .await?; - let (output_type, expected_nullability) = - expr.data_type_and_nullable(&input_schema)?; - let output_field = Field::new("", output_type, expected_nullability); - let mut names_idx = 0; - let output_field = rename_field( - &output_field, - &substrait_expr.output_names, - expr_idx, - &mut names_idx, - /*rename_self=*/ true, - )?; - exprs.push((expr, output_field)); - } - - Ok(ExprContainer { - input_schema, - exprs, - }) -} - -pub fn apply_masking( - schema: DFSchema, - mask_expression: &::core::option::Option, -) -> Result { - match mask_expression { - Some(MaskExpression { select, .. }) => match &select.as_ref() { - Some(projection) => { - let column_indices: Vec = projection - .struct_items - .iter() - .map(|item| item.field as usize) - .collect(); - - let fields = column_indices - .iter() - .map(|i| schema.qualified_field(*i)) - .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) - }) - .collect(); - - Ok(DFSchema::new_with_metadata( - fields, - schema.metadata().clone(), - )?) - } - None => Ok(schema), - }, - None => Ok(schema), - } -} - -/// Ensure the expressions have the right name(s) according to the new schema. -/// This includes the top-level (column) name, which will be renamed through aliasing if needed, -/// as well as nested names (if the expression produces any struct types), which will be renamed -/// through casting if needed. -fn rename_expressions( - exprs: impl IntoIterator, - input_schema: &DFSchema, - new_schema_fields: &[Arc], -) -> Result> { - exprs - .into_iter() - .zip(new_schema_fields) - .map(|(old_expr, new_field)| { - // Check if type (i.e. nested struct field names) match, use Cast to rename if needed - let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { - Expr::Cast(Cast::new( - Box::new(old_expr), - new_field.data_type().to_owned(), - )) - } else { - old_expr - }; - // Alias column if needed to fix the top-level name - match &new_expr { - // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier - Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), - _ => new_expr.alias_if_changed(new_field.name().to_owned()), - } - }) - .collect() -} - -fn rename_field( - field: &Field, - dfs_names: &Vec, - unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" - name_idx: &mut usize, // Index into dfs_names - rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name -) -> Result { - let name = if rename_self { - next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)? - } else { - field.name().to_string() - }; - match field.data_type() { - DataType::Struct(children) => { - let children = children - .iter() - .enumerate() - .map(|(child_idx, f)| { - rename_field( - f.as_ref(), - dfs_names, - child_idx, - name_idx, - /*rename_self=*/ true, - ) - }) - .collect::>()?; - Ok(field - .to_owned() - .with_name(name) - .with_data_type(DataType::Struct(children))) - } - DataType::List(inner) => { - let renamed_inner = rename_field( - inner.as_ref(), - dfs_names, - 0, - name_idx, - /*rename_self=*/ false, - )?; - Ok(field - .to_owned() - .with_data_type(DataType::List(FieldRef::new(renamed_inner))) - .with_name(name)) - } - DataType::LargeList(inner) => { - let renamed_inner = rename_field( - inner.as_ref(), - dfs_names, - 0, - name_idx, - /*rename_self= */ false, - )?; - Ok(field - .to_owned() - .with_data_type(DataType::LargeList(FieldRef::new(renamed_inner))) - .with_name(name)) - } - _ => Ok(field.to_owned().with_name(name)), - } -} - -/// Produce a version of the given schema with names matching the given list of names. -/// Substrait doesn't deal with column (incl. nested struct field) names within the schema, -/// but it does give us the list of expected names at the end of the plan, so we use this -/// to rename the schema to match the expected names. -fn make_renamed_schema( - schema: &DFSchemaRef, - dfs_names: &Vec, -) -> Result { - let mut name_idx = 0; - - let (qualifiers, fields): (_, Vec) = schema - .iter() - .enumerate() - .map(|(field_idx, (q, f))| { - let renamed_f = rename_field( - f.as_ref(), - dfs_names, - field_idx, - &mut name_idx, - /*rename_self=*/ true, - )?; - Ok((q.cloned(), renamed_f)) - }) - .collect::>>()? - .into_iter() - .unzip(); - - if name_idx != dfs_names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - dfs_names.len()); - } - - DFSchema::from_field_specific_qualified_schema( - qualifiers, - &Arc::new(Schema::new(fields)), - ) -} - -#[async_recursion] -pub async fn from_project_rel( - consumer: &impl SubstraitConsumer, - p: &ProjectRel, -) -> Result { - if let Some(input) = p.input.as_ref() { - let input = consumer.consume_rel(input).await?; - let original_schema = Arc::clone(input.schema()); - - // Ensure that all expressions have a unique display name, so that - // validate_unique_names does not fail when constructing the project. - let mut name_tracker = NameTracker::new(); - - // By default, a Substrait Project emits all inputs fields followed by all expressions. - // We build the explicit expressions first, and then the input expressions to avoid - // adding aliases to the explicit expressions (as part of ensuring unique names). - // - // This is helpful for plan visualization and tests, because when DataFusion produces - // Substrait Projects it adds an output mapping that excludes all input columns - // leaving only explicit expressions. - - let mut explicit_exprs: Vec = vec![]; - // For WindowFunctions, we need to wrap them in a Window relation. If there are duplicates, - // we can do the window'ing only once, then the project will duplicate the result. - // Order here doesn't matter since LPB::window_plan sorts the expressions. - let mut window_exprs: HashSet = HashSet::new(); - for expr in &p.expressions { - let e = consumer - .consume_expression(expr, input.clone().schema()) - .await?; - // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = &e { - // Adding the same expression here and in the project below - // works because the project's builder uses columnize_expr(..) - // to transform it into a column reference - window_exprs.insert(e.clone()); - } - explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - - let input = if !window_exprs.is_empty() { - LogicalPlanBuilder::window_plan(input, window_exprs)? - } else { - input - }; - - let mut final_exprs: Vec = vec![]; - for index in 0..original_schema.fields().len() { - let e = Expr::Column(Column::from(original_schema.qualified_field(index))); - final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - final_exprs.append(&mut explicit_exprs); - project(input, final_exprs) - } else { - not_impl_err!("Projection without an input is not supported") - } -} - -#[async_recursion] -pub async fn from_filter_rel( - consumer: &impl SubstraitConsumer, - filter: &FilterRel, -) -> Result { - if let Some(input) = filter.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - if let Some(condition) = filter.condition.as_ref() { - let expr = consumer - .consume_expression(condition, input.schema()) - .await?; - input.filter(expr)?.build() - } else { - not_impl_err!("Filter without an condition is not valid") - } - } else { - not_impl_err!("Filter without an input is not valid") - } -} - -#[async_recursion] -pub async fn from_fetch_rel( - consumer: &impl SubstraitConsumer, - fetch: &FetchRel, -) -> Result { - if let Some(input) = fetch.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let offset = match &fetch.offset_mode { - Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), - Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { - Some(consumer.consume_expression(expr, &empty_schema).await?) - } - None => None, - }; - let count = match &fetch.count_mode { - Some(fetch_rel::CountMode::Count(count)) => { - // -1 means that ALL records should be returned, equivalent to None - (*count != -1).then(|| lit(*count)) - } - Some(fetch_rel::CountMode::CountExpr(expr)) => { - Some(consumer.consume_expression(expr, &empty_schema).await?) - } - None => None, - }; - input.limit_by_expr(offset, count)?.build() - } else { - not_impl_err!("Fetch without an input is not valid") - } -} - -pub async fn from_sort_rel( - consumer: &impl SubstraitConsumer, - sort: &SortRel, -) -> Result { - if let Some(input) = sort.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; - input.sort(sorts)?.build() - } else { - not_impl_err!("Sort without an input is not valid") - } -} - -pub async fn from_aggregate_rel( - consumer: &impl SubstraitConsumer, - agg: &AggregateRel, -) -> Result { - if let Some(input) = agg.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let mut ref_group_exprs = vec![]; - - for e in &agg.grouping_expressions { - let x = consumer.consume_expression(e, input.schema()).await?; - ref_group_exprs.push(x); - } - - let mut group_exprs = vec![]; - let mut aggr_exprs = vec![]; - - match agg.groupings.len() { - 1 => { - group_exprs.extend_from_slice( - &from_substrait_grouping( - consumer, - &agg.groupings[0], - &ref_group_exprs, - input.schema(), - ) - .await?, - ); - } - _ => { - let mut grouping_sets = vec![]; - for grouping in &agg.groupings { - let grouping_set = from_substrait_grouping( - consumer, - grouping, - &ref_group_exprs, - input.schema(), - ) - .await?; - grouping_sets.push(grouping_set); - } - // Single-element grouping expression of type Expr::GroupingSet. - // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when - // parsed by the producer and consumer, since Substrait does not have a type dedicated - // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_exprs - .push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets))); - } - }; - - for m in &agg.measures { - let filter = match &m.filter { - Some(fil) => Some(Box::new( - consumer.consume_expression(fil, input.schema()).await?, - )), - None => None, - }; - let agg_func = match &m.measure { - Some(f) => { - let distinct = match f.invocation { - _ if f.invocation == AggregationInvocation::Distinct as i32 => { - true - } - _ if f.invocation == AggregationInvocation::All as i32 => false, - _ => false, - }; - let order_by = if !f.sorts.is_empty() { - Some( - from_substrait_sorts(consumer, &f.sorts, input.schema()) - .await?, - ) - } else { - None - }; - - from_substrait_agg_func( - consumer, - f, - input.schema(), - filter, - order_by, - distinct, - ) - .await - } - None => { - not_impl_err!("Aggregate without aggregate function is not supported") - } - }; - aggr_exprs.push(agg_func?.as_ref().clone()); - } - input.aggregate(group_exprs, aggr_exprs)?.build() - } else { - not_impl_err!("Aggregate without an input is not valid") - } -} - -pub async fn from_join_rel( - consumer: &impl SubstraitConsumer, - join: &JoinRel, -) -> Result { - if join.post_join_filter.is_some() { - return not_impl_err!("JoinRel with post_join_filter is not yet supported"); - } - - let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - consumer.consume_rel(join.left.as_ref().unwrap()).await?, - ); - let right = LogicalPlanBuilder::from( - consumer.consume_rel(join.right.as_ref().unwrap()).await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - - let join_type = from_substrait_jointype(join.r#type)?; - // The join condition expression needs full input schema and not the output schema from join since we lose columns from - // certain join types such as semi and anti joins - let in_join_schema = left.schema().join(right.schema())?; - - // If join expression exists, parse the `on` condition expression, build join and return - // Otherwise, build join with only the filter, without join keys - match &join.expression.as_ref() { - Some(expr) => { - let on = consumer.consume_expression(expr, &in_join_schema).await?; - // The join expression can contain both equal and non-equal ops. - // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. - // So we extract each part as follows: - // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector - // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) - let (join_ons, nulls_equal_nulls, join_filter) = - split_eq_and_noneq_join_predicate_with_nulls_equality(&on); - let (left_cols, right_cols): (Vec<_>, Vec<_>) = - itertools::multiunzip(join_ons); - left.join_detailed( - right.build()?, - join_type, - (left_cols, right_cols), - join_filter, - nulls_equal_nulls, - )? - .build() - } - None => { - let on: Vec = vec![]; - left.join_detailed(right.build()?, join_type, (on.clone(), on), None, false)? - .build() - } - } -} - -pub async fn from_cross_rel( - consumer: &impl SubstraitConsumer, - cross: &CrossRel, -) -> Result { - let left = LogicalPlanBuilder::from( - consumer.consume_rel(cross.left.as_ref().unwrap()).await?, - ); - let right = LogicalPlanBuilder::from( - consumer.consume_rel(cross.right.as_ref().unwrap()).await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - left.cross_join(right.build()?)?.build() -} - -#[allow(deprecated)] -pub async fn from_read_rel( - consumer: &impl SubstraitConsumer, - read: &ReadRel, -) -> Result { - async fn read_with_schema( - consumer: &impl SubstraitConsumer, - table_ref: TableReference, - schema: DFSchema, - projection: &Option, - filter: &Option>, - ) -> Result { - let schema = schema.replace_qualifier(table_ref.clone()); - - let filters = if let Some(f) = filter { - let filter_expr = consumer.consume_expression(f, &schema).await?; - split_conjunction_owned(filter_expr) - } else { - vec![] - }; - - let plan = { - let provider = match consumer.resolve_table_ref(&table_ref).await? { - Some(ref provider) => Arc::clone(provider), - _ => return plan_err!("No table named '{table_ref}'"), - }; - - LogicalPlanBuilder::scan_with_filters( - table_ref, - provider_as_source(Arc::clone(&provider)), - None, - filters, - )? - .build()? - }; - - ensure_schema_compatibility(plan.schema(), schema.clone())?; - - let schema = apply_masking(schema, projection)?; - - apply_projection(plan, schema) - } - - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Read Relation") - })?; - - let substrait_schema = from_substrait_named_struct(consumer, named_struct)?; - - match &read.read_type { - Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; - - read_with_schema( - consumer, - table_reference, - substrait_schema, - &read.projection, - &read.filter, - ) - .await - } - Some(ReadType::VirtualTable(vt)) => { - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: DFSchemaRef::new(substrait_schema), - })); - } - - let values = vt - .values - .iter() - .map(|row| { - let mut name_idx = 0; - let lits = row - .fields - .iter() - .map(|lit| { - name_idx += 1; // top-level names are provided through schema - Ok(Expr::Literal(from_substrait_literal( - consumer, - lit, - &named_struct.names, - &mut name_idx, - )?)) - }) - .collect::>()?; - if name_idx != named_struct.names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - named_struct.names.len() - ); - } - Ok(lits) - }) - .collect::>()?; - - Ok(LogicalPlan::Values(Values { - schema: DFSchemaRef::new(substrait_schema), - values, - })) - } - Some(ReadType::LocalFiles(lf)) => { - fn extract_filename(name: &str) -> Option { - let corrected_url = - if name.starts_with("file://") && !name.starts_with("file:///") { - name.replacen("file://", "file:///", 1) - } else { - name.to_string() - }; - - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } - - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); - - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); - } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - - read_with_schema( - consumer, - table_reference, - substrait_schema, - &read.projection, - &read.filter, - ) - .await - } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", read.read_type) - } - } -} - -pub async fn from_set_rel( - consumer: &impl SubstraitConsumer, - set: &SetRel, -) -> Result { - if set.inputs.len() < 2 { - substrait_err!("Set operation requires at least two inputs") - } else { - match set.op() { - SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, - SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, - SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( - consumer.consume_rel(&set.inputs[0]).await?, - union_rels(consumer, &set.inputs[1..], true).await?, - false, - ), - SetOp::IntersectionMultiset => { - intersect_rels(consumer, &set.inputs, false).await - } - SetOp::IntersectionMultisetAll => { - intersect_rels(consumer, &set.inputs, true).await - } - SetOp::MinusPrimary => except_rels(consumer, &set.inputs, false).await, - SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs, true).await, - set_op => not_impl_err!("Unsupported set operator: {set_op:?}"), - } - } -} - -pub async fn from_exchange_rel( - consumer: &impl SubstraitConsumer, - exchange: &ExchangeRel, -) -> Result { - let Some(input) = exchange.input.as_ref() else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - let input = Arc::new(consumer.consume_rel(input).await?); - - let Some(exchange_kind) = &exchange.exchange_kind else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let partitioning_scheme = match exchange_kind { - ExchangeKind::ScatterByFields(scatter_fields) => { - let mut partition_columns = vec![]; - let input_schema = input.schema(); - for field_ref in &scatter_fields.fields { - let column = from_substrait_field_reference(field_ref, input_schema)?; - partition_columns.push(column); - } - Partitioning::Hash(partition_columns, exchange.partition_count as usize) - } - ExchangeKind::RoundRobin(_) => { - Partitioning::RoundRobinBatch(exchange.partition_count as usize) - } - ExchangeKind::SingleTarget(_) - | ExchangeKind::MultiTarget(_) - | ExchangeKind::Broadcast(_) => { - return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); - } - }; - Ok(LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - })) -} - -fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { - match rel.rel_type.as_ref() { - None => None, - Some(rt) => match rt { - RelType::Read(r) => r.common.as_ref(), - RelType::Filter(f) => f.common.as_ref(), - RelType::Fetch(f) => f.common.as_ref(), - RelType::Aggregate(a) => a.common.as_ref(), - RelType::Sort(s) => s.common.as_ref(), - RelType::Join(j) => j.common.as_ref(), - RelType::Project(p) => p.common.as_ref(), - RelType::Set(s) => s.common.as_ref(), - RelType::ExtensionSingle(e) => e.common.as_ref(), - RelType::ExtensionMulti(e) => e.common.as_ref(), - RelType::ExtensionLeaf(e) => e.common.as_ref(), - RelType::Cross(c) => c.common.as_ref(), - RelType::Reference(_) => None, - RelType::Write(w) => w.common.as_ref(), - RelType::Ddl(d) => d.common.as_ref(), - RelType::HashJoin(j) => j.common.as_ref(), - RelType::MergeJoin(j) => j.common.as_ref(), - RelType::NestedLoopJoin(j) => j.common.as_ref(), - RelType::Window(w) => w.common.as_ref(), - RelType::Exchange(e) => e.common.as_ref(), - RelType::Expand(e) => e.common.as_ref(), - RelType::Update(_) => None, - }, - } -} - -fn retrieve_emit_kind(rel_common: Option<&RelCommon>) -> EmitKind { - // the default EmitKind is Direct if it is not set explicitly - let default = EmitKind::Direct(rel_common::Direct {}); - rel_common - .and_then(|rc| rc.emit_kind.as_ref()) - .map_or(default, |ek| ek.clone()) -} - -fn contains_volatile_expr(proj: &Projection) -> bool { - proj.expr.iter().any(|e| e.is_volatile()) -} - -fn apply_emit_kind( - rel_common: Option<&RelCommon>, - plan: LogicalPlan, -) -> Result { - match retrieve_emit_kind(rel_common) { - EmitKind::Direct(_) => Ok(plan), - EmitKind::Emit(Emit { output_mapping }) => { - // It is valid to reference the same field multiple times in the Emit - // In this case, we need to provide unique names to avoid collisions - let mut name_tracker = NameTracker::new(); - match plan { - // To avoid adding a projection on top of a projection, we apply special case - // handling to flatten Substrait Emits. This is only applicable if none of the - // expressions in the projection are volatile. This is to avoid issues like - // converting a single call of the random() function into multiple calls due to - // duplicate fields in the output_mapping. - LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { - let mut exprs: Vec = vec![]; - for field in output_mapping { - let expr = proj.expr - .get(field as usize) - .ok_or_else(|| substrait_datafusion_err!( - "Emit output field {} cannot be resolved in input schema {}", - field, proj.input.schema() - ))?; - exprs.push(name_tracker.get_uniquely_named_expr(expr.clone())?); - } - - let input = Arc::unwrap_or_clone(proj.input); - project(input, exprs) - } - // Otherwise we just handle the output_mapping as a projection - _ => { - let input_schema = plan.schema(); - - let mut exprs: Vec = vec![]; - for index in output_mapping.into_iter() { - let column = Expr::Column(Column::from( - input_schema.qualified_field(index as usize), - )); - let expr = name_tracker.get_uniquely_named_expr(column)?; - exprs.push(expr); - } - - project(plan, exprs) - } - } - } - } -} - -struct NameTracker { - seen_names: HashSet, -} - -enum NameTrackerStatus { - NeverSeen, - SeenBefore, -} - -impl NameTracker { - fn new() -> Self { - NameTracker { - seen_names: HashSet::default(), - } - } - fn get_unique_name(&mut self, name: String) -> (String, NameTrackerStatus) { - match self.seen_names.insert(name.clone()) { - true => (name, NameTrackerStatus::NeverSeen), - false => { - let mut counter = 0; - loop { - let candidate_name = format!("{}__temp__{}", name, counter); - if self.seen_names.insert(candidate_name.clone()) { - return (candidate_name, NameTrackerStatus::SeenBefore); - } - counter += 1; - } - } - } - } - - fn get_uniquely_named_expr(&mut self, expr: Expr) -> Result { - match self.get_unique_name(expr.name_for_alias()?) { - (_, NameTrackerStatus::NeverSeen) => Ok(expr), - (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), - } - } -} - -/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion -/// -/// This means: -/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The -/// DataFusion schema may have MORE fields, but not the other way around. -/// 2. All fields are compatible. See [`ensure_field_compatibility`] for details -fn ensure_schema_compatibility( - table_schema: &DFSchema, - substrait_schema: DFSchema, -) -> Result<()> { - substrait_schema - .strip_qualifiers() - .fields() - .iter() - .try_for_each(|substrait_field| { - let df_field = - table_schema.field_with_unqualified_name(substrait_field.name())?; - ensure_field_compatibility(df_field, substrait_field) - }) -} - -/// This function returns a DataFrame with fields adjusted if necessary in the event that the -/// Substrait schema is a subset of the DataFusion schema. -fn apply_projection( - plan: LogicalPlan, - substrait_schema: DFSchema, -) -> Result { - let df_schema = plan.schema(); - - if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(plan); - } - - let df_schema = df_schema.to_owned(); - - match plan { - LogicalPlan::TableScan(mut scan) => { - let column_indices: Vec = substrait_schema - .strip_qualifiers() - .fields() - .iter() - .map(|substrait_field| { - Ok(df_schema - .index_of_column_by_name(None, substrait_field.name().as_str()) - .unwrap()) - }) - .collect::>()?; - - let fields = column_indices - .iter() - .map(|i| df_schema.qualified_field(*i)) - .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) - .collect(); - - scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( - fields, - df_schema.metadata().clone(), - )?); - scan.projection = Some(column_indices); - - Ok(LogicalPlan::TableScan(scan)) - } - _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), - } -} - -/// Ensures that the given Substrait field is compatible with the given DataFusion field -/// -/// A field is compatible between Substrait and DataFusion if: -/// 1. They have logically equivalent types. -/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields -/// is not nullable. -/// -/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not -/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. -fn ensure_field_compatibility( - datafusion_field: &Field, - substrait_field: &Field, -) -> Result<()> { - if !DFSchema::datatype_is_logically_equal( - datafusion_field.data_type(), - substrait_field.data_type(), - ) { - return substrait_err!( - "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", - substrait_field.name(), - substrait_field.data_type(), - datafusion_field.data_type() - ); - } - - if !compatible_nullabilities( - datafusion_field.is_nullable(), - substrait_field.is_nullable(), - ) { - // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. - return substrait_err!( - "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", - substrait_field.name() - ); - } - Ok(()) -} - -/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise -fn compatible_nullabilities( - datafusion_nullability: bool, - substrait_nullability: bool, -) -> bool { - // DataFusion and Substrait have the same nullability - (datafusion_nullability == substrait_nullability) - // DataFusion is not nullable and Substrait is nullable - || (!datafusion_nullability && substrait_nullability) -} - -/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise -/// conflict with the columns from the other. -/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For -/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion -/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). -fn requalify_sides_if_needed( - left: LogicalPlanBuilder, - right: LogicalPlanBuilder, -) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder)> { - let left_cols = left.schema().columns(); - let right_cols = right.schema().columns(); - if left_cols.iter().any(|l| { - right_cols.iter().any(|r| { - l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) - }) - }) { - // These names have no connection to the original plan, but they'll make the columns - // (mostly) unique. There may be cases where this still causes duplicates, if either left - // or right side itself contains duplicate names with different qualifiers. - Ok(( - left.alias(TableReference::bare("left"))?, - right.alias(TableReference::bare("right"))?, - )) - } else { - Ok((left, right)) - } -} - -fn from_substrait_jointype(join_type: i32) -> Result { - if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { - match substrait_join_type { - join_rel::JoinType::Inner => Ok(JoinType::Inner), - join_rel::JoinType::Left => Ok(JoinType::Left), - join_rel::JoinType::Right => Ok(JoinType::Right), - join_rel::JoinType::Outer => Ok(JoinType::Full), - join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), - join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), - join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), - _ => plan_err!("unsupported join type {substrait_join_type:?}"), - } - } else { - plan_err!("invalid join type variant {join_type:?}") - } -} - -/// Convert Substrait Sorts to DataFusion Exprs -pub async fn from_substrait_sorts( - consumer: &impl SubstraitConsumer, - substrait_sorts: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut sorts: Vec = vec![]; - for s in substrait_sorts { - let expr = consumer - .consume_expression(s.expr.as_ref().unwrap(), input_schema) - .await?; - let asc_nullfirst = match &s.sort_kind { - Some(k) => match k { - Direction(d) => { - let Ok(direction) = SortDirection::try_from(*d) else { - return not_impl_err!( - "Unsupported Substrait SortDirection value {d}" - ); - }; - - match direction { - SortDirection::AscNullsFirst => Ok((true, true)), - SortDirection::AscNullsLast => Ok((true, false)), - SortDirection::DescNullsFirst => Ok((false, true)), - SortDirection::DescNullsLast => Ok((false, false)), - SortDirection::Clustered => not_impl_err!( - "Sort with direction clustered is not yet supported" - ), - SortDirection::Unspecified => { - not_impl_err!("Unspecified sort direction is invalid") - } - } - } - ComparisonFunctionReference(_) => not_impl_err!( - "Sort using comparison function reference is not supported" - ), - }, - None => not_impl_err!("Sort without sort kind is invalid"), - }; - let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Sort { - expr, - asc, - nulls_first, - }); - } - Ok(sorts) -} - -/// Convert Substrait Expressions to DataFusion Exprs -pub async fn from_substrait_rex_vec( - consumer: &impl SubstraitConsumer, - exprs: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut expressions: Vec = vec![]; - for expr in exprs { - let expression = consumer.consume_expression(expr, input_schema).await?; - expressions.push(expression); - } - Ok(expressions) -} - -/// Convert Substrait FunctionArguments to DataFusion Exprs -pub async fn from_substrait_func_args( - consumer: &impl SubstraitConsumer, - arguments: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut args: Vec = vec![]; - for arg in arguments { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => consumer.consume_expression(e, input_schema).await, - _ => not_impl_err!("Function argument non-Value type not supported"), - }; - args.push(arg_expr?); - } - Ok(args) -} - -/// Convert Substrait AggregateFunction to DataFusion Expr -pub async fn from_substrait_agg_func( - consumer: &impl SubstraitConsumer, - f: &AggregateFunction, - input_schema: &DFSchema, - filter: Option>, - order_by: Option>, - distinct: bool, -) -> Result> { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&f.function_reference) - else { - return plan_err!( - "Aggregate function not registered: function anchor = {:?}", - f.function_reference - ); - }; - - let fn_name = substrait_fun_name(fn_signature); - let udaf = consumer.get_function_registry().udaf(fn_name); - let udaf = udaf.map_err(|_| { - not_impl_datafusion_err!( - "Aggregate function {} is not supported: function anchor = {:?}", - fn_signature, - f.function_reference - ) - })?; - - let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; - - // Datafusion does not support aggregate functions with no arguments, so - // we inject a dummy argument that does not affect the query, but allows - // us to bypass this limitation. - let args = if udaf.name() == "count" && args.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] - } else { - args - }; - - Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), - ))) -} - -/// Convert Substrait Rex to DataFusion Expr -pub async fn from_substrait_rex( - consumer: &impl SubstraitConsumer, - expression: &Expression, - input_schema: &DFSchema, -) -> Result { - match &expression.rex_type { - Some(t) => match t { - RexType::Literal(expr) => consumer.consume_literal(expr).await, - RexType::Selection(expr) => { - consumer.consume_field_reference(expr, input_schema).await - } - RexType::ScalarFunction(expr) => { - consumer.consume_scalar_function(expr, input_schema).await - } - RexType::WindowFunction(expr) => { - consumer.consume_window_function(expr, input_schema).await - } - RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await, - RexType::SwitchExpression(expr) => { - consumer.consume_switch(expr, input_schema).await - } - RexType::SingularOrList(expr) => { - consumer.consume_singular_or_list(expr, input_schema).await - } - - RexType::MultiOrList(expr) => { - consumer.consume_multi_or_list(expr, input_schema).await - } - - RexType::Cast(expr) => { - consumer.consume_cast(expr.as_ref(), input_schema).await - } - - RexType::Subquery(expr) => { - consumer.consume_subquery(expr.as_ref(), input_schema).await - } - RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, - RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, - RexType::DynamicParameter(expr) => { - consumer.consume_dynamic_parameter(expr, input_schema).await - } - }, - None => substrait_err!("Expression must set rex_type: {:?}", expression), - } -} - -pub async fn from_singular_or_list( - consumer: &impl SubstraitConsumer, - expr: &SingularOrList, - input_schema: &DFSchema, -) -> Result { - let substrait_expr = expr.value.as_ref().unwrap(); - let substrait_list = expr.options.as_ref(); - Ok(Expr::InList(InList { - expr: Box::new( - consumer - .consume_expression(substrait_expr, input_schema) - .await?, - ), - list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, - negated: false, - })) -} - -pub async fn from_field_reference( - _consumer: &impl SubstraitConsumer, - field_ref: &FieldReference, - input_schema: &DFSchema, -) -> Result { - from_substrait_field_reference(field_ref, input_schema) -} - -pub async fn from_if_then( - consumer: &impl SubstraitConsumer, - if_then: &IfThen, - input_schema: &DFSchema, -) -> Result { - // Parse `ifs` - // If the first element does not have a `then` part, then we can assume it's a base expression - let mut when_then_expr: Vec<(Box, Box)> = vec![]; - let mut expr = None; - for (i, if_expr) in if_then.ifs.iter().enumerate() { - if i == 0 { - // Check if the first element is type base expression - if if_expr.then.is_none() { - expr = Some(Box::new( - consumer - .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) - .await?, - )); - continue; - } - } - when_then_expr.push(( - Box::new( - consumer - .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) - .await?, - ), - Box::new( - consumer - .consume_expression(if_expr.then.as_ref().unwrap(), input_schema) - .await?, - ), - )); - } - // Parse `else` - let else_expr = match &if_then.r#else { - Some(e) => Some(Box::new( - consumer.consume_expression(e, input_schema).await?, - )), - None => None, - }; - Ok(Expr::Case(Case { - expr, - when_then_expr, - else_expr, - })) -} - -pub async fn from_scalar_function( - consumer: &impl SubstraitConsumer, - f: &ScalarFunction, - input_schema: &DFSchema, -) -> Result { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&f.function_reference) - else { - return plan_err!( - "Scalar function not found: function reference = {:?}", - f.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_signature); - let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; - - // try to first match the requested function into registered udfs, then built-in ops - // and finally built-in expressions - if let Ok(func) = consumer.get_function_registry().udf(fn_name) { - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( - func.to_owned(), - args, - ))) - } else if let Some(op) = name_to_op(fn_name) { - if f.arguments.len() < 2 { - return not_impl_err!( - "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", - f.arguments.len() - ); - } - // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. - // In those cases we iterate through all the arguments, applying the binary expression against them all - let combined_expr = args - .into_iter() - .fold(None, |combined_expr: Option, arg: Expr| { - Some(match combined_expr { - Some(expr) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(expr), - op, - right: Box::new(arg), - }), - None => arg, - }) - }) - .unwrap(); - - Ok(combined_expr) - } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { - builder.build(consumer, f, input_schema).await - } else { - not_impl_err!("Unsupported function name: {fn_name:?}") - } -} - -pub async fn from_literal( - consumer: &impl SubstraitConsumer, - expr: &Literal, -) -> Result { - let scalar_value = from_substrait_literal_without_names(consumer, expr)?; - Ok(Expr::Literal(scalar_value)) -} - -pub async fn from_cast( - consumer: &impl SubstraitConsumer, - cast: &substrait_expression::Cast, - input_schema: &DFSchema, -) -> Result { - match cast.r#type.as_ref() { - Some(output_type) => { - let input_expr = Box::new( - consumer - .consume_expression( - cast.input.as_ref().unwrap().as_ref(), - input_schema, - ) - .await?, - ); - let data_type = from_substrait_type_without_names(consumer, output_type)?; - if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) - } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) - } - } - None => substrait_err!("Cast expression without output type is not allowed"), - } -} - -pub async fn from_window_function( - consumer: &impl SubstraitConsumer, - window: &WindowFunction, - input_schema: &DFSchema, -) -> Result { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&window.function_reference) - else { - return plan_err!( - "Window function not found: function reference = {:?}", - window.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_signature); - - // check udwf first, then udaf, then built-in window and aggregate functions - let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) { - Ok(WindowFunctionDefinition::WindowUDF(udwf)) - } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) { - Ok(WindowFunctionDefinition::AggregateUDF(udaf)) - } else { - not_impl_err!( - "Window function {} is not supported: function anchor = {:?}", - fn_name, - window.function_reference - ) - }?; - - let mut order_by = - from_substrait_sorts(consumer, &window.sorts, input_schema).await?; - - let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { - plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) - })? { - BoundsType::Rows => WindowFrameUnits::Rows, - BoundsType::Range => WindowFrameUnits::Range, - BoundsType::Unspecified => { - // If the plan does not specify the bounds type, then we use a simple logic to determine the units - // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary - // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row - if order_by.is_empty() { - WindowFrameUnits::Rows - } else { - WindowFrameUnits::Range - } - } - }; - let window_frame = datafusion::logical_expr::WindowFrame::new_bounds( - bound_units, - from_substrait_bound(&window.lower_bound, true)?, - from_substrait_bound(&window.upper_bound, false)?, - ); - - window_frame.regularize_order_bys(&mut order_by)?; - - // Datafusion does not support aggregate functions with no arguments, so - // we inject a dummy argument that does not affect the query, but allows - // us to bypass this limitation. - let args = if fun.name() == "count" && window.arguments.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] - } else { - from_substrait_func_args(consumer, &window.arguments, input_schema).await? - }; - - Ok(Expr::WindowFunction(expr::WindowFunction { - fun, - params: WindowFunctionParams { - args, - partition_by: from_substrait_rex_vec( - consumer, - &window.partitions, - input_schema, - ) - .await?, - order_by, - window_frame, - null_treatment: None, - }, - })) -} - -pub async fn from_subquery( - consumer: &impl SubstraitConsumer, - subquery: &substrait_expression::Subquery, - input_schema: &DFSchema, -) -> Result { - match &subquery.subquery_type { - Some(subquery_type) => match subquery_type { - SubqueryType::InPredicate(in_predicate) => { - if in_predicate.needles.len() != 1 { - substrait_err!("InPredicate Subquery type must have exactly one Needle expression") - } else { - let needle_expr = &in_predicate.needles[0]; - let haystack_expr = &in_predicate.haystack; - if let Some(haystack_expr) = haystack_expr { - let haystack_expr = consumer.consume_rel(haystack_expr).await?; - let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Expr::InSubquery(InSubquery { - expr: Box::new( - consumer - .consume_expression(needle_expr, input_schema) - .await?, - ), - subquery: Subquery { - subquery: Arc::new(haystack_expr), - outer_ref_columns: outer_refs, - spans: Spans::new(), - }, - negated: false, - })) - } else { - substrait_err!( - "InPredicate Subquery type must have a Haystack expression" - ) - } - } - } - SubqueryType::Scalar(query) => { - let plan = consumer - .consume_rel(&(query.input.clone()).unwrap_or_default()) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - spans: Spans::new(), - })) - } - SubqueryType::SetPredicate(predicate) => { - match predicate.predicate_op() { - // exist - PredicateOp::Exists => { - let relation = &predicate.tuples; - let plan = consumer - .consume_rel(&relation.clone().unwrap_or_default()) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::Exists(Exists::new( - Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - spans: Spans::new(), - }, - false, - ))) - } - other_type => substrait_err!( - "unimplemented type {:?} for set predicate", - other_type - ), - } - } - other_type => { - substrait_err!("Subquery type {:?} not implemented", other_type) - } - }, - None => { - substrait_err!("Subquery expression without SubqueryType is not allowed") - } - } -} - -pub(crate) fn from_substrait_type_without_names( - consumer: &impl SubstraitConsumer, - dt: &Type, -) -> Result { - from_substrait_type(consumer, dt, &[], &mut 0) -} - -fn from_substrait_type( - consumer: &impl SubstraitConsumer, - dt: &Type, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - match &dt.kind { - Some(s_kind) => match s_kind { - r#type::Kind::Bool(_) => Ok(DataType::Boolean), - r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::Fp32(_) => Ok(DataType::Float32), - r#type::Kind::Fp64(_) => Ok(DataType::Float64), - r#type::Kind::Timestamp(ts) => { - // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead - #[allow(deprecated)] - match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - } - } - r#type::Kind::PrecisionTimestamp(pts) => { - let unit = match pts.precision { - 0 => Ok(TimeUnit::Second), - 3 => Ok(TimeUnit::Millisecond), - 6 => Ok(TimeUnit::Microsecond), - 9 => Ok(TimeUnit::Nanosecond), - p => not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ), - }?; - Ok(DataType::Timestamp(unit, None)) - } - r#type::Kind::PrecisionTimestampTz(pts) => { - let unit = match pts.precision { - 0 => Ok(TimeUnit::Second), - 3 => Ok(TimeUnit::Millisecond), - 6 => Ok(TimeUnit::Microsecond), - 9 => Ok(TimeUnit::Nanosecond), - p => not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestampTz" - ), - }?; - Ok(DataType::Timestamp(unit, Some(DEFAULT_TIMEZONE.into()))) - } - r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), - DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), - VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::FixedBinary(fixed) => { - Ok(DataType::FixedSizeBinary(fixed.length)) - } - r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), - VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::List(list) => { - let inner_type = list.r#type.as_ref().ok_or_else(|| { - substrait_datafusion_err!("List type must have inner type") - })?; - let field = Arc::new(Field::new_list_field( - from_substrait_type(consumer, inner_type, dfs_names, name_idx)?, - // We ignore Substrait's nullability here to match to_substrait_literal - // which always creates nullable lists - true, - )); - match list.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - )?, - } - } - r#type::Kind::Map(map) => { - let key_type = map.key.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Map type must have key type") - })?; - let value_type = map.value.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Map type must have value type") - })?; - let key_field = Arc::new(Field::new( - "key", - from_substrait_type(consumer, key_type, dfs_names, name_idx)?, - false, - )); - let value_field = Arc::new(Field::new( - "value", - from_substrait_type(consumer, value_type, dfs_names, name_idx)?, - true, - )); - Ok(DataType::Map( - Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), - )), - false, // whether keys are sorted - )) - } - r#type::Kind::Decimal(d) => match d.type_variation_reference { - DECIMAL_128_TYPE_VARIATION_REF => { - Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) - } - DECIMAL_256_TYPE_VARIATION_REF => { - Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::IntervalYear(_) => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), - r#type::Kind::IntervalCompound(_) => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - r#type::Kind::UserDefined(u) => { - if let Ok(data_type) = consumer.consume_user_defined_type(u) { - return Ok(data_type); - } - - // TODO: remove the code below once the producer has been updated - if let Some(name) = consumer.get_extensions().types.get(&u.type_reference) - { - #[allow(deprecated)] - match name.as_ref() { - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), - _ => not_impl_err!( - "Unsupported Substrait user defined type with ref {} and variation {}", - u.type_reference, - u.type_variation_reference - ), - } - } else { - #[allow(deprecated)] - match u.type_reference { - // Kept for backwards compatibility, producers should use IntervalYear instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - // Kept for backwards compatibility, producers should use IntervalDay instead - INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) - } - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - _ => not_impl_err!( - "Unsupported Substrait user defined type with ref {} and variation {}", - u.type_reference, - u.type_variation_reference - ), - } - } - } - r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( - consumer, s, dfs_names, name_idx, - )?)), - r#type::Kind::Varchar(_) => Ok(DataType::Utf8), - r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), - _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), - }, - _ => not_impl_err!("`None` Substrait kind is not supported"), - } -} - -fn from_substrait_struct_type( - consumer: &impl SubstraitConsumer, - s: &r#type::Struct, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - let mut fields = vec![]; - for (i, f) in s.types.iter().enumerate() { - let field = Field::new( - next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(consumer, f, dfs_names, name_idx)?, - true, // We assume everything to be nullable since that's easier than ensuring it matches - ); - fields.push(field); - } - Ok(fields.into()) -} - -fn next_struct_field_name( - column_idx: usize, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - if dfs_names.is_empty() { - // If names are not given, create dummy names - // c0, c1, ... align with e.g. SqlToRel::create_named_struct - Ok(format!("c{column_idx}")) - } else { - let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { - substrait_datafusion_err!("Named schema must contain names for all fields") - })?; - *name_idx += 1; - Ok(name) - } -} - -/// Convert Substrait NamedStruct to DataFusion DFSchemaRef -pub fn from_substrait_named_struct( - consumer: &impl SubstraitConsumer, - base_schema: &NamedStruct, -) -> Result { - let mut name_idx = 0; - let fields = from_substrait_struct_type( - consumer, - base_schema.r#struct.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Named struct must contain a struct") - })?, - &base_schema.names, - &mut name_idx, - ); - if name_idx != base_schema.names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - base_schema.names.len() - ); - } - DFSchema::try_from(Schema::new(fields?)) -} - -fn from_substrait_bound( - bound: &Option, - is_lower: bool, -) -> Result { - match bound { - Some(b) => match &b.kind { - Some(k) => match k { - BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { - Ok(WindowFrameBound::CurrentRow) - } - BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { - if *offset <= 0 { - return plan_err!("Preceding bound must be positive"); - } - Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( - *offset as u64, - )))) - } - BoundKind::Following(SubstraitBound::Following { offset }) => { - if *offset <= 0 { - return plan_err!("Following bound must be positive"); - } - Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( - *offset as u64, - )))) - } - BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { - if is_lower { - Ok(WindowFrameBound::Preceding(ScalarValue::Null)) - } else { - Ok(WindowFrameBound::Following(ScalarValue::Null)) - } - } - }, - None => substrait_err!("WindowFunction missing Substrait Bound kind"), - }, - None => { - if is_lower { - Ok(WindowFrameBound::Preceding(ScalarValue::Null)) - } else { - Ok(WindowFrameBound::Following(ScalarValue::Null)) - } - } - } -} - -pub(crate) fn from_substrait_literal_without_names( - consumer: &impl SubstraitConsumer, - lit: &Literal, -) -> Result { - from_substrait_literal(consumer, lit, &vec![], &mut 0) -} - -fn from_substrait_literal( - consumer: &impl SubstraitConsumer, - lit: &Literal, - dfs_names: &Vec, - name_idx: &mut usize, -) -> Result { - let scalar_value = match &lit.literal_type { - Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), - Some(LiteralType::I8(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int8(Some(*n as i8)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt8(Some(*n as u8)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I16(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int16(Some(*n as i16)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt16(Some(*n as u16)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I32(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int32(Some(*n)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt32(Some(*n as u32)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I64(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int64(Some(*n)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt64(Some(*n as u64)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), - Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), - Some(LiteralType::Timestamp(t)) => { - // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead - #[allow(deprecated)] - match lit.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - ScalarValue::TimestampSecond(Some(*t), None) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - ScalarValue::TimestampMillisecond(Some(*t), None) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - ScalarValue::TimestampMicrosecond(Some(*t), None) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - ScalarValue::TimestampNanosecond(Some(*t), None) - } - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::PrecisionTimestamp(pt)) => match pt.precision { - 0 => ScalarValue::TimestampSecond(Some(pt.value), None), - 3 => ScalarValue::TimestampMillisecond(Some(pt.value), None), - 6 => ScalarValue::TimestampMicrosecond(Some(pt.value), None), - 9 => ScalarValue::TimestampNanosecond(Some(pt.value), None), - p => { - return not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ); - } - }, - Some(LiteralType::PrecisionTimestampTz(pt)) => match pt.precision { - 0 => ScalarValue::TimestampSecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 3 => ScalarValue::TimestampMillisecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 6 => ScalarValue::TimestampMicrosecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 9 => ScalarValue::TimestampNanosecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - p => { - return not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ); - } - }, - Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), - Some(LiteralType::String(s)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), - VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::Binary(b)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::LargeBinary(Some(b.clone())) - } - VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::FixedBinary(b)) => { - ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) - } - Some(LiteralType::Decimal(d)) => { - let value: [u8; 16] = d - .value - .clone() - .try_into() - .or(substrait_err!("Failed to parse decimal value"))?; - let p = d.precision.try_into().map_err(|e| { - substrait_datafusion_err!("Failed to parse decimal precision: {e}") - })?; - let s = d.scale.try_into().map_err(|e| { - substrait_datafusion_err!("Failed to parse decimal scale: {e}") - })?; - ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) - } - Some(LiteralType::List(l)) => { - // Each element should start the name index from the same value, then we increase it - // once at the end - let mut element_name_idx = *name_idx; - let elements = l - .values - .iter() - .map(|el| { - element_name_idx = *name_idx; - from_substrait_literal(consumer, el, dfs_names, &mut element_name_idx) - }) - .collect::>>()?; - *name_idx = element_name_idx; - if elements.is_empty() { - return substrait_err!( - "Empty list must be encoded as EmptyList literal type, not List" - ); - } - let element_type = elements[0].data_type(); - match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( - ScalarValue::new_list_nullable(elements.as_slice(), &element_type), - ), - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( - ScalarValue::new_large_list(elements.as_slice(), &element_type), - ), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::EmptyList(l)) => { - let element_type = from_substrait_type( - consumer, - l.r#type.clone().unwrap().as_ref(), - dfs_names, - name_idx, - )?; - match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) - } - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( - ScalarValue::new_large_list(&[], &element_type), - ), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::Map(m)) => { - // Each entry should start the name index from the same value, then we increase it - // once at the end - let mut entry_name_idx = *name_idx; - let entries = m - .key_values - .iter() - .map(|kv| { - entry_name_idx = *name_idx; - let key_sv = from_substrait_literal( - consumer, - kv.key.as_ref().unwrap(), - dfs_names, - &mut entry_name_idx, - )?; - let value_sv = from_substrait_literal( - consumer, - kv.value.as_ref().unwrap(), - dfs_names, - &mut entry_name_idx, - )?; - ScalarStructBuilder::new() - .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) - .with_scalar( - Field::new("value", value_sv.data_type(), true), - value_sv, - ) - .build() - }) - .collect::>>()?; - *name_idx = entry_name_idx; - - if entries.is_empty() { - return substrait_err!( - "Empty map must be encoded as EmptyMap literal type, not Map" - ); - } - - ScalarValue::Map(Arc::new(MapArray::new( - Arc::new(Field::new("entries", entries[0].data_type(), false)), - OffsetBuffer::new(vec![0, entries.len() as i32].into()), - ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), - None, - false, - ))) - } - Some(LiteralType::EmptyMap(m)) => { - let key = match &m.key { - Some(k) => Ok(k), - _ => plan_err!("Missing key type for empty map"), - }?; - let value = match &m.value { - Some(v) => Ok(v), - _ => plan_err!("Missing value type for empty map"), - }?; - let key_type = from_substrait_type(consumer, key, dfs_names, name_idx)?; - let value_type = from_substrait_type(consumer, value, dfs_names, name_idx)?; - - // new_empty_array on a MapType creates a too empty array - // We want it to contain an empty struct array to align with an empty MapBuilder one - let entries = Field::new_struct( - "entries", - vec![ - Field::new("key", key_type, false), - Field::new("value", value_type, true), - ], - false, - ); - let struct_array = - new_empty_array(entries.data_type()).as_struct().to_owned(); - ScalarValue::Map(Arc::new(MapArray::new( - Arc::new(entries), - OffsetBuffer::new(vec![0, 0].into()), - struct_array, - None, - false, - ))) - } - Some(LiteralType::Struct(s)) => { - let mut builder = ScalarStructBuilder::new(); - for (i, field) in s.fields.iter().enumerate() { - let name = next_struct_field_name(i, dfs_names, name_idx)?; - let sv = from_substrait_literal(consumer, field, dfs_names, name_idx)?; - // We assume everything to be nullable, since Arrow's strict about things matching - // and it's hard to match otherwise. - builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); - } - builder.build()? - } - Some(LiteralType::Null(null_type)) => { - let data_type = - from_substrait_type(consumer, null_type, dfs_names, name_idx)?; - ScalarValue::try_from(&data_type)? - } - Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { - days, - seconds, - subseconds, - precision_mode, - })) => { - use interval_day_to_second::PrecisionMode; - // DF only supports millisecond precision, so for any more granular type we lose precision - let milliseconds = match precision_mode { - Some(PrecisionMode::Microseconds(ms)) => ms / 1000, - None => - if *subseconds != 0 { - return substrait_err!("Cannot set subseconds field of IntervalDayToSecond without setting precision"); - } else { - 0_i32 - } - Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, - Some(PrecisionMode::Precision(3)) => *subseconds as i32, - Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, - Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, - _ => { - return not_impl_err!( - "Unsupported Substrait interval day to second precision mode: {precision_mode:?}") - } - }; - - ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) - } - Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { - ScalarValue::new_interval_ym(*years, *months) - } - Some(LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month, - interval_day_to_second, - })) => match (interval_year_to_month, interval_day_to_second) { - ( - Some(IntervalYearToMonth { years, months }), - Some(IntervalDayToSecond { - days, - seconds, - subseconds, - precision_mode: - Some(interval_day_to_second::PrecisionMode::Precision(p)), - }), - ) => { - if *p < 0 || *p > 9 { - return plan_err!( - "Unsupported Substrait interval day to second precision: {}", - p - ); - } - let nanos = *subseconds * i64::pow(10, (9 - p) as u32); - ScalarValue::new_interval_mdn( - *years * 12 + months, - *days, - *seconds as i64 * NANOSECONDS + nanos, - ) - } - _ => return plan_err!("Substrait compound interval missing components"), - }, - Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), - Some(LiteralType::UserDefined(user_defined)) => { - if let Ok(value) = consumer.consume_user_defined_literal(user_defined) { - return Ok(value); - } - - // TODO: remove the code below once the producer has been updated - - // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed - let interval_month_day_nano = - |user_defined: &proto::expression::literal::UserDefined| -> Result { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval month day nano value is empty"); - }; - let value_slice: [u8; 16] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval month day nano value" - ) - })?; - let months = - i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - let nanoseconds = - i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); - Ok(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months, - days, - nanoseconds, - }, - ))) - }; - - if let Some(name) = consumer - .get_extensions() - .types - .get(&user_defined.type_reference) - { - match name.as_ref() { - // Kept for backwards compatibility - producers should use IntervalCompound instead - #[allow(deprecated)] - INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { - interval_month_day_nano(user_defined)? - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type with ref {} and name {}", - user_defined.type_reference, - name - ) - } - } - } else { - #[allow(deprecated)] - match user_defined.type_reference { - // Kept for backwards compatibility, producers should useIntervalYearToMonth instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval year month value is empty"); - }; - let value_slice: [u8; 4] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval year month value" - ) - })?; - ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( - value_slice, - ))) - } - // Kept for backwards compatibility, producers should useIntervalDayToSecond instead - INTERVAL_DAY_TIME_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval day time value is empty"); - }; - let value_slice: [u8; 8] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval day time value" - ) - })?; - let days = - i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let milliseconds = - i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days, - milliseconds, - })) - } - // Kept for backwards compatibility, producers should useIntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - interval_month_day_nano(user_defined)? - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type literal with ref {}", - user_defined.type_reference - ) - } - } - } - } - _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), - }; - - Ok(scalar_value) -} - -#[allow(deprecated)] -async fn from_substrait_grouping( - consumer: &impl SubstraitConsumer, - grouping: &Grouping, - expressions: &[Expr], - input_schema: &DFSchemaRef, -) -> Result> { - let mut group_exprs = vec![]; - if !grouping.grouping_expressions.is_empty() { - for e in &grouping.grouping_expressions { - let expr = consumer.consume_expression(e, input_schema).await?; - group_exprs.push(expr); - } - return Ok(group_exprs); - } - for idx in &grouping.expression_references { - let e = &expressions[*idx as usize]; - group_exprs.push(e.clone()); - } - Ok(group_exprs) -} - -fn from_substrait_field_reference( - field_ref: &FieldReference, - input_schema: &DFSchema, -) -> Result { - match &field_ref.reference_type { - Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => Ok(Expr::Column(Column::from( - input_schema.qualified_field(x.field as usize), - ))), - }, - _ => not_impl_err!( - "Direct reference with types other than StructField is not supported" - ), - }, - _ => not_impl_err!("unsupported field ref type"), - } -} - -/// Build [`Expr`] from its name and required inputs. -struct BuiltinExprBuilder { - expr_name: String, -} - -impl BuiltinExprBuilder { - pub fn try_from_name(name: &str) -> Option { - match name { - "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" - | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" | "negative" | "negate" => Some(Self { - expr_name: name.to_string(), - }), - _ => None, - } - } - - pub async fn build( - self, - consumer: &impl SubstraitConsumer, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - match self.expr_name.as_str() { - "like" => Self::build_like_expr(consumer, false, f, input_schema).await, - "ilike" => Self::build_like_expr(consumer, true, f, input_schema).await, - "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" - | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" => { - Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await - } - _ => { - not_impl_err!("Unsupported builtin expression: {}", self.expr_name) - } - } - } - - async fn build_unary_expr( - consumer: &impl SubstraitConsumer, - fn_name: &str, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - if f.arguments.len() != 1 { - return substrait_err!("Expect one argument for {fn_name} expr"); - } - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return substrait_err!("Invalid arguments type for {fn_name} expr"); - }; - let arg = consumer - .consume_expression(expr_substrait, input_schema) - .await?; - let arg = Box::new(arg); - - let expr = match fn_name { - "not" => Expr::Not(arg), - "negative" | "negate" => Expr::Negative(arg), - "is_null" => Expr::IsNull(arg), - "is_not_null" => Expr::IsNotNull(arg), - "is_true" => Expr::IsTrue(arg), - "is_false" => Expr::IsFalse(arg), - "is_not_true" => Expr::IsNotTrue(arg), - "is_not_false" => Expr::IsNotFalse(arg), - "is_unknown" => Expr::IsUnknown(arg), - "is_not_unknown" => Expr::IsNotUnknown(arg), - _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), - }; - - Ok(expr) - } - - async fn build_like_expr( - consumer: &impl SubstraitConsumer, - case_insensitive: bool, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 2 && f.arguments.len() != 3 { - return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); - } - - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let expr = consumer - .consume_expression(expr_substrait, input_schema) - .await?; - let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let pattern = consumer - .consume_expression(pattern_substrait, input_schema) - .await?; - - // Default case: escape character is Literal(Utf8(None)) - let escape_char = if f.arguments.len() == 3 { - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type - else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - - let escape_char_expr = consumer - .consume_expression(escape_char_substrait, input_schema) - .await?; - - match escape_char_expr { - Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { - // Convert Option to Option - escape_char_string.and_then(|s| s.chars().next()) - } - _ => { - return substrait_err!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" - ) - } - } - } else { - None - }; - - Ok(Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char, - case_insensitive, - })) - } -} - -#[cfg(test)] -mod test { - use crate::extensions::Extensions; - use crate::logical_plan::consumer::{ - from_substrait_literal_without_names, from_substrait_rex, - DefaultSubstraitConsumer, - }; - use arrow::array::types::IntervalMonthDayNano; - use datafusion::arrow; - use datafusion::common::DFSchema; - use datafusion::error::Result; - use datafusion::execution::SessionState; - use datafusion::prelude::{Expr, SessionContext}; - use datafusion::scalar::ScalarValue; - use std::sync::LazyLock; - use substrait::proto::expression::literal::{ - interval_day_to_second, IntervalCompound, IntervalDayToSecond, - IntervalYearToMonth, LiteralType, - }; - use substrait::proto::expression::window_function::BoundsType; - use substrait::proto::expression::Literal; - - static TEST_SESSION_STATE: LazyLock = - LazyLock::new(|| SessionContext::default().state()); - static TEST_EXTENSIONS: LazyLock = LazyLock::new(Extensions::default); - fn test_consumer() -> DefaultSubstraitConsumer<'static> { - let extensions = &TEST_EXTENSIONS; - let state = &TEST_SESSION_STATE; - DefaultSubstraitConsumer::new(extensions, state) - } - - #[test] - fn interval_compound_different_precision() -> Result<()> { - // DF producer (and thus roundtrip) always uses precision = 9, - // this test exists to test with some other value. - let substrait = Literal { - nullable: false, - type_variation_reference: 0, - literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month: Some(IntervalYearToMonth { - years: 1, - months: 2, - }), - interval_day_to_second: Some(IntervalDayToSecond { - days: 3, - seconds: 4, - subseconds: 5, - precision_mode: Some( - interval_day_to_second::PrecisionMode::Precision(6), - ), - }), - })), - }; - - let consumer = test_consumer(); - assert_eq!( - from_substrait_literal_without_names(&consumer, &substrait)?, - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { - months: 14, - days: 3, - nanoseconds: 4_000_005_000 - })) - ); - - Ok(()) - } - - #[tokio::test] - async fn window_function_with_range_unit_and_no_order_by() -> Result<()> { - let substrait = substrait::proto::Expression { - rex_type: Some(substrait::proto::expression::RexType::WindowFunction( - substrait::proto::expression::WindowFunction { - function_reference: 0, - bounds_type: BoundsType::Range as i32, - sorts: vec![], - ..Default::default() - }, - )), - }; - - let mut consumer = test_consumer(); - - // Just registering a single function (index 0) so that the plan - // does not throw a "function not found" error. - let mut extensions = Extensions::default(); - extensions.register_function("count".to_string()); - consumer.extensions = &extensions; - - match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { - Expr::WindowFunction(window_function) => { - assert_eq!(window_function.params.order_by.len(), 1) - } - _ => panic!("expr was not a WindowFunction"), - }; - - Ok(()) - } - - #[tokio::test] - async fn window_function_with_count() -> Result<()> { - let substrait = substrait::proto::Expression { - rex_type: Some(substrait::proto::expression::RexType::WindowFunction( - substrait::proto::expression::WindowFunction { - function_reference: 0, - ..Default::default() - }, - )), - }; - - let mut consumer = test_consumer(); - - let mut extensions = Extensions::default(); - extensions.register_function("count".to_string()); - consumer.extensions = &extensions; - - match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { - Expr::WindowFunction(window_function) => { - assert_eq!(window_function.params.args.len(), 1) - } - _ => panic!("expr was not a WindowFunction"), - }; - - Ok(()) - } -} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs new file mode 100644 index 0000000000000..62e140acc27b3 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{ + from_substrait_func_args, substrait_fun_name, SubstraitConsumer, +}; +use datafusion::common::{not_impl_datafusion_err, plan_err, DFSchema, ScalarValue}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{expr, Expr, SortExpr}; +use std::sync::Arc; +use substrait::proto::AggregateFunction; + +/// Convert Substrait AggregateFunction to DataFusion Expr +pub async fn from_substrait_agg_func( + consumer: &impl SubstraitConsumer, + f: &AggregateFunction, + input_schema: &DFSchema, + filter: Option>, + order_by: Vec, + distinct: bool, +) -> datafusion::common::Result> { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { + return plan_err!( + "Aggregate function not registered: function anchor = {:?}", + f.function_reference + ); + }; + + let fn_name = substrait_fun_name(fn_signature); + let udaf = consumer.get_function_registry().udaf(fn_name); + let udaf = udaf.map_err(|_| { + not_impl_datafusion_err!( + "Aggregate function {} is not supported: function anchor = {:?}", + fn_signature, + f.function_reference + ) + })?; + + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // Datafusion does not support aggregate functions with no arguments, so + // we inject a dummy argument that does not affect the query, but allows + // us to bypass this limitation. + let args = if udaf.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)] + } else { + args + }; + + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), + ))) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs new file mode 100644 index 0000000000000..5e8d3d93065f4 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::types::from_substrait_type_without_names; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{substrait_err, DFSchema}; +use datafusion::logical_expr::{Cast, Expr, TryCast}; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::cast::FailureBehavior::ReturnNull; + +pub async fn from_cast( + consumer: &impl SubstraitConsumer, + cast: &substrait_expression::Cast, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match cast.r#type.as_ref() { + Some(output_type) => { + let input_expr = Box::new( + consumer + .consume_expression( + cast.input.as_ref().unwrap().as_ref(), + input_schema, + ) + .await?, + ); + let data_type = from_substrait_type_without_names(consumer, output_type)?; + if cast.failure_behavior() == ReturnNull { + Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + } else { + Ok(Expr::Cast(Cast::new(input_expr, data_type))) + } + } + None => substrait_err!("Cast expression without output type is not allowed"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs new file mode 100644 index 0000000000000..90b5b6418149b --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, Column, DFSchema}; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::field_reference::ReferenceType::DirectReference; +use substrait::proto::expression::reference_segment::ReferenceType::StructField; +use substrait::proto::expression::FieldReference; + +pub async fn from_field_reference( + _consumer: &impl SubstraitConsumer, + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> datafusion::common::Result { + from_substrait_field_reference(field_ref, input_schema) +} + +pub(crate) fn from_substrait_field_reference( + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!( + "Direct reference StructField with child is not supported" + ), + None => Ok(Expr::Column(Column::from( + input_schema.qualified_field(x.field as usize), + ))), + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs b/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs new file mode 100644 index 0000000000000..0b610b61b1dea --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, DFSchema}; +use datafusion::logical_expr::Expr; +use substrait::proto::function_argument::ArgType; +use substrait::proto::FunctionArgument; + +/// Convert Substrait FunctionArguments to DataFusion Exprs +pub async fn from_substrait_func_args( + consumer: &impl SubstraitConsumer, + arguments: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut args: Vec = vec![]; + for arg in arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => consumer.consume_expression(e, input_schema).await, + _ => not_impl_err!("Function argument non-Value type not supported"), + }; + args.push(arg_expr?); + } + Ok(args) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs b/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs new file mode 100644 index 0000000000000..c4cc6c2fcd24f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::DFSchema; +use datafusion::logical_expr::{Case, Expr}; +use substrait::proto::expression::IfThen; + +pub async fn from_if_then( + consumer: &impl SubstraitConsumer, + if_then: &IfThen, + input_schema: &DFSchema, +) -> datafusion::common::Result { + // Parse `ifs` + // If the first element does not have a `then` part, then we can assume it's a base expression + let mut when_then_expr: Vec<(Box, Box)> = vec![]; + let mut expr = None; + for (i, if_expr) in if_then.ifs.iter().enumerate() { + if i == 0 { + // Check if the first element is type base expression + if if_expr.then.is_none() { + expr = Some(Box::new( + consumer + .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) + .await?, + )); + continue; + } + } + when_then_expr.push(( + Box::new( + consumer + .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) + .await?, + ), + Box::new( + consumer + .consume_expression(if_expr.then.as_ref().unwrap(), input_schema) + .await?, + ), + )); + } + // Parse `else` + let else_expr = match &if_then.r#else { + Some(e) => Some(Box::new( + consumer.consume_expression(e, input_schema).await?, + )), + None => None, + }; + Ok(Expr::Case(Case { + expr, + when_then_expr, + else_expr, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs new file mode 100644 index 0000000000000..dc7a5935c0149 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs @@ -0,0 +1,587 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::types::from_substrait_type; +use crate::logical_plan::consumer::utils::{next_struct_field_name, DEFAULT_TIMEZONE}; +use crate::logical_plan::consumer::SubstraitConsumer; +#[allow(deprecated)] +use crate::variation_const::{ + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, TIME_32_TYPE_VARIATION_REF, + TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::array::{new_empty_array, AsArray, MapArray}; +use datafusion::arrow::buffer::OffsetBuffer; +use datafusion::arrow::datatypes::{Field, IntervalDayTime, IntervalMonthDayNano}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::common::{ + not_impl_err, plan_err, substrait_datafusion_err, substrait_err, ScalarValue, +}; +use datafusion::logical_expr::Expr; +use std::sync::Arc; +use substrait::proto; +use substrait::proto::expression::literal::user_defined::Val; +use substrait::proto::expression::literal::{ + interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, + LiteralType, +}; +use substrait::proto::expression::Literal; + +pub async fn from_literal( + consumer: &impl SubstraitConsumer, + expr: &Literal, +) -> datafusion::common::Result { + let scalar_value = from_substrait_literal_without_names(consumer, expr)?; + Ok(Expr::Literal(scalar_value, None)) +} + +pub(crate) fn from_substrait_literal_without_names( + consumer: &impl SubstraitConsumer, + lit: &Literal, +) -> datafusion::common::Result { + from_substrait_literal(consumer, lit, &vec![], &mut 0) +} + +pub(crate) fn from_substrait_literal( + consumer: &impl SubstraitConsumer, + lit: &Literal, + dfs_names: &Vec, + name_idx: &mut usize, +) -> datafusion::common::Result { + let scalar_value = match &lit.literal_type { + Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), + Some(LiteralType::I8(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int8(Some(*n as i8)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt8(Some(*n as u8)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I16(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int16(Some(*n as i16)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt16(Some(*n as u16)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I32(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int32(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt32(Some(*n as u32)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I64(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int64(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt64(Some(*n as u64)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), + Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), + Some(LiteralType::Timestamp(t)) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match lit.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + ScalarValue::TimestampSecond(Some(*t), None) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + ScalarValue::TimestampMillisecond(Some(*t), None) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + ScalarValue::TimestampMicrosecond(Some(*t), None) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + ScalarValue::TimestampNanosecond(Some(*t), None) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::PrecisionTimestamp(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond(Some(pt.value), None), + 3 => ScalarValue::TimestampMillisecond(Some(pt.value), None), + 6 => ScalarValue::TimestampMicrosecond(Some(pt.value), None), + 9 => ScalarValue::TimestampNanosecond(Some(pt.value), None), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); + } + }, + Some(LiteralType::PrecisionTimestampTz(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 3 => ScalarValue::TimestampMillisecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 6 => ScalarValue::TimestampMicrosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 9 => ScalarValue::TimestampNanosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); + } + }, + Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), + Some(LiteralType::PrecisionTime(pt)) => match pt.precision { + 0 => match lit.type_variation_reference { + TIME_32_TYPE_VARIATION_REF => { + ScalarValue::Time32Second(Some(pt.value as i32)) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + 3 => match lit.type_variation_reference { + TIME_32_TYPE_VARIATION_REF => { + ScalarValue::Time32Millisecond(Some(pt.value as i32)) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + 6 => match lit.type_variation_reference { + TIME_64_TYPE_VARIATION_REF => { + ScalarValue::Time64Microsecond(Some(pt.value)) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + 9 => match lit.type_variation_reference { + TIME_64_TYPE_VARIATION_REF => { + ScalarValue::Time64Nanosecond(Some(pt.value)) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTime" + ); + } + }, + Some(LiteralType::String(s)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::Binary(b)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::LargeBinary(Some(b.clone())) + } + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::FixedBinary(b)) => { + ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) + } + Some(LiteralType::Decimal(d)) => { + let value: [u8; 16] = d + .value + .clone() + .try_into() + .or(substrait_err!("Failed to parse decimal value"))?; + let p = d.precision.try_into().map_err(|e| { + substrait_datafusion_err!("Failed to parse decimal precision: {e}") + })?; + let s = d.scale.try_into().map_err(|e| { + substrait_datafusion_err!("Failed to parse decimal scale: {e}") + })?; + ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) + } + Some(LiteralType::List(l)) => { + // Each element should start the name index from the same value, then we increase it + // once at the end + let mut element_name_idx = *name_idx; + let elements = l + .values + .iter() + .map(|el| { + element_name_idx = *name_idx; + from_substrait_literal(consumer, el, dfs_names, &mut element_name_idx) + }) + .collect::>>()?; + *name_idx = element_name_idx; + if elements.is_empty() { + return substrait_err!( + "Empty list must be encoded as EmptyList literal type, not List" + ); + } + let element_type = elements[0].data_type(); + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( + ScalarValue::new_list_nullable(elements.as_slice(), &element_type), + ), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(elements.as_slice(), &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::EmptyList(l)) => { + let element_type = from_substrait_type( + consumer, + l.r#type.clone().unwrap().as_ref(), + dfs_names, + name_idx, + )?; + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) + } + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(&[], &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::Map(m)) => { + // Each entry should start the name index from the same value, then we increase it + // once at the end + let mut entry_name_idx = *name_idx; + let entries = m + .key_values + .iter() + .map(|kv| { + entry_name_idx = *name_idx; + let key_sv = from_substrait_literal( + consumer, + kv.key.as_ref().unwrap(), + dfs_names, + &mut entry_name_idx, + )?; + let value_sv = from_substrait_literal( + consumer, + kv.value.as_ref().unwrap(), + dfs_names, + &mut entry_name_idx, + )?; + ScalarStructBuilder::new() + .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) + .with_scalar( + Field::new("value", value_sv.data_type(), true), + value_sv, + ) + .build() + }) + .collect::>>()?; + *name_idx = entry_name_idx; + + if entries.is_empty() { + return substrait_err!( + "Empty map must be encoded as EmptyMap literal type, not Map" + ); + } + + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(Field::new("entries", entries[0].data_type(), false)), + OffsetBuffer::new(vec![0, entries.len() as i32].into()), + ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), + None, + false, + ))) + } + Some(LiteralType::EmptyMap(m)) => { + let key = match &m.key { + Some(k) => Ok(k), + _ => plan_err!("Missing key type for empty map"), + }?; + let value = match &m.value { + Some(v) => Ok(v), + _ => plan_err!("Missing value type for empty map"), + }?; + let key_type = from_substrait_type(consumer, key, dfs_names, name_idx)?; + let value_type = from_substrait_type(consumer, value, dfs_names, name_idx)?; + + // new_empty_array on a MapType creates a too empty array + // We want it to contain an empty struct array to align with an empty MapBuilder one + let entries = Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + ); + let struct_array = + new_empty_array(entries.data_type()).as_struct().to_owned(); + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(entries), + OffsetBuffer::new(vec![0, 0].into()), + struct_array, + None, + false, + ))) + } + Some(LiteralType::Struct(s)) => { + let mut builder = ScalarStructBuilder::new(); + for (i, field) in s.fields.iter().enumerate() { + let name = next_struct_field_name(i, dfs_names, name_idx)?; + let sv = from_substrait_literal(consumer, field, dfs_names, name_idx)?; + // We assume everything to be nullable, since Arrow's strict about things matching + // and it's hard to match otherwise. + builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); + } + builder.build()? + } + Some(LiteralType::Null(null_type)) => { + let data_type = + from_substrait_type(consumer, null_type, dfs_names, name_idx)?; + ScalarValue::try_from(&data_type)? + } + Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode, + })) => { + use interval_day_to_second::PrecisionMode; + // DF only supports millisecond precision, so for any more granular type we lose precision + let milliseconds = match precision_mode { + Some(PrecisionMode::Microseconds(ms)) => ms / 1000, + None => + if *subseconds != 0 { + return substrait_err!("Cannot set subseconds field of IntervalDayToSecond without setting precision"); + } else { + 0_i32 + } + Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, + Some(PrecisionMode::Precision(3)) => *subseconds as i32, + Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, + Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, + _ => { + return not_impl_err!( + "Unsupported Substrait interval day to second precision mode: {precision_mode:?}") + } + }; + + ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) + } + Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { + ScalarValue::new_interval_ym(*years, *months) + } + Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month, + interval_day_to_second, + })) => match (interval_year_to_month, interval_day_to_second) { + ( + Some(IntervalYearToMonth { years, months }), + Some(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode: + Some(interval_day_to_second::PrecisionMode::Precision(p)), + }), + ) => { + if *p < 0 || *p > 9 { + return plan_err!( + "Unsupported Substrait interval day to second precision: {}", + p + ); + } + let nanos = *subseconds * i64::pow(10, (9 - p) as u32); + ScalarValue::new_interval_mdn( + *years * 12 + months, + *days, + *seconds as i64 * NANOSECONDS + nanos, + ) + } + _ => return plan_err!("Substrait compound interval missing components"), + }, + Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), + Some(LiteralType::UserDefined(user_defined)) => { + if let Ok(value) = consumer.consume_user_defined_literal(user_defined) { + return Ok(value); + } + + // TODO: remove the code below once the producer has been updated + + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed + let interval_month_day_nano = + |user_defined: &proto::expression::literal::UserDefined| -> datafusion::common::Result { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval month day nano value is empty"); + }; + let value_slice: [u8; 16] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval month day nano value" + ) + })?; + let months = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + let nanoseconds = + i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); + Ok(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months, + days, + nanoseconds, + }, + ))) + }; + + if let Some(name) = consumer + .get_extensions() + .types + .get(&user_defined.type_reference) + { + match name.as_ref() { + // Kept for backwards compatibility - producers should use IntervalCompound instead + #[allow(deprecated)] + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type with ref {} and name {}", + user_defined.type_reference, + name + ) + } + } + } else { + #[allow(deprecated)] + match user_defined.type_reference { + // Kept for backwards compatibility, producers should useIntervalYearToMonth instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval year month value is empty"); + }; + let value_slice: [u8; 4] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval year month value" + ) + })?; + ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( + value_slice, + ))) + } + // Kept for backwards compatibility, producers should useIntervalDayToSecond instead + INTERVAL_DAY_TIME_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval day time value is empty"); + }; + let value_slice: [u8; 8] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval day time value" + ) + })?; + let days = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let milliseconds = + i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days, + milliseconds, + })) + } + // Kept for backwards compatibility, producers should useIntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type literal with ref {}", + user_defined.type_reference + ) + } + } + } + } + _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), + }; + + Ok(scalar_value) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::utils::tests::test_consumer; + + #[test] + fn interval_compound_different_precision() -> datafusion::common::Result<()> { + // DF producer (and thus roundtrip) always uses precision = 9, + // this test exists to test with some other value. + let substrait = Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: 1, + months: 2, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: 3, + seconds: 4, + subseconds: 5, + precision_mode: Some( + interval_day_to_second::PrecisionMode::Precision(6), + ), + }), + })), + }; + + let consumer = test_consumer(); + assert_eq!( + from_substrait_literal_without_names(&consumer, &substrait)?, + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 14, + days: 3, + nanoseconds: 4_000_005_000 + })) + ); + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs new file mode 100644 index 0000000000000..7358f1422f1b4 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_function; +mod cast; +mod field_reference; +mod function_arguments; +mod if_then; +mod literal; +mod scalar_function; +mod singular_or_list; +mod subquery; +mod window_function; + +pub use aggregate_function::*; +pub use cast::*; +pub use field_reference::*; +pub use function_arguments::*; +pub use if_then::*; +pub use literal::*; +pub use scalar_function::*; +pub use singular_or_list::*; +pub use subquery::*; +pub use window_function::*; + +use crate::extensions::Extensions; +use crate::logical_plan::consumer::{ + from_substrait_named_struct, rename_field, DefaultSubstraitConsumer, + SubstraitConsumer, +}; +use datafusion::arrow::datatypes::Field; +use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, DFSchemaRef}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{Expr, ExprSchemable}; +use substrait::proto::expression::RexType; +use substrait::proto::expression_reference::ExprType; +use substrait::proto::{Expression, ExtendedExpression}; + +/// Convert Substrait Rex to DataFusion Expr +pub async fn from_substrait_rex( + consumer: &impl SubstraitConsumer, + expression: &Expression, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &expression.rex_type { + Some(t) => match t { + RexType::Literal(expr) => consumer.consume_literal(expr).await, + RexType::Selection(expr) => { + consumer.consume_field_reference(expr, input_schema).await + } + RexType::ScalarFunction(expr) => { + consumer.consume_scalar_function(expr, input_schema).await + } + RexType::WindowFunction(expr) => { + consumer.consume_window_function(expr, input_schema).await + } + RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await, + RexType::SwitchExpression(expr) => { + consumer.consume_switch(expr, input_schema).await + } + RexType::SingularOrList(expr) => { + consumer.consume_singular_or_list(expr, input_schema).await + } + + RexType::MultiOrList(expr) => { + consumer.consume_multi_or_list(expr, input_schema).await + } + + RexType::Cast(expr) => { + consumer.consume_cast(expr.as_ref(), input_schema).await + } + + RexType::Subquery(expr) => { + consumer.consume_subquery(expr.as_ref(), input_schema).await + } + RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, + RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, + RexType::DynamicParameter(expr) => { + consumer.consume_dynamic_parameter(expr, input_schema).await + } + }, + None => substrait_err!("Expression must set rex_type: {expression:?}"), + } +} + +/// Convert Substrait ExtendedExpression to ExprContainer +/// +/// A Substrait ExtendedExpression message contains one or more expressions, +/// with names for the outputs, and an input schema. These pieces are all included +/// in the ExprContainer. +/// +/// This is a top-level message and can be used to send expressions (not plans) +/// between systems. This is often useful for scenarios like pushdown where filter +/// expressions need to be sent to remote systems. +pub async fn from_substrait_extended_expr( + state: &SessionState, + extended_expr: &ExtendedExpression, +) -> datafusion::common::Result { + // Register function extension + let extensions = Extensions::try_from(&extended_expr.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let consumer = DefaultSubstraitConsumer { + extensions: &extensions, + state, + }; + + let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { + Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), + None => { + plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") + } + }?); + + // Parse expressions + let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); + for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { + let scalar_expr = match &substrait_expr.expr_type { + Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), + Some(ExprType::Measure(_)) => { + not_impl_err!("Measure expressions are not yet supported") + } + None => { + plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") + } + }?; + let expr = consumer + .consume_expression(scalar_expr, &input_schema) + .await?; + let (output_type, expected_nullability) = + expr.data_type_and_nullable(&input_schema)?; + let output_field = Field::new("", output_type, expected_nullability); + let mut names_idx = 0; + let output_field = rename_field( + &output_field, + &substrait_expr.output_names, + expr_idx, + &mut names_idx, + )?; + exprs.push((expr, output_field)); + } + + Ok(ExprContainer { + input_schema, + exprs, + }) +} + +/// An ExprContainer is a container for a collection of expressions with a common input schema +/// +/// In addition, each expression is associated with a field, which defines the +/// expression's output. The data type and nullability of the field are calculated from the +/// expression and the input schema. However the names of the field (and its nested fields) are +/// derived from the Substrait message. +pub struct ExprContainer { + /// The input schema for the expressions + pub input_schema: DFSchemaRef, + /// The expressions + /// + /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output + pub exprs: Vec<(Expr, Field)>, +} + +/// Convert Substrait Expressions to DataFusion Exprs +pub async fn from_substrait_rex_vec( + consumer: &impl SubstraitConsumer, + exprs: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut expressions: Vec = vec![]; + for expr in exprs { + let expression = consumer.consume_expression(expr, input_schema).await?; + expressions.push(expression); + } + Ok(expressions) +} + +#[cfg(test)] +mod tests { + use crate::extensions::Extensions; + use crate::logical_plan::consumer::utils::tests::test_consumer; + use crate::logical_plan::consumer::*; + use datafusion::common::DFSchema; + use datafusion::logical_expr::Expr; + use substrait::proto::expression::window_function::BoundsType; + use substrait::proto::expression::RexType; + use substrait::proto::Expression; + + #[tokio::test] + async fn window_function_with_range_unit_and_no_order_by( + ) -> datafusion::common::Result<()> { + let substrait = Expression { + rex_type: Some(RexType::WindowFunction( + substrait::proto::expression::WindowFunction { + function_reference: 0, + bounds_type: BoundsType::Range as i32, + sorts: vec![], + ..Default::default() + }, + )), + }; + + let mut consumer = test_consumer(); + + // Just registering a single function (index 0) so that the plan + // does not throw a "function not found" error. + let mut extensions = Extensions::default(); + extensions.register_function("count".to_string()); + consumer.extensions = &extensions; + + match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { + Expr::WindowFunction(window_function) => { + assert_eq!(window_function.params.order_by.len(), 1) + } + _ => panic!("expr was not a WindowFunction"), + }; + + Ok(()) + } + + #[tokio::test] + async fn window_function_with_count() -> datafusion::common::Result<()> { + let substrait = Expression { + rex_type: Some(RexType::WindowFunction( + substrait::proto::expression::WindowFunction { + function_reference: 0, + ..Default::default() + }, + )), + }; + + let mut consumer = test_consumer(); + + let mut extensions = Extensions::default(); + extensions.register_function("count".to_string()); + consumer.extensions = &extensions; + + match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { + Expr::WindowFunction(window_function) => { + assert_eq!(window_function.params.args.len(), 1) + } + _ => panic!("expr was not a WindowFunction"), + }; + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs new file mode 100644 index 0000000000000..f80cf43eb81eb --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -0,0 +1,479 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_func_args, SubstraitConsumer}; +use datafusion::common::Result; +use datafusion::common::{ + not_impl_err, plan_err, substrait_err, DFSchema, DataFusionError, ScalarValue, +}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{expr, Between, BinaryExpr, Expr, Like, Operator}; +use std::vec::Drain; +use substrait::proto::expression::ScalarFunction; + +pub async fn from_scalar_function( + consumer: &impl SubstraitConsumer, + f: &ScalarFunction, + input_schema: &DFSchema, +) -> Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { + return plan_err!( + "Scalar function not found: function reference = {:?}", + f.function_reference + ); + }; + + let fn_name = substrait_fun_name(fn_signature); + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + let udf_func = consumer.get_function_registry().udf(fn_name).or_else(|e| { + if let Some(alt_name) = substrait_to_df_name(fn_name) { + consumer.get_function_registry().udf(alt_name).or(Err(e)) + } else { + Err(e) + } + }); + + // try to first match the requested function into registered udfs, then built-in ops + // and finally built-in expressions + if let Ok(func) = udf_func { + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + func.to_owned(), + args, + ))) + } else if let Some(op) = name_to_op(fn_name) { + if args.len() < 2 { + return not_impl_err!( + "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", + f.arguments.len() + ); + } + // In those cases we build a balanced tree of BinaryExprs + arg_list_to_binary_op_tree(op, args) + } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { + builder.build(consumer, f, args).await + } else { + not_impl_err!("Unsupported function name: {fn_name:?}") + } +} + +pub fn substrait_fun_name(name: &str) -> &str { + let name = match name.rsplit_once(':') { + // Since 0.32.0, Substrait requires the function names to be in a compound format + // https://substrait.io/extensions/#function-signature-compound-names + // for example, `add:i8_i8`. + // On the consumer side, we don't really care about the signature though, just the name. + Some((name, _)) => name, + None => name, + }; + name +} + +pub fn name_to_op(name: &str) -> Option { + match name { + "equal" => Some(Operator::Eq), + "not_equal" => Some(Operator::NotEq), + "lt" => Some(Operator::Lt), + "lte" => Some(Operator::LtEq), + "gt" => Some(Operator::Gt), + "gte" => Some(Operator::GtEq), + "add" => Some(Operator::Plus), + "subtract" => Some(Operator::Minus), + "multiply" => Some(Operator::Multiply), + "divide" => Some(Operator::Divide), + "mod" => Some(Operator::Modulo), + "modulus" => Some(Operator::Modulo), + "and" => Some(Operator::And), + "or" => Some(Operator::Or), + "is_distinct_from" => Some(Operator::IsDistinctFrom), + "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), + "regex_match" => Some(Operator::RegexMatch), + "regex_imatch" => Some(Operator::RegexIMatch), + "regex_not_match" => Some(Operator::RegexNotMatch), + "regex_not_imatch" => Some(Operator::RegexNotIMatch), + "bitwise_and" => Some(Operator::BitwiseAnd), + "bitwise_or" => Some(Operator::BitwiseOr), + "str_concat" => Some(Operator::StringConcat), + "at_arrow" => Some(Operator::AtArrow), + "arrow_at" => Some(Operator::ArrowAt), + "bitwise_xor" => Some(Operator::BitwiseXor), + "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), + _ => None, + } +} + +pub fn substrait_to_df_name(name: &str) -> Option<&str> { + match name { + "is_nan" => Some("isnan"), + _ => None, + } +} + +/// Build a balanced tree of binary operations from a binary operator and a list of arguments. +/// +/// For example, `OR` `(a, b, c, d, e)` will be converted to: `OR(OR(a, OR(b, c)), OR(d, e))`. +/// +/// `args` must not be empty. +fn arg_list_to_binary_op_tree(op: Operator, mut args: Vec) -> Result { + let n_args = args.len(); + let mut drained_args = args.drain(..); + arg_list_to_binary_op_tree_inner(op, &mut drained_args, n_args) +} + +/// Helper function for [`arg_list_to_binary_op_tree`] implementation +/// +/// `take_len` represents the number of elements to take from `args` before returning. +/// We use `take_len` to avoid recursively building a `Take>>` type. +fn arg_list_to_binary_op_tree_inner( + op: Operator, + args: &mut Drain, + take_len: usize, +) -> Result { + if take_len == 1 { + return args.next().ok_or_else(|| { + DataFusionError::Substrait( + "Expected one more available element in iterator, found none".to_string(), + ) + }); + } else if take_len == 0 { + return substrait_err!("Cannot build binary operation tree with 0 arguments"); + } + // Cut argument list in 2 balanced parts + let left_take = take_len / 2; + let right_take = take_len - left_take; + let left = arg_list_to_binary_op_tree_inner(op, args, left_take)?; + let right = arg_list_to_binary_op_tree_inner(op, args, right_take)?; + Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(left), + op, + right: Box::new(right), + })) +} + +/// Build [`Expr`] from its name and required inputs. +struct BuiltinExprBuilder { + expr_name: String, +} + +impl BuiltinExprBuilder { + pub fn try_from_name(name: &str) -> Option { + match name { + "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" | "negative" | "negate" | "and_not" | "xor" + | "between" | "logb" => Some(Self { + expr_name: name.to_string(), + }), + _ => None, + } + } + + pub async fn build( + self, + consumer: &impl SubstraitConsumer, + f: &ScalarFunction, + args: Vec, + ) -> Result { + match self.expr_name.as_str() { + "like" => Self::build_like_expr(false, f, args).await, + "ilike" => Self::build_like_expr(true, f, args).await, + "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" => Self::build_unary_expr(&self.expr_name, args).await, + "and_not" | "xor" => Self::build_binary_expr(&self.expr_name, args).await, + "between" => Self::build_between_expr(&self.expr_name, args).await, + "logb" => { + Self::build_custom_handling_expr(consumer, &self.expr_name, args).await + } + _ => { + not_impl_err!("Unsupported builtin expression: {}", self.expr_name) + } + } + } + + async fn build_unary_expr(fn_name: &str, args: Vec) -> Result { + let [arg] = match args.try_into() { + Ok(args_arr) => args_arr, + Err(_) => return substrait_err!("Expected one argument for {fn_name} expr"), + }; + let arg = Box::new(arg); + + let expr = match fn_name { + "not" => Expr::Not(arg), + "negative" | "negate" => Expr::Negative(arg), + "is_null" => Expr::IsNull(arg), + "is_not_null" => Expr::IsNotNull(arg), + "is_true" => Expr::IsTrue(arg), + "is_false" => Expr::IsFalse(arg), + "is_not_true" => Expr::IsNotTrue(arg), + "is_not_false" => Expr::IsNotFalse(arg), + "is_unknown" => Expr::IsUnknown(arg), + "is_not_unknown" => Expr::IsNotUnknown(arg), + _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), + }; + + Ok(expr) + } + + async fn build_like_expr( + case_insensitive: bool, + f: &ScalarFunction, + args: Vec, + ) -> Result { + let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; + if args.len() != 2 && args.len() != 3 { + return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); + } + + let mut args_iter = args.into_iter(); + let Some(expr) = args_iter.next() else { + return substrait_err!("Missing first argument for {fn_name} expression"); + }; + let Some(pattern) = args_iter.next() else { + return substrait_err!("Missing second argument for {fn_name} expression"); + }; + + // Default case: escape character is Literal(Utf8(None)) + let escape_char = if f.arguments.len() == 3 { + let Some(escape_char_expr) = args_iter.next() else { + return substrait_err!("Missing third argument for {fn_name} expression"); + }; + + match escape_char_expr { + Expr::Literal(ScalarValue::Utf8(escape_char_string), _) => { + // Convert Option to Option + escape_char_string.and_then(|s| s.chars().next()) + } + _ => { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ) + } + } + } else { + None + }; + + Ok(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char, + case_insensitive, + })) + } + + async fn build_binary_expr(fn_name: &str, args: Vec) -> Result { + let [a, b] = match args.try_into() { + Ok(args_arr) => args_arr, + Err(_) => { + return substrait_err!("Expected two arguments for `{fn_name}` expr") + } + }; + match fn_name { + "and_not" => Ok(Self::build_and_not_expr(a, b)), + "xor" => Ok(Self::build_xor_expr(a, b)), + _ => not_impl_err!("Unsupported builtin expression: {}", fn_name), + } + } + + fn build_and_not_expr(a: Expr, b: Expr) -> Expr { + a.and(Expr::Not(Box::new(b))) + } + + fn build_xor_expr(a: Expr, b: Expr) -> Expr { + let or_expr = a.clone().or(b.clone()); + let and_expr = a.and(b); + Self::build_and_not_expr(or_expr, and_expr) + } + + async fn build_between_expr(fn_name: &str, args: Vec) -> Result { + let [expression, low, high] = match args.try_into() { + Ok(args_arr) => args_arr, + Err(_) => { + return substrait_err!("Expected three arguments for `{fn_name}` expr") + } + }; + + Ok(Expr::Between(Between { + expr: Box::new(expression), + negated: false, + low: Box::new(low), + high: Box::new(high), + })) + } + + //This handles any functions that require custom handling + async fn build_custom_handling_expr( + consumer: &impl SubstraitConsumer, + fn_name: &str, + args: Vec, + ) -> Result { + match fn_name { + "logb" => Self::build_logb_expr(consumer, args).await, + _ => not_impl_err!("Unsupported custom handled expression: {}", fn_name), + } + } + + async fn build_logb_expr( + consumer: &impl SubstraitConsumer, + args: Vec, + ) -> Result { + if args.len() != 2 { + return substrait_err!("Expect two arguments for logb function"); + } + + let mut args = args; + args.swap(0, 1); + + //The equivalent of logb in DataFusion is the log function (which has its arguments in reverse order) + if let Ok(func) = consumer.get_function_registry().udf("log") { + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + func.to_owned(), + args, + ))) + } else { + not_impl_err!("Unsupported function name: logb") + } + } +} + +#[cfg(test)] +mod tests { + use super::arg_list_to_binary_op_tree; + use crate::extensions::Extensions; + use crate::logical_plan::consumer::tests::TEST_SESSION_STATE; + use crate::logical_plan::consumer::{DefaultSubstraitConsumer, SubstraitConsumer}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::{DFSchema, Result, ScalarValue}; + use datafusion::logical_expr::{Expr, Operator}; + use insta::assert_snapshot; + use substrait::proto::expression::literal::LiteralType; + use substrait::proto::expression::{Literal, RexType, ScalarFunction}; + use substrait::proto::function_argument::ArgType; + use substrait::proto::{Expression, FunctionArgument}; + + /// Test that large argument lists for binary operations do not crash the consumer + #[tokio::test] + async fn test_binary_op_large_argument_list() -> Result<()> { + // Build substrait extensions (we are using only one function) + let mut extensions = Extensions::default(); + extensions.functions.insert(0, String::from("or:bool_bool")); + // Build substrait consumer + let consumer = DefaultSubstraitConsumer::new(&extensions, &TEST_SESSION_STATE); + + // Build arguments for the function call, this is basically an OR(true, true, ..., true) + let arg = FunctionArgument { + arg_type: Some(ArgType::Value(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::Boolean(true)), + })), + })), + }; + let arguments = vec![arg; 50000]; + let func = ScalarFunction { + function_reference: 0, + arguments, + ..Default::default() + }; + // Trivial input schema + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + + // Consume the expression and ensure we don't crash + let _ = consumer.consume_scalar_function(&func, &df_schema).await?; + Ok(()) + } + + fn int64_literals(integers: &[i64]) -> Vec { + integers + .iter() + .map(|value| Expr::Literal(ScalarValue::Int64(Some(*value)), None)) + .collect() + } + + #[test] + fn arg_list_to_binary_op_tree_1_arg() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1]))?; + assert_snapshot!(expr.to_string(), @"Int64(1)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_2_args() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_3_args() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2, 3]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_4_args() -> Result<()> { + let expr = + arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2, 3, 4]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3) OR Int64(4)"); + Ok(()) + } + + //Test that DataFusion can consume scalar functions that have a different name in Substrait + #[tokio::test] + async fn test_substrait_to_df_name_mapping() -> Result<()> { + // Build substrait extensions (we are using only one function) + let mut extensions = Extensions::default(); + //is_nan is one of the functions that has a different name in Substrait (mapping is in substrait_to_df_name()) + extensions.functions.insert(0, String::from("is_nan:fp32")); + // Build substrait consumer + let consumer = DefaultSubstraitConsumer::new(&extensions, &TEST_SESSION_STATE); + + // Build arguments for the function call + let arg = FunctionArgument { + arg_type: Some(ArgType::Value(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::Fp32(1.0)), + })), + })), + }; + let arguments = vec![arg]; + let func = ScalarFunction { + function_reference: 0, + arguments, + ..Default::default() + }; + // Trivial input schema + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + + // Consume the expression and ensure we don't get an error + let _ = consumer.consume_scalar_function(&func, &df_schema).await?; + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs b/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs new file mode 100644 index 0000000000000..6d44ebcce5908 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_rex_vec, SubstraitConsumer}; +use datafusion::common::DFSchema; +use datafusion::logical_expr::expr::InList; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::SingularOrList; + +pub async fn from_singular_or_list( + consumer: &impl SubstraitConsumer, + expr: &SingularOrList, + input_schema: &DFSchema, +) -> datafusion::common::Result { + let substrait_expr = expr.value.as_ref().unwrap(); + let substrait_list = expr.options.as_ref(); + Ok(Expr::InList(InList { + expr: Box::new( + consumer + .consume_expression(substrait_expr, input_schema) + .await?, + ), + list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, + negated: false, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs new file mode 100644 index 0000000000000..917bcc007716b --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{substrait_err, DFSchema, Spans}; +use datafusion::logical_expr::expr::{Exists, InSubquery}; +use datafusion::logical_expr::{Expr, Subquery}; +use std::sync::Arc; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::subquery::set_predicate::PredicateOp; +use substrait::proto::expression::subquery::SubqueryType; + +pub async fn from_subquery( + consumer: &impl SubstraitConsumer, + subquery: &substrait_expression::Subquery, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &subquery.subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + substrait_err!("InPredicate Subquery type must have exactly one Needle expression") + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = consumer.consume_rel(haystack_expr).await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Expr::InSubquery(InSubquery { + expr: Box::new( + consumer + .consume_expression(needle_expr, input_schema) + .await?, + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + spans: Spans::new(), + }, + negated: false, + })) + } else { + substrait_err!( + "InPredicate Subquery type must have a Haystack expression" + ) + } + } + } + SubqueryType::Scalar(query) => { + let plan = consumer + .consume_rel(&(query.input.clone()).unwrap_or_default()) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + })) + } + SubqueryType::SetPredicate(predicate) => { + match predicate.predicate_op() { + // exist + PredicateOp::Exists => { + let relation = &predicate.tuples; + let plan = consumer + .consume_rel(&relation.clone().unwrap_or_default()) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + }, + false, + ))) + } + other_type => substrait_err!( + "unimplemented type {other_type:?} for set predicate" + ), + } + } + other_type => { + substrait_err!("Subquery type {other_type:?} not implemented") + } + }, + None => { + substrait_err!("Subquery expression without SubqueryType is not allowed") + } + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs new file mode 100644 index 0000000000000..3399d660df62b --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{ + from_substrait_func_args, from_substrait_rex_vec, from_substrait_sorts, + substrait_fun_name, SubstraitConsumer, +}; +use datafusion::common::{ + not_impl_err, plan_datafusion_err, plan_err, substrait_err, DFSchema, ScalarValue, +}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::expr::WindowFunctionParams; +use datafusion::logical_expr::{ + expr, Expr, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, +}; +use substrait::proto::expression::window_function::{Bound, BoundsType}; +use substrait::proto::expression::WindowFunction; +use substrait::proto::expression::{ + window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind, +}; + +pub async fn from_window_function( + consumer: &impl SubstraitConsumer, + window: &WindowFunction, + input_schema: &DFSchema, +) -> datafusion::common::Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&window.function_reference) + else { + return plan_err!( + "Window function not found: function reference = {:?}", + window.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + + // check udwf first, then udaf, then built-in window and aggregate functions + let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) { + Ok(WindowFunctionDefinition::WindowUDF(udwf)) + } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf)) + } else { + not_impl_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) + }?; + + let mut order_by = + from_substrait_sorts(consumer, &window.sorts, input_schema).await?; + + let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { + plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) + })? { + BoundsType::Rows => WindowFrameUnits::Rows, + BoundsType::Range => WindowFrameUnits::Range, + BoundsType::Unspecified => { + // If the plan does not specify the bounds type, then we use a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + } + } + }; + let window_frame = datafusion::logical_expr::WindowFrame::new_bounds( + bound_units, + from_substrait_bound(&window.lower_bound, true)?, + from_substrait_bound(&window.upper_bound, false)?, + ); + + window_frame.regularize_order_bys(&mut order_by)?; + + // Datafusion does not support aggregate functions with no arguments, so + // we inject a dummy argument that does not affect the query, but allows + // us to bypass this limitation. + let args = if fun.name() == "count" && window.arguments.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)] + } else { + from_substrait_func_args(consumer, &window.arguments, input_schema).await? + }; + + Ok(Expr::from(expr::WindowFunction { + fun, + params: WindowFunctionParams { + args, + partition_by: from_substrait_rex_vec( + consumer, + &window.partitions, + input_schema, + ) + .await?, + order_by, + window_frame, + filter: None, + null_treatment: None, + distinct: false, + }, + })) +} + +fn from_substrait_bound( + bound: &Option, + is_lower: bool, +) -> datafusion::common::Result { + match bound { + Some(b) => match &b.kind { + Some(k) => match k { + BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { + Ok(WindowFrameBound::CurrentRow) + } + BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { + if *offset <= 0 { + return plan_err!("Preceding bound must be positive"); + } + Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Following(SubstraitBound::Following { offset }) => { + if *offset <= 0 { + return plan_err!("Following bound must be positive"); + } + Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { + if is_lower { + Ok(WindowFrameBound::Preceding(ScalarValue::Null)) + } else { + Ok(WindowFrameBound::Following(ScalarValue::Null)) + } + } + }, + None => substrait_err!("WindowFunction missing Substrait Bound kind"), + }, + None => { + if is_lower { + Ok(WindowFrameBound::Preceding(ScalarValue::Null)) + } else { + Ok(WindowFrameBound::Following(ScalarValue::Null)) + } + } + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/mod.rs b/datafusion/substrait/src/logical_plan/consumer/mod.rs new file mode 100644 index 0000000000000..0e01d6ded6e4e --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod expr; +mod plan; +mod rel; +mod substrait_consumer; +mod types; +mod utils; + +pub use expr::*; +pub use plan::*; +pub use rel::*; +pub use substrait_consumer::*; +pub use types::*; +pub use utils::*; diff --git a/datafusion/substrait/src/logical_plan/consumer/plan.rs b/datafusion/substrait/src/logical_plan/consumer/plan.rs new file mode 100644 index 0000000000000..f994f792a17ea --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/plan.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::utils::{make_renamed_schema, rename_expressions}; +use super::{DefaultSubstraitConsumer, SubstraitConsumer}; +use crate::extensions::Extensions; +use datafusion::common::{not_impl_err, plan_err}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{col, Aggregate, LogicalPlan, Projection}; +use std::sync::Arc; +use substrait::proto::{plan_rel, Plan}; + +/// Convert Substrait Plan to DataFusion LogicalPlan +pub async fn from_substrait_plan( + state: &SessionState, + plan: &Plan, +) -> datafusion::common::Result { + // Register function extension + let extensions = Extensions::try_from(&plan.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let consumer = DefaultSubstraitConsumer { + extensions: &extensions, + state, + }; + from_substrait_plan_with_consumer(&consumer, plan).await +} + +/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer +pub async fn from_substrait_plan_with_consumer( + consumer: &impl SubstraitConsumer, + plan: &Plan, +) -> datafusion::common::Result { + match plan.relations.len() { + 1 => { + match plan.relations[0].rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => Ok(consumer.consume_rel(rel).await?), + plan_rel::RelType::Root(root) => { + let plan = consumer.consume_rel(root.input.as_ref().unwrap()).await?; + if root.names.is_empty() { + // Backwards compatibility for plans missing names + return Ok(plan); + } + let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; + if renamed_schema.has_equivalent_names_and_types(plan.schema()).is_ok() { + // Nothing to do if the schema is already equivalent + return Ok(plan); + } + match plan { + // If the last node of the plan produces expressions, bake the renames into those expressions. + // This isn't necessary for correctness, but helps with roundtrip tests. + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), + LogicalPlan::Aggregate(a) => { + let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); + let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; + Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) + }, + // There are probably more plans where we could bake things in, can add them later as needed. + // Otherwise, add a new Project to handle the renaming. + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) + } + } + }, + None => plan_err!("Cannot parse plan relation: None") + } + }, + _ => not_impl_err!( + "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", + plan.relations.len() + ) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs new file mode 100644 index 0000000000000..c919bd038936d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_agg_func, from_substrait_sorts}; +use crate::logical_plan::consumer::{NameTracker, SubstraitConsumer}; +use datafusion::common::{not_impl_err, DFSchemaRef}; +use datafusion::logical_expr::{Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::aggregate_function::AggregationInvocation; +use substrait::proto::aggregate_rel::Grouping; +use substrait::proto::AggregateRel; + +pub async fn from_aggregate_rel( + consumer: &impl SubstraitConsumer, + agg: &AggregateRel, +) -> datafusion::common::Result { + if let Some(input) = agg.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let mut ref_group_exprs = vec![]; + + for e in &agg.grouping_expressions { + let x = consumer.consume_expression(e, input.schema()).await?; + ref_group_exprs.push(x); + } + + let mut group_exprs = vec![]; + let mut aggr_exprs = vec![]; + + match agg.groupings.len() { + 1 => { + group_exprs.extend_from_slice( + &from_substrait_grouping( + consumer, + &agg.groupings[0], + &ref_group_exprs, + input.schema(), + ) + .await?, + ); + } + _ => { + let mut grouping_sets = vec![]; + for grouping in &agg.groupings { + let grouping_set = from_substrait_grouping( + consumer, + grouping, + &ref_group_exprs, + input.schema(), + ) + .await?; + grouping_sets.push(grouping_set); + } + // Single-element grouping expression of type Expr::GroupingSet. + // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when + // parsed by the producer and consumer, since Substrait does not have a type dedicated + // to ROLLUP. Only vector of Groupings (grouping sets) is available. + group_exprs + .push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets))); + } + }; + + for m in &agg.measures { + let filter = match &m.filter { + Some(fil) => Some(Box::new( + consumer.consume_expression(fil, input.schema()).await?, + )), + None => None, + }; + let agg_func = match &m.measure { + Some(f) => { + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => { + true + } + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false, + }; + let order_by = + from_substrait_sorts(consumer, &f.sorts, input.schema()).await?; + + from_substrait_agg_func( + consumer, + f, + input.schema(), + filter, + order_by, + distinct, + ) + .await + } + None => { + not_impl_err!("Aggregate without aggregate function is not supported") + } + }; + aggr_exprs.push(agg_func?.as_ref().clone()); + } + + // Ensure that all expressions have a unique name + let mut name_tracker = NameTracker::new(); + let group_exprs = group_exprs + .iter() + .map(|e| name_tracker.get_uniquely_named_expr(e.clone())) + .collect::, _>>()?; + + input.aggregate(group_exprs, aggr_exprs)?.build() + } else { + not_impl_err!("Aggregate without an input is not valid") + } +} + +#[allow(deprecated)] +async fn from_substrait_grouping( + consumer: &impl SubstraitConsumer, + grouping: &Grouping, + expressions: &[Expr], + input_schema: &DFSchemaRef, +) -> datafusion::common::Result> { + let mut group_exprs = vec![]; + if !grouping.grouping_expressions.is_empty() { + for e in &grouping.grouping_expressions { + let expr = consumer.consume_expression(e, input_schema).await?; + group_exprs.push(expr); + } + return Ok(group_exprs); + } + for idx in &grouping.expression_references { + let e = &expressions[*idx as usize]; + group_exprs.push(e.clone()); + } + Ok(group_exprs) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs new file mode 100644 index 0000000000000..25c66a8e22972 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; + +use datafusion::logical_expr::requalify_sides_if_needed; + +use substrait::proto::CrossRel; + +pub async fn from_cross_rel( + consumer: &impl SubstraitConsumer, + cross: &CrossRel, +) -> datafusion::common::Result { + let left = LogicalPlanBuilder::from( + consumer.consume_rel(cross.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + consumer.consume_rel(cross.right.as_ref().unwrap()).await?, + ); + let (left, right, _requalified) = requalify_sides_if_needed(left, right)?; + left.cross_join(right.build()?)?.build() +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs new file mode 100644 index 0000000000000..d326fff44bbbd --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::from_substrait_field_reference; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, substrait_err}; +use datafusion::logical_expr::{LogicalPlan, Partitioning, Repartition}; +use std::sync::Arc; +use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::ExchangeRel; + +pub async fn from_exchange_rel( + consumer: &impl SubstraitConsumer, + exchange: &ExchangeRel, +) -> datafusion::common::Result { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(consumer.consume_rel(input).await?); + + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash(partition_columns, exchange.partition_count as usize) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); + } + }; + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs new file mode 100644 index 0000000000000..74161d8600ea6 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, DFSchema, DFSchemaRef}; +use datafusion::logical_expr::{lit, LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::{fetch_rel, FetchRel}; + +#[async_recursion] +pub async fn from_fetch_rel( + consumer: &impl SubstraitConsumer, + fetch: &FetchRel, +) -> datafusion::common::Result { + if let Some(input) = fetch.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let offset = match &fetch.offset_mode { + Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), + Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { + Some(consumer.consume_expression(expr, &empty_schema).await?) + } + None => None, + }; + let count = match &fetch.count_mode { + Some(fetch_rel::CountMode::Count(count)) => { + // -1 means that ALL records should be returned, equivalent to None + (*count != -1).then(|| lit(*count)) + } + Some(fetch_rel::CountMode::CountExpr(expr)) => { + Some(consumer.consume_expression(expr, &empty_schema).await?) + } + None => None, + }; + input.limit_by_expr(offset, count)?.build() + } else { + not_impl_err!("Fetch without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs new file mode 100644 index 0000000000000..645b98278208d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::FilterRel; + +#[async_recursion] +pub async fn from_filter_rel( + consumer: &impl SubstraitConsumer, + filter: &FilterRel, +) -> datafusion::common::Result { + if let Some(input) = filter.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + if let Some(condition) = filter.condition.as_ref() { + let expr = consumer + .consume_expression(condition, input.schema()) + .await?; + input.filter(expr)?.build() + } else { + not_impl_err!("Filter without an condition is not valid") + } + } else { + not_impl_err!("Filter without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs new file mode 100644 index 0000000000000..5681c92326e1a --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, plan_err, Column, JoinType, NullEquality}; +use datafusion::logical_expr::requalify_sides_if_needed; +use datafusion::logical_expr::utils::split_conjunction; +use datafusion::logical_expr::{ + BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, +}; + +use substrait::proto::{join_rel, JoinRel}; + +pub async fn from_join_rel( + consumer: &impl SubstraitConsumer, + join: &JoinRel, +) -> datafusion::common::Result { + if join.post_join_filter.is_some() { + return not_impl_err!("JoinRel with post_join_filter is not yet supported"); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + consumer.consume_rel(join.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + consumer.consume_rel(join.right.as_ref().unwrap()).await?, + ); + let (left, right, _requalified) = requalify_sides_if_needed(left, right)?; + + let join_type = from_substrait_jointype(join.r#type)?; + // The join condition expression needs full input schema and not the output schema from join since we lose columns from + // certain join types such as semi and anti joins + let in_join_schema = left.schema().join(right.schema())?; + + // If join expression exists, parse the `on` condition expression, build join and return + // Otherwise, build join with only the filter, without join keys + match &join.expression.as_ref() { + Some(expr) => { + let on = consumer.consume_expression(expr, &in_join_schema).await?; + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); + let null_equality = if nulls_equal_nulls { + NullEquality::NullEqualsNull + } else { + NullEquality::NullEqualsNothing + }; + left.join_detailed( + right.build()?, + join_type, + (left_cols, right_cols), + join_filter, + null_equality, + )? + .build() + } + None => { + let on: Vec = vec![]; + left.join_detailed( + right.build()?, + join_type, + (on.clone(), on), + None, + NullEquality::NullEqualsNothing, + )? + .build() + } + } +} + +fn split_eq_and_noneq_join_predicate_with_nulls_equality( + filter: &Expr, +) -> (Vec<(Column, Column)>, bool, Option) { + let exprs = split_conjunction(filter); + + let mut accum_join_keys: Vec<(Column, Column)> = vec![]; + let mut accum_filters: Vec = vec![]; + let mut nulls_equal_nulls = false; + + for expr in exprs { + #[allow(clippy::collapsible_match)] + match expr { + Expr::BinaryExpr(binary_expr) => match binary_expr { + x @ (BinaryExpr { + left, + op: Operator::Eq, + right, + } + | BinaryExpr { + left, + op: Operator::IsNotDistinctFrom, + right, + }) => { + nulls_equal_nulls = match x.op { + Operator::Eq => false, + Operator::IsNotDistinctFrom => true, + _ => unreachable!(), + }; + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => { + accum_join_keys.push((l.clone(), r.clone())); + } + _ => accum_filters.push(expr.clone()), + } + } + _ => accum_filters.push(expr.clone()), + }, + _ => accum_filters.push(expr.clone()), + } + } + + let join_filter = accum_filters.into_iter().reduce(Expr::and); + (accum_join_keys, nulls_equal_nulls, join_filter) +} + +fn from_substrait_jointype(join_type: i32) -> datafusion::common::Result { + if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { + match substrait_join_type { + join_rel::JoinType::Inner => Ok(JoinType::Inner), + join_rel::JoinType::Left => Ok(JoinType::Left), + join_rel::JoinType::Right => Ok(JoinType::Right), + join_rel::JoinType::Outer => Ok(JoinType::Full), + join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), + join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), + join_rel::JoinType::RightMark => Ok(JoinType::RightMark), + join_rel::JoinType::RightAnti => Ok(JoinType::RightAnti), + join_rel::JoinType::RightSemi => Ok(JoinType::RightSemi), + _ => plan_err!("unsupported join type {substrait_join_type:?}"), + } + } else { + plan_err!("invalid join type variant {join_type}") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs b/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs new file mode 100644 index 0000000000000..a83ddd8997b29 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_rel; +mod cross_rel; +mod exchange_rel; +mod fetch_rel; +mod filter_rel; +mod join_rel; +mod project_rel; +mod read_rel; +mod set_rel; +mod sort_rel; + +pub use aggregate_rel::*; +pub use cross_rel::*; +pub use exchange_rel::*; +pub use fetch_rel::*; +pub use filter_rel::*; +pub use join_rel::*; +pub use project_rel::*; +pub use read_rel::*; +pub use set_rel::*; +pub use sort_rel::*; + +use crate::logical_plan::consumer::utils::NameTracker; +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, substrait_datafusion_err, substrait_err, Column}; +use datafusion::logical_expr::builder::project; +use datafusion::logical_expr::{Expr, LogicalPlan, Projection}; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::rel_common::{Emit, EmitKind}; +use substrait::proto::{rel_common, Rel, RelCommon}; + +/// Convert Substrait Rel to DataFusion DataFrame +#[async_recursion] +pub async fn from_substrait_rel( + consumer: &impl SubstraitConsumer, + relation: &Rel, +) -> datafusion::common::Result { + let plan: datafusion::common::Result = match &relation.rel_type { + Some(rel_type) => match rel_type { + RelType::Read(rel) => consumer.consume_read(rel).await, + RelType::Filter(rel) => consumer.consume_filter(rel).await, + RelType::Fetch(rel) => consumer.consume_fetch(rel).await, + RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await, + RelType::Sort(rel) => consumer.consume_sort(rel).await, + RelType::Join(rel) => consumer.consume_join(rel).await, + RelType::Project(rel) => consumer.consume_project(rel).await, + RelType::Set(rel) => consumer.consume_set(rel).await, + RelType::ExtensionSingle(rel) => consumer.consume_extension_single(rel).await, + RelType::ExtensionMulti(rel) => consumer.consume_extension_multi(rel).await, + RelType::ExtensionLeaf(rel) => consumer.consume_extension_leaf(rel).await, + RelType::Cross(rel) => consumer.consume_cross(rel).await, + RelType::Window(rel) => { + consumer.consume_consistent_partition_window(rel).await + } + RelType::Exchange(rel) => consumer.consume_exchange(rel).await, + rt => not_impl_err!("{rt:?} rel not supported yet"), + }, + None => return substrait_err!("rel must set rel_type"), + }; + apply_emit_kind(retrieve_rel_common(relation), plan?) +} + +fn apply_emit_kind( + rel_common: Option<&RelCommon>, + plan: LogicalPlan, +) -> datafusion::common::Result { + match retrieve_emit_kind(rel_common) { + EmitKind::Direct(_) => Ok(plan), + EmitKind::Emit(Emit { output_mapping }) => { + // It is valid to reference the same field multiple times in the Emit + // In this case, we need to provide unique names to avoid collisions + let mut name_tracker = NameTracker::new(); + match plan { + // To avoid adding a projection on top of a projection, we apply special case + // handling to flatten Substrait Emits. This is only applicable if none of the + // expressions in the projection are volatile. This is to avoid issues like + // converting a single call of the random() function into multiple calls due to + // duplicate fields in the output_mapping. + LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { + let mut exprs: Vec = vec![]; + for field in output_mapping { + let expr = proj.expr + .get(field as usize) + .ok_or_else(|| substrait_datafusion_err!( + "Emit output field {} cannot be resolved in input schema {}", + field, proj.input.schema() + ))?; + exprs.push(name_tracker.get_uniquely_named_expr(expr.clone())?); + } + + let input = Arc::unwrap_or_clone(proj.input); + project(input, exprs) + } + // Otherwise we just handle the output_mapping as a projection + _ => { + let input_schema = plan.schema(); + + let mut exprs: Vec = vec![]; + for index in output_mapping.into_iter() { + let column = Expr::Column(Column::from( + input_schema.qualified_field(index as usize), + )); + let expr = name_tracker.get_uniquely_named_expr(column)?; + exprs.push(expr); + } + + project(plan, exprs) + } + } + } + } +} + +fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { + match rel.rel_type.as_ref() { + None => None, + Some(rt) => match rt { + RelType::Read(r) => r.common.as_ref(), + RelType::Filter(f) => f.common.as_ref(), + RelType::Fetch(f) => f.common.as_ref(), + RelType::Aggregate(a) => a.common.as_ref(), + RelType::Sort(s) => s.common.as_ref(), + RelType::Join(j) => j.common.as_ref(), + RelType::Project(p) => p.common.as_ref(), + RelType::Set(s) => s.common.as_ref(), + RelType::ExtensionSingle(e) => e.common.as_ref(), + RelType::ExtensionMulti(e) => e.common.as_ref(), + RelType::ExtensionLeaf(e) => e.common.as_ref(), + RelType::Cross(c) => c.common.as_ref(), + RelType::Reference(_) => None, + RelType::Write(w) => w.common.as_ref(), + RelType::Ddl(d) => d.common.as_ref(), + RelType::HashJoin(j) => j.common.as_ref(), + RelType::MergeJoin(j) => j.common.as_ref(), + RelType::NestedLoopJoin(j) => j.common.as_ref(), + RelType::Window(w) => w.common.as_ref(), + RelType::Exchange(e) => e.common.as_ref(), + RelType::Expand(e) => e.common.as_ref(), + RelType::Update(_) => None, + }, + } +} + +fn retrieve_emit_kind(rel_common: Option<&RelCommon>) -> EmitKind { + // the default EmitKind is Direct if it is not set explicitly + let default = EmitKind::Direct(rel_common::Direct {}); + rel_common + .and_then(|rc| rc.emit_kind.as_ref()) + .map_or(default, |ek| ek.clone()) +} + +fn contains_volatile_expr(proj: &Projection) -> bool { + proj.expr.iter().any(|e| e.is_volatile()) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs new file mode 100644 index 0000000000000..239073108ce50 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::utils::NameTracker; +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, Column}; +use datafusion::logical_expr::builder::project; +use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use std::collections::HashSet; +use std::sync::Arc; +use substrait::proto::ProjectRel; + +#[async_recursion] +pub async fn from_project_rel( + consumer: &impl SubstraitConsumer, + p: &ProjectRel, +) -> datafusion::common::Result { + if let Some(input) = p.input.as_ref() { + let input = consumer.consume_rel(input).await?; + let original_schema = Arc::clone(input.schema()); + + // Ensure that all expressions have a unique display name, so that + // validate_unique_names does not fail when constructing the project. + let mut name_tracker = NameTracker::new(); + + // By default, a Substrait Project emits all inputs fields followed by all expressions. + // We build the explicit expressions first, and then the input expressions to avoid + // adding aliases to the explicit expressions (as part of ensuring unique names). + // + // This is helpful for plan visualization and tests, because when DataFusion produces + // Substrait Projects it adds an output mapping that excludes all input columns + // leaving only explicit expressions. + + let mut explicit_exprs: Vec = vec![]; + // For WindowFunctions, we need to wrap them in a Window relation. If there are duplicates, + // we can do the window'ing only once, then the project will duplicate the result. + // Order here doesn't matter since LPB::window_plan sorts the expressions. + let mut window_exprs: HashSet = HashSet::new(); + for expr in &p.expressions { + let e = consumer + .consume_expression(expr, input.clone().schema()) + .await?; + // if the expression is WindowFunction, wrap in a Window relation + if let Expr::WindowFunction(_) = &e { + // Adding the same expression here and in the project below + // works because the project's builder uses columnize_expr(..) + // to transform it into a column reference + window_exprs.insert(e.clone()); + } + // Substrait plans are ordinal based, so they do not provide names for columns. + // Names for columns are generated by Datafusion during conversion, and for literals + // Datafusion produces names based on the literal value. It is possible to construct + // valid Substrait plans that result in duplicated names if the same literal value is + // used in multiple relations. To avoid this issue, we alias literals with unique names. + // The name tracker will ensure that two literals in the same project would have + // unique names but, it does not ensure that if a literal column exists in a previous + // project say before a join that it is deduplicated with respect to those columns. + // See: https://github.com/apache/datafusion/pull/17299 + let maybe_apply_alias = match e { + lit @ Expr::Literal(_, _) => lit.alias(uuid::Uuid::new_v4().to_string()), + _ => e, + }; + explicit_exprs.push(name_tracker.get_uniquely_named_expr(maybe_apply_alias)?); + } + + let input = if !window_exprs.is_empty() { + LogicalPlanBuilder::window_plan(input, window_exprs)? + } else { + input + }; + + let mut final_exprs: Vec = vec![]; + for index in 0..original_schema.fields().len() { + let e = Expr::Column(Column::from(original_schema.qualified_field(index))); + final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); + } + final_exprs.append(&mut explicit_exprs); + project(input, final_exprs) + } else { + not_impl_err!("Projection without an input is not supported") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs new file mode 100644 index 0000000000000..48e93c04bb034 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs @@ -0,0 +1,304 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::from_substrait_literal; +use crate::logical_plan::consumer::from_substrait_named_struct; +use crate::logical_plan::consumer::utils::ensure_schema_compatibility; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{ + not_impl_err, plan_err, substrait_datafusion_err, substrait_err, DFSchema, + DFSchemaRef, TableReference, +}; +use datafusion::datasource::provider_as_source; +use datafusion::logical_expr::utils::split_conjunction_owned; +use datafusion::logical_expr::{ + EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, Values, +}; +use std::sync::Arc; +use substrait::proto::expression::MaskExpression; +use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use substrait::proto::read_rel::ReadType; +use substrait::proto::{Expression, ReadRel}; +use url::Url; + +#[allow(deprecated)] +pub async fn from_read_rel( + consumer: &impl SubstraitConsumer, + read: &ReadRel, +) -> datafusion::common::Result { + async fn read_with_schema( + consumer: &impl SubstraitConsumer, + table_ref: TableReference, + schema: DFSchema, + projection: &Option, + filter: &Option>, + ) -> datafusion::common::Result { + let schema = schema.replace_qualifier(table_ref.clone()); + + let filters = if let Some(f) = filter { + let filter_expr = consumer.consume_expression(f, &schema).await?; + split_conjunction_owned(filter_expr) + } else { + vec![] + }; + + let plan = { + let provider = match consumer.resolve_table_ref(&table_ref).await? { + Some(ref provider) => Arc::clone(provider), + _ => return plan_err!("No table named '{table_ref}'"), + }; + + LogicalPlanBuilder::scan_with_filters( + table_ref, + provider_as_source(Arc::clone(&provider)), + None, + filters, + )? + .build()? + }; + + ensure_schema_compatibility(plan.schema(), schema.clone())?; + + let schema = apply_masking(schema, projection)?; + + apply_projection(plan, schema) + } + + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; + + let substrait_schema = from_substrait_named_struct(consumer, named_struct)?; + + match &read.read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + &read.filter, + ) + .await + } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() && vt.expressions.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); + } + + let values = if !vt.expressions.is_empty() { + let mut exprs = vec![]; + for row in &vt.expressions { + let mut name_idx = 0; + let mut row_exprs = vec![]; + for expression in &row.fields { + name_idx += 1; + let expr = consumer + .consume_expression(expression, &DFSchema::empty()) + .await?; + row_exprs.push(expr); + } + if name_idx != named_struct.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + named_struct.names.len() + ); + } + exprs.push(row_exprs); + } + exprs + } else { + vt + .values + .iter() + .map(|row| { + let mut name_idx = 0; + let lits = row + .fields + .iter() + .map(|lit| { + name_idx += 1; // top-level names are provided through schema + Ok(Expr::Literal(from_substrait_literal( + consumer, + lit, + &named_struct.names, + &mut name_idx, + )?, None)) + }) + .collect::>()?; + if name_idx != named_struct.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + named_struct.names.len() + ); + } + Ok(lits) + }) + .collect::>()? + }; + + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = + if name.starts_with("file://") && !name.starts_with("file:///") { + name.replacen("file://", "file:///", 1) + } else { + name.to_string() + }; + + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } + + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); + + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + &read.filter, + ) + .await + } + _ => { + not_impl_err!("Unsupported Readtype: {:?}", read.read_type) + } + } +} + +pub fn apply_masking( + schema: DFSchema, + mask_expression: &::core::option::Option, +) -> datafusion::common::Result { + match mask_expression { + Some(MaskExpression { select, .. }) => match &select.as_ref() { + Some(projection) => { + let column_indices: Vec = projection + .struct_items + .iter() + .map(|item| item.field as usize) + .collect(); + + let fields = column_indices + .iter() + .map(|i| schema.qualified_field(*i)) + .map(|(qualifier, field)| { + (qualifier.cloned(), Arc::new(field.clone())) + }) + .collect(); + + Ok(DFSchema::new_with_metadata( + fields, + schema.metadata().clone(), + )?) + } + None => Ok(schema), + }, + None => Ok(schema), + } +} + +/// This function returns a DataFrame with fields adjusted if necessary in the event that the +/// Substrait schema is a subset of the DataFusion schema. +fn apply_projection( + plan: LogicalPlan, + substrait_schema: DFSchema, +) -> datafusion::common::Result { + let df_schema = plan.schema(); + + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(plan); + } + + let df_schema = df_schema.to_owned(); + + match plan { + LogicalPlan::TableScan(mut scan) => { + let column_indices: Vec = substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + Ok(df_schema + .index_of_column_by_name(None, substrait_field.name().as_str()) + .unwrap()) + }) + .collect::>()?; + + let fields = column_indices + .iter() + .map(|i| df_schema.qualified_field(*i)) + .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) + .collect(); + + scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + df_schema.metadata().clone(), + )?); + scan.projection = Some(column_indices); + + Ok(LogicalPlan::TableScan(scan)) + } + _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs new file mode 100644 index 0000000000000..6688a80f52746 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, substrait_err}; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::set_rel::SetOp; +use substrait::proto::{Rel, SetRel}; + +pub async fn from_set_rel( + consumer: &impl SubstraitConsumer, + set: &SetRel, +) -> datafusion::common::Result { + if set.inputs.len() < 2 { + substrait_err!("Set operation requires at least two inputs") + } else { + match set.op() { + SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, + SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, + SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( + consumer.consume_rel(&set.inputs[0]).await?, + union_rels(consumer, &set.inputs[1..], true).await?, + false, + ), + SetOp::IntersectionMultiset => { + intersect_rels(consumer, &set.inputs, false).await + } + SetOp::IntersectionMultisetAll => { + intersect_rels(consumer, &set.inputs, true).await + } + SetOp::MinusPrimary => except_rels(consumer, &set.inputs, false).await, + SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs, true).await, + set_op => not_impl_err!("Unsupported set operator: {set_op:?}"), + } + } +} + +async fn union_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut union_builder = Ok(LogicalPlanBuilder::from( + consumer.consume_rel(&rels[0]).await?, + )); + for input in &rels[1..] { + let rel_plan = consumer.consume_rel(input).await?; + + union_builder = if is_all { + union_builder?.union(rel_plan) + } else { + union_builder?.union_distinct(rel_plan) + }; + } + union_builder?.build() +} + +async fn intersect_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut rel = consumer.consume_rel(&rels[0]).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::intersect( + rel, + consumer.consume_rel(input).await?, + is_all, + )? + } + + Ok(rel) +} + +async fn except_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut rel = consumer.consume_rel(&rels[0]).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::except(rel, consumer.consume_rel(input).await?, is_all)? + } + + Ok(rel) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs new file mode 100644 index 0000000000000..56ca0ba03857d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_sorts, SubstraitConsumer}; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::SortRel; + +pub async fn from_sort_rel( + consumer: &impl SubstraitConsumer, + sort: &SortRel, +) -> datafusion::common::Result { + if let Some(input) = sort.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; + input.sort(sorts)?.build() + } else { + not_impl_err!("Sort without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs new file mode 100644 index 0000000000000..5392dd77b576b --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs @@ -0,0 +1,523 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + from_aggregate_rel, from_cast, from_cross_rel, from_exchange_rel, from_fetch_rel, + from_field_reference, from_filter_rel, from_if_then, from_join_rel, from_literal, + from_project_rel, from_read_rel, from_scalar_function, from_set_rel, + from_singular_or_list, from_sort_rel, from_subquery, from_substrait_rel, + from_substrait_rex, from_window_function, +}; +use crate::extensions::Extensions; +use async_trait::async_trait; +use datafusion::arrow::datatypes::DataType; +use datafusion::catalog::TableProvider; +use datafusion::common::{ + not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference, +}; +use datafusion::execution::{FunctionRegistry, SessionState}; +use datafusion::logical_expr::{Expr, Extension, LogicalPlan}; +use std::sync::Arc; +use substrait::proto; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::{ + Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction, + SingularOrList, SwitchExpression, WindowFunction, +}; +use substrait::proto::{ + r#type, AggregateRel, ConsistentPartitionWindowRel, CrossRel, DynamicParameter, + ExchangeRel, Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, + FetchRel, FilterRel, JoinRel, ProjectRel, ReadRel, Rel, SetRel, SortRel, +}; + +#[async_trait] +/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use async_trait::async_trait; +/// # use datafusion::catalog::TableProvider; +/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference}; +/// # use datafusion::error::Result; +/// # use datafusion::execution::{FunctionRegistry, SessionState}; +/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +/// # use std::sync::Arc; +/// # use substrait::proto; +/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel}; +/// # use datafusion::arrow::datatypes::DataType; +/// # use datafusion::logical_expr::expr::ScalarFunction; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::consumer::{ +/// # from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer +/// # }; +/// +/// struct CustomSubstraitConsumer { +/// extensions: Arc, +/// state: Arc, +/// } +/// +/// #[async_trait] +/// impl SubstraitConsumer for CustomSubstraitConsumer { +/// async fn resolve_table_ref( +/// &self, +/// table_ref: &TableReference, +/// ) -> Result>> { +/// let table = table_ref.table().to_string(); +/// let schema = self.state.schema_for_ref(table_ref.clone())?; +/// let table_provider = schema.table(&table).await?; +/// Ok(table_provider) +/// } +/// +/// fn get_extensions(&self) -> &Extensions { +/// self.extensions.as_ref() +/// } +/// +/// fn get_function_registry(&self) -> &impl FunctionRegistry { +/// self.state.as_ref() +/// } +/// +/// // You can reuse existing consumer code to assist in handling advanced extensions +/// async fn consume_project(&self, rel: &ProjectRel) -> Result { +/// let df_plan = from_project_rel(self, rel).await?; +/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() { +/// not_impl_err!( +/// "decode and handle an advanced extension: {:?}", +/// advanced_extension +/// ) +/// } else { +/// Ok(df_plan) +/// } +/// } +/// +/// // You can implement a fully custom consumer method if you need special handling +/// async fn consume_filter(&self, rel: &FilterRel) -> Result { +/// let input = self.consume_rel(rel.input.as_ref().unwrap()).await?; +/// let expression = +/// self.consume_expression(rel.condition.as_ref().unwrap(), input.schema()) +/// .await?; +/// // though this one is quite boring +/// LogicalPlanBuilder::from(input).filter(expression)?.build() +/// } +/// +/// // You can add handlers for extension relations +/// async fn consume_extension_leaf( +/// &self, +/// rel: &ExtensionLeafRel, +/// ) -> Result { +/// not_impl_err!( +/// "handle protobuf Any {} as you need", +/// rel.detail.as_ref().unwrap().type_url +/// ) +/// } +/// +/// // and handlers for user-define types +/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// +/// // and user-defined literals +/// fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// } +/// ``` +/// +pub trait SubstraitConsumer: Send + Sync + Sized { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> datafusion::common::Result>>; + + // TODO: Remove these two methods + // Ideally, the abstract consumer should not place any constraints on implementations. + // The functionality for which the Extensions and FunctionRegistry is needed should be abstracted + // out into methods on the trait. As an example, resolve_table_reference is such a method. + // See: https://github.com/apache/datafusion/issues/13863 + fn get_extensions(&self) -> &Extensions; + fn get_function_registry(&self) -> &impl FunctionRegistry; + + // Relation Methods + // There is one method per Substrait relation to allow for easy overriding of consumer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + /// All [Rel]s to be converted pass through this method. + /// You can provide your own implementation if you wish to customize the conversion behaviour. + async fn consume_rel(&self, rel: &Rel) -> datafusion::common::Result { + from_substrait_rel(self, rel).await + } + + async fn consume_read( + &self, + rel: &ReadRel, + ) -> datafusion::common::Result { + from_read_rel(self, rel).await + } + + async fn consume_filter( + &self, + rel: &FilterRel, + ) -> datafusion::common::Result { + from_filter_rel(self, rel).await + } + + async fn consume_fetch( + &self, + rel: &FetchRel, + ) -> datafusion::common::Result { + from_fetch_rel(self, rel).await + } + + async fn consume_aggregate( + &self, + rel: &AggregateRel, + ) -> datafusion::common::Result { + from_aggregate_rel(self, rel).await + } + + async fn consume_sort( + &self, + rel: &SortRel, + ) -> datafusion::common::Result { + from_sort_rel(self, rel).await + } + + async fn consume_join( + &self, + rel: &JoinRel, + ) -> datafusion::common::Result { + from_join_rel(self, rel).await + } + + async fn consume_project( + &self, + rel: &ProjectRel, + ) -> datafusion::common::Result { + from_project_rel(self, rel).await + } + + async fn consume_set(&self, rel: &SetRel) -> datafusion::common::Result { + from_set_rel(self, rel).await + } + + async fn consume_cross( + &self, + rel: &CrossRel, + ) -> datafusion::common::Result { + from_cross_rel(self, rel).await + } + + async fn consume_consistent_partition_window( + &self, + _rel: &ConsistentPartitionWindowRel, + ) -> datafusion::common::Result { + not_impl_err!("Consistent Partition Window Rel not supported") + } + + async fn consume_exchange( + &self, + rel: &ExchangeRel, + ) -> datafusion::common::Result { + from_exchange_rel(self, rel).await + } + + // Expression Methods + // There is one method per Substrait expression to allow for easy overriding of consumer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + /// All [Expression]s to be converted pass through this method. + /// You can provide your own implementation if you wish to customize the conversion behaviour. + async fn consume_expression( + &self, + expr: &Expression, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_substrait_rex(self, expr, input_schema).await + } + + async fn consume_literal(&self, expr: &Literal) -> datafusion::common::Result { + from_literal(self, expr).await + } + + async fn consume_field_reference( + &self, + expr: &FieldReference, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_field_reference(self, expr, input_schema).await + } + + async fn consume_scalar_function( + &self, + expr: &ScalarFunction, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_scalar_function(self, expr, input_schema).await + } + + async fn consume_window_function( + &self, + expr: &WindowFunction, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_window_function(self, expr, input_schema).await + } + + async fn consume_if_then( + &self, + expr: &IfThen, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_if_then(self, expr, input_schema).await + } + + async fn consume_switch( + &self, + _expr: &SwitchExpression, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Switch expression not supported") + } + + async fn consume_singular_or_list( + &self, + expr: &SingularOrList, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_singular_or_list(self, expr, input_schema).await + } + + async fn consume_multi_or_list( + &self, + _expr: &MultiOrList, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Multi Or List expression not supported") + } + + async fn consume_cast( + &self, + expr: &substrait_expression::Cast, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_cast(self, expr, input_schema).await + } + + async fn consume_subquery( + &self, + expr: &substrait_expression::Subquery, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_subquery(self, expr, input_schema).await + } + + async fn consume_nested( + &self, + _expr: &Nested, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Nested expression not supported") + } + + async fn consume_enum( + &self, + _expr: &Enum, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Enum expression not supported") + } + + async fn consume_dynamic_parameter( + &self, + _expr: &DynamicParameter, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Dynamic Parameter expression not supported") + } + + // User-Defined Functionality + + // The details of extension relations, and how to handle them, are fully up to users to specify. + // The following methods allow users to customize the consumer behaviour + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionLeafRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionLeafRel") + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionSingleRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionSingleRel") + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionMultiRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionMultiRel") + } + + // Users can bring their own types to Substrait which require custom handling + + fn consume_user_defined_type( + &self, + user_defined_type: &r#type::UserDefined, + ) -> datafusion::common::Result { + substrait_err!( + "Missing handler for user-defined type: {}", + user_defined_type.type_reference + ) + } + + fn consume_user_defined_literal( + &self, + user_defined_literal: &proto::expression::literal::UserDefined, + ) -> datafusion::common::Result { + substrait_err!( + "Missing handler for user-defined literals {}", + user_defined_literal.type_reference + ) + } +} + +/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions. +/// +/// Used as the consumer in [crate::logical_plan::consumer::from_substrait_plan] +pub struct DefaultSubstraitConsumer<'a> { + pub(super) extensions: &'a Extensions, + pub(super) state: &'a SessionState, +} + +impl<'a> DefaultSubstraitConsumer<'a> { + pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { + DefaultSubstraitConsumer { extensions, state } + } +} + +#[async_trait] +impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> datafusion::common::Result>> { + let table = table_ref.table().to_string(); + let schema = self.state.schema_for_ref(table_ref.clone())?; + let table_provider = schema.table(&table).await?; + Ok(table_provider) + } + + fn get_extensions(&self) -> &Extensions { + self.extensions + } + + fn get_function_registry(&self) -> &impl FunctionRegistry { + self.state + } + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + let Some(input_rel) = &rel.input else { + return substrait_err!( + "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead" + ); + }; + let input_plan = self.consume_rel(input_rel).await?; + let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionMultiRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + let mut inputs = Vec::with_capacity(rel.inputs.len()); + for input in &rel.inputs { + let input_plan = self.consume_rel(input).await?; + inputs.push(input_plan); + } + let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs new file mode 100644 index 0000000000000..80300af24ac4a --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -0,0 +1,334 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::utils::{from_substrait_precision, next_struct_field_name, DEFAULT_TIMEZONE}; +use super::SubstraitConsumer; +#[allow(deprecated)] +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, + DEFAULT_MAP_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + DICTIONARY_MAP_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, TIME_32_TYPE_VARIATION_REF, + TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::datatypes::{ + DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, +}; +use datafusion::common::{ + not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, +}; +use std::sync::Arc; +use substrait::proto::{r#type, NamedStruct, Type}; + +pub(crate) fn from_substrait_type_without_names( + consumer: &impl SubstraitConsumer, + dt: &Type, +) -> datafusion::common::Result { + from_substrait_type(consumer, dt, &[], &mut 0) +} + +pub fn from_substrait_type( + consumer: &impl SubstraitConsumer, + dt: &Type, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + match &dt.kind { + Some(s_kind) => match s_kind { + r#type::Kind::Bool(_) => Ok(DataType::Boolean), + r#type::Kind::I8(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I16(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I32(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I64(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::Fp32(_) => Ok(DataType::Float32), + r#type::Kind::Fp64(_) => Ok(DataType::Float64), + r#type::Kind::Timestamp(ts) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Second, None)) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + } + } + r#type::Kind::PrecisionTimestamp(pts) => { + let unit = from_substrait_precision(pts.precision, "PrecisionTimestamp")?; + Ok(DataType::Timestamp(unit, None)) + } + r#type::Kind::PrecisionTimestampTz(pts) => { + let unit = + from_substrait_precision(pts.precision, "PrecisionTimestampTz")?; + Ok(DataType::Timestamp(unit, Some(DEFAULT_TIMEZONE.into()))) + } + r#type::Kind::PrecisionTime(pt) => { + let time_unit = from_substrait_precision(pt.precision, "PrecisionTime")?; + match pt.type_variation_reference { + TIME_32_TYPE_VARIATION_REF => Ok(DataType::Time32(time_unit)), + TIME_64_TYPE_VARIATION_REF => Ok(DataType::Time64(time_unit)), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + } + } + r#type::Kind::Date(date) => match date.type_variation_reference { + DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), + DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::Binary(binary) => match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::FixedBinary(fixed) => { + Ok(DataType::FixedSizeBinary(fixed.length)) + } + r#type::Kind::String(string) => match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::List(list) => { + let inner_type = list.r#type.as_ref().ok_or_else(|| { + substrait_datafusion_err!("List type must have inner type") + })?; + let field = Arc::new(Field::new_list_field( + from_substrait_type(consumer, inner_type, dfs_names, name_idx)?, + // We ignore Substrait's nullability here to match to_substrait_literal + // which always creates nullable lists + true, + )); + match list.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + )?, + } + } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + let key_type = + from_substrait_type(consumer, key_type, dfs_names, name_idx)?; + let value_type = + from_substrait_type(consumer, value_type, dfs_names, name_idx)?; + + match map.type_variation_reference { + DEFAULT_MAP_TYPE_VARIATION_REF => { + let key_field = Arc::new(Field::new("key", key_type, false)); + let value_field = Arc::new(Field::new("value", value_type, true)); + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, // whether keys are sorted + )) + } + DICTIONARY_MAP_TYPE_VARIATION_REF => Ok(DataType::Dictionary( + Box::new(key_type), + Box::new(value_type), + )), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + } + } + r#type::Kind::Decimal(d) => match d.type_variation_reference { + DECIMAL_128_TYPE_VARIATION_REF => { + Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) + } + DECIMAL_256_TYPE_VARIATION_REF => { + Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::IntervalYear(_) => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + r#type::Kind::IntervalDay(i) => match i.type_variation_reference { + DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + DURATION_INTERVAL_DAY_TYPE_VARIATION_REF => { + let duration_unit = match i.precision { + Some(p) => from_substrait_precision(p, "Duration"), + None => { + not_impl_err!("Missing Substrait precision for Duration") + } + }?; + Ok(DataType::Duration(duration_unit)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::IntervalCompound(_) => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + r#type::Kind::UserDefined(u) => { + if let Ok(data_type) = consumer.consume_user_defined_type(u) { + return Ok(data_type); + } + + // TODO: remove the code below once the producer has been updated + if let Some(name) = consumer.get_extensions().types.get(&u.type_reference) + { + #[allow(deprecated)] + match name.as_ref() { + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), + } + } else { + #[allow(deprecated)] + match u.type_reference { + // Kept for backwards compatibility, producers should use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, producers should use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), + } + } + } + r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( + consumer, s, dfs_names, name_idx, + )?)), + r#type::Kind::Varchar(_) => Ok(DataType::Utf8), + r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), + _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), + }, + _ => not_impl_err!("`None` Substrait kind is not supported"), + } +} + +/// Convert Substrait NamedStruct to DataFusion DFSchemaRef +pub fn from_substrait_named_struct( + consumer: &impl SubstraitConsumer, + base_schema: &NamedStruct, +) -> datafusion::common::Result { + let mut name_idx = 0; + let fields = from_substrait_struct_type( + consumer, + base_schema.r#struct.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Named struct must contain a struct") + })?, + &base_schema.names, + &mut name_idx, + ); + if name_idx != base_schema.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + base_schema.names.len() + ); + } + DFSchema::try_from(Schema::new(fields?)) +} + +fn from_substrait_struct_type( + consumer: &impl SubstraitConsumer, + s: &r#type::Struct, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + let mut fields = vec![]; + for (i, f) in s.types.iter().enumerate() { + let field = Field::new( + next_struct_field_name(i, dfs_names, name_idx)?, + from_substrait_type(consumer, f, dfs_names, name_idx)?, + true, // We assume everything to be nullable since that's easier than ensuring it matches + ); + fields.push(field); + } + Ok(fields.into()) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs new file mode 100644 index 0000000000000..f7eedcb7a2b25 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -0,0 +1,643 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit, UnionFields}; +use datafusion::common::{ + exec_err, not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, + DFSchemaRef, +}; +use datafusion::logical_expr::expr::Sort; +use datafusion::logical_expr::{Cast, Expr, ExprSchemable}; +use std::collections::HashSet; +use std::sync::Arc; +use substrait::proto::sort_field::SortDirection; +use substrait::proto::sort_field::SortKind::{ComparisonFunctionReference, Direction}; +use substrait::proto::SortField; + +// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which +// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone +// results in correct points on the timeline, and we pick UTC as a reasonable default. +// However, DF uses the timezone also for some arithmetic and display purposes (see e.g. +// https://github.com/apache/arrow-rs/blob/ee5694078c86c8201549654246900a4232d531a9/arrow-cast/src/cast/mod.rs#L1749). +pub(super) const DEFAULT_TIMEZONE: &str = "UTC"; + +pub(super) fn next_struct_field_name( + column_idx: usize, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + if dfs_names.is_empty() { + // If names are not given, create dummy names + // c0, c1, ... align with e.g. SqlToRel::create_named_struct + Ok(format!("c{column_idx}")) + } else { + let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { + substrait_datafusion_err!("Named schema must contain names for all fields") + })?; + *name_idx += 1; + Ok(name) + } +} + +/// Traverse through the field, renaming the provided field itself and all its inner struct fields. +pub fn rename_field( + field: &Field, + dfs_names: &Vec, + unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" + name_idx: &mut usize, // Index into dfs_names +) -> datafusion::common::Result { + let name = next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)?; + rename_fields_data_type(field.clone().with_name(name), dfs_names, name_idx) +} + +/// Rename the field's data type but not the field itself. +pub fn rename_fields_data_type( + field: Field, + dfs_names: &Vec, + name_idx: &mut usize, // Index into dfs_names +) -> datafusion::common::Result { + let dt = rename_data_type(field.data_type(), dfs_names, name_idx)?; + Ok(field.with_data_type(dt)) +} + +/// Traverse through the data type (incl. lists/maps/etc), renaming all inner struct fields. +pub fn rename_data_type( + data_type: &DataType, + dfs_names: &Vec, + name_idx: &mut usize, // Index into dfs_names +) -> datafusion::common::Result { + match data_type { + DataType::Struct(children) => { + let children = children + .iter() + .enumerate() + .map(|(field_idx, f)| { + rename_field(f.as_ref(), dfs_names, field_idx, name_idx) + }) + .collect::>()?; + Ok(DataType::Struct(children)) + } + DataType::List(inner) => Ok(DataType::List(Arc::new(rename_fields_data_type( + inner.as_ref().to_owned(), + dfs_names, + name_idx, + )?))), + DataType::LargeList(inner) => Ok(DataType::LargeList(Arc::new( + rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?, + ))), + DataType::ListView(inner) => Ok(DataType::ListView(Arc::new( + rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?, + ))), + DataType::LargeListView(inner) => Ok(DataType::LargeListView(Arc::new( + rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?, + ))), + DataType::FixedSizeList(inner, len) => Ok(DataType::FixedSizeList( + Arc::new(rename_fields_data_type( + inner.as_ref().to_owned(), + dfs_names, + name_idx, + )?), + *len, + )), + DataType::Map(entries, sorted) => { + let entries_data_type = match entries.data_type() { + DataType::Struct(fields) => { + // This should be two fields, normally "key" and "value", but not guaranteed + let fields = fields + .iter() + .map(|f| { + rename_fields_data_type( + f.as_ref().to_owned(), + dfs_names, + name_idx, + ) + }) + .collect::>()?; + Ok(DataType::Struct(fields)) + } + _ => exec_err!("Expected map type to contain an inner struct type"), + }?; + Ok(DataType::Map( + Arc::new( + entries + .as_ref() + .to_owned() + .with_data_type(entries_data_type), + ), + *sorted, + )) + } + DataType::Dictionary(key_type, value_type) => { + // Dicts probably shouldn't contain structs, but support them just in case one does + Ok(DataType::Dictionary( + Box::new(rename_data_type(key_type, dfs_names, name_idx)?), + Box::new(rename_data_type(value_type, dfs_names, name_idx)?), + )) + } + DataType::RunEndEncoded(run_ends_field, values_field) => { + // At least the run_ends_field shouldn't contain names (since it should be i16/i32/i64), + // but we'll try renaming its datatype just in case. + let run_ends_field = rename_fields_data_type( + run_ends_field.as_ref().clone(), + dfs_names, + name_idx, + )?; + let values_field = rename_fields_data_type( + values_field.as_ref().clone(), + dfs_names, + name_idx, + )?; + + Ok(DataType::RunEndEncoded( + Arc::new(run_ends_field), + Arc::new(values_field), + )) + } + DataType::Union(fields, mode) => { + let fields = fields + .iter() + .map(|(i, f)| { + Ok(( + i, + Arc::new(rename_fields_data_type( + f.as_ref().clone(), + dfs_names, + name_idx, + )?), + )) + }) + .collect::>()?; + Ok(DataType::Union(fields, *mode)) + } + // Explicitly listing the rest (which can not contain inner fields needing renaming) + // to ensure we're exhaustive + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::BinaryView + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => Ok(data_type.clone()), + } +} + +/// Produce a version of the given schema with names matching the given list of names. +/// Substrait doesn't deal with column (incl. nested struct field) names within the schema, +/// but it does give us the list of expected names at the end of the plan, so we use this +/// to rename the schema to match the expected names. +pub(super) fn make_renamed_schema( + schema: &DFSchemaRef, + dfs_names: &Vec, +) -> datafusion::common::Result { + let mut name_idx = 0; + + let (qualifiers, fields): (_, Vec) = schema + .iter() + .enumerate() + .map(|(field_idx, (q, f))| { + let renamed_f = + rename_field(f.as_ref(), dfs_names, field_idx, &mut name_idx)?; + Ok((q.cloned(), renamed_f)) + }) + .collect::>>()? + .into_iter() + .unzip(); + + if name_idx != dfs_names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + dfs_names.len()); + } + + DFSchema::from_field_specific_qualified_schema( + qualifiers, + &Arc::new(Schema::new(fields)), + ) +} + +/// Ensure the expressions have the right name(s) according to the new schema. +/// This includes the top-level (column) name, which will be renamed through aliasing if needed, +/// as well as nested names (if the expression produces any struct types), which will be renamed +/// through casting if needed. +pub(super) fn rename_expressions( + exprs: impl IntoIterator, + input_schema: &DFSchema, + new_schema_fields: &[Arc], +) -> datafusion::common::Result> { + exprs + .into_iter() + .zip(new_schema_fields) + .map(|(old_expr, new_field)| { + // Check if type (i.e. nested struct field names) match, use Cast to rename if needed + let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { + Expr::Cast(Cast::new( + Box::new(old_expr), + new_field.data_type().to_owned(), + )) + } else { + old_expr + }; + // Alias column if needed to fix the top-level name + match &new_expr { + // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier + Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), + _ => new_expr.alias_if_changed(new_field.name().to_owned()), + } + }) + .collect() +} + +/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion +/// +/// This means: +/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The +/// DataFusion schema may have MORE fields, but not the other way around. +/// 2. All fields are compatible. See [`ensure_field_compatibility`] for details +pub(super) fn ensure_schema_compatibility( + table_schema: &DFSchema, + substrait_schema: DFSchema, +) -> datafusion::common::Result<()> { + substrait_schema + .strip_qualifiers() + .fields() + .iter() + .try_for_each(|substrait_field| { + let df_field = + table_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatibility(df_field, substrait_field) + }) +} + +/// Ensures that the given Substrait field is compatible with the given DataFusion field +/// +/// A field is compatible between Substrait and DataFusion if: +/// 1. They have logically equivalent types. +/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields +/// is not nullable. +/// +/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not +/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. +fn ensure_field_compatibility( + datafusion_field: &Field, + substrait_field: &Field, +) -> datafusion::common::Result<()> { + if !DFSchema::datatype_is_logically_equal( + datafusion_field.data_type(), + substrait_field.data_type(), + ) { + return substrait_err!( + "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", + substrait_field.name(), + substrait_field.data_type(), + datafusion_field.data_type() + ); + } + + if !compatible_nullabilities( + datafusion_field.is_nullable(), + substrait_field.is_nullable(), + ) { + // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. + return substrait_err!( + "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", + substrait_field.name() + ); + } + Ok(()) +} + +/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise +fn compatible_nullabilities( + datafusion_nullability: bool, + substrait_nullability: bool, +) -> bool { + // DataFusion and Substrait have the same nullability + (datafusion_nullability == substrait_nullability) + // DataFusion is not nullable and Substrait is nullable + || (!datafusion_nullability && substrait_nullability) +} + +pub(super) struct NameTracker { + seen_names: HashSet, +} + +pub(super) enum NameTrackerStatus { + NeverSeen, + SeenBefore, +} + +impl NameTracker { + pub(super) fn new() -> Self { + NameTracker { + seen_names: HashSet::default(), + } + } + pub(super) fn get_unique_name( + &mut self, + name: String, + ) -> (String, NameTrackerStatus) { + match self.seen_names.insert(name.clone()) { + true => (name, NameTrackerStatus::NeverSeen), + false => { + let mut counter = 0; + loop { + let candidate_name = format!("{name}__temp__{counter}"); + if self.seen_names.insert(candidate_name.clone()) { + return (candidate_name, NameTrackerStatus::SeenBefore); + } + counter += 1; + } + } + } + } + + pub(super) fn get_uniquely_named_expr( + &mut self, + expr: Expr, + ) -> datafusion::common::Result { + match self.get_unique_name(expr.name_for_alias()?) { + (_, NameTrackerStatus::NeverSeen) => Ok(expr), + (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), + } + } +} + +/// Convert Substrait Sorts to DataFusion Exprs +pub async fn from_substrait_sorts( + consumer: &impl SubstraitConsumer, + substrait_sorts: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut sorts: Vec = vec![]; + for s in substrait_sorts { + let expr = consumer + .consume_expression(s.expr.as_ref().unwrap(), input_schema) + .await?; + let asc_nullfirst = match &s.sort_kind { + Some(k) => match k { + Direction(d) => { + let Ok(direction) = SortDirection::try_from(*d) else { + return not_impl_err!( + "Unsupported Substrait SortDirection value {d}" + ); + }; + + match direction { + SortDirection::AscNullsFirst => Ok((true, true)), + SortDirection::AscNullsLast => Ok((true, false)), + SortDirection::DescNullsFirst => Ok((false, true)), + SortDirection::DescNullsLast => Ok((false, false)), + SortDirection::Clustered => not_impl_err!( + "Sort with direction clustered is not yet supported" + ), + SortDirection::Unspecified => { + not_impl_err!("Unspecified sort direction is invalid") + } + } + } + ComparisonFunctionReference(_) => not_impl_err!( + "Sort using comparison function reference is not supported" + ), + }, + None => not_impl_err!("Sort without sort kind is invalid"), + }; + let (asc, nulls_first) = asc_nullfirst.unwrap(); + sorts.push(Sort { + expr, + asc, + nulls_first, + }); + } + Ok(sorts) +} + +pub(crate) fn from_substrait_precision( + precision: i32, + type_name: &str, +) -> datafusion::common::Result { + match precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + precision => { + not_impl_err!("Unsupported Substrait precision {precision}, for {type_name}") + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::make_renamed_schema; + use crate::extensions::Extensions; + use crate::logical_plan::consumer::DefaultSubstraitConsumer; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::DFSchema; + use datafusion::error::Result; + use datafusion::execution::SessionState; + use datafusion::prelude::SessionContext; + use datafusion::sql::TableReference; + use std::collections::HashMap; + use std::sync::{Arc, LazyLock}; + + pub(crate) static TEST_SESSION_STATE: LazyLock = + LazyLock::new(|| SessionContext::default().state()); + pub(crate) static TEST_EXTENSIONS: LazyLock = + LazyLock::new(Extensions::default); + pub(crate) fn test_consumer() -> DefaultSubstraitConsumer<'static> { + let extensions = &TEST_EXTENSIONS; + let state = &TEST_SESSION_STATE; + DefaultSubstraitConsumer::new(extensions, state) + } + + #[tokio::test] + async fn rename_schema() -> Result<()> { + let table_ref = TableReference::bare("test"); + let fields = vec![ + ( + Some(table_ref.clone()), + Arc::new(Field::new("0", DataType::Int32, false)), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_struct( + "1", + vec![ + Field::new("2", DataType::Int32, false), + Field::new_struct( + "3", + vec![Field::new("4", DataType::Int32, false)], + false, + ), + ], + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_list( + "5", + Arc::new(Field::new_struct( + "item", + vec![Field::new("6", DataType::Int32, false)], + false, + )), + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_large_list( + "7", + Arc::new(Field::new_struct( + "item", + vec![Field::new("8", DataType::Int32, false)], + false, + )), + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_map( + "9", + "entries", + Arc::new(Field::new_struct( + "keys", + vec![Field::new("10", DataType::Int32, false)], + false, + )), + Arc::new(Field::new_struct( + "values", + vec![Field::new("11", DataType::Int32, false)], + false, + )), + false, + false, + )), + ), + ]; + + let schema = Arc::new(DFSchema::new_with_metadata(fields, HashMap::default())?); + let dfs_names = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + "e".to_string(), + "f".to_string(), + "g".to_string(), + "h".to_string(), + "i".to_string(), + "j".to_string(), + "k".to_string(), + "l".to_string(), + ]; + let renamed_schema = make_renamed_schema(&schema, &dfs_names)?; + + assert_eq!(renamed_schema.fields().len(), 5); + assert_eq!( + *renamed_schema.field(0), + Field::new("a", DataType::Int32, false) + ); + assert_eq!( + *renamed_schema.field(1), + Field::new_struct( + "b", + vec![ + Field::new("c", DataType::Int32, false), + Field::new_struct( + "d", + vec![Field::new("e", DataType::Int32, false)], + false, + ) + ], + false, + ) + ); + assert_eq!( + *renamed_schema.field(2), + Field::new_list( + "f", + Arc::new(Field::new_struct( + "item", + vec![Field::new("g", DataType::Int32, false)], + false, + )), + false, + ) + ); + assert_eq!( + *renamed_schema.field(3), + Field::new_large_list( + "h", + Arc::new(Field::new_struct( + "item", + vec![Field::new("i", DataType::Int32, false)], + false, + )), + false, + ) + ); + assert_eq!( + *renamed_schema.field(4), + Field::new_map( + "j", + "entries", + Arc::new(Field::new_struct( + "keys", + vec![Field::new("k", DataType::Int32, false)], + false, + )), + Arc::new(Field::new_struct( + "values", + vec![Field::new("l", DataType::Int32, false)], + false, + )), + false, + false, + ) + ); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs deleted file mode 100644 index 07bf0cb96aa33..0000000000000 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ /dev/null @@ -1,2915 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; -use substrait::proto::expression_reference::ExprType; - -use datafusion::arrow::datatypes::{Field, IntervalUnit}; -use datafusion::logical_expr::{ - Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, - Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, - TryCast, Union, Values, Window, WindowFrameUnits, -}; -use datafusion::{ - arrow::datatypes::{DataType, TimeUnit}, - error::{DataFusionError, Result}, - logical_expr::{WindowFrame, WindowFrameBound}, - prelude::JoinType, - scalar::ScalarValue, -}; - -use crate::extensions::Extensions; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, - VIEW_CONTAINER_TYPE_VARIATION_REF, -}; -use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; -use datafusion::arrow::temporal_conversions::NANOSECONDS; -use datafusion::common::{ - exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, - substrait_err, Column, DFSchema, DFSchemaRef, ToDFSchema, -}; -use datafusion::execution::registry::SerializerRegistry; -use datafusion::execution::SessionState; -use datafusion::logical_expr::expr::{ - AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, WindowFunction, WindowFunctionParams, -}; -use datafusion::logical_expr::utils::conjunction; -use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; -use datafusion::prelude::Expr; -use pbjson_types::Any as ProtoAny; -use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; -use substrait::proto::expression::cast::FailureBehavior; -use substrait::proto::expression::field_reference::{RootReference, RootType}; -use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; -use substrait::proto::expression::literal::map::KeyValue; -use substrait::proto::expression::literal::{ - IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map, - PrecisionTimestamp, Struct, -}; -use substrait::proto::expression::subquery::InPredicate; -use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::expression::ScalarFunction; -use substrait::proto::read_rel::VirtualTable; -use substrait::proto::rel_common::EmitKind; -use substrait::proto::rel_common::EmitKind::Emit; -use substrait::proto::{ - fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, - RelCommon, -}; -use substrait::{ - proto::{ - aggregate_function::AggregationInvocation, - aggregate_rel::{Grouping, Measure}, - expression::{ - field_reference::ReferenceType, - if_then::IfClause, - literal::{Decimal, LiteralType}, - mask_expression::{StructItem, StructSelect}, - reference_segment, - window_function::bound as SubstraitBound, - window_function::bound::Kind as BoundKind, - window_function::Bound, - FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - SingularOrList, WindowFunction as SubstraitWindowFunction, - }, - function_argument::ArgType, - join_rel, plan_rel, r#type, - read_rel::{NamedTable, ReadType}, - rel::RelType, - set_rel, - sort_field::{SortDirection, SortKind}, - AggregateFunction, AggregateRel, AggregationPhase, Expression, ExtensionLeafRel, - ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, - JoinRel, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, SetRel, - SortField, SortRel, - }, - version, -}; - -/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. -/// It can be implemented by users to allow for custom handling of relations, expressions, etc. -/// -/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully -/// customizable Substrait serde. -/// -/// # Example Usage -/// -/// ``` -/// # use std::sync::Arc; -/// # use substrait::proto::{Expression, Rel}; -/// # use substrait::proto::rel::RelType; -/// # use datafusion::common::DFSchemaRef; -/// # use datafusion::error::Result; -/// # use datafusion::execution::SessionState; -/// # use datafusion::logical_expr::{Between, Extension, Projection}; -/// # use datafusion_substrait::extensions::Extensions; -/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; -/// -/// struct CustomSubstraitProducer { -/// extensions: Extensions, -/// state: Arc, -/// } -/// -/// impl SubstraitProducer for CustomSubstraitProducer { -/// -/// fn register_function(&mut self, signature: String) -> u32 { -/// self.extensions.register_function(signature) -/// } -/// -/// fn get_extensions(self) -> Extensions { -/// self.extensions -/// } -/// -/// // You can set additional metadata on the Rels you produce -/// fn handle_projection(&mut self, plan: &Projection) -> Result> { -/// let mut rel = from_projection(self, plan)?; -/// match rel.rel_type { -/// Some(RelType::Project(mut project)) => { -/// let mut project = project.clone(); -/// // set common metadata or advanced extension -/// project.common = None; -/// project.advanced_extension = None; -/// Ok(Box::new(Rel { -/// rel_type: Some(RelType::Project(project)), -/// })) -/// } -/// rel_type => Ok(Box::new(Rel { rel_type })), -/// } -/// } -/// -/// // You can tweak how you convert expressions for your target system -/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { -/// // add your own encoding for Between -/// todo!() -/// } -/// -/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait -/// fn handle_extension(&mut self, _plan: &Extension) -> Result> { -/// // implement your own serializer into Substrait -/// todo!() -/// } -/// } -/// ``` -pub trait SubstraitProducer: Send + Sync + Sized { - /// Within a Substrait plan, functions are referenced using function anchors that are stored at - /// the top level of the [Plan] within - /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) - /// messages. - /// - /// When given a function signature, this method should return the existing anchor for it if - /// there is one. Otherwise, it should generate a new anchor. - fn register_function(&mut self, signature: String) -> u32; - - /// Consume the producer to generate the [Extensions] for the Substrait plan based on the - /// functions that have been registered - fn get_extensions(self) -> Extensions; - - // Logical Plan Methods - // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - fn handle_plan(&mut self, plan: &LogicalPlan) -> Result> { - to_substrait_rel(self, plan) - } - - fn handle_projection(&mut self, plan: &Projection) -> Result> { - from_projection(self, plan) - } - - fn handle_filter(&mut self, plan: &Filter) -> Result> { - from_filter(self, plan) - } - - fn handle_window(&mut self, plan: &Window) -> Result> { - from_window(self, plan) - } - - fn handle_aggregate(&mut self, plan: &Aggregate) -> Result> { - from_aggregate(self, plan) - } - - fn handle_sort(&mut self, plan: &Sort) -> Result> { - from_sort(self, plan) - } - - fn handle_join(&mut self, plan: &Join) -> Result> { - from_join(self, plan) - } - - fn handle_repartition(&mut self, plan: &Repartition) -> Result> { - from_repartition(self, plan) - } - - fn handle_union(&mut self, plan: &Union) -> Result> { - from_union(self, plan) - } - - fn handle_table_scan(&mut self, plan: &TableScan) -> Result> { - from_table_scan(self, plan) - } - - fn handle_empty_relation(&mut self, plan: &EmptyRelation) -> Result> { - from_empty_relation(plan) - } - - fn handle_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result> { - from_subquery_alias(self, plan) - } - - fn handle_limit(&mut self, plan: &Limit) -> Result> { - from_limit(self, plan) - } - - fn handle_values(&mut self, plan: &Values) -> Result> { - from_values(self, plan) - } - - fn handle_distinct(&mut self, plan: &Distinct) -> Result> { - from_distinct(self, plan) - } - - fn handle_extension(&mut self, _plan: &Extension) -> Result> { - substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") - } - - // Expression Methods - // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - fn handle_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { - to_substrait_rex(self, expr, schema) - } - - fn handle_alias( - &mut self, - alias: &Alias, - schema: &DFSchemaRef, - ) -> Result { - from_alias(self, alias, schema) - } - - fn handle_column( - &mut self, - column: &Column, - schema: &DFSchemaRef, - ) -> Result { - from_column(column, schema) - } - - fn handle_literal(&mut self, value: &ScalarValue) -> Result { - from_literal(self, value) - } - - fn handle_binary_expr( - &mut self, - expr: &BinaryExpr, - schema: &DFSchemaRef, - ) -> Result { - from_binary_expr(self, expr, schema) - } - - fn handle_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result { - from_like(self, like, schema) - } - - /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative - fn handle_unary_expr( - &mut self, - expr: &Expr, - schema: &DFSchemaRef, - ) -> Result { - from_unary_expr(self, expr, schema) - } - - fn handle_between( - &mut self, - between: &Between, - schema: &DFSchemaRef, - ) -> Result { - from_between(self, between, schema) - } - - fn handle_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result { - from_case(self, case, schema) - } - - fn handle_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result { - from_cast(self, cast, schema) - } - - fn handle_try_cast( - &mut self, - cast: &TryCast, - schema: &DFSchemaRef, - ) -> Result { - from_try_cast(self, cast, schema) - } - - fn handle_scalar_function( - &mut self, - scalar_fn: &expr::ScalarFunction, - schema: &DFSchemaRef, - ) -> Result { - from_scalar_function(self, scalar_fn, schema) - } - - fn handle_aggregate_function( - &mut self, - agg_fn: &expr::AggregateFunction, - schema: &DFSchemaRef, - ) -> Result { - from_aggregate_function(self, agg_fn, schema) - } - - fn handle_window_function( - &mut self, - window_fn: &WindowFunction, - schema: &DFSchemaRef, - ) -> Result { - from_window_function(self, window_fn, schema) - } - - fn handle_in_list( - &mut self, - in_list: &InList, - schema: &DFSchemaRef, - ) -> Result { - from_in_list(self, in_list, schema) - } - - fn handle_in_subquery( - &mut self, - in_subquery: &InSubquery, - schema: &DFSchemaRef, - ) -> Result { - from_in_subquery(self, in_subquery, schema) - } -} - -pub struct DefaultSubstraitProducer<'a> { - extensions: Extensions, - serializer_registry: &'a dyn SerializerRegistry, -} - -impl<'a> DefaultSubstraitProducer<'a> { - pub fn new(state: &'a SessionState) -> Self { - DefaultSubstraitProducer { - extensions: Extensions::default(), - serializer_registry: state.serializer_registry().as_ref(), - } - } -} - -impl SubstraitProducer for DefaultSubstraitProducer<'_> { - fn register_function(&mut self, fn_name: String) -> u32 { - self.extensions.register_function(fn_name) - } - - fn get_extensions(self) -> Extensions { - self.extensions - } - - fn handle_extension(&mut self, plan: &Extension) -> Result> { - let extension_bytes = self - .serializer_registry - .serialize_logical_plan(plan.node.as_ref())?; - let detail = ProtoAny { - type_url: plan.node.name().to_string(), - value: extension_bytes.into(), - }; - let mut inputs_rel = plan - .node - .inputs() - .into_iter() - .map(|plan| self.handle_plan(plan)) - .collect::>>()?; - let rel_type = match inputs_rel.len() { - 0 => RelType::ExtensionLeaf(ExtensionLeafRel { - common: None, - detail: Some(detail), - }), - 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { - common: None, - detail: Some(detail), - input: Some(inputs_rel.pop().unwrap()), - })), - _ => RelType::ExtensionMulti(ExtensionMultiRel { - common: None, - detail: Some(detail), - inputs: inputs_rel.into_iter().map(|r| *r).collect(), - }), - }; - Ok(Box::new(Rel { - rel_type: Some(rel_type), - })) - } -} - -/// Convert DataFusion LogicalPlan to Substrait Plan -pub fn to_substrait_plan(plan: &LogicalPlan, state: &SessionState) -> Result> { - // Parse relation nodes - // Generate PlanRel(s) - // Note: Only 1 relation tree is currently supported - - let mut producer: DefaultSubstraitProducer = DefaultSubstraitProducer::new(state); - let plan_rels = vec![PlanRel { - rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*producer.handle_plan(plan)?), - names: to_substrait_named_struct(plan.schema())?.names, - })), - }]; - - // Return parsed plan - let extensions = producer.get_extensions(); - Ok(Box::new(Plan { - version: Some(version::version_with_producer("datafusion")), - extension_uris: vec![], - extensions: extensions.into(), - relations: plan_rels, - advanced_extensions: None, - expected_type_urls: vec![], - parameter_bindings: vec![], - })) -} - -/// Serializes a collection of expressions to a Substrait ExtendedExpression message -/// -/// The ExtendedExpression message is a top-level message that can be used to send -/// expressions (not plans) between systems. -/// -/// Each expression is also given names for the output type. These are provided as a -/// field and not a String (since the names may be nested, e.g. a struct). The data -/// type and nullability of this field is redundant (those can be determined by the -/// Expr) and will be ignored. -/// -/// Substrait also requires the input schema of the expressions to be included in the -/// message. The field names of the input schema will be serialized. -pub fn to_substrait_extended_expr( - exprs: &[(&Expr, &Field)], - schema: &DFSchemaRef, - state: &SessionState, -) -> Result> { - let mut producer = DefaultSubstraitProducer::new(state); - let substrait_exprs = exprs - .iter() - .map(|(expr, field)| { - let substrait_expr = producer.handle_expr(expr, schema)?; - let mut output_names = Vec::new(); - flatten_names(field, false, &mut output_names)?; - Ok(ExpressionReference { - output_names, - expr_type: Some(ExprType::Expression(substrait_expr)), - }) - }) - .collect::>>()?; - let substrait_schema = to_substrait_named_struct(schema)?; - - let extensions = producer.get_extensions(); - Ok(Box::new(ExtendedExpression { - advanced_extensions: None, - expected_type_urls: vec![], - extension_uris: vec![], - extensions: extensions.into(), - version: Some(version::version_with_producer("datafusion")), - referred_expr: substrait_exprs, - base_schema: Some(substrait_schema), - })) -} - -pub fn to_substrait_rel( - producer: &mut impl SubstraitProducer, - plan: &LogicalPlan, -) -> Result> { - match plan { - LogicalPlan::Projection(plan) => producer.handle_projection(plan), - LogicalPlan::Filter(plan) => producer.handle_filter(plan), - LogicalPlan::Window(plan) => producer.handle_window(plan), - LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), - LogicalPlan::Sort(plan) => producer.handle_sort(plan), - LogicalPlan::Join(plan) => producer.handle_join(plan), - LogicalPlan::Repartition(plan) => producer.handle_repartition(plan), - LogicalPlan::Union(plan) => producer.handle_union(plan), - LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), - LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), - LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), - LogicalPlan::Limit(plan) => producer.handle_limit(plan), - LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Values(plan) => producer.handle_values(plan), - LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Extension(plan) => producer.handle_extension(plan), - LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), - LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::DescribeTable(plan) => { - not_impl_err!("Unsupported plan type: {plan:?}")? - } - LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::RecursiveQuery(plan) => { - not_impl_err!("Unsupported plan type: {plan:?}")? - } - } -} - -pub fn from_table_scan( - producer: &mut impl SubstraitProducer, - scan: &TableScan, -) -> Result> { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, - child: None, - }) - .collect() - }); - - let projection = projection.map(|struct_items| MaskExpression { - select: Some(StructSelect { struct_items }), - maintain_singular_struct: false, - }); - - let table_schema = scan.source.schema().to_dfschema_ref()?; - let base_schema = to_substrait_named_struct(&table_schema)?; - - let filter_option = if scan.filters.is_empty() { - None - } else { - let table_schema_qualified = Arc::new( - DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - &(scan.source.schema()), - ) - .unwrap(), - ); - - let combined_expr = conjunction(scan.filters.clone()).unwrap(); - let filter_expr = - producer.handle_expr(&combined_expr, &table_schema_qualified)?; - Some(Box::new(filter_expr)) - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(base_schema), - filter: filter_option, - best_effort_filter: None, - projection, - advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), - }))), - })) -} - -pub fn from_empty_relation(e: &EmptyRelation) -> Result> { - if e.produce_one_row { - return not_impl_err!("Producing a row from empty relation is unsupported"); - } - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&e.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values: vec![], - expressions: vec![], - })), - }))), - })) -} - -pub fn from_values( - producer: &mut impl SubstraitProducer, - v: &Values, -) -> Result> { - let values = v - .values - .iter() - .map(|row| { - let fields = row - .iter() - .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(producer, sv), - Expr::Alias(alias) => match alias.expr.as_ref() { - // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(producer, sv), - _ => Err(substrait_datafusion_err!( - "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() - )), - }, - _ => Err(substrait_datafusion_err!( - "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() - )), - }) - .collect::>()?; - Ok(Struct { fields }) - }) - .collect::>()?; - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&v.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values, - expressions: vec![], - })), - }))), - })) -} - -pub fn from_projection( - producer: &mut impl SubstraitProducer, - p: &Projection, -) -> Result> { - let expressions = p - .expr - .iter() - .map(|e| producer.handle_expr(e, p.input.schema())) - .collect::>>()?; - - let emit_kind = create_project_remapping( - expressions.len(), - p.input.as_ref().schema().fields().len(), - ); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(Box::new(ProjectRel { - common: Some(common), - input: Some(producer.handle_plan(p.input.as_ref())?), - expressions, - advanced_extension: None, - }))), - })) -} - -pub fn from_filter( - producer: &mut impl SubstraitProducer, - filter: &Filter, -) -> Result> { - let input = producer.handle_plan(filter.input.as_ref())?; - let filter_expr = producer.handle_expr(&filter.predicate, filter.input.schema())?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Filter(Box::new(FilterRel { - common: None, - input: Some(input), - condition: Some(Box::new(filter_expr)), - advanced_extension: None, - }))), - })) -} - -pub fn from_limit( - producer: &mut impl SubstraitProducer, - limit: &Limit, -) -> Result> { - let input = producer.handle_plan(limit.input.as_ref())?; - let empty_schema = Arc::new(DFSchema::empty()); - let offset_mode = limit - .skip - .as_ref() - .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) - .transpose()? - .map(Box::new) - .map(fetch_rel::OffsetMode::OffsetExpr); - let count_mode = limit - .fetch - .as_ref() - .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) - .transpose()? - .map(Box::new) - .map(fetch_rel::CountMode::CountExpr); - Ok(Box::new(Rel { - rel_type: Some(RelType::Fetch(Box::new(FetchRel { - common: None, - input: Some(input), - offset_mode, - count_mode, - advanced_extension: None, - }))), - })) -} - -pub fn from_sort(producer: &mut impl SubstraitProducer, sort: &Sort) -> Result> { - let Sort { expr, input, fetch } = sort; - let sort_fields = expr - .iter() - .map(|e| substrait_sort_field(producer, e, input.schema())) - .collect::>>()?; - - let input = producer.handle_plan(input.as_ref())?; - - let sort_rel = Box::new(Rel { - rel_type: Some(RelType::Sort(Box::new(SortRel { - common: None, - input: Some(input), - sorts: sort_fields, - advanced_extension: None, - }))), - }); - - match fetch { - Some(amount) => { - let count_mode = - Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: false, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::I64(*amount as i64)), - })), - }))); - Ok(Box::new(Rel { - rel_type: Some(RelType::Fetch(Box::new(FetchRel { - common: None, - input: Some(sort_rel), - offset_mode: None, - count_mode, - advanced_extension: None, - }))), - })) - } - None => Ok(sort_rel), - } -} - -pub fn from_aggregate( - producer: &mut impl SubstraitProducer, - agg: &Aggregate, -) -> Result> { - let input = producer.handle_plan(agg.input.as_ref())?; - let (grouping_expressions, groupings) = - to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; - let measures = agg - .aggr_expr - .iter() - .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema())) - .collect::>>()?; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { - common: None, - input: Some(input), - grouping_expressions, - groupings, - measures, - advanced_extension: None, - }))), - })) -} - -pub fn from_distinct( - producer: &mut impl SubstraitProducer, - distinct: &Distinct, -) -> Result> { - match distinct { - Distinct::All(plan) => { - // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = producer.handle_plan(plan.as_ref())?; - // Get grouping keys from the input relation's number of output fields - let grouping = (0..plan.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; - - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { - common: None, - input: Some(input), - grouping_expressions: vec![], - groupings: vec![Grouping { - grouping_expressions: grouping, - expression_references: vec![], - }], - measures: vec![], - advanced_extension: None, - }))), - })) - } - Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"), - } -} - -pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result> { - let left = producer.handle_plan(join.left.as_ref())?; - let right = producer.handle_plan(join.right.as_ref())?; - let join_type = to_substrait_jointype(join.join_type); - // we only support basic joins so return an error for anything not yet supported - match join.join_constraint { - JoinConstraint::On => {} - JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), - } - let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); - - // convert filter if present - let join_filter = match &join.filter { - Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), - None => None, - }; - - // map the left and right columns to binary expressions in the form `l = r` - // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let eq_op = if join.null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq - }; - let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; - - // create conjunction between `join_on` and `join_filter` to embed all join conditions, - // whether equal or non-equal in a single expression - let join_expr = match &join_on { - Some(on_expr) => match &join_filter { - Some(filter) => Some(Box::new(make_binary_op_scalar_func( - producer, - on_expr, - filter, - Operator::And, - ))), - None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist - }, - None => match &join_filter { - Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist - None => None, - }, - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: join_expr, - post_join_filter: None, - advanced_extension: None, - }))), - })) -} - -pub fn from_subquery_alias( - producer: &mut impl SubstraitProducer, - alias: &SubqueryAlias, -) -> Result> { - // Do nothing if encounters SubqueryAlias - // since there is no corresponding relation type in Substrait - producer.handle_plan(alias.input.as_ref()) -} - -pub fn from_union( - producer: &mut impl SubstraitProducer, - union: &Union, -) -> Result> { - let input_rels = union - .inputs - .iter() - .map(|input| producer.handle_plan(input.as_ref())) - .collect::>>()? - .into_iter() - .map(|ptr| *ptr) - .collect(); - Ok(Box::new(Rel { - rel_type: Some(RelType::Set(SetRel { - common: None, - inputs: input_rels, - op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL - advanced_extension: None, - })), - })) -} - -pub fn from_window( - producer: &mut impl SubstraitProducer, - window: &Window, -) -> Result> { - let input = producer.handle_plan(window.input.as_ref())?; - - // create a field reference for each input field - let mut expressions = (0..window.input.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; - - // process and add each window function expression - for expr in &window.window_expr { - expressions.push(producer.handle_expr(expr, window.input.schema())?); - } - - let emit_kind = - create_project_remapping(expressions.len(), window.input.schema().fields().len()); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; - let project_rel = Box::new(ProjectRel { - common: Some(common), - input: Some(input), - expressions, - advanced_extension: None, - }); - - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(project_rel)), - })) -} - -pub fn from_repartition( - producer: &mut impl SubstraitProducer, - repartition: &Repartition, -) -> Result> { - let input = producer.handle_plan(repartition.input.as_ref())?; - let partition_count = match repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(num) => num, - Partitioning::Hash(_, num) => num, - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let exchange_kind = match &repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(_) => { - ExchangeKind::RoundRobin(RoundRobin::default()) - } - Partitioning::Hash(exprs, _) => { - let fields = exprs - .iter() - .map(|e| try_to_substrait_field_reference(e, repartition.input.schema())) - .collect::>>()?; - ExchangeKind::ScatterByFields(ScatterFields { fields }) - } - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - let exchange_rel = ExchangeRel { - common: None, - input: Some(input), - exchange_kind: Some(exchange_kind), - advanced_extension: None, - partition_count: partition_count as i32, - targets: vec![], - }; - Ok(Box::new(Rel { - rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), - })) -} - -/// By default, a Substrait Project outputs all input fields followed by all expressions. -/// A DataFusion Projection only outputs expressions. In order to keep the Substrait -/// plan consistent with DataFusion, we must apply an output mapping that skips the input -/// fields so that the Substrait Project will only output the expression fields. -fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { - let expression_field_start = input_field_count; - let expression_field_end = expression_field_start + expr_count; - let output_mapping = (expression_field_start..expression_field_end) - .map(|i| i as i32) - .collect(); - Emit(rel_common::Emit { output_mapping }) -} - -// Substrait wants a list of all field names, including nested fields from structs, -// also from within e.g. lists and maps. However, it does not want the list and map field names -// themselves - only proper structs fields are considered to have useful names. -fn flatten_names(field: &Field, skip_self: bool, names: &mut Vec) -> Result<()> { - if !skip_self { - names.push(field.name().to_string()); - } - match field.data_type() { - DataType::Struct(fields) => { - for field in fields { - flatten_names(field, false, names)?; - } - Ok(()) - } - DataType::List(l) => flatten_names(l, true, names), - DataType::LargeList(l) => flatten_names(l, true, names), - DataType::Map(m, _) => match m.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - flatten_names(&key_and_value[0], true, names)?; - flatten_names(&key_and_value[1], true, names) - } - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - }, - _ => Ok(()), - }?; - Ok(()) -} - -fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { - let mut names = Vec::with_capacity(schema.fields().len()); - for field in schema.fields() { - flatten_names(field, false, &mut names)?; - } - - let field_types = r#type::Struct { - types: schema - .fields() - .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) - .collect::>()?, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - }; - - Ok(NamedStruct { - names, - r#struct: Some(field_types), - }) -} - -fn to_substrait_join_expr( - producer: &mut impl SubstraitProducer, - join_conditions: &Vec<(Expr, Expr)>, - eq_op: Operator, - join_schema: &DFSchemaRef, -) -> Result> { - // Only support AND conjunction for each binary expression in join conditions - let mut exprs: Vec = vec![]; - for (left, right) in join_conditions { - let l = producer.handle_expr(left, join_schema)?; - let r = producer.handle_expr(right, join_schema)?; - // AND with existing expression - exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); - } - - let join_expr: Option = - exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(producer, &acc, &e, Operator::And) - }); - Ok(join_expr) -} - -fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { - match join_type { - JoinType::Inner => join_rel::JoinType::Inner, - JoinType::Left => join_rel::JoinType::Left, - JoinType::Right => join_rel::JoinType::Right, - JoinType::Full => join_rel::JoinType::Outer, - JoinType::LeftAnti => join_rel::JoinType::LeftAnti, - JoinType::LeftSemi => join_rel::JoinType::LeftSemi, - JoinType::LeftMark => join_rel::JoinType::LeftMark, - JoinType::RightAnti | JoinType::RightSemi => { - unimplemented!() - } - } -} - -pub fn operator_to_name(op: Operator) -> &'static str { - match op { - Operator::Eq => "equal", - Operator::NotEq => "not_equal", - Operator::Lt => "lt", - Operator::LtEq => "lte", - Operator::Gt => "gt", - Operator::GtEq => "gte", - Operator::Plus => "add", - Operator::Minus => "subtract", - Operator::Multiply => "multiply", - Operator::Divide => "divide", - Operator::Modulo => "modulus", - Operator::And => "and", - Operator::Or => "or", - Operator::IsDistinctFrom => "is_distinct_from", - Operator::IsNotDistinctFrom => "is_not_distinct_from", - Operator::RegexMatch => "regex_match", - Operator::RegexIMatch => "regex_imatch", - Operator::RegexNotMatch => "regex_not_match", - Operator::RegexNotIMatch => "regex_not_imatch", - Operator::LikeMatch => "like_match", - Operator::ILikeMatch => "like_imatch", - Operator::NotLikeMatch => "like_not_match", - Operator::NotILikeMatch => "like_not_imatch", - Operator::BitwiseAnd => "bitwise_and", - Operator::BitwiseOr => "bitwise_or", - Operator::StringConcat => "str_concat", - Operator::AtArrow => "at_arrow", - Operator::ArrowAt => "arrow_at", - Operator::Arrow => "arrow", - Operator::LongArrow => "long_arrow", - Operator::HashArrow => "hash_arrow", - Operator::HashLongArrow => "hash_long_arrow", - Operator::AtAt => "at_at", - Operator::IntegerDivide => "integer_divide", - Operator::HashMinus => "hash_minus", - Operator::AtQuestion => "at_question", - Operator::Question => "question", - Operator::QuestionAnd => "question_and", - Operator::QuestionPipe => "question_pipe", - Operator::BitwiseXor => "bitwise_xor", - Operator::BitwiseShiftRight => "bitwise_shift_right", - Operator::BitwiseShiftLeft => "bitwise_shift_left", - } -} - -pub fn parse_flat_grouping_exprs( - producer: &mut impl SubstraitProducer, - exprs: &[Expr], - schema: &DFSchemaRef, - ref_group_exprs: &mut Vec, -) -> Result { - let mut expression_references = vec![]; - let mut grouping_expressions = vec![]; - - for e in exprs { - let rex = producer.handle_expr(e, schema)?; - grouping_expressions.push(rex.clone()); - ref_group_exprs.push(rex); - expression_references.push((ref_group_exprs.len() - 1) as u32); - } - #[allow(deprecated)] - Ok(Grouping { - grouping_expressions, - expression_references, - }) -} - -pub fn to_substrait_groupings( - producer: &mut impl SubstraitProducer, - exprs: &[Expr], - schema: &DFSchemaRef, -) -> Result<(Vec, Vec)> { - let mut ref_group_exprs = vec![]; - let groupings = match exprs.len() { - 1 => match &exprs[0] { - Expr::GroupingSet(gs) => match gs { - GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( - "GroupingSet CUBE is not yet supported".to_string(), - )), - GroupingSet::GroupingSets(sets) => Ok(sets - .iter() - .map(|set| { - parse_flat_grouping_exprs( - producer, - set, - schema, - &mut ref_group_exprs, - ) - }) - .collect::>>()?), - GroupingSet::Rollup(set) => { - let mut sets: Vec> = vec![vec![]]; - for i in 0..set.len() { - sets.push(set[..=i].to_vec()); - } - Ok(sets - .iter() - .rev() - .map(|set| { - parse_flat_grouping_exprs( - producer, - set, - schema, - &mut ref_group_exprs, - ) - }) - .collect::>>()?) - } - }, - _ => Ok(vec![parse_flat_grouping_exprs( - producer, - exprs, - schema, - &mut ref_group_exprs, - )?]), - }, - _ => Ok(vec![parse_flat_grouping_exprs( - producer, - exprs, - schema, - &mut ref_group_exprs, - )?]), - }?; - Ok((ref_group_exprs, groupings)) -} - -pub fn from_aggregate_function( - producer: &mut impl SubstraitProducer, - agg_fn: &expr::AggregateFunction, - schema: &DFSchemaRef, -) -> Result { - let expr::AggregateFunction { - func, - params: - AggregateFunctionParams { - args, - distinct, - filter, - order_by, - null_treatment: _null_treatment, - }, - } = agg_fn; - let sorts = if let Some(order_by) = order_by { - order_by - .iter() - .map(|expr| to_substrait_sort_field(producer, expr, schema)) - .collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - let function_anchor = producer.register_function(func.name().to_string()); - #[allow(deprecated)] - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(producer.handle_expr(f, schema)?), - None => None, - }, - }) -} - -pub fn to_substrait_agg_measure( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), - Expr::Alias(Alias { expr, .. }) => { - to_substrait_agg_measure(producer, expr, schema) - } - _ => internal_err!( - "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", - expr, - expr.variant_name() - ), - } -} - -/// Converts sort expression to corresponding substrait `SortField` -fn to_substrait_sort_field( - producer: &mut impl SubstraitProducer, - sort: &expr::Sort, - schema: &DFSchemaRef, -) -> Result { - let sort_kind = match (sort.asc, sort.nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(producer.handle_expr(&sort.expr, schema)?), - sort_kind: Some(SortKind::Direction(sort_kind.into())), - }) -} - -/// Return Substrait scalar function with two arguments -pub fn make_binary_op_scalar_func( - producer: &mut impl SubstraitProducer, - lhs: &Expression, - rhs: &Expression, - op: Operator, -) -> Expression { - let function_anchor = producer.register_function(operator_to_name(op).to_string()); - #[allow(deprecated)] - Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![ - FunctionArgument { - arg_type: Some(ArgType::Value(lhs.clone())), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(rhs.clone())), - }, - ], - output_type: None, - args: vec![], - options: vec![], - })), - } -} - -/// Convert DataFusion Expr to Substrait Rex -/// -/// # Arguments -/// * `producer` - SubstraitProducer implementation which the handles the actual conversion -/// * `expr` - DataFusion expression to convert into a Substrait expression -/// * `schema` - DataFusion input schema for looking up columns -pub fn to_substrait_rex( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::Alias(expr) => producer.handle_alias(expr, schema), - Expr::Column(expr) => producer.handle_column(expr, schema), - Expr::ScalarVariable(_, _) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - Expr::Literal(expr) => producer.handle_literal(expr), - Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), - Expr::Like(expr) => producer.handle_like(expr, schema), - Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Not(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), - Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema), - Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema), - Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema), - Expr::Negative(_) => producer.handle_unary_expr(expr, schema), - Expr::Between(expr) => producer.handle_between(expr, schema), - Expr::Case(expr) => producer.handle_case(expr, schema), - Expr::Cast(expr) => producer.handle_cast(expr, schema), - Expr::TryCast(expr) => producer.handle_try_cast(expr, schema), - Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema), - Expr::AggregateFunction(_) => { - internal_err!( - "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" - ) - } - Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), - Expr::InList(expr) => producer.handle_in_list(expr, schema), - Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), - Expr::ScalarSubquery(expr) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - #[expect(deprecated)] - Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::OuterReferenceColumn(_, _) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - } -} - -pub fn from_in_list( - producer: &mut impl SubstraitProducer, - in_list: &InList, - schema: &DFSchemaRef, -) -> Result { - let InList { - expr, - list, - negated, - } = in_list; - let substrait_list = list - .iter() - .map(|x| producer.handle_expr(x, schema)) - .collect::>>()?; - let substrait_expr = producer.handle_expr(expr, schema)?; - - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; - - if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } -} - -pub fn from_scalar_function( - producer: &mut impl SubstraitProducer, - fun: &expr::ScalarFunction, - schema: &DFSchemaRef, -) -> Result { - let mut arguments: Vec = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - - let function_anchor = producer.register_function(fun.name().to_string()); - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - options: vec![], - args: vec![], - })), - }) -} - -pub fn from_between( - producer: &mut impl SubstraitProducer, - between: &Between, - schema: &DFSchemaRef, -) -> Result { - let Between { - expr, - negated, - low, - high, - } = between; - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_low, - Operator::Lt, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_high, - &substrait_expr, - Operator::Lt, - ); - - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::Or, - )) - } else { - // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_low, - &substrait_expr, - Operator::LtEq, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_high, - Operator::LtEq, - ); - - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::And, - )) - } -} -pub fn from_column(col: &Column, schema: &DFSchemaRef) -> Result { - let index = schema.index_of_column(col)?; - substrait_field_ref(index) -} - -pub fn from_binary_expr( - producer: &mut impl SubstraitProducer, - expr: &BinaryExpr, - schema: &DFSchemaRef, -) -> Result { - let BinaryExpr { left, op, right } = expr; - let l = producer.handle_expr(left, schema)?; - let r = producer.handle_expr(right, schema)?; - Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) -} -pub fn from_case( - producer: &mut impl SubstraitProducer, - case: &Case, - schema: &DFSchemaRef, -) -> Result { - let Case { - expr, - when_then_expr, - else_expr, - } = case; - let mut ifs: Vec = vec![]; - // Parse base - if let Some(e) = expr { - // Base expression exists - ifs.push(IfClause { - r#if: Some(producer.handle_expr(e, schema)?), - then: None, - }); - } - // Parse `when`s - for (r#if, then) in when_then_expr { - ifs.push(IfClause { - r#if: Some(producer.handle_expr(r#if, schema)?), - then: Some(producer.handle_expr(then, schema)?), - }); - } - - // Parse outer `else` - let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)), - None => None, - }; - - Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), - }) -} - -pub fn from_cast( - producer: &mut impl SubstraitProducer, - cast: &Cast, - schema: &DFSchemaRef, -) -> Result { - let Cast { expr, data_type } = cast; - Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(producer.handle_expr(expr, schema)?)), - failure_behavior: FailureBehavior::ThrowException.into(), - }, - ))), - }) -} - -pub fn from_try_cast( - producer: &mut impl SubstraitProducer, - cast: &TryCast, - schema: &DFSchemaRef, -) -> Result { - let TryCast { expr, data_type } = cast; - Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(producer.handle_expr(expr, schema)?)), - failure_behavior: FailureBehavior::ReturnNull.into(), - }, - ))), - }) -} - -pub fn from_literal( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - to_substrait_literal_expr(producer, value) -} - -pub fn from_alias( - producer: &mut impl SubstraitProducer, - alias: &Alias, - schema: &DFSchemaRef, -) -> Result { - producer.handle_expr(alias.expr.as_ref(), schema) -} - -pub fn from_window_function( - producer: &mut impl SubstraitProducer, - window_fn: &WindowFunction, - schema: &DFSchemaRef, -) -> Result { - let WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }, - } = window_fn; - // function reference - let function_anchor = producer.register_function(fun.to_string()); - // arguments - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - // partition by expressions - let partition_by = partition_by - .iter() - .map(|e| producer.handle_expr(e, schema)) - .collect::>>()?; - // order by expressions - let order_by = order_by - .iter() - .map(|e| substrait_sort_field(producer, e, schema)) - .collect::>>()?; - // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; - Ok(make_substrait_window_function( - function_anchor, - arguments, - partition_by, - order_by, - bounds, - bound_type, - )) -} - -pub fn from_like( - producer: &mut impl SubstraitProducer, - like: &Like, - schema: &DFSchemaRef, -) -> Result { - let Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - } = like; - make_substrait_like_expr( - producer, - *case_insensitive, - *negated, - expr, - pattern, - *escape_char, - schema, - ) -} - -pub fn from_in_subquery( - producer: &mut impl SubstraitProducer, - subquery: &InSubquery, - schema: &DFSchemaRef, -) -> Result { - let InSubquery { - expr, - subquery, - negated, - } = subquery; - let substrait_expr = producer.handle_expr(expr, schema)?; - - let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; - - let substrait_subquery = Expression { - rex_type: Some(RexType::Subquery(Box::new( - substrait::proto::expression::Subquery { - subquery_type: Some( - substrait::proto::expression::subquery::SubqueryType::InPredicate( - Box::new(InPredicate { - needles: (vec![substrait_expr]), - haystack: Some(subquery_plan), - }), - ), - ), - }, - ))), - }; - if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_subquery)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_subquery) - } -} - -pub fn from_unary_expr( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - let (fn_name, arg) = match expr { - Expr::Not(arg) => ("not", arg), - Expr::IsNull(arg) => ("is_null", arg), - Expr::IsNotNull(arg) => ("is_not_null", arg), - Expr::IsTrue(arg) => ("is_true", arg), - Expr::IsFalse(arg) => ("is_false", arg), - Expr::IsUnknown(arg) => ("is_unknown", arg), - Expr::IsNotTrue(arg) => ("is_not_true", arg), - Expr::IsNotFalse(arg) => ("is_not_false", arg), - Expr::IsNotUnknown(arg) => ("is_not_unknown", arg), - Expr::Negative(arg) => ("negate", arg), - expr => not_impl_err!("Unsupported expression: {expr:?}")?, - }; - to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) -} - -fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { - let nullability = if nullable { - r#type::Nullability::Nullable as i32 - } else { - r#type::Nullability::Required as i32 - }; - match dt { - DataType::Null => internal_err!("Null cast is not valid"), - DataType::Boolean => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int16 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt16 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - // Float16 is not supported in Substrait - DataType::Float32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Fp32(r#type::Fp32 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Float64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Fp64(r#type::Fp64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Timestamp(unit, tz) => { - let precision = match unit { - TimeUnit::Second => 0, - TimeUnit::Millisecond => 3, - TimeUnit::Microsecond => 6, - TimeUnit::Nanosecond => 9, - }; - let kind = match tz { - None => r#type::Kind::PrecisionTimestamp(r#type::PrecisionTimestamp { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision, - }), - Some(_) => { - // If timezone is present, no matter what the actual tz value is, it indicates the - // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. - // As the timezone is lost, this conversion may be lossy for downstream use of the value. - r#type::Kind::PrecisionTimestampTz(r#type::PrecisionTimestampTz { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision, - }) - } - }; - Ok(substrait::proto::Type { kind: Some(kind) }) - } - DataType::Date32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_32_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Date64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_64_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Interval(interval_unit) => { - match interval_unit { - IntervalUnit::YearMonth => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - IntervalUnit::DayTime => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision: Some(3), // DayTime precision is always milliseconds - })), - }), - IntervalUnit::MonthDayNano => { - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalCompound( - r#type::IntervalCompound { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision: 9, // nanos - }, - )), - }) - } - } - } - DataType::Binary => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { - length: *length, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::LargeBinary => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::BinaryView => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Utf8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::LargeUtf8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Utf8View => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::List(inner) => { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::List(Box::new(r#type::List { - r#type: Some(Box::new(inner_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - DataType::LargeList(inner) => { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::List(Box::new(r#type::List { - r#type: Some(Box::new(inner_type)), - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - DataType::Map(inner, _) => match inner.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - let key_type = to_substrait_type( - key_and_value[0].data_type(), - key_and_value[0].is_nullable(), - )?; - let value_type = to_substrait_type( - key_and_value[1].data_type(), - key_and_value[1].is_nullable(), - )?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Map(Box::new(r#type::Map { - key: Some(Box::new(key_type)), - value: Some(Box::new(value_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - }, - DataType::Struct(fields) => { - let field_types = fields - .iter() - .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) - .collect::>>()?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Struct(r#type::Struct { - types: field_types, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }) - } - DataType::Decimal128(p, s) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, - nullability, - scale: *s as i32, - precision: *p as i32, - })), - }), - DataType::Decimal256(p, s) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, - nullability, - scale: *s as i32, - precision: *p as i32, - })), - }), - _ => not_impl_err!("Unsupported cast type: {dt:?}"), - } -} - -fn make_substrait_window_function( - function_reference: u32, - arguments: Vec, - partitions: Vec, - sorts: Vec, - bounds: (Bound, Bound), - bounds_type: BoundsType, -) -> Expression { - #[allow(deprecated)] - Expression { - rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { - function_reference, - arguments, - partitions, - sorts, - options: vec![], - output_type: None, - phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED - invocation: 0, // TODO: fix - lower_bound: Some(bounds.0), - upper_bound: Some(bounds.1), - args: vec![], - bounds_type: bounds_type as i32, - })), - } -} - -fn make_substrait_like_expr( - producer: &mut impl SubstraitProducer, - ignore_case: bool, - negated: bool, - expr: &Expr, - pattern: &Expr, - escape_char: Option, - schema: &DFSchemaRef, -) -> Result { - let function_anchor = if ignore_case { - producer.register_function("ilike".to_string()) - } else { - producer.register_function("like".to_string()) - }; - let expr = producer.handle_expr(expr, schema)?; - let pattern = producer.handle_expr(pattern, schema)?; - let escape_char = to_substrait_literal_expr( - producer, - &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), - )?; - let arguments = vec![ - FunctionArgument { - arg_type: Some(ArgType::Value(expr)), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(pattern)), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(escape_char)), - }, - ]; - - #[allow(deprecated)] - let substrait_like = Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }; - - if negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_like)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_like) - } -} - -fn to_substrait_bound_offset(value: &ScalarValue) -> Option { - match value { - ScalarValue::UInt8(Some(v)) => Some(*v as i64), - ScalarValue::UInt16(Some(v)) => Some(*v as i64), - ScalarValue::UInt32(Some(v)) => Some(*v as i64), - ScalarValue::UInt64(Some(v)) => Some(*v as i64), - ScalarValue::Int8(Some(v)) => Some(*v as i64), - ScalarValue::Int16(Some(v)) => Some(*v as i64), - ScalarValue::Int32(Some(v)) => Some(*v as i64), - ScalarValue::Int64(Some(v)) => Some(*v), - _ => None, - } -} - -fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { - match bound { - WindowFrameBound::CurrentRow => Bound { - kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), - }, - WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { - Some(offset) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), - }, - None => Bound { - kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), - }, - }, - WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { - Some(offset) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), - }, - None => Bound { - kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), - }, - }, - } -} - -fn to_substrait_bound_type(window_frame: &WindowFrame) -> Result { - match window_frame.units { - WindowFrameUnits::Rows => Ok(BoundsType::Rows), // ROWS - WindowFrameUnits::Range => Ok(BoundsType::Range), // RANGE - // TODO: Support GROUPS - unit => not_impl_err!("Unsupported window frame unit: {unit:?}"), - } -} - -fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { - Ok(( - to_substrait_bound(&window_frame.start_bound), - to_substrait_bound(&window_frame.end_bound), - )) -} - -fn to_substrait_literal( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - if value.is_null() { - return Ok(Literal { - nullable: true, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::Null(to_substrait_type( - &value.data_type(), - true, - )?)), - }); - } - let (literal_type, type_variation_reference) = match value { - ScalarValue::Boolean(Some(b)) => { - (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::Int8(Some(n)) => { - (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::UInt8(Some(n)) => ( - LiteralType::I8(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int16(Some(n)) => { - (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::UInt16(Some(n)) => ( - LiteralType::I16(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF), - ScalarValue::UInt32(Some(n)) => ( - LiteralType::I32(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF), - ScalarValue::UInt64(Some(n)) => ( - LiteralType::I64(*n as i64), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Float32(Some(f)) => { - (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::Float64(Some(f)) => { - (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::TimestampSecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 0, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMillisecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 3, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMicrosecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 6, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampNanosecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 9, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - // If timezone is present, no matter what the actual tz value is, it indicates the - // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. - // As the timezone is lost, this conversion may be lossy for downstream use of the value. - ScalarValue::TimestampSecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 0, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 3, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 6, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 9, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Date32(Some(d)) => { - (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) - } - // Date64 literal is not supported in Substrait - ScalarValue::IntervalYearMonth(Some(i)) => ( - LiteralType::IntervalYearToMonth(IntervalYearToMonth { - // DF only tracks total months, but there should always be 12 months in a year - years: *i / 12, - months: *i % 12, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::IntervalMonthDayNano(Some(i)) => ( - LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month: Some(IntervalYearToMonth { - years: i.months / 12, - months: i.months % 12, - }), - interval_day_to_second: Some(IntervalDayToSecond { - days: i.days, - seconds: (i.nanoseconds / NANOSECONDS) as i32, - subseconds: i.nanoseconds % NANOSECONDS, - precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds - }), - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::IntervalDayTime(Some(i)) => ( - LiteralType::IntervalDayToSecond(IntervalDayToSecond { - days: i.days, - seconds: i.milliseconds / 1000, - subseconds: (i.milliseconds % 1000) as i64, - precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Binary(Some(b)) => ( - LiteralType::Binary(b.clone()), - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeBinary(Some(b)) => ( - LiteralType::Binary(b.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::BinaryView(Some(b)) => ( - LiteralType::Binary(b.clone()), - VIEW_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::FixedSizeBinary(_, Some(b)) => ( - LiteralType::FixedBinary(b.clone()), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Utf8(Some(s)) => ( - LiteralType::String(s.clone()), - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeUtf8(Some(s)) => ( - LiteralType::String(s.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Utf8View(Some(s)) => ( - LiteralType::String(s.clone()), - VIEW_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Decimal128(v, p, s) if v.is_some() => ( - LiteralType::Decimal(Decimal { - value: v.unwrap().to_le_bytes().to_vec(), - precision: *p as i32, - scale: *s as i32, - }), - DECIMAL_128_TYPE_VARIATION_REF, - ), - ScalarValue::List(l) => ( - convert_array_to_literal_list(producer, l)?, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeList(l) => ( - convert_array_to_literal_list(producer, l)?, - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Map(m) => { - let map = if m.is_empty() || m.value(0).is_empty() { - let mt = to_substrait_type(m.data_type(), m.is_nullable())?; - let mt = match mt { - substrait::proto::Type { - kind: Some(r#type::Kind::Map(mt)), - } => Ok(mt.as_ref().to_owned()), - _ => exec_err!("Unexpected type for a map: {mt:?}"), - }?; - LiteralType::EmptyMap(mt) - } else { - let keys = (0..m.keys().len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&m.keys(), i)?, - ) - }) - .collect::>>()?; - let values = (0..m.values().len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&m.values(), i)?, - ) - }) - .collect::>>()?; - - let key_values = keys - .into_iter() - .zip(values.into_iter()) - .map(|(k, v)| { - Ok(KeyValue { - key: Some(k), - value: Some(v), - }) - }) - .collect::>>()?; - LiteralType::Map(Map { key_values }) - }; - (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) - } - ScalarValue::Struct(s) => ( - LiteralType::Struct(Struct { - fields: s - .columns() - .iter() - .map(|col| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(col, 0)?, - ) - }) - .collect::>>()?, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - _ => ( - not_impl_err!("Unsupported literal: {value:?}")?, - DEFAULT_TYPE_VARIATION_REF, - ), - }; - - Ok(Literal { - nullable: false, - type_variation_reference, - literal_type: Some(literal_type), - }) -} - -fn convert_array_to_literal_list( - producer: &mut impl SubstraitProducer, - array: &GenericListArray, -) -> Result { - assert_eq!(array.len(), 1); - let nested_array = array.value(0); - - let values = (0..nested_array.len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&nested_array, i)?, - ) - }) - .collect::>>()?; - - if values.is_empty() { - let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { - substrait::proto::Type { - kind: Some(r#type::Kind::List(lt)), - } => lt.as_ref().to_owned(), - _ => unreachable!(), - }; - Ok(LiteralType::EmptyList(lt)) - } else { - Ok(LiteralType::List(List { values })) - } -} - -fn to_substrait_literal_expr( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - let literal = to_substrait_literal(producer, value)?; - Ok(Expression { - rex_type: Some(RexType::Literal(literal)), - }) -} - -/// Util to generate substrait [RexType::ScalarFunction] with one argument -fn to_substrait_unary_scalar_fn( - producer: &mut impl SubstraitProducer, - fn_name: &str, - arg: &Expr, - schema: &DFSchemaRef, -) -> Result { - let function_anchor = producer.register_function(fn_name.to_string()); - let substrait_expr = producer.handle_expr(arg, schema)?; - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_expr)), - }], - output_type: None, - options: vec![], - ..Default::default() - })), - }) -} - -/// Try to convert an [Expr] to a [FieldReference]. -/// Returns `Err` if the [Expr] is not a [Expr::Column]. -fn try_to_substrait_field_reference( - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::Column(col) => { - let index = schema.index_of_column(col)?; - Ok(FieldReference { - reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { - reference_type: Some(reference_segment::ReferenceType::StructField( - Box::new(reference_segment::StructField { - field: index as i32, - child: None, - }), - )), - })), - root_type: Some(RootType::RootReference(RootReference {})), - }) - } - _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), - } -} - -fn substrait_sort_field( - producer: &mut impl SubstraitProducer, - sort: &SortExpr, - schema: &DFSchemaRef, -) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = sort; - let e = producer.handle_expr(expr, schema)?; - let d = match (asc, nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(e), - sort_kind: Some(SortKind::Direction(d as i32)), - }) -} - -fn substrait_field_ref(index: usize) -> Result { - Ok(Expression { - rex_type: Some(RexType::Selection(Box::new(FieldReference { - reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { - reference_type: Some(reference_segment::ReferenceType::StructField( - Box::new(reference_segment::StructField { - field: index as i32, - child: None, - }), - )), - })), - root_type: Some(RootType::RootReference(RootReference {})), - }))), - }) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::logical_plan::consumer::{ - from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, - DefaultSubstraitConsumer, - }; - use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion::arrow; - use datafusion::arrow::array::{ - GenericListArray, Int64Builder, MapBuilder, StringBuilder, - }; - use datafusion::arrow::datatypes::{Field, Fields, Schema}; - use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; - use datafusion::execution::{SessionState, SessionStateBuilder}; - use datafusion::prelude::SessionContext; - use std::sync::LazyLock; - - static TEST_SESSION_STATE: LazyLock = - LazyLock::new(|| SessionContext::default().state()); - static TEST_EXTENSIONS: LazyLock = LazyLock::new(Extensions::default); - fn test_consumer() -> DefaultSubstraitConsumer<'static> { - let extensions = &TEST_EXTENSIONS; - let state = &TEST_SESSION_STATE; - DefaultSubstraitConsumer::new(extensions, state) - } - - #[test] - fn round_trip_literals() -> Result<()> { - round_trip_literal(ScalarValue::Boolean(None))?; - round_trip_literal(ScalarValue::Boolean(Some(true)))?; - round_trip_literal(ScalarValue::Boolean(Some(false)))?; - - round_trip_literal(ScalarValue::Int8(None))?; - round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?; - round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?; - round_trip_literal(ScalarValue::UInt8(None))?; - round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?; - round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?; - - round_trip_literal(ScalarValue::Int16(None))?; - round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?; - round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?; - round_trip_literal(ScalarValue::UInt16(None))?; - round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?; - round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?; - - round_trip_literal(ScalarValue::Int32(None))?; - round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?; - round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?; - round_trip_literal(ScalarValue::UInt32(None))?; - round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?; - round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?; - - round_trip_literal(ScalarValue::Int64(None))?; - round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?; - round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?; - round_trip_literal(ScalarValue::UInt64(None))?; - round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; - round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; - - for (ts, tz) in [ - (Some(12345), None), - (None, None), - (Some(12345), Some("UTC".into())), - (None, Some("UTC".into())), - ] { - round_trip_literal(ScalarValue::TimestampSecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampMillisecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampMicrosecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampNanosecond(ts, tz))?; - } - - round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( - &[ScalarValue::Float32(Some(1.0))], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( - &[], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( - Field::new_list_field(DataType::Float32, true).into(), - 1, - ))))?; - round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( - &[ScalarValue::Float32(Some(1.0))], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( - &[], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::LargeList(Arc::new( - GenericListArray::new_null( - Field::new_list_field(DataType::Float32, true).into(), - 1, - ), - )))?; - - // Null map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.append(false)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - // Empty map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.append(true)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - // Valid map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.keys().append_value("key1"); - map_builder.keys().append_value("key2"); - map_builder.values().append_value(1); - map_builder.values().append_value(2); - map_builder.append(true)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - let c0 = Field::new("c0", DataType::Boolean, true); - let c1 = Field::new("c1", DataType::Int32, true); - let c2 = Field::new("c2", DataType::Utf8, true); - round_trip_literal( - ScalarStructBuilder::new() - .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true))) - .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1))) - .with_scalar(c2.to_owned(), ScalarValue::Utf8(None)) - .build()?, - )?; - round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; - - round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; - round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano::new(17, 25, 1234567890), - )))?; - round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( - 57, 123456, - ))))?; - - Ok(()) - } - - fn round_trip_literal(scalar: ScalarValue) -> Result<()> { - println!("Checking round trip of {scalar:?}"); - let state = SessionContext::default().state(); - let mut producer = DefaultSubstraitProducer::new(&state); - let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; - let roundtrip_scalar = - from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; - assert_eq!(scalar, roundtrip_scalar); - Ok(()) - } - - #[test] - fn round_trip_types() -> Result<()> { - round_trip_type(DataType::Boolean)?; - round_trip_type(DataType::Int8)?; - round_trip_type(DataType::UInt8)?; - round_trip_type(DataType::Int16)?; - round_trip_type(DataType::UInt16)?; - round_trip_type(DataType::Int32)?; - round_trip_type(DataType::UInt32)?; - round_trip_type(DataType::Int64)?; - round_trip_type(DataType::UInt64)?; - round_trip_type(DataType::Float32)?; - round_trip_type(DataType::Float64)?; - - for tz in [None, Some("UTC".into())] { - round_trip_type(DataType::Timestamp(TimeUnit::Second, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, tz))?; - } - - round_trip_type(DataType::Date32)?; - round_trip_type(DataType::Date64)?; - round_trip_type(DataType::Binary)?; - round_trip_type(DataType::FixedSizeBinary(10))?; - round_trip_type(DataType::LargeBinary)?; - round_trip_type(DataType::BinaryView)?; - round_trip_type(DataType::Utf8)?; - round_trip_type(DataType::LargeUtf8)?; - round_trip_type(DataType::Utf8View)?; - round_trip_type(DataType::Decimal128(10, 2))?; - round_trip_type(DataType::Decimal256(30, 2))?; - - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, true).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, true).into(), - ))?; - - round_trip_type(DataType::Map( - Field::new_struct( - "entries", - [ - Field::new("key", DataType::Utf8, false).into(), - Field::new("value", DataType::Int32, true).into(), - ], - false, - ) - .into(), - false, - ))?; - - round_trip_type(DataType::Struct( - vec![ - Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), - ] - .into(), - ))?; - - round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; - round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; - round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; - - Ok(()) - } - - fn round_trip_type(dt: DataType) -> Result<()> { - println!("Checking round trip of {dt:?}"); - - // As DataFusion doesn't consider nullability as a property of the type, but field, - // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true)?; - let consumer = test_consumer(); - let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?; - assert_eq!(dt, roundtrip_dt); - Ok(()) - } - - #[test] - fn to_field_reference() -> Result<()> { - let expression = substrait_field_ref(2)?; - - match &expression.rex_type { - Some(RexType::Selection(field_ref)) => { - assert_eq!( - field_ref - .root_type - .clone() - .expect("root type should be set"), - RootType::RootReference(RootReference {}) - ); - } - - _ => panic!("Should not be anything other than field reference"), - } - Ok(()) - } - - #[test] - fn named_struct_names() -> Result<()> { - let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ - Field::new("int", DataType::Int32, true), - Field::new( - "struct", - DataType::Struct(Fields::from(vec![Field::new( - "inner", - DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), - true, - )])), - true, - ), - Field::new("trailer", DataType::Float64, true), - ]))?); - - let named_struct = to_substrait_named_struct(&schema)?; - - // Struct field names should be flattened DFS style - // List field names should be omitted - assert_eq!( - named_struct.names, - vec!["int", "struct", "inner", "trailer"] - ); - - let roundtrip_schema = - from_substrait_named_struct(&test_consumer(), &named_struct)?; - assert_eq!(schema.as_ref(), &roundtrip_schema); - Ok(()) - } - - #[tokio::test] - async fn extended_expressions() -> Result<()> { - let state = SessionStateBuilder::default().build(); - - // One expression, empty input schema - let expr = Expr::Literal(ScalarValue::Int32(Some(42))); - let field = Field::new("out", DataType::Int32, false); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let substrait = - to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; - let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; - - assert_eq!(roundtrip_expr.input_schema, empty_schema); - assert_eq!(roundtrip_expr.exprs.len(), 1); - - let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); - assert_eq!(rt_field, &field); - assert_eq!(rt_expr, &expr); - - // Multiple expressions, with column references - let expr1 = Expr::Column("c0".into()); - let expr2 = Expr::Column("c1".into()); - let out1 = Field::new("out1", DataType::Int32, true); - let out2 = Field::new("out2", DataType::Utf8, true); - let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ - Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), - ]))?); - - let substrait = to_substrait_extended_expr( - &[(&expr1, &out1), (&expr2, &out2)], - &input_schema, - &state, - )?; - let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; - - assert_eq!(roundtrip_expr.input_schema, input_schema); - assert_eq!(roundtrip_expr.exprs.len(), 2); - - let mut exprs = roundtrip_expr.exprs.into_iter(); - - let (rt_expr, rt_field) = exprs.next().unwrap(); - assert_eq!(rt_field, out1); - assert_eq!(rt_expr, expr1); - - let (rt_expr, rt_field) = exprs.next().unwrap(); - assert_eq!(rt_field, out2); - assert_eq!(rt_expr, expr2); - - Ok(()) - } - - #[tokio::test] - async fn invalid_extended_expression() { - let state = SessionStateBuilder::default().build(); - - // Not ok if input schema is missing field referenced by expr - let expr = Expr::Column("missing".into()); - let field = Field::new("out", DataType::Int32, false); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - - let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); - - assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); - } -} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs new file mode 100644 index 0000000000000..1e79897a1b770 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr; +use datafusion::logical_expr::expr::AggregateFunctionParams; +use substrait::proto::aggregate_function::AggregationInvocation; +use substrait::proto::aggregate_rel::Measure; +use substrait::proto::function_argument::ArgType; +use substrait::proto::sort_field::{SortDirection, SortKind}; +use substrait::proto::{ + AggregateFunction, AggregationPhase, FunctionArgument, SortField, +}; + +pub fn from_aggregate_function( + producer: &mut impl SubstraitProducer, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let expr::AggregateFunction { + func, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + }, + } = agg_fn; + let sorts = order_by + .iter() + .map(|expr| to_substrait_sort_field(producer, expr, schema)) + .collect::>>()?; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + let function_anchor = producer.register_function(func.name().to_string()); + #[allow(deprecated)] + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(producer.handle_expr(f, schema)?), + None => None, + }, + }) +} + +/// Converts sort expression to corresponding substrait `SortField` +fn to_substrait_sort_field( + producer: &mut impl SubstraitProducer, + sort: &expr::Sort, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let sort_kind = match (sort.asc, sort.nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(producer.handle_expr(&sort.expr, schema)?), + sort_kind: Some(SortKind::Direction(sort_kind.into())), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs new file mode 100644 index 0000000000000..9741dcdd10951 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer}; +use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; +use datafusion::common::{DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::{Cast, Expr, TryCast}; +use substrait::proto::expression::cast::FailureBehavior; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::Expression; + +pub fn from_cast( + producer: &mut impl SubstraitProducer, + cast: &Cast, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Cast { expr, data_type } = cast; + // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null + if let Expr::Literal(lit, _) = expr.as_ref() { + // only the untyped(a null scalar value) null literal need this special handling + // since all other kind of nulls are already typed and can be handled by substrait + // e.g. null:: or null:: + if matches!(lit, ScalarValue::Null) { + let lit = Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + data_type, true, + )?)), + }; + return Ok(Expression { + rex_type: Some(RexType::Literal(lit)), + }); + } + } + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), + failure_behavior: FailureBehavior::ThrowException.into(), + }, + ))), + }) +} + +pub fn from_try_cast( + producer: &mut impl SubstraitProducer, + cast: &TryCast, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let TryCast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), + failure_behavior: FailureBehavior::ReturnNull.into(), + }, + ))), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::producer::to_substrait_extended_expr; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::DFSchema; + use datafusion::execution::SessionStateBuilder; + use datafusion::logical_expr::ExprSchemable; + use substrait::proto::expression_reference::ExprType; + + #[tokio::test] + async fn fold_cast_null() { + let state = SessionStateBuilder::default().build(); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let field = Field::new("out", DataType::Int32, false); + + let expr = Expr::Literal(ScalarValue::Null, None) + .cast_to(&DataType::Int32, &empty_schema) + .unwrap(); + + let typed_null = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state) + .unwrap(); + + if let ExprType::Expression(expr) = + typed_null.referred_expr[0].expr_type.as_ref().unwrap() + { + let lit = Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null( + to_substrait_type(&DataType::Int32, true).unwrap(), + )), + }; + let expected = Expression { + rex_type: Some(RexType::Literal(lit)), + }; + assert_eq!(*expr, expected); + } else { + panic!("Expected expression type"); + } + + // a typed null should not be folded + let expr = Expr::Literal(ScalarValue::Int64(None), None) + .cast_to(&DataType::Int32, &empty_schema) + .unwrap(); + + let typed_null = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state) + .unwrap(); + + if let ExprType::Expression(expr) = + typed_null.referred_expr[0].expr_type.as_ref().unwrap() + { + let cast_expr = substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(&DataType::Int32, true).unwrap()), + input: Some(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null( + to_substrait_type(&DataType::Int64, true).unwrap(), + )), + })), + })), + failure_behavior: FailureBehavior::ThrowException as i32, + }; + let expected = Expression { + rex_type: Some(RexType::Cast(Box::new(cast_expr))), + }; + assert_eq!(*expr, expected); + } else { + panic!("Expected expression type"); + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs new file mode 100644 index 0000000000000..d1d80ca545ff2 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::{substrait_err, Column, DFSchemaRef}; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::field_reference::{ + ReferenceType, RootReference, RootType, +}; +use substrait::proto::expression::{ + reference_segment, FieldReference, ReferenceSegment, RexType, +}; +use substrait::proto::Expression; + +pub fn from_column( + col: &Column, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let index = schema.index_of_column(col)?; + substrait_field_ref(index) +} + +pub(crate) fn substrait_field_ref( + index: usize, +) -> datafusion::common::Result { + Ok(Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }))), + }) +} + +/// Try to convert an [Expr] to a [FieldReference]. +/// Returns `Err` if the [Expr] is not a [Expr::Column]. +pub(crate) fn try_to_substrait_field_reference( + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::Column(col) => { + let index = schema.index_of_column(col)?; + Ok(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }) + } + _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::Result; + + #[test] + fn to_field_reference() -> Result<()> { + let expression = substrait_field_ref(2)?; + + match &expression.rex_type { + Some(RexType::Selection(field_ref)) => { + assert_eq!( + field_ref + .root_type + .clone() + .expect("root type should be set"), + RootType::RootReference(RootReference {}) + ); + } + + _ => panic!("Should not be anything other than field reference"), + } + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs b/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs new file mode 100644 index 0000000000000..a34959ead76de --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::Case; +use substrait::proto::expression::if_then::IfClause; +use substrait::proto::expression::{IfThen, RexType}; +use substrait::proto::Expression; + +pub fn from_case( + producer: &mut impl SubstraitProducer, + case: &Case, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Case { + expr, + when_then_expr, + else_expr, + } = case; + let mut ifs: Vec = vec![]; + // Parse base + if let Some(e) = expr { + // Base expression exists + ifs.push(IfClause { + r#if: Some(producer.handle_expr(e, schema)?), + then: None, + }); + } + // Parse `when`s + for (r#if, then) in when_then_expr { + ifs.push(IfClause { + r#if: Some(producer.handle_expr(r#if, schema)?), + then: Some(producer.handle_expr(then, schema)?), + }); + } + + // Parse outer `else` + let r#else: Option> = match else_expr { + Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)), + None => None, + }; + + Ok(Expression { + rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/literal.rs b/datafusion/substrait/src/logical_plan/producer/expr/literal.rs new file mode 100644 index 0000000000000..2c66e9f6b03c2 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/literal.rs @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer}; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIME_32_TYPE_VARIATION_REF, + TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::common::{exec_err, not_impl_err, ScalarValue}; +use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; +use substrait::proto::expression::literal::map::KeyValue; +use substrait::proto::expression::literal::{ + Decimal, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, + LiteralType, Map, PrecisionTime, PrecisionTimestamp, Struct, +}; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::{r#type, Expression}; + +pub fn from_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + to_substrait_literal_expr(producer, value) +} + +pub(crate) fn to_substrait_literal_expr( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + let literal = to_substrait_literal(producer, value)?; + Ok(Expression { + rex_type: Some(RexType::Literal(literal)), + }) +} + +pub(crate) fn to_substrait_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + if value.is_null() { + return Ok(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + &value.data_type(), + true, + )?)), + }); + } + let (literal_type, type_variation_reference) = match value { + ScalarValue::Boolean(Some(b)) => { + (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::Int8(Some(n)) => { + (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::UInt8(Some(n)) => ( + LiteralType::I8(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int16(Some(n)) => { + (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::UInt16(Some(n)) => ( + LiteralType::I16(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt32(Some(n)) => ( + LiteralType::I32(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt64(Some(n)) => ( + LiteralType::I64(*n as i64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Float32(Some(f)) => { + (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::Float64(Some(f)) => { + (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::TimestampSecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + ScalarValue::TimestampSecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Date32(Some(d)) => { + (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) + } + // Date64 literal is not supported in Substrait + ScalarValue::IntervalYearMonth(Some(i)) => ( + LiteralType::IntervalYearToMonth(IntervalYearToMonth { + // DF only tracks total months, but there should always be 12 months in a year + years: *i / 12, + months: *i % 12, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::IntervalMonthDayNano(Some(i)) => ( + LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: i.months / 12, + months: i.months % 12, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: i.days, + seconds: (i.nanoseconds / NANOSECONDS) as i32, + subseconds: i.nanoseconds % NANOSECONDS, + precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds + }), + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::IntervalDayTime(Some(i)) => ( + LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days: i.days, + seconds: i.milliseconds / 1000, + subseconds: (i.milliseconds % 1000) as i64, + precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Binary(Some(b)) => ( + LiteralType::Binary(b.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeBinary(Some(b)) => ( + LiteralType::Binary(b.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::BinaryView(Some(b)) => ( + LiteralType::Binary(b.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::FixedSizeBinary(_, Some(b)) => ( + LiteralType::FixedBinary(b.clone()), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8(Some(s)) => ( + LiteralType::String(s.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeUtf8(Some(s)) => ( + LiteralType::String(s.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8View(Some(s)) => ( + LiteralType::String(s.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Decimal128(v, p, s) if v.is_some() => ( + LiteralType::Decimal(Decimal { + value: v.unwrap().to_le_bytes().to_vec(), + precision: *p as i32, + scale: *s as i32, + }), + DECIMAL_128_TYPE_VARIATION_REF, + ), + ScalarValue::List(l) => ( + convert_array_to_literal_list(producer, l)?, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeList(l) => ( + convert_array_to_literal_list(producer, l)?, + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Map(m) => { + let map = if m.is_empty() || m.value(0).is_empty() { + let mt = to_substrait_type(m.data_type(), m.is_nullable())?; + let mt = match mt { + substrait::proto::Type { + kind: Some(r#type::Kind::Map(mt)), + } => Ok(mt.as_ref().to_owned()), + _ => exec_err!("Unexpected type for a map: {mt:?}"), + }?; + LiteralType::EmptyMap(mt) + } else { + let keys = (0..m.keys().len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&m.keys(), i)?, + ) + }) + .collect::>>()?; + let values = (0..m.values().len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&m.values(), i)?, + ) + }) + .collect::>>()?; + + let key_values = keys + .into_iter() + .zip(values.into_iter()) + .map(|(k, v)| { + Ok(KeyValue { + key: Some(k), + value: Some(v), + }) + }) + .collect::>>()?; + LiteralType::Map(Map { key_values }) + }; + (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) + } + ScalarValue::Time32Second(Some(t)) => ( + LiteralType::PrecisionTime(PrecisionTime { + precision: 0, + value: *t as i64, + }), + TIME_32_TYPE_VARIATION_REF, + ), + ScalarValue::Time32Millisecond(Some(t)) => ( + LiteralType::PrecisionTime(PrecisionTime { + precision: 3, + value: *t as i64, + }), + TIME_32_TYPE_VARIATION_REF, + ), + ScalarValue::Time64Microsecond(Some(t)) => ( + LiteralType::PrecisionTime(PrecisionTime { + precision: 6, + value: *t, + }), + TIME_64_TYPE_VARIATION_REF, + ), + ScalarValue::Time64Nanosecond(Some(t)) => ( + LiteralType::PrecisionTime(PrecisionTime { + precision: 9, + value: *t, + }), + TIME_64_TYPE_VARIATION_REF, + ), + ScalarValue::Struct(s) => ( + LiteralType::Struct(Struct { + fields: s + .columns() + .iter() + .map(|col| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(col, 0)?, + ) + }) + .collect::>>()?, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + _ => ( + not_impl_err!("Unsupported literal: {value:?}")?, + DEFAULT_TYPE_VARIATION_REF, + ), + }; + + Ok(Literal { + nullable: false, + type_variation_reference, + literal_type: Some(literal_type), + }) +} + +fn convert_array_to_literal_list( + producer: &mut impl SubstraitProducer, + array: &GenericListArray, +) -> datafusion::common::Result { + assert_eq!(array.len(), 1); + let nested_array = array.value(0); + + let values = (0..nested_array.len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&nested_array, i)?, + ) + }) + .collect::>>()?; + + if values.is_empty() { + let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { + substrait::proto::Type { + kind: Some(r#type::Kind::List(lt)), + } => lt.as_ref().to_owned(), + _ => unreachable!(), + }; + Ok(LiteralType::EmptyList(lt)) + } else { + Ok(LiteralType::List(List { values })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::from_substrait_literal_without_names; + use crate::logical_plan::consumer::tests::test_consumer; + use crate::logical_plan::producer::DefaultSubstraitProducer; + use datafusion::arrow::array::{Int64Builder, MapBuilder, StringBuilder}; + use datafusion::arrow::datatypes::{ + DataType, Field, IntervalDayTime, IntervalMonthDayNano, + }; + use datafusion::common::scalar::ScalarStructBuilder; + use datafusion::common::Result; + use datafusion::prelude::SessionContext; + use std::sync::Arc; + + #[test] + fn round_trip_literals() -> Result<()> { + round_trip_literal(ScalarValue::Boolean(None))?; + round_trip_literal(ScalarValue::Boolean(Some(true)))?; + round_trip_literal(ScalarValue::Boolean(Some(false)))?; + + round_trip_literal(ScalarValue::Int8(None))?; + round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?; + round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?; + round_trip_literal(ScalarValue::UInt8(None))?; + round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?; + round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?; + + round_trip_literal(ScalarValue::Int16(None))?; + round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?; + round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?; + round_trip_literal(ScalarValue::UInt16(None))?; + round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?; + round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?; + + round_trip_literal(ScalarValue::Int32(None))?; + round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?; + round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?; + round_trip_literal(ScalarValue::UInt32(None))?; + round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?; + round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?; + + round_trip_literal(ScalarValue::Int64(None))?; + round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?; + round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?; + round_trip_literal(ScalarValue::UInt64(None))?; + round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; + round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; + + for (ts, tz) in [ + (Some(12345), None), + (None, None), + (Some(12345), Some("UTC".into())), + (None, Some("UTC".into())), + ] { + round_trip_literal(ScalarValue::TimestampSecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMillisecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMicrosecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampNanosecond(ts, tz))?; + } + + // Test Time32 literals + round_trip_literal(ScalarValue::Time32Second(Some(45296)))?; + round_trip_literal(ScalarValue::Time32Second(None))?; + round_trip_literal(ScalarValue::Time32Millisecond(Some(45296789)))?; + round_trip_literal(ScalarValue::Time32Millisecond(None))?; + + // Test Time64 literals + round_trip_literal(ScalarValue::Time64Microsecond(Some(45296789123)))?; + round_trip_literal(ScalarValue::Time64Microsecond(None))?; + round_trip_literal(ScalarValue::Time64Nanosecond(Some(45296789123000)))?; + round_trip_literal(ScalarValue::Time64Nanosecond(None))?; + + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ))))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(Arc::new( + GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ), + )))?; + + // Null map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(false)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Empty map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Valid map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.keys().append_value("key1"); + map_builder.keys().append_value("key2"); + map_builder.values().append_value(1); + map_builder.values().append_value(2); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + let c0 = Field::new("c0", DataType::Boolean, true); + let c1 = Field::new("c1", DataType::Int32, true); + let c2 = Field::new("c2", DataType::Utf8, true); + round_trip_literal( + ScalarStructBuilder::new() + .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true))) + .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1))) + .with_scalar(c2.to_owned(), ScalarValue::Utf8(None)) + .build()?, + )?; + round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; + + round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; + round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano::new(17, 25, 1234567890), + )))?; + round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( + 57, 123456, + ))))?; + + Ok(()) + } + + fn round_trip_literal(scalar: ScalarValue) -> Result<()> { + println!("Checking round trip of {scalar:?}"); + let state = SessionContext::default().state(); + let mut producer = DefaultSubstraitProducer::new(&state); + let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; + assert_eq!(scalar, roundtrip_scalar); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs new file mode 100644 index 0000000000000..42e1f962f1d1f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_function; +mod cast; +mod field_reference; +mod if_then; +mod literal; +mod scalar_function; +mod singular_or_list; +mod subquery; +mod window_function; + +pub use aggregate_function::*; +pub use cast::*; +pub use field_reference::*; +pub use if_then::*; +pub use literal::*; +pub use scalar_function::*; +pub use singular_or_list::*; +pub use subquery::*; +pub use window_function::*; + +use crate::logical_plan::producer::utils::flatten_names; +use crate::logical_plan::producer::{ + to_substrait_named_struct, DefaultSubstraitProducer, SubstraitProducer, +}; +use datafusion::arrow::datatypes::Field; +use datafusion::common::{internal_err, not_impl_err, DFSchemaRef}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::expr::Alias; +use datafusion::logical_expr::Expr; +use substrait::proto::expression_reference::ExprType; +use substrait::proto::{Expression, ExpressionReference, ExtendedExpression}; +use substrait::version; + +/// Serializes a collection of expressions to a Substrait ExtendedExpression message +/// +/// The ExtendedExpression message is a top-level message that can be used to send +/// expressions (not plans) between systems. +/// +/// Each expression is also given names for the output type. These are provided as a +/// field and not a String (since the names may be nested, e.g. a struct). The data +/// type and nullability of this field is redundant (those can be determined by the +/// Expr) and will be ignored. +/// +/// Substrait also requires the input schema of the expressions to be included in the +/// message. The field names of the input schema will be serialized. +pub fn to_substrait_extended_expr( + exprs: &[(&Expr, &Field)], + schema: &DFSchemaRef, + state: &SessionState, +) -> datafusion::common::Result> { + let mut producer = DefaultSubstraitProducer::new(state); + let substrait_exprs = exprs + .iter() + .map(|(expr, field)| { + let substrait_expr = producer.handle_expr(expr, schema)?; + let mut output_names = Vec::new(); + flatten_names(field, false, &mut output_names)?; + Ok(ExpressionReference { + output_names, + expr_type: Some(ExprType::Expression(substrait_expr)), + }) + }) + .collect::>>()?; + let substrait_schema = to_substrait_named_struct(schema)?; + + let extensions = producer.get_extensions(); + Ok(Box::new(ExtendedExpression { + advanced_extensions: None, + expected_type_urls: vec![], + extension_uris: vec![], + extensions: extensions.into(), + version: Some(version::version_with_producer("datafusion")), + referred_expr: substrait_exprs, + base_schema: Some(substrait_schema), + })) +} + +/// Convert DataFusion Expr to Substrait Rex +/// +/// # Arguments +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns +pub fn to_substrait_rex( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::Alias(expr) => producer.handle_alias(expr, schema), + Expr::Column(expr) => producer.handle_column(expr, schema), + Expr::ScalarVariable(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + Expr::Literal(expr, _) => producer.handle_literal(expr), + Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), + Expr::Like(expr) => producer.handle_like(expr, schema), + Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Not(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::Negative(_) => producer.handle_unary_expr(expr, schema), + Expr::Between(expr) => producer.handle_between(expr, schema), + Expr::Case(expr) => producer.handle_case(expr, schema), + Expr::Cast(expr) => producer.handle_cast(expr, schema), + Expr::TryCast(expr) => producer.handle_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema), + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) + } + Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), + Expr::InList(expr) => producer.handle_in_list(expr, schema), + Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), + Expr::ScalarSubquery(expr) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + #[expect(deprecated)] + Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::OuterReferenceColumn(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + } +} + +pub fn from_alias( + producer: &mut impl SubstraitProducer, + alias: &Alias, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + producer.handle_expr(alias.expr.as_ref(), schema) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::from_substrait_extended_expr; + use datafusion::arrow::datatypes::{DataType, Schema}; + use datafusion::common::{DFSchema, DataFusionError, ScalarValue}; + use datafusion::execution::SessionStateBuilder; + + #[tokio::test] + async fn extended_expressions() -> datafusion::common::Result<()> { + let state = SessionStateBuilder::default().build(); + + // One expression, empty input schema + let expr = Expr::Literal(ScalarValue::Int32(Some(42)), None); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let substrait = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, empty_schema); + assert_eq!(roundtrip_expr.exprs.len(), 1); + + let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); + assert_eq!(rt_field, &field); + assert_eq!(rt_expr, &expr); + + // Multiple expressions, with column references + let expr1 = Expr::Column("c0".into()); + let expr2 = Expr::Column("c1".into()); + let out1 = Field::new("out1", DataType::Int32, true); + let out2 = Field::new("out2", DataType::Utf8, true); + let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ]))?); + + let substrait = to_substrait_extended_expr( + &[(&expr1, &out1), (&expr2, &out2)], + &input_schema, + &state, + )?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, input_schema); + assert_eq!(roundtrip_expr.exprs.len(), 2); + + let mut exprs = roundtrip_expr.exprs.into_iter(); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out1); + assert_eq!(rt_expr, expr1); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out2); + assert_eq!(rt_expr, expr2); + + Ok(()) + } + + #[tokio::test] + async fn invalid_extended_expression() { + let state = SessionStateBuilder::default().build(); + + // Not ok if input schema is missing field referenced by expr + let expr = Expr::Column("missing".into()); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); + + assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs new file mode 100644 index 0000000000000..abb26f6f66822 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -0,0 +1,348 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_literal_expr, SubstraitProducer}; +use datafusion::common::{not_impl_err, DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::{expr, Between, BinaryExpr, Expr, Like, Operator}; +use substrait::proto::expression::{RexType, ScalarFunction}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_scalar_function( + producer: &mut impl SubstraitProducer, + fun: &expr::ScalarFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let mut arguments: Vec = vec![]; + for arg in &fun.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + + let arguments = custom_argument_handler(fun.name(), arguments); + + let function_anchor = producer.register_function(fun.name().to_string()); + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + options: vec![], + args: vec![], + })), + }) +} + +// Handle functions that require custom handling for their arguments (e.g. log) +pub fn custom_argument_handler( + name: &str, + args: Vec, +) -> Vec { + match name { + "log" => { + if args.len() == 2 { + let mut args = args; + args.swap(0, 1); + args + } else { + args + } + } + _ => args, + } +} + +pub fn from_unary_expr( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let (fn_name, arg) = match expr { + Expr::Not(arg) => ("not", arg), + Expr::IsNull(arg) => ("is_null", arg), + Expr::IsNotNull(arg) => ("is_not_null", arg), + Expr::IsTrue(arg) => ("is_true", arg), + Expr::IsFalse(arg) => ("is_false", arg), + Expr::IsUnknown(arg) => ("is_unknown", arg), + Expr::IsNotTrue(arg) => ("is_not_true", arg), + Expr::IsNotFalse(arg) => ("is_not_false", arg), + Expr::IsNotUnknown(arg) => ("is_not_unknown", arg), + Expr::Negative(arg) => ("negate", arg), + expr => not_impl_err!("Unsupported expression: {expr:?}")?, + }; + to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) +} + +pub fn from_binary_expr( + producer: &mut impl SubstraitProducer, + expr: &BinaryExpr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let BinaryExpr { left, op, right } = expr; + let l = producer.handle_expr(left, schema)?; + let r = producer.handle_expr(right, schema)?; + Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) +} + +pub fn from_like( + producer: &mut impl SubstraitProducer, + like: &Like, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + } = like; + make_substrait_like_expr( + producer, + *case_insensitive, + *negated, + expr, + pattern, + *escape_char, + schema, + ) +} + +fn make_substrait_like_expr( + producer: &mut impl SubstraitProducer, + ignore_case: bool, + negated: bool, + expr: &Expr, + pattern: &Expr, + escape_char: Option, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let function_anchor = if ignore_case { + producer.register_function("ilike".to_string()) + } else { + producer.register_function("like".to_string()) + }; + let expr = producer.handle_expr(expr, schema)?; + let pattern = producer.handle_expr(pattern, schema)?; + let escape_char = to_substrait_literal_expr( + producer, + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + )?; + let arguments = vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(expr)), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(pattern)), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(escape_char)), + }, + ]; + + #[allow(deprecated)] + let substrait_like = Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }; + + if negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_like)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_like) + } +} + +/// Util to generate substrait [RexType::ScalarFunction] with one argument +fn to_substrait_unary_scalar_fn( + producer: &mut impl SubstraitProducer, + fn_name: &str, + arg: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let function_anchor = producer.register_function(fn_name.to_string()); + let substrait_expr = producer.handle_expr(arg, schema)?; + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_expr)), + }], + output_type: None, + options: vec![], + ..Default::default() + })), + }) +} + +/// Return Substrait scalar function with two arguments +pub fn make_binary_op_scalar_func( + producer: &mut impl SubstraitProducer, + lhs: &Expression, + rhs: &Expression, + op: Operator, +) -> Expression { + let function_anchor = producer.register_function(operator_to_name(op).to_string()); + #[allow(deprecated)] + Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(lhs.clone())), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(rhs.clone())), + }, + ], + output_type: None, + args: vec![], + options: vec![], + })), + } +} + +pub fn from_between( + producer: &mut impl SubstraitProducer, + between: &Between, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Between { + expr, + negated, + low, + high, + } = between; + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_low, + Operator::Lt, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_high, + &substrait_expr, + Operator::Lt, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::Or, + )) + } else { + // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_low, + &substrait_expr, + Operator::LtEq, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_high, + Operator::LtEq, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::And, + )) + } +} + +pub fn operator_to_name(op: Operator) -> &'static str { + match op { + Operator::Eq => "equal", + Operator::NotEq => "not_equal", + Operator::Lt => "lt", + Operator::LtEq => "lte", + Operator::Gt => "gt", + Operator::GtEq => "gte", + Operator::Plus => "add", + Operator::Minus => "subtract", + Operator::Multiply => "multiply", + Operator::Divide => "divide", + Operator::Modulo => "modulus", + Operator::And => "and", + Operator::Or => "or", + Operator::IsDistinctFrom => "is_distinct_from", + Operator::IsNotDistinctFrom => "is_not_distinct_from", + Operator::RegexMatch => "regex_match", + Operator::RegexIMatch => "regex_imatch", + Operator::RegexNotMatch => "regex_not_match", + Operator::RegexNotIMatch => "regex_not_imatch", + Operator::LikeMatch => "like_match", + Operator::ILikeMatch => "like_imatch", + Operator::NotLikeMatch => "like_not_match", + Operator::NotILikeMatch => "like_not_imatch", + Operator::BitwiseAnd => "bitwise_and", + Operator::BitwiseOr => "bitwise_or", + Operator::StringConcat => "str_concat", + Operator::AtArrow => "at_arrow", + Operator::ArrowAt => "arrow_at", + Operator::Arrow => "arrow", + Operator::LongArrow => "long_arrow", + Operator::HashArrow => "hash_arrow", + Operator::HashLongArrow => "hash_long_arrow", + Operator::AtAt => "at_at", + Operator::IntegerDivide => "integer_divide", + Operator::HashMinus => "hash_minus", + Operator::AtQuestion => "at_question", + Operator::Question => "question", + Operator::QuestionAnd => "question_and", + Operator::QuestionPipe => "question_pipe", + Operator::BitwiseXor => "bitwise_xor", + Operator::BitwiseShiftRight => "bitwise_shift_right", + Operator::BitwiseShiftLeft => "bitwise_shift_left", + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs new file mode 100644 index 0000000000000..1c0b6dcc154bc --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr::InList; +use substrait::proto::expression::{RexType, ScalarFunction, SingularOrList}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_in_list( + producer: &mut impl SubstraitProducer, + in_list: &InList, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let InList { + expr, + list, + negated, + } = in_list; + let substrait_list = list + .iter() + .map(|x| producer.handle_expr(x, schema)) + .collect::>>()?; + let substrait_expr = producer.handle_expr(expr, schema)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; + + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs new file mode 100644 index 0000000000000..c1ee78c68c258 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr::InSubquery; +use substrait::proto::expression::subquery::InPredicate; +use substrait::proto::expression::{RexType, ScalarFunction}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_in_subquery( + producer: &mut impl SubstraitProducer, + subquery: &InSubquery, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let InSubquery { + expr, + subquery, + negated, + } = subquery; + let substrait_expr = producer.handle_expr(expr, schema)?; + + let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), + ), + ), + }, + ))), + }; + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs new file mode 100644 index 0000000000000..465479e1e0488 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::utils::substrait_sort_field; +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::{not_impl_err, DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; +use datafusion::logical_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; +use substrait::proto::expression::window_function::bound as SubstraitBound; +use substrait::proto::expression::window_function::bound::Kind as BoundKind; +use substrait::proto::expression::window_function::{Bound, BoundsType}; +use substrait::proto::expression::RexType; +use substrait::proto::expression::WindowFunction as SubstraitWindowFunction; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument, SortField}; + +pub fn from_window_function( + producer: &mut impl SubstraitProducer, + window_fn: &WindowFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + distinct: _, + filter: _, + }, + } = window_fn; + // function reference + let function_anchor = producer.register_function(fun.to_string()); + // arguments + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + // partition by expressions + let partition_by = partition_by + .iter() + .map(|e| producer.handle_expr(e, schema)) + .collect::>>()?; + // order by expressions + let order_by = order_by + .iter() + .map(|e| substrait_sort_field(producer, e, schema)) + .collect::>>()?; + // window frame + let bounds = to_substrait_bounds(window_frame)?; + let bound_type = to_substrait_bound_type(window_frame)?; + Ok(make_substrait_window_function( + function_anchor, + arguments, + partition_by, + order_by, + bounds, + bound_type, + )) +} + +fn make_substrait_window_function( + function_reference: u32, + arguments: Vec, + partitions: Vec, + sorts: Vec, + bounds: (Bound, Bound), + bounds_type: BoundsType, +) -> Expression { + #[allow(deprecated)] + Expression { + rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { + function_reference, + arguments, + partitions, + sorts, + options: vec![], + output_type: None, + phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED + invocation: 0, // TODO: fix + lower_bound: Some(bounds.0), + upper_bound: Some(bounds.1), + args: vec![], + bounds_type: bounds_type as i32, + })), + } +} + +fn to_substrait_bound_type( + window_frame: &WindowFrame, +) -> datafusion::common::Result { + match window_frame.units { + WindowFrameUnits::Rows => Ok(BoundsType::Rows), // ROWS + WindowFrameUnits::Range => Ok(BoundsType::Range), // RANGE + // TODO: Support GROUPS + unit => not_impl_err!("Unsupported window frame unit: {unit:?}"), + } +} + +fn to_substrait_bounds( + window_frame: &WindowFrame, +) -> datafusion::common::Result<(Bound, Bound)> { + Ok(( + to_substrait_bound(&window_frame.start_bound), + to_substrait_bound(&window_frame.end_bound), + )) +} + +fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { + match bound { + WindowFrameBound::CurrentRow => Bound { + kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), + }, + WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), + }, + None => Bound { + kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), + }, + }, + WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), + }, + None => Bound { + kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), + }, + }, + } +} + +fn to_substrait_bound_offset(value: &ScalarValue) -> Option { + match value { + ScalarValue::UInt8(Some(v)) => Some(*v as i64), + ScalarValue::UInt16(Some(v)) => Some(*v as i64), + ScalarValue::UInt32(Some(v)) => Some(*v as i64), + ScalarValue::UInt64(Some(v)) => Some(*v as i64), + ScalarValue::Int8(Some(v)) => Some(*v as i64), + ScalarValue::Int16(Some(v)) => Some(*v as i64), + ScalarValue::Int32(Some(v)) => Some(*v as i64), + ScalarValue::Int64(Some(v)) => Some(*v), + _ => None, + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/mod.rs b/datafusion/substrait/src/logical_plan/producer/mod.rs new file mode 100644 index 0000000000000..fc4af94a25fe4 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod expr; +mod plan; +mod rel; +mod substrait_producer; +mod types; +mod utils; + +pub use expr::*; +pub use plan::*; +pub use rel::*; +pub use substrait_producer::*; +pub(crate) use types::*; +pub(crate) use utils::*; diff --git a/datafusion/substrait/src/logical_plan/producer/plan.rs b/datafusion/substrait/src/logical_plan/producer/plan.rs new file mode 100644 index 0000000000000..7d5b7754122d6 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/plan.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + to_substrait_named_struct, DefaultSubstraitProducer, SubstraitProducer, +}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{LogicalPlan, SubqueryAlias}; +use substrait::proto::{plan_rel, Plan, PlanRel, Rel, RelRoot}; +use substrait::version; + +/// Convert DataFusion LogicalPlan to Substrait Plan +pub fn to_substrait_plan( + plan: &LogicalPlan, + state: &SessionState, +) -> datafusion::common::Result> { + // Parse relation nodes + // Generate PlanRel(s) + // Note: Only 1 relation tree is currently supported + + let mut producer: DefaultSubstraitProducer = DefaultSubstraitProducer::new(state); + let plan_rels = vec![PlanRel { + rel_type: Some(plan_rel::RelType::Root(RelRoot { + input: Some(*producer.handle_plan(plan)?), + names: to_substrait_named_struct(plan.schema())?.names, + })), + }]; + + // Return parsed plan + let extensions = producer.get_extensions(); + Ok(Box::new(Plan { + version: Some(version::version_with_producer("datafusion")), + extension_uris: vec![], + extensions: extensions.into(), + relations: plan_rels, + advanced_extensions: None, + expected_type_urls: vec![], + parameter_bindings: vec![], + })) +} + +pub fn from_subquery_alias( + producer: &mut impl SubstraitProducer, + alias: &SubqueryAlias, +) -> datafusion::common::Result> { + // Do nothing if encounters SubqueryAlias + // since there is no corresponding relation type in Substrait + producer.handle_plan(alias.input.as_ref()) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs new file mode 100644 index 0000000000000..917959ea7ddae --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + from_aggregate_function, substrait_field_ref, SubstraitProducer, +}; +use datafusion::common::{internal_err, not_impl_err, DFSchemaRef, DataFusionError}; +use datafusion::logical_expr::expr::Alias; +use datafusion::logical_expr::{Aggregate, Distinct, Expr, GroupingSet}; +use substrait::proto::aggregate_rel::{Grouping, Measure}; +use substrait::proto::rel::RelType; +use substrait::proto::{AggregateRel, Expression, Rel}; + +pub fn from_aggregate( + producer: &mut impl SubstraitProducer, + agg: &Aggregate, +) -> datafusion::common::Result> { + let input = producer.handle_plan(agg.input.as_ref())?; + let (grouping_expressions, groupings) = + to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; + let measures = agg + .aggr_expr + .iter() + .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema())) + .collect::>>()?; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + grouping_expressions, + groupings, + measures, + advanced_extension: None, + }))), + })) +} + +pub fn from_distinct( + producer: &mut impl SubstraitProducer, + distinct: &Distinct, +) -> datafusion::common::Result> { + match distinct { + Distinct::All(plan) => { + // Use Substrait's AggregateRel with empty measures to represent `select distinct` + let input = producer.handle_plan(plan.as_ref())?; + // Get grouping keys from the input relation's number of output fields + let grouping = (0..plan.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + grouping_expressions: vec![], + groupings: vec![Grouping { + grouping_expressions: grouping, + expression_references: vec![], + }], + measures: vec![], + advanced_extension: None, + }))), + })) + } + Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"), + } +} + +pub fn to_substrait_groupings( + producer: &mut impl SubstraitProducer, + exprs: &[Expr], + schema: &DFSchemaRef, +) -> datafusion::common::Result<(Vec, Vec)> { + let mut ref_group_exprs = vec![]; + let groupings = match exprs.len() { + 1 => match &exprs[0] { + Expr::GroupingSet(gs) => match gs { + GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( + "GroupingSet CUBE is not yet supported".to_string(), + )), + GroupingSet::GroupingSets(sets) => Ok(sets + .iter() + .map(|set| { + parse_flat_grouping_exprs( + producer, + set, + schema, + &mut ref_group_exprs, + ) + }) + .collect::>>()?), + GroupingSet::Rollup(set) => { + let mut sets: Vec> = vec![vec![]]; + for i in 0..set.len() { + sets.push(set[..=i].to_vec()); + } + Ok(sets + .iter() + .rev() + .map(|set| { + parse_flat_grouping_exprs( + producer, + set, + schema, + &mut ref_group_exprs, + ) + }) + .collect::>>()?) + } + }, + _ => Ok(vec![parse_flat_grouping_exprs( + producer, + exprs, + schema, + &mut ref_group_exprs, + )?]), + }, + _ => Ok(vec![parse_flat_grouping_exprs( + producer, + exprs, + schema, + &mut ref_group_exprs, + )?]), + }?; + Ok((ref_group_exprs, groupings)) +} + +pub fn parse_flat_grouping_exprs( + producer: &mut impl SubstraitProducer, + exprs: &[Expr], + schema: &DFSchemaRef, + ref_group_exprs: &mut Vec, +) -> datafusion::common::Result { + let mut expression_references = vec![]; + let mut grouping_expressions = vec![]; + + for e in exprs { + let rex = producer.handle_expr(e, schema)?; + grouping_expressions.push(rex.clone()); + ref_group_exprs.push(rex); + expression_references.push((ref_group_exprs.len() - 1) as u32); + } + #[allow(deprecated)] + Ok(Grouping { + grouping_expressions, + expression_references, + }) +} + +pub fn to_substrait_agg_measure( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), + Expr::Alias(Alias { expr, .. }) => { + to_substrait_agg_measure(producer, expr, schema) + } + _ => internal_err!( + "Expression must be compatible with aggregation. Unsupported expression: {:?}. Expressiontype: {}", + expr, + expr.variant_name() + ), + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs new file mode 100644 index 0000000000000..9e0ef8905f432 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + try_to_substrait_field_reference, SubstraitProducer, +}; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{Partitioning, Repartition}; +use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::rel::RelType; +use substrait::proto::{ExchangeRel, Rel}; + +pub fn from_repartition( + producer: &mut impl SubstraitProducer, + repartition: &Repartition, +) -> datafusion::common::Result> { + let input = producer.handle_plan(repartition.input.as_ref())?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) + } + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| try_to_substrait_field_reference(e, repartition.input.schema())) + .collect::>>()?; + ExchangeKind::ScatterByFields(ScatterFields { fields }) + } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs new file mode 100644 index 0000000000000..4706401d558ec --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchema; +use datafusion::logical_expr::Limit; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::{fetch_rel, FetchRel, Rel}; + +pub fn from_limit( + producer: &mut impl SubstraitProducer, + limit: &Limit, +) -> datafusion::common::Result> { + let input = producer.handle_plan(limit.input.as_ref())?; + let empty_schema = Arc::new(DFSchema::empty()); + let offset_mode = limit + .skip + .as_ref() + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) + .transpose()? + .map(Box::new) + .map(fetch_rel::OffsetMode::OffsetExpr); + let count_mode = limit + .fetch + .as_ref() + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) + .transpose()? + .map(Box::new) + .map(fetch_rel::CountMode::CountExpr); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(input), + offset_mode, + count_mode, + advanced_extension: None, + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs new file mode 100644 index 0000000000000..770696dfe1a93 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::logical_expr::Filter; +use substrait::proto::rel::RelType; +use substrait::proto::{FilterRel, Rel}; + +pub fn from_filter( + producer: &mut impl SubstraitProducer, + filter: &Filter, +) -> datafusion::common::Result> { + let input = producer.handle_plan(filter.input.as_ref())?; + let filter_expr = producer.handle_expr(&filter.predicate, filter.input.schema())?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Filter(Box::new(FilterRel { + common: None, + input: Some(input), + condition: Some(Box::new(filter_expr)), + advanced_extension: None, + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs b/datafusion/substrait/src/logical_plan/producer/rel/join.rs new file mode 100644 index 0000000000000..835d3ee37a459 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{make_binary_op_scalar_func, SubstraitProducer}; +use datafusion::common::{ + not_impl_err, DFSchemaRef, JoinConstraint, JoinType, NullEquality, +}; +use datafusion::logical_expr::{Expr, Join, Operator}; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::{join_rel, Expression, JoinRel, Rel}; + +pub fn from_join( + producer: &mut impl SubstraitProducer, + join: &Join, +) -> datafusion::common::Result> { + let left = producer.handle_plan(join.left.as_ref())?; + let right = producer.handle_plan(join.right.as_ref())?; + let join_type = to_substrait_jointype(join.join_type); + // we only support basic joins so return an error for anything not yet supported + match join.join_constraint { + JoinConstraint::On => {} + JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), + } + let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); + + // convert filter if present + let join_filter = match &join.filter { + Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), + None => None, + }; + + // map the left and right columns to binary expressions in the form `l = r` + // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` + let eq_op = match join.null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, + }; + let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; + + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + producer, + on_expr, + filter, + Operator::And, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: join_expr, + post_join_filter: None, + advanced_extension: None, + }))), + })) +} + +fn to_substrait_join_expr( + producer: &mut impl SubstraitProducer, + join_conditions: &Vec<(Expr, Expr)>, + eq_op: Operator, + join_schema: &DFSchemaRef, +) -> datafusion::common::Result> { + // Only support AND conjunction for each binary expression in join conditions + let mut exprs: Vec = vec![]; + for (left, right) in join_conditions { + let l = producer.handle_expr(left, join_schema)?; + let r = producer.handle_expr(right, join_schema)?; + // AND with existing expression + exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); + } + + let join_expr: Option = + exprs.into_iter().reduce(|acc: Expression, e: Expression| { + make_binary_op_scalar_func(producer, &acc, &e, Operator::And) + }); + Ok(join_expr) +} + +fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { + match join_type { + JoinType::Inner => join_rel::JoinType::Inner, + JoinType::Left => join_rel::JoinType::Left, + JoinType::Right => join_rel::JoinType::Right, + JoinType::Full => join_rel::JoinType::Outer, + JoinType::LeftAnti => join_rel::JoinType::LeftAnti, + JoinType::LeftSemi => join_rel::JoinType::LeftSemi, + JoinType::LeftMark => join_rel::JoinType::LeftMark, + JoinType::RightMark => join_rel::JoinType::RightMark, + JoinType::RightAnti => join_rel::JoinType::RightAnti, + JoinType::RightSemi => join_rel::JoinType::RightSemi, + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs new file mode 100644 index 0000000000000..c3599a2635ffa --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_rel; +mod exchange_rel; +mod fetch_rel; +mod filter_rel; +mod join; +mod project_rel; +mod read_rel; +mod set_rel; +mod sort_rel; + +pub use aggregate_rel::*; +pub use exchange_rel::*; +pub use fetch_rel::*; +pub use filter_rel::*; +pub use join::*; +pub use project_rel::*; +pub use read_rel::*; +pub use set_rel::*; +pub use sort_rel::*; + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::LogicalPlan; +use substrait::proto::Rel; + +pub fn to_substrait_rel( + producer: &mut impl SubstraitProducer, + plan: &LogicalPlan, +) -> datafusion::common::Result> { + match plan { + LogicalPlan::Projection(plan) => producer.handle_projection(plan), + LogicalPlan::Filter(plan) => producer.handle_filter(plan), + LogicalPlan::Window(plan) => producer.handle_window(plan), + LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), + LogicalPlan::Sort(plan) => producer.handle_sort(plan), + LogicalPlan::Join(plan) => producer.handle_join(plan), + LogicalPlan::Repartition(plan) => producer.handle_repartition(plan), + LogicalPlan::Union(plan) => producer.handle_union(plan), + LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), + LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.handle_limit(plan), + LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Values(plan) => producer.handle_values(plan), + LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Extension(plan) => producer.handle_extension(plan), + LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), + LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::DescribeTable(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } + LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::RecursiveQuery(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs new file mode 100644 index 0000000000000..0190dca12bf53 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{substrait_field_ref, SubstraitProducer}; +use datafusion::logical_expr::{Projection, Window}; +use substrait::proto::rel::RelType; +use substrait::proto::rel_common::EmitKind; +use substrait::proto::rel_common::EmitKind::Emit; +use substrait::proto::{rel_common, ProjectRel, Rel, RelCommon}; + +pub fn from_projection( + producer: &mut impl SubstraitProducer, + p: &Projection, +) -> datafusion::common::Result> { + let expressions = p + .expr + .iter() + .map(|e| producer.handle_expr(e, p.input.schema())) + .collect::>>()?; + + let emit_kind = create_project_remapping( + expressions.len(), + p.input.as_ref().schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(Box::new(ProjectRel { + common: Some(common), + input: Some(producer.handle_plan(p.input.as_ref())?), + expressions, + advanced_extension: None, + }))), + })) +} + +pub fn from_window( + producer: &mut impl SubstraitProducer, + window: &Window, +) -> datafusion::common::Result> { + let input = producer.handle_plan(window.input.as_ref())?; + + // create a field reference for each input field + let mut expressions = (0..window.input.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + // process and add each window function expression + for expr in &window.window_expr { + expressions.push(producer.handle_expr(expr, window.input.schema())?); + } + + let emit_kind = + create_project_remapping(expressions.len(), window.input.schema().fields().len()); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + let project_rel = Box::new(ProjectRel { + common: Some(common), + input: Some(input), + expressions, + advanced_extension: None, + }); + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(project_rel)), + })) +} + +/// By default, a Substrait Project outputs all input fields followed by all expressions. +/// A DataFusion Projection only outputs expressions. In order to keep the Substrait +/// plan consistent with DataFusion, we must apply an output mapping that skips the input +/// fields so that the Substrait Project will only output the expression fields. +fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { + let expression_field_start = input_field_count; + let expression_field_end = expression_field_start + expr_count; + let output_mapping = (expression_field_start..expression_field_end) + .map(|i| i as i32) + .collect(); + Emit(rel_common::Emit { output_mapping }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs new file mode 100644 index 0000000000000..212874e7913b5 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + to_substrait_literal, to_substrait_named_struct, SubstraitProducer, +}; +use datafusion::common::{not_impl_err, substrait_datafusion_err, DFSchema, ToDFSchema}; +use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::{EmptyRelation, Expr, TableScan, Values}; +use std::sync::Arc; +use substrait::proto::expression::literal::Struct; +use substrait::proto::expression::mask_expression::{StructItem, StructSelect}; +use substrait::proto::expression::MaskExpression; +use substrait::proto::read_rel::{NamedTable, ReadType, VirtualTable}; +use substrait::proto::rel::RelType; +use substrait::proto::{ReadRel, Rel}; + +pub fn from_table_scan( + producer: &mut impl SubstraitProducer, + scan: &TableScan, +) -> datafusion::common::Result> { + let projection = scan.projection.as_ref().map(|p| { + p.iter() + .map(|i| StructItem { + field: *i as i32, + child: None, + }) + .collect() + }); + + let projection = projection.map(|struct_items| MaskExpression { + select: Some(StructSelect { struct_items }), + maintain_singular_struct: false, + }); + + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema)?; + + let filter_option = if scan.filters.is_empty() { + None + } else { + let table_schema_qualified = Arc::new( + DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + &(scan.source.schema()), + ) + .unwrap(), + ); + + let combined_expr = conjunction(scan.filters.clone()).unwrap(); + let filter_expr = + producer.handle_expr(&combined_expr, &table_schema_qualified)?; + Some(Box::new(filter_expr)) + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(base_schema), + filter: filter_option, + best_effort_filter: None, + projection, + advanced_extension: None, + read_type: Some(ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + })), + }))), + })) +} + +pub fn from_empty_relation(e: &EmptyRelation) -> datafusion::common::Result> { + if e.produce_one_row { + return not_impl_err!("Producing a row from empty relation is unsupported"); + } + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&e.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values: vec![], + expressions: vec![], + })), + }))), + })) +} + +pub fn from_values( + producer: &mut impl SubstraitProducer, + v: &Values, +) -> datafusion::common::Result> { + let values = v + .values + .iter() + .map(|row| { + let fields = row + .iter() + .map(|v| match v { + Expr::Literal(sv, _) => to_substrait_literal(producer, sv), + Expr::Alias(alias) => match alias.expr.as_ref() { + // The schema gives us the names, so we can skip aliases + Expr::Literal(sv, _) => to_substrait_literal(producer, sv), + _ => Err(substrait_datafusion_err!( + "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() + )), + }, + _ => Err(substrait_datafusion_err!( + "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() + )), + }) + .collect::>()?; + Ok(Struct { fields }) + }) + .collect::>()?; + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&v.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs new file mode 100644 index 0000000000000..58ddfca3617ae --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::logical_expr::Union; +use substrait::proto::rel::RelType; +use substrait::proto::{set_rel, Rel, SetRel}; + +pub fn from_union( + producer: &mut impl SubstraitProducer, + union: &Union, +) -> datafusion::common::Result> { + let input_rels = union + .inputs + .iter() + .map(|input| producer.handle_plan(input.as_ref())) + .collect::>>()? + .into_iter() + .map(|ptr| *ptr) + .collect(); + Ok(Box::new(Rel { + rel_type: Some(RelType::Set(SetRel { + common: None, + inputs: input_rels, + op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL + advanced_extension: None, + })), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs new file mode 100644 index 0000000000000..aaa8be1635600 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{substrait_sort_field, SubstraitProducer}; +use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; +use datafusion::logical_expr::Sort; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::rel::RelType; +use substrait::proto::{fetch_rel, Expression, FetchRel, Rel, SortRel}; + +pub fn from_sort( + producer: &mut impl SubstraitProducer, + sort: &Sort, +) -> datafusion::common::Result> { + let Sort { expr, input, fetch } = sort; + let sort_fields = expr + .iter() + .map(|e| substrait_sort_field(producer, e, input.schema())) + .collect::>>()?; + + let input = producer.handle_plan(input.as_ref())?; + + let sort_rel = Box::new(Rel { + rel_type: Some(RelType::Sort(Box::new(SortRel { + common: None, + input: Some(input), + sorts: sort_fields, + advanced_extension: None, + }))), + }); + + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(sort_rel), + offset_mode: None, + count_mode, + advanced_extension: None, + }))), + })) + } + None => Ok(sort_rel), + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs new file mode 100644 index 0000000000000..56edfac5769cf --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::extensions::Extensions; +use crate::logical_plan::producer::{ + from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, + from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, + from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, + from_projection, from_repartition, from_scalar_function, from_sort, + from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, + from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, +}; +use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; +use datafusion::execution::registry::SerializerRegistry; +use datafusion::execution::SessionState; +use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction}; +use datafusion::logical_expr::{ + expr, Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, + Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, + SubqueryAlias, TableScan, TryCast, Union, Values, Window, +}; +use pbjson_types::Any as ProtoAny; +use substrait::proto::aggregate_rel::Measure; +use substrait::proto::rel::RelType; +use substrait::proto::{ + Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, Rel, +}; + +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn handle_projection(&mut self, plan: &Projection) -> Result> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn handle_extension(&mut self, _plan: &Extension) -> Result> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` +pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan](substrait::proto::Plan) within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn handle_plan( + &mut self, + plan: &LogicalPlan, + ) -> datafusion::common::Result> { + to_substrait_rel(self, plan) + } + + fn handle_projection( + &mut self, + plan: &Projection, + ) -> datafusion::common::Result> { + from_projection(self, plan) + } + + fn handle_filter(&mut self, plan: &Filter) -> datafusion::common::Result> { + from_filter(self, plan) + } + + fn handle_window(&mut self, plan: &Window) -> datafusion::common::Result> { + from_window(self, plan) + } + + fn handle_aggregate( + &mut self, + plan: &Aggregate, + ) -> datafusion::common::Result> { + from_aggregate(self, plan) + } + + fn handle_sort(&mut self, plan: &Sort) -> datafusion::common::Result> { + from_sort(self, plan) + } + + fn handle_join(&mut self, plan: &Join) -> datafusion::common::Result> { + from_join(self, plan) + } + + fn handle_repartition( + &mut self, + plan: &Repartition, + ) -> datafusion::common::Result> { + from_repartition(self, plan) + } + + fn handle_union(&mut self, plan: &Union) -> datafusion::common::Result> { + from_union(self, plan) + } + + fn handle_table_scan( + &mut self, + plan: &TableScan, + ) -> datafusion::common::Result> { + from_table_scan(self, plan) + } + + fn handle_empty_relation( + &mut self, + plan: &EmptyRelation, + ) -> datafusion::common::Result> { + from_empty_relation(plan) + } + + fn handle_subquery_alias( + &mut self, + plan: &SubqueryAlias, + ) -> datafusion::common::Result> { + from_subquery_alias(self, plan) + } + + fn handle_limit(&mut self, plan: &Limit) -> datafusion::common::Result> { + from_limit(self, plan) + } + + fn handle_values(&mut self, plan: &Values) -> datafusion::common::Result> { + from_values(self, plan) + } + + fn handle_distinct( + &mut self, + plan: &Distinct, + ) -> datafusion::common::Result> { + from_distinct(self, plan) + } + + fn handle_extension( + &mut self, + _plan: &Extension, + ) -> datafusion::common::Result> { + substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") + } + + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn handle_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + to_substrait_rex(self, expr, schema) + } + + fn handle_alias( + &mut self, + alias: &Alias, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_alias(self, alias, schema) + } + + fn handle_column( + &mut self, + column: &Column, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_column(column, schema) + } + + fn handle_literal( + &mut self, + value: &ScalarValue, + ) -> datafusion::common::Result { + from_literal(self, value) + } + + fn handle_binary_expr( + &mut self, + expr: &BinaryExpr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_binary_expr(self, expr, schema) + } + + fn handle_like( + &mut self, + like: &Like, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_like(self, like, schema) + } + + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + fn handle_unary_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_unary_expr(self, expr, schema) + } + + fn handle_between( + &mut self, + between: &Between, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_between(self, between, schema) + } + + fn handle_case( + &mut self, + case: &Case, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_case(self, case, schema) + } + + fn handle_cast( + &mut self, + cast: &Cast, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_cast(self, cast, schema) + } + + fn handle_try_cast( + &mut self, + cast: &TryCast, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_try_cast(self, cast, schema) + } + + fn handle_scalar_function( + &mut self, + scalar_fn: &expr::ScalarFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_scalar_function(self, scalar_fn, schema) + } + + fn handle_aggregate_function( + &mut self, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_aggregate_function(self, agg_fn, schema) + } + + fn handle_window_function( + &mut self, + window_fn: &WindowFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_window_function(self, window_fn, schema) + } + + fn handle_in_list( + &mut self, + in_list: &InList, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_in_list(self, in_list, schema) + } + + fn handle_in_subquery( + &mut self, + in_subquery: &InSubquery, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_in_subquery(self, in_subquery, schema) + } +} + +pub struct DefaultSubstraitProducer<'a> { + extensions: Extensions, + serializer_registry: &'a dyn SerializerRegistry, +} + +impl<'a> DefaultSubstraitProducer<'a> { + pub fn new(state: &'a SessionState) -> Self { + DefaultSubstraitProducer { + extensions: Extensions::default(), + serializer_registry: state.serializer_registry().as_ref(), + } + } +} + +impl SubstraitProducer for DefaultSubstraitProducer<'_> { + fn register_function(&mut self, fn_name: String) -> u32 { + self.extensions.register_function(fn_name) + } + + fn get_extensions(self) -> Extensions { + self.extensions + } + + fn handle_extension( + &mut self, + plan: &Extension, + ) -> datafusion::common::Result> { + let extension_bytes = self + .serializer_registry + .serialize_logical_plan(plan.node.as_ref())?; + let detail = ProtoAny { + type_url: plan.node.name().to_string(), + value: extension_bytes.into(), + }; + let mut inputs_rel = plan + .node + .inputs() + .into_iter() + .map(|plan| self.handle_plan(plan)) + .collect::>>()?; + let rel_type = match inputs_rel.len() { + 0 => RelType::ExtensionLeaf(ExtensionLeafRel { + common: None, + detail: Some(detail), + }), + 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { + common: None, + detail: Some(detail), + input: Some(inputs_rel.pop().unwrap()), + })), + _ => RelType::ExtensionMulti(ExtensionMultiRel { + common: None, + detail: Some(detail), + inputs: inputs_rel.into_iter().map(|r| *r).collect(), + }), + }; + Ok(Box::new(Rel { + rel_type: Some(rel_type), + })) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs new file mode 100644 index 0000000000000..3da9269c5b9e3 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -0,0 +1,490 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::to_substrait_precision; +use crate::logical_plan::producer::utils::flatten_names; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, + DEFAULT_MAP_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + DICTIONARY_MAP_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIME_32_TYPE_VARIATION_REF, + TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::datatypes::{DataType, IntervalUnit}; +use datafusion::common::{internal_err, not_impl_err, plan_err, DFSchemaRef}; +use substrait::proto::{r#type, NamedStruct}; + +pub(crate) fn to_substrait_type( + dt: &DataType, + nullable: bool, +) -> datafusion::common::Result { + let nullability = if nullable { + r#type::Nullability::Nullable as i32 + } else { + r#type::Nullability::Required as i32 + }; + match dt { + DataType::Null => internal_err!("Null cast is not valid"), + DataType::Boolean => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Bool(r#type::Boolean { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I8(r#type::I8 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I8(r#type::I8 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int16 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I16(r#type::I16 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt16 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I16(r#type::I16 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I64(r#type::I64 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I64(r#type::I64 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + // Float16 is not supported in Substrait + DataType::Float32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Fp32(r#type::Fp32 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Float64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Fp64(r#type::Fp64 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Timestamp(unit, tz) => { + let precision = to_substrait_precision(unit); + let kind = match tz { + None => r#type::Kind::PrecisionTimestamp(r#type::PrecisionTimestamp { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }), + Some(_) => { + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + r#type::Kind::PrecisionTimestampTz(r#type::PrecisionTimestampTz { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }) + } + }; + Ok(substrait::proto::Type { kind: Some(kind) }) + } + DataType::Time32(unit) => { + let precision = to_substrait_precision(unit); + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::PrecisionTime(r#type::PrecisionTime { + precision, + type_variation_reference: TIME_32_TYPE_VARIATION_REF, + nullability, + })), + }) + } + DataType::Time64(unit) => { + let precision = to_substrait_precision(unit); + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::PrecisionTime(r#type::PrecisionTime { + precision, + type_variation_reference: TIME_64_TYPE_VARIATION_REF, + nullability, + })), + }) + } + DataType::Date32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_32_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Date64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_64_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Interval(interval_unit) => { + match interval_unit { + IntervalUnit::YearMonth => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + IntervalUnit::DayTime => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { + type_variation_reference: DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, + nullability, + precision: Some(3), // DayTime precision is always milliseconds + })), + }), + IntervalUnit::MonthDayNano => { + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalCompound( + r#type::IntervalCompound { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: 9, // nanos + }, + )), + }) + } + } + } + DataType::Duration(duration_unit) => { + let precision = to_substrait_precision(duration_unit); + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { + type_variation_reference: DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + nullability, + precision: Some(precision), + })), + }) + } + DataType::Binary => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { + length: *length, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeBinary => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::BinaryView => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeUtf8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8View => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::List(inner) => { + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(inner_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + DataType::LargeList(inner) => { + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(inner_type)), + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + DataType::Map(inner, _) => match inner.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + let key_type = to_substrait_type( + key_and_value[0].data_type(), + key_and_value[0].is_nullable(), + )?; + let value_type = to_substrait_type( + key_and_value[1].data_type(), + key_and_value[1].is_nullable(), + )?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DEFAULT_MAP_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + DataType::Dictionary(key_type, value_type) => { + let key_type = to_substrait_type(key_type, nullable)?; + let value_type = to_substrait_type(value_type, nullable)?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DICTIONARY_MAP_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + DataType::Struct(fields) => { + let field_types = fields + .iter() + .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) + .collect::>>()?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Struct(r#type::Struct { + types: field_types, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }) + } + DataType::Decimal128(p, s) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Decimal(r#type::Decimal { + type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, + nullability, + scale: *s as i32, + precision: *p as i32, + })), + }), + DataType::Decimal256(p, s) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Decimal(r#type::Decimal { + type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, + nullability, + scale: *s as i32, + precision: *p as i32, + })), + }), + _ => not_impl_err!("Unsupported cast type: {dt}"), + } +} + +pub(crate) fn to_substrait_named_struct( + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let mut names = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + flatten_names(field, false, &mut names)?; + } + + let field_types = r#type::Struct { + types: schema + .fields() + .iter() + .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) + .collect::>()?, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability: r#type::Nullability::Required as i32, + }; + + Ok(NamedStruct { + names, + r#struct: Some(field_types), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::tests::test_consumer; + use crate::logical_plan::consumer::{ + from_substrait_named_struct, from_substrait_type_without_names, + }; + use datafusion::arrow::datatypes::{Field, Fields, Schema, TimeUnit}; + use datafusion::common::{DFSchema, Result}; + use std::sync::Arc; + + #[test] + fn round_trip_types() -> Result<()> { + round_trip_type(DataType::Boolean)?; + round_trip_type(DataType::Int8)?; + round_trip_type(DataType::UInt8)?; + round_trip_type(DataType::Int16)?; + round_trip_type(DataType::UInt16)?; + round_trip_type(DataType::Int32)?; + round_trip_type(DataType::UInt32)?; + round_trip_type(DataType::Int64)?; + round_trip_type(DataType::UInt64)?; + round_trip_type(DataType::Float32)?; + round_trip_type(DataType::Float64)?; + + for tz in [None, Some("UTC".into())] { + round_trip_type(DataType::Timestamp(TimeUnit::Second, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, tz))?; + } + + round_trip_type(DataType::Time32(TimeUnit::Second))?; + round_trip_type(DataType::Time32(TimeUnit::Millisecond))?; + round_trip_type(DataType::Time64(TimeUnit::Microsecond))?; + round_trip_type(DataType::Time64(TimeUnit::Nanosecond))?; + round_trip_type(DataType::Date32)?; + round_trip_type(DataType::Date64)?; + round_trip_type(DataType::Binary)?; + round_trip_type(DataType::FixedSizeBinary(10))?; + round_trip_type(DataType::LargeBinary)?; + round_trip_type(DataType::BinaryView)?; + round_trip_type(DataType::Utf8)?; + round_trip_type(DataType::LargeUtf8)?; + round_trip_type(DataType::Utf8View)?; + round_trip_type(DataType::Decimal128(10, 2))?; + round_trip_type(DataType::Decimal256(30, 2))?; + + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + + round_trip_type(DataType::Map( + Field::new_struct( + "entries", + [ + Field::new("key", DataType::Utf8, false).into(), + Field::new("value", DataType::Int32, true).into(), + ], + false, + ) + .into(), + false, + ))?; + round_trip_type(DataType::Dictionary( + Box::new(DataType::Utf8), + Box::new(DataType::Int32), + ))?; + + round_trip_type(DataType::Struct( + vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ] + .into(), + ))?; + + round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; + round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; + round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; + + round_trip_type(DataType::Duration(TimeUnit::Second))?; + round_trip_type(DataType::Duration(TimeUnit::Millisecond))?; + round_trip_type(DataType::Duration(TimeUnit::Microsecond))?; + round_trip_type(DataType::Duration(TimeUnit::Nanosecond))?; + + Ok(()) + } + + fn round_trip_type(dt: DataType) -> Result<()> { + println!("Checking round trip of {dt}"); + + // As DataFusion doesn't consider nullability as a property of the type, but field, + // it doesn't matter if we set nullability to true or false here. + let substrait = to_substrait_type(&dt, true)?; + let consumer = test_consumer(); + let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?; + assert_eq!(dt, roundtrip_dt); + Ok(()) + } + + #[test] + fn named_struct_names() -> Result<()> { + let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("int", DataType::Int32, true), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new( + "inner", + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), + true, + )])), + true, + ), + Field::new("trailer", DataType::Float64, true), + ]))?); + + let named_struct = to_substrait_named_struct(&schema)?; + + // Struct field names should be flattened DFS style + // List field names should be omitted + assert_eq!( + named_struct.names, + vec!["int", "struct", "inner", "trailer"] + ); + + let roundtrip_schema = + from_substrait_named_struct(&test_consumer(), &named_struct)?; + assert_eq!(schema.as_ref(), &roundtrip_schema); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/utils.rs b/datafusion/substrait/src/logical_plan/producer/utils.rs new file mode 100644 index 0000000000000..9f96b88d084fe --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/utils.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; +use datafusion::common::{plan_err, DFSchemaRef}; +use datafusion::logical_expr::SortExpr; +use substrait::proto::sort_field::{SortDirection, SortKind}; +use substrait::proto::SortField; + +// Substrait wants a list of all field names, including nested fields from structs, +// also from within e.g. lists and maps. However, it does not want the list and map field names +// themselves - only proper structs fields are considered to have useful names. +pub(crate) fn flatten_names( + field: &Field, + skip_self: bool, + names: &mut Vec, +) -> datafusion::common::Result<()> { + if !skip_self { + names.push(field.name().to_string()); + } + match field.data_type() { + DataType::Struct(fields) => { + for field in fields { + flatten_names(field, false, names)?; + } + Ok(()) + } + DataType::List(l) => flatten_names(l, true, names), + DataType::LargeList(l) => flatten_names(l, true, names), + DataType::Map(m, _) => match m.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + flatten_names(&key_and_value[0], true, names)?; + flatten_names(&key_and_value[1], true, names) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + _ => Ok(()), + }?; + Ok(()) +} + +pub(crate) fn substrait_sort_field( + producer: &mut impl SubstraitProducer, + sort: &SortExpr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let SortExpr { + expr, + asc, + nulls_first, + } = sort; + let e = producer.handle_expr(expr, schema)?; + let d = match (asc, nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(e), + sort_kind: Some(SortKind::Direction(d as i32)), + }) +} + +pub(crate) fn to_substrait_precision(time_unit: &TimeUnit) -> i32 { + match time_unit { + TimeUnit::Second => 0, + TimeUnit::Millisecond => 3, + TimeUnit::Microsecond => 6, + TimeUnit::Nanosecond => 9, + } +} diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 4990054ac7fc7..ecf465dd3f18d 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -166,7 +166,7 @@ pub async fn from_substrait_rel( ), } } - _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), + _ => not_impl_err!("Unsupported Reltype: {:?}", rel.rel_type), } } diff --git a/datafusion/substrait/src/physical_plan/producer.rs b/datafusion/substrait/src/physical_plan/producer.rs index 9ba0e0c964e9e..cb725a7277fd3 100644 --- a/datafusion/substrait/src/physical_plan/producer.rs +++ b/datafusion/substrait/src/physical_plan/producer.rs @@ -61,7 +61,7 @@ pub fn to_substrait_rel( substrait_files.push(FileOrFiles { partition_index: partition_index.try_into().unwrap(), start: 0, - length: file.object_meta.size as u64, + length: file.object_meta.size, path_type: Some(PathType::UriPath( file.object_meta.location.as_ref().to_string(), )), diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index e5bebf8e11819..f78b3d785303c 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -50,11 +50,24 @@ pub const TIMESTAMP_NANO_TYPE_VARIATION_REF: u32 = 3; pub const DATE_32_TYPE_VARIATION_REF: u32 = 0; pub const DATE_64_TYPE_VARIATION_REF: u32 = 1; +pub const TIME_32_TYPE_VARIATION_REF: u32 = 0; +pub const TIME_64_TYPE_VARIATION_REF: u32 = 1; pub const DEFAULT_CONTAINER_TYPE_VARIATION_REF: u32 = 0; pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1; pub const VIEW_CONTAINER_TYPE_VARIATION_REF: u32 = 2; +pub const DEFAULT_MAP_TYPE_VARIATION_REF: u32 = 0; +pub const DICTIONARY_MAP_TYPE_VARIATION_REF: u32 = 1; pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0; pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; +/// Used for the arrow type [`DataType::Interval`] with [`IntervalUnit::DayTime`]. +/// +/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval +/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime +pub const DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF: u32 = 0; +/// Used for the arrow type [`DataType::Duration`]. +/// +/// [`DataType::Duration`]: datafusion::arrow::datatypes::DataType::Duration +pub const DURATION_INTERVAL_DAY_TYPE_VARIATION_REF: u32 = 1; // For [user-defined types](https://substrait.io/types/type_classes/#user-defined-types). /// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`]. @@ -96,7 +109,7 @@ pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; /// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano #[deprecated( since = "41.0.0", - note = "Use Substrait `IntervalCompund` type instead" + note = "Use Substrait `IntervalCompound` type instead" )] pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; @@ -106,6 +119,6 @@ pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; /// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano #[deprecated( since = "43.0.0", - note = "Use Substrait `IntervalCompund` type instead" + note = "Use Substrait `IntervalCompound` type instead" )] pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano"; diff --git a/datafusion/substrait/tests/cases/builtin_expr_semantics_tests.rs b/datafusion/substrait/tests/cases/builtin_expr_semantics_tests.rs new file mode 100644 index 0000000000000..c7ca669b27c84 --- /dev/null +++ b/datafusion/substrait/tests/cases/builtin_expr_semantics_tests.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! There are some Substrait functions that are semantically equivalent to nested built-in expressions, such as xor:bool_bool and and_not:bool_bool +//! This module tests that the semantics of these functions are correct roundtripped + +#[cfg(test)] +mod tests { + use crate::utils::test::add_plan_schemas_to_ctx; + use datafusion::arrow::util::pretty; + use datafusion::common::Result; + use datafusion::prelude::DataFrame; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use datafusion_substrait::logical_plan::producer::to_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use substrait::proto::Plan; + + // Helper function to test scalar function semantics and roundtrip conversion + async fn test_scalar_fn_semantics( + file_path: &str, + expected_results: Vec<&str>, + ) -> Result<()> { + let path = format!("tests/testdata/test_plans/{file_path}"); + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + + // Test correct semantics of function + let df = DataFrame::new(ctx.state().clone(), plan.clone()); + let results = df.collect().await?; + let pretty_results = pretty::pretty_format_batches(&results)?.to_string(); + assert_eq!( + pretty_results.trim().lines().collect::>(), + expected_results + ); + + // Test roundtrip semantics + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; + let df2 = DataFrame::new(ctx.state().clone(), plan2.clone()); + let results2 = df2.collect().await?; + let pretty_results2 = pretty::pretty_format_batches(&results2)?.to_string(); + assert_eq!( + pretty_results2.trim().lines().collect::>(), + expected_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_xor_semantics() -> Result<()> { + let expected = vec![ + "+-------+-------+--------+", + "| a | b | result |", + "+-------+-------+--------+", + "| true | true | false |", + "| true | false | true |", + "| false | true | true |", + "| false | false | false |", + "+-------+-------+--------+", + ]; + + test_scalar_fn_semantics( + "scalar_fn_to_built_in_binary_expr_xor.substrait.json", + expected, + ) + .await + } + + #[tokio::test] + async fn test_and_not_semantics() -> Result<()> { + let expected = vec![ + "+-------+-------+--------+", + "| a | b | result |", + "+-------+-------+--------+", + "| true | true | false |", + "| true | false | true |", + "| false | true | false |", + "| false | false | false |", + "+-------+-------+--------+", + ]; + + test_scalar_fn_semantics( + "scalar_fn_to_built_in_binary_expr_and_not.substrait.json", + expected, + ) + .await + } + + #[tokio::test] + async fn test_logb_semantics() -> Result<()> { + let expected = vec![ + "+-------+------+--------+", + "| x | base | result |", + "+-------+------+--------+", + "| 1.0 | 10.0 | 0.0 |", + "| 100.0 | 10.0 | 2.0 |", + "+-------+------+--------+", + ]; + + test_scalar_fn_semantics("scalar_fn_logb_expr.substrait.json", expected).await + } +} diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index af9d92378298a..a92fc2957cae3 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -44,7 +44,7 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; ctx.state().create_physical_plan(&plan).await?; - Ok(format!("{}", plan)) + Ok(format!("{plan}")) } #[tokio::test] @@ -501,7 +501,7 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; ctx.state().create_physical_plan(&plan).await?; - Ok(format!("{}", plan)) + Ok(format!("{plan}")) } #[tokio::test] @@ -519,6 +519,120 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_expressions_in_virtual_table() -> Result<()> { + let plan_str = + test_plan_to_string("virtual_table_with_expressions.substrait.json").await?; + + assert_snapshot!( + plan_str, + @r#" + Projection: dummy1 AS result1, dummy2 AS result2 + Values: (Int64(0), Utf8("temp")), (Int64(1), Utf8("test")) + "# + ); + Ok(()) + } + + #[tokio::test] + //There are some Substrait functions that can be represented with nested built-in expressions + //xor:bool_bool is implemented in the consumer with binary expressions + //This tests that the consumer correctly builds the nested expressions for this function + async fn test_built_in_binary_exprs_for_xor() -> Result<()> { + let plan_str = + test_plan_to_string("scalar_fn_to_built_in_binary_expr_xor.substrait.json") + .await?; + + //Test correct plan structure + assert_snapshot!(plan_str, + @r#" + Projection: a, b, (a OR b) AND NOT a AND b AS result + Values: (Boolean(true), Boolean(true)), (Boolean(true), Boolean(false)), (Boolean(false), Boolean(true)), (Boolean(false), Boolean(false)) + "# + ); + + Ok(()) + } + + #[tokio::test] + //There are some Substrait functions that can be represented with nested built-in expressions + //and_not:bool_bool is implemented in the consumer as binary expressions + //This tests that the consumer correctly builds the nested expressions for this function + async fn test_built_in_binary_exprs_for_and_not() -> Result<()> { + let plan_str = test_plan_to_string( + "scalar_fn_to_built_in_binary_expr_and_not.substrait.json", + ) + .await?; + + //Test correct plan structure + assert_snapshot!(plan_str, + @r#" + Projection: a, b, a AND NOT b AS result + Values: (Boolean(true), Boolean(true)), (Boolean(true), Boolean(false)), (Boolean(false), Boolean(true)), (Boolean(false), Boolean(false)) + "# + ); + + Ok(()) + } + + //The between:any_any_any function is implemented as Expr::Between in the Substrait consumer + //This test tests that the consumer correctly builds the Expr::Between expression for this function + #[tokio::test] + async fn test_between_expr() -> Result<()> { + let plan_str = + test_plan_to_string("scalar_fn_to_between_expr.substrait.json").await?; + assert_snapshot!(plan_str, + @r#" + Projection: expr BETWEEN low AND high AS result + Values: (Int8(2), Int8(1), Int8(3)), (Int8(4), Int8(1), Int8(2)) + "# + ); + Ok(()) + } + + #[tokio::test] + async fn test_logb_expr() -> Result<()> { + let plan_str = test_plan_to_string("scalar_fn_logb_expr.substrait.json").await?; + assert_snapshot!(plan_str, + @r#" + Projection: x, base, log(base, x) AS result + Values: (Float32(1), Float32(10)), (Float32(100), Float32(10)) + "# + ); + Ok(()) + } + + #[tokio::test] + async fn test_multiple_joins() -> Result<()> { + let plan_str = test_plan_to_string("multiple_joins.json").await?; + assert_snapshot!( + plan_str, + @r#" + Projection: left.count(Int64(1)) AS count_first, left.category, left.count(Int64(1)):1 AS count_second, right.count(Int64(1)) AS count_third + Left Join: left.id = right.id + SubqueryAlias: left + Projection: left.id, left.count(Int64(1)), left.id:1, left.category, right.id AS id:2, right.count(Int64(1)) AS count(Int64(1)):1 + Left Join: left.id = right.id + SubqueryAlias: left + Projection: left.id, left.count(Int64(1)), right.id AS id:1, right.category + Left Join: left.id = right.id + SubqueryAlias: left + Aggregate: groupBy=[[id]], aggr=[[count(Int64(1))]] + Values: (Int64(1)), (Int64(2)) + SubqueryAlias: right + Aggregate: groupBy=[[id, category]], aggr=[[]] + Values: (Int64(1), Utf8("info")), (Int64(2), Utf8("low")) + SubqueryAlias: right + Aggregate: groupBy=[[id]], aggr=[[count(Int64(1))]] + Values: (Int64(1)), (Int64(2)) + SubqueryAlias: right + Aggregate: groupBy=[[id]], aggr=[[count(Int64(1))]] + Values: (Int64(1)), (Int64(2)) + "# + ); + Ok(()) + } + #[tokio::test] async fn test_select_window_count() -> Result<()> { let plan_str = test_plan_to_string("select_window_count.substrait.json").await?; @@ -533,4 +647,65 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn test_multiple_unions() -> Result<()> { + let plan_str = test_plan_to_string("multiple_unions.json").await?; + + let mut settings = insta::Settings::clone_current(); + settings.add_filter( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + "[UUID]", + ); + settings.bind(|| { + assert_snapshot!( + plan_str, + @r#" + Projection: [UUID] AS product_category, [UUID] AS product_type, product_key + Union + Projection: Utf8("people") AS [UUID], Utf8("people") AS [UUID], sales.product_key + Left Join: sales.product_key = food.@food_id + TableScan: sales + TableScan: food + Union + Projection: people.$f3, people.$f5, people.product_key0 + Left Join: people.product_key0 = food.@food_id + TableScan: people + TableScan: food + TableScan: more_products + "# + ); + }); + + Ok(()) + } + + #[tokio::test] + async fn test_join_with_expression_key() -> Result<()> { + let plan_str = test_plan_to_string("join_with_expression_key.json").await?; + assert_snapshot!( + plan_str, + @r#" + Projection: left.index_name AS index, right.upper(host) AS host, left.max(size_bytes) AS idx_size, right.max(total_bytes) AS db_size, CAST(left.max(size_bytes) AS Float64) / CAST(right.max(total_bytes) AS Float64) * Float64(100) AS pct_of_db + Inner Join: left.upper(host) = right.upper(host) + SubqueryAlias: left + Aggregate: groupBy=[[index_name, upper(host)]], aggr=[[max(size_bytes)]] + Projection: size_bytes, index_name, upper(host) + Filter: index_name = Utf8("aaa") + Values: (Utf8("aaa"), Utf8("host-a"), Int64(128)), (Utf8("bbb"), Utf8("host-b"), Int64(256)) + SubqueryAlias: right + Aggregate: groupBy=[[upper(host)]], aggr=[[max(total_bytes)]] + Projection: total_bytes, upper(host) + Inner Join: Filter: upper(host) = upper(host) + Values: (Utf8("host-a"), Int64(107)), (Utf8("host-b"), Int64(214)) + Projection: upper(host) + Aggregate: groupBy=[[index_name, upper(host)]], aggr=[[max(size_bytes)]] + Projection: size_bytes, index_name, upper(host) + Filter: index_name = Utf8("aaa") + Values: (Utf8("aaa"), Utf8("host-a"), Int64(128)), (Utf8("bbb"), Utf8("host-b"), Int64(256)) + "# + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index 88db2bc34d7f6..e916b4cb0e1a9 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -126,8 +126,8 @@ mod tests { let plan1str = format!("{plan}"); let plan2str = format!("{plan2}"); - println!("{}", plan1str); - println!("{}", plan2str); + println!("{plan1str}"); + println!("{plan2str}"); assert_eq!(plan1str, plan2str); Ok(()) diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 4dd97193034bd..426f3c12e5a15 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -144,6 +144,47 @@ mod tests { Ok(()) } + #[tokio::test] + async fn null_literal_before_and_after_joins() -> Result<()> { + // Confirms that literals used before and after a join but for different columns + // are correctly handled. + + // File generated with substrait-java's Isthmus: + // ./isthmus-cli/build/graal/isthmus --create "create table A (a int); create table B (a int, c int); create table C (a int, d int)" "select t.*, C.d, CAST(NULL AS VARCHAR) as e from (select a, CAST(NULL AS VARCHAR) as c from A UNION ALL select a, c from B) t LEFT JOIN C ON t.a = C.a" + let proto_plan = + read_json("tests/testdata/test_plans/disambiguate_literals_with_same_name.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + let mut settings = insta::Settings::clone_current(); + settings.add_filter( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + "[UUID]", + ); + settings.bind(|| { + assert_snapshot!( + plan, + @r#" + Projection: left.A, left.[UUID] AS C, right.D, Utf8(NULL) AS [UUID] AS E + Left Join: left.A = right.A + SubqueryAlias: left + Union + Projection: A.A, Utf8(NULL) AS [UUID] + TableScan: A + Projection: B.A, CAST(B.C AS Utf8) + TableScan: B + SubqueryAlias: right + TableScan: C + "# + ); + }); + + // Trigger execution to ensure plan validity + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } + #[tokio::test] async fn non_nullable_lists() -> Result<()> { // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index 777246e4139bf..9e69bb4edd854 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod builtin_expr_semantics_tests; mod consumer_integration; mod emit_kind_tests; mod function_test; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 36ee78fe5d9a7..39e4984ab9f79 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -92,6 +92,8 @@ impl PartialOrd for MockUserDefinedLogicalPlan { Some(Ordering::Equal) => self.inputs.partial_cmp(&other.inputs), cmp => cmp, } + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) } } @@ -112,11 +114,7 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { &self.empty_schema } - fn check_invariants( - &self, - _check: InvariantLevel, - _plan: &LogicalPlan, - ) -> Result<()> { + fn check_invariants(&self, _check: InvariantLevel) -> Result<()> { Ok(()) } @@ -348,7 +346,7 @@ async fn decimal_literal() -> Result<()> { #[tokio::test] async fn null_decimal_literal() -> Result<()> { - roundtrip("SELECT * FROM data WHERE b = NULL").await + roundtrip("SELECT *, CAST(NULL AS decimal(10, 2)) FROM data").await } #[tokio::test] @@ -426,6 +424,41 @@ async fn simple_scalar_function_substr() -> Result<()> { roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await } +// Test that DataFusion functions gets correctly mapped to Substrait names (when the names are different) +// Follows the same structure as existing roundtrip tests, but more explicitly tests for name mappings +async fn test_substrait_to_df_name_mapping( + substrait_name: &str, + sql: &str, +) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + + let function_name = match proto.extensions[0].mapping_type.as_ref().unwrap() { + MappingType::ExtensionFunction(ext_f) => &ext_f.name, + _ => unreachable!("Expected function extension"), + }; + + assert_eq!(function_name, substrait_name); + + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + let plan1str = format!("{plan}"); + let plan2str = format!("{plan2}"); + assert_eq!(plan1str, plan2str); + + assert_eq!(plan.schema(), plan2.schema()); + + Ok(()) +} + +#[tokio::test] +async fn scalar_function_is_nan_mapping() -> Result<()> { + test_substrait_to_df_name_mapping("is_nan", "SELECT ISNAN(a) FROM data").await +} + #[tokio::test] async fn simple_scalar_function_is_null() -> Result<()> { roundtrip("SELECT * FROM data WHERE a IS NULL").await @@ -593,6 +626,66 @@ async fn roundtrip_exists_filter() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_not_exists_filter_left_anti_join() -> Result<()> { + let plan = generate_plan_from_sql( + "SELECT ba.isbn, ba.author FROM book_author ba WHERE NOT EXISTS (SELECT 1 FROM book b WHERE b.isbn = ba.isbn)", + false, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + LeftAnti Join: book_author.isbn = book.isbn + TableScan: book_author projection=[isbn, author] + TableScan: book projection=[isbn] + "# + ); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_right_anti_join() -> Result<()> { + let plan = generate_plan_from_sql( + "SELECT * FROM book b RIGHT ANTI JOIN book_author ba ON b.isbn = ba.isbn", + false, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + RightAnti Join: book.isbn = book_author.isbn + TableScan: book projection=[isbn] + TableScan: book_author projection=[isbn, author] + "# + ); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_right_semi_join() -> Result<()> { + let plan = generate_plan_from_sql( + "SELECT * FROM book b RIGHT SEMI JOIN book_author ba ON b.isbn = ba.isbn", + false, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + RightSemi Join: book.isbn = book_author.isbn + TableScan: book projection=[isbn] + TableScan: book_author projection=[isbn, author] + "# + ); + Ok(()) +} + #[tokio::test] async fn inner_join() -> Result<()> { let plan = generate_plan_from_sql( @@ -763,7 +856,7 @@ async fn simple_intersect() -> Result<()> { let expected_plan_str = format!( "Projection: count(Int64(1)) AS {syntax}\ \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ @@ -780,7 +873,7 @@ async fn simple_intersect() -> Result<()> { async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> { let expected_plan_str = format!( "Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ @@ -854,6 +947,22 @@ async fn aggregate_wo_projection_sorted_consume() -> Result<()> { Ok(()) } +#[tokio::test] +async fn aggregate_identical_grouping_expressions() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json"); + + let plan = generate_plan_from_substrait(proto_plan).await?; + assert_snapshot!( + plan, + @r#" + Aggregate: groupBy=[[Int32(1) AS grouping_col_1, Int32(1) AS grouping_col_2]], aggr=[[]] + TableScan: data projection=[] + "# + ); + Ok(()) +} + #[tokio::test] async fn simple_intersect_consume() -> Result<()> { let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json"); @@ -942,7 +1051,7 @@ async fn simple_intersect_table_reuse() -> Result<()> { let expected_plan_str = format!( "Projection: count(Int64(1)) AS {syntax}\ \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: left.a = right.a\ \n SubqueryAlias: left\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ @@ -961,7 +1070,7 @@ async fn simple_intersect_table_reuse() -> Result<()> { async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> { let expected_plan_str = format!( "Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: left.a = right.a\ \n SubqueryAlias: left\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ @@ -1061,7 +1170,7 @@ async fn roundtrip_literal_list() -> Result<()> { async fn roundtrip_literal_struct() -> Result<()> { let plan = generate_plan_from_sql( "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data", - false, + true, true, ) .await?; @@ -1076,6 +1185,46 @@ async fn roundtrip_literal_struct() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_literal_named_struct() -> Result<()> { + let plan = generate_plan_from_sql( + "SELECT STRUCT(1 as int_field, true as boolean_field, CAST(NULL AS STRING) as string_field) FROM data", + true, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + Projection: Struct({int_field:1,boolean_field:true,string_field:}) AS named_struct(Utf8("int_field"),Int64(1),Utf8("boolean_field"),Boolean(true),Utf8("string_field"),NULL) + TableScan: data projection=[] + "# + ); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_literal_renamed_struct() -> Result<()> { + // This test aims to hit a case where the struct column itself has the expected name, but its + // inner field needs to be renamed. + let plan = generate_plan_from_sql( + "SELECT CAST((STRUCT(1)) AS Struct<\"int_field\"Int>) AS 'Struct({c0:1})' FROM data", + true, + true, + ) + .await?; + + assert_snapshot!( + plan, + @r#" + Projection: Struct({int_field:1}) AS Struct({c0:1}) + TableScan: data projection=[] + "# + ); + Ok(()) +} + #[tokio::test] async fn roundtrip_values() -> Result<()> { // TODO: would be nice to have a struct inside the LargeList, but arrow_cast doesn't support that currently @@ -1386,7 +1535,7 @@ fn check_post_join_filters(rel: &Rel) -> Result<()> { } Some(RelType::ExtensionLeaf(_)) | Some(RelType::Read(_)) => Ok(()), _ => not_impl_err!( - "Unsupported RelType: {:?} in post join filter check", + "Unsupported Reltype: {:?} in post join filter check", rel.rel_type ), } @@ -1662,6 +1811,34 @@ async fn create_context() -> Result { ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new()) .await?; + // Register test tables for anti join tests + let book_fields = vec![ + Field::new("isbn", DataType::Int64, false), + Field::new("title", DataType::Utf8, true), + Field::new("genre", DataType::Utf8, true), + ]; + let book_schema = Schema::new(book_fields); + let mut book_options = CsvReadOptions::new(); + book_options.schema = Some(&book_schema); + book_options.has_header = false; + ctx.register_csv("book", "tests/testdata/empty.csv", book_options) + .await?; + + let book_author_fields = vec![ + Field::new("isbn", DataType::Int64, true), + Field::new("author", DataType::Utf8, true), + ]; + let book_author_schema = Schema::new(book_author_fields); + let mut book_author_options = CsvReadOptions::new(); + book_author_options.schema = Some(&book_author_schema); + book_author_options.has_header = false; + ctx.register_csv( + "book_author", + "tests/testdata/empty.csv", + book_author_options, + ) + .await?; + Ok(ctx) } diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index a31b3ca385e9c..c8cc3fe9940ce 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -51,7 +51,7 @@ mod tests { let ctx = SessionContext::new(); ctx.register_table( table_ref, - Arc::new(EmptyTable::new(df_schema.inner().clone())), + Arc::new(EmptyTable::new(Arc::clone(df_schema.inner()))), )?; Ok(ctx) } diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json new file mode 100644 index 0000000000000..15c0b0505fa68 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json @@ -0,0 +1,53 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [], + "struct": { + "types": [], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": ["data"] + } + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "literal": { + "i32": 1 + } + }, + { + "literal": { + "i32": 1 + } + } + ] + } + ], + "measures": [] + } + }, + "names": ["grouping_col_1", "grouping_col_2"] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "manual" + } +} diff --git a/datafusion/substrait/tests/testdata/test_plans/disambiguate_literals_with_same_name.substrait.json b/datafusion/substrait/tests/testdata/test_plans/disambiguate_literals_with_same_name.substrait.json new file mode 100644 index 0000000000000..d72830898f913 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/disambiguate_literals_with_same_name.substrait.json @@ -0,0 +1,287 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6, 7] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [1, 2] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["A"], + "struct": { + "types": [{ + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["A"] + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "literal": { + "null": { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "nullable": false, + "typeVariationReference": 0 + } + }] + } + }, { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["A", "C"], + "struct": { + "types": [{ + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["B"] + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "cast": { + "type": { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + }] + } + }], + "op": "SET_OP_UNION_ALL" + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["A", "D"], + "struct": { + "types": [{ + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["C"] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "literal": { + "null": { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "nullable": false, + "typeVariationReference": 0 + } + }] + } + }, + "names": ["A", "C", "D", "E"] + } + }], + "expectedTypeUrls": [], + "version": { + "majorNumber": 0, + "minorNumber": 74, + "patchNumber": 0, + "gitHash": "", + "producer": "isthmus" + }, + "parameterBindings": [] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/join_with_expression_key.json b/datafusion/substrait/tests/testdata/test_plans/join_with_expression_key.json new file mode 100644 index 0000000000000..73fa06eea5f05 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/join_with_expression_key.json @@ -0,0 +1,814 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 3, + "uri": "/functions_arithmetic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_string.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "upper:str" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "max:i64" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "multiply:fp64_fp64" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 4, + "name": "divide:fp64_fp64" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [5, 6, 7, 8, 9] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [3, 4, 5] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["index_name", "host", "size_bytes"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "string": "aaa", + "nullable": true + }, { + "string": "host-a", + "nullable": true + }, { + "i64": "128", + "nullable": true + }] + }, { + "fields": [{ + "string": "bbb", + "nullable": true + }, { + "string": "host-b", + "nullable": true + }, { + "i64": "256", + "nullable": true + }] + }] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "aaa" + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 2, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "right": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [3, 4] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["host", "total_bytes"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "string": "host-a", + "nullable": true + }, { + "i64": "107", + "nullable": true + }] + }, { + "fields": [{ + "string": "host-b", + "nullable": true + }, { + "i64": "214", + "nullable": true + }] + }] + } + } + }, + "right": { + "project": { + "common": { + "emit": { + "outputMapping": [3] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [3, 4, 5] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["index_name", "host", "size_bytes"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "string": "aaa", + "nullable": true + }, { + "string": "host-a", + "nullable": true + }, { + "i64": "128", + "nullable": true + }] + }, { + "fields": [{ + "string": "bbb", + "nullable": true + }, { + "string": "host-b", + "nullable": true + }, { + "i64": "256", + "nullable": true + }] + }] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "aaa" + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 2, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 2, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "cast": { + "type": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "literal": { + "fp64": 100.0 + } + } + }] + } + }] + } + }, + "names": ["index", "host", "idx_size", "db_size", "pct_of_db"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/multiple_joins.json b/datafusion/substrait/tests/testdata/test_plans/multiple_joins.json new file mode 100644 index 0000000000000..e88cce648da7c --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multiple_joins.json @@ -0,0 +1,536 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "count:" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [8, 9, 10, 11] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["id"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i64": "1", + "nullable": true, + "typeVariationReference": 0 + }] + }, { + "fields": [{ + "i64": "2", + "nullable": true, + "typeVariationReference": 0 + }] + }] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "right": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["id", "category"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i64": "1", + "nullable": true, + "typeVariationReference": 0 + }, { + "string": "info", + "nullable": true, + "typeVariationReference": 0 + }] + }, { + "fields": [{ + "i64": "2", + "nullable": true, + "typeVariationReference": 0 + }, { + "string": "low", + "nullable": true, + "typeVariationReference": 0 + }] + }] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [], + "groupingExpressions": [] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "right": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["id"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i64": "1", + "nullable": true, + "typeVariationReference": 0 + }] + }, { + "fields": [{ + "i64": "2", + "nullable": true, + "typeVariationReference": 0 + }] + }] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "right": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["id"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i64": "1", + "nullable": true, + "typeVariationReference": 0 + }] + }, { + "fields": [{ + "i64": "2", + "nullable": true, + "typeVariationReference": 0 + }] + }] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["count_first", "category", "count_second", "count_third"] + } + }], + "expectedTypeUrls": [], + "version": { + "majorNumber": 0, + "minorNumber": 52, + "patchNumber": 0, + "gitHash": "" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json new file mode 100644 index 0000000000000..8b82d6eec7552 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json @@ -0,0 +1,328 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [2, 3, 4] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["product_key"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "literal": { + "string": "people" + } + }, { + "literal": { + "string": "people" + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }] + } + }, { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f3", "$f5", "product_key0"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "people" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f1000", "$f2000", "more_products_key0000"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "more_products" + ] + } + + } + }], + "op": "SET_OP_UNION_ALL" + } + }], + "op": "SET_OP_UNION_ALL" + } + }, + "names": ["product_category", "product_type", "product_key"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/scalar_fn_logb_expr.substrait.json b/datafusion/substrait/tests/testdata/test_plans/scalar_fn_logb_expr.substrait.json new file mode 100644 index 0000000000000..eeaf5a3dd8476 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/scalar_fn_logb_expr.substrait.json @@ -0,0 +1,116 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_boolean.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "logb:fp32_fp32" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "expressions": [ + { + "scalarFunction": { + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + } + } + ], + "input": { + "read": { + "baseSchema": { + "names": [ + "x", "base" + ], + "struct": { + "types": [ + { + "fp32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fp32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "common": { + "direct": {} + }, + "virtualTable": { + "values": [{ + "fields": [{ + "fp32": 1.0, + "nullable": false + }, { + "fp32": 10.0, + "nullable": false + }] + }, { + "fields": [{ + "fp32": 100.0, + "nullable": false + }, { + "fp32": 10.0, + "nullable": false + }] + }] + } + } + } + } + }, + "names": [ + "x", "base", "result" + ] + } + } + ] + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_between_expr.substrait.json b/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_between_expr.substrait.json new file mode 100644 index 0000000000000..6749a301b17df --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_between_expr.substrait.json @@ -0,0 +1,143 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "between:any_any_any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 3 + ] + } + }, + "expressions": [ + { + "scalarFunction": { + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }], + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + } + } + ], + "input": { + "read": { + "baseSchema": { + "names": [ + "expr", "low", "high" + ], + "struct": { + "types": [ + { + "i8": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i8": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i8": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "common": { + "direct": {} + }, + "virtualTable": { + "values": [{ + "fields": [{ + "i8": 2, + "nullable": false + }, { + "i8": 1, + "nullable": false + }, { + "i8": 3, + "nullable": false + }] + }, { + "fields": [{ + "i8": 4, + "nullable": false + }, { + "i8": 1, + "nullable": false + }, { + "i8": 2, + "nullable": false + }] + }] + } + } + } + } + }, + "names": [ + "result" + ] + } + } + ] + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_built_in_binary_expr_and_not.substrait.json b/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_built_in_binary_expr_and_not.substrait.json new file mode 100644 index 0000000000000..8365b1edfe250 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_built_in_binary_expr_and_not.substrait.json @@ -0,0 +1,132 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_boolean.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "and_not:bool_bool" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "expressions": [ + { + "scalarFunction": { + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + } + } + ], + "input": { + "read": { + "baseSchema": { + "names": [ + "a", "b" + ], + "struct": { + "types": [ + { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "common": { + "direct": {} + }, + "virtualTable": { + "values": [{ + "fields": [{ + "boolean": true, + "nullable": false + }, { + "boolean": true, + "nullable": false + }] + }, { + "fields": [{ + "boolean": true, + "nullable": false + }, { + "boolean": false, + "nullable": false + }] + }, { + "fields": [{ + "boolean": false, + "nullable": false + }, { + "boolean": true, + "nullable": false + }] + }, { + "fields": [{ + "boolean": false, + "nullable": false + }, { + "boolean": false, + "nullable": false + }] + }] + } + } + } + } + }, + "names": [ + "a", "b", "result" + ] + } + } + ] + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_built_in_binary_expr_xor.substrait.json b/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_built_in_binary_expr_xor.substrait.json new file mode 100644 index 0000000000000..cfd760de890c0 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/scalar_fn_to_built_in_binary_expr_xor.substrait.json @@ -0,0 +1,132 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_boolean.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "xor:bool_bool" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "expressions": [ + { + "scalarFunction": { + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + } + } + ], + "input": { + "read": { + "baseSchema": { + "names": [ + "a", "b" + ], + "struct": { + "types": [ + { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "common": { + "direct": {} + }, + "virtualTable": { + "values": [{ + "fields": [{ + "boolean": true, + "nullable": false + }, { + "boolean": true, + "nullable": false + }] + }, { + "fields": [{ + "boolean": true, + "nullable": false + }, { + "boolean": false, + "nullable": false + }] + }, { + "fields": [{ + "boolean": false, + "nullable": false + }, { + "boolean": true, + "nullable": false + }] + }, { + "fields": [{ + "boolean": false, + "nullable": false + }, { + "boolean": false, + "nullable": false + }] + }] + } + } + } + } + }, + "names": [ + "a", "b", "result" + ] + } + } + ] + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/virtual_table_with_expressions.substrait.json b/datafusion/substrait/tests/testdata/test_plans/virtual_table_with_expressions.substrait.json new file mode 100644 index 0000000000000..2c634fa957579 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/virtual_table_with_expressions.substrait.json @@ -0,0 +1,75 @@ +{ + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "dummy1", "dummy2" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "expressions": [ + { + "fields": [ + { + "literal": { + "i64": "0", + "nullable": false + } + }, + { + "literal": { + "string": "temp", + "nullable": false + } + } + ] + }, + { + "fields": [ + { + "literal": { + "i64": "1", + "nullable": false + } + }, + { + "literal": { + "string": "test", + "nullable": false + } + } + ] + } + ] + } + } + }, + "names": [ + "result1", "result2" + ] + } + } + ] + } \ No newline at end of file diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index e3e3ec3fab018..f84594312b634 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -150,7 +150,7 @@ pub mod test { let df_schema = from_substrait_named_struct(self.consumer, substrait_schema)? .replace_qualifier(table_reference.clone()); - let table = EmptyTable::new(df_schema.inner().clone()); + let table = EmptyTable::new(Arc::clone(df_schema.inner())); self.schemas.push((table_reference, Arc::new(table))); Ok(()) } diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 10eab025734c9..c1b2f927e30c7 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -39,27 +39,28 @@ crate-type = ["cdylib", "rlib"] [dependencies] # chrono must be compiled with wasmbind feature chrono = { version = "0.4", features = ["wasmbind"] } - # The `console_error_panic_hook` crate provides better debugging of panics by # logging them with `console.error`. This is great for development, but requires # all the `std::fmt` and `std::panicking` infrastructure, so isn't great for # code size when deploying. console_error_panic_hook = { version = "0.1.1", optional = true } -datafusion = { workspace = true, features = ["parquet"] } -datafusion-common = { workspace = true, default-features = true } +datafusion = { workspace = true, features = ["parquet", "sql"] } +datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } -# getrandom must be compiled with js feature -getrandom = { version = "0.2.8", features = ["js"] } - +# needs to be compiled +getrandom = { version = "0.3", features = ["wasm_js"] } wasm-bindgen = "0.2.99" [dev-dependencies] -insta = { workspace = true } object_store = { workspace = true } +# needs to be compiled tokio = { workspace = true } url = { workspace = true } -wasm-bindgen-test = "0.3.49" +wasm-bindgen-test = "0.3.54" + +[package.metadata.cargo-machete] +ignored = ["chrono", "getrandom"] diff --git a/datafusion/wasmtest/README.md b/datafusion/wasmtest/README.md index 8843eed697eca..57a12ef8b8321 100644 --- a/datafusion/wasmtest/README.md +++ b/datafusion/wasmtest/README.md @@ -32,7 +32,7 @@ Some of DataFusion's downstream projects compile to WASM to run in the browser. ## Setup -First, [install wasm-pack](https://rustwasm.github.io/wasm-pack/installer/) +First, [install wasm-pack](https://drager.github.io/wasm-pack/installer/) Then use wasm-pack to compile the crate from within this directory @@ -40,6 +40,20 @@ Then use wasm-pack to compile the crate from within this directory wasm-pack build ``` +### Apple silicon + +The default installation of Clang on Apple silicon does not support wasm, so you'll need to install LLVM Clang. For example via Homebrew: + +```sh +brew install llvm +# You will also need to install wasm-bindgen-cli separately, changing version as needed (0.3.53 = 0.2.103) +cargo install wasm-bindgen-cli@0.2.103 +# Need to run commands like so, unless you edit your PATH to prepend the LLVM version of Clang +PATH="/opt/homebrew/opt/llvm/bin:$PATH" RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack build +``` + +- For reference: https://github.com/briansmith/ring/issues/1824 + ## Try it out The `datafusion-wasm-app` directory contains a simple app (created with [`create-wasm-app`](https://github.com/rustwasm/create-wasm-app) and then manually updated to WebPack 5) that invokes DataFusion and writes results to the browser console. @@ -71,8 +85,6 @@ wasm-pack test --headless --chrome wasm-pack test --headless --safari ``` -**Note:** In GitHub Actions we test the compilation with `wasm-build`, but we don't currently invoke `wasm-pack test`. This is because the headless mode is not yet working. Document of adding a GitHub Action job: https://rustwasm.github.io/docs/wasm-bindgen/wasm-bindgen-test/continuous-integration.html#github-actions. - To tweak timeout setting, use `WASM_BINDGEN_TEST_TIMEOUT` environment variable. E.g., `WASM_BINDGEN_TEST_TIMEOUT=300 wasm-pack test --firefox --headless`. ## Compatibility diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 65d8bdbb5e931..80d3d7b473bca 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -15,7 +15,7 @@ "copy-webpack-plugin": "12.0.2", "webpack": "5.94.0", "webpack-cli": "5.1.4", - "webpack-dev-server": "4.15.1" + "webpack-dev-server": "5.2.1" } }, "../pkg": { @@ -90,10 +90,11 @@ } }, "node_modules/@leichtgewicht/ip-codec": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", - "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==", - "dev": true + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.5.tgz", + "integrity": "sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==", + "dev": true, + "license": "MIT" }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", @@ -157,10 +158,11 @@ } }, "node_modules/@types/bonjour": { - "version": "3.5.11", - "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.11.tgz", - "integrity": "sha512-isGhjmBtLIxdHBDl2xGwUzEM8AOyOvWsADWq7rqirdi/ZQoHnLWErHvsThcEzTX8juDRiZtzp2Qkv5bgNh6mAg==", + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.13.tgz", + "integrity": "sha512-z9fJ5Im06zvUL548KvYNecEVlA7cVDkGUi6kZusb04mpyEFKCIZJvloCcmpmLaIahDpOQGHaHmG6imtPMmPXGQ==", "dev": true, + "license": "MIT", "dependencies": { "@types/node": "*" } @@ -175,10 +177,11 @@ } }, "node_modules/@types/connect-history-api-fallback": { - "version": "1.5.1", - "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.1.tgz", - "integrity": "sha512-iaQslNbARe8fctL5Lk+DsmgWOM83lM+7FzP0eQUJs1jd3kBE8NWqBTIT2S8SqQOJjxvt2eyIjpOuYeRXq2AdMw==", + "version": "1.5.4", + "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.4.tgz", + "integrity": "sha512-n6Cr2xS1h4uAulPRdlw6Jl6s1oG8KrVilPN2yUITEs+K48EzMJJ3W1xy8K5eWuFvjp3R74AOIGSmp2UfBJ8HFw==", "dev": true, + "license": "MIT", "dependencies": { "@types/express-serve-static-core": "*", "@types/node": "*" @@ -191,10 +194,11 @@ "dev": true }, "node_modules/@types/express": { - "version": "4.17.17", - "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.17.tgz", - "integrity": "sha512-Q4FmmuLGBG58btUnfS1c1r/NQdlp3DMfGDGig8WhfpA2YRUtEkxAjkZb0yvplJGYdF1fsQ81iMDcH24sSCNC/Q==", + "version": "4.17.22", + "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.22.tgz", + "integrity": "sha512-eZUmSnhRX9YRSkplpz0N+k6NljUUn5l3EWZIKZvYzhvMphEuNiyyy1viH/ejgt66JWgALwC/gtSUAeQKtSwW/w==", "dev": true, + "license": "MIT", "dependencies": { "@types/body-parser": "*", "@types/express-serve-static-core": "^4.17.33", @@ -247,6 +251,16 @@ "integrity": "sha512-HksnYH4Ljr4VQgEy2lTStbCKv/P590tmPe5HqOnv9Gprffgv5WXAY+Y5Gqniu0GGqeTCUdBnzC3QSrzPkBkAMA==", "dev": true }, + "node_modules/@types/node-forge": { + "version": "1.3.11", + "resolved": "https://registry.npmjs.org/@types/node-forge/-/node-forge-1.3.11.tgz", + "integrity": "sha512-FQx220y22OKNTqaByeBGqHWYz4cl94tpcxeFdvBo3wjG6XPBuZ0BNgNZRV5J5TFmmcsJ4IzsLkmGRiQbnYsBEQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@types/qs": { "version": "6.9.8", "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.8.tgz", @@ -260,10 +274,11 @@ "dev": true }, "node_modules/@types/retry": { - "version": "0.12.0", - "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.0.tgz", - "integrity": "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==", - "dev": true + "version": "0.12.2", + "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.2.tgz", + "integrity": "sha512-XISRgDJ2Tc5q4TRqvgJtzsRkFYNJzZrhTdtMoGVBttwzzQJkPnS3WWTFc7kuDRoPtPakl+T+OfdEUjYJj7Jbow==", + "dev": true, + "license": "MIT" }, "node_modules/@types/send": { "version": "0.17.1", @@ -276,39 +291,43 @@ } }, "node_modules/@types/serve-index": { - "version": "1.9.1", - "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.1.tgz", - "integrity": "sha512-d/Hs3nWDxNL2xAczmOVZNj92YZCS6RGxfBPjKzuu/XirCgXdpKEb88dYNbrYGint6IVWLNP+yonwVAuRC0T2Dg==", + "version": "1.9.4", + "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.4.tgz", + "integrity": "sha512-qLpGZ/c2fhSs5gnYsQxtDEq3Oy8SXPClIXkW5ghvAvsNuVSA8k+gCONcUCS/UjLEYvYps+e8uBtfgXgvhwfNug==", "dev": true, + "license": "MIT", "dependencies": { "@types/express": "*" } }, "node_modules/@types/serve-static": { - "version": "1.15.2", - "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.2.tgz", - "integrity": "sha512-J2LqtvFYCzaj8pVYKw8klQXrLLk7TBZmQ4ShlcdkELFKGwGMfevMLneMMRkMgZxotOD9wg497LpC7O8PcvAmfw==", + "version": "1.15.7", + "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.7.tgz", + "integrity": "sha512-W8Ym+h8nhuRwaKPaDw34QUkwsGi6Rc4yYqvKFo5rm2FUEhCFbzVWrxXUxuKK8TASjWsysJY0nsmNCGhCOIsrOw==", "dev": true, + "license": "MIT", "dependencies": { "@types/http-errors": "*", - "@types/mime": "*", - "@types/node": "*" + "@types/node": "*", + "@types/send": "*" } }, "node_modules/@types/sockjs": { - "version": "0.3.33", - "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.33.tgz", - "integrity": "sha512-f0KEEe05NvUnat+boPTZ0dgaLZ4SfSouXUgv5noUiefG2ajgKjmETo9ZJyuqsl7dfl2aHlLJUiki6B4ZYldiiw==", + "version": "0.3.36", + "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.36.tgz", + "integrity": "sha512-MK9V6NzAS1+Ud7JV9lJLFqW85VbC9dq3LmwZCuBe4wBDgKC0Kj/jd8Xl+nSviU+Qc3+m7umHHyHg//2KSa0a0Q==", "dev": true, + "license": "MIT", "dependencies": { "@types/node": "*" } }, "node_modules/@types/ws": { - "version": "8.5.5", - "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.5.tgz", - "integrity": "sha512-lwhs8hktwxSjf9UaZ9tG5M03PGogvFaH8gUgLNbN9HKIg0dvv6q+gkSuJ8HN4/VbyxkuLzCjlN7GquQ0gUJfIg==", + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", "dev": true, + "license": "MIT", "dependencies": { "@types/node": "*" } @@ -630,6 +649,7 @@ "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", "dev": true, + "license": "ISC", "dependencies": { "normalize-path": "^3.0.0", "picomatch": "^2.0.4" @@ -639,16 +659,11 @@ } }, "node_modules/array-flatten": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-2.1.2.tgz", - "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", - "dev": true - }, - "node_modules/balanced-match": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", - "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", - "dev": true + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", + "dev": true, + "license": "MIT" }, "node_modules/batch": { "version": "0.6.1", @@ -657,12 +672,16 @@ "dev": true }, "node_modules/binary-extensions": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", - "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", + "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", "dev": true, + "license": "MIT", "engines": { "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/body-parser": { @@ -670,6 +689,7 @@ "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", "dev": true, + "license": "MIT", "dependencies": { "bytes": "3.1.2", "content-type": "~1.0.5", @@ -689,20 +709,12 @@ "npm": "1.2.8000 || >= 1.4.16" } }, - "node_modules/body-parser/node_modules/bytes": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", - "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", - "dev": true, - "engines": { - "node": ">= 0.8" - } - }, "node_modules/body-parser/node_modules/debug": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, + "license": "MIT", "dependencies": { "ms": "2.0.0" } @@ -712,32 +724,22 @@ "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } }, "node_modules/bonjour-service": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.1.1.tgz", - "integrity": "sha512-Z/5lQRMOG9k7W+FkeGTNjh7htqn/2LMnfOvBZ8pynNZCM9MwkQkI3zeI4oz09uWdcgmgHugVvBqxGg4VQJ5PCg==", + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.3.0.tgz", + "integrity": "sha512-3YuAUiSkWykd+2Azjgyxei8OWf8thdn8AITIog2M4UICzoqfjlqr64WIjEXZllf/W6vK1goqleSR6brGomxQqA==", "dev": true, + "license": "MIT", "dependencies": { - "array-flatten": "^2.1.2", - "dns-equal": "^1.0.0", "fast-deep-equal": "^3.1.3", "multicast-dns": "^7.2.5" } }, - "node_modules/brace-expansion": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", - "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", - "dev": true, - "dependencies": { - "balanced-match": "^1.0.0", - "concat-map": "0.0.1" - } - }, "node_modules/braces": { "version": "3.0.3", "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", @@ -788,26 +790,54 @@ "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", "dev": true }, + "node_modules/bundle-name": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/bundle-name/-/bundle-name-4.1.0.tgz", + "integrity": "sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "run-applescript": "^7.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/bytes": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", - "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", "dev": true, "engines": { "node": ">= 0.8" } }, - "node_modules/call-bind": { - "version": "1.0.7", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", - "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", "dev": true, + "license": "MIT", "dependencies": { - "es-define-property": "^1.0.0", "es-errors": "^1.3.0", - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.4", - "set-function-length": "^1.2.1" + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" }, "engines": { "node": ">= 0.4" @@ -837,16 +867,11 @@ ] }, "node_modules/chokidar": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", - "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", + "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", "dev": true, - "funding": [ - { - "type": "individual", - "url": "https://paulmillr.com/funding/" - } - ], + "license": "MIT", "dependencies": { "anymatch": "~3.1.2", "braces": "~3.0.2", @@ -859,6 +884,9 @@ "engines": { "node": ">= 8.10.0" }, + "funding": { + "url": "https://paulmillr.com/funding/" + }, "optionalDependencies": { "fsevents": "~2.3.2" } @@ -914,17 +942,17 @@ } }, "node_modules/compression": { - "version": "1.7.4", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.7.4.tgz", - "integrity": "sha512-jaSIDzP9pZVS4ZfQ+TzvtiWhdpFhE2RDHz8QJkpX9SIpLq88VueF5jJw6t+6CUQcAoA6t+x89MLrWAqpfDE8iQ==", + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", + "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", "dev": true, "dependencies": { - "accepts": "~1.3.5", - "bytes": "3.0.0", - "compressible": "~2.0.16", + "bytes": "3.1.2", + "compressible": "~2.0.18", "debug": "2.6.9", - "on-headers": "~1.0.2", - "safe-buffer": "5.1.2", + "negotiator": "~0.6.4", + "on-headers": "~1.1.0", + "safe-buffer": "5.2.1", "vary": "~1.1.2" }, "engines": { @@ -940,11 +968,34 @@ "ms": "2.0.0" } }, - "node_modules/concat-map": { - "version": "0.0.1", - "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", - "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", - "dev": true + "node_modules/compression/node_modules/negotiator": { + "version": "0.6.4", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.4.tgz", + "integrity": "sha512-myRT3DiWPHqho5PrJaIRyaMv2kgYf0mUVgBNOYMuCH5Ki1yEiQaf/ZJuQ62nvpc44wL5WDbTX7yGJi1Neevw8w==", + "dev": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/compression/node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ] }, "node_modules/connect-history-api-fallback": { "version": "2.0.0", @@ -960,6 +1011,7 @@ "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.4.tgz", "integrity": "sha512-FveZTNuGw04cxlAiWbzi6zTAL/lhehaWbTtgluJh4/E95DqMwTmha3KZN1aAWA8cFIhHzMZUvLevkw5Rqk+tSQ==", "dev": true, + "license": "MIT", "dependencies": { "safe-buffer": "5.2.1" }, @@ -985,13 +1037,15 @@ "type": "consulting", "url": "https://feross.org/support" } - ] + ], + "license": "MIT" }, "node_modules/content-type": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz", "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -1001,6 +1055,7 @@ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -1009,7 +1064,8 @@ "version": "1.0.6", "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/copy-webpack-plugin": { "version": "12.0.2", @@ -1146,42 +1202,47 @@ "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", "dev": true }, - "node_modules/default-gateway": { - "version": "6.0.3", - "resolved": "https://registry.npmjs.org/default-gateway/-/default-gateway-6.0.3.tgz", - "integrity": "sha512-fwSOJsbbNzZ/CUFpqFBqYfYNLj1NbMPm8MMCIzHjC83iSJRBEGmDUxU+WP661BaBQImeC2yHwXtz+P/O9o+XEg==", + "node_modules/default-browser": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/default-browser/-/default-browser-5.2.1.tgz", + "integrity": "sha512-WY/3TUME0x3KPYdRRxEJJvXRHV4PyPoUsxtZa78lwItwRQRHhd2U9xOscaT/YTf8uCXIAjeJOFBVEh/7FtD8Xg==", "dev": true, + "license": "MIT", "dependencies": { - "execa": "^5.0.0" + "bundle-name": "^4.1.0", + "default-browser-id": "^5.0.0" }, "engines": { - "node": ">= 10" + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/define-data-property": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", - "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "node_modules/default-browser-id": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/default-browser-id/-/default-browser-id-5.0.0.tgz", + "integrity": "sha512-A6p/pu/6fyBcA1TRz/GqWYPViplrftcW2gZC9q79ngNCKAeR/X3gcEdXQHl4KNXV+3wgIJ1CPkJQ3IHM6lcsyA==", "dev": true, - "dependencies": { - "es-define-property": "^1.0.0", - "es-errors": "^1.3.0", - "gopd": "^1.0.1" - }, + "license": "MIT", "engines": { - "node": ">= 0.4" + "node": ">=18" }, "funding": { - "url": "https://github.com/sponsors/ljharb" + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/define-lazy-prop": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", - "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-3.0.0.tgz", + "integrity": "sha512-N+MeXYoqr3pOgn8xfyRPREN7gHakLYjhsHhWGT3fWAiL4IkAt0iDw14QiiEm2bE30c5XX5q0FtAA3CK5f9/BUg==", "dev": true, + "license": "MIT", "engines": { - "node": ">=8" + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/depd": { @@ -1198,6 +1259,7 @@ "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.2.0.tgz", "integrity": "sha512-2sJGJTaXIIaR1w4iJSNoN0hnMY7Gpc/n8D4qSCJw8QqFWXf7cuAgnEHxBpweaVcPevC2l3KpjYCx3NypQQgaJg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8", "npm": "1.2.8000 || >= 1.4.16" @@ -1209,17 +1271,12 @@ "integrity": "sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw==", "dev": true }, - "node_modules/dns-equal": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/dns-equal/-/dns-equal-1.0.0.tgz", - "integrity": "sha512-z+paD6YUQsk+AbGCEM4PrOXSss5gd66QfcVBFTKR/HpFL9jCqikS94HYwKww6fQyO7IxrIIyUu+g0Ka9tUS2Cg==", - "dev": true - }, "node_modules/dns-packet": { "version": "5.6.1", "resolved": "https://registry.npmjs.org/dns-packet/-/dns-packet-5.6.1.tgz", "integrity": "sha512-l4gcSouhcgIKRvyy99RNVOgxXiicE+2jZoNmaNmZ6JXiGajBOJAesk1OBlJuM5k2c+eudGdLxDqXuPCKIj6kpw==", "dev": true, + "license": "MIT", "dependencies": { "@leichtgewicht/ip-codec": "^2.0.1" }, @@ -1227,11 +1284,27 @@ "node": ">=6" } }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/ee-first": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/electron-to-chromium": { "version": "1.4.528", @@ -1240,10 +1313,11 @@ "dev": true }, "node_modules/encodeurl": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", - "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1274,13 +1348,11 @@ } }, "node_modules/es-define-property": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", - "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", "dev": true, - "dependencies": { - "get-intrinsic": "^1.2.4" - }, + "license": "MIT", "engines": { "node": ">= 0.4" } @@ -1290,6 +1362,7 @@ "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.4" } @@ -1300,6 +1373,19 @@ "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", "dev": true }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/escalade": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", @@ -1363,6 +1449,7 @@ "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -1382,34 +1469,12 @@ "node": ">=0.8.x" } }, - "node_modules/execa": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/execa/-/execa-5.1.1.tgz", - "integrity": "sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==", - "dev": true, - "dependencies": { - "cross-spawn": "^7.0.3", - "get-stream": "^6.0.0", - "human-signals": "^2.1.0", - "is-stream": "^2.0.0", - "merge-stream": "^2.0.0", - "npm-run-path": "^4.0.1", - "onetime": "^5.1.2", - "signal-exit": "^3.0.3", - "strip-final-newline": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sindresorhus/execa?sponsor=1" - } - }, "node_modules/express": { - "version": "4.21.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", - "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", + "version": "4.21.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.2.tgz", + "integrity": "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==", "dev": true, + "license": "MIT", "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", @@ -1430,7 +1495,7 @@ "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.10", + "path-to-regexp": "0.1.12", "proxy-addr": "~2.0.7", "qs": "6.13.0", "range-parser": "~1.2.1", @@ -1445,19 +1510,18 @@ }, "engines": { "node": ">= 0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" } }, - "node_modules/express/node_modules/array-flatten": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", - "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", - "dev": true - }, "node_modules/express/node_modules/debug": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, + "license": "MIT", "dependencies": { "ms": "2.0.0" } @@ -1467,15 +1531,7 @@ "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true, - "engines": { - "node": ">= 0.8" - } - }, - "node_modules/express/node_modules/encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1498,13 +1554,15 @@ "type": "consulting", "url": "https://feross.org/support" } - ] + ], + "license": "MIT" }, "node_modules/express/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1603,6 +1661,7 @@ "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.3.1.tgz", "integrity": "sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==", "dev": true, + "license": "MIT", "dependencies": { "debug": "2.6.9", "encodeurl": "~2.0.0", @@ -1621,24 +1680,17 @@ "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, + "license": "MIT", "dependencies": { "ms": "2.0.0" } }, - "node_modules/finalhandler/node_modules/encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true, - "engines": { - "node": ">= 0.8" - } - }, "node_modules/finalhandler/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1681,6 +1733,7 @@ "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -1690,28 +1743,18 @@ "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", "integrity": "sha512-zJ2mQYM18rEFOudeV4GShTGIQ7RbzA7ozbU9I/XBpm7kqgMywgmylMwXHxZJmkVoYkna9d2pVXVXPdYTP9ej8Q==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } }, - "node_modules/fs-monkey": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.4.tgz", - "integrity": "sha512-INM/fWAxMICjttnD0DX1rBvinKskj5G1w+oy/pnm9u/tSlnBrzFonJMcalKJ30P8RRsPzKcCG7Q8l0jx5Fh9YQ==", - "dev": true - }, - "node_modules/fs.realpath": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", - "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", - "dev": true - }, "node_modules/fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", "dev": true, "hasInstallScript": true, + "license": "MIT", "optional": true, "os": [ "darwin" @@ -1730,16 +1773,22 @@ } }, "node_modules/get-intrinsic": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", - "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", "dev": true, + "license": "MIT", "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", "function-bind": "^1.1.2", - "has-proto": "^1.0.1", - "has-symbols": "^1.0.3", - "hasown": "^2.0.0" + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" }, "engines": { "node": ">= 0.4" @@ -1748,36 +1797,18 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/get-stream": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz", - "integrity": "sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==", - "dev": true, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/glob": { - "version": "7.2.3", - "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", - "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", "dev": true, + "license": "MIT", "dependencies": { - "fs.realpath": "^1.0.0", - "inflight": "^1.0.4", - "inherits": "2", - "minimatch": "^3.1.1", - "once": "^1.3.0", - "path-is-absolute": "^1.0.0" + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" }, "engines": { - "node": "*" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" + "node": ">= 0.4" } }, "node_modules/glob-parent": { @@ -1820,12 +1851,13 @@ } }, "node_modules/gopd": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", - "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", "dev": true, - "dependencies": { - "get-intrinsic": "^1.1.3" + "license": "MIT", + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1864,35 +1896,12 @@ "node": ">=8" } }, - "node_modules/has-property-descriptors": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", - "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", - "dev": true, - "dependencies": { - "es-define-property": "^1.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-proto": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", - "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", - "dev": true, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, "node_modules/has-symbols": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", - "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.4" }, @@ -1905,6 +1914,7 @@ "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", "dev": true, + "license": "MIT", "dependencies": { "function-bind": "^1.1.2" }, @@ -1924,22 +1934,6 @@ "wbuf": "^1.1.0" } }, - "node_modules/html-entities": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/html-entities/-/html-entities-2.4.0.tgz", - "integrity": "sha512-igBTJcNNNhvZFRtm8uA6xMY6xYleeDwn3PeBCkDz7tHttv4F2hsDI2aPgNERWzvRcNYHNT3ymRaQzllmXj4YsQ==", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/mdevils" - }, - { - "type": "patreon", - "url": "https://patreon.com/mdevils" - } - ] - }, "node_modules/http-deceiver": { "version": "1.2.7", "resolved": "https://registry.npmjs.org/http-deceiver/-/http-deceiver-1.2.7.tgz", @@ -1951,6 +1945,7 @@ "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz", "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==", "dev": true, + "license": "MIT", "dependencies": { "depd": "2.0.0", "inherits": "2.0.4", @@ -1967,6 +1962,7 @@ "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1975,13 +1971,15 @@ "version": "2.0.4", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", - "dev": true + "dev": true, + "license": "ISC" }, "node_modules/http-errors/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -2007,10 +2005,11 @@ } }, "node_modules/http-proxy-middleware": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.6.tgz", - "integrity": "sha512-ya/UeJ6HVBYxrgYotAZo1KvPWlgB48kUJLDePFeneHsVujFaW5WNj2NgWCAE//B1Dl02BIfYlpNgBy8Kf8Rjmw==", + "version": "2.0.9", + "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz", + "integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==", "dev": true, + "license": "MIT", "dependencies": { "@types/http-proxy": "^1.17.8", "http-proxy": "^1.18.1", @@ -2030,13 +2029,14 @@ } } }, - "node_modules/human-signals": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-2.1.0.tgz", - "integrity": "sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==", + "node_modules/hyperdyperid": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/hyperdyperid/-/hyperdyperid-1.2.0.tgz", + "integrity": "sha512-Y93lCzHYgGWdrJ66yIktxiaGULYc6oGiABxhcO5AufBeOyoIdZF7bIfLaOrbM0iGIOXQQgxxRrFEnb+Y6w1n4A==", "dev": true, + "license": "MIT", "engines": { - "node": ">=10.17.0" + "node": ">=10.18" } }, "node_modules/iconv-lite": { @@ -2044,6 +2044,7 @@ "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz", "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==", "dev": true, + "license": "MIT", "dependencies": { "safer-buffer": ">= 2.1.2 < 3" }, @@ -2080,16 +2081,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/inflight": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", - "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", - "dev": true, - "dependencies": { - "once": "^1.3.0", - "wrappy": "1" - } - }, "node_modules/inherits": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", @@ -2119,6 +2110,7 @@ "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", "dev": true, + "license": "MIT", "dependencies": { "binary-extensions": "^2.0.0" }, @@ -2139,15 +2131,16 @@ } }, "node_modules/is-docker": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-2.2.1.tgz", - "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-3.0.0.tgz", + "integrity": "sha512-eljcgEDlEns/7AXFosB5K/2nCM4P7FQPkGc/DWLy5rmFEWvZayGrik1d9/QIY5nJ4f9YsVvBkA6kJpHn9rISdQ==", "dev": true, + "license": "MIT", "bin": { "is-docker": "cli.js" }, "engines": { - "node": ">=8" + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" @@ -2174,6 +2167,38 @@ "node": ">=0.10.0" } }, + "node_modules/is-inside-container": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-inside-container/-/is-inside-container-1.0.0.tgz", + "integrity": "sha512-KIYLCCJghfHZxqjYBE7rEy0OBuTd5xCHS7tHVgvCLkx7StIoaxwNW3hCALgEUjFfeRk+MG/Qxmp/vtETEF3tRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-docker": "^3.0.0" + }, + "bin": { + "is-inside-container": "cli.js" + }, + "engines": { + "node": ">=14.16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-network-error": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-network-error/-/is-network-error-1.1.0.tgz", + "integrity": "sha512-tUdRRAnhT+OtCZR/LxZelH/C7QtjtFrTu5tXCA8pl55eTUElUHT+GPYV8MBMBvea/j+NxQqVt3LbWMRir7Gx9g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/is-number": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", @@ -2207,28 +2232,20 @@ "node": ">=0.10.0" } }, - "node_modules/is-stream": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", - "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", - "dev": true, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/is-wsl": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", - "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-3.1.0.tgz", + "integrity": "sha512-UcVfVfaK4Sc4m7X3dUSoHoozQGBEFeDC+zVo06t98xe8CzHSZZBekNXH+tu0NalHolcJ/QAGqS46Hef7QXBIMw==", "dev": true, + "license": "MIT", "dependencies": { - "is-docker": "^2.0.0" + "is-inside-container": "^1.0.0" }, "engines": { - "node": ">=8" + "node": ">=16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/isarray": { @@ -2288,13 +2305,14 @@ } }, "node_modules/launch-editor": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.6.0.tgz", - "integrity": "sha512-JpDCcQnyAAzZZaZ7vEiSqL690w7dAEyLao+KC96zBplnYbJS7TYNjvM3M7y3dGz+v7aIsJk3hllWuc0kWAjyRQ==", + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.10.0.tgz", + "integrity": "sha512-D7dBRJo/qcGX9xlvt/6wUYzQxjh5G1RvZPgPv8vi4KRU99DVQL/oW7tnVOCCTm2HGeo3C5HvGE5Yrh6UBoZ0vA==", "dev": true, + "license": "MIT", "dependencies": { "picocolors": "^1.0.0", - "shell-quote": "^1.7.3" + "shell-quote": "^1.8.1" } }, "node_modules/loader-runner": { @@ -2318,32 +2336,146 @@ "node": ">=8" } }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, "node_modules/media-typer": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", "integrity": "sha512-dq+qelQ9akHpcOl/gUVRTxVIOkAJ1wR3QAvb4RsVjS8oVoFjDGTc679wJYmUmknUF5HwMLOgb5O+a3KxfWapPQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } }, "node_modules/memfs": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/memfs/-/memfs-3.5.3.tgz", - "integrity": "sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==", + "version": "4.17.2", + "resolved": "https://registry.npmjs.org/memfs/-/memfs-4.17.2.tgz", + "integrity": "sha512-NgYhCOWgovOXSzvYgUW0LQ7Qy72rWQMGGFJDoWg4G30RHd3z77VbYdtJ4fembJXBy8pMIUA31XNAupobOQlwdg==", "dev": true, + "license": "Apache-2.0", "dependencies": { - "fs-monkey": "^1.0.4" + "@jsonjoy.com/json-pack": "^1.0.3", + "@jsonjoy.com/util": "^1.3.0", + "tree-dump": "^1.0.1", + "tslib": "^2.0.0" }, "engines": { "node": ">= 4.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + } + }, + "node_modules/memfs/node_modules/@jsonjoy.com/base64": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/base64/-/base64-1.1.2.tgz", + "integrity": "sha512-q6XAnWQDIMA3+FTiOYajoYqySkO+JSat0ytXGSuRdq9uXE7o92gzuQwQM14xaCRlBLGq3v5miDGC4vkVTn54xA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + }, + "peerDependencies": { + "tslib": "2" + } + }, + "node_modules/memfs/node_modules/@jsonjoy.com/json-pack": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/json-pack/-/json-pack-1.2.0.tgz", + "integrity": "sha512-io1zEbbYcElht3tdlqEOFxZ0dMTYrHz9iMf0gqn1pPjZFTCgM5R4R5IMA20Chb2UPYYsxjzs8CgZ7Nb5n2K2rA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@jsonjoy.com/base64": "^1.1.1", + "@jsonjoy.com/util": "^1.1.2", + "hyperdyperid": "^1.2.0", + "thingies": "^1.20.0" + }, + "engines": { + "node": ">=10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + }, + "peerDependencies": { + "tslib": "2" + } + }, + "node_modules/memfs/node_modules/@jsonjoy.com/util": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/util/-/util-1.6.0.tgz", + "integrity": "sha512-sw/RMbehRhN68WRtcKCpQOPfnH6lLP4GJfqzi3iYej8tnzpZUDr6UkZYJjcjjC0FWEJOJbyM3PTIwxucUmDG2A==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + }, + "peerDependencies": { + "tslib": "2" } }, + "node_modules/memfs/node_modules/thingies": { + "version": "1.21.0", + "resolved": "https://registry.npmjs.org/thingies/-/thingies-1.21.0.tgz", + "integrity": "sha512-hsqsJsFMsV+aD4s3CWKk85ep/3I9XzYV/IXaSouJMYIoDlgyi11cBhsqYe9/geRfB0YIikBQg6raRaM+nIMP9g==", + "dev": true, + "license": "Unlicense", + "engines": { + "node": ">=10.18" + }, + "peerDependencies": { + "tslib": "^2" + } + }, + "node_modules/memfs/node_modules/tree-dump": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/tree-dump/-/tree-dump-1.0.3.tgz", + "integrity": "sha512-il+Cv80yVHFBwokQSfd4bldvr1Md951DpgAGfmhydt04L+YzHgubm2tQ7zueWDcGENKHq0ZvGFR/hjvNXilHEg==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + }, + "peerDependencies": { + "tslib": "2" + } + }, + "node_modules/memfs/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true, + "license": "0BSD" + }, "node_modules/merge-descriptors": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.3.tgz", "integrity": "sha512-gaNvAS7TZ897/rVaZ0nMtAyxNyi/pdbjbAwUpFQpN70GqnVfOiXpeUUMKRBmzXaSQ8DdTX4/0ms62r2K+hE6mQ==", "dev": true, + "license": "MIT", "funding": { "url": "https://github.com/sponsors/sindresorhus" } @@ -2369,6 +2501,7 @@ "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz", "integrity": "sha512-iclAHeNqNm68zFtnZ0e+1L2yUIdvzNoauKU4WBA3VvH/vPFieF7qfRlwUZU+DA9P9bPXIS90ulxoUoCH23sV2w==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -2392,6 +2525,7 @@ "resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz", "integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==", "dev": true, + "license": "MIT", "bin": { "mime": "cli.js" }, @@ -2414,19 +2548,10 @@ "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", "dev": true, "dependencies": { - "mime-db": "1.52.0" - }, - "engines": { - "node": ">= 0.6" - } - }, - "node_modules/mimic-fn": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.1.0.tgz", - "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", - "dev": true, + "mime-db": "1.52.0" + }, "engines": { - "node": ">=6" + "node": ">= 0.6" } }, "node_modules/minimalistic-assert": { @@ -2435,18 +2560,6 @@ "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", "dev": true }, - "node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", - "dev": true, - "dependencies": { - "brace-expansion": "^1.1.7" - }, - "engines": { - "node": "*" - } - }, "node_modules/ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", @@ -2458,6 +2571,7 @@ "resolved": "https://registry.npmjs.org/multicast-dns/-/multicast-dns-7.2.5.tgz", "integrity": "sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg==", "dev": true, + "license": "MIT", "dependencies": { "dns-packet": "^5.2.2", "thunky": "^1.0.2" @@ -2486,6 +2600,7 @@ "resolved": "https://registry.npmjs.org/node-forge/-/node-forge-1.3.1.tgz", "integrity": "sha512-dPEtOeMvF9VMcYV/1Wb8CPoVAXtp6MKMlcbAt4ddqmGqUJ6fQZFXkNZNkNlfevtNkGtaSoXf/vNNNSvgrdXwtA==", "dev": true, + "license": "(BSD-3-Clause OR GPL-2.0)", "engines": { "node": ">= 6.13.0" } @@ -2505,23 +2620,12 @@ "node": ">=0.10.0" } }, - "node_modules/npm-run-path": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", - "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", - "dev": true, - "dependencies": { - "path-key": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/object-inspect": { - "version": "1.13.2", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", - "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", + "version": "1.13.4", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.4" }, @@ -2540,6 +2644,7 @@ "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==", "dev": true, + "license": "MIT", "dependencies": { "ee-first": "1.1.1" }, @@ -2548,50 +2653,28 @@ } }, "node_modules/on-headers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", - "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", + "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", "dev": true, "engines": { "node": ">= 0.8" } }, - "node_modules/once": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", - "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", - "dev": true, - "dependencies": { - "wrappy": "1" - } - }, - "node_modules/onetime": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.2.tgz", - "integrity": "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==", - "dev": true, - "dependencies": { - "mimic-fn": "^2.1.0" - }, - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/open": { - "version": "8.4.2", - "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", - "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", + "version": "10.1.2", + "resolved": "https://registry.npmjs.org/open/-/open-10.1.2.tgz", + "integrity": "sha512-cxN6aIDPz6rm8hbebcP7vrQNhvRcveZoJU72Y7vskh4oIm+BZwBECnx5nTmrlres1Qapvx27Qo1Auukpf8PKXw==", "dev": true, + "license": "MIT", "dependencies": { - "define-lazy-prop": "^2.0.0", - "is-docker": "^2.1.1", - "is-wsl": "^2.2.0" + "default-browser": "^5.2.1", + "define-lazy-prop": "^3.0.0", + "is-inside-container": "^1.0.0", + "is-wsl": "^3.1.0" }, "engines": { - "node": ">=12" + "node": ">=18" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" @@ -2625,16 +2708,21 @@ } }, "node_modules/p-retry": { - "version": "4.6.2", - "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", - "integrity": "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ==", + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-6.2.1.tgz", + "integrity": "sha512-hEt02O4hUct5wtwg4H4KcWgDdm+l1bOaEy/hWzd8xtXB9BqxTWBBhb+2ImAtH4Cv4rPjV76xN3Zumqk3k3AhhQ==", "dev": true, + "license": "MIT", "dependencies": { - "@types/retry": "0.12.0", + "@types/retry": "0.12.2", + "is-network-error": "^1.0.0", "retry": "^0.13.1" }, "engines": { - "node": ">=8" + "node": ">=16.17" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/p-try": { @@ -2664,15 +2752,6 @@ "node": ">=8" } }, - "node_modules/path-is-absolute": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", - "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/path-key": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", @@ -2689,10 +2768,11 @@ "dev": true }, "node_modules/path-to-regexp": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", - "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==", - "dev": true + "version": "0.1.12", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.12.tgz", + "integrity": "sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==", + "dev": true, + "license": "MIT" }, "node_modules/path-type": { "version": "6.0.0", @@ -2748,6 +2828,7 @@ "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==", "dev": true, + "license": "MIT", "dependencies": { "forwarded": "0.2.0", "ipaddr.js": "1.9.1" @@ -2761,6 +2842,7 @@ "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz", "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.10" } @@ -2779,6 +2861,7 @@ "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", "dev": true, + "license": "BSD-3-Clause", "dependencies": { "side-channel": "^1.0.6" }, @@ -2824,6 +2907,7 @@ "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -2833,6 +2917,7 @@ "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.2.tgz", "integrity": "sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==", "dev": true, + "license": "MIT", "dependencies": { "bytes": "3.1.2", "http-errors": "2.0.0", @@ -2843,15 +2928,6 @@ "node": ">= 0.8" } }, - "node_modules/raw-body/node_modules/bytes": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", - "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", - "dev": true, - "engines": { - "node": ">= 0.8" - } - }, "node_modules/readable-stream": { "version": "2.3.6", "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-2.3.6.tgz", @@ -2872,6 +2948,7 @@ "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", "dev": true, + "license": "MIT", "dependencies": { "picomatch": "^2.2.1" }, @@ -2949,6 +3026,7 @@ "resolved": "https://registry.npmjs.org/retry/-/retry-0.13.1.tgz", "integrity": "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 4" } @@ -2964,19 +3042,17 @@ "node": ">=0.10.0" } }, - "node_modules/rimraf": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", - "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "node_modules/run-applescript": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/run-applescript/-/run-applescript-7.0.0.tgz", + "integrity": "sha512-9by4Ij99JUr/MCFBUkDKLWK3G9HVXmabKz9U5MlIAIuvuzkiOicRYs8XJLxX+xahD+mLiiCYDqF9dKAgtzKP1A==", "dev": true, - "dependencies": { - "glob": "^7.1.3" - }, - "bin": { - "rimraf": "bin.js" + "license": "MIT", + "engines": { + "node": ">=18" }, "funding": { - "url": "https://github.com/sponsors/isaacs" + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/run-parallel": { @@ -3013,7 +3089,8 @@ "version": "2.1.2", "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/schema-utils": { "version": "3.3.0", @@ -3040,11 +3117,13 @@ "dev": true }, "node_modules/selfsigned": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.1.1.tgz", - "integrity": "sha512-GSL3aowiF7wa/WtSFwnUrludWFoNhftq8bUkH9pkzjpN2XSPOAYEgg6e0sS9s0rZwgJzJiQRPU18A6clnoW5wQ==", + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.4.1.tgz", + "integrity": "sha512-th5B4L2U+eGLq1TVh7zNRGBapioSORUeymIydxgFpwww9d2qyKvtuPU2jJuHvYAwwqi2Y596QBL3eEqcPEYL8Q==", "dev": true, + "license": "MIT", "dependencies": { + "@types/node-forge": "^1.3.0", "node-forge": "^1" }, "engines": { @@ -3056,6 +3135,7 @@ "resolved": "https://registry.npmjs.org/send/-/send-0.19.0.tgz", "integrity": "sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==", "dev": true, + "license": "MIT", "dependencies": { "debug": "2.6.9", "depd": "2.0.0", @@ -3080,6 +3160,7 @@ "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, + "license": "MIT", "dependencies": { "ms": "2.0.0" } @@ -3088,13 +3169,25 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", "integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/send/node_modules/depd": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/send/node_modules/encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -3103,13 +3196,15 @@ "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/send/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -3177,6 +3272,7 @@ "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.16.2.tgz", "integrity": "sha512-VqpjJZKadQB/PEbEwvFdO43Ax5dFBZ2UECszz8bQ7pi7wt//PWe1P6MN7eCnjsatYtBT6EuiClbjSWP2WrIoTw==", "dev": true, + "license": "MIT", "dependencies": { "encodeurl": "~2.0.0", "escape-html": "~1.0.3", @@ -3187,37 +3283,12 @@ "node": ">= 0.8.0" } }, - "node_modules/serve-static/node_modules/encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true, - "engines": { - "node": ">= 0.8" - } - }, - "node_modules/set-function-length": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", - "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", - "dev": true, - "dependencies": { - "define-data-property": "^1.1.4", - "es-errors": "^1.3.0", - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.4", - "gopd": "^1.0.1", - "has-property-descriptors": "^1.0.2" - }, - "engines": { - "node": ">= 0.4" - } - }, "node_modules/setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==", - "dev": true + "dev": true, + "license": "ISC" }, "node_modules/shallow-clone": { "version": "3.0.1", @@ -3253,24 +3324,30 @@ } }, "node_modules/shell-quote": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz", - "integrity": "sha512-6j1W9l1iAs/4xYBI1SYOVZyFcCis9b4KCLQ8fgAGG07QvzaRLVVRQvAy85yNmmZSjYjg4MWh4gNvlPujU/5LpA==", + "version": "1.8.3", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz", + "integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==", "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, "funding": { "url": "https://github.com/sponsors/ljharb" } }, "node_modules/side-channel": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", - "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", "dev": true, + "license": "MIT", "dependencies": { - "call-bind": "^1.0.7", "es-errors": "^1.3.0", - "get-intrinsic": "^1.2.4", - "object-inspect": "^1.13.1" + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" }, "engines": { "node": ">= 0.4" @@ -3279,11 +3356,61 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/signal-exit": { - "version": "3.0.7", - "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", - "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", - "dev": true + "node_modules/side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } }, "node_modules/slash": { "version": "5.1.0", @@ -3390,15 +3517,6 @@ "safe-buffer": "~5.1.0" } }, - "node_modules/strip-final-newline": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/strip-final-newline/-/strip-final-newline-2.0.0.tgz", - "integrity": "sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==", - "dev": true, - "engines": { - "node": ">=6" - } - }, "node_modules/supports-color": { "version": "8.1.1", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", @@ -3491,7 +3609,8 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/thunky/-/thunky-1.1.0.tgz", "integrity": "sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/to-regex-range": { "version": "5.0.1", @@ -3510,6 +3629,7 @@ "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz", "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", "dev": true, + "license": "MIT", "engines": { "node": ">=0.6" } @@ -3525,6 +3645,7 @@ "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz", "integrity": "sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==", "dev": true, + "license": "MIT", "dependencies": { "media-typer": "0.3.0", "mime-types": "~2.1.24" @@ -3551,6 +3672,7 @@ "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -3605,6 +3727,7 @@ "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz", "integrity": "sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.4.0" } @@ -3750,38 +3873,46 @@ } }, "node_modules/webpack-dev-middleware": { - "version": "5.3.4", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", - "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", + "version": "7.4.2", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-7.4.2.tgz", + "integrity": "sha512-xOO8n6eggxnwYpy1NlzUKpvrjfJTvae5/D6WOK0S2LSo7vjmo5gCM1DbLUmFqrMTJP+W/0YZNctm7jasWvLuBA==", "dev": true, + "license": "MIT", "dependencies": { "colorette": "^2.0.10", - "memfs": "^3.4.3", + "memfs": "^4.6.0", "mime-types": "^2.1.31", + "on-finished": "^2.4.1", "range-parser": "^1.2.1", "schema-utils": "^4.0.0" }, "engines": { - "node": ">= 12.13.0" + "node": ">= 18.12.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/webpack" }, "peerDependencies": { - "webpack": "^4.0.0 || ^5.0.0" + "webpack": "^5.0.0" + }, + "peerDependenciesMeta": { + "webpack": { + "optional": true + } } }, "node_modules/webpack-dev-middleware/node_modules/ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, + "license": "MIT", "dependencies": { - "fast-deep-equal": "^3.1.1", + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" + "require-from-string": "^2.0.2" }, "funding": { "type": "github", @@ -3793,6 +3924,7 @@ "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", "dev": true, + "license": "MIT", "dependencies": { "fast-deep-equal": "^3.1.3" }, @@ -3804,13 +3936,15 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/webpack-dev-middleware/node_modules/schema-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", - "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", + "version": "4.3.2", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", + "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", "dev": true, + "license": "MIT", "dependencies": { "@types/json-schema": "^7.0.9", "ajv": "^8.9.0", @@ -3818,7 +3952,7 @@ "ajv-keywords": "^5.1.0" }, "engines": { - "node": ">= 12.13.0" + "node": ">= 10.13.0" }, "funding": { "type": "opencollective", @@ -3826,54 +3960,53 @@ } }, "node_modules/webpack-dev-server": { - "version": "4.15.1", - "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-4.15.1.tgz", - "integrity": "sha512-5hbAst3h3C3L8w6W4P96L5vaV0PxSmJhxZvWKYIdgxOQm8pNZ5dEOmmSLBVpP85ReeyRt6AS1QJNyo/oFFPeVA==", - "dev": true, - "dependencies": { - "@types/bonjour": "^3.5.9", - "@types/connect-history-api-fallback": "^1.3.5", - "@types/express": "^4.17.13", - "@types/serve-index": "^1.9.1", - "@types/serve-static": "^1.13.10", - "@types/sockjs": "^0.3.33", - "@types/ws": "^8.5.5", + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-5.2.1.tgz", + "integrity": "sha512-ml/0HIj9NLpVKOMq+SuBPLHcmbG+TGIjXRHsYfZwocUBIqEvws8NnS/V9AFQ5FKP+tgn5adwVwRrTEpGL33QFQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/bonjour": "^3.5.13", + "@types/connect-history-api-fallback": "^1.5.4", + "@types/express": "^4.17.21", + "@types/express-serve-static-core": "^4.17.21", + "@types/serve-index": "^1.9.4", + "@types/serve-static": "^1.15.5", + "@types/sockjs": "^0.3.36", + "@types/ws": "^8.5.10", "ansi-html-community": "^0.0.8", - "bonjour-service": "^1.0.11", - "chokidar": "^3.5.3", + "bonjour-service": "^1.2.1", + "chokidar": "^3.6.0", "colorette": "^2.0.10", "compression": "^1.7.4", "connect-history-api-fallback": "^2.0.0", - "default-gateway": "^6.0.3", - "express": "^4.17.3", + "express": "^4.21.2", "graceful-fs": "^4.2.6", - "html-entities": "^2.3.2", - "http-proxy-middleware": "^2.0.3", - "ipaddr.js": "^2.0.1", - "launch-editor": "^2.6.0", - "open": "^8.0.9", - "p-retry": "^4.5.0", - "rimraf": "^3.0.2", - "schema-utils": "^4.0.0", - "selfsigned": "^2.1.1", + "http-proxy-middleware": "^2.0.7", + "ipaddr.js": "^2.1.0", + "launch-editor": "^2.6.1", + "open": "^10.0.3", + "p-retry": "^6.2.0", + "schema-utils": "^4.2.0", + "selfsigned": "^2.4.1", "serve-index": "^1.9.1", "sockjs": "^0.3.24", "spdy": "^4.0.2", - "webpack-dev-middleware": "^5.3.1", - "ws": "^8.13.0" + "webpack-dev-middleware": "^7.4.2", + "ws": "^8.18.0" }, "bin": { "webpack-dev-server": "bin/webpack-dev-server.js" }, "engines": { - "node": ">= 12.13.0" + "node": ">= 18.12.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/webpack" }, "peerDependencies": { - "webpack": "^4.37.0 || ^5.0.0" + "webpack": "^5.0.0" }, "peerDependenciesMeta": { "webpack": { @@ -4003,17 +4136,12 @@ "integrity": "sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==", "dev": true }, - "node_modules/wrappy": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", - "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", - "dev": true - }, "node_modules/ws": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", - "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", + "version": "8.18.2", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz", + "integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=10.0.0" }, @@ -4088,9 +4216,9 @@ } }, "@leichtgewicht/ip-codec": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", - "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==", + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.5.tgz", + "integrity": "sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==", "dev": true }, "@nodelib/fs.scandir": { @@ -4136,9 +4264,9 @@ } }, "@types/bonjour": { - "version": "3.5.11", - "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.11.tgz", - "integrity": "sha512-isGhjmBtLIxdHBDl2xGwUzEM8AOyOvWsADWq7rqirdi/ZQoHnLWErHvsThcEzTX8juDRiZtzp2Qkv5bgNh6mAg==", + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.13.tgz", + "integrity": "sha512-z9fJ5Im06zvUL548KvYNecEVlA7cVDkGUi6kZusb04mpyEFKCIZJvloCcmpmLaIahDpOQGHaHmG6imtPMmPXGQ==", "dev": true, "requires": { "@types/node": "*" @@ -4154,9 +4282,9 @@ } }, "@types/connect-history-api-fallback": { - "version": "1.5.1", - "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.1.tgz", - "integrity": "sha512-iaQslNbARe8fctL5Lk+DsmgWOM83lM+7FzP0eQUJs1jd3kBE8NWqBTIT2S8SqQOJjxvt2eyIjpOuYeRXq2AdMw==", + "version": "1.5.4", + "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.4.tgz", + "integrity": "sha512-n6Cr2xS1h4uAulPRdlw6Jl6s1oG8KrVilPN2yUITEs+K48EzMJJ3W1xy8K5eWuFvjp3R74AOIGSmp2UfBJ8HFw==", "dev": true, "requires": { "@types/express-serve-static-core": "*", @@ -4170,9 +4298,9 @@ "dev": true }, "@types/express": { - "version": "4.17.17", - "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.17.tgz", - "integrity": "sha512-Q4FmmuLGBG58btUnfS1c1r/NQdlp3DMfGDGig8WhfpA2YRUtEkxAjkZb0yvplJGYdF1fsQ81iMDcH24sSCNC/Q==", + "version": "4.17.22", + "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.22.tgz", + "integrity": "sha512-eZUmSnhRX9YRSkplpz0N+k6NljUUn5l3EWZIKZvYzhvMphEuNiyyy1viH/ejgt66JWgALwC/gtSUAeQKtSwW/w==", "dev": true, "requires": { "@types/body-parser": "*", @@ -4226,6 +4354,15 @@ "integrity": "sha512-HksnYH4Ljr4VQgEy2lTStbCKv/P590tmPe5HqOnv9Gprffgv5WXAY+Y5Gqniu0GGqeTCUdBnzC3QSrzPkBkAMA==", "dev": true }, + "@types/node-forge": { + "version": "1.3.11", + "resolved": "https://registry.npmjs.org/@types/node-forge/-/node-forge-1.3.11.tgz", + "integrity": "sha512-FQx220y22OKNTqaByeBGqHWYz4cl94tpcxeFdvBo3wjG6XPBuZ0BNgNZRV5J5TFmmcsJ4IzsLkmGRiQbnYsBEQ==", + "dev": true, + "requires": { + "@types/node": "*" + } + }, "@types/qs": { "version": "6.9.8", "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.8.tgz", @@ -4239,9 +4376,9 @@ "dev": true }, "@types/retry": { - "version": "0.12.0", - "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.0.tgz", - "integrity": "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==", + "version": "0.12.2", + "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.2.tgz", + "integrity": "sha512-XISRgDJ2Tc5q4TRqvgJtzsRkFYNJzZrhTdtMoGVBttwzzQJkPnS3WWTFc7kuDRoPtPakl+T+OfdEUjYJj7Jbow==", "dev": true }, "@types/send": { @@ -4255,38 +4392,38 @@ } }, "@types/serve-index": { - "version": "1.9.1", - "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.1.tgz", - "integrity": "sha512-d/Hs3nWDxNL2xAczmOVZNj92YZCS6RGxfBPjKzuu/XirCgXdpKEb88dYNbrYGint6IVWLNP+yonwVAuRC0T2Dg==", + "version": "1.9.4", + "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.4.tgz", + "integrity": "sha512-qLpGZ/c2fhSs5gnYsQxtDEq3Oy8SXPClIXkW5ghvAvsNuVSA8k+gCONcUCS/UjLEYvYps+e8uBtfgXgvhwfNug==", "dev": true, "requires": { "@types/express": "*" } }, "@types/serve-static": { - "version": "1.15.2", - "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.2.tgz", - "integrity": "sha512-J2LqtvFYCzaj8pVYKw8klQXrLLk7TBZmQ4ShlcdkELFKGwGMfevMLneMMRkMgZxotOD9wg497LpC7O8PcvAmfw==", + "version": "1.15.7", + "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.7.tgz", + "integrity": "sha512-W8Ym+h8nhuRwaKPaDw34QUkwsGi6Rc4yYqvKFo5rm2FUEhCFbzVWrxXUxuKK8TASjWsysJY0nsmNCGhCOIsrOw==", "dev": true, "requires": { "@types/http-errors": "*", - "@types/mime": "*", - "@types/node": "*" + "@types/node": "*", + "@types/send": "*" } }, "@types/sockjs": { - "version": "0.3.33", - "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.33.tgz", - "integrity": "sha512-f0KEEe05NvUnat+boPTZ0dgaLZ4SfSouXUgv5noUiefG2ajgKjmETo9ZJyuqsl7dfl2aHlLJUiki6B4ZYldiiw==", + "version": "0.3.36", + "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.36.tgz", + "integrity": "sha512-MK9V6NzAS1+Ud7JV9lJLFqW85VbC9dq3LmwZCuBe4wBDgKC0Kj/jd8Xl+nSviU+Qc3+m7umHHyHg//2KSa0a0Q==", "dev": true, "requires": { "@types/node": "*" } }, "@types/ws": { - "version": "8.5.5", - "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.5.tgz", - "integrity": "sha512-lwhs8hktwxSjf9UaZ9tG5M03PGogvFaH8gUgLNbN9HKIg0dvv6q+gkSuJ8HN4/VbyxkuLzCjlN7GquQ0gUJfIg==", + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", "dev": true, "requires": { "@types/node": "*" @@ -4559,15 +4696,9 @@ } }, "array-flatten": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-2.1.2.tgz", - "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", - "dev": true - }, - "balanced-match": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", - "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", "dev": true }, "batch": { @@ -4577,9 +4708,9 @@ "dev": true }, "binary-extensions": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", - "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", + "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", "dev": true }, "body-parser": { @@ -4602,12 +4733,6 @@ "unpipe": "1.0.0" }, "dependencies": { - "bytes": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", - "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", - "dev": true - }, "debug": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", @@ -4626,25 +4751,13 @@ } }, "bonjour-service": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.1.1.tgz", - "integrity": "sha512-Z/5lQRMOG9k7W+FkeGTNjh7htqn/2LMnfOvBZ8pynNZCM9MwkQkI3zeI4oz09uWdcgmgHugVvBqxGg4VQJ5PCg==", - "dev": true, - "requires": { - "array-flatten": "^2.1.2", - "dns-equal": "^1.0.0", - "fast-deep-equal": "^3.1.3", - "multicast-dns": "^7.2.5" - } - }, - "brace-expansion": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", - "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.3.0.tgz", + "integrity": "sha512-3YuAUiSkWykd+2Azjgyxei8OWf8thdn8AITIog2M4UICzoqfjlqr64WIjEXZllf/W6vK1goqleSR6brGomxQqA==", "dev": true, "requires": { - "balanced-match": "^1.0.0", - "concat-map": "0.0.1" + "fast-deep-equal": "^3.1.3", + "multicast-dns": "^7.2.5" } }, "braces": { @@ -4674,23 +4787,39 @@ "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", "dev": true }, + "bundle-name": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/bundle-name/-/bundle-name-4.1.0.tgz", + "integrity": "sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q==", + "dev": true, + "requires": { + "run-applescript": "^7.0.0" + } + }, "bytes": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", - "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", "dev": true }, - "call-bind": { - "version": "1.0.7", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", - "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", + "call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", "dev": true, "requires": { - "es-define-property": "^1.0.0", "es-errors": "^1.3.0", - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.4", - "set-function-length": "^1.2.1" + "function-bind": "^1.1.2" + } + }, + "call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "dev": true, + "requires": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" } }, "caniuse-lite": { @@ -4700,9 +4829,9 @@ "dev": true }, "chokidar": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", - "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", + "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", "dev": true, "requires": { "anymatch": "~3.1.2", @@ -4757,17 +4886,17 @@ } }, "compression": { - "version": "1.7.4", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.7.4.tgz", - "integrity": "sha512-jaSIDzP9pZVS4ZfQ+TzvtiWhdpFhE2RDHz8QJkpX9SIpLq88VueF5jJw6t+6CUQcAoA6t+x89MLrWAqpfDE8iQ==", + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", + "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", "dev": true, "requires": { - "accepts": "~1.3.5", - "bytes": "3.0.0", - "compressible": "~2.0.16", + "bytes": "3.1.2", + "compressible": "~2.0.18", "debug": "2.6.9", - "on-headers": "~1.0.2", - "safe-buffer": "5.1.2", + "negotiator": "~0.6.4", + "on-headers": "~1.1.0", + "safe-buffer": "5.2.1", "vary": "~1.1.2" }, "dependencies": { @@ -4779,15 +4908,21 @@ "requires": { "ms": "2.0.0" } + }, + "negotiator": { + "version": "0.6.4", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.4.tgz", + "integrity": "sha512-myRT3DiWPHqho5PrJaIRyaMv2kgYf0mUVgBNOYMuCH5Ki1yEiQaf/ZJuQ62nvpc44wL5WDbTX7yGJi1Neevw8w==", + "dev": true + }, + "safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "dev": true } } }, - "concat-map": { - "version": "0.0.1", - "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", - "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", - "dev": true - }, "connect-history-api-fallback": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/connect-history-api-fallback/-/connect-history-api-fallback-2.0.0.tgz", @@ -4930,30 +5065,26 @@ } } }, - "default-gateway": { - "version": "6.0.3", - "resolved": "https://registry.npmjs.org/default-gateway/-/default-gateway-6.0.3.tgz", - "integrity": "sha512-fwSOJsbbNzZ/CUFpqFBqYfYNLj1NbMPm8MMCIzHjC83iSJRBEGmDUxU+WP661BaBQImeC2yHwXtz+P/O9o+XEg==", + "default-browser": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/default-browser/-/default-browser-5.2.1.tgz", + "integrity": "sha512-WY/3TUME0x3KPYdRRxEJJvXRHV4PyPoUsxtZa78lwItwRQRHhd2U9xOscaT/YTf8uCXIAjeJOFBVEh/7FtD8Xg==", "dev": true, "requires": { - "execa": "^5.0.0" + "bundle-name": "^4.1.0", + "default-browser-id": "^5.0.0" } }, - "define-data-property": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", - "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", - "dev": true, - "requires": { - "es-define-property": "^1.0.0", - "es-errors": "^1.3.0", - "gopd": "^1.0.1" - } + "default-browser-id": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/default-browser-id/-/default-browser-id-5.0.0.tgz", + "integrity": "sha512-A6p/pu/6fyBcA1TRz/GqWYPViplrftcW2gZC9q79ngNCKAeR/X3gcEdXQHl4KNXV+3wgIJ1CPkJQ3IHM6lcsyA==", + "dev": true }, "define-lazy-prop": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", - "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-3.0.0.tgz", + "integrity": "sha512-N+MeXYoqr3pOgn8xfyRPREN7gHakLYjhsHhWGT3fWAiL4IkAt0iDw14QiiEm2bE30c5XX5q0FtAA3CK5f9/BUg==", "dev": true }, "depd": { @@ -4974,12 +5105,6 @@ "integrity": "sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw==", "dev": true }, - "dns-equal": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/dns-equal/-/dns-equal-1.0.0.tgz", - "integrity": "sha512-z+paD6YUQsk+AbGCEM4PrOXSss5gd66QfcVBFTKR/HpFL9jCqikS94HYwKww6fQyO7IxrIIyUu+g0Ka9tUS2Cg==", - "dev": true - }, "dns-packet": { "version": "5.6.1", "resolved": "https://registry.npmjs.org/dns-packet/-/dns-packet-5.6.1.tgz", @@ -4989,6 +5114,17 @@ "@leichtgewicht/ip-codec": "^2.0.1" } }, + "dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "dev": true, + "requires": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + } + }, "ee-first": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", @@ -5002,9 +5138,9 @@ "dev": true }, "encodeurl": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", - "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", "dev": true }, "enhanced-resolve": { @@ -5024,13 +5160,10 @@ "dev": true }, "es-define-property": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", - "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", - "dev": true, - "requires": { - "get-intrinsic": "^1.2.4" - } + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "dev": true }, "es-errors": { "version": "1.3.0", @@ -5044,6 +5177,15 @@ "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", "dev": true }, + "es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "dev": true, + "requires": { + "es-errors": "^1.3.0" + } + }, "escalade": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", @@ -5107,27 +5249,10 @@ "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", "dev": true }, - "execa": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/execa/-/execa-5.1.1.tgz", - "integrity": "sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==", - "dev": true, - "requires": { - "cross-spawn": "^7.0.3", - "get-stream": "^6.0.0", - "human-signals": "^2.1.0", - "is-stream": "^2.0.0", - "merge-stream": "^2.0.0", - "npm-run-path": "^4.0.1", - "onetime": "^5.1.2", - "signal-exit": "^3.0.3", - "strip-final-newline": "^2.0.0" - } - }, "express": { - "version": "4.21.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", - "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", + "version": "4.21.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.2.tgz", + "integrity": "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==", "dev": true, "requires": { "accepts": "~1.3.8", @@ -5149,7 +5274,7 @@ "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.10", + "path-to-regexp": "0.1.12", "proxy-addr": "~2.0.7", "qs": "6.13.0", "range-parser": "~1.2.1", @@ -5163,12 +5288,6 @@ "vary": "~1.1.2" }, "dependencies": { - "array-flatten": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", - "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", - "dev": true - }, "debug": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", @@ -5184,12 +5303,6 @@ "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true }, - "encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true - }, "safe-buffer": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", @@ -5292,12 +5405,6 @@ "ms": "2.0.0" } }, - "encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true - }, "statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", @@ -5334,18 +5441,6 @@ "integrity": "sha512-zJ2mQYM18rEFOudeV4GShTGIQ7RbzA7ozbU9I/XBpm7kqgMywgmylMwXHxZJmkVoYkna9d2pVXVXPdYTP9ej8Q==", "dev": true }, - "fs-monkey": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.4.tgz", - "integrity": "sha512-INM/fWAxMICjttnD0DX1rBvinKskj5G1w+oy/pnm9u/tSlnBrzFonJMcalKJ30P8RRsPzKcCG7Q8l0jx5Fh9YQ==", - "dev": true - }, - "fs.realpath": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", - "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", - "dev": true - }, "fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", @@ -5360,36 +5455,31 @@ "dev": true }, "get-intrinsic": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", - "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", "dev": true, "requires": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", "function-bind": "^1.1.2", - "has-proto": "^1.0.1", - "has-symbols": "^1.0.3", - "hasown": "^2.0.0" + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" } }, - "get-stream": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz", - "integrity": "sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==", - "dev": true - }, - "glob": { - "version": "7.2.3", - "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", - "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", "dev": true, "requires": { - "fs.realpath": "^1.0.0", - "inflight": "^1.0.4", - "inherits": "2", - "minimatch": "^3.1.1", - "once": "^1.3.0", - "path-is-absolute": "^1.0.0" + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" } }, "glob-parent": { @@ -5422,13 +5512,10 @@ } }, "gopd": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", - "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", - "dev": true, - "requires": { - "get-intrinsic": "^1.1.3" - } + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "dev": true }, "graceful-fs": { "version": "4.2.11", @@ -5457,25 +5544,10 @@ "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", "dev": true }, - "has-property-descriptors": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", - "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", - "dev": true, - "requires": { - "es-define-property": "^1.0.0" - } - }, - "has-proto": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", - "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", - "dev": true - }, "has-symbols": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", - "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", "dev": true }, "hasown": { @@ -5499,12 +5571,6 @@ "wbuf": "^1.1.0" } }, - "html-entities": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/html-entities/-/html-entities-2.4.0.tgz", - "integrity": "sha512-igBTJcNNNhvZFRtm8uA6xMY6xYleeDwn3PeBCkDz7tHttv4F2hsDI2aPgNERWzvRcNYHNT3ymRaQzllmXj4YsQ==", - "dev": true - }, "http-deceiver": { "version": "1.2.7", "resolved": "https://registry.npmjs.org/http-deceiver/-/http-deceiver-1.2.7.tgz", @@ -5562,9 +5628,9 @@ } }, "http-proxy-middleware": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.6.tgz", - "integrity": "sha512-ya/UeJ6HVBYxrgYotAZo1KvPWlgB48kUJLDePFeneHsVujFaW5WNj2NgWCAE//B1Dl02BIfYlpNgBy8Kf8Rjmw==", + "version": "2.0.9", + "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz", + "integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==", "dev": true, "requires": { "@types/http-proxy": "^1.17.8", @@ -5574,10 +5640,10 @@ "micromatch": "^4.0.2" } }, - "human-signals": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-2.1.0.tgz", - "integrity": "sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==", + "hyperdyperid": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/hyperdyperid/-/hyperdyperid-1.2.0.tgz", + "integrity": "sha512-Y93lCzHYgGWdrJ66yIktxiaGULYc6oGiABxhcO5AufBeOyoIdZF7bIfLaOrbM0iGIOXQQgxxRrFEnb+Y6w1n4A==", "dev": true }, "iconv-lite": { @@ -5605,16 +5671,6 @@ "resolve-cwd": "^3.0.0" } }, - "inflight": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", - "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", - "dev": true, - "requires": { - "once": "^1.3.0", - "wrappy": "1" - } - }, "inherits": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", @@ -5652,9 +5708,9 @@ } }, "is-docker": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-2.2.1.tgz", - "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-3.0.0.tgz", + "integrity": "sha512-eljcgEDlEns/7AXFosB5K/2nCM4P7FQPkGc/DWLy5rmFEWvZayGrik1d9/QIY5nJ4f9YsVvBkA6kJpHn9rISdQ==", "dev": true }, "is-extglob": { @@ -5672,6 +5728,21 @@ "is-extglob": "^2.1.1" } }, + "is-inside-container": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-inside-container/-/is-inside-container-1.0.0.tgz", + "integrity": "sha512-KIYLCCJghfHZxqjYBE7rEy0OBuTd5xCHS7tHVgvCLkx7StIoaxwNW3hCALgEUjFfeRk+MG/Qxmp/vtETEF3tRA==", + "dev": true, + "requires": { + "is-docker": "^3.0.0" + } + }, + "is-network-error": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-network-error/-/is-network-error-1.1.0.tgz", + "integrity": "sha512-tUdRRAnhT+OtCZR/LxZelH/C7QtjtFrTu5tXCA8pl55eTUElUHT+GPYV8MBMBvea/j+NxQqVt3LbWMRir7Gx9g==", + "dev": true + }, "is-number": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", @@ -5693,19 +5764,13 @@ "isobject": "^3.0.1" } }, - "is-stream": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", - "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", - "dev": true - }, "is-wsl": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", - "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-3.1.0.tgz", + "integrity": "sha512-UcVfVfaK4Sc4m7X3dUSoHoozQGBEFeDC+zVo06t98xe8CzHSZZBekNXH+tu0NalHolcJ/QAGqS46Hef7QXBIMw==", "dev": true, "requires": { - "is-docker": "^2.0.0" + "is-inside-container": "^1.0.0" } }, "isarray": { @@ -5756,13 +5821,13 @@ "dev": true }, "launch-editor": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.6.0.tgz", - "integrity": "sha512-JpDCcQnyAAzZZaZ7vEiSqL690w7dAEyLao+KC96zBplnYbJS7TYNjvM3M7y3dGz+v7aIsJk3hllWuc0kWAjyRQ==", + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.10.0.tgz", + "integrity": "sha512-D7dBRJo/qcGX9xlvt/6wUYzQxjh5G1RvZPgPv8vi4KRU99DVQL/oW7tnVOCCTm2HGeo3C5HvGE5Yrh6UBoZ0vA==", "dev": true, "requires": { "picocolors": "^1.0.0", - "shell-quote": "^1.7.3" + "shell-quote": "^1.8.1" } }, "loader-runner": { @@ -5780,6 +5845,12 @@ "p-locate": "^4.1.0" } }, + "math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "dev": true + }, "media-typer": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", @@ -5787,12 +5858,63 @@ "dev": true }, "memfs": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/memfs/-/memfs-3.5.3.tgz", - "integrity": "sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==", + "version": "4.17.2", + "resolved": "https://registry.npmjs.org/memfs/-/memfs-4.17.2.tgz", + "integrity": "sha512-NgYhCOWgovOXSzvYgUW0LQ7Qy72rWQMGGFJDoWg4G30RHd3z77VbYdtJ4fembJXBy8pMIUA31XNAupobOQlwdg==", "dev": true, "requires": { - "fs-monkey": "^1.0.4" + "@jsonjoy.com/json-pack": "^1.0.3", + "@jsonjoy.com/util": "^1.3.0", + "tree-dump": "^1.0.1", + "tslib": "^2.0.0" + }, + "dependencies": { + "@jsonjoy.com/base64": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/base64/-/base64-1.1.2.tgz", + "integrity": "sha512-q6XAnWQDIMA3+FTiOYajoYqySkO+JSat0ytXGSuRdq9uXE7o92gzuQwQM14xaCRlBLGq3v5miDGC4vkVTn54xA==", + "dev": true, + "requires": {} + }, + "@jsonjoy.com/json-pack": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/json-pack/-/json-pack-1.2.0.tgz", + "integrity": "sha512-io1zEbbYcElht3tdlqEOFxZ0dMTYrHz9iMf0gqn1pPjZFTCgM5R4R5IMA20Chb2UPYYsxjzs8CgZ7Nb5n2K2rA==", + "dev": true, + "requires": { + "@jsonjoy.com/base64": "^1.1.1", + "@jsonjoy.com/util": "^1.1.2", + "hyperdyperid": "^1.2.0", + "thingies": "^1.20.0" + } + }, + "@jsonjoy.com/util": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/util/-/util-1.6.0.tgz", + "integrity": "sha512-sw/RMbehRhN68WRtcKCpQOPfnH6lLP4GJfqzi3iYej8tnzpZUDr6UkZYJjcjjC0FWEJOJbyM3PTIwxucUmDG2A==", + "dev": true, + "requires": {} + }, + "thingies": { + "version": "1.21.0", + "resolved": "https://registry.npmjs.org/thingies/-/thingies-1.21.0.tgz", + "integrity": "sha512-hsqsJsFMsV+aD4s3CWKk85ep/3I9XzYV/IXaSouJMYIoDlgyi11cBhsqYe9/geRfB0YIikBQg6raRaM+nIMP9g==", + "dev": true, + "requires": {} + }, + "tree-dump": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/tree-dump/-/tree-dump-1.0.3.tgz", + "integrity": "sha512-il+Cv80yVHFBwokQSfd4bldvr1Md951DpgAGfmhydt04L+YzHgubm2tQ7zueWDcGENKHq0ZvGFR/hjvNXilHEg==", + "dev": true, + "requires": {} + }, + "tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true + } } }, "merge-descriptors": { @@ -5850,27 +5972,12 @@ "mime-db": "1.52.0" } }, - "mimic-fn": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.1.0.tgz", - "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", - "dev": true - }, "minimalistic-assert": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", "dev": true }, - "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", - "dev": true, - "requires": { - "brace-expansion": "^1.1.7" - } - }, "ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", @@ -5917,19 +6024,10 @@ "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", "dev": true }, - "npm-run-path": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", - "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", - "dev": true, - "requires": { - "path-key": "^3.0.0" - } - }, "object-inspect": { - "version": "1.13.2", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", - "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", + "version": "1.13.4", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", "dev": true }, "obuf": { @@ -5948,38 +6046,21 @@ } }, "on-headers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", - "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", + "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", "dev": true }, - "once": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", - "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", - "dev": true, - "requires": { - "wrappy": "1" - } - }, - "onetime": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.2.tgz", - "integrity": "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==", - "dev": true, - "requires": { - "mimic-fn": "^2.1.0" - } - }, "open": { - "version": "8.4.2", - "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", - "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", + "version": "10.1.2", + "resolved": "https://registry.npmjs.org/open/-/open-10.1.2.tgz", + "integrity": "sha512-cxN6aIDPz6rm8hbebcP7vrQNhvRcveZoJU72Y7vskh4oIm+BZwBECnx5nTmrlres1Qapvx27Qo1Auukpf8PKXw==", "dev": true, "requires": { - "define-lazy-prop": "^2.0.0", - "is-docker": "^2.1.1", - "is-wsl": "^2.2.0" + "default-browser": "^5.2.1", + "define-lazy-prop": "^3.0.0", + "is-inside-container": "^1.0.0", + "is-wsl": "^3.1.0" } }, "p-locate": { @@ -6003,12 +6084,13 @@ } }, "p-retry": { - "version": "4.6.2", - "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", - "integrity": "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ==", + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-6.2.1.tgz", + "integrity": "sha512-hEt02O4hUct5wtwg4H4KcWgDdm+l1bOaEy/hWzd8xtXB9BqxTWBBhb+2ImAtH4Cv4rPjV76xN3Zumqk3k3AhhQ==", "dev": true, "requires": { - "@types/retry": "0.12.0", + "@types/retry": "0.12.2", + "is-network-error": "^1.0.0", "retry": "^0.13.1" } }, @@ -6030,12 +6112,6 @@ "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", "dev": true }, - "path-is-absolute": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", - "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", - "dev": true - }, "path-key": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", @@ -6049,9 +6125,9 @@ "dev": true }, "path-to-regexp": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", - "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==", + "version": "0.1.12", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.12.tgz", + "integrity": "sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==", "dev": true }, "path-type": { @@ -6151,14 +6227,6 @@ "http-errors": "2.0.0", "iconv-lite": "0.4.24", "unpipe": "1.0.0" - }, - "dependencies": { - "bytes": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", - "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", - "dev": true - } } }, "readable-stream": { @@ -6244,14 +6312,11 @@ "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", "dev": true }, - "rimraf": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", - "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", - "dev": true, - "requires": { - "glob": "^7.1.3" - } + "run-applescript": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/run-applescript/-/run-applescript-7.0.0.tgz", + "integrity": "sha512-9by4Ij99JUr/MCFBUkDKLWK3G9HVXmabKz9U5MlIAIuvuzkiOicRYs8XJLxX+xahD+mLiiCYDqF9dKAgtzKP1A==", + "dev": true }, "run-parallel": { "version": "1.2.0", @@ -6292,11 +6357,12 @@ "dev": true }, "selfsigned": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.1.1.tgz", - "integrity": "sha512-GSL3aowiF7wa/WtSFwnUrludWFoNhftq8bUkH9pkzjpN2XSPOAYEgg6e0sS9s0rZwgJzJiQRPU18A6clnoW5wQ==", + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.4.1.tgz", + "integrity": "sha512-th5B4L2U+eGLq1TVh7zNRGBapioSORUeymIydxgFpwww9d2qyKvtuPU2jJuHvYAwwqi2Y596QBL3eEqcPEYL8Q==", "dev": true, "requires": { + "@types/node-forge": "^1.3.0", "node-forge": "^1" } }, @@ -6344,6 +6410,12 @@ "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true }, + "encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "dev": true + }, "ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", @@ -6421,28 +6493,6 @@ "escape-html": "~1.0.3", "parseurl": "~1.3.3", "send": "0.19.0" - }, - "dependencies": { - "encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true - } - } - }, - "set-function-length": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", - "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", - "dev": true, - "requires": { - "define-data-property": "^1.1.4", - "es-errors": "^1.3.0", - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.4", - "gopd": "^1.0.1", - "has-property-descriptors": "^1.0.2" } }, "setprototypeof": { @@ -6476,28 +6526,58 @@ "dev": true }, "shell-quote": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz", - "integrity": "sha512-6j1W9l1iAs/4xYBI1SYOVZyFcCis9b4KCLQ8fgAGG07QvzaRLVVRQvAy85yNmmZSjYjg4MWh4gNvlPujU/5LpA==", + "version": "1.8.3", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz", + "integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==", "dev": true }, "side-channel": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", - "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", "dev": true, "requires": { - "call-bind": "^1.0.7", "es-errors": "^1.3.0", - "get-intrinsic": "^1.2.4", - "object-inspect": "^1.13.1" + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" } }, - "signal-exit": { - "version": "3.0.7", - "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", - "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", - "dev": true + "side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "dev": true, + "requires": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + } + }, + "side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "dev": true, + "requires": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + } + }, + "side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "dev": true, + "requires": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + } }, "slash": { "version": "5.1.0", @@ -6587,12 +6667,6 @@ "safe-buffer": "~5.1.0" } }, - "strip-final-newline": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/strip-final-newline/-/strip-final-newline-2.0.0.tgz", - "integrity": "sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==", - "dev": true - }, "supports-color": { "version": "8.1.1", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", @@ -6819,28 +6893,29 @@ } }, "webpack-dev-middleware": { - "version": "5.3.4", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", - "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", + "version": "7.4.2", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-7.4.2.tgz", + "integrity": "sha512-xOO8n6eggxnwYpy1NlzUKpvrjfJTvae5/D6WOK0S2LSo7vjmo5gCM1DbLUmFqrMTJP+W/0YZNctm7jasWvLuBA==", "dev": true, "requires": { "colorette": "^2.0.10", - "memfs": "^3.4.3", + "memfs": "^4.6.0", "mime-types": "^2.1.31", + "on-finished": "^2.4.1", "range-parser": "^1.2.1", "schema-utils": "^4.0.0" }, "dependencies": { "ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "requires": { - "fast-deep-equal": "^3.1.1", + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" + "require-from-string": "^2.0.2" } }, "ajv-keywords": { @@ -6859,9 +6934,9 @@ "dev": true }, "schema-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", - "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", + "version": "4.3.2", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", + "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", "dev": true, "requires": { "@types/json-schema": "^7.0.9", @@ -6873,41 +6948,39 @@ } }, "webpack-dev-server": { - "version": "4.15.1", - "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-4.15.1.tgz", - "integrity": "sha512-5hbAst3h3C3L8w6W4P96L5vaV0PxSmJhxZvWKYIdgxOQm8pNZ5dEOmmSLBVpP85ReeyRt6AS1QJNyo/oFFPeVA==", - "dev": true, - "requires": { - "@types/bonjour": "^3.5.9", - "@types/connect-history-api-fallback": "^1.3.5", - "@types/express": "^4.17.13", - "@types/serve-index": "^1.9.1", - "@types/serve-static": "^1.13.10", - "@types/sockjs": "^0.3.33", - "@types/ws": "^8.5.5", + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-5.2.1.tgz", + "integrity": "sha512-ml/0HIj9NLpVKOMq+SuBPLHcmbG+TGIjXRHsYfZwocUBIqEvws8NnS/V9AFQ5FKP+tgn5adwVwRrTEpGL33QFQ==", + "dev": true, + "requires": { + "@types/bonjour": "^3.5.13", + "@types/connect-history-api-fallback": "^1.5.4", + "@types/express": "^4.17.21", + "@types/express-serve-static-core": "^4.17.21", + "@types/serve-index": "^1.9.4", + "@types/serve-static": "^1.15.5", + "@types/sockjs": "^0.3.36", + "@types/ws": "^8.5.10", "ansi-html-community": "^0.0.8", - "bonjour-service": "^1.0.11", - "chokidar": "^3.5.3", + "bonjour-service": "^1.2.1", + "chokidar": "^3.6.0", "colorette": "^2.0.10", "compression": "^1.7.4", "connect-history-api-fallback": "^2.0.0", - "default-gateway": "^6.0.3", - "express": "^4.17.3", + "express": "^4.21.2", "graceful-fs": "^4.2.6", - "html-entities": "^2.3.2", - "http-proxy-middleware": "^2.0.3", - "ipaddr.js": "^2.0.1", - "launch-editor": "^2.6.0", - "open": "^8.0.9", - "p-retry": "^4.5.0", - "rimraf": "^3.0.2", - "schema-utils": "^4.0.0", - "selfsigned": "^2.1.1", + "http-proxy-middleware": "^2.0.7", + "ipaddr.js": "^2.1.0", + "launch-editor": "^2.6.1", + "open": "^10.0.3", + "p-retry": "^6.2.0", + "schema-utils": "^4.2.0", + "selfsigned": "^2.4.1", "serve-index": "^1.9.1", "sockjs": "^0.3.24", "spdy": "^4.0.2", - "webpack-dev-middleware": "^5.3.1", - "ws": "^8.13.0" + "webpack-dev-middleware": "^7.4.2", + "ws": "^8.18.0" }, "dependencies": { "ajv": { @@ -6993,16 +7066,10 @@ "integrity": "sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==", "dev": true }, - "wrappy": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", - "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", - "dev": true - }, "ws": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", - "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", + "version": "8.18.2", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz", + "integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==", "dev": true, "requires": {} } diff --git a/datafusion/wasmtest/datafusion-wasm-app/package.json b/datafusion/wasmtest/datafusion-wasm-app/package.json index 5a2262400cfd5..b46993de77d9b 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package.json @@ -29,7 +29,7 @@ "devDependencies": { "webpack": "5.94.0", "webpack-cli": "5.1.4", - "webpack-dev-server": "4.15.1", + "webpack-dev-server": "5.2.1", "copy-webpack-plugin": "12.0.2" } } diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index 6c7be9056eb43..d2efe995f100d 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -19,7 +19,7 @@ html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] extern crate wasm_bindgen; @@ -82,7 +82,6 @@ pub fn basic_parse() { #[cfg(test)] mod test { use super::*; - use datafusion::execution::options::ParquetReadOptions; use datafusion::{ arrow::{ array::{ArrayRef, Int32Array, RecordBatch, StringArray}, @@ -93,12 +92,12 @@ mod test { }; use datafusion_common::test_util::batches_to_string; use datafusion_execution::{ - config::SessionConfig, disk_manager::DiskManagerConfig, + config::SessionConfig, + disk_manager::{DiskManagerBuilder, DiskManagerMode}, runtime_env::RuntimeEnvBuilder, }; use datafusion_physical_plan::collect; use datafusion_sql::parser::DFParser; - use insta::assert_snapshot; use object_store::{memory::InMemory, path::Path, ObjectStore}; use url::Url; use wasm_bindgen_test::wasm_bindgen_test; @@ -114,7 +113,9 @@ mod test { fn get_ctx() -> Arc { let rt = RuntimeEnvBuilder::new() - .with_disk_manager(DiskManagerConfig::Disabled) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) .build_arc() .unwrap(); let session_config = SessionConfig::new().with_target_partitions(1); @@ -240,22 +241,24 @@ mod test { let url = Url::parse("memory://").unwrap(); session_ctx.register_object_store(&url, Arc::new(store)); - - let df = session_ctx - .read_parquet("memory:///", ParquetReadOptions::new()) + session_ctx + .register_parquet("a", "memory:///a.parquet", Default::default()) .await .unwrap(); + let df = session_ctx.sql("SELECT * FROM a").await.unwrap(); + let result = df.collect().await.unwrap(); - assert_snapshot!(batches_to_string(&result), @r" - +----+-------+ - | id | value | - +----+-------+ - | 1 | a | - | 2 | b | - | 3 | c | - +----+-------+ - "); + assert_eq!( + batches_to_string(&result), + "+----+-------+\n\ + | id | value |\n\ + +----+-------+\n\ + | 1 | a |\n\ + | 2 | b |\n\ + | 3 | c |\n\ + +----+-------+" + ); } } diff --git a/datafusion/wasmtest/webdriver.json b/datafusion/wasmtest/webdriver.json new file mode 100644 index 0000000000000..f59a2be9955f1 --- /dev/null +++ b/datafusion/wasmtest/webdriver.json @@ -0,0 +1,15 @@ +{ + "moz:firefoxOptions": { + "prefs": { + "media.navigator.streams.fake": true, + "media.navigator.permission.disabled": true + }, + "args": [] + }, + "goog:chromeOptions": { + "args": [ + "--use-fake-device-for-media-stream", + "--use-fake-ui-for-media-stream" + ] + } +} \ No newline at end of file diff --git a/dev/changelog/47.0.0.md b/dev/changelog/47.0.0.md new file mode 100644 index 0000000000000..64ca2e157a9e3 --- /dev/null +++ b/dev/changelog/47.0.0.md @@ -0,0 +1,506 @@ + + +# Apache DataFusion 47.0.0 Changelog + +This release consists of 364 commits from 94 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- chore: cleanup deprecated API since `version <= 40` [#15027](https://github.com/apache/datafusion/pull/15027) (qazxcdswe123) +- fix: mark ScalarUDFImpl::invoke_batch as deprecated [#15049](https://github.com/apache/datafusion/pull/15049) (Blizzara) +- feat: support customize metadata in alias for dataframe api [#15120](https://github.com/apache/datafusion/pull/15120) (chenkovsky) +- Refactor: add `FileGroup` structure for `Vec` [#15379](https://github.com/apache/datafusion/pull/15379) (xudong963) +- Change default `EXPLAIN` format in `datafusion-cli` to `tree` format [#15427](https://github.com/apache/datafusion/pull/15427) (alamb) +- Support computing statistics for FileGroup [#15432](https://github.com/apache/datafusion/pull/15432) (xudong963) +- Remove redundant statistics from FileScanConfig [#14955](https://github.com/apache/datafusion/pull/14955) (Standing-Man) +- parquet reader: move pruning predicate creation from ParquetSource to ParquetOpener [#15561](https://github.com/apache/datafusion/pull/15561) (adriangb) +- feat: Add unique id for every memory consumer [#15613](https://github.com/apache/datafusion/pull/15613) (EmilyMatt) + +**Performance related:** + +- Fix sequential metadata fetching in ListingTable causing high latency [#14918](https://github.com/apache/datafusion/pull/14918) (geoffreyclaude) +- Implement GroupsAccumulator for min/max Duration [#15322](https://github.com/apache/datafusion/pull/15322) (shruti2522) +- [Minor] Remove/reorder logical plan rules [#15421](https://github.com/apache/datafusion/pull/15421) (Dandandan) +- Improve performance of `first_value` by implementing special `GroupsAccumulator` [#15266](https://github.com/apache/datafusion/pull/15266) (UBarney) +- perf: unwrap cast for comparing ints =/!= strings [#15110](https://github.com/apache/datafusion/pull/15110) (alan910127) +- Improve performance sort TPCH q3 with Utf8Vew ( Sort-preserving mergi… [#15447](https://github.com/apache/datafusion/pull/15447) (zhuqi-lucas) +- perf: Reuse row converter during sort [#15302](https://github.com/apache/datafusion/pull/15302) (2010YOUY01) +- perf: Add TopK benchmarks as variation over the `sort_tpch` benchmarks [#15560](https://github.com/apache/datafusion/pull/15560) (geoffreyclaude) +- Perf: remove `clone` on `uninitiated_partitions` in SortPreservingMergeStream [#15562](https://github.com/apache/datafusion/pull/15562) (rluvaton) +- Add short circuit evaluation for `AND` and `OR` [#15462](https://github.com/apache/datafusion/pull/15462) (acking-you) +- perf: Introduce sort prefix computation for early TopK exit optimization on partially sorted input (10x speedup on top10 bench) [#15563](https://github.com/apache/datafusion/pull/15563) (geoffreyclaude) +- Improve performance of `last_value` by implementing special `GroupsAccumulator` [#15542](https://github.com/apache/datafusion/pull/15542) (UBarney) +- Enhance: simplify `x=x` --> `x IS NOT NULL OR NULL` [#15589](https://github.com/apache/datafusion/pull/15589) (ding-young) + +**Implemented enhancements:** + +- feat: Add `tree` / pretty explain mode [#14677](https://github.com/apache/datafusion/pull/14677) (irenjj) +- feat: Add `array_max` function support [#14470](https://github.com/apache/datafusion/pull/14470) (erenavsarogullari) +- feat: implement tree explain for `ProjectionExec` [#15082](https://github.com/apache/datafusion/pull/15082) (Standing-Man) +- feat: support ApproxDistinct with utf8view [#15200](https://github.com/apache/datafusion/pull/15200) (zhuqi-lucas) +- feat: Attach `Diagnostic` to more than one column errors in scalar_subquery and in_subquery [#15143](https://github.com/apache/datafusion/pull/15143) (changsun20) +- feat: topk functionality for aggregates should support utf8view and largeutf8 [#15152](https://github.com/apache/datafusion/pull/15152) (zhuqi-lucas) +- feat: Native support utf8view for regex string operators [#15275](https://github.com/apache/datafusion/pull/15275) (zhuqi-lucas) +- feat: introduce `JoinSetTracer` trait for tracing context propagation in spawned tasks [#14547](https://github.com/apache/datafusion/pull/14547) (geoffreyclaude) +- feat: Support serde for JsonSource PhysicalPlan [#15311](https://github.com/apache/datafusion/pull/15311) (westhide) +- feat: Support serde for FileScanConfig `batch_size` [#15335](https://github.com/apache/datafusion/pull/15335) (westhide) +- feat: simplify regex wildcard pattern [#15299](https://github.com/apache/datafusion/pull/15299) (waynexia) +- feat: Add union_by_name, union_by_name_distinct to DataFrame api [#15489](https://github.com/apache/datafusion/pull/15489) (Omega359) +- feat: Add config `max_temp_directory_size` to limit max disk usage for spilling queries [#15520](https://github.com/apache/datafusion/pull/15520) (2010YOUY01) +- feat: Add tracing regression tests [#15673](https://github.com/apache/datafusion/pull/15673) (geoffreyclaude) + +**Fixed bugs:** + +- fix: External sort failing on an edge case [#15017](https://github.com/apache/datafusion/pull/15017) (2010YOUY01) +- fix: graceful NULL and type error handling in array functions [#14737](https://github.com/apache/datafusion/pull/14737) (alan910127) +- fix: Support datatype cast for insert api same as insert into sql [#15091](https://github.com/apache/datafusion/pull/15091) (zhuqi-lucas) +- fix: unparse for subqueryalias [#15068](https://github.com/apache/datafusion/pull/15068) (chenkovsky) +- fix: date_trunc bench broken by #15049 [#15169](https://github.com/apache/datafusion/pull/15169) (Blizzara) +- fix: compound_field_access doesn't identifier qualifier. [#15153](https://github.com/apache/datafusion/pull/15153) (chenkovsky) +- fix: unparsing left/ right semi/mark join [#15212](https://github.com/apache/datafusion/pull/15212) (chenkovsky) +- fix: handle duplicate WindowFunction expressions in Substrait consumer [#15211](https://github.com/apache/datafusion/pull/15211) (Blizzara) +- fix: write hive partitions for any int/uint/float [#15337](https://github.com/apache/datafusion/pull/15337) (christophermcdermott) +- fix: `core_expressions` feature flag broken, move `overlay` into `core` functions [#15217](https://github.com/apache/datafusion/pull/15217) (shruti2522) +- fix: Redundant files spilled during external sort + introduce `SpillManager` [#15355](https://github.com/apache/datafusion/pull/15355) (2010YOUY01) +- fix: typo of DropFunction [#15434](https://github.com/apache/datafusion/pull/15434) (chenkovsky) +- fix: Unconditionally wrap UNION BY NAME input nodes w/ `Projection` [#15242](https://github.com/apache/datafusion/pull/15242) (rkrishn7) +- fix: the average time for clickbench query compute should use new vec to make it compute for each query [#15472](https://github.com/apache/datafusion/pull/15472) (zhuqi-lucas) +- fix: Assertion fail in external sort [#15469](https://github.com/apache/datafusion/pull/15469) (2010YOUY01) +- fix: aggregation corner case [#15457](https://github.com/apache/datafusion/pull/15457) (chenkovsky) +- fix: update group by columns for merge phase after spill [#15531](https://github.com/apache/datafusion/pull/15531) (rluvaton) +- fix: Queries similar to `count-bug` produce incorrect results [#15281](https://github.com/apache/datafusion/pull/15281) (suibianwanwank) +- fix: ffi aggregation [#15576](https://github.com/apache/datafusion/pull/15576) (chenkovsky) +- fix: nested window function [#15033](https://github.com/apache/datafusion/pull/15033) (chenkovsky) +- fix: dictionary encoded column to partition column casting bug [#15652](https://github.com/apache/datafusion/pull/15652) (haruband) +- fix: recursion protection for physical plan node [#15600](https://github.com/apache/datafusion/pull/15600) (chenkovsky) +- fix: add map coercion for binary ops [#15551](https://github.com/apache/datafusion/pull/15551) (alexwilcoxson-rel) +- fix: Rewrite `date_trunc` and `from_unixtime` for the SQLite unparser [#15630](https://github.com/apache/datafusion/pull/15630) (peasee) +- fix(substrait): fix regressed edge case in renaming inner struct fields [#15634](https://github.com/apache/datafusion/pull/15634) (Blizzara) +- fix: normalize window ident [#15639](https://github.com/apache/datafusion/pull/15639) (chenkovsky) +- fix: unparse join without projection [#15693](https://github.com/apache/datafusion/pull/15693) (chenkovsky) + +**Documentation updates:** + +- MINOR fix(docs): set the proper link for dev-env setup in contrib guide [#14960](https://github.com/apache/datafusion/pull/14960) (clflushopt) +- Add Upgrade Guide for DataFusion 46.0.0 [#14891](https://github.com/apache/datafusion/pull/14891) (alamb) +- Improve `SessionStateBuilder::new` documentation [#14980](https://github.com/apache/datafusion/pull/14980) (alamb) +- Minor: Replace Star and Fork buttons in docs with static versions [#14988](https://github.com/apache/datafusion/pull/14988) (amoeba) +- Fix documentation warnings and error if anymore occur [#14952](https://github.com/apache/datafusion/pull/14952) (AmosAidoo) +- docs: Improve docs on AggregateFunctionExpr construction [#15044](https://github.com/apache/datafusion/pull/15044) (ctsk) +- Minor: More comment to aggregation fuzzer [#15048](https://github.com/apache/datafusion/pull/15048) (2010YOUY01) +- Improve benchmark documentation [#15054](https://github.com/apache/datafusion/pull/15054) (carols10cents) +- doc: update RecordBatchReceiverStreamBuilder::spawn_blocking task behaviour [#14995](https://github.com/apache/datafusion/pull/14995) (shruti2522) +- doc: Correct benchmark command [#15094](https://github.com/apache/datafusion/pull/15094) (qazxcdswe123) +- Add `insta` / snapshot testing to CLI & set up AWS mock [#13672](https://github.com/apache/datafusion/pull/13672) (blaginin) +- Config: Add support default sql varchar to view types [#15104](https://github.com/apache/datafusion/pull/15104) (zhuqi-lucas) +- Support `EXPLAIN ... FORMAT ...` [#15166](https://github.com/apache/datafusion/pull/15166) (alamb) +- Update version to 46.0.1, add CHANGELOG (#15243) [#15244](https://github.com/apache/datafusion/pull/15244) (xudong963) +- docs: update documentation for Final GroupBy in accumulator.rs [#15279](https://github.com/apache/datafusion/pull/15279) (qazxcdswe123) +- minor: fix `data/sqlite` link [#15286](https://github.com/apache/datafusion/pull/15286) (sdht0) +- Add upgrade notes for array signatures [#15237](https://github.com/apache/datafusion/pull/15237) (jkosh44) +- Add doc for the `statistics_from_parquet_meta_calc method` [#15330](https://github.com/apache/datafusion/pull/15330) (xudong963) +- added explaination for Schema and DFSchema to documentation [#15329](https://github.com/apache/datafusion/pull/15329) (Jiashu-Hu) +- Documentation: Plan custom expressions [#15353](https://github.com/apache/datafusion/pull/15353) (Jiashu-Hu) +- Update concepts-readings-events.md [#15440](https://github.com/apache/datafusion/pull/15440) (berkaysynnada) +- Add support for DISTINCT + ORDER BY in `ARRAY_AGG` [#14413](https://github.com/apache/datafusion/pull/14413) (gabotechs) +- Update the copyright year [#15453](https://github.com/apache/datafusion/pull/15453) (omkenge) +- Docs: Formatting and Added Extra resources [#15450](https://github.com/apache/datafusion/pull/15450) (2SpaceMasterRace) +- Add documentation for `Run extended tests` command [#15463](https://github.com/apache/datafusion/pull/15463) (alamb) +- bench: Document how to use cross platform Samply profiler [#15481](https://github.com/apache/datafusion/pull/15481) (comphead) +- Update user guide to note decimal is not experimental anymore [#15515](https://github.com/apache/datafusion/pull/15515) (Jiashu-Hu) +- datafusion-cli: document reading partitioned parquet [#15505](https://github.com/apache/datafusion/pull/15505) (marvelshan) +- Update concepts-readings-events.md [#15541](https://github.com/apache/datafusion/pull/15541) (oznur-synnada) +- Add documentation example for `AggregateExprBuilder` [#15504](https://github.com/apache/datafusion/pull/15504) (Shreyaskr1409) +- Docs : Added Sql examples for window Functions : `nth_val` , etc [#15555](https://github.com/apache/datafusion/pull/15555) (Adez017) +- Add disk usage limit configuration to datafusion-cli [#15586](https://github.com/apache/datafusion/pull/15586) (jsai28) +- Bug fix : fix the bug in docs in 'cum_dist()' Example [#15618](https://github.com/apache/datafusion/pull/15618) (Adez017) +- Make tree the Default EXPLAIN Format and Reorder Documentation Sections [#15706](https://github.com/apache/datafusion/pull/15706) (kosiew) +- Add coerce int96 option for Parquet to support different TimeUnits, test int96_from_spark.parquet from parquet-testing [#15537](https://github.com/apache/datafusion/pull/15537) (mbutrovich) +- STRING_AGG missing functionality [#14412](https://github.com/apache/datafusion/pull/14412) (gabotechs) +- doc : update RepartitionExec display tree [#15710](https://github.com/apache/datafusion/pull/15710) (getChan) +- Update version to 47.0.0, add CHANGELOG [#15731](https://github.com/apache/datafusion/pull/15731) (xudong963) + +**Other:** + +- Improve documentation for `DataSourceExec`, `FileScanConfig`, `DataSource` etc [#14941](https://github.com/apache/datafusion/pull/14941) (alamb) +- Do not swap with projection when file is partitioned [#14956](https://github.com/apache/datafusion/pull/14956) (blaginin) +- Minor: Add more projection pushdown tests, clarify comments [#14963](https://github.com/apache/datafusion/pull/14963) (alamb) +- Update labeler components [#14942](https://github.com/apache/datafusion/pull/14942) (alamb) +- Deprecate `Expr::Wildcard` [#14959](https://github.com/apache/datafusion/pull/14959) (linhr) +- Minor: use FileScanConfig builder API in some tests [#14938](https://github.com/apache/datafusion/pull/14938) (alamb) +- Minor: improve documentation of `AggregateMode` [#14946](https://github.com/apache/datafusion/pull/14946) (alamb) +- chore(deps): bump thiserror from 2.0.11 to 2.0.12 [#14971](https://github.com/apache/datafusion/pull/14971) (dependabot[bot]) +- chore(deps): bump pyo3 from 0.23.4 to 0.23.5 [#14972](https://github.com/apache/datafusion/pull/14972) (dependabot[bot]) +- chore(deps): bump async-trait from 0.1.86 to 0.1.87 [#14973](https://github.com/apache/datafusion/pull/14973) (dependabot[bot]) +- Fix verification script and extended tests due to `rustup` changes [#14990](https://github.com/apache/datafusion/pull/14990) (alamb) +- Split out avro, parquet, json and csv into individual crates [#14951](https://github.com/apache/datafusion/pull/14951) (AdamGS) +- Minor: Add `backtrace` feature in datafusion-cli [#14997](https://github.com/apache/datafusion/pull/14997) (2010YOUY01) +- chore: Update `SessionStateBuilder::with_default_features` does not replace existing features [#14935](https://github.com/apache/datafusion/pull/14935) (irenjj) +- Make `create_ordering` pub and add doc for it [#14996](https://github.com/apache/datafusion/pull/14996) (xudong963) +- Simplify Between expression to Eq [#14994](https://github.com/apache/datafusion/pull/14994) (jayzhan211) +- Count wildcard alias [#14927](https://github.com/apache/datafusion/pull/14927) (jayzhan211) +- replace TypeSignature::String with TypeSignature::Coercible [#14917](https://github.com/apache/datafusion/pull/14917) (zjregee) +- Minor: Add indentation to EnforceDistribution test plans. [#15007](https://github.com/apache/datafusion/pull/15007) (wiedld) +- Minor: add method `SessionStateBuilder::new_with_default_features()` [#14998](https://github.com/apache/datafusion/pull/14998) (shruti2522) +- Implement `tree` explain for FilterExec [#15001](https://github.com/apache/datafusion/pull/15001) (alamb) +- Unparser add `AtArrow` and `ArrowAt` conversion to BinaryOperator [#14968](https://github.com/apache/datafusion/pull/14968) (cetra3) +- Add dependency checks to verify-release-candidate script [#15009](https://github.com/apache/datafusion/pull/15009) (waynexia) +- Fix: to_char Function Now Correctly Handles DATE Values in DataFusion [#14970](https://github.com/apache/datafusion/pull/14970) (kosiew) +- Make Substrait Schema Structs always non-nullable [#15011](https://github.com/apache/datafusion/pull/15011) (amoeba) +- Adjust physical optimizer rule order, put `ProjectionPushdown` at last [#15040](https://github.com/apache/datafusion/pull/15040) (xudong963) +- Move `UnwrapCastInComparison` into `Simplifier` [#15012](https://github.com/apache/datafusion/pull/15012) (jayzhan211) +- chore(deps): bump aws-config from 1.5.17 to 1.5.18 [#15041](https://github.com/apache/datafusion/pull/15041) (dependabot[bot]) +- chore(deps): bump bytes from 1.10.0 to 1.10.1 [#15042](https://github.com/apache/datafusion/pull/15042) (dependabot[bot]) +- Minor: Deprecate `ScalarValue::raw_data` [#15016](https://github.com/apache/datafusion/pull/15016) (qazxcdswe123) +- Implement tree explain for `DataSourceExec` [#15029](https://github.com/apache/datafusion/pull/15029) (alamb) +- Refactor test suite in EnforceDistribution, to use standard test config. [#15010](https://github.com/apache/datafusion/pull/15010) (wiedld) +- Update ring to v0.17.13 [#15063](https://github.com/apache/datafusion/pull/15063) (alamb) +- Remove deprecated function `OptimizerRule::try_optimize` [#15051](https://github.com/apache/datafusion/pull/15051) (qazxcdswe123) +- Minor: fix CI to make the sqllogic testing result consistent [#15059](https://github.com/apache/datafusion/pull/15059) (zhuqi-lucas) +- Refactor SortPushdown using the standard top-down visitor and using `EquivalenceProperties` [#14821](https://github.com/apache/datafusion/pull/14821) (wiedld) +- Improve explain tree formatting for longer lines / word wrap [#15031](https://github.com/apache/datafusion/pull/15031) (irenjj) +- chore(deps): bump sqllogictest from 0.27.2 to 0.28.0 [#15060](https://github.com/apache/datafusion/pull/15060) (dependabot[bot]) +- chore(deps): bump async-compression from 0.4.18 to 0.4.19 [#15061](https://github.com/apache/datafusion/pull/15061) (dependabot[bot]) +- Handle columns in with_new_exprs with a Join [#15055](https://github.com/apache/datafusion/pull/15055) (delamarch3) +- Minor: Improve documentation of `need_handle_count_bug` [#15050](https://github.com/apache/datafusion/pull/15050) (suibianwanwank) +- Implement `tree` explain for `HashJoinExec` [#15079](https://github.com/apache/datafusion/pull/15079) (irenjj) +- Implement tree explain for PartialSortExec [#15066](https://github.com/apache/datafusion/pull/15066) (irenjj) +- Implement `tree` explain for `SortExec` [#15077](https://github.com/apache/datafusion/pull/15077) (irenjj) +- Minor: final `46.0.0` release tweaks: changelog + instructions [#15073](https://github.com/apache/datafusion/pull/15073) (alamb) +- Implement tree explain for `NestedLoopJoinExec`, `CrossJoinExec`, `So… [#15081](https://github.com/apache/datafusion/pull/15081) (irenjj) +- Implement `tree` explain for `BoundedWindowAggExec` and `WindowAggExec` [#15084](https://github.com/apache/datafusion/pull/15084) (irenjj) +- implement tree rendering for StreamingTableExec [#15085](https://github.com/apache/datafusion/pull/15085) (Standing-Man) +- chore(deps): bump semver from 1.0.25 to 1.0.26 [#15116](https://github.com/apache/datafusion/pull/15116) (dependabot[bot]) +- chore(deps): bump clap from 4.5.30 to 4.5.31 [#15115](https://github.com/apache/datafusion/pull/15115) (dependabot[bot]) +- implement tree explain for GlobalLimitExec [#15100](https://github.com/apache/datafusion/pull/15100) (zjregee) +- Minor: Cleanup useless/duplicated code in gen tools [#15113](https://github.com/apache/datafusion/pull/15113) (lewiszlw) +- Refactor EnforceDistribution test cases to demonstrate dependencies across optimizer runs. [#15074](https://github.com/apache/datafusion/pull/15074) (wiedld) +- Improve parsing `extra_info` in tree explain [#15125](https://github.com/apache/datafusion/pull/15125) (irenjj) +- Add tests for simplification and coercion of `SessionContext::create_physical_expr` [#15034](https://github.com/apache/datafusion/pull/15034) (alamb) +- Minor: Fix invalid query in test [#15131](https://github.com/apache/datafusion/pull/15131) (alamb) +- Do not display logical_plan win explain `tree` mode 🧹 [#15132](https://github.com/apache/datafusion/pull/15132) (alamb) +- Substrait support for propagating TableScan.filters to Substrait ReadRel.filter [#14194](https://github.com/apache/datafusion/pull/14194) (jamxia155) +- Fix wasm32 build on version 46 [#15102](https://github.com/apache/datafusion/pull/15102) (XiangpengHao) +- Fix broken `serde` feature [#15124](https://github.com/apache/datafusion/pull/15124) (vadimpiven) +- chore(deps): bump tempfile from 3.17.1 to 3.18.0 [#15146](https://github.com/apache/datafusion/pull/15146) (dependabot[bot]) +- chore(deps): bump syn from 2.0.98 to 2.0.100 [#15147](https://github.com/apache/datafusion/pull/15147) (dependabot[bot]) +- Implement tree explain for AggregateExec [#15103](https://github.com/apache/datafusion/pull/15103) (zebsme) +- Implement tree explain for `RepartitionExec` and `WorkTableExec` [#15137](https://github.com/apache/datafusion/pull/15137) (Standing-Man) +- Expand wildcard to actual expressions in `prepare_select_exprs` [#15090](https://github.com/apache/datafusion/pull/15090) (jayzhan211) +- fixed PushDownFilter bug [15047] [#15142](https://github.com/apache/datafusion/pull/15142) (Jiashu-Hu) +- Bump `env_logger` from `0.11.6` to `0.11.7` [#15148](https://github.com/apache/datafusion/pull/15148) (mbrobbel) +- Minor: fix extend sqllogical consistent with main test [#15145](https://github.com/apache/datafusion/pull/15145) (zhuqi-lucas) +- Implement tree rendering for `SortPreservingMergeExec` [#15140](https://github.com/apache/datafusion/pull/15140) (Standing-Man) +- Remove expand wildcard rule [#15170](https://github.com/apache/datafusion/pull/15170) (jayzhan211) +- chore: remove ScalarUDFImpl::return_type_from_exprs [#15130](https://github.com/apache/datafusion/pull/15130) (Blizzara) +- chore(deps): bump libc from 0.2.170 to 0.2.171 [#15176](https://github.com/apache/datafusion/pull/15176) (dependabot[bot]) +- chore(deps): bump serde_json from 1.0.139 to 1.0.140 [#15175](https://github.com/apache/datafusion/pull/15175) (dependabot[bot]) +- chore(deps): bump substrait from 0.53.2 to 0.54.0 [#15043](https://github.com/apache/datafusion/pull/15043) (dependabot[bot]) +- Minor: split EXPLAIN and ANALYZE planning into different functions [#15188](https://github.com/apache/datafusion/pull/15188) (alamb) +- Implement `tree` explain for `JsonSink` [#15185](https://github.com/apache/datafusion/pull/15185) (irenjj) +- Split out `datafusion-substrait` and `datafusion-proto` CI feature checks, increase coverage [#15156](https://github.com/apache/datafusion/pull/15156) (alamb) +- Remove unused wildcard expanding methods [#15180](https://github.com/apache/datafusion/pull/15180) (goldmedal) +- #15108 issue: "Non Panic Task error" is not an internal error [#15109](https://github.com/apache/datafusion/pull/15109) (Satyam018) +- Implement tree explain for LazyMemoryExec [#15187](https://github.com/apache/datafusion/pull/15187) (zebsme) +- implement tree explain for CoalesceBatchesExec [#15194](https://github.com/apache/datafusion/pull/15194) (Standing-Man) +- Implement `tree` explain for `CsvSink` [#15204](https://github.com/apache/datafusion/pull/15204) (irenjj) +- chore(deps): bump blake3 from 1.6.0 to 1.6.1 [#15198](https://github.com/apache/datafusion/pull/15198) (dependabot[bot]) +- chore(deps): bump clap from 4.5.31 to 4.5.32 [#15199](https://github.com/apache/datafusion/pull/15199) (dependabot[bot]) +- chore(deps): bump serde from 1.0.218 to 1.0.219 [#15197](https://github.com/apache/datafusion/pull/15197) (dependabot[bot]) +- Fix datafusion proto crate `json` feature [#15172](https://github.com/apache/datafusion/pull/15172) (Owen-CH-Leung) +- Add blog link to `EquivalenceProperties` docs [#15215](https://github.com/apache/datafusion/pull/15215) (alamb) +- Minor: split datafusion-cli testing into its own CI job [#15075](https://github.com/apache/datafusion/pull/15075) (alamb) +- Implement tree explain for InterleaveExec [#15219](https://github.com/apache/datafusion/pull/15219) (zebsme) +- Move catalog_common out of core [#15193](https://github.com/apache/datafusion/pull/15193) (logan-keede) +- chore(deps): bump tokio-util from 0.7.13 to 0.7.14 [#15223](https://github.com/apache/datafusion/pull/15223) (dependabot[bot]) +- chore(deps): bump aws-config from 1.5.18 to 1.6.0 [#15222](https://github.com/apache/datafusion/pull/15222) (dependabot[bot]) +- chore(deps): bump bzip2 from 0.5.1 to 0.5.2 [#15221](https://github.com/apache/datafusion/pull/15221) (dependabot[bot]) +- Document guidelines for physical operator yielding [#15030](https://github.com/apache/datafusion/pull/15030) (carols10cents) +- Implement `tree` explain for `ArrowFileSink`, fix original URL [#15206](https://github.com/apache/datafusion/pull/15206) (irenjj) +- Implement tree explain for `LocalLimitExec` [#15232](https://github.com/apache/datafusion/pull/15232) (shruti2522) +- Use insta for `DataFrame` tests [#15165](https://github.com/apache/datafusion/pull/15165) (blaginin) +- Re-enable github discussion [#15241](https://github.com/apache/datafusion/pull/15241) (2010YOUY01) +- Minor: exclude datafusion-cli testing for mac [#15240](https://github.com/apache/datafusion/pull/15240) (zhuqi-lucas) +- Implement tree explain for CoalescePartitionsExec [#15225](https://github.com/apache/datafusion/pull/15225) (Shreyaskr1409) +- Enable `used_underscore_binding` clippy lint [#15189](https://github.com/apache/datafusion/pull/15189) (Shreyaskr1409) +- Simpler to see expressions in explain `tree` mode [#15163](https://github.com/apache/datafusion/pull/15163) (irenjj) +- Fix invalid schema for unions in ViewTables [#15135](https://github.com/apache/datafusion/pull/15135) (Friede80) +- Make `ListingTableUrl::try_new` public [#15250](https://github.com/apache/datafusion/pull/15250) (linhr) +- Fix wildcard dataframe case [#15230](https://github.com/apache/datafusion/pull/15230) (jayzhan211) +- Simplify the printing of all plans containing `expr` in `tree` mode [#15249](https://github.com/apache/datafusion/pull/15249) (irenjj) +- Support utf8view datatype for window [#15257](https://github.com/apache/datafusion/pull/15257) (zhuqi-lucas) +- chore: remove deprecated variants of UDF's invoke (invoke, invoke_no_args, invoke_batch) [#15123](https://github.com/apache/datafusion/pull/15123) (Blizzara) +- Improve feature flag CI coverage `datafusion` and `datafusion-functions` [#15203](https://github.com/apache/datafusion/pull/15203) (alamb) +- Add debug logging for default catalog overwrite in SessionState build [#15251](https://github.com/apache/datafusion/pull/15251) (byte-sourcerer) +- Implement tree explain for PlaceholderRowExec [#15270](https://github.com/apache/datafusion/pull/15270) (zebsme) +- Implement tree explain for UnionExec [#15278](https://github.com/apache/datafusion/pull/15278) (zebsme) +- Migrate dataframe tests to `insta` [#15262](https://github.com/apache/datafusion/pull/15262) (jsai28) +- Minor: consistently apply `clippy::clone_on_ref_ptr` in all crates [#15284](https://github.com/apache/datafusion/pull/15284) (alamb) +- chore(deps): bump async-trait from 0.1.87 to 0.1.88 [#15294](https://github.com/apache/datafusion/pull/15294) (dependabot[bot]) +- chore(deps): bump uuid from 1.15.1 to 1.16.0 [#15292](https://github.com/apache/datafusion/pull/15292) (dependabot[bot]) +- Add CatalogProvider and SchemaProvider to FFI Crate [#15280](https://github.com/apache/datafusion/pull/15280) (timsaucer) +- Refactor file schema type coercions [#15268](https://github.com/apache/datafusion/pull/15268) (xudong963) +- chore(deps): bump rust_decimal from 1.36.0 to 1.37.0 [#15293](https://github.com/apache/datafusion/pull/15293) (dependabot[bot]) +- chore: Attach Diagnostic to "incompatible type in unary expression" error [#15209](https://github.com/apache/datafusion/pull/15209) (onlyjackfrost) +- Support logic optimize rule to pass the case that Utf8view datatype combined with Utf8 datatype [#15239](https://github.com/apache/datafusion/pull/15239) (zhuqi-lucas) +- Migrate user_defined tests to insta [#15255](https://github.com/apache/datafusion/pull/15255) (shruti2522) +- Remove inline table scan analyzer rule [#15201](https://github.com/apache/datafusion/pull/15201) (jayzhan211) +- CI Red: Fix union in view table test [#15300](https://github.com/apache/datafusion/pull/15300) (jayzhan211) +- refactor: Move view and stream from `datasource` to `catalog`, deprecate `View::try_new` [#15260](https://github.com/apache/datafusion/pull/15260) (logan-keede) +- chore(deps): bump substrait from 0.54.0 to 0.55.0 [#15305](https://github.com/apache/datafusion/pull/15305) (dependabot[bot]) +- chore(deps): bump half from 2.4.1 to 2.5.0 [#15303](https://github.com/apache/datafusion/pull/15303) (dependabot[bot]) +- chore(deps): bump mimalloc from 0.1.43 to 0.1.44 [#15304](https://github.com/apache/datafusion/pull/15304) (dependabot[bot]) +- Fix predicate pushdown for custom SchemaAdapters [#15263](https://github.com/apache/datafusion/pull/15263) (adriangb) +- Fix extended tests by restore datafusion-testing submodule [#15318](https://github.com/apache/datafusion/pull/15318) (alamb) +- Support Duration in min/max agg functions [#15310](https://github.com/apache/datafusion/pull/15310) (svranesevic) +- Migrate tests to insta [#15288](https://github.com/apache/datafusion/pull/15288) (jsai28) +- chore(deps): bump quote from 1.0.38 to 1.0.40 [#15332](https://github.com/apache/datafusion/pull/15332) (dependabot[bot]) +- chore(deps): bump blake3 from 1.6.1 to 1.7.0 [#15331](https://github.com/apache/datafusion/pull/15331) (dependabot[bot]) +- Simplify display format of `AggregateFunctionExpr`, add `Expr::sql_name` [#15253](https://github.com/apache/datafusion/pull/15253) (irenjj) +- chore(deps): bump indexmap from 2.7.1 to 2.8.0 [#15333](https://github.com/apache/datafusion/pull/15333) (dependabot[bot]) +- chore(deps): bump tokio from 1.43.0 to 1.44.1 [#15347](https://github.com/apache/datafusion/pull/15347) (dependabot[bot]) +- chore(deps): bump tempfile from 3.18.0 to 3.19.1 [#15346](https://github.com/apache/datafusion/pull/15346) (dependabot[bot]) +- Minor: Keep debug symbols for `release-nonlto` build [#15350](https://github.com/apache/datafusion/pull/15350) (2010YOUY01) +- Use `any` instead of `for_each` [#15289](https://github.com/apache/datafusion/pull/15289) (xudong963) +- refactor: move `CteWorkTable`, `default_table_source` a bunch of files out of core [#15316](https://github.com/apache/datafusion/pull/15316) (logan-keede) +- Fix empty aggregation function count() in Substrait [#15345](https://github.com/apache/datafusion/pull/15345) (gabotechs) +- Improved error for expand wildcard rule [#15287](https://github.com/apache/datafusion/pull/15287) (Jiashu-Hu) +- Added tests with are writing into parquet files in memory for issue #… [#15325](https://github.com/apache/datafusion/pull/15325) (pranavJibhakate) +- Migrate physical plan tests to `insta` (Part-1) [#15313](https://github.com/apache/datafusion/pull/15313) (Shreyaskr1409) +- Fix array_has_all and array_has_any with empty array [#15039](https://github.com/apache/datafusion/pull/15039) (LuQQiu) +- Update datafusion-testing pin to fix extended tests [#15368](https://github.com/apache/datafusion/pull/15368) (alamb) +- chore(deps): Update sqlparser to 0.55.0 [#15183](https://github.com/apache/datafusion/pull/15183) (PokIsemaine) +- Only unnest source for `EmptyRelation` [#15159](https://github.com/apache/datafusion/pull/15159) (blaginin) +- chore(deps): bump rust_decimal from 1.37.0 to 1.37.1 [#15378](https://github.com/apache/datafusion/pull/15378) (dependabot[bot]) +- chore(deps): bump chrono-tz from 0.10.1 to 0.10.2 [#15377](https://github.com/apache/datafusion/pull/15377) (dependabot[bot]) +- remove the duplicate test for unparser [#15385](https://github.com/apache/datafusion/pull/15385) (goldmedal) +- Minor: add average time for clickbench benchmark query [#15381](https://github.com/apache/datafusion/pull/15381) (zhuqi-lucas) +- include some BinaryOperator from sqlparser [#15327](https://github.com/apache/datafusion/pull/15327) (waynexia) +- Add "end to end parquet reading test" for WASM [#15362](https://github.com/apache/datafusion/pull/15362) (jsai28) +- Migrate physical plan tests to `insta` (Part-2) [#15364](https://github.com/apache/datafusion/pull/15364) (Shreyaskr1409) +- Migrate physical plan tests to `insta` (Part-3 / Final) [#15399](https://github.com/apache/datafusion/pull/15399) (Shreyaskr1409) +- Restore lazy evaluation of fallible CASE [#15390](https://github.com/apache/datafusion/pull/15390) (findepi) +- chore(deps): bump log from 0.4.26 to 0.4.27 [#15410](https://github.com/apache/datafusion/pull/15410) (dependabot[bot]) +- chore(deps): bump chrono-tz from 0.10.2 to 0.10.3 [#15412](https://github.com/apache/datafusion/pull/15412) (dependabot[bot]) +- Perf: Support Utf8View datatype single column comparisons for SortPreservingMergeStream [#15348](https://github.com/apache/datafusion/pull/15348) (zhuqi-lucas) +- Enforce JOIN plan to require condition [#15334](https://github.com/apache/datafusion/pull/15334) (goldmedal) +- Fix type coercion for unsigned and signed integers (`Int64` vs `UInt64`, etc) [#15341](https://github.com/apache/datafusion/pull/15341) (Omega359) +- simplify `array_has` UDF to `InList` expr when haystack is constant [#15354](https://github.com/apache/datafusion/pull/15354) (davidhewitt) +- Move `DataSink` to `datasource` and add session crate [#15371](https://github.com/apache/datafusion/pull/15371) (jayzhan-synnada) +- refactor: SpillManager into a separate file [#15407](https://github.com/apache/datafusion/pull/15407) (Weijun-H) +- Always use `PartitionMode::Auto` in planner [#15339](https://github.com/apache/datafusion/pull/15339) (Dandandan) +- Fix link to Volcano paper [#15437](https://github.com/apache/datafusion/pull/15437) (JackKelly) +- minor: Add new crates to labeler [#15426](https://github.com/apache/datafusion/pull/15426) (logan-keede) +- refactor: Use SpillManager for all spilling scenarios [#15405](https://github.com/apache/datafusion/pull/15405) (2010YOUY01) +- refactor(hash_join): Move JoinHashMap to separate mod [#15419](https://github.com/apache/datafusion/pull/15419) (ctsk) +- Migrate datasource tests to insta [#15258](https://github.com/apache/datafusion/pull/15258) (shruti2522) +- Add `downcast_to_source` method for `DataSourceExec` [#15416](https://github.com/apache/datafusion/pull/15416) (xudong963) +- refactor: use TypeSignature::Coercible for crypto functions [#14826](https://github.com/apache/datafusion/pull/14826) (Chen-Yuan-Lai) +- Minor: fix doc for `FileGroupPartitioner` [#15448](https://github.com/apache/datafusion/pull/15448) (xudong963) +- chore(deps): bump clap from 4.5.32 to 4.5.34 [#15452](https://github.com/apache/datafusion/pull/15452) (dependabot[bot]) +- Fix roundtrip bug with empty projection in DataSourceExec [#15449](https://github.com/apache/datafusion/pull/15449) (XiangpengHao) +- Triggering extended tests through PR comment: `Run extended tests` [#15101](https://github.com/apache/datafusion/pull/15101) (danila-b) +- Use `equals_datatype` to compare type when type coercion [#15366](https://github.com/apache/datafusion/pull/15366) (goldmedal) +- Fix no effect metrics bug in ParquetSource [#15460](https://github.com/apache/datafusion/pull/15460) (XiangpengHao) +- chore(deps): bump aws-config from 1.6.0 to 1.6.1 [#15470](https://github.com/apache/datafusion/pull/15470) (dependabot[bot]) +- minor: Allow to run TPCH bench for a specific query [#15467](https://github.com/apache/datafusion/pull/15467) (comphead) +- Migrate subtraits tests to insta, part1 [#15444](https://github.com/apache/datafusion/pull/15444) (qstommyshu) +- Add `FileScanConfigBuilder` [#15352](https://github.com/apache/datafusion/pull/15352) (blaginin) +- Update ClickBench queries to avoid to_timestamp_seconds [#15475](https://github.com/apache/datafusion/pull/15475) (acking-you) +- Remove CoalescePartitions insertion from HashJoinExec [#15476](https://github.com/apache/datafusion/pull/15476) (ctsk) +- Migrate-substrait-tests-to-insta, part2 [#15480](https://github.com/apache/datafusion/pull/15480) (qstommyshu) +- Revert #15476 to fix the datafusion-examples CI fail [#15496](https://github.com/apache/datafusion/pull/15496) (goldmedal) +- Migrate datafusion/sql tests to insta, part1 [#15497](https://github.com/apache/datafusion/pull/15497) (qstommyshu) +- Allow type coersion of zero input arrays to nullary [#15487](https://github.com/apache/datafusion/pull/15487) (timsaucer) +- Decimal type support for `to_timestamp` [#15486](https://github.com/apache/datafusion/pull/15486) (jatin510) +- refactor: Move `Memtable` to catalog [#15459](https://github.com/apache/datafusion/pull/15459) (logan-keede) +- Migrate optimizer tests to insta [#15446](https://github.com/apache/datafusion/pull/15446) (qstommyshu) +- FIX : some benchmarks are failing [#15367](https://github.com/apache/datafusion/pull/15367) (getChan) +- Add query to extended clickbench suite for "complex filter" [#15500](https://github.com/apache/datafusion/pull/15500) (acking-you) +- Extract tokio runtime creation from hot loop in benchmarks [#15508](https://github.com/apache/datafusion/pull/15508) (Omega359) +- chore(deps): bump blake3 from 1.7.0 to 1.8.0 [#15502](https://github.com/apache/datafusion/pull/15502) (dependabot[bot]) +- Minor: clone and debug for FileSinkConfig [#15516](https://github.com/apache/datafusion/pull/15516) (jayzhan211) +- use state machine to refactor the `get_files_with_limit` method [#15521](https://github.com/apache/datafusion/pull/15521) (xudong963) +- Migrate `datafusion/sql` tests to insta, part2 [#15499](https://github.com/apache/datafusion/pull/15499) (qstommyshu) +- Disable sccache action to fix gh cache issue [#15536](https://github.com/apache/datafusion/pull/15536) (Omega359) +- refactor: Cleanup unused `fetch` field inside `ExternalSorter` [#15525](https://github.com/apache/datafusion/pull/15525) (2010YOUY01) +- Fix duplicate unqualified Field name (schema error) on join queries [#15438](https://github.com/apache/datafusion/pull/15438) (LiaCastaneda) +- Add utf8view benchmark for aggregate topk [#15518](https://github.com/apache/datafusion/pull/15518) (zhuqi-lucas) +- ArraySort: support structs [#15527](https://github.com/apache/datafusion/pull/15527) (cht42) +- Migrate datafusion/sql tests to insta, part3 [#15533](https://github.com/apache/datafusion/pull/15533) (qstommyshu) +- Migrate datafusion/sql tests to insta, part4 [#15548](https://github.com/apache/datafusion/pull/15548) (qstommyshu) +- Add topk information into tree explain plans [#15547](https://github.com/apache/datafusion/pull/15547) (kumarlokesh) +- Minor: add Arc for statistics in FileGroup [#15564](https://github.com/apache/datafusion/pull/15564) (xudong963) +- Test: configuration fuzzer for (external) sort queries [#15501](https://github.com/apache/datafusion/pull/15501) (2010YOUY01) +- minor: Organize fields inside SortMergeJoinStream [#15557](https://github.com/apache/datafusion/pull/15557) (suibianwanwank) +- Minor: rm session downcast [#15575](https://github.com/apache/datafusion/pull/15575) (jayzhan211) +- Migrate datafusion/sql tests to insta, part5 [#15567](https://github.com/apache/datafusion/pull/15567) (qstommyshu) +- Add SQL logic tests for compound field access in JOIN conditions [#15556](https://github.com/apache/datafusion/pull/15556) (kosiew) +- Run audit CI check on all pushes to main [#15572](https://github.com/apache/datafusion/pull/15572) (alamb) +- Introduce load-balanced `split_groups_by_statistics` method [#15473](https://github.com/apache/datafusion/pull/15473) (xudong963) +- chore: update clickbench [#15574](https://github.com/apache/datafusion/pull/15574) (chenkovsky) +- Improve spill performance: Disable re-validation of spilled files [#15454](https://github.com/apache/datafusion/pull/15454) (zebsme) +- chore: rm duplicated `JoinOn` type [#15590](https://github.com/apache/datafusion/pull/15590) (jayzhan211) +- Chore: Call arrow's methods `row_count` and `skipped_row_count` [#15587](https://github.com/apache/datafusion/pull/15587) (jayzhan211) +- Actually run wasm test in ci [#15595](https://github.com/apache/datafusion/pull/15595) (XiangpengHao) +- Migrate datafusion/sql tests to insta, part6 [#15578](https://github.com/apache/datafusion/pull/15578) (qstommyshu) +- Add test case for new casting feature from date to tz-aware timestamps [#15609](https://github.com/apache/datafusion/pull/15609) (friendlymatthew) +- Remove CoalescePartitions insertion from Joins [#15570](https://github.com/apache/datafusion/pull/15570) (ctsk) +- fix doc and broken api [#15602](https://github.com/apache/datafusion/pull/15602) (logan-keede) +- Migrate datafusion/sql tests to insta, part7 [#15621](https://github.com/apache/datafusion/pull/15621) (qstommyshu) +- ignore security_audit CI check proc-macro-error warning [#15626](https://github.com/apache/datafusion/pull/15626) (Jiashu-Hu) +- chore(deps): bump tokio from 1.44.1 to 1.44.2 [#15627](https://github.com/apache/datafusion/pull/15627) (dependabot[bot]) +- Upgrade toolchain to Rust-1.86 [#15625](https://github.com/apache/datafusion/pull/15625) (jsai28) +- chore(deps): bump bigdecimal from 0.4.7 to 0.4.8 [#15523](https://github.com/apache/datafusion/pull/15523) (dependabot[bot]) +- chore(deps): bump the arrow-parquet group across 1 directory with 7 updates [#15593](https://github.com/apache/datafusion/pull/15593) (dependabot[bot]) +- chore: improve RepartitionExec display tree [#15606](https://github.com/apache/datafusion/pull/15606) (getChan) +- Move back schema not matching check and workaround [#15580](https://github.com/apache/datafusion/pull/15580) (LiaCastaneda) +- Minor: refine comments for statistics compution [#15647](https://github.com/apache/datafusion/pull/15647) (xudong963) +- Remove uneeded binary_op benchmarks [#15632](https://github.com/apache/datafusion/pull/15632) (alamb) +- chore(deps): bump blake3 from 1.8.0 to 1.8.1 [#15650](https://github.com/apache/datafusion/pull/15650) (dependabot[bot]) +- chore(deps): bump mimalloc from 0.1.44 to 0.1.46 [#15651](https://github.com/apache/datafusion/pull/15651) (dependabot[bot]) +- chore: avoid erroneuous warning for FFI table operation (only not default value) [#15579](https://github.com/apache/datafusion/pull/15579) (chenkovsky) +- Update datafusion-testing pin (to fix extended test on main) [#15655](https://github.com/apache/datafusion/pull/15655) (alamb) +- Ignore false positive only_used_in_recursion Clippy warning [#15635](https://github.com/apache/datafusion/pull/15635) (DerGut) +- chore: Rename protobuf Java package [#15658](https://github.com/apache/datafusion/pull/15658) (andygrove) +- Remove redundant `Precision` combination code in favor of `Precision::min/max/add` [#15659](https://github.com/apache/datafusion/pull/15659) (alamb) +- Introduce DynamicFilterSource and DynamicPhysicalExpr [#15568](https://github.com/apache/datafusion/pull/15568) (adriangb) +- Public some projected methods in `FileScanConfig` [#15671](https://github.com/apache/datafusion/pull/15671) (xudong963) +- fix decimal precision issue in simplify expression optimize rule [#15588](https://github.com/apache/datafusion/pull/15588) (jayzhan211) +- Implement Future for SpawnedTask. [#15653](https://github.com/apache/datafusion/pull/15653) (ashdnazg) +- chore(deps): bump crossbeam-channel from 0.5.14 to 0.5.15 [#15674](https://github.com/apache/datafusion/pull/15674) (dependabot[bot]) +- chore(deps): bump clap from 4.5.34 to 4.5.35 [#15668](https://github.com/apache/datafusion/pull/15668) (dependabot[bot]) +- [Minor] Use interleave_record_batch in TopK implementation [#15677](https://github.com/apache/datafusion/pull/15677) (Dandandan) +- Consolidate statistics merging code (try 2) [#15661](https://github.com/apache/datafusion/pull/15661) (alamb) +- Add Table Functions to FFI Crate [#15581](https://github.com/apache/datafusion/pull/15581) (timsaucer) +- Remove waits from blocking threads reading spill files. [#15654](https://github.com/apache/datafusion/pull/15654) (ashdnazg) +- chore(deps): bump sysinfo from 0.33.1 to 0.34.2 [#15682](https://github.com/apache/datafusion/pull/15682) (dependabot[bot]) +- Minor: add order by arg for last value [#15695](https://github.com/apache/datafusion/pull/15695) (jayzhan211) +- Upgrade to arrow/parquet 55, and `object_store` to `0.12.0` and pyo3 to `0.24.0` [#15466](https://github.com/apache/datafusion/pull/15466) (alamb) +- tests: only refresh the minimum sysinfo in mem limit tests. [#15702](https://github.com/apache/datafusion/pull/15702) (ashdnazg) +- ci: fix workflow triggering extended tests from pr comments. [#15704](https://github.com/apache/datafusion/pull/15704) (ashdnazg) +- chore(deps): bump flate2 from 1.1.0 to 1.1.1 [#15703](https://github.com/apache/datafusion/pull/15703) (dependabot[bot]) +- Fix internal error in sort when hitting memory limit [#15692](https://github.com/apache/datafusion/pull/15692) (DerGut) +- Update checked in Cargo.lock file to get clean CI [#15725](https://github.com/apache/datafusion/pull/15725) (alamb) +- chore(deps): bump indexmap from 2.8.0 to 2.9.0 [#15732](https://github.com/apache/datafusion/pull/15732) (dependabot[bot]) +- Minor: include output partition count of `RepartitionExec` to tree explain [#15717](https://github.com/apache/datafusion/pull/15717) (2010YOUY01) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 48 dependabot[bot] + 34 Andrew Lamb + 16 xudong.w + 15 Jay Zhan + 15 Qi Zhu + 15 irenjj + 13 Chen Chongchen + 13 Yongting You + 10 Tommy shu + 7 Shruti Sharma + 6 Alan Tang + 6 Arttu + 6 Jiashu Hu + 6 Shreyas (Lua) + 6 logan-keede + 6 zeb + 5 Dmitrii Blaginin + 5 Geoffrey Claude + 5 Jax Liu + 5 YuNing Chen + 4 Bruce Ritchie + 4 Christian + 4 Eshed Schacham + 4 Xiangpeng Hao + 4 wiedld + 3 Adrian Garcia Badaracco + 3 Daniël Heres + 3 Gabriel + 3 LB7666 + 3 Namgung Chan + 3 Ruihang Xia + 3 Tim Saucer + 3 jsai28 + 3 kosiew + 3 suibianwanwan + 2 Bryce Mecum + 2 Carol (Nichols || Goulding) + 2 Heran Lin + 2 Jannik Steinmann + 2 Jyotir Sai + 2 Li-Lun Lin + 2 Lía Adriana + 2 Oleks V + 2 Raz Luvaton + 2 UBarney + 2 aditya singh rathore + 2 westhide + 2 zjregee + 1 @clflushopt + 1 Adam Gutglick + 1 Alex Huang + 1 Alex Wilcoxson + 1 Amos Aidoo + 1 Andy Grove + 1 Andy Yen + 1 Berkay Şahin + 1 Chang + 1 Danila Baklazhenko + 1 David Hewitt + 1 Emily Matheys + 1 Eren Avsarogullari + 1 Hari Varsha + 1 Ian Lai + 1 Jack Kelly + 1 Jagdish Parihar + 1 Joseph Koshakow + 1 Lokesh + 1 LuQQiu + 1 Matt Butrovich + 1 Matt Friede + 1 Matthew Kim + 1 Matthijs Brobbel + 1 Om Kenge + 1 Owen Leung + 1 Peter L + 1 Piotr Findeisen + 1 Rohan Krishnaswamy + 1 Satyam018 + 1 Sava Vranešević + 1 Siddhartha Sahu + 1 Sile Zhou + 1 Vadim Piven + 1 Zaki + 1 christophermcdermott + 1 cht42 + 1 cjw + 1 delamarch3 + 1 ding-young + 1 haruband + 1 jamxia155 + 1 oznur-synnada + 1 peasee + 1 pranavJibhakate + 1 张林伟 +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/48.0.0.md b/dev/changelog/48.0.0.md new file mode 100644 index 0000000000000..9cf6c03b7acf0 --- /dev/null +++ b/dev/changelog/48.0.0.md @@ -0,0 +1,405 @@ + + +# Apache DataFusion 48.0.0 Changelog + +This release consists of 267 commits from 89 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- Attach Diagnostic to syntax errors [#15680](https://github.com/apache/datafusion/pull/15680) (logan-keede) +- Change `flatten` so it does only a level, not recursively [#15160](https://github.com/apache/datafusion/pull/15160) (delamarch3) +- Improve `simplify_expressions` rule [#15735](https://github.com/apache/datafusion/pull/15735) (xudong963) +- Support WITHIN GROUP syntax to standardize certain existing aggregate functions [#13511](https://github.com/apache/datafusion/pull/13511) (Garamda) +- Add Extension Type / Metadata support for Scalar UDFs [#15646](https://github.com/apache/datafusion/pull/15646) (timsaucer) +- chore: fix clippy::large_enum_variant for DataFusionError [#15861](https://github.com/apache/datafusion/pull/15861) (rroelke) +- Feat: introduce `ExecutionPlan::partition_statistics` API [#15852](https://github.com/apache/datafusion/pull/15852) (xudong963) +- refactor: remove deprecated `ParquetExec` [#15973](https://github.com/apache/datafusion/pull/15973) (miroim) +- refactor: remove deprecated `ArrowExec` [#16006](https://github.com/apache/datafusion/pull/16006) (miroim) +- refactor: remove deprecated `MemoryExec` [#16007](https://github.com/apache/datafusion/pull/16007) (miroim) +- refactor: remove deprecated `JsonExec` [#16005](https://github.com/apache/datafusion/pull/16005) (miroim) +- feat: metadata handling for aggregates and window functions [#15911](https://github.com/apache/datafusion/pull/15911) (timsaucer) +- Remove `Filter::having` field [#16154](https://github.com/apache/datafusion/pull/16154) (findepi) +- Shift from Field to FieldRef for all user defined functions [#16122](https://github.com/apache/datafusion/pull/16122) (timsaucer) +- Change default SQL mapping for `VARCAHR` from `Utf8` to `Utf8View` [#16142](https://github.com/apache/datafusion/pull/16142) (zhuqi-lucas) +- Minor: remove unused IPCWriter [#16215](https://github.com/apache/datafusion/pull/16215) (alamb) +- Reduce size of `Expr` struct [#16207](https://github.com/apache/datafusion/pull/16207) (hendrikmakait) + +**Performance related:** + +- Apply pre-selection and computation skipping to short-circuit optimization [#15694](https://github.com/apache/datafusion/pull/15694) (acking-you) +- Add a fast path for `optimize_projection` [#15746](https://github.com/apache/datafusion/pull/15746) (xudong963) +- Speed up `optimize_projection` by improving `is_projection_unnecessary` [#15761](https://github.com/apache/datafusion/pull/15761) (xudong963) +- Speed up `optimize_projection` [#15787](https://github.com/apache/datafusion/pull/15787) (xudong963) +- Support `GroupsAccumulator` for Avg duration [#15748](https://github.com/apache/datafusion/pull/15748) (shruti2522) +- Optimize performance of `string::ascii` function [#16087](https://github.com/apache/datafusion/pull/16087) (tlm365) + +**Implemented enhancements:** + +- Set DataFusion runtime configurations through SQL interface [#15594](https://github.com/apache/datafusion/pull/15594) (kumarlokesh) +- feat: Add option to adjust writer buffer size for query output [#15747](https://github.com/apache/datafusion/pull/15747) (m09526) +- feat: Add `datafusion-spark` crate [#15168](https://github.com/apache/datafusion/pull/15168) (shehabgamin) +- feat: create helpers to set the max_temp_directory_size [#15919](https://github.com/apache/datafusion/pull/15919) (jdrouet) +- feat: ORDER BY ALL [#15772](https://github.com/apache/datafusion/pull/15772) (PokIsemaine) +- feat: support min/max for struct [#15667](https://github.com/apache/datafusion/pull/15667) (chenkovsky) +- feat(proto): udf decoding fallback [#15997](https://github.com/apache/datafusion/pull/15997) (leoyvens) +- feat: make error handling in indent explain consistent with that in tree [#16097](https://github.com/apache/datafusion/pull/16097) (chenkovsky) +- feat: coerce to/from fixed size binary to binary view [#16110](https://github.com/apache/datafusion/pull/16110) (chenkovsky) +- feat: array_length for fixed size list [#16167](https://github.com/apache/datafusion/pull/16167) (chenkovsky) +- feat: ADD sha2 spark function [#16168](https://github.com/apache/datafusion/pull/16168) (getChan) +- feat: create builder for disk manager [#16191](https://github.com/apache/datafusion/pull/16191) (jdrouet) +- feat: Add Aggregate UDF to FFI crate [#14775](https://github.com/apache/datafusion/pull/14775) (timsaucer) +- feat(small): Add `BaselineMetrics` to `generate_series()` table function [#16255](https://github.com/apache/datafusion/pull/16255) (2010YOUY01) +- feat: Add Window UDFs to FFI Crate [#16261](https://github.com/apache/datafusion/pull/16261) (timsaucer) + +**Fixed bugs:** + +- fix: serialize listing table without partition column [#15737](https://github.com/apache/datafusion/pull/15737) (chenkovsky) +- fix: describe Parquet schema with coerce_int96 [#15750](https://github.com/apache/datafusion/pull/15750) (chenkovsky) +- fix: clickbench type err [#15773](https://github.com/apache/datafusion/pull/15773) (chenkovsky) +- Fix: fetch is missing in `replace_order_preserving_variants` method during `EnforceDistribution` optimizer [#15808](https://github.com/apache/datafusion/pull/15808) (xudong963) +- Fix: fetch is missing in `EnforceSorting` optimizer (two places) [#15822](https://github.com/apache/datafusion/pull/15822) (xudong963) +- fix: Avoid mistaken ILike to string equality optimization [#15836](https://github.com/apache/datafusion/pull/15836) (srh) +- Map file-level column statistics to the table-level [#15865](https://github.com/apache/datafusion/pull/15865) (xudong963) +- fix(avro): Respect projection order in Avro reader [#15840](https://github.com/apache/datafusion/pull/15840) (nantunes) +- fix: correctly specify the nullability of `map_values` return type [#15901](https://github.com/apache/datafusion/pull/15901) (rluvaton) +- Fix CI in main [#15917](https://github.com/apache/datafusion/pull/15917) (blaginin) +- fix: sqllogictest on Windows [#15932](https://github.com/apache/datafusion/pull/15932) (nuno-faria) +- fix: fold cast null to substrait typed null [#15854](https://github.com/apache/datafusion/pull/15854) (discord9) +- Fix: `build_predicate_expression` method doesn't process `false` expr correctly [#15995](https://github.com/apache/datafusion/pull/15995) (xudong963) +- fix: add an "expr_planners" method to SessionState [#15119](https://github.com/apache/datafusion/pull/15119) (niebayes) +- fix: overcounting of memory in first/last. [#15924](https://github.com/apache/datafusion/pull/15924) (ashdnazg) +- fix: track timing for coalescer's in execution time [#16048](https://github.com/apache/datafusion/pull/16048) (waynexia) +- fix: stack overflow for substrait functions with large argument lists that translate to DataFusion binary operators [#16031](https://github.com/apache/datafusion/pull/16031) (fmonjalet) +- fix: coerce int96 resolution inside of list, struct, and map types [#16058](https://github.com/apache/datafusion/pull/16058) (mbutrovich) +- fix: Add coercion rules for Float16 types [#15816](https://github.com/apache/datafusion/pull/15816) (etseidl) +- fix: describe escaped quoted identifiers [#16082](https://github.com/apache/datafusion/pull/16082) (jfahne) +- fix: Remove trailing whitespace in `Display` for `LogicalPlan::Projection` [#16164](https://github.com/apache/datafusion/pull/16164) (atahanyorganci) +- fix: metadata of join schema [#16221](https://github.com/apache/datafusion/pull/16221) (chenkovsky) +- fix: add missing row count limits to TPC-H queries [#16230](https://github.com/apache/datafusion/pull/16230) (0ax1) +- fix: NaN semantics in GROUP BY [#16256](https://github.com/apache/datafusion/pull/16256) (chenkovsky) + +**Documentation updates:** + +- Add DataFusion 47.0.0 Upgrade Guide [#15749](https://github.com/apache/datafusion/pull/15749) (alamb) +- Improve documentation for format `OPTIONS` clause [#15708](https://github.com/apache/datafusion/pull/15708) (marvelshan) +- doc: Adding Feldera as known user [#15799](https://github.com/apache/datafusion/pull/15799) (comphead) +- docs: add ArkFlow [#15826](https://github.com/apache/datafusion/pull/15826) (chenquan) +- Fix `from_unixtime` function documentation [#15844](https://github.com/apache/datafusion/pull/15844) (Viicos) +- Upgrade-guide: Downgrade "FileScanConfig –> FileScanConfigBuilder" headline [#15883](https://github.com/apache/datafusion/pull/15883) (simonvandel) +- doc: Update known users docs [#15895](https://github.com/apache/datafusion/pull/15895) (comphead) +- Add `union_tag` scalar function [#14687](https://github.com/apache/datafusion/pull/14687) (gstvg) +- Fix typo in introduction.md [#15910](https://github.com/apache/datafusion/pull/15910) (tom-mont) +- Add `FormatOptions` to Config [#15793](https://github.com/apache/datafusion/pull/15793) (blaginin) +- docs: Label `bloom_filter_on_read` as a reading config [#15933](https://github.com/apache/datafusion/pull/15933) (nuno-faria) +- Implement Parquet filter pushdown via new filter pushdown APIs [#15769](https://github.com/apache/datafusion/pull/15769) (adriangb) +- Enable repartitioning on MemTable. [#15409](https://github.com/apache/datafusion/pull/15409) (wiedld) +- Updated extending operators documentation [#15612](https://github.com/apache/datafusion/pull/15612) (the0ninjas) +- chore: Replace MSRV link on main page with Github badge [#16020](https://github.com/apache/datafusion/pull/16020) (comphead) +- Add note to upgrade guide for removal of `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` [#16034](https://github.com/apache/datafusion/pull/16034) (alamb) +- docs: Clarify that it is only the name of the field that is ignored [#16052](https://github.com/apache/datafusion/pull/16052) (alamb) +- [Docs]: Added SQL example for all window functions [#16074](https://github.com/apache/datafusion/pull/16074) (Adez017) +- Fix CI on main: Add window function examples in code [#16102](https://github.com/apache/datafusion/pull/16102) (alamb) +- chore: Remove SMJ experimental status in docs [#16072](https://github.com/apache/datafusion/pull/16072) (comphead) +- doc: fix indent format explain [#16085](https://github.com/apache/datafusion/pull/16085) (chenkovsky) +- Update documentation for `datafusion.execution.collect_statistics` [#16100](https://github.com/apache/datafusion/pull/16100) (alamb) +- Make `SessionContext::register_parquet` obey `collect_statistics` config [#16080](https://github.com/apache/datafusion/pull/16080) (adriangb) +- Improve the DML / DDL Documentation [#16115](https://github.com/apache/datafusion/pull/16115) (alamb) +- docs: Fix typos and minor grammatical issues in Architecture docs [#16119](https://github.com/apache/datafusion/pull/16119) (patrickcsullivan) +- Set `TrackConsumersPool` as default in datafusion-cli [#16081](https://github.com/apache/datafusion/pull/16081) (ding-young) +- Minor: Fix links in substrait readme [#16156](https://github.com/apache/datafusion/pull/16156) (alamb) +- Add macro for creating DataFrame (#16090) [#16104](https://github.com/apache/datafusion/pull/16104) (cj-zhukov) +- doc: Move `dataframe!` example into dedicated example [#16197](https://github.com/apache/datafusion/pull/16197) (comphead) +- doc: add diagram to describe how DataSource, FileSource, and DataSourceExec are related [#16181](https://github.com/apache/datafusion/pull/16181) (onlyjackfrost) +- Clarify documentation about gathering statistics for parquet files [#16157](https://github.com/apache/datafusion/pull/16157) (alamb) +- Add change to VARCHAR in the upgrade guide [#16216](https://github.com/apache/datafusion/pull/16216) (alamb) +- Add iceberg-rust to user list [#16246](https://github.com/apache/datafusion/pull/16246) (jonathanc-n) +- Prepare for 48.0.0 release: Version and Changelog [#16238](https://github.com/apache/datafusion/pull/16238) (xudong963) + +**Other:** + +- Enable setting default values for target_partitions and planning_concurrency [#15712](https://github.com/apache/datafusion/pull/15712) (nuno-faria) +- minor: fix doc comment [#15733](https://github.com/apache/datafusion/pull/15733) (niebayes) +- chore(deps-dev): bump http-proxy-middleware from 2.0.6 to 2.0.9 in /datafusion/wasmtest/datafusion-wasm-app [#15738](https://github.com/apache/datafusion/pull/15738) (dependabot[bot]) +- Avoid computing unnecessary statstics [#15729](https://github.com/apache/datafusion/pull/15729) (xudong963) +- chore(deps): bump libc from 0.2.171 to 0.2.172 [#15745](https://github.com/apache/datafusion/pull/15745) (dependabot[bot]) +- Final release note touchups [#15741](https://github.com/apache/datafusion/pull/15741) (alamb) +- Refactor regexp slt tests [#15709](https://github.com/apache/datafusion/pull/15709) (kumarlokesh) +- ExecutionPlan: add APIs for filter pushdown & optimizer rule to apply them [#15566](https://github.com/apache/datafusion/pull/15566) (adriangb) +- Coerce and simplify FixedSizeBinary equality to literal binary [#15726](https://github.com/apache/datafusion/pull/15726) (leoyvens) +- Minor: simplify code in datafusion-proto [#15752](https://github.com/apache/datafusion/pull/15752) (alamb) +- chore(deps): bump clap from 4.5.35 to 4.5.36 [#15759](https://github.com/apache/datafusion/pull/15759) (dependabot[bot]) +- Support `Accumulator` for avg duration [#15468](https://github.com/apache/datafusion/pull/15468) (shruti2522) +- Show current SQL recursion limit in RecursionLimitExceeded error message [#15644](https://github.com/apache/datafusion/pull/15644) (kumarlokesh) +- Minor: fix flaky test in `aggregate.slt` [#15786](https://github.com/apache/datafusion/pull/15786) (xudong963) +- Minor: remove unused logic for limit pushdown [#15730](https://github.com/apache/datafusion/pull/15730) (zhuqi-lucas) +- chore(deps): bump sqllogictest from 0.28.0 to 0.28.1 [#15788](https://github.com/apache/datafusion/pull/15788) (dependabot[bot]) +- Add try_new for LogicalPlan::Join [#15757](https://github.com/apache/datafusion/pull/15757) (kumarlokesh) +- Minor: eliminate unnecessary struct creation in session state build [#15800](https://github.com/apache/datafusion/pull/15800) (Rachelint) +- chore(deps): bump half from 2.5.0 to 2.6.0 [#15806](https://github.com/apache/datafusion/pull/15806) (dependabot[bot]) +- Add `or_fun_call` and `unnecessary_lazy_evaluations` lints on `core` [#15807](https://github.com/apache/datafusion/pull/15807) (Rachelint) +- chore(deps): bump env_logger from 0.11.7 to 0.11.8 [#15823](https://github.com/apache/datafusion/pull/15823) (dependabot[bot]) +- Support unparsing `UNION` for distinct results [#15814](https://github.com/apache/datafusion/pull/15814) (phillipleblanc) +- Add `MemoryPool::memory_limit` to expose setting memory usage limit [#15828](https://github.com/apache/datafusion/pull/15828) (Rachelint) +- Preserve projection for inline scan [#15825](https://github.com/apache/datafusion/pull/15825) (jayzhan211) +- Minor: cleanup hash table after emit all [#15834](https://github.com/apache/datafusion/pull/15834) (jayzhan211) +- chore(deps): bump pyo3 from 0.24.1 to 0.24.2 [#15838](https://github.com/apache/datafusion/pull/15838) (dependabot[bot]) +- Minor: fix potential flaky test in aggregate.slt [#15829](https://github.com/apache/datafusion/pull/15829) (bikbov) +- Fix `ILIKE` expression support in SQL unparser [#15820](https://github.com/apache/datafusion/pull/15820) (ewgenius) +- Make `Diagnostic` easy/convinient to attach by using macro and avoiding `map_err` [#15796](https://github.com/apache/datafusion/pull/15796) (logan-keede) +- Feature/benchmark config from env [#15782](https://github.com/apache/datafusion/pull/15782) (ctsk) +- predicate pruning: support cast and try_cast for more types [#15764](https://github.com/apache/datafusion/pull/15764) (adriangb) +- Fix: fetch is missing in `plan_with_order_breaking_variants` method [#15842](https://github.com/apache/datafusion/pull/15842) (xudong963) +- Fix `CoalescePartitionsExec` proto serialization [#15824](https://github.com/apache/datafusion/pull/15824) (lewiszlw) +- Fix build failure caused by new `CoalescePartitionsExec::with_fetch` method [#15849](https://github.com/apache/datafusion/pull/15849) (lewiszlw) +- Fix ScalarValue::List comparison when the compared lists have different lengths [#15856](https://github.com/apache/datafusion/pull/15856) (gabotechs) +- chore: More details to `No UDF registered` error [#15843](https://github.com/apache/datafusion/pull/15843) (comphead) +- chore(deps): bump clap from 4.5.36 to 4.5.37 [#15853](https://github.com/apache/datafusion/pull/15853) (dependabot[bot]) +- Remove usage of `dbg!` [#15858](https://github.com/apache/datafusion/pull/15858) (phillipleblanc) +- Minor: Interval singleton [#15859](https://github.com/apache/datafusion/pull/15859) (jayzhan211) +- Make aggr fuzzer query builder more configurable [#15851](https://github.com/apache/datafusion/pull/15851) (Rachelint) +- chore(deps): bump aws-config from 1.6.1 to 1.6.2 [#15874](https://github.com/apache/datafusion/pull/15874) (dependabot[bot]) +- Add slt tests for `datafusion.execution.parquet.coerce_int96` setting [#15723](https://github.com/apache/datafusion/pull/15723) (alamb) +- Improve `ListingTable` / `ListingTableOptions` docs [#15767](https://github.com/apache/datafusion/pull/15767) (alamb) +- Migrate Optimizer tests to insta, part2 [#15884](https://github.com/apache/datafusion/pull/15884) (qstommyshu) +- Improve documentation for `FileSource`, `DataSource` and `DataSourceExec` [#15766](https://github.com/apache/datafusion/pull/15766) (alamb) +- Implement min max for dictionary types [#15827](https://github.com/apache/datafusion/pull/15827) (XiangpengHao) +- chore(deps): bump blake3 from 1.8.1 to 1.8.2 [#15890](https://github.com/apache/datafusion/pull/15890) (dependabot[bot]) +- Respect ignore_nulls in array_agg [#15544](https://github.com/apache/datafusion/pull/15544) (joroKr21) +- Set HashJoin seed [#15783](https://github.com/apache/datafusion/pull/15783) (ctsk) +- Saner handling of nulls inside arrays [#15149](https://github.com/apache/datafusion/pull/15149) (joroKr21) +- Keeping pull request in sync with the base branch [#15894](https://github.com/apache/datafusion/pull/15894) (xudong963) +- Fix `flatten` scalar function when inner list is `FixedSizeList` [#15898](https://github.com/apache/datafusion/pull/15898) (gstvg) +- support OR operator in binary `evaluate_bounds` [#15716](https://github.com/apache/datafusion/pull/15716) (davidhewitt) +- infer placeholder datatype for IN lists [#15864](https://github.com/apache/datafusion/pull/15864) (kczimm) +- Fix allow_update_branch [#15904](https://github.com/apache/datafusion/pull/15904) (xudong963) +- chore(deps): bump tokio from 1.44.1 to 1.44.2 [#15900](https://github.com/apache/datafusion/pull/15900) (dependabot[bot]) +- chore(deps): bump assert_cmd from 2.0.16 to 2.0.17 [#15909](https://github.com/apache/datafusion/pull/15909) (dependabot[bot]) +- Factor out Substrait consumers into separate files [#15794](https://github.com/apache/datafusion/pull/15794) (gabotechs) +- Unparse `UNNEST` projection with the table column alias [#15879](https://github.com/apache/datafusion/pull/15879) (goldmedal) +- Migrate Optimizer tests to insta, part3 [#15893](https://github.com/apache/datafusion/pull/15893) (qstommyshu) +- Minor: cleanup datafusion-spark scalar functions [#15921](https://github.com/apache/datafusion/pull/15921) (alamb) +- Fix ClickBench extended queries after update to APPROX_PERCENTILE_CONT [#15929](https://github.com/apache/datafusion/pull/15929) (alamb) +- Add extended query for checking improvement for blocked groups optimization [#15936](https://github.com/apache/datafusion/pull/15936) (Rachelint) +- Speedup `character_length` [#15931](https://github.com/apache/datafusion/pull/15931) (Dandandan) +- chore(deps): bump tokio-util from 0.7.14 to 0.7.15 [#15918](https://github.com/apache/datafusion/pull/15918) (dependabot[bot]) +- Migrate Optimizer tests to insta, part4 [#15937](https://github.com/apache/datafusion/pull/15937) (qstommyshu) +- fix query results for predicates referencing partition columns and data columns [#15935](https://github.com/apache/datafusion/pull/15935) (adriangb) +- chore(deps): bump substrait from 0.55.0 to 0.55.1 [#15941](https://github.com/apache/datafusion/pull/15941) (dependabot[bot]) +- Fix main CI by adding `rowsort` to slt test [#15942](https://github.com/apache/datafusion/pull/15942) (xudong963) +- Improve sqllogictest error reporting [#15905](https://github.com/apache/datafusion/pull/15905) (gabotechs) +- refactor filter pushdown apis [#15801](https://github.com/apache/datafusion/pull/15801) (adriangb) +- Add additional tests for filter pushdown apis [#15955](https://github.com/apache/datafusion/pull/15955) (adriangb) +- Improve filter pushdown optimizer rule performance [#15959](https://github.com/apache/datafusion/pull/15959) (adriangb) +- Reduce rehashing cost for primitive grouping by also reusing hash value [#15962](https://github.com/apache/datafusion/pull/15962) (Rachelint) +- chore(deps): bump chrono from 0.4.40 to 0.4.41 [#15956](https://github.com/apache/datafusion/pull/15956) (dependabot[bot]) +- refactor: replace `unwrap_or` with `unwrap_or_else` for improved lazy… [#15841](https://github.com/apache/datafusion/pull/15841) (NevroHelios) +- add benchmark code for `Reuse rows in row cursor stream` [#15913](https://github.com/apache/datafusion/pull/15913) (acking-you) +- [Update] : Removal of duplicate CI jobs [#15966](https://github.com/apache/datafusion/pull/15966) (Adez017) +- Segfault in ByteGroupValueBuilder [#15968](https://github.com/apache/datafusion/pull/15968) (thinkharderdev) +- make can_expr_be_pushed_down_with_schemas public again [#15971](https://github.com/apache/datafusion/pull/15971) (adriangb) +- re-export can_expr_be_pushed_down_with_schemas to be public [#15974](https://github.com/apache/datafusion/pull/15974) (adriangb) +- Migrate Optimizer tests to insta, part5 [#15945](https://github.com/apache/datafusion/pull/15945) (qstommyshu) +- Show LogicalType name for `INFORMATION_SCHEMA` [#15965](https://github.com/apache/datafusion/pull/15965) (goldmedal) +- chore(deps): bump sha2 from 0.10.8 to 0.10.9 [#15970](https://github.com/apache/datafusion/pull/15970) (dependabot[bot]) +- chore(deps): bump insta from 1.42.2 to 1.43.1 [#15988](https://github.com/apache/datafusion/pull/15988) (dependabot[bot]) +- [datafusion-spark] Add Spark-compatible hex function [#15947](https://github.com/apache/datafusion/pull/15947) (andygrove) +- refactor: remove deprecated `AvroExec` [#15987](https://github.com/apache/datafusion/pull/15987) (miroim) +- Substrait: Handle inner map fields in schema renaming [#15869](https://github.com/apache/datafusion/pull/15869) (cht42) +- refactor: remove deprecated `CsvExec` [#15991](https://github.com/apache/datafusion/pull/15991) (miroim) +- Migrate Optimizer tests to insta, part6 [#15984](https://github.com/apache/datafusion/pull/15984) (qstommyshu) +- chore(deps): bump nix from 0.29.0 to 0.30.1 [#16002](https://github.com/apache/datafusion/pull/16002) (dependabot[bot]) +- Implement RightSemi join for SortMergeJoin [#15972](https://github.com/apache/datafusion/pull/15972) (irenjj) +- Migrate Optimizer tests to insta, part7 [#16010](https://github.com/apache/datafusion/pull/16010) (qstommyshu) +- chore(deps): bump sysinfo from 0.34.2 to 0.35.1 [#16027](https://github.com/apache/datafusion/pull/16027) (dependabot[bot]) +- refactor: move `should_enable_page_index` from `mod.rs` to `opener.rs` [#16026](https://github.com/apache/datafusion/pull/16026) (miroim) +- chore(deps): bump sqllogictest from 0.28.1 to 0.28.2 [#16037](https://github.com/apache/datafusion/pull/16037) (dependabot[bot]) +- chores: Add lint rule to enforce string formatting style [#16024](https://github.com/apache/datafusion/pull/16024) (Lordworms) +- Use human-readable byte sizes in `EXPLAIN` [#16043](https://github.com/apache/datafusion/pull/16043) (tlm365) +- Docs: Add example of creating a field in `return_field_from_args` [#16039](https://github.com/apache/datafusion/pull/16039) (alamb) +- Support `MIN` and `MAX` for `DataType::List` [#16025](https://github.com/apache/datafusion/pull/16025) (gabotechs) +- Improve docs for Exprs and scalar functions [#16036](https://github.com/apache/datafusion/pull/16036) (alamb) +- Add h2o window benchmark [#16003](https://github.com/apache/datafusion/pull/16003) (2010YOUY01) +- Fix Infer prepare statement type tests [#15743](https://github.com/apache/datafusion/pull/15743) (brayanjuls) +- style: simplify some strings for readability [#15999](https://github.com/apache/datafusion/pull/15999) (hamirmahal) +- support simple/cross lateral joins [#16015](https://github.com/apache/datafusion/pull/16015) (jayzhan211) +- Improve error message on Out of Memory [#16050](https://github.com/apache/datafusion/pull/16050) (ding-young) +- chore(deps): bump the arrow-parquet group with 7 updates [#16047](https://github.com/apache/datafusion/pull/16047) (dependabot[bot]) +- chore(deps): bump petgraph from 0.7.1 to 0.8.1 [#15669](https://github.com/apache/datafusion/pull/15669) (dependabot[bot]) +- [datafusion-spark] Add Spark-compatible `char` expression [#15994](https://github.com/apache/datafusion/pull/15994) (andygrove) +- chore(deps): bump substrait from 0.55.1 to 0.56.0 [#16091](https://github.com/apache/datafusion/pull/16091) (dependabot[bot]) +- Add test that demonstrate behavior for `collect_statistics` [#16098](https://github.com/apache/datafusion/pull/16098) (alamb) +- Refactor substrait producer into multiple files [#16089](https://github.com/apache/datafusion/pull/16089) (gabotechs) +- Fix temp dir leak in tests [#16094](https://github.com/apache/datafusion/pull/16094) (findepi) +- Label Spark functions PRs with spark label [#16095](https://github.com/apache/datafusion/pull/16095) (findepi) +- Added SLT tests for IMDB benchmark queries [#16067](https://github.com/apache/datafusion/pull/16067) (kumarlokesh) +- chore(CI) Upgrade toolchain to Rust-1.87 [#16068](https://github.com/apache/datafusion/pull/16068) (kadai0308) +- minor: Add benchmark query and corresponding documentation for Average Duration [#16105](https://github.com/apache/datafusion/pull/16105) (logan-keede) +- Use qualified names on DELETE selections [#16033](https://github.com/apache/datafusion/pull/16033) (nuno-faria) +- chore(deps): bump testcontainers from 0.23.3 to 0.24.0 [#15989](https://github.com/apache/datafusion/pull/15989) (dependabot[bot]) +- Clean up ExternalSorter and use upstream kernel [#16109](https://github.com/apache/datafusion/pull/16109) (alamb) +- Test Duration in aggregation `fuzz` tests [#16111](https://github.com/apache/datafusion/pull/16111) (alamb) +- Move PruningStatistics into datafusion::common [#16069](https://github.com/apache/datafusion/pull/16069) (adriangb) +- Revert use file schema in parquet pruning [#16086](https://github.com/apache/datafusion/pull/16086) (adriangb) +- Minor: Add `ScalarFunctionArgs::return_type` method [#16113](https://github.com/apache/datafusion/pull/16113) (alamb) +- Fix `contains` function expression [#16046](https://github.com/apache/datafusion/pull/16046) (liamzwbao) +- chore: Use materialized data for filter pushdown tests [#16123](https://github.com/apache/datafusion/pull/16123) (comphead) +- chore: Upgrade rand crate and some other minor crates [#16062](https://github.com/apache/datafusion/pull/16062) (comphead) +- Include data types in logical plans of inferred prepare statements [#16019](https://github.com/apache/datafusion/pull/16019) (brayanjuls) +- CI: Fix extended test failure [#16144](https://github.com/apache/datafusion/pull/16144) (2010YOUY01) +- Fix: handle column name collisions when combining UNION logical inputs & nested Column expressions in maybe_fix_physical_column_name [#16064](https://github.com/apache/datafusion/pull/16064) (LiaCastaneda) +- adding support for Min/Max over LargeList and FixedSizeList [#16071](https://github.com/apache/datafusion/pull/16071) (logan-keede) +- Move prepare/parameter handling tests into `params.rs` [#16141](https://github.com/apache/datafusion/pull/16141) (liamzwbao) +- Minor: Add `Accumulator::return_type` and `StateFieldsArgs::return_type` to help with upgrade to 48 [#16112](https://github.com/apache/datafusion/pull/16112) (alamb) +- Support filtering specific sqllogictests identified by line number [#16029](https://github.com/apache/datafusion/pull/16029) (gabotechs) +- Enrich GroupedHashAggregateStream name to ease debugging Resources exhausted errors [#16152](https://github.com/apache/datafusion/pull/16152) (ahmed-mez) +- chore(deps): bump uuid from 1.16.0 to 1.17.0 [#16162](https://github.com/apache/datafusion/pull/16162) (dependabot[bot]) +- Clarify docs and names in parquet predicate pushdown tests [#16155](https://github.com/apache/datafusion/pull/16155) (alamb) +- Minor: Fix name() for FilterPushdown physical optimizer rule [#16175](https://github.com/apache/datafusion/pull/16175) (adriangb) +- migrate tests in `pool.rs` to use insta [#16145](https://github.com/apache/datafusion/pull/16145) (lifan-ake) +- refactor(optimizer): Add support for dynamically adding test tables [#16138](https://github.com/apache/datafusion/pull/16138) (atahanyorganci) +- [Minor] Speedup TPC-H benchmark run with memtable option [#16159](https://github.com/apache/datafusion/pull/16159) (Dandandan) +- Fast path for joins with distinct values in build side [#16153](https://github.com/apache/datafusion/pull/16153) (Dandandan) +- chore: Reduce repetition in the parameter type inference tests [#16079](https://github.com/apache/datafusion/pull/16079) (jsai28) +- chore(deps): bump tokio from 1.45.0 to 1.45.1 [#16190](https://github.com/apache/datafusion/pull/16190) (dependabot[bot]) +- Improve `unproject_sort_expr` to handle arbitrary expressions [#16127](https://github.com/apache/datafusion/pull/16127) (phillipleblanc) +- chore(deps): bump rustyline from 15.0.0 to 16.0.0 [#16194](https://github.com/apache/datafusion/pull/16194) (dependabot[bot]) +- migrate `logical_plan` tests to insta [#16184](https://github.com/apache/datafusion/pull/16184) (lifan-ake) +- chore(deps): bump clap from 4.5.38 to 4.5.39 [#16204](https://github.com/apache/datafusion/pull/16204) (dependabot[bot]) +- implement `AggregateExec.partition_statistics` [#15954](https://github.com/apache/datafusion/pull/15954) (UBarney) +- Propagate .execute() calls immediately in `RepartitionExec` [#16093](https://github.com/apache/datafusion/pull/16093) (gabotechs) +- Set aggregation hash seed [#16165](https://github.com/apache/datafusion/pull/16165) (ctsk) +- Fix ScalarStructBuilder::build() for an empty struct [#16205](https://github.com/apache/datafusion/pull/16205) (Blizzara) +- Return an error on overflow in `do_append_val_inner` [#16201](https://github.com/apache/datafusion/pull/16201) (liamzwbao) +- chore(deps): bump testcontainers-modules from 0.12.0 to 0.12.1 [#16212](https://github.com/apache/datafusion/pull/16212) (dependabot[bot]) +- Substrait: handle identical grouping expressions [#16189](https://github.com/apache/datafusion/pull/16189) (cht42) +- Add new stats pruning helpers to allow combining partition values in file level stats [#16139](https://github.com/apache/datafusion/pull/16139) (adriangb) +- Implement schema adapter support for FileSource and add integration tests [#16148](https://github.com/apache/datafusion/pull/16148) (kosiew) +- Minor: update documentation for PrunableStatistics [#16213](https://github.com/apache/datafusion/pull/16213) (alamb) +- Remove use of deprecated dict_ordered in datafusion-proto (#16218) [#16220](https://github.com/apache/datafusion/pull/16220) (cj-zhukov) +- Minor: Print cargo command in bench script [#16236](https://github.com/apache/datafusion/pull/16236) (2010YOUY01) +- Simplify FileSource / SchemaAdapterFactory API [#16214](https://github.com/apache/datafusion/pull/16214) (alamb) +- Add dicts to aggregation fuzz testing [#16232](https://github.com/apache/datafusion/pull/16232) (blaginin) +- chore(deps): bump sysinfo from 0.35.1 to 0.35.2 [#16247](https://github.com/apache/datafusion/pull/16247) (dependabot[bot]) +- Improve performance of constant aggregate window expression [#16234](https://github.com/apache/datafusion/pull/16234) (suibianwanwank) +- Support compound identifier when parsing tuples [#16225](https://github.com/apache/datafusion/pull/16225) (hozan23) +- Schema adapter helper [#16108](https://github.com/apache/datafusion/pull/16108) (kosiew) +- Update tpch, clickbench, sort_tpch to mark failed queries [#16182](https://github.com/apache/datafusion/pull/16182) (ding-young) +- Adjust slttest to pass without RUST_BACKTRACE enabled [#16251](https://github.com/apache/datafusion/pull/16251) (alamb) +- Handle dicts for distinct count [#15871](https://github.com/apache/datafusion/pull/15871) (blaginin) +- Add `--substrait-round-trip` option in sqllogictests [#16183](https://github.com/apache/datafusion/pull/16183) (gabotechs) +- Minor: fix upgrade papercut `pub use PruningStatistics` [#16264](https://github.com/apache/datafusion/pull/16264) (alamb) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 30 dependabot[bot] + 29 Andrew Lamb + 16 xudong.w + 14 Adrian Garcia Badaracco + 10 Chen Chongchen + 8 Gabriel + 8 Oleks V + 7 miro + 6 Tommy shu + 6 kamille + 5 Lokesh + 5 Tim Saucer + 4 Dmitrii Blaginin + 4 Jay Zhan + 4 Nuno Faria + 4 Yongting You + 4 logan-keede + 3 Christian + 3 Daniël Heres + 3 Liam Bao + 3 Phillip LeBlanc + 3 Piotr Findeisen + 3 ding-young + 2 Andy Grove + 2 Atahan Yorgancı + 2 Brayan Jules + 2 Georgi Krastev + 2 Jax Liu + 2 Jérémie Drouet + 2 LB7666 + 2 Leonardo Yvens + 2 Qi Zhu + 2 Sergey Zhukov + 2 Shruti Sharma + 2 Tai Le Manh + 2 aditya singh rathore + 2 ake + 2 cht42 + 2 gstvg + 2 kosiew + 2 niebayes + 2 张林伟 + 1 Ahmed Mezghani + 1 Alexander Droste + 1 Andy Yen + 1 Arka Dash + 1 Arttu + 1 Dan Harris + 1 David Hewitt + 1 Davy + 1 Ed Seidl + 1 Eshed Schacham + 1 Evgenii Khramkov + 1 Florent Monjalet + 1 Galim Bikbov + 1 Garam Choi + 1 Hamir Mahal + 1 Hendrik Makait + 1 Jonathan Chen + 1 Joseph Fahnestock + 1 Kevin Zimmerman + 1 Lordworms + 1 Lía Adriana + 1 Matt Butrovich + 1 Namgung Chan + 1 Nelson Antunes + 1 Patrick Sullivan + 1 Raz Luvaton + 1 Ruihang Xia + 1 Ryan Roelke + 1 Sam Hughes + 1 Shehab Amin + 1 Sile Zhou + 1 Simon Vandel Sillesen + 1 Tom Montgomery + 1 UBarney + 1 Victorien + 1 Xiangpeng Hao + 1 Zaki + 1 chen quan + 1 delamarch3 + 1 discord9 + 1 hozan23 + 1 irenjj + 1 jsai28 + 1 m09526 + 1 suibianwanwan + 1 the0ninjas + 1 wiedld +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/49.0.0.md b/dev/changelog/49.0.0.md new file mode 100644 index 0000000000000..239c7c9dfc973 --- /dev/null +++ b/dev/changelog/49.0.0.md @@ -0,0 +1,387 @@ + + +# Apache DataFusion 49.0.0 Changelog + +This release consists of 253 commits from 71 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Breaking changes:** + +- feat: add metadata to literal expressions [#16170](https://github.com/apache/datafusion/pull/16170) (timsaucer) +- [MAJOR] Equivalence System Overhaul [#16217](https://github.com/apache/datafusion/pull/16217) (ozankabak) +- remove unused methods in SortExec [#16457](https://github.com/apache/datafusion/pull/16457) (adriangb) +- Move Pruning Logic to a Dedicated datafusion-pruning Crate for Improved Modularity [#16549](https://github.com/apache/datafusion/pull/16549) (kosiew) +- Fix type of ExecutionOptions::time_zone [#16569](https://github.com/apache/datafusion/pull/16569) (findepi) +- Convert Option> to Vec [#16615](https://github.com/apache/datafusion/pull/16615) (ViggoC) +- Refactor error handling to use boxed errors for DataFusionError variants [#16672](https://github.com/apache/datafusion/pull/16672) (kosiew) +- Reuse Rows allocation in RowCursorStream [#16647](https://github.com/apache/datafusion/pull/16647) (Dandandan) +- refactor: shrink `SchemaError` [#16653](https://github.com/apache/datafusion/pull/16653) (crepererum) +- Remove unused AggregateUDF struct [#16683](https://github.com/apache/datafusion/pull/16683) (ViggoC) +- Bump the MSRV to `1.85.1` due to transitive dependencies (`aws-sdk`) [#16728](https://github.com/apache/datafusion/pull/16728) (rtyler) + +**Performance related:** + +- Add late pruning of Parquet files based on file level statistics [#16014](https://github.com/apache/datafusion/pull/16014) (adriangb) +- Add fast paths for try_process_unnest [#16389](https://github.com/apache/datafusion/pull/16389) (simonvandel) +- Set the default value of `datafusion.execution.collect_statistics` to `true` [#16447](https://github.com/apache/datafusion/pull/16447) (AdamGS) +- Perf: Optimize CursorValues compare performance for StringViewArray (1.4X faster for sort-tpch Q11) [#16509](https://github.com/apache/datafusion/pull/16509) (zhuqi-lucas) +- Simplify predicates in `PushDownFilter` optimizer rule [#16362](https://github.com/apache/datafusion/pull/16362) (xudong963) +- optimize `ScalarValue::to_array_of_size` for structural types [#16706](https://github.com/apache/datafusion/pull/16706) (ding-young) +- Refactor filter pushdown APIs to enable joins to pass through filters [#16732](https://github.com/apache/datafusion/pull/16732) (adriangb) +- perf: Optimize hash joins with an empty build side [#16716](https://github.com/apache/datafusion/pull/16716) (nuno-faria) +- Per file filter evaluation [#15057](https://github.com/apache/datafusion/pull/15057) (adriangb) + +**Implemented enhancements:** + +- feat: Support defining custom MetricValues in PhysicalPlans [#16195](https://github.com/apache/datafusion/pull/16195) (sfluor) +- feat: Allow cancelling of grouping operations which are CPU bound [#16196](https://github.com/apache/datafusion/pull/16196) (zhuqi-lucas) +- feat: support FixedSizeList for array_has [#16333](https://github.com/apache/datafusion/pull/16333) (chenkovsky) +- feat: Support tpch and tpch10 benchmark for csv format [#16373](https://github.com/apache/datafusion/pull/16373) (zhuqi-lucas) +- feat: Support RightMark join for NestedLoop and Hash join [#16083](https://github.com/apache/datafusion/pull/16083) (jonathanc-n) +- feat: mapping sql Char/Text/String default to Utf8View [#16290](https://github.com/apache/datafusion/pull/16290) (zhuqi-lucas) +- feat: support fixed size list for array reverse [#16423](https://github.com/apache/datafusion/pull/16423) (chenkovsky) +- feat: add SchemaProvider::table_type(table_name: &str) [#16401](https://github.com/apache/datafusion/pull/16401) (epgif) +- feat: derive `Debug` and `Clone` for `ScalarFunctionArgs` [#16471](https://github.com/apache/datafusion/pull/16471) (crepererum) +- feat: support `map_entries` builtin function [#16557](https://github.com/apache/datafusion/pull/16557) (comphead) +- feat: add `array_min` scalar function and associated tests [#16574](https://github.com/apache/datafusion/pull/16574) (dharanad) +- feat: Finalize support for `RightMark` join + `Mark` join swap [#16488](https://github.com/apache/datafusion/pull/16488) (jonathanc-n) +- feat: Parquet modular encryption [#16351](https://github.com/apache/datafusion/pull/16351) (corwinjoy) +- feat: Support `u32` indices for `HashJoinExec` [#16434](https://github.com/apache/datafusion/pull/16434) (jonathanc-n) +- feat: expose intersect distinct/except distinct in dataframe api [#16578](https://github.com/apache/datafusion/pull/16578) (chenkovsky) +- feat: Add a configuration to make parquet encryption optional [#16649](https://github.com/apache/datafusion/pull/16649) (corwinjoy) + +**Fixed bugs:** + +- fix: preserve null_equals_null flag in eliminate_cross_join rule [#16356](https://github.com/apache/datafusion/pull/16356) (waynexia) +- fix: Fix SparkSha2 to be compliant with Spark response and add support for Int32 [#16350](https://github.com/apache/datafusion/pull/16350) (rishvin) +- fix: Fixed error handling for `generate_series/range` [#16391](https://github.com/apache/datafusion/pull/16391) (jonathanc-n) +- fix: Enable WASM compilation by making sqlparser's recursive-protection optional [#16418](https://github.com/apache/datafusion/pull/16418) (jonmmease) +- fix: create file for empty stream [#16342](https://github.com/apache/datafusion/pull/16342) (chenkovsky) +- fix: document and fix macro hygiene for `config_field!` [#16473](https://github.com/apache/datafusion/pull/16473) (crepererum) +- fix: make `with_new_state` a trait method for `ExecutionPlan` [#16469](https://github.com/apache/datafusion/pull/16469) (geoffreyclaude) +- fix: column indices in FFI partition evaluator [#16480](https://github.com/apache/datafusion/pull/16480) (timsaucer) +- fix: support within_group [#16538](https://github.com/apache/datafusion/pull/16538) (chenkovsky) +- fix: disallow specify both order_by and within_group [#16606](https://github.com/apache/datafusion/pull/16606) (watchingthewheelsgo) +- fix: format within_group error message [#16613](https://github.com/apache/datafusion/pull/16613) (watchingthewheelsgo) +- fix: reserved keywords in qualified column names [#16584](https://github.com/apache/datafusion/pull/16584) (crepererum) +- fix: support scalar function nested in get_field in Unparser [#16610](https://github.com/apache/datafusion/pull/16610) (chenkovsky) +- fix: sqllogictest runner label condition mismatch [#16633](https://github.com/apache/datafusion/pull/16633) (lliangyu-lin) +- fix: port arrow inline fast key fix to datafusion [#16698](https://github.com/apache/datafusion/pull/16698) (zhuqi-lucas) +- fix: try to lower plain reserved functions to columns as well [#16669](https://github.com/apache/datafusion/pull/16669) (crepererum) +- fix: Fix CI failing due to #16686 [#16718](https://github.com/apache/datafusion/pull/16718) (jonathanc-n) +- fix: return NULL if any of the param to make_date is NULL [#16759](https://github.com/apache/datafusion/pull/16759) (feniljain) +- fix: add `order_requirement` & `dist_requirement` to `OutputRequirementExec` display [#16726](https://github.com/apache/datafusion/pull/16726) (Loaki07) +- fix: support nullable columns in pre-sorted data sources [#16783](https://github.com/apache/datafusion/pull/16783) (crepererum) +- fix: The inconsistency between scalar and array on the cast decimal to timestamp [#16539](https://github.com/apache/datafusion/pull/16539) (chenkovsky) +- fix: unit test for object_storage [#16824](https://github.com/apache/datafusion/pull/16824) (chenkovsky) +- fix(docs): Update broken links to `TableProvider` docs [#16830](https://github.com/apache/datafusion/pull/16830) (jcsherin) + +**Documentation updates:** + +- Minor: Add upgrade guide for `Expr::WindowFunction` [#16313](https://github.com/apache/datafusion/pull/16313) (alamb) +- Fix `array_position` on empty list [#16292](https://github.com/apache/datafusion/pull/16292) (Blizzara) +- Fix: mark "Spilling (to disk) Joins" as supported in features [#16343](https://github.com/apache/datafusion/pull/16343) (kosiew) +- Fix cp_solver doc formatting [#16352](https://github.com/apache/datafusion/pull/16352) (xudong963) +- docs: Expand `MemoryPool` docs with related structs [#16289](https://github.com/apache/datafusion/pull/16289) (2010YOUY01) +- Support datafusion-cli access to public S3 buckets that do not require authentication [#16300](https://github.com/apache/datafusion/pull/16300) (alamb) +- Document Table Constraint Enforcement Behavior in Custom Table Providers Guide [#16340](https://github.com/apache/datafusion/pull/16340) (kosiew) +- doc: Add SQL examples for SEMI + ANTI Joins [#16316](https://github.com/apache/datafusion/pull/16316) (jonathanc-n) +- [datafusion-spark] Example of using Spark compatible function library [#16384](https://github.com/apache/datafusion/pull/16384) (alamb) +- Add note in upgrade guide about changes to `Expr::Scalar` in 48.0.0 [#16360](https://github.com/apache/datafusion/pull/16360) (alamb) +- Update PMC management instructions to follow new ASF process [#16417](https://github.com/apache/datafusion/pull/16417) (alamb) +- Add design process section to the docs [#16397](https://github.com/apache/datafusion/pull/16397) (alamb) +- Unify Metadata Handing: use `FieldMetadata` in `Expr::Alias` and `ExprSchemable` [#16320](https://github.com/apache/datafusion/pull/16320) (alamb) +- TopK dynamic filter pushdown attempt 2 [#15770](https://github.com/apache/datafusion/pull/15770) (adriangb) +- Update Roadmap documentation [#16399](https://github.com/apache/datafusion/pull/16399) (alamb) +- doc: Add comments to clarify algorithm for `MarkJoin`s [#16436](https://github.com/apache/datafusion/pull/16436) (jonathanc-n) +- Add compression option to SpillManager [#16268](https://github.com/apache/datafusion/pull/16268) (ding-young) +- Redirect user defined function webpage [#16475](https://github.com/apache/datafusion/pull/16475) (alamb) +- Use Tokio's task budget consistently, better APIs to support task cancellation [#16398](https://github.com/apache/datafusion/pull/16398) (pepijnve) +- doc: upgrade guide for new compression option for spill files [#16472](https://github.com/apache/datafusion/pull/16472) (2010YOUY01) +- Introduce Async User Defined Functions [#14837](https://github.com/apache/datafusion/pull/14837) (goldmedal) +- Minor: Add more links to cooperative / scheduling docs [#16484](https://github.com/apache/datafusion/pull/16484) (alamb) +- doc: Document DESCRIBE comman in ddl.md [#16524](https://github.com/apache/datafusion/pull/16524) (krikera) +- Add more doc for physical filter pushdown [#16504](https://github.com/apache/datafusion/pull/16504) (xudong963) +- chore: fix CI failures on `ddl.md` [#16526](https://github.com/apache/datafusion/pull/16526) (comphead) +- Add some comments about adding new dependencies in datafusion-sql [#16543](https://github.com/apache/datafusion/pull/16543) (alamb) +- Add note for planning release in Upgrade Guides [#16534](https://github.com/apache/datafusion/pull/16534) (xudong963) +- Consolidate configuration sections in docs [#16544](https://github.com/apache/datafusion/pull/16544) (alamb) +- Minor: add clearer link to the main website from intro paragraph. [#16556](https://github.com/apache/datafusion/pull/16556) (alamb) +- Simplify AsyncScalarUdfImpl so it extends ScalarUdfImpl [#16523](https://github.com/apache/datafusion/pull/16523) (alamb) +- docs: Minor grammatical fixes for the scalar UDF docs [#16618](https://github.com/apache/datafusion/pull/16618) (ianthetechie) +- Implementation for regex_instr [#15928](https://github.com/apache/datafusion/pull/15928) (nirnayroy) +- Update Upgrade Guide for 48.0.1 [#16699](https://github.com/apache/datafusion/pull/16699) (alamb) +- ensure MemTable has at least one partition [#16754](https://github.com/apache/datafusion/pull/16754) (waynexia) +- Restore custom SchemaAdapter functionality for Parquet [#16791](https://github.com/apache/datafusion/pull/16791) (adriangb) +- Update `upgrading.md` for new unified config for sql string mapping to utf8view [#16809](https://github.com/apache/datafusion/pull/16809) (zhuqi-lucas) +- docs: Remove reference to forthcoming example (#16817) [#16818](https://github.com/apache/datafusion/pull/16818) (m09526) +- docs: Fix broken links [#16839](https://github.com/apache/datafusion/pull/16839) (2010YOUY01) +- Add note to upgrade guide about MSRV update [#16845](https://github.com/apache/datafusion/pull/16845) (alamb) + +**Other:** + +- chore(deps): bump sqllogictest from 0.28.2 to 0.28.3 [#16286](https://github.com/apache/datafusion/pull/16286) (dependabot[bot]) +- chore(deps-dev): bump webpack-dev-server from 4.15.1 to 5.2.1 in /datafusion/wasmtest/datafusion-wasm-app [#16253](https://github.com/apache/datafusion/pull/16253) (dependabot[bot]) +- Improve DataFusion subcrate readme files [#16263](https://github.com/apache/datafusion/pull/16263) (alamb) +- Fix intermittent SQL logic test failure in limit.slt by adding ORDER BY clause [#16257](https://github.com/apache/datafusion/pull/16257) (kosiew) +- Extend benchmark comparison script with more detailed statistics [#16262](https://github.com/apache/datafusion/pull/16262) (pepijnve) +- chore(deps): bump flate2 from 1.1.1 to 1.1.2 [#16338](https://github.com/apache/datafusion/pull/16338) (dependabot[bot]) +- chore(deps): bump petgraph from 0.8.1 to 0.8.2 [#16337](https://github.com/apache/datafusion/pull/16337) (dependabot[bot]) +- chore(deps): bump substrait from 0.56.0 to 0.57.0 [#16143](https://github.com/apache/datafusion/pull/16143) (dependabot[bot]) +- Add test for ordering of predicate pushdown into parquet [#16169](https://github.com/apache/datafusion/pull/16169) (adriangb) +- Fix distinct count for DictionaryArray to correctly account for nulls in values array [#16258](https://github.com/apache/datafusion/pull/16258) (kosiew) +- Fix inconsistent schema projection in ListingTable even when schema is specified [#16305](https://github.com/apache/datafusion/pull/16305) (kosiew) +- tpch: move reading of SQL queries out of timed span. [#16357](https://github.com/apache/datafusion/pull/16357) (pepijnve) +- chore(deps): bump clap from 4.5.39 to 4.5.40 [#16354](https://github.com/apache/datafusion/pull/16354) (dependabot[bot]) +- chore(deps): bump syn from 2.0.101 to 2.0.102 [#16355](https://github.com/apache/datafusion/pull/16355) (dependabot[bot]) +- Encapsulate metadata for literals on to a `FieldMetadata` structure [#16317](https://github.com/apache/datafusion/pull/16317) (alamb) +- Add support `UInt64` and other integer data types for `to_hex` [#16335](https://github.com/apache/datafusion/pull/16335) (tlm365) +- Document `copy_array_data` function with example [#16361](https://github.com/apache/datafusion/pull/16361) (alamb) +- Fix array_agg memory over use [#16346](https://github.com/apache/datafusion/pull/16346) (gabotechs) +- Update publish command [#16377](https://github.com/apache/datafusion/pull/16377) (xudong963) +- Add more context to error message for datafusion-cli config failure [#16379](https://github.com/apache/datafusion/pull/16379) (alamb) +- Fix: datafusion-sqllogictest 48.0.0 can't be published [#16376](https://github.com/apache/datafusion/pull/16376) (xudong963) +- bug: remove busy-wait while sort is ongoing [#16322](https://github.com/apache/datafusion/pull/16322) (pepijnve) +- chore: refactor Substrait consumer's "rename_field" and implement the rest of types [#16345](https://github.com/apache/datafusion/pull/16345) (Blizzara) +- chore(deps): bump object_store from 0.12.1 to 0.12.2 [#16368](https://github.com/apache/datafusion/pull/16368) (dependabot[bot]) +- Disable `datafusion-cli` tests for hash_collision tests, fix extended CI [#16382](https://github.com/apache/datafusion/pull/16382) (alamb) +- Fix array_concat with NULL arrays [#16348](https://github.com/apache/datafusion/pull/16348) (alexanderbianchi) +- Minor: add testing case for add YieldStreamExec and polish docs [#16369](https://github.com/apache/datafusion/pull/16369) (zhuqi-lucas) +- chore(deps): bump aws-config from 1.6.3 to 1.8.0 [#16394](https://github.com/apache/datafusion/pull/16394) (dependabot[bot]) +- fix typo in test file name [#16403](https://github.com/apache/datafusion/pull/16403) (adriangb) +- Add topk_tpch benchmark [#16410](https://github.com/apache/datafusion/pull/16410) (Dandandan) +- Reduce some cloning [#16404](https://github.com/apache/datafusion/pull/16404) (simonvandel) +- chore(deps): bump syn from 2.0.102 to 2.0.103 [#16393](https://github.com/apache/datafusion/pull/16393) (dependabot[bot]) +- Simplify expressions passed to table functions [#16388](https://github.com/apache/datafusion/pull/16388) (simonvandel) +- Minor: Clean-up `bench.sh` usage message [#16416](https://github.com/apache/datafusion/pull/16416) (2010YOUY01) +- chore(deps): bump rust_decimal from 1.37.1 to 1.37.2 [#16422](https://github.com/apache/datafusion/pull/16422) (dependabot[bot]) +- Migrate core test to insta, part1 [#16324](https://github.com/apache/datafusion/pull/16324) (Chen-Yuan-Lai) +- chore(deps): bump mimalloc from 0.1.46 to 0.1.47 [#16426](https://github.com/apache/datafusion/pull/16426) (dependabot[bot]) +- chore(deps): bump libc from 0.2.172 to 0.2.173 [#16421](https://github.com/apache/datafusion/pull/16421) (dependabot[bot]) +- Use dedicated NullEquality enum instead of null_equals_null boolean [#16419](https://github.com/apache/datafusion/pull/16419) (tobixdev) +- chore: generate basic spark function tests [#16409](https://github.com/apache/datafusion/pull/16409) (shehabgamin) +- Fix CI Failure: replace false with NullEqualsNothing [#16437](https://github.com/apache/datafusion/pull/16437) (ding-young) +- chore(deps): bump bzip2 from 0.5.2 to 0.6.0 [#16441](https://github.com/apache/datafusion/pull/16441) (dependabot[bot]) +- chore(deps): bump libc from 0.2.173 to 0.2.174 [#16440](https://github.com/apache/datafusion/pull/16440) (dependabot[bot]) +- Remove redundant license-header-check CI job [#16451](https://github.com/apache/datafusion/pull/16451) (alamb) +- Remove unused feature in `physical-plan` and fix compilation error in benchmark [#16449](https://github.com/apache/datafusion/pull/16449) (AdamGS) +- Temporarily fix bug in dynamic top-k optimization [#16465](https://github.com/apache/datafusion/pull/16465) (AdamGS) +- Ignore `sort_query_fuzzer_runner` [#16462](https://github.com/apache/datafusion/pull/16462) (blaginin) +- Revert "Ignore `sort_query_fuzzer_runner` (#16462)" [#16470](https://github.com/apache/datafusion/pull/16470) (2010YOUY01) +- Reapply "Ignore `sort_query_fuzzer_runner` (#16462)" (#16470) [#16485](https://github.com/apache/datafusion/pull/16485) (alamb) +- Fix constant window for evaluate stateful [#16430](https://github.com/apache/datafusion/pull/16430) (suibianwanwank) +- Use UDTF name in logical plan table scan [#16468](https://github.com/apache/datafusion/pull/16468) (Jeadie) +- refactor reassign_predicate_columns to accept an &Schema instead of &Arc [#16499](https://github.com/apache/datafusion/pull/16499) (adriangb) +- re-enable `sort_query_fuzzer_runner` [#16491](https://github.com/apache/datafusion/pull/16491) (adriangb) +- Example for using a separate threadpool for CPU bound work (try 3) [#16331](https://github.com/apache/datafusion/pull/16331) (alamb) +- chore(deps): bump syn from 2.0.103 to 2.0.104 [#16507](https://github.com/apache/datafusion/pull/16507) (dependabot[bot]) +- use 'lit' as the field name for literal values [#16498](https://github.com/apache/datafusion/pull/16498) (adriangb) +- [datafusion-spark] Implement `factorical` function [#16125](https://github.com/apache/datafusion/pull/16125) (tlm365) +- Add DESC alias for DESCRIBE command. [#16514](https://github.com/apache/datafusion/pull/16514) (lucqui) +- Split clickbench query set into one file per query [#16476](https://github.com/apache/datafusion/pull/16476) (pepijnve) +- Support query filter on all benchmarks [#16477](https://github.com/apache/datafusion/pull/16477) (pepijnve) +- `TableProvider` to skip files in the folder which non relevant to selected reader [#16487](https://github.com/apache/datafusion/pull/16487) (comphead) +- Reuse `BaselineMetrics` in `UnnestMetrics` [#16497](https://github.com/apache/datafusion/pull/16497) (hendrikmakait) +- Fix array_has to return false for empty arrays instead of null [#16529](https://github.com/apache/datafusion/pull/16529) (kosiew) +- Minor: Add documentation to `AggregateWindowExpr::get_result_column` [#16479](https://github.com/apache/datafusion/pull/16479) (alamb) +- Fix WindowFrame::new with order_by [#16537](https://github.com/apache/datafusion/pull/16537) (findepi) +- chore(deps): bump object_store from 0.12.1 to 0.12.2 [#16548](https://github.com/apache/datafusion/pull/16548) (dependabot[bot]) +- chore(deps): bump mimalloc from 0.1.46 to 0.1.47 [#16547](https://github.com/apache/datafusion/pull/16547) (dependabot[bot]) +- Add support for Arrow Duration type in Substrait [#16503](https://github.com/apache/datafusion/pull/16503) (jkosh44) +- Allow unparser to override the alias name for the specific dialect [#16540](https://github.com/apache/datafusion/pull/16540) (goldmedal) +- Avoid clones when calling find_window_exprs [#16551](https://github.com/apache/datafusion/pull/16551) (findepi) +- Update `spilled_bytes` metric to reflect actual disk usage [#16535](https://github.com/apache/datafusion/pull/16535) (ding-young) +- adapt filter expressions to file schema during parquet scan [#16461](https://github.com/apache/datafusion/pull/16461) (adriangb) +- datafusion-cli: Use correct S3 region if it is not specified [#16502](https://github.com/apache/datafusion/pull/16502) (liamzwbao) +- Add nested struct casting support and integrate into SchemaAdapter [#16371](https://github.com/apache/datafusion/pull/16371) (kosiew) +- Improve err message grammar [#16566](https://github.com/apache/datafusion/pull/16566) (findepi) +- refactor: move PruningPredicate into its own module [#16587](https://github.com/apache/datafusion/pull/16587) (adriangb) +- chore(deps): bump indexmap from 2.9.0 to 2.10.0 [#16582](https://github.com/apache/datafusion/pull/16582) (dependabot[bot]) +- Skip re-pruning based on partition values and file level stats if there are no dynamic filters [#16424](https://github.com/apache/datafusion/pull/16424) (adriangb) +- Support timestamp and date arguments for `range` and `generate_series` table functions [#16552](https://github.com/apache/datafusion/pull/16552) (simonvandel) +- Fix normalization of columns in JOIN ... USING. [#16560](https://github.com/apache/datafusion/pull/16560) (brunal) +- Revert Finalize support for `RightMark` join + `Mark` join [#16597](https://github.com/apache/datafusion/pull/16597) (comphead) +- move min_batch/max_batch to functions-aggregate-common [#16593](https://github.com/apache/datafusion/pull/16593) (adriangb) +- Allow usage of table functions in relations [#16571](https://github.com/apache/datafusion/pull/16571) (osipovartem) +- Update to arrow/parquet 55.2.0 [#16575](https://github.com/apache/datafusion/pull/16575) (alamb) +- Improve field naming in first_value, last_value implementation [#16631](https://github.com/apache/datafusion/pull/16631) (findepi) +- Fix spurious failure in convert_batches test helper [#16627](https://github.com/apache/datafusion/pull/16627) (findepi) +- Aggregate UDF cleanup [#16628](https://github.com/apache/datafusion/pull/16628) (findepi) +- Avoid treating incomparable scalars as equal [#16624](https://github.com/apache/datafusion/pull/16624) (findepi) +- restore topk pre-filtering of batches and make sort query fuzzer less sensitive to expected non determinism [#16501](https://github.com/apache/datafusion/pull/16501) (alamb) +- Add support for Arrow Time types in Substrait [#16558](https://github.com/apache/datafusion/pull/16558) (jkosh44) +- chore(deps): bump substrait from 0.57.0 to 0.58.0 [#16640](https://github.com/apache/datafusion/pull/16640) (dependabot[bot]) +- Support explain tree format debug for benchmark debug [#16604](https://github.com/apache/datafusion/pull/16604) (zhuqi-lucas) +- Add microbenchmark for spilling with compression [#16512](https://github.com/apache/datafusion/pull/16512) (ding-young) +- Fix parquet filter_pushdown: respect parquet filter pushdown config in scan [#16646](https://github.com/apache/datafusion/pull/16646) (adriangb) +- chore(deps): bump aws-config from 1.8.0 to 1.8.1 [#16651](https://github.com/apache/datafusion/pull/16651) (dependabot[bot]) +- Migrate core test to insta, part 2 [#16617](https://github.com/apache/datafusion/pull/16617) (Chen-Yuan-Lai) +- Update all spark SLT files [#16637](https://github.com/apache/datafusion/pull/16637) (findepi) +- Add PhysicalExpr optimizer and cast unwrapping [#16530](https://github.com/apache/datafusion/pull/16530) (adriangb) +- benchmark: Support sort_tpch10 for benchmark [#16671](https://github.com/apache/datafusion/pull/16671) (zhuqi-lucas) +- chore(deps): bump tokio from 1.45.1 to 1.46.0 [#16666](https://github.com/apache/datafusion/pull/16666) (dependabot[bot]) +- Fix TopK Sort incorrectly pushed down past Join with anti join [#16641](https://github.com/apache/datafusion/pull/16641) (zhuqi-lucas) +- Improve error message when ScalarValue fails to cast array [#16670](https://github.com/apache/datafusion/pull/16670) (findepi) +- Add an example of embedding indexes inside a parquet file [#16395](https://github.com/apache/datafusion/pull/16395) (zhuqi-lucas) +- `datafusion-cli`: Refactor statement execution logic [#16634](https://github.com/apache/datafusion/pull/16634) (liamzwbao) +- Add SchemaAdapterFactory Support for ListingTable with Schema Evolution and Mapping [#16583](https://github.com/apache/datafusion/pull/16583) (kosiew) +- Perf: fast CursorValues compare for StringViewArray using inline*key*… [#16630](https://github.com/apache/datafusion/pull/16630) (zhuqi-lucas) +- Update to Rust 1.88 [#16663](https://github.com/apache/datafusion/pull/16663) (melroy12) +- Refactor StreamJoinMetrics to reuse BaselineMetrics [#16674](https://github.com/apache/datafusion/pull/16674) (Standing-Man) +- chore: refactor `BuildProbeJoinMetrics` to use `BaselineMetrics` [#16500](https://github.com/apache/datafusion/pull/16500) (Samyak2) +- Use compression type in CSV file suffices [#16609](https://github.com/apache/datafusion/pull/16609) (theirix) +- Clarify the generality of the embedded parquet index [#16692](https://github.com/apache/datafusion/pull/16692) (alamb) +- Refactor SortMergeJoinMetrics to reuse BaselineMetrics [#16675](https://github.com/apache/datafusion/pull/16675) (Standing-Man) +- Add support for Arrow Dictionary type in Substrait [#16608](https://github.com/apache/datafusion/pull/16608) (jkosh44) +- Fix duplicate field name error in Join::try_new_with_project_input during physical planning [#16454](https://github.com/apache/datafusion/pull/16454) (LiaCastaneda) +- chore(deps): bump tokio from 1.46.0 to 1.46.1 [#16700](https://github.com/apache/datafusion/pull/16700) (dependabot[bot]) +- Add reproducer for tpch Q16 deserialization bug [#16662](https://github.com/apache/datafusion/pull/16662) (NGA-TRAN) +- Minor: Update release instructions [#16701](https://github.com/apache/datafusion/pull/16701) (alamb) +- refactor filter pushdown APIs [#16642](https://github.com/apache/datafusion/pull/16642) (adriangb) +- Add comments to ClickBench queries about setting binary_as_string [#16605](https://github.com/apache/datafusion/pull/16605) (alamb) +- minor: improve display output for FFI execution plans [#16713](https://github.com/apache/datafusion/pull/16713) (timsaucer) +- Revert "fix: create file for empty stream" [#16682](https://github.com/apache/datafusion/pull/16682) (brunal) +- Add the missing equivalence info for filter pushdown [#16686](https://github.com/apache/datafusion/pull/16686) (liamzwbao) +- Fix sqllogictests test running compatibility (ignore `--test-threads`) [#16694](https://github.com/apache/datafusion/pull/16694) (mjgarton) +- Fix: Make `CopyTo` logical plan output schema consistent with physical schema [#16705](https://github.com/apache/datafusion/pull/16705) (bert-beyondloops) +- chore(devcontainer): use debian's `protobuf-compiler` package [#16687](https://github.com/apache/datafusion/pull/16687) (fvj) +- Add link to upgrade guide in changelog script [#16680](https://github.com/apache/datafusion/pull/16680) (alamb) +- Improve display format of BoundedWindowAggExec [#16645](https://github.com/apache/datafusion/pull/16645) (geetanshjuneja) +- Fix: optimize projections for unnest logical plan. [#16632](https://github.com/apache/datafusion/pull/16632) (bert-beyondloops) +- Use the `test-threads` option in sqllogictests [#16722](https://github.com/apache/datafusion/pull/16722) (mjgarton) +- chore(deps): bump clap from 4.5.40 to 4.5.41 [#16735](https://github.com/apache/datafusion/pull/16735) (dependabot[bot]) +- chore: make more clarity for internal errors [#16741](https://github.com/apache/datafusion/pull/16741) (comphead) +- Remove parquet_filter and parquet `sort` benchmarks [#16730](https://github.com/apache/datafusion/pull/16730) (alamb) +- Perform type coercion for corr aggregate function [#15776](https://github.com/apache/datafusion/pull/15776) (kumarlokesh) +- Improve dictionary null handling in hashing and expand aggregate test coverage for nulls [#16466](https://github.com/apache/datafusion/pull/16466) (kosiew) +- Improve Ci cache [#16709](https://github.com/apache/datafusion/pull/16709) (blaginin) +- Fix in list round trip in df proto [#16744](https://github.com/apache/datafusion/pull/16744) (XiangpengHao) +- chore: Make `GroupValues` and APIs on `PhysicalGroupBy` aggregation APIs public [#16733](https://github.com/apache/datafusion/pull/16733) (haohuaijin) +- Extend binary coercion rules to support Decimal arithmetic operations with integer(signed and unsigned) types [#16668](https://github.com/apache/datafusion/pull/16668) (jatin510) +- Support Type Coercion for NULL in Binary Arithmetic Expressions [#16761](https://github.com/apache/datafusion/pull/16761) (kosiew) +- chore(deps): bump chrono-tz from 0.10.3 to 0.10.4 [#16769](https://github.com/apache/datafusion/pull/16769) (dependabot[bot]) +- limit intermediate batch size in nested_loop_join [#16443](https://github.com/apache/datafusion/pull/16443) (UBarney) +- Add serialization/deserialization and round-trip tests for all tpc-h queries [#16742](https://github.com/apache/datafusion/pull/16742) (NGA-TRAN) +- Auto start testcontainers for `datafusion-cli` [#16644](https://github.com/apache/datafusion/pull/16644) (blaginin) +- Refactor BinaryTypeCoercer to Handle Null Coercion Early and Avoid Redundant Checks [#16768](https://github.com/apache/datafusion/pull/16768) (kosiew) +- Remove fixed version from MSRV check [#16786](https://github.com/apache/datafusion/pull/16786) (findepi) +- Add `clickbench_pushdown` benchmark [#16731](https://github.com/apache/datafusion/pull/16731) (alamb) +- add filter to handle backtrace [#16752](https://github.com/apache/datafusion/pull/16752) (geetanshjuneja) +- Support min/max aggregates for FixedSizeBinary type [#16765](https://github.com/apache/datafusion/pull/16765) (theirix) +- fix tests in page_pruning when filter pushdown is enabled by default [#16794](https://github.com/apache/datafusion/pull/16794) (XiangpengHao) +- Automatically split large single RecordBatches in `MemorySource` into smaller batches [#16734](https://github.com/apache/datafusion/pull/16734) (kosiew) +- CI: Fix slow join test [#16796](https://github.com/apache/datafusion/pull/16796) (2010YOUY01) +- Benchmark for char expression [#16743](https://github.com/apache/datafusion/pull/16743) (ajita-asthana) +- Add example of custom file schema casting rules [#16803](https://github.com/apache/datafusion/pull/16803) (adriangb) +- Fix discrepancy in Float64 to timestamp(9) casts for constants [#16639](https://github.com/apache/datafusion/pull/16639) (findepi) +- Fix: Preserve sorting for the COPY TO plan [#16785](https://github.com/apache/datafusion/pull/16785) (bert-beyondloops) +- chore(deps): bump object_store from 0.12.2 to 0.12.3 [#16807](https://github.com/apache/datafusion/pull/16807) (dependabot[bot]) +- Implement equals for stateful functions [#16781](https://github.com/apache/datafusion/pull/16781) (findepi) +- benchmark: Add parquet h2o support [#16804](https://github.com/apache/datafusion/pull/16804) (zhuqi-lucas) +- chore: use `equals_datatype` for `BinaryExpr` [#16813](https://github.com/apache/datafusion/pull/16813) (comphead) +- chore: add tests for out of bounds for NullArray [#16802](https://github.com/apache/datafusion/pull/16802) (comphead) +- Refactor binary.rs tests into modular submodules under `binary/tests` [#16782](https://github.com/apache/datafusion/pull/16782) (kosiew) +- cache generation of dictionary keys and null arrays for ScalarValue [#16789](https://github.com/apache/datafusion/pull/16789) (adriangb) +- refactor(examples): remove redundant call to create directory in `parquet_embedded_index.rs` [#16825](https://github.com/apache/datafusion/pull/16825) (jcsherin) +- Add benchmark for ByteViewGroupValueBuilder [#16826](https://github.com/apache/datafusion/pull/16826) (zhuqi-lucas) +- Simplify try cast expr evaluation [#16834](https://github.com/apache/datafusion/pull/16834) (lewiszlw) +- Fix flaky test case in joins.slt [#16849](https://github.com/apache/datafusion/pull/16849) (findepi) +- chore(deps): bump sysinfo from 0.35.2 to 0.36.1 [#16850](https://github.com/apache/datafusion/pull/16850) (dependabot[bot]) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 33 Andrew Lamb + 26 dependabot[bot] + 19 Adrian Garcia Badaracco + 14 kosiew + 13 Piotr Findeisen + 13 Qi Zhu + 7 Jonathan Chen + 6 Chen Chongchen + 6 Marco Neumann + 6 Oleks V + 6 Pepijn Van Eeckhoudt + 6 xudong.w + 5 Yongting You + 5 ding-young + 4 Simon Vandel Sillesen + 3 Adam Gutglick + 3 Bert Vermeiren + 3 Dmitrii Blaginin + 3 Joseph Koshakow + 3 Liam Bao + 3 Tim Saucer + 2 Alan Tang + 2 Arttu + 2 Bruno + 2 Corwin Joy + 2 Daniël Heres + 2 Geetansh Juneja + 2 Ian Lai + 2 Jax Liu + 2 Martin Garton + 2 Nga Tran + 2 Ruihang Xia + 2 Tai Le Manh + 2 ViggoC + 2 Xiangpeng Hao + 2 haiywu + 2 theirix + 1 Ajeeta Asthana + 1 Artem Osipov + 1 Dharan Aditya + 1 Gabriel + 1 Geoffrey Claude + 1 Hendrik Makait + 1 Huaijin + 1 Ian Wagner + 1 Jack Eadie + 1 Jagdish Parihar + 1 Jon Mease + 1 Julius von Froreich + 1 K + 1 Leon Lin + 1 Loakesh Indiran + 1 Lokesh + 1 Lucas Earl + 1 Lía Adriana + 1 Mehmet Ozan Kabak + 1 Melroy dsilva + 1 Nirnay Roy + 1 Nuno Faria + 1 R. Tyler Croy + 1 Rishab Joshi + 1 Sami Tabet + 1 Samyak Sarnayak + 1 Shehab Amin + 1 Tobias Schwarzinger + 1 UBarney + 1 alexanderbianchi + 1 epgif + 1 feniljain + 1 m09526 + 1 suibianwanwan +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/49.0.1.md b/dev/changelog/49.0.1.md new file mode 100644 index 0000000000000..06d7c1e2c77a6 --- /dev/null +++ b/dev/changelog/49.0.1.md @@ -0,0 +1,48 @@ + + +# Apache DataFusion 49.0.1 Changelog + +This release consists of 5 commits from 5 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Other:** + +- [branch-49] Final Changelog Tweaks [#16852](https://github.com/apache/datafusion/pull/16852) (alamb) +- [branch-49] remove warning from every file open [#17059](https://github.com/apache/datafusion/pull/17059) (mbutrovich) +- [branch-49] Backport PR #16995 to branch-49 [#17068](https://github.com/apache/datafusion/pull/17068) (pepijnve) +- [branch-49] Backport "Add ExecutionPlan::reset_state (apache#17028)" to v49 [#17096](https://github.com/apache/datafusion/pull/17096) (adriangb) +- [branch-49] Backport #17129 to branch 49 [#17143](https://github.com/apache/datafusion/pull/17143) (AdamGS) +- [branch-49] Backport Pass the input schema to stats_projection for ProjectionExpr (#17123) [#17174](https://github.com/apache/datafusion/pull/17174) (alamb) +- [branch-49] fix: string_agg not respecting ORDER BY [#17058](https://github.com/apache/datafusion/pull/17058) (nuno-faria) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 1 Adam Gutglick + 1 Adrian Garcia Badaracco + 1 Andrew Lamb + 1 Matt Butrovich + 1 Pepijn Van Eeckhoudt +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/49.0.2.md b/dev/changelog/49.0.2.md new file mode 100644 index 0000000000000..7e6fc3e7eb487 --- /dev/null +++ b/dev/changelog/49.0.2.md @@ -0,0 +1,45 @@ + + +# Apache DataFusion 49.0.2 Changelog + +This release consists of 3 commits from 3 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Fixed bugs:** + +- fix: align `array_has` null buffer for scalar (#17272) [#17274](https://github.com/apache/datafusion/pull/17274) (comphead) + +**Other:** + +- [branch-49] Backport fix: deserialization error for FilterExec (predicates with inlist) [#17254](https://github.com/apache/datafusion/pull/17254) (haohuaijin) +- [branch-49] FFI_RecordBatchStream was causing a memory leak (#17190) [#17270](https://github.com/apache/datafusion/pull/17270) (timsaucer) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 1 Huaijin + 1 Oleks V + 1 Tim Saucer +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/50.0.0.md b/dev/changelog/50.0.0.md new file mode 100644 index 0000000000000..7563d57777d56 --- /dev/null +++ b/dev/changelog/50.0.0.md @@ -0,0 +1,445 @@ + + +# Apache DataFusion 50.0.0 Changelog + +This release consists of 315 commits from 79 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Breaking changes:** + +- Support multiple ordered `array_agg` aggregations [#16625](https://github.com/apache/datafusion/pull/16625) (findepi) +- Make `AsyncScalarUDFImpl::invoke_async_with_args` consistent with `ScalarUDFImpl::invoke_with_args` [#16902](https://github.com/apache/datafusion/pull/16902) (geetanshjuneja) +- Derive `WindowUDFImpl` equality, hash from `Eq`, `Hash` traits [#17081](https://github.com/apache/datafusion/pull/17081) (findepi) +- Remove redundant `plan` from extension's check_invariants [#17199](https://github.com/apache/datafusion/pull/17199) (findepi) +- feat: Make parquet_encryption a non-default feature [#17137](https://github.com/apache/datafusion/pull/17137) (miroim) +- chore: fix typos [#17135](https://github.com/apache/datafusion/pull/17135) (waynexia) +- Use a struct for ProjectionExpr [#17398](https://github.com/apache/datafusion/pull/17398) (adriangb) +- Use DataFusionError instead of ArrowError in FileOpenFuture [#17397](https://github.com/apache/datafusion/pull/17397) (adriangb) +- Use return_field instead of return_type for calling aggregates via FFI [#17407](https://github.com/apache/datafusion/pull/17407) (timsaucer) + +**Performance related:** + +- feat: improve LiteralGuarantee for the case like `(a=1 AND b=1) OR (a=2 AND b=3)` [#16762](https://github.com/apache/datafusion/pull/16762) (haohuaijin) +- optimize `initcap` function by avoiding memory allocation [#16878](https://github.com/apache/datafusion/pull/16878) (waynexia) +- speedup `date_trunc` (~7x faster) in some cases [#16859](https://github.com/apache/datafusion/pull/16859) (waynexia) +- Feature: Improve hash Expr performance [#16977](https://github.com/apache/datafusion/pull/16977) (tobixdev) +- Perf: Port arrow-rs optimization for get_buffer_memory_size and add fast path for no buffer for gc string view [#17008](https://github.com/apache/datafusion/pull/17008) (zhuqi-lucas-001) +- Simplify comparisons and binary operations involving NULL [#17088](https://github.com/apache/datafusion/pull/17088) (findepi) +- Eliminate all redundant aggregations [#17139](https://github.com/apache/datafusion/pull/17139) (findepi) + +**Implemented enhancements:** + +- feat: Allow tree explain format width to be customizable [#16827](https://github.com/apache/datafusion/pull/16827) (nuno-faria) +- feat(spark): Implement Spark `string` function `luhn_check` [#16848](https://github.com/apache/datafusion/pull/16848) (Standing-Man) +- feat(spark): implement Spark datetime function last_day [#16828](https://github.com/apache/datafusion/pull/16828) (Standing-Man) +- feat: Add `ScalarValue::{new_one,new_zero,new_ten,distance}` support for `Decimal128` and `Decimal256` [#16831](https://github.com/apache/datafusion/pull/16831) (theirix) +- feat: support distinct for window [#16925](https://github.com/apache/datafusion/pull/16925) (zhuqi-lucas-001) +- feat: add multi level merge sort that will always fit in memory [#15700](https://github.com/apache/datafusion/pull/15700) (rluvaton) +- feat: [datafusion-spark] Implement `next_day` function [#16780](https://github.com/apache/datafusion/pull/16780) (petern48) +- feat: Support distinct window for sum [#16943](https://github.com/apache/datafusion/pull/16943) (zhuqi-lucas-001) +- feat(spark): implement Spark math function rint [#16924](https://github.com/apache/datafusion/pull/16924) (chenkovsky) +- feat(spark): implement Spark string function like/ilike [#16962](https://github.com/apache/datafusion/pull/16962) (chenkovsky) +- feat: Cache Parquet metadata in built in parquet reader [#16971](https://github.com/apache/datafusion/pull/16971) (nuno-faria) +- feat: Add `Arc` to `ScalarFunctionArgs`, don't copy `ConfigOptions` on each query [#16970](https://github.com/apache/datafusion/pull/16970) (Omega359) +- feat(spark): implement spark hash function crc32/sha1 [#17032](https://github.com/apache/datafusion/pull/17032) (chenkovsky) +- feat: Limit the memory used in the file metadata cache [#17031](https://github.com/apache/datafusion/pull/17031) (nuno-faria) +- feat: Dynamic Parquet encryption and decryption properties [#16779](https://github.com/apache/datafusion/pull/16779) (adamreeve) +- feat: Use Cached Metadata for ListingTable Statistics [#17022](https://github.com/apache/datafusion/pull/17022) (shehabgamin) +- feat(spark): implement Spark math function mod/pmod [#16829](https://github.com/apache/datafusion/pull/16829) (chenkovsky) +- feat(spark): implement Spark math function bit_get/bit_count [#16942](https://github.com/apache/datafusion/pull/16942) (chenkovsky) +- feat: add `isodow` (ISO day-of-week) support to date_part (Monday = 0) [#17112](https://github.com/apache/datafusion/pull/17112) (ayemjay) +- feat(spark): implement spark datetime function date_add/date_sub [#17024](https://github.com/apache/datafusion/pull/17024) (chenkovsky) +- feat: Add the ability to review the contents of the Metadata Cache [#17126](https://github.com/apache/datafusion/pull/17126) (nuno-faria) +- feat: add `datafusion-physical-adapter`, implement predicate adaptation missing fields of structs [#16589](https://github.com/apache/datafusion/pull/16589) (adriangb) +- feat: implement QUALIFY clause [#16933](https://github.com/apache/datafusion/pull/16933) (haohuaijin) +- feat: allow to `spawn`/`spawn_blocking` on a provided runtime in `RecordBatchReceiverStreamBuilder` [#17239](https://github.com/apache/datafusion/pull/17239) (rluvaton) +- feat: Support SortMergeJoin proto serde [#17296](https://github.com/apache/datafusion/pull/17296) (milenkovicm) +- feat(spark): implement Spark `bitmap` function `bitmap_count` [#17179](https://github.com/apache/datafusion/pull/17179) (SparkApplicationMaster) +- feat: Track peak value in tracked consumer [#17327](https://github.com/apache/datafusion/pull/17327) (wForget) +- feat(spark): implement Spark conditional function if [#16946](https://github.com/apache/datafusion/pull/16946) (chenkovsky) +- feat(spark): implement Spark `width_bucket` function [#17331](https://github.com/apache/datafusion/pull/17331) (davidlghellin) +- feat: Make Parquet EncryptionFactory async [#17342](https://github.com/apache/datafusion/pull/17342) (adamreeve) +- feat: Support `FILTER` clause in aggregate window functions [#17378](https://github.com/apache/datafusion/pull/17378) (geoffreyclaude) +- feat: Support binary data types for `SortMergeJoin` `on` clause [#17431](https://github.com/apache/datafusion/pull/17431) (stuartcarnie) + +**Fixed bugs:** + +- fix: The inconsistency between scalar and array on the cast decimal to timestamp [#16539](https://github.com/apache/datafusion/pull/16539) (chenkovsky) +- fix: unit test for object_storage [#16824](https://github.com/apache/datafusion/pull/16824) (chenkovsky) +- fix(docs): Update broken links to `TableProvider` docs [#16830](https://github.com/apache/datafusion/pull/16830) (jcsherin) +- fix: `PlaceholderRowExec::partition_statistics` [#16851](https://github.com/apache/datafusion/pull/16851) (crepererum) +- fix: skip predicates on struct unnest in PushDownFilter [#16790](https://github.com/apache/datafusion/pull/16790) (akoshchiy) +- fix: regex bench [#16890](https://github.com/apache/datafusion/pull/16890) (chenkovsky) +- fix: `ComposedPhysicalExtensionCodec` does not use the same codec as encoding when decoding [#16986](https://github.com/apache/datafusion/pull/16986) (Thearas) +- fix: Remove `datafusion.execution.parquet.cache_metadata` config [#17062](https://github.com/apache/datafusion/pull/17062) (jonathanc-n) +- fix: Add missing member to visitor for ConfigFileEncryptionProperties [#17103](https://github.com/apache/datafusion/pull/17103) (corwinjoy) +- fix(ci): update `datafusion-physical-expr-adapter` version to 49.0.1in Cargo.lock [#17209](https://github.com/apache/datafusion/pull/17209) (miroim) +- fix: respect inexact flags in row group metadata [#16412](https://github.com/apache/datafusion/pull/16412) (CookiePieWw) +- fix: deserialization error for `FilterExec` (predicates with inlist) [#17224](https://github.com/apache/datafusion/pull/17224) (haohuaijin) +- FFI_RecordBatchStream was causing a memory leak [#17190](https://github.com/apache/datafusion/pull/17190) (timsaucer) +- fix: Windows paths crashing core tests [#17231](https://github.com/apache/datafusion/pull/17231) (nuno-faria) +- fix: sort should always output batches with `batch_size` rows [#17244](https://github.com/apache/datafusion/pull/17244) (rluvaton) +- fix: align `array_has` null buffer for scalar [#17272](https://github.com/apache/datafusion/pull/17272) (comphead) +- fix: dataframe function count_all with alias [#17282](https://github.com/apache/datafusion/pull/17282) (Loaki07) +- fix: correct readme field in `Cargo.toml` [#17310](https://github.com/apache/datafusion/pull/17310) (Weijun-H) +- fix(doc): update the link of deprecation guidelines (#17328) [#17329](https://github.com/apache/datafusion/pull/17329) (ivila) +- fix: lazy case else evaluation [#17311](https://github.com/apache/datafusion/pull/17311) (chenkovsky) +- fix: set distinct_count to Absent when merging statistics [#17385](https://github.com/apache/datafusion/pull/17385) (adriangb) +- fix: Remove duplicate filter from `CrossJoin` unparsing [#17382](https://github.com/apache/datafusion/pull/17382) (jonathanc-n) +- fix: set IPC alignment based on schema [#17363](https://github.com/apache/datafusion/pull/17363) (ding-young) +- fix: return ALL constants in `EquivalenceProperties::constants` [#17404](https://github.com/apache/datafusion/pull/17404) (crepererum) +- fix: align `map_keys` nullability flag [#17454](https://github.com/apache/datafusion/pull/17454) (comphead) + +**Documentation updates:** + +- docs: Fix broken links [#16839](https://github.com/apache/datafusion/pull/16839) (2010YOUY01) +- Add note to upgrade guide about MSRV update [#16845](https://github.com/apache/datafusion/pull/16845) (alamb) +- [main] Update version to 49.0.0, add 49.0.0 changelog [#16855](https://github.com/apache/datafusion/pull/16855) (alamb) +- Improve async_udf example and docs [#16846](https://github.com/apache/datafusion/pull/16846) (alamb) +- Docs: Update Upgrading.md to reflect 49.0.0 is released [#16853](https://github.com/apache/datafusion/pull/16853) (alamb) +- docs: Remove references to DataFusion for Ray sub project [#16966](https://github.com/apache/datafusion/pull/16966) (andygrove) +- Add `temp_directory` and `max_temp_directory_size` runtime config variables [#16934](https://github.com/apache/datafusion/pull/16934) (delamarch3) +- Add `sql_parser.default_null_ordering` config option to customize the default null ordering [#16963](https://github.com/apache/datafusion/pull/16963) (goldmedal) +- Added Example for `Statistical Functions` in Docs [#16927](https://github.com/apache/datafusion/pull/16927) (Adez017) +- Fix window_functions docs formatting [#17005](https://github.com/apache/datafusion/pull/17005) (mattmatravers) +- docs: Fix 'Analaysis' typo in query optimizer docs [#17015](https://github.com/apache/datafusion/pull/17015) (petern48) +- docs: Fix random extra bullet for 'Analytical Functions' [#17014](https://github.com/apache/datafusion/pull/17014) (petern48) +- docs: Fix failing documentation check in CI [#17026](https://github.com/apache/datafusion/pull/17026) (adamreeve) +- Upgrade arrow/parquet to 56.0.0 [#16690](https://github.com/apache/datafusion/pull/16690) (alamb) +- fix error result in execute&pre_selection [#16930](https://github.com/apache/datafusion/pull/16930) (acking-you) +- docs: Fix failing CI [#17041](https://github.com/apache/datafusion/pull/17041) (liamzwbao) +- Docs: Add Examples to Config Options page [#17039](https://github.com/apache/datafusion/pull/17039) (alamb) +- Docs: Add Tuning Guide for small data / short queries [#17040](https://github.com/apache/datafusion/pull/17040) (alamb) +- Docs: Update the crate configuration / build settings page [#17038](https://github.com/apache/datafusion/pull/17038) (alamb) +- Support `centroids` config for `approx_percentile_cont_with_weight` [#17003](https://github.com/apache/datafusion/pull/17003) (liamzwbao) +- Add ExecutionPlan::reset_state [#17028](https://github.com/apache/datafusion/pull/17028) (adriangb) +- Docs: Add Tuning Guide for larger-than-memory queries [#17069](https://github.com/apache/datafusion/pull/17069) (2010YOUY01) +- Link UdfEq and PtrEq to help understand relationship [#17082](https://github.com/apache/datafusion/pull/17082) (findepi) +- Derive `AggregateUDFImpl` equality, hash from `Eq`, `Hash` traits [#17130](https://github.com/apache/datafusion/pull/17130) (findepi) +- chore: Clarify `EmptyRelation` description [#17157](https://github.com/apache/datafusion/pull/17157) (comphead) +- Update dev env documentation to reflect pinned rust version [#17107](https://github.com/apache/datafusion/pull/17107) (Jefffrey) +- Differentiate 0-row and 1-row EmptyRelation in EXPLAIN [#17145](https://github.com/apache/datafusion/pull/17145) (findepi) +- (Re)Support old syntax for `approx_percentile_cont` and `approx_percentile_cont_with_weight` [#16999](https://github.com/apache/datafusion/pull/16999) (alamb) +- Derive `ScalarUDFImpl` equality, hash from `Eq`, `Hash` traits [#17164](https://github.com/apache/datafusion/pull/17164) (findepi) +- #17128 Add support for chr(0) [#17131](https://github.com/apache/datafusion/pull/17131) (pepijnve) +- [main] Update version to 49.0.1 and add changelog (#17175) [#17191](https://github.com/apache/datafusion/pull/17191) (alamb) +- Docs: Consolidate feature proposal content into roadmap [#17156](https://github.com/apache/datafusion/pull/17156) (alamb) +- Doc: Update upgrade guide for the rewritten NLJ operator [#17202](https://github.com/apache/datafusion/pull/17202) (2010YOUY01) +- Support serializing `generate_series` in `datafusion-proto` [#17200](https://github.com/apache/datafusion/pull/17200) (cetra3) +- Fix broken links in user docs [#17228](https://github.com/apache/datafusion/pull/17228) (AdamGS) +- Format `Date32` to string given timestamp specifiers [#15361](https://github.com/apache/datafusion/pull/15361) (friendlymatthew) +- Improve documentation for Signature, Volatility, and TypeSignature [#17264](https://github.com/apache/datafusion/pull/17264) (alamb) +- [main] Forward port `49.0.2` version and changelog (#17277) [#17287](https://github.com/apache/datafusion/pull/17287) (alamb) +- Document schema merging. [#17249](https://github.com/apache/datafusion/pull/17249) (wiedld) +- Support from-first SQL syntax [#17295](https://github.com/apache/datafusion/pull/17295) (simonvandel) +- Add `cfg(feature = "avro")` attribute to Avro example in SQL API docs [#17142](https://github.com/apache/datafusion/pull/17142) (kosiew) +- Push the limits past window functions [#17347](https://github.com/apache/datafusion/pull/17347) (avantgardnerio) +- Refactor DataSourceExec::try_swapping_with_projection to simplify and remove abstraction leakage [#17395](https://github.com/apache/datafusion/pull/17395) (adriangb) +- doc: Document caveats of `swap_inputs()` interface in join executors [#17373](https://github.com/apache/datafusion/pull/17373) (2010YOUY01) +- Fix syntax error in DDL documentation example [#17412](https://github.com/apache/datafusion/pull/17412) (pepijnve) +- Add MSRV change to upgrade guide [#17406](https://github.com/apache/datafusion/pull/17406) (findepi) +- Add PhysicalExpr::is_volatile_node to upgrade guide [#17443](https://github.com/apache/datafusion/pull/17443) (adriangb) +- docs: Render `--` properly in profiling docs [#17430](https://github.com/apache/datafusion/pull/17430) (petern48) + +**Other:** + +- chore: use `equals_datatype` for `BinaryExpr` [#16813](https://github.com/apache/datafusion/pull/16813) (comphead) +- chore: add tests for out of bounds for NullArray [#16802](https://github.com/apache/datafusion/pull/16802) (comphead) +- Refactor binary.rs tests into modular submodules under `binary/tests` [#16782](https://github.com/apache/datafusion/pull/16782) (kosiew) +- cache generation of dictionary keys and null arrays for ScalarValue [#16789](https://github.com/apache/datafusion/pull/16789) (adriangb) +- refactor(examples): remove redundant call to create directory in `parquet_embedded_index.rs` [#16825](https://github.com/apache/datafusion/pull/16825) (jcsherin) +- Add benchmark for ByteViewGroupValueBuilder [#16826](https://github.com/apache/datafusion/pull/16826) (zhuqi-lucas-001) +- Simplify try cast expr evaluation [#16834](https://github.com/apache/datafusion/pull/16834) (lewiszlw) +- Fix flaky test case in joins.slt [#16849](https://github.com/apache/datafusion/pull/16849) (findepi) +- chore(deps): bump sysinfo from 0.35.2 to 0.36.1 [#16850](https://github.com/apache/datafusion/pull/16850) (dependabot[bot]) +- chore(deps): bump aws-credential-types from 1.2.3 to 1.2.4 [#16815](https://github.com/apache/datafusion/pull/16815) (dependabot[bot]) +- fix(build-wasm): put `arrow-ipc/zstd` dep under `compression` feature [#16844](https://github.com/apache/datafusion/pull/16844) (chrisvander) +- chore(deps): bump serde_json from 1.0.140 to 1.0.141 [#16863](https://github.com/apache/datafusion/pull/16863) (dependabot[bot]) +- chore(deps): bump aws-config from 1.8.1 to 1.8.2 [#16864](https://github.com/apache/datafusion/pull/16864) (dependabot[bot]) +- test: Fix flaky join tests [#16860](https://github.com/apache/datafusion/pull/16860) (2010YOUY01) +- chore(deps): bump rand from 0.9.1 to 0.9.2 [#16882](https://github.com/apache/datafusion/pull/16882) (dependabot[bot]) +- Report error when `SessionState::sql_to_expr_with_alias` does not consume all input [#16811](https://github.com/apache/datafusion/pull/16811) (pepijnve) +- test: fix more flaky join tests [#16880](https://github.com/apache/datafusion/pull/16880) (2010YOUY01) +- MINOR: add unit tests for chr function [#16856](https://github.com/apache/datafusion/pull/16856) (waynexia) +- remove deprecated methods from FileScanConfig / DataSourceExec [#16901](https://github.com/apache/datafusion/pull/16901) (adriangb) +- Support utf8view for spark hex [#16885](https://github.com/apache/datafusion/pull/16885) (xudong963) +- Fixes 3 bugs during serialization and deserialization of physical plans [#16858](https://github.com/apache/datafusion/pull/16858) (NGA-TRAN) +- chore(deps): bump aws-config from 1.8.2 to 1.8.3 [#16912](https://github.com/apache/datafusion/pull/16912) (dependabot[bot]) +- Derive UDF equality from PartialEq, Hash [#16842](https://github.com/apache/datafusion/pull/16842) (findepi) +- Ensure Substrait consumer can handle expressions in VirtualTable [#16857](https://github.com/apache/datafusion/pull/16857) (lorenarosati) +- Mutable Join Unwind [#16883](https://github.com/apache/datafusion/pull/16883) (berkaysynnada) +- fix(datafusion-proto): support serializing/deserilizing ArrowFormat tables [#16875](https://github.com/apache/datafusion/pull/16875) (colinmarc) +- ScalarValue Default + Min + Max [#16891](https://github.com/apache/datafusion/pull/16891) (berkaysynnada) +- minor: add is_superset() method for Interval's [#16895](https://github.com/apache/datafusion/pull/16895) (berkaysynnada) +- minor: implement with_new_expressions for AggregateFunctionExpr [#16897](https://github.com/apache/datafusion/pull/16897) (berkaysynnada) +- minor: Rename add_spm_on_top as add_merge_on_top [#16913](https://github.com/apache/datafusion/pull/16913) (berkaysynnada) +- Implement Helpers for ScopedTimerGuard and Time Structs [#16911](https://github.com/apache/datafusion/pull/16911) (berkaysynnada) +- Fix Partial Sort Get Slice Point Between Batches [#16881](https://github.com/apache/datafusion/pull/16881) (berkaysynnada) +- Fix `schema_adapter` integration tests not running [#16835](https://github.com/apache/datafusion/pull/16835) (kosiew) +- Update release process [#16929](https://github.com/apache/datafusion/pull/16929) (xudong963) +- Fix `next_up` and `next_down` behavior for zero float values [#16745](https://github.com/apache/datafusion/pull/16745) (liamzwbao) +- Add Fetch Property to OutputRequirementExec [#16892](https://github.com/apache/datafusion/pull/16892) (berkaysynnada) +- chore(deps): bump tokio from 1.46.1 to 1.47.0 [#16952](https://github.com/apache/datafusion/pull/16952) (dependabot[bot]) +- chore(deps): bump serde_json from 1.0.140 to 1.0.141 [#16951](https://github.com/apache/datafusion/pull/16951) (dependabot[bot]) +- chore: Remove attributes to allow dead_code that aren't relevant anymore [#16953](https://github.com/apache/datafusion/pull/16953) (AdamGS) +- chore(deps): bump rand from 0.9.1 to 0.9.2 [#16960](https://github.com/apache/datafusion/pull/16960) (dependabot[bot]) +- chore(deps): bump ctor from 0.4.2 to 0.4.3 [#16961](https://github.com/apache/datafusion/pull/16961) (dependabot[bot]) +- disallow pushdown of volatile functions [#16861](https://github.com/apache/datafusion/pull/16861) (adriangb) +- remove warning from every file open [#16968](https://github.com/apache/datafusion/pull/16968) (adriangb) +- Pin github actions to commit sha [#16964](https://github.com/apache/datafusion/pull/16964) (gopidesupavan) +- Enable physical filter pushdown for hash joins [#16954](https://github.com/apache/datafusion/pull/16954) (adriangb) +- Fix [Bug] Aggregate + TopK fails when asc = false [#16972](https://github.com/apache/datafusion/pull/16972) (avantgardnerio) +- Use tokio::task::coop::poll_proceed by default in CooperativeStream [#16748](https://github.com/apache/datafusion/pull/16748) (pepijnve) +- Add benchmark utility to profile peak memory usage [#16814](https://github.com/apache/datafusion/pull/16814) (ding-young) +- chore(deps): bump indicatif from 0.17.11 to 0.18.0 [#16992](https://github.com/apache/datafusion/pull/16992) (dependabot[bot]) +- test(datafusion-cli): migrate tests to `insta` in `print_format.rs` [#16993](https://github.com/apache/datafusion/pull/16993) (Thearas) +- Chore: remove 'spill_record_batch_by_size' api [#16958](https://github.com/apache/datafusion/pull/16958) (ding-young) +- chore(deps): bump serde_json from 1.0.141 to 1.0.142 [#17006](https://github.com/apache/datafusion/pull/17006) (dependabot[bot]) +- Add tests for yielding in `SpillManager::read_spill_as_stream` [#16616](https://github.com/apache/datafusion/pull/16616) (ding-young) +- #16994 Ensure CooperativeExec#maintains_input_order returns a Vec of the correct size [#16995](https://github.com/apache/datafusion/pull/16995) (pepijnve) +- test: Add logic tests for string_agg with order [#17033](https://github.com/apache/datafusion/pull/17033) (nuno-faria) +- Implement `From>' for `ScalarValue` [#17043](https://github.com/apache/datafusion/pull/17043) (findepi) +- chore(deps): bump tokio-util from 0.7.15 to 0.7.16 [#17030](https://github.com/apache/datafusion/pull/17030) (dependabot[bot]) +- Add missing Substrait to DataFusion function name mappings [#16950](https://github.com/apache/datafusion/pull/16950) (lorenarosati) +- refactor: use upstream arrow-rs inline_key_fast [#17044](https://github.com/apache/datafusion/pull/17044) (zhuqi-lucas-001) +- Implement spark `array` function `array` [#16936](https://github.com/apache/datafusion/pull/16936) (Standing-Man) +- Address memory over-accounting in array_agg [#16816](https://github.com/apache/datafusion/pull/16816) (gabotechs) +- chore(deps): bump aws-credential-types from 1.2.4 to 1.2.5 [#17053](https://github.com/apache/datafusion/pull/17053) (dependabot[bot]) +- Support Substrait functions and_not, xor, and between in consumer built-in expression builder [#16984](https://github.com/apache/datafusion/pull/16984) (lorenarosati) +- Derive UDWF equality from PartialEq, Hash [#17057](https://github.com/apache/datafusion/pull/17057) (findepi) +- fix return field for `is_not_null` expression [#17056](https://github.com/apache/datafusion/pull/17056) (davidhewitt) +- chore(deps): bump tokio from 1.47.0 to 1.47.1 [#17063](https://github.com/apache/datafusion/pull/17063) (dependabot[bot]) +- Optimize char expression [#16076](https://github.com/apache/datafusion/pull/16076) (ajita-asthana) +- Fix equality of parametrizable ArrayAgg function [#17065](https://github.com/apache/datafusion/pull/17065) (findepi) +- Implement Spark `url` function `parse_url` [#16937](https://github.com/apache/datafusion/pull/16937) (Standing-Man) +- Derive UDAF equality from Eq, Hash [#17067](https://github.com/apache/datafusion/pull/17067) (findepi) +- Remove elements deprecated since v 45 [#17075](https://github.com/apache/datafusion/pull/17075) (findepi) +- Deprecate ScalarUDF::is_nullable [#17074](https://github.com/apache/datafusion/pull/17074) (findepi) +- Re-export `object_store` crate via DataFusion Core and Common [#17070](https://github.com/apache/datafusion/pull/17070) (kosiew) +- Fix hash/equality issues for ScalarFunctionExpr [#17078](https://github.com/apache/datafusion/pull/17078) (findepi) +- Fill missing methods in aliased UDF impls [#17080](https://github.com/apache/datafusion/pull/17080) (findepi) +- Improve Hash speed for ScalarFunctionExpr [#17099](https://github.com/apache/datafusion/pull/17099) (findepi) +- chore(deps): bump clap from 4.5.42 to 4.5.43 [#17079](https://github.com/apache/datafusion/pull/17079) (dependabot[bot]) +- minor: remove unused import in docstring of datafusion_common::record_batch [#17106](https://github.com/apache/datafusion/pull/17106) (Jefffrey) +- Make macros in common::test_util hygenic and not dependent on user dependencies [#17102](https://github.com/apache/datafusion/pull/17102) (AdamGS) +- minor: remove unnecessary clippy:large_enum_variant allows [#17108](https://github.com/apache/datafusion/pull/17108) (Jefffrey) +- minor: Improve equivalence handling of joins [#16893](https://github.com/apache/datafusion/pull/16893) (berkaysynnada) +- Fix incorrect `NULL IN ()` optimization [#17092](https://github.com/apache/datafusion/pull/17092) (findepi) +- Add `prettier` to the devcontainer (GitHub codespaces) [#17019](https://github.com/apache/datafusion/pull/17019) (alamb) +- Set a lower threshold for clippy to flag large error variants [#17109](https://github.com/apache/datafusion/pull/17109) (Jefffrey) +- chore(deps): bump rustyline from 16.0.0 to 17.0.0 [#17116](https://github.com/apache/datafusion/pull/17116) (dependabot[bot]) +- Add dynamic filter (bounds) pushdown to HashJoinExec [#16445](https://github.com/apache/datafusion/pull/16445) (adriangb) +- Remove the "run extended tests" github PR commend action [#17119](https://github.com/apache/datafusion/pull/17119) (alamb) +- chore(deps): bump sysinfo from 0.36.1 to 0.37.0 [#17124](https://github.com/apache/datafusion/pull/17124) (dependabot[bot]) +- chore(deps): bump libc from 0.2.174 to 0.2.175 [#17121](https://github.com/apache/datafusion/pull/17121) (dependabot[bot]) +- ff: Preserve cached plan information when pushing projection [#17129](https://github.com/apache/datafusion/pull/17129) (friendlymatthew) +- chore: Enforce checks for RC branches [#17132](https://github.com/apache/datafusion/pull/17132) (comphead) +- chore(deps): bump actions/checkout from 4.2.2 to 5.0.0 [#17149](https://github.com/apache/datafusion/pull/17149) (dependabot[bot]) +- minor: enhance comment in SortPreservingMergeStream.abort [#17115](https://github.com/apache/datafusion/pull/17115) (mapleFU) +- Update workspace to use Rust 1.89 [#17100](https://github.com/apache/datafusion/pull/17100) (shruti2522) +- chore(deps): bump on-headers and compression in /datafusion/wasmtest/datafusion-wasm-app [#16812](https://github.com/apache/datafusion/pull/16812) (dependabot[bot]) +- chore(deps): bump slab from 0.4.10 to 0.4.11 [#17161](https://github.com/apache/datafusion/pull/17161) (dependabot[bot]) +- refactor `character_length` impl by unifying null handling logic [#16877](https://github.com/apache/datafusion/pull/16877) (waynexia) +- chore(deps): bump clap from 4.5.43 to 4.5.44 [#17148](https://github.com/apache/datafusion/pull/17148) (dependabot[bot]) +- Pass the input schema to stats_projection for ProjectionExpr [#17123](https://github.com/apache/datafusion/pull/17123) (hareshkh) +- Fix extended tests failure on main by updating `datafusion-testing` pin [#17176](https://github.com/apache/datafusion/pull/17176) (alamb) +- Minor: display filter in HashJoin's tree explain [#17170](https://github.com/apache/datafusion/pull/17170) (2010YOUY01) +- add test for multi-column topk dynamic filter pushdown [#17162](https://github.com/apache/datafusion/pull/17162) (adriangb) +- Test: Add checks to sqllogictest temporary file creations [#17017](https://github.com/apache/datafusion/pull/17017) (2010YOUY01) +- Deprecate unused `ScalarUDF::display_name` [#17168](https://github.com/apache/datafusion/pull/17168) (findepi) +- CI: Fix extended test failure by updating `datafusion-testing` submodule [#17187](https://github.com/apache/datafusion/pull/17187) (2010YOUY01) +- Normalize `NUL` to `\0` in sqllogictests [#17181](https://github.com/apache/datafusion/pull/17181) (Jefffrey) +- Simplify `GetFieldFunc`'s `display_name`, `schema_name` [#17167](https://github.com/apache/datafusion/pull/17167) (findepi) +- Rewrite Nested Loop Join executor for 5× speed and 1% memory usage [#16996](https://github.com/apache/datafusion/pull/16996) (2010YOUY01) +- Minor: Fix compiler warning when compiling `datafusion-cli` [#17205](https://github.com/apache/datafusion/pull/17205) (2010YOUY01) +- Refactor: Do not silently ignore errors in `stats_projection` [#17154](https://github.com/apache/datafusion/pull/17154) (alamb) +- Miscellaneous cleanups [#17189](https://github.com/apache/datafusion/pull/17189) (findepi) +- [Parquet Metadata Cache] Document the ListingTable cache [#17133](https://github.com/apache/datafusion/pull/17133) (alamb) +- Fix: Show backtrace for ArrowError [#17204](https://github.com/apache/datafusion/pull/17204) (2010YOUY01) +- minor: clean up distinct window code [#17215](https://github.com/apache/datafusion/pull/17215) (zhuqi-lucas-001) +- chore: Add drop table test on create_drop.rs [#17219](https://github.com/apache/datafusion/pull/17219) (caicancai) +- chore(deps): bump async-trait from 0.1.88 to 0.1.89 [#17203](https://github.com/apache/datafusion/pull/17203) (dependabot[bot]) +- Bump MSRV to 1.86.0 [#17230](https://github.com/apache/datafusion/pull/17230) (adriangb) +- Minor: improve error message when file creation failed [#17217](https://github.com/apache/datafusion/pull/17217) (2010YOUY01) +- Fix dynamic filter pushdown in HashJoinExec [#17201](https://github.com/apache/datafusion/pull/17201) (adriangb) +- Fix Analyze Exec protobuf roundtrip [#17234](https://github.com/apache/datafusion/pull/17234) (cetra3) +- Preserve `distinct` and `ignore_nulls` in window expressions during proto serde [#17235](https://github.com/apache/datafusion/pull/17235) (cetra3) +- chore(deps): bump serde_json from 1.0.142 to 1.0.143 [#17240](https://github.com/apache/datafusion/pull/17240) (dependabot[bot]) +- chore(deps): bump syn from 2.0.105 to 2.0.106 [#17243](https://github.com/apache/datafusion/pull/17243) (dependabot[bot]) +- Push dynamic pushdown through cooperative and projection execs [#17238](https://github.com/apache/datafusion/pull/17238) (jackkleeman) +- Configure cli test that requires backtrace to be optional [#17236](https://github.com/apache/datafusion/pull/17236) (Jefffrey) +- chore(deps): Update sqlparser to 0.58 [#16456](https://github.com/apache/datafusion/pull/16456) (Dimchikkk) +- chore(deps): bump rustyline from 17.0.0 to 17.0.1 [#17252](https://github.com/apache/datafusion/pull/17252) (dependabot[bot]) +- chore(deps): bump thiserror from 2.0.14 to 2.0.16 [#17257](https://github.com/apache/datafusion/pull/17257) (dependabot[bot]) +- Fix HashJoinExec sideways information passing for partitioned queries [#17197](https://github.com/apache/datafusion/pull/17197) (adriangb) +- Fix HashJoinExec test snapshot under force_hash_collisions=true [#17265](https://github.com/apache/datafusion/pull/17265) (adriangb) +- Deprecate confusingly named `UserDefinedFunctionPlanner` [#17247](https://github.com/apache/datafusion/pull/17247) (alamb) +- Fix: ListingTableFactory paths with dots [#17233](https://github.com/apache/datafusion/pull/17233) (BlakeOrth) +- chore(deps): bump tempfile from 3.20.0 to 3.21.0 [#17268](https://github.com/apache/datafusion/pull/17268) (dependabot[bot]) +- Fix PartialOrd for ScalarUDF [#17182](https://github.com/apache/datafusion/pull/17182) (findepi) +- chore(deps): bump url from 2.5.4 to 2.5.6 [#17283](https://github.com/apache/datafusion/pull/17283) (dependabot[bot]) +- Make dynamic filter creation in HashJoinExec deterministic against partition evaluation order [#17280](https://github.com/apache/datafusion/pull/17280) (adriangb) +- Consolidate Parquet Metadata handling into its own module and struct `DFParquetMetadata` [#17127](https://github.com/apache/datafusion/pull/17127) (alamb) +- Only update TopK dynamic filters if the new ones are more selective [#16433](https://github.com/apache/datafusion/pull/16433) (adriangb) +- Add documentation for UNION schema handling. [#17248](https://github.com/apache/datafusion/pull/17248) (wiedld) +- Replace π-related bound constants with next_up/next_down [#16823](https://github.com/apache/datafusion/pull/16823) (rthummaluru) +- chore: add example for how to use TrackConsumersPool [#17213](https://github.com/apache/datafusion/pull/17213) (wiedld) +- minor: Remove extra line break in explain physical plan [#17303](https://github.com/apache/datafusion/pull/17303) (nuno-faria) +- Support `avg(distinct)` for `float64` type [#17255](https://github.com/apache/datafusion/pull/17255) (Jefffrey) +- chore: check the error message log [#17308](https://github.com/apache/datafusion/pull/17308) (caicancai) +- Expand sql_planner benchmark for benchmarking physical and logical optimization. [#17276](https://github.com/apache/datafusion/pull/17276) (Omega359) +- Encapsulate early File pruning in parquet opener in its own stream [#17293](https://github.com/apache/datafusion/pull/17293) (alamb) +- Implement `partition_statistics` API for `RepartitionExec` [#17061](https://github.com/apache/datafusion/pull/17061) (liamzwbao) +- chore: replace Schema with SchemaRef in PruningExpressionBuilder [#17216](https://github.com/apache/datafusion/pull/17216) (etolbakov) +- chore(deps): bump regex-syntax from 0.8.5 to 0.8.6 [#17320](https://github.com/apache/datafusion/pull/17320) (dependabot[bot]) +- chore(deps): bump indexmap from 2.10.0 to 2.11.0 [#17316](https://github.com/apache/datafusion/pull/17316) (dependabot[bot]) +- refactor: Split `SortMergeJoin` into multiple modules [#17304](https://github.com/apache/datafusion/pull/17304) (jonathanc-n) +- MINOR: add missing examples to example list [#17333](https://github.com/apache/datafusion/pull/17333) (waynexia) +- chore: split hash join to smaller modules [#17300](https://github.com/apache/datafusion/pull/17300) (2010YOUY01) +- chore(deps): bump url from 2.5.6 to 2.5.7 [#17324](https://github.com/apache/datafusion/pull/17324) (dependabot[bot]) +- chore(deps): bump regex from 1.11.1 to 1.11.2 [#17325](https://github.com/apache/datafusion/pull/17325) (dependabot[bot]) +- add a ci job for typo checking [#17339](https://github.com/apache/datafusion/pull/17339) (waynexia) +- chore(deps): bump clap from 4.5.45 to 4.5.46 [#17338](https://github.com/apache/datafusion/pull/17338) (dependabot[bot]) +- chore(deps): bump korandoru/hawkeye from 6.1.1 to 6.2.0 [#17321](https://github.com/apache/datafusion/pull/17321) (dependabot[bot]) +- chore: avoid very cheap copy in `SchemaMapping` [#17344](https://github.com/apache/datafusion/pull/17344) (rluvaton) +- chore(deps): bump actions/checkout from 4.2.2 to 5.0.0 [#17345](https://github.com/apache/datafusion/pull/17345) (dependabot[bot]) +- chore(deps): bump libmimalloc-sys from 0.1.43 to 0.1.44 [#17343](https://github.com/apache/datafusion/pull/17343) (dependabot[bot]) +- fix EquivalenceProperties calculation in DataSourceExec [#17323](https://github.com/apache/datafusion/pull/17323) (adriangb) +- chore(deps): bump mimalloc from 0.1.47 to 0.1.48 [#17353](https://github.com/apache/datafusion/pull/17353) (dependabot[bot]) +- chore(deps): bump tracing-subscriber from 0.3.19 to 0.3.20 [#17355](https://github.com/apache/datafusion/pull/17355) (dependabot[bot]) +- refactor: simplify json_shredding example by using ListingTable [#17369](https://github.com/apache/datafusion/pull/17369) (waynexia) +- Fix incorrect memory accounting for sliced `StringViewArray` [#17315](https://github.com/apache/datafusion/pull/17315) (ding-young) +- chore(deps): bump aws-credential-types from 1.2.5 to 1.2.6 [#17368](https://github.com/apache/datafusion/pull/17368) (dependabot[bot]) +- minor: use debug level log for physical optimizer [#17383](https://github.com/apache/datafusion/pull/17383) (waynexia) +- chore(deps): bump uuid from 1.18.0 to 1.18.1 [#17384](https://github.com/apache/datafusion/pull/17384) (dependabot[bot]) +- chore(deps): bump aws-config from 1.8.5 to 1.8.6 [#17386](https://github.com/apache/datafusion/pull/17386) (dependabot[bot]) +- minor: make dict_from_values public [#17376](https://github.com/apache/datafusion/pull/17376) (parthchandra) +- chore: add memory catalog test to handle table removal before schema deregistration [#17307](https://github.com/apache/datafusion/pull/17307) (caicancai) +- chore(deps): bump actions/setup-node from 4.4.0 to 5.0.0 [#17410](https://github.com/apache/datafusion/pull/17410) (dependabot[bot]) +- chore(deps): bump actions/stale from 9.1.0 to 10.0.0 [#17409](https://github.com/apache/datafusion/pull/17409) (dependabot[bot]) +- chore(deps): bump actions/labeler from 5.0.0 to 6.0.0 [#17408](https://github.com/apache/datafusion/pull/17408) (dependabot[bot]) +- Avoid panic when 'with order' expression could not be converted to a logical expression [#17394](https://github.com/apache/datafusion/pull/17394) (pepijnve) +- chore(deps): bump apache-avro from 0.17.0 to 0.20.0 [#16092](https://github.com/apache/datafusion/pull/16092) (dependabot[bot]) +- chore(deps): bump actions/setup-python from 5.6.0 to 6.0.0 [#17413](https://github.com/apache/datafusion/pull/17413) (dependabot[bot]) +- Test grouping by FixedSizeList [#17415](https://github.com/apache/datafusion/pull/17415) (findepi) +- re-export physical_expr_adapter [#17414](https://github.com/apache/datafusion/pull/17414) (adriangb) +- Benchmark window function with multiple partitioning columns [#17402](https://github.com/apache/datafusion/pull/17402) (findepi) +- Fix PartialOrd for Window [#17393](https://github.com/apache/datafusion/pull/17393) (findepi) +- Memory datasource protobuf support [#17290](https://github.com/apache/datafusion/pull/17290) (lewiszlw) +- fix bounds accumulator reset in HashJoinExec dynamic filter pushdown [#17371](https://github.com/apache/datafusion/pull/17371) (adriangb) +- Unimplement `PartialOrd` for `TDigest`'s `Centroid` [#17440](https://github.com/apache/datafusion/pull/17440) (findepi) +- Unimplement `PartialEq`, `PartialOrd` from `ToRepartition`, `RePartition` [#17441](https://github.com/apache/datafusion/pull/17441) (findepi) +- chore(deps): bump insta from 1.43.1 to 1.43.2 [#17436](https://github.com/apache/datafusion/pull/17436) (dependabot[bot]) +- chore(deps): bump actions/labeler from 6.0.0 to 6.0.1 [#17433](https://github.com/apache/datafusion/pull/17433) (dependabot[bot]) +- chore(deps): bump clap from 4.5.46 to 4.5.47 [#17435](https://github.com/apache/datafusion/pull/17435) (dependabot[bot]) +- Add PhysicalExpr::is_volatile [#17351](https://github.com/apache/datafusion/pull/17351) (adriangb) +- refactor: Use `BufferedBatchState` enum for SMJ spilling [#17429](https://github.com/apache/datafusion/pull/17429) (jonathanc-n) +- Re-enable page index for encrypted Parquet [#17426](https://github.com/apache/datafusion/pull/17426) (adamreeve) +- Re-export apache-avro when avro feature flag is set [#17388](https://github.com/apache/datafusion/pull/17388) (shivbhatia10) +- Improved experience when remote object store URL does not end in / [#17364](https://github.com/apache/datafusion/pull/17364) (xiedeyantu) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 51 dependabot[bot] + 31 Piotr Findeisen + 24 Adrian Garcia Badaracco + 21 Andrew Lamb + 14 Yongting You + 11 Chen Chongchen + 9 Berkay Şahin + 9 Ruihang Xia + 7 Jeffrey Vo + 7 Nuno Faria + 6 Oleks V + 6 Pepijn Van Eeckhoudt + 6 Qi Zhu + 5 ding-young + 4 Adam Reeve + 4 Alan Tang + 4 Jonathan Chen + 4 Liam Bao + 4 Peter Nguyen + 4 Raz Luvaton + 4 kosiew + 3 Adam Gutglick + 3 Cancai Cai + 3 Huaijin + 3 Peter L + 3 lorenarosati + 3 wiedld + 2 Brent Gardner + 2 Bruce Ritchie + 2 Marco Neumann + 2 Matthew Kim + 2 Sherin Jacob + 2 Thearas + 2 Tim Saucer + 2 miro + 2 xudong.w + 2 张林伟 + 1 Ajeeta Asthana + 1 Alex Huang + 1 Andrey Koshchiy + 1 Andy Grove + 1 Blake Orth + 1 Christian van der Loo + 1 Colin Marc + 1 Corwin Joy + 1 David Hewitt + 1 David López + 1 Dima + 1 Eugene Tolbakov + 1 Evgenii Glotov + 1 GPK + 1 Gabriel + 1 Geetansh Juneja + 1 Geoffrey Claude + 1 Haresh Khanna + 1 Jack Kleeman + 1 Jax Liu + 1 Jensen + 1 LB7666 + 1 Loakesh Indiran + 1 Marko Milenković + 1 Matt Matravers + 1 Nga Tran + 1 Parth Chandra + 1 Ronit Thummaluru + 1 Shehab Amin + 1 Shiv Bhatia + 1 Shruti Sharma + 1 Simon Vandel Sillesen + 1 Stuart Carnie + 1 Tobias Schwarzinger + 1 Yuhan Wang + 1 ZC + 1 Zhen Wang + 1 aditya singh rathore + 1 ayemjay + 1 delamarch3 + 1 mwish + 1 theirix +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/50.1.0.md b/dev/changelog/50.1.0.md new file mode 100644 index 0000000000000..e4ead4cb456c3 --- /dev/null +++ b/dev/changelog/50.1.0.md @@ -0,0 +1,47 @@ + + +# Apache DataFusion 50.1.0 Changelog + +This release consists of 4 commits from 4 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Other:** + +- [branch-50] fix: ignore non-existent columns when adding filter equivalence info in FileScanConfig (#17546) [#17600](https://github.com/apache/datafusion/pull/17600) (rkrishn7) +- [branch-50] fix: Ensure the CachedParquetFileReader respects the metadata prefetch hint (#17302) [#17613](https://github.com/apache/datafusion/pull/17613) (shehabgamin) +- [branch-50] Partial AggregateMode will generate duplicate field names which will fail DFSchema construct to branch-50 [#17717](https://github.com/apache/datafusion/pull/17717) (zhuqi-lucas) +- [branch-50]: feat: expose `udafs` and `udwfs` methods on `FunctionRegistry` (#17650) [#17725](https://github.com/apache/datafusion/pull/17725) (milenkovicm) +- [branch-50] Backport change to avoid debug symbols in ci builds to 50.0.0 [#17795](https://github.com/apache/datafusion/pull/17795) (alamb) +- [branch-50] Backport Prevent exponential planning time for Window functions - v2 #17684 [#17778](https://github.com/apache/datafusion/pull/17778) (alamb) +- [branch-50] Fix potential overflow when we print verbose physical plan [#17804](https://github.com/apache/datafusion/pull/17804) (zhuqi-lucas) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 1 Marko Milenković + 1 Qi Zhu + 1 Rohan Krishnaswamy + 1 Shehab Amin +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/50.2.0.md b/dev/changelog/50.2.0.md new file mode 100644 index 0000000000000..6d16ace832ab7 --- /dev/null +++ b/dev/changelog/50.2.0.md @@ -0,0 +1,43 @@ + + +# Apache DataFusion 50.2.0 Changelog + +This release consists of 3 commits from 1 contributors. See credits at the end of this changelog for more information. + +See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) for information on how to upgrade from previous versions. + +**Documentation updates:** + +- [branch-50] Backport: fix typos & pin action hashes (#17855) [#17892](https://github.com/apache/datafusion/pull/17892) (AdamGS) + +**Other:** + +- [branch-50] Backport: Fix docs.rs build: Replace auto_doc_cfg with doc_cfg [#17890](https://github.com/apache/datafusion/pull/17890) (AdamGS) +- [branch-50] Backport: `avg(distinct)` support for decimal types (#17560) [#17885](https://github.com/apache/datafusion/pull/17885) (AdamGS) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 3 Adam Gutglick +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/release/README.md b/dev/release/README.md index 6e4079de8f069..d70e256f73831 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -36,8 +36,8 @@ If you would like to propose your change for inclusion in a release branch for a patch release: 1. Find (or create) the issue for the incremental release ([example release issue]) and discuss the proposed change there with the maintainers. -1. Follow normal workflow to create PR to `main` branch and wait for its approval and merge. -1. After PR is squash merged to `main`, branch from most recent release branch (e.g. `branch-37`), cherry-pick the commit and create a PR targeting the release branch [example backport PR]. +2. Follow normal workflow to create PR to `main` branch and wait for its approval and merge. +3. After PR is squash merged to `main`, branch from most recent release branch (e.g. `branch-37`), cherry-pick the commit and create a PR targeting the release branch [example backport PR]. For example, to backport commit `12345` from `main` to `branch-43`: @@ -128,6 +128,15 @@ release. See [#9697](https://github.com/apache/datafusion/pull/9697) for an example. +Modify `asf.yaml` to protect future release candidate branch to prevent accidental merges: + +```yaml +# needs to be updated as part of the release process +branch-50: + required_pull_request_reviews: + required_approving_review_count: 1 +``` + Here are the commands that could be used to prepare the `38.0.0` release: ### Update Version @@ -271,6 +280,7 @@ Verify that the Cargo.toml in the tarball contains the correct version (cd datafusion/execution && cargo publish) (cd datafusion/functions && cargo publish) (cd datafusion/physical-expr && cargo publish) +(cd datafusion/physical-expr-adapter && cargo publish) (cd datafusion/functions-aggregate && cargo publish) (cd datafusion/functions-window && cargo publish) (cd datafusion/functions-nested && cargo publish) @@ -278,63 +288,31 @@ Verify that the Cargo.toml in the tarball contains the correct version (cd datafusion/optimizer && cargo publish) (cd datafusion/common-runtime && cargo publish) (cd datafusion/physical-plan && cargo publish) +(cd datafusion/pruning && cargo publish) (cd datafusion/physical-optimizer && cargo publish) -(cd datafusion/catalog && cargo publish) +(cd datafusion/session && cargo publish) (cd datafusion/datasource && cargo publish) +(cd datafusion/catalog && cargo publish) (cd datafusion/catalog-listing && cargo publish) (cd datafusion/functions-table && cargo publish) +(cd datafusion/datasource-csv && cargo publish) +(cd datafusion/datasource-json && cargo publish) +(cd datafusion/datasource-parquet && cargo publish) (cd datafusion/core && cargo publish) (cd datafusion/proto-common && cargo publish) (cd datafusion/proto && cargo publish) +(cd datafusion/datasource-avro && cargo publish) (cd datafusion/substrait && cargo publish) (cd datafusion/ffi && cargo publish) (cd datafusion-cli && cargo publish) +(cd datafusion/spark && cargo publish) (cd datafusion/sqllogictest && cargo publish) ``` ### Publish datafusion-cli on Homebrew -Run `publish_homebrew.sh` to publish `datafusion-cli` on Homebrew. In order to do so it is necessary to -fork the `homebrew-core` repo https://github.com/Homebrew/homebrew-core/, have Homebrew installed on your -macOS/Linux/WSL2 and properly configured and have a Github Personal Access Token that has permission to file pull requests in the `homebrew-core` repo. - -#### Fork the `homebrew-core` repo - -Go to https://github.com/Homebrew/homebrew-core/ and fork the repo. - -#### Install and configure Homebrew - -Please visit https://brew.sh/ to obtain Homebrew. In addition to that please check out https://docs.brew.sh/Homebrew-on-Linux if you are on Linux or WSL2. - -Before running the script make sure that you can run the following command in your bash to make sure -that `brew` has been installed and configured properly: - -```shell -brew --version -``` - -#### Create a Github Personal Access Token - -To create a Github Personal Access Token, please visit https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token for instructions. - -- Make sure to select either **All repositories** or **Only selected repositories** so that you have access to **Repository permissions**. -- If you only use the token for selected repos make sure you include your - fork of `homebrew-core` in the list of repos under **Selected repositories**. -- Make sure to have **Read and write** access enabled for pull requests in your **Repository permissions**. - -After all of the above is complete execute the following command: - -```shell -dev/release/publish_homebrew.sh -``` - -Note that sometimes someone else has already submitted a PR to update the datafusion formula in homebrew. -In this case you will get an error with a message that your PR is a duplicate of an existing one. In this -case no further action is required. - -Alternatively manually submit a simple PR to update tag and commit hash for the datafusion -formula in homebrew-core. Here is an example PR: -https://github.com/Homebrew/homebrew-core/pull/89562. +[`datafusion` formula](https://formulae.brew.sh/formula/datafusion) is [updated automatically](https://github.com/Homebrew/homebrew-core/pulls?q=is%3Apr+datafusion+is%3Aclosed), +so no action is needed. ### Call the vote diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index 1349416bcaa59..830d329f73c4f 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -124,6 +124,9 @@ def generate_changelog(repo, repo_name, tag1, tag2, version): print(f"This release consists of {commit_count} commits from {contributor_count} contributors. " f"See credits at the end of this changelog for more information.\n") + print("See the [upgrade guide](https://datafusion.apache.org/library-user-guide/upgrading.html) " + "for information on how to upgrade from previous versions.\n") + print_pulls(repo_name, "Breaking changes", breaking) print_pulls(repo_name, "Performance related", performance) print_pulls(repo_name, "Implemented enhancements", enhancements) diff --git a/dev/release/publish_homebrew.sh b/dev/release/publish_homebrew.sh deleted file mode 100644 index 20955953e85a7..0000000000000 --- a/dev/release/publish_homebrew.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/usr/bin/env bash -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -ue - -if [ "$#" -ne 4 ]; then - echo "Usage: $0 " - exit 1 -fi - -version=$1 -github_user=$2 -github_token=$3 -# Prepare for possible renaming of the default branch on Homebrew -homebrew_default_branch_name=$4 - -# Git parallel fetch -if sysctl -n hw.ncpu 2>/dev/null; then # macOS - num_processing_units=$(sysctl -n hw.ncpu) -elif [ -x "$(command -v nproc)" ]; then # Linux - num_processing_units=$(nproc) -else # Fallback - num_processing_units=1 -fi - -url="https://www.apache.org/dyn/closer.lua?path=datafusion/datafusion-${version}/apache-datafusion-${version}.tar.gz" -sha256="$(curl https://dist.apache.org/repos/dist/release/datafusion/datafusion-${version}/apache-datafusion-${version}.tar.gz.sha256 | cut -d' ' -f1)" - -pushd "$(brew --repository homebrew/core)" - -if ! git remote | grep -q --fixed-strings ${github_user}; then - echo "Setting ''${github_user}' remote" - git remote add ${github_user} git@github.com:${github_user}/homebrew-core.git -fi - -echo "Updating working copy" -git fetch --all --prune --tags --force -j$num_processing_units - -branch=apache-datafusion-${version} -echo "Creating branch: ${branch}" -git branch -D ${branch} || : -git checkout -b ${branch} origin/master - -echo "Updating datafusion formulae" -brew bump-formula-pr \ - --commit \ - --no-audit \ - --sha256="${sha256}" \ - --url="${url}" \ - --verbose \ - --write-only \ - datafusion - -echo "Testing datafusion formulae" -brew uninstall datafusion || : -brew install --build-from-source datafusion -brew test datafusion -brew audit --strict datafusion - -git push -u $github_user ${branch} - -git checkout - - -popd - -echo "Create the pull request" -title="datafusion ${version}" -body="Created using \`bump-formula-pr\`" -data="{\"title\":\"$title\", \"body\":\"$body\", \"head\":\"$github_username:$branch\", \"base\":\"$homebrew_default_branch_name\"}" -curl -X POST \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer $github_token" \ - https://api.github.com/repos/Homebrew/homebrew-core/pulls \ - -d "$data" - -echo "Complete!" diff --git a/dev/rust_lint.sh b/dev/rust_lint.sh index af0fce72ccfa5..8fe7220085c93 100755 --- a/dev/rust_lint.sh +++ b/dev/rust_lint.sh @@ -20,13 +20,21 @@ # This script runs all the Rust lints locally the same way the # DataFusion CI does +# For `.toml` format checking set -e if ! command -v taplo &> /dev/null; then echo "Installing taplo using cargo" cargo install taplo-cli fi +# For Apache licence header checking +if ! command -v hawkeye &> /dev/null; then + echo "Installing hawkeye using cargo" + cargo install hawkeye --locked +fi + ci/scripts/rust_fmt.sh ci/scripts/rust_clippy.sh ci/scripts/rust_toml_fmt.sh ci/scripts/rust_docs.sh +ci/scripts/license_header.sh \ No newline at end of file diff --git a/dev/update_config_docs.sh b/dev/update_config_docs.sh index 585cb77839f98..ed3e699c1413a 100755 --- a/dev/update_config_docs.sh +++ b/dev/update_config_docs.sh @@ -25,6 +25,8 @@ cd "${SOURCE_DIR}/../" && pwd TARGET_FILE="docs/source/user-guide/configs.md" PRINT_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_config_docs" +PRINT_RUNTIME_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_runtime_config_docs" + echo "Inserting header" cat <<'EOF' > "$TARGET_FILE" @@ -48,28 +50,134 @@ cat <<'EOF' > "$TARGET_FILE" --> # Configuration Settings -The following configuration options can be passed to `SessionConfig` to control various aspects of query execution. +DataFusion configurations control various aspects of DataFusion planning and execution + +## Setting Configuration Options + +### Programmatically +You can set the options programmatically via the [`ConfigOptions`] object. For +example, to configure the `datafusion.execution.target_partitions` using the API: + +```rust +use datafusion::common::config::ConfigOptions; +let mut config = ConfigOptions::new(); +config.execution.target_partitions = 1; +``` + +### Via Environment Variables + +You can also set configuration options via environment variables using +[`ConfigOptions::from_env`], for example + +```shell +DATAFUSION_EXECUTION_TARGET_PARTITIONS=1 ./your_program +``` + +### Via SQL + +You can also set configuration options via SQL using the `SET` command. For +example, to configure `datafusion.execution.target_partitions`: + +```sql +SET datafusion.execution.target_partitions = '1'; +``` -For applications which do not expose `SessionConfig`, like `datafusion-cli`, these options may also be set via environment variables. -To construct a session with options from the environment, use `SessionConfig::from_env`. -The name of the environment variable is the option's key, transformed to uppercase and with periods replaced with underscores. -For example, to configure `datafusion.execution.batch_size` you would set the `DATAFUSION_EXECUTION_BATCH_SIZE` environment variable. -Values are parsed according to the [same rules used in casts from Utf8](https://docs.rs/arrow/latest/arrow/compute/kernels/cast/fn.cast.html). -If the value in the environment variable cannot be cast to the type of the configuration option, the default value will be used instead and a warning emitted. -Environment variables are read during `SessionConfig` initialisation so they must be set beforehand and will not affect running sessions. +[`ConfigOptions`]: https://docs.rs/datafusion/latest/datafusion/common/config/struct.ConfigOptions.html +[`ConfigOptions::from_env`]: https://docs.rs/datafusion/latest/datafusion/common/config/struct.ConfigOptions.html#method.from_env + +The following configuration settings are available: EOF echo "Running CLI and inserting config docs table" $PRINT_CONFIG_DOCS_COMMAND >> "$TARGET_FILE" +echo "Inserting runtime config header" +cat <<'EOF' >> "$TARGET_FILE" + +# Runtime Configuration Settings + +DataFusion runtime configurations can be set via SQL using the `SET` command. + +For example, to configure `datafusion.runtime.memory_limit`: + +```sql +SET datafusion.runtime.memory_limit = '2G'; +``` + +The following runtime configuration settings are available: + +EOF + +echo "Running CLI and inserting runtime config docs table" +$PRINT_RUNTIME_CONFIG_DOCS_COMMAND >> "$TARGET_FILE" + +cat <<'EOF' >> "$TARGET_FILE" + +# Tuning Guide + +## Short Queries + +By default DataFusion will attempt to maximize parallelism and use all cores -- +For example, if you have 32 cores, each plan will split the data into 32 +partitions. However, if your data is small, the overhead of splitting the data +to enable parallelization can dominate the actual computation. + +You can find out how many cores are being used via the [`EXPLAIN`] command and look +at the number of partitions in the plan. + +[`EXPLAIN`]: sql/explain.md + +The `datafusion.optimizer.repartition_file_min_size` option controls the minimum file size the +[`ListingTable`] provider will attempt to repartition. However, this +does not apply to user defined data sources and only works when DataFusion has accurate statistics. + +If you know your data is small, you can set the `datafusion.execution.target_partitions` +option to a smaller number to reduce the overhead of repartitioning. For very small datasets (e.g. less +than 1MB), we recommend setting `target_partitions` to 1 to avoid repartitioning altogether. + +```sql +SET datafusion.execution.target_partitions = '1'; +``` + +[`ListingTable`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html + +## Memory-limited Queries + +When executing a memory-consuming query under a tight memory limit, DataFusion +will spill intermediate results to disk. + +When the [`FairSpillPool`] is used, memory is divided evenly among partitions. +The higher the value of `datafusion.execution.target_partitions`, the less memory +is allocated to each partition, and the out-of-core execution path may trigger +more frequently, possibly slowing down execution. + +Additionally, while spilling, data is read back in `datafusion.execution.batch_size` size batches. +The larger this value, the fewer spilled sorted runs can be merged. Decreasing this setting +can help reduce the number of subsequent spills required. + +In conclusion, for queries under a very tight memory limit, it's recommended to +set `target_partitions` and `batch_size` to smaller values. + +```sql +-- Query still gets parallelized, but each partition will have more memory to use +SET datafusion.execution.target_partitions = 4; +-- Smaller than the default '8192', while still keep the benefit of vectorized execution +SET datafusion.execution.batch_size = 1024; +``` + +[`FairSpillPool`]: https://docs.rs/datafusion/latest/datafusion/execution/memory_pool/struct.FairSpillPool.html + +EOF + + echo "Running prettier" npx prettier@2.3.2 --write "$TARGET_FILE" diff --git a/dev/update_function_docs.sh b/dev/update_function_docs.sh index a9e87aacf5ad1..6ed760bd22ff4 100755 --- a/dev/update_function_docs.sh +++ b/dev/update_function_docs.sh @@ -59,6 +59,25 @@ dev/update_function_docs.sh file for updating surrounding text. # Aggregate Functions Aggregate functions operate on a set of values to compute a single result. + +## Filter clause + +Aggregate functions support the SQL `FILTER (WHERE ...)` clause to restrict which input rows contribute to the aggregate result. + +```sql +function([exprs]) FILTER (WHERE condition) +``` + +Example: + +```sql +SELECT + sum(salary) FILTER (WHERE salary > 0) AS sum_positive_salaries, + count(*) FILTER (WHERE active) AS active_count +FROM employees; +``` + +Note: When no rows pass the filter, `COUNT` returns `0` while `SUM`/`AVG`/`MIN`/`MAX` return `NULL`. EOF echo "Running CLI and inserting aggregate function docs table" @@ -266,6 +285,17 @@ where **offset** is an non-negative integer. RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must specify exactly one column). +## Filter clause for aggregate window functions + +Aggregate window functions support the SQL `FILTER (WHERE ...)` clause to include only rows that satisfy the predicate from the window frame in the aggregation. + +```sql +sum(salary) FILTER (WHERE salary > 0) + OVER (PARTITION BY depname ORDER BY salary ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +``` + +If no rows in the frame satisfy the filter for a given output row, `COUNT` yields `0` while `SUM`/`AVG`/`MIN`/`MAX` yield `NULL`. + ## Aggregate functions All [aggregate functions](aggregate_functions.md) can be used as window functions. diff --git a/docs/.gitignore b/docs/.gitignore index e2a54c053edf9..a3adddc690ab0 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -build -temp +temp/ +build/ venv/ .python-version +__pycache__/ diff --git a/docs/Makefile b/docs/Makefile index 6bce19911da5b..20ccd822f59c7 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -21,7 +21,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= +SPHINXOPTS ?= -W SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build diff --git a/docs/README.md b/docs/README.md index acf3cb754c008..c3d87ee8e84a3 100644 --- a/docs/README.md +++ b/docs/README.md @@ -28,22 +28,36 @@ https://datafusion.apache.org/ as part of the release process. It's recommended to install build dependencies and build the documentation inside a Python virtualenv. -- Python -- `pip install -r requirements.txt` +```sh +python3 -m venv venv +pip install -r requirements.txt +``` + +If using [uv](https://docs.astral.sh/uv/) the script can be run like so without +needing to create a virtual environment: + +```sh +uv run --with-requirements requirements.txt bash build.sh +``` ## Build & Preview Run the provided script to build the HTML pages. ```bash +# If using venv, ensure you have activated it ./build.sh ``` -The HTML will be generated into a `build` directory. +The HTML will be generated into a `build` directory. Open `build/html/index.html` +in your preferred browser, e.g. Preview the site on Linux by running this command. ```bash +# On macOS +open build/html/index.html +# On Linux with Firefox firefox build/html/index.html ``` diff --git a/docs/build.sh b/docs/build.sh index 73516e8e9c68c..9e4a118580cab 100755 --- a/docs/build.sh +++ b/docs/build.sh @@ -20,12 +20,5 @@ set -e rm -rf build 2> /dev/null -rm -rf temp 2> /dev/null -mkdir temp -cp -rf source/* temp/ -# replace relative URLs with absolute URLs -sed -i -e 's/\.\.\/\.\.\/\.\.\//https:\/\/github.com\/apache\/arrow-datafusion\/blob\/main\//g' temp/contributor-guide/index.md -python rustdoc_trim.py - -make SOURCEDIR=`pwd`/temp SPHINXOPTS=-W html +make html diff --git a/docs/make.bat b/docs/make.bat index ded5b4a3e2b67..33e25e4ee4651 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -23,7 +23,8 @@ REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build -) +) +set SPHINXOPTS=-W set SOURCEDIR=source set BUILDDIR=build diff --git a/docs/requirements.txt b/docs/requirements.txt index bd030fb670446..78206d2c19866 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. -sphinx +sphinx==8.2.3 +sphinx-reredirects==1.0.0 pydata-sphinx-theme==0.8.0 -myst-parser -maturin -jinja2 -setuptools>=48.0.0 +myst-parser==4.0.1 +maturin==1.9.6 +jinja2==3.1.6 +setuptools==80.9.0 diff --git a/docs/rustdoc_trim.py b/docs/rustdoc_trim.py index 7ea96dbb44a54..70becc45ee760 100644 --- a/docs/rustdoc_trim.py +++ b/docs/rustdoc_trim.py @@ -16,8 +16,7 @@ # under the License. import re - -from pathlib import Path +from sphinx.application import Sphinx # Regex pattern to match Rust code blocks in Markdown RUST_CODE_BLOCK_PATTERN = re.compile(r"```rust\s*(.*?)```", re.DOTALL) @@ -46,30 +45,16 @@ def _process_code_block(match): return RUST_CODE_BLOCK_PATTERN.sub(_process_code_block, markdown_content) -# Example usage -def process_markdown_file(file_path): - # Read the Markdown file - with open(file_path, "r", encoding="utf-8") as file: - markdown_content = file.read() - +def process_source_file(app: Sphinx, docname: str, source: list[str]): + original_content = source[0] # Remove lines starting with '#' in Rust code blocks - updated_markdown_content = remove_hashtag_lines_in_rust_blocks(markdown_content) - - # Write the updated content back to the Markdown file - with open(file_path, "w", encoding="utf-8") as file: - file.write(updated_markdown_content) - - print(f"Done processing file: {file_path}") - - -root_directory = Path("./temp/library-user-guide") -for file_path in root_directory.rglob("*.md"): - print(f"Processing file: {file_path}") - process_markdown_file(file_path) + modified_content = remove_hashtag_lines_in_rust_blocks(original_content) + source[0] = modified_content -root_directory = Path("./temp/user-guide") -for file_path in root_directory.rglob("*.md"): - print(f"Processing file: {file_path}") - process_markdown_file(file_path) -print("All Markdown files processed.") +def setup(app: Sphinx): + app.connect("source-read", process_source_file) + return dict( + parallel_read_safe=True, + parallel_write_safe=True, + ) diff --git a/docs/scripts/update_committer_list.py b/docs/scripts/update_committer_list.py new file mode 100755 index 0000000000000..c66eb52468523 --- /dev/null +++ b/docs/scripts/update_committer_list.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Utility for updating the committer list in the governance documentation +by reading from the Apache DataFusion phonebook and combining with existing data. +""" + +import re +import requests +import sys +import os +from typing import Dict, List, NamedTuple, Set + + +class Committer(NamedTuple): + name: str + apache: str + github: str + affiliation: str + role: str + + +# Return (pmc, committers) each a dictionary like +# key: apache id +# value: Real name + +def get_asf_roster(): + """Get the current roster from Apache phonebook.""" + # See https://home.apache.org/phonebook-about.html + committers_url = "https://whimsy.apache.org/public/public_ldap_projects.json" + + # people https://whimsy.apache.org/public/public_ldap_people.json + people_url = "https://whimsy.apache.org/public/public_ldap_people.json" + + try: + r = requests.get(committers_url) + r.raise_for_status() + j = r.json() + proj = j['projects']['datafusion'] + + # Get PMC members and committers + pmc_ids = set(proj['owners']) + committer_ids = set(proj['members']) - pmc_ids + + except Exception as e: + print(f"Error fetching ASF roster: {e}") + return set(), set() + + # Fetch people to get github handles and affiliations + # + # The data looks like this: + # { + # "lastCreateTimestamp": "20250913131506Z", + # "people_count": 9932, + # "people": { + # "a_budroni": { + # "name": "Alessandro Budroni", + # "createTimestamp": "20160720223917Z" + # }, + # ... + # } + try: + r = requests.get(people_url) + r.raise_for_status() + j = r.json() + people = j['people'] + + # make a dictionary with each pmc_id and value their real name + pmcs = {p: people[p]['name'] for p in pmc_ids} + committers = {c: people[c]['name'] for c in committer_ids} + + except Exception as e: + print(f"Error fetching ASF people: {e}") + + + return pmcs, committers + + + +def parse_existing_table(content: str) -> List[Committer]: + """Parse the existing committer table from the markdown content.""" + committers = [] + + # Find the table between the markers + start_marker = "" + end_marker = "" + + start_idx = content.find(start_marker) + end_idx = content.find(end_marker) + + if start_idx == -1 or end_idx == -1: + return committers + + table_content = content[start_idx:end_idx] + + # Parse table rows (skip header and separator) + lines = table_content.split('\n') + for line in lines: + line = line.strip() + if line.startswith('|') and '---' not in line and line.count('|') >= 4: + # Split by | and clean up + parts = [part.strip() for part in line.split('|')] + if len(parts) >= 5: + name = parts[1].strip() + apache = parts[2].strip() + github = parts[3].strip() + affiliation = parts[4].strip() + role = parts[5].strip() + + if name and name != 'Name' and (not '-----' in name): + committers.append(Committer(name, apache, github, affiliation, role)) + + return committers + + +def generate_table_row(committer: Committer) -> str: + """Generate a markdown table row for a committer.""" + github_link = f"[{committer.github}](https://github.com/{committer.github})" + return f"| {committer.name:<23} | {committer.apache:<39} |{committer.github:<39} | {committer.affiliation:<11} | {committer.role:<9} |" + + +def sort_committers(committers: List[Committer]) -> List[Committer]: + """Sort committers by role ('PMC Chair', PMC, Committer) then by apache id.""" + role_order = {'PMC Chair': 0, 'PMC': 1, 'Committer': 2} + + return sorted(committers, key=lambda c: (role_order.get(c.role, 3), c.apache.lower())) + + +def update_governance_file(file_path: str): + """Update the governance file with the latest committer information.""" + try: + with open(file_path, 'r') as f: + content = f.read() + except FileNotFoundError: + print(f"Error: File {file_path} not found") + return False + + # Parse existing committers + existing_committers = parse_existing_table(content) + print(f"Found {len(existing_committers)} existing committers") + + # Get ASF roster + asf_pmcs, asf_committers = get_asf_roster() + print(f"Found {len(asf_pmcs)} PMCs and {len(asf_committers)} committers in ASF roster") + + + # Create a map of existing committers by apache id + existing_by_apache = {c.apache: c for c in existing_committers} + + # Update the entries based on the ASF roster + updated_committers = [] + for apache_id, name in {**asf_pmcs, **asf_committers}.items(): + role = 'PMC' if apache_id in asf_pmcs else 'Committer' + if apache_id in existing_by_apache: + existing = existing_by_apache[apache_id] + # Preserve PMC Chair role if already set + if existing.role == 'PMC Chair': + role = 'PMC Chair' + updated_committers.append(Committer( + name=existing.name, + apache=apache_id, + github=existing.github, + affiliation=existing.affiliation, + role=role + )) + # add a new entry for new committers with placeholder values + else: + print(f"New entry found: {name} ({apache_id})") + # Placeholder github and affiliation + updated_committers.append(Committer( + name=name, + apache=apache_id, + github="", # user should update + affiliation="", # User should update + role=role + )) + + + # Sort the committers + sorted_committers = sort_committers(updated_committers) + + # Generate new table + table_lines = [ + "| Name | Apache ID | github | Affiliation | Role |", + "|-------------------------|-----------|----------------------------|-------------|-----------|" + ] + + for committer in sorted_committers: + table_lines.append(generate_table_row(committer)) + + new_table = '\n'.join(table_lines) + + # Replace the table in the content + start_marker = "" + end_marker = "" + + start_idx = content.find(start_marker) + end_idx = content.find(end_marker) + + if start_idx == -1 or end_idx == -1: + print("Error: Could not find table markers in file") + return False + + # Find the end of the start marker line + start_line_end = content.find('\n', start_idx) + 1 + + new_content = ( + content[:start_line_end] + + new_table + '\n' + + content[end_idx:] + ) + + # Write back to file + try: + with open(file_path, 'w') as f: + f.write(new_content) + print(f"Successfully updated {file_path}") + return True + except Exception as e: + print(f"Error writing file: {e}") + return False + + +def main(): + """Main function.""" + # Default path to governance file + script_dir = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.dirname(script_dir) + governance_file = os.path.join(repo_root, "source", "contributor-guide", "governance.md") + + if len(sys.argv) > 1: + governance_file = sys.argv[1] + + if not os.path.exists(governance_file): + print(f"Error: Governance file not found at {governance_file}") + sys.exit(1) + + print(f"Updating committer list in {governance_file}") + + if update_governance_file(governance_file): + print("Committer list updated successfully") + else: + print("Failed to update committer list") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/docs/source/_static/theme_overrides.css b/docs/source/_static/theme_overrides.css index 3b1b86daac6aa..0859beb788aa4 100644 --- a/docs/source/_static/theme_overrides.css +++ b/docs/source/_static/theme_overrides.css @@ -84,3 +84,29 @@ Details: 8rem for search box etc*/ white-space: normal !important; } } + +/* Make wide tables scroll within the content area to avoid overlapping the + right sidebar. Prevents tables from bleeding underneath the sticky sidebar. */ +.bd-content table { + display: block; + overflow-x: auto; + -webkit-overflow-scrolling: touch; + max-width: 100%; +} + +/* Restore proper table display to maintain column alignment */ +.bd-content table thead, +.bd-content table tbody { display: table-row-group; } + +.bd-content table tr { display: table-row; } + +.bd-content table th, +.bd-content table td { + display: table-cell; + white-space: normal; +} + +/* Maintain striped styling when table scrolls */ +.bd-content table tbody tr:nth-of-type(odd) { + background-color: rgba(0, 0, 0, 0.03); +} diff --git a/docs/source/conf.py b/docs/source/conf.py index 00037867a0923..36556e74e69c4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,16 +26,17 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) +import os +import sys + +# To pickup rustdoc_trim.py +sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- -project = 'Apache DataFusion' -copyright = '2019-2025, Apache Software Foundation' -author = 'Apache Software Foundation' +project = "Apache DataFusion" +copyright = "2019-2025, Apache Software Foundation" +author = "Apache Software Foundation" # -- General configuration --------------------------------------------------- @@ -44,23 +45,25 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.ifconfig', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'myst_parser', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.ifconfig", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "myst_parser", + "sphinx_reredirects", + "rustdoc_trim", ] source_suffix = { - '.rst': 'restructuredtext', - '.md': 'markdown', + ".rst": "restructuredtext", + ".md": "markdown", } # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -82,7 +85,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'pydata_sphinx_theme' +html_theme = "pydata_sphinx_theme" html_theme_options = { "use_edit_page_button": True, @@ -98,13 +101,11 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_logo = "_static/images/2x_bgwhite_original.png" -html_css_files = [ - "theme_overrides.css" -] +html_css_files = ["theme_overrides.css"] html_sidebars = { "**": ["docs-sidebar.html"], @@ -120,4 +121,9 @@ # presence of some special characters like: 🚀, å, {,... But this isn’t a major # issue for our documentation. So, suppress these warnings to keep our build # log cleaner. -suppress_warnings = ['misc.highlighting_failure'] +suppress_warnings = ["misc.highlighting_failure"] + +redirects = { + "library-user-guide/adding-udfs": "functions/index.html", + "user-guide/runtime_configs": "configs.html", +} diff --git a/docs/source/contributor-guide/development_environment.md b/docs/source/contributor-guide/development_environment.md index cd1b8ea356427..53f2eb97c6fb2 100644 --- a/docs/source/contributor-guide/development_environment.md +++ b/docs/source/contributor-guide/development_environment.md @@ -75,18 +75,20 @@ Alternatively a binary release can be downloaded from the [Release Page](https:/ DataFusion is written in Rust and it uses a standard rust toolkit: +- `rustup update stable` DataFusion generally uses the latest stable release of Rust, though it may lag when new Rust toolchains release + - See which toolchain is currently pinned in the [`rust-toolchain.toml`](https://github.com/apache/datafusion/blob/main/rust-toolchain.toml) file + - This can cause issues such as not having the rust-analyzer component installed for the specified toolchain, in which case just install it manually, e.g. `rustup component add --toolchain 1.88 rust-analyzer` - `cargo build` - `cargo fmt` to format the code -- `cargo test` to test - etc. -Note that running `cargo test` requires significant memory resources, due to cargo running many tests in parallel by default. If you run into issues with slow tests or system lock ups, you can significantly reduce the memory required by instead running `cargo test -- --test-threads=1`. For more information see [this issue](https://github.com/apache/datafusion/issues/5347). - Testing setup: -- `rustup update stable` DataFusion uses the latest stable release of rust - `git submodule init` - `git submodule update --init --remote --recursive` +- `cargo test` to run tests + +Note that running `cargo test` requires significant memory resources, due to cargo running many tests in parallel by default. If you run into issues with slow tests or system lock ups, you can significantly reduce the memory required by instead running `cargo test -- --test-threads=1`. For more information see [this issue](https://github.com/apache/datafusion/issues/5347). Formatting instructions: diff --git a/docs/source/contributor-guide/governance.md b/docs/source/contributor-guide/governance.md index 27ff90eb92c8d..857a82fa9613f 100644 --- a/docs/source/contributor-guide/governance.md +++ b/docs/source/contributor-guide/governance.md @@ -19,10 +19,6 @@ # Governance -The current PMC and committers are listed in the [Apache Phonebook]. - -[apache phonebook]: https://projects.apache.org/committee.html?datafusion - ## Overview DataFusion is part of the [Apache Software Foundation] and is governed following @@ -38,6 +34,87 @@ As much as practicable, we strive to make decisions by consensus, and anyone in the community is encouraged to propose ideas, start discussions, and contribute to the project. +## People + +DataFusion is currently governed by the following individuals + + + + + +| Name | Apache ID | github | Affiliation | Role | +| ----------------------- | ---------------- | ------------------------------------------------------- | -------------- | --------- | +| Andrew Lamb | alamb | [alamb](https://github.com/alamb) | InfluxData | PMC Chair | +| Andrew Grove | agrove | [andygrove](https://github.com/andygrove) | Apple | PMC | +| Mustafa Akur | akurmustafa | [akurmustafa](https://github.com/akurmustafa) | OHSU | PMC | +| Berkay Şahin | berkay | [berkaysynnada](https://github.com/berkaysynnada) | Synnada | PMC | +| Oleksandr Voievodin | comphead | [comphead](https://github.com/comphead) | Apple | PMC | +| Daniël Heres | dheres | [Dandandan](https://github.com/Dandandan) | | PMC | +| QP Hou | houqp | [houqp](https://github.com/houqp) | | PMC | +| Jie Wen | jakevin | [jackwener](https://github.com/jackwener) | | PMC | +| Jay Zhan | jayzhan | [jayzhan211](https://github.com/jayzhan211) | | PMC | +| Jonah Gao | jonah | [jonahgao](https://github.com/jonahgao) | | PMC | +| Kun Liu | liukun | [liukun4515](https://github.com/liukun4515) | | PMC | +| Mehmet Ozan Kabak | ozankabak | [ozankabak](https://github.com/ozankabak) | Synnada, Inc | PMC | +| Tim Saucer | timsaucer | [timsaucer](https://github.com/timsaucer) | | PMC | +| L. C. Hsieh | viirya | [viirya](https://github.com/viirya) | Databricks | PMC | +| Ruihang Xia | wayne | [waynexia](https://github.com/waynexia) | Greptime | PMC | +| Wes McKinney | wesm | [wesm](https://github.com/wesm) | Posit | PMC | +| Will Jones | wjones127 | [wjones127](https://github.com/wjones127) | LanceDB | PMC | +| Xudong Wang | xudong963 | [xudong963](https://github.com/xudong963) | Polygon.io | PMC | +| Adrian Garcia Badaracco | adriangb | [adriangb](https://github.com/adriangb) | Pydantic | Committer | +| Brent Gardner | avantgardner | [avantgardnerio](https://github.com/avantgardnerio) | Coralogix | Committer | +| Dmitrii Blaginin | blaginin | [blaginin](https://github.com/blaginin) | SpiralDB | Committer | +| Piotr Findeisen | findepi | [findepi](https://github.com/findepi) | dbt Labs | Committer | +| Jax Liu | goldmedal | [goldmedal](https://github.com/goldmedal) | Canner | Committer | +| Huaxin Gao | huaxingao | [huaxingao](https://github.com/huaxingao) | | Committer | +| Ifeanyi Ubah | iffyio | [iffyio](https://github.com/iffyio) | Validio | Committer | +| Jeffrey Vo | jeffreyvo | [Jefffrey](https://github.com/Jefffrey) | | Committer | +| Liu Jiayu | jiayuliu | [jimexist](https://github.com/jimexist) | | Committer | +| Ruiqiu Cao | kamille | [Rachelint](https://github.com/Rachelint) | Tencent | Committer | +| Kazuyuki Tanimura | kazuyukitanimura | [kazuyukitanimura](https://github.com/kazuyukitanimura) | | Committer | +| Eduard Karacharov | korowa | [korowa](https://github.com/korowa) | | Committer | +| Siew Kam Onn | kosiew | [kosiew](https://github.com/kosiew) | | Committer | +| Lewis Zhang | linwei | [lewiszlw](https://github.com/lewiszlw) | diit.cn | Committer | +| Matt Butrovich | mbutrovich | [mbutrovich](https://github.com/mbutrovich) | Apple | Committer | +| Metehan Yildirim | mete | [metegenez](https://github.com/metegenez) | | Committer | +| Marko Milenković | milenkovicm | [milenkovicm](https://github.com/milenkovicm) | | Committer | +| Wang Mingming | mingmwang | [mingmwang](https://github.com/mingmwang) | | Committer | +| Michael Ward | mjward | [Michael-J-Ward ](https://github.com/Michael-J-Ward) | | Committer | +| Marco Neumann | mneumann | [crepererum](https://github.com/crepererum) | InfluxData | Committer | +| Zhong Yanghong | nju_yaho | [yahoNanJing](https://github.com/yahoNanJing) | | Committer | +| Paddy Horan | paddyhoran | [paddyhoran](https://github.com/paddyhoran) | Assured Allies | Committer | +| Parth Chandra | parthc | [parthchandra](https://github.com/parthchandra) | Apple | Committer | +| Rémi Dettai | rdettai | [rdettai](https://github.com/rdettai) | | Committer | +| Raz Luvaton | rluvaton | [rluvaton](https://github.com/rluvaton) | | Committer | +| Chao Sun | sunchao | [sunchao](https://github.com/sunchao) | OpenAI | Committer | +| Daniel Harris | thinkharderdev | [thinkharderdev](https://github.com/thinkharderdev) | Coralogix | Committer | +| Raphael Taylor-Davies | tustvold | [tustvold](https://github.com/tustvold) | | Committer | +| Zhen Wang | wangzhen | [wForget](https://github.com/wForget) | | Committer | +| Weijun Huang | weijun | [Weijun-H](https://github.com/Weijun-H) | OrbDB | Committer | +| Yang Jiang | yangjiang | [Ted-jiang](https://github.com/Ted-jiang) | Ebay | Committer | +| Yoav Cohen | ycohen | [yoavcloud](https://github.com/yoavcloud) | | Committer | +| Yijie Shen | yjshen | [yjshen](https://github.com/yjshen) | DataPelago | Committer | +| Yongting You | ytyou | [2010YOUY01](https://github.com/2010YOUY01) | Independent | Committer | +| Qi Zhu | zhuqi | [zhuqi-lucas](https://github.com/zhuqi-lucas) | Polygon.io | Committer | + + + +Note that the authoritative list of PMC and committers is the [Apache Phonebook] + +[apache phonebook]: https://projects.apache.org/committee.html?datafusion + ## Roles - **Contributors**: Anyone who contributes to the project, whether it be code, diff --git a/docs/source/contributor-guide/gsoc_application_guidelines.md b/docs/source/contributor-guide/gsoc/gsoc_application_guidelines_2025.md similarity index 99% rename from docs/source/contributor-guide/gsoc_application_guidelines.md rename to docs/source/contributor-guide/gsoc/gsoc_application_guidelines_2025.md index e8ca9703a5ddf..c127b4231b8e1 100644 --- a/docs/source/contributor-guide/gsoc_application_guidelines.md +++ b/docs/source/contributor-guide/gsoc/gsoc_application_guidelines_2025.md @@ -1,4 +1,4 @@ -# GSoC Application Guidelines +# GSoC Application Guidelines (2025) ## Introduction diff --git a/docs/source/contributor-guide/gsoc_project_ideas.md b/docs/source/contributor-guide/gsoc/gsoc_project_ideas_2025.md similarity index 99% rename from docs/source/contributor-guide/gsoc_project_ideas.md rename to docs/source/contributor-guide/gsoc/gsoc_project_ideas_2025.md index da6c24e2921b1..d81d9eb9adab5 100644 --- a/docs/source/contributor-guide/gsoc_project_ideas.md +++ b/docs/source/contributor-guide/gsoc/gsoc_project_ideas_2025.md @@ -1,4 +1,4 @@ -# GSoC Project Ideas +# GSoC Project Ideas (2025) ## Introduction diff --git a/docs/source/contributor-guide/gsoc/index.rst b/docs/source/contributor-guide/gsoc/index.rst new file mode 100644 index 0000000000000..10b0013e9b169 --- /dev/null +++ b/docs/source/contributor-guide/gsoc/index.rst @@ -0,0 +1,36 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Google Summer of Code (GSOC) +============================ + +DataFusion has participated in +`Google Summer of Code (GSOC) `_ +since 2025. GSOC is a global program that offers students stipends to +write code for open source projects. + +If you are a interested in contributing to DataFusion, we encourage you +to apply. You can find more information about the application process and +project ideas in the sections below. + + +.. toctree:: + :maxdepth: 1 + + gsoc_application_guidelines_2025 + gsoc_project_ideas_2025 + diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index e38898db5a92a..383827893c70f 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -37,14 +37,15 @@ You can find how to setup build and testing environment [here](https://datafusio ## Finding and Creating Issues to Work On You can find a curated [good-first-issue] list to help you get started. +You can read about how we plan larger projects in the [Roadmap and Improvement Proposals](roadmap.md) section. [good-first-issue]: https://github.com/apache/datafusion/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22 ### Open Contribution and Assigning tickets DataFusion is an open contribution project, and thus there is no particular -project imposed deadline for completing any issue or any restriction on who can -work on an issue, nor how many people can work on an issue at the same time. +project imposed deadline for completing issues or restrictions on who can +work on an issue, nor limits to how many people can work on an issue at the same time. Contributors drive the project forward based on their own priorities and interests and thus you are free to work on any issue that interests you. @@ -62,52 +63,6 @@ unable to make progress you should unassign the issue by using the `unassign me` link at the top of the issue page (and ask for help if are stuck) so that someone else can get involved in the work. -### Discussing New Features - -If you plan to work on a new feature that doesn't have an existing ticket, it is -a good idea to open a ticket to discuss the feature. Advanced discussion often -helps avoid wasted effort by determining early if the feature is a good fit for -DataFusion before too much time is invested. Discussion on a ticket can help -gather feedback from the community and is likely easier to discuss than a 1000 -line PR. - -If you open a ticket and it doesn't get any response, you can try `@`-mentioning -recently active community members in the ticket to get their attention. - -### What Contributions are Good Fits? - -DataFusion is designed to be highly extensible, and many features can be -implemented as extensions without changes or additions to the core. Support for -new functions, data formats, and similar functionality can be added using those -extension APIs, and there are already many existing community supported -extensions listed in the [extensions list]. - -Query engines are complex pieces of software to develop and maintain. Given our -limited maintenance bandwidth, we try to keep the DataFusion core as simple and -focused as possible, while still satisfying the [design goal] of an easy to -start initial experience. - -With that in mind, contributions that meet the following criteria are more likely -to be accepted: - -1. Bug fixes for existing features -2. Test coverage for existing features -3. Documentation improvements / examples -4. Performance improvements to existing features (with benchmarks) -5. "Small" functional improvements to existing features (if they don't change existing behavior) -6. Additional APIs for extending DataFusion's capabilities -7. CI improvements - -Contributions that will likely involve more discussion (see Discussing New -Features above) prior to acceptance include: - -1. Major new functionality (even if it is part of the "standard SQL") -2. New functions, especially if they aren't part of "standard SQL" -3. New data sources (e.g. support for Apache ORC) - -[extensions list]: ../library-user-guide/extensions.md -[design goal]: https://docs.rs/datafusion/latest/datafusion/index.html#design-goals - # Developer's guide ## Pull Request Overview diff --git a/docs/source/contributor-guide/inviting.md b/docs/source/contributor-guide/inviting.md index c6ed2695cfc12..a61e16c9a65b7 100644 --- a/docs/source/contributor-guide/inviting.md +++ b/docs/source/contributor-guide/inviting.md @@ -175,8 +175,8 @@ Of course, you can decline and instead remain as a contributor, participating as you do now. A. This personal invitation is a chance for you to accept or decline -in private. Either way, please let us know in reply to the -private@datafusion.apache.org address only. +in private. Either way, please let us know in reply to this email, make sure to reply-all (it should send a copy to +private@datafusion.apache.org) for record keeping / log of the project. B. If you accept, the next step is to register an ICLA: @@ -232,8 +232,8 @@ Of course, you can decline and instead remain as a contributor, participating as you do now. This personal invitation is a chance for you to accept or decline -in private. Either way, please let us know in reply to the -private@datafusion.apache.org address only. We will have to request an +in private. Either way, please let us know in reply to this email, make sure to reply-all (it should send a copy to +private@datafusion.apache.org) for record keeping / log of the project. We will have to request an Apache account be created for you, so please let us know what user id you would prefer. @@ -275,14 +275,16 @@ probably find that you spend more time here. Of course, you can decline and instead remain as a contributor, participating as you do now. -If you accept, please let us know by replying to private@datafusion.apache.org. +If you accept, please let us know in reply to this email, make sure to reply-all (it should send a copy to +private@datafusion.apache.org) for record keeping / log of the project. ``` ## New PMC Members -See also the ASF instructions on [how to add a PMC member]. +This is a DataFusion specific cookbook for the Apache Software Foundation +instructions on [how to add a PMC member]. -[how to add a pmc member]: https://www.apache.org/dev/pmc.html#newpmc +[how to add a pmc member]: https://www.apache.org/dev/pmc.html#pmcmembers ### Step 1: Start a Discussion Thread @@ -333,29 +335,18 @@ Thanks, Your Name ``` -### Step 3: Send Notice to ASF Board - -The DataFusion PMC Chair then sends a NOTICE to `board@apache.org` (cc'ing -`private@`) like this: +If this vote succeeds, send a "RESULT" email to `private@` like this: ``` -To: board@apache.org -Cc: private@datafusion.apache.org -Subject: [NOTICE] $NEW_PMC_MEMBER to join DataFusion PMC - -DataFusion proposes to invite $NEW_PMC_MEMBER ($NEW_PMC_MEMBER_APACHE_ID) to join the PMC. - -The vote result is available here: -$VOTE_RESULT_URL +To: private@datafusion.apache.org +Subject: [RESULT][VOTE] $NEW_PMC_MEMBER for PMC -FYI: Full vote details: -$VOTE_URL +The vote carries with N +1 votes and no -1 votes. I will send an invitation ``` -### Step 4: Send invitation email +### Step 3: Send invitation email -Once, the PMC chair has confirmed that the email sent to `board@apache.org` has -made it to the archives, the Chair sends an invitation e-mail to the new PMC +Assuming the vote passes, the Chair sends an invitation e-mail to the new PMC member (cc'ing `private@`) like this: ``` @@ -405,11 +396,11 @@ With the expectation of your acceptance, welcome! The Apache DataFusion PMC ``` -### Step 5: Chair Promotes the Committer to PMC +### Step 4: Chair Promotes the Committer to PMC The PMC chair adds the user to the PMC using the [Whimsy Roster Tool]. -### Step 6: Announce and Celebrate the New PMC Member +### Step 5: Announce and Celebrate the New PMC Member Send an email such as the following to `dev@datafusion.apache.org` to celebrate: diff --git a/docs/source/contributor-guide/roadmap.md b/docs/source/contributor-guide/roadmap.md index 3d9c1ee371fe6..073682008047d 100644 --- a/docs/source/contributor-guide/roadmap.md +++ b/docs/source/contributor-guide/roadmap.md @@ -17,7 +17,7 @@ specific language governing permissions and limitations under the License. --> -# Roadmap +# Roadmap and Improvement Proposals The [project introduction](../user-guide/introduction) explains the overview and goals of DataFusion, and our development efforts largely @@ -25,102 +25,127 @@ align to that vision. ## Planning `EPIC`s -DataFusion uses [GitHub -issues](https://github.com/apache/datafusion/issues) to track -planned work. We collect related tickets using tracking issues labeled -with `[EPIC]` which contain discussion and links to more detailed items. - -Epics offer a high level roadmap of what the DataFusion -community is thinking about. The epics are not meant to restrict -possibilities, but rather help the community see where development is -headed, align our work, and inspire additional contributions. - -As this project is entirely driven by volunteers, we welcome -contributions for items not currently covered by epics. However, -before submitting a large PR, we strongly suggest and request you -start a conversation using a github issue or the -[dev@arrow.apache.org](mailto:dev@arrow.apache.org) mailing list to -make review efficient and avoid surprises. - -[The current list of `EPIC`s can be found here](https://github.com/apache/datafusion/issues?q=is%3Aissue+is%3Aopen+epic). - -# Quarterly Roadmap - -A quarterly roadmap will be published to give the DataFusion community -visibility into the priorities of the projects contributors. This roadmap is not -binding and we would welcome any/all contributions to help keep this list up to -date. - -## 2023 Q4 - -- Improve data output (`COPY`, `INSERT` and DataFrame) output capability [#6569](https://github.com/apache/datafusion/issues/6569) -- Implementation of `ARRAY` types and related functions [#6980](https://github.com/apache/datafusion/issues/6980) -- Write an industrial paper about DataFusion for SIGMOD [#6782](https://github.com/apache/datafusion/issues/6782) - -## 2022 Q2 - -### DataFusion Core - -- IO Improvements - - Reading, registering, and writing more file formats from both DataFrame API and SQL - - Additional options for IO including partitioning and metadata support -- Work Scheduling - - Improve predictability, observability and performance of IO and CPU-bound work - - Develop a more explicit story for managing parallelism during plan execution -- Memory Management - - Add more operators for memory limited execution -- Performance - - Incorporate row-format into operators such as aggregate - - Add row-format benchmarks - - Explore JIT-compiling complex expressions - - Explore LLVM for JIT, with inline Rust functions as the primary goal - - Improve performance of Sort and Merge using Row Format / JIT expressions -- Documentation - - General improvements to DataFusion website - - Publish design documents -- Streaming - - Create `StreamProvider` trait - -### Ballista - -- Make production ready - - Shuffle file cleanup - - Fill functional gaps between DataFusion and Ballista - - Improve task scheduling and data exchange efficiency - - Better error handling - - Task failure - - Executor lost - - Schedule restart - - Improve monitoring and logging - - Auto scaling support -- Support for multi-scheduler deployments. Initially for resiliency and fault tolerance but ultimately to support sharding for scalability and more efficient caching. -- Executor deployment grouping based on resource allocation - -### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib)) - -### [DataFusion-Python](https://github.com/datafusion-contrib/datafusion-python) - -- Add missing functionality to DataFrame and SessionContext -- Improve documentation - -### [DataFusion-S3](https://github.com/datafusion-contrib/datafusion-objectstore-s3) - -- Create Python bindings to use with datafusion-python - -### [DataFusion-Tui](https://github.com/datafusion-contrib/datafusion-tui) - -- Create multiple SQL editors -- Expose more Context and query metadata -- Support new data sources - - BigTable, HDFS, HTTP APIs - -### [DataFusion-BigTable](https://github.com/datafusion-contrib/datafusion-bigtable) - -- Python binding to use with datafusion-python -- Timestamp range predicate pushdown -- Multi-threaded partition aware execution -- Production ready Rust SDK - -### [DataFusion-Streams](https://github.com/datafusion-contrib/datafusion-streams) - -- Create experimental implementation of `StreamProvider` trait +DataFusion uses [GitHub issues] to track planned work. We collect related +tickets using tracking issues marked with the `EPIC` label, containing +discussion and links to more detailed items: + +[github issues]: https://github.com/apache/datafusion/issues + +- [The current list of `EPIC`s can be found here.](https://github.com/apache/datafusion/issues?q=is%3Aissue%20state%3Aopen%20label%3AEPIC) + +- [The current list of `PROPOSAL EPIC` (that are not yet underway) can be found here.](https://github.com/apache/datafusion/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22PROPOSAL%20EPIC%22) + +Epics offer a high level roadmap of what the DataFusion community is thinking +about. The epics are not meant to restrict possibilities, but rather help +organize the community and make it easier to see where development is headed, +align our work, and inspire additional contributions. + +We also welcome contributions for items not covered by epics. However, before +submitting a large PR, we strongly suggest and request you start a conversation as described in [Discussing New Features](#discussing-new-features) below. + +[dev@arrow.apache.org]: mailto:dev@arrow.apache.org + +## Quarterly Roadmap + +The DataFusion roadmap is driven by the priorities of contributors rather than +any single organization or coordinating committee. We typically discuss our +roadmap using GitHub issues, approximately quarterly, and invite you to join the +discussion. + +For more information: + +1. [Search for issues labeled `roadmap`](https://github.com/apache/datafusion/issues?q=is%3Aissue%20%20%20roadmap) +2. [DataFusion Road Map: Q3-Q4 2025](https://github.com/apache/datafusion/issues/15878) +3. [2024 Q4 / 2025 Q1 Roadmap](https://github.com/apache/datafusion/issues/13274) + +## Improvement Proposals + +### Discussing New Features + +If you plan to work on a new feature that doesn't have an existing ticket, it is +a good idea to open one for discussion. Advanced discussion helps avoid wasted +effort by determining if the feature is a good fit for DataFusion before too +much time is invested. Discussion on a ticket can help gather feedback from the +community and is likely easier to discuss than a 1000 line PR. + +Maintainers will mark major proposals as `PROPOSED EPIC` to make them more +visible, but we are very limited on review bandwidth. If you open a ticket and it +doesn't get any response, try `@`-mentioning recently active community members +in the ticket, or [posting to the mailing list or Discord](communication.md). + +### Supervising Maintainers + +We have found that most successful epics have one or more "supervising +maintainers", a committer ([see here for current list]) who take the lead on +reviewing and committing PRs, helps with design, and coordinates and +communicates with the community. If you want to ship a large feature, we +recommend finding such maintainer upfront; otherwise, your PRs may +remain unreviewed for a very long time. + +Supervising maintainers have no additional formal authority and there is +currently no formal process for appointing, approving or tracking who has that +role for a given epic. Instead, we rely on discussion on the ticket or PR. +Helping complete an epic is a significant time commitment, so maintainers are +more likely to help features they are particularly interested in or align with +their own project's use of DataFusion. + +If you are willing to be a supervising maintainer for a feature, please say so +explicitly. If you are unsure, we suggest asking directly who is willing to take +the role, as it can be hard to tell sometimes whether a committer is simply +participating and giving general feedback. + +[see here for current list]: governance.md + +### What Contributions are Good Fits? + +DataFusion is designed to be highly extensible, and many features can be +implemented as extensions without changes or additions to the core. Support for +new functions, data formats, and similar functionality can be added using those +extension APIs, and there are already many existing community supported +extensions listed in the [extensions list]. + +Query engines are complex pieces of software to develop and maintain. Given our +limited maintenance bandwidth, we try to keep the DataFusion core as simple and +focused as possible, while still satisfying the [design goal] of an easy to +start initial experience. + +With that in mind, contributions that meet the following criteria are more likely +to be accepted: + +1. Bug fixes for existing features +2. Test coverage for existing features +3. Documentation improvements / examples +4. Performance improvements to existing features (with benchmarks) +5. "Small" functional improvements to existing features (if they don't change existing behavior) +6. Additional APIs for extending DataFusion's capabilities +7. CI improvements + +Contributions that will likely involve more discussion (see Discussing New +Features above) prior to acceptance include: + +1. Major new functionality (even if it is part of the "standard SQL") +2. New functions, especially if they aren't part of "standard SQL" +3. New data sources (e.g. support for Apache ORC) + +[extensions list]: ../library-user-guide/extensions.md +[design goal]: https://docs.rs/datafusion/latest/datafusion/index.html#design-goals + +### Design Build vs. Big Up Front Design + +Typically, the DataFusion community attacks large problems by solving them bit +by bit and refining a solution iteratively on the `main` branch as a series of +Pull Requests. This is different from projects which front-load the effort +with a more comprehensive design process. + +By "advancing the front" the community always makes tangible progress, and the strategy is +especially effective in a project that relies on individual contributors who may +not have the time or resources to invest in a large upfront design effort. +However, this "bit by bit approach" doesn't always succeed, and sometimes we get +stuck or go down the wrong path and then change directions. + +Our process necessarily results in imperfect solutions being the "state of the +code" in some cases, and larger visions are not yet fully realized. However, the +community is good at driving things to completion in the long run. If you see +something that needs improvement or an area that is not yet fully realized, +please consider submitting an issue or PR to improve it. We are always looking +for more contributions. diff --git a/docs/source/contributor-guide/testing.md b/docs/source/contributor-guide/testing.md index eeed2a0c5d76c..dd22e1236081a 100644 --- a/docs/source/contributor-guide/testing.md +++ b/docs/source/contributor-guide/testing.md @@ -75,23 +75,17 @@ cargo insta review In addition to the standard CI test suite that is run on all PRs prior to merge, DataFusion has "extended" tests (defined in [extended.yml]) that are run on each commit to `main`. These tests rarely fail but take significantly longer to run -than the standard test suite and add important test coverage such as that the -code works when there are hash collisions as well as running the relevant -portions of the entire [sqlite test suite]. +than the standard test suite and add important test coverage such as ensuring +correctness when there are hash collisions and running the relevant portions of +the entire [sqlite test suite]. You can run the extended tests +locally by following the [instructions in the documentation]. -You can run the extended tests on any PR by leaving the following comment (see [example here]): - -``` -Run extended tests -``` - -[extended.yml]: https://github.com/apache/datafusion/blob/main/.github/workflows/extended.yml [sqlite test suite]: https://www.sqlite.org/sqllogictest/dir?ci=tip -[example here]: https://github.com/apache/datafusion/pull/15427#issuecomment-2759160812 +[instructions in the documentation]: https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest#running-tests-sqlite ## Rust Integration Tests -There are several tests of the public interface of the DataFusion library in the [tests](https://github.com/apache/datafusion/tree/main/datafusion/core/tests) directory. +There are several public interface tests for the DataFusion library in the [tests](https://github.com/apache/datafusion/tree/main/datafusion/core/tests) directory. You can run these tests individually using `cargo` as normal command such as diff --git a/docs/source/index.rst b/docs/source/index.rst index 0dc947fdea579..574c285b0e65e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -49,16 +49,16 @@ The following related subprojects target end users and have separate documentati - `DataFusion Python `_ offers a Python interface for SQL and DataFrame queries. -- `DataFusion Ray `_ provides a distributed version of DataFusion - that scales out on `Ray `_ clusters. - `DataFusion Comet `_ is an accelerator for Apache Spark based on DataFusion. +- `DataFusion Ballista `_ is distributed processing extension for DataFusion. "Out of the box," DataFusion offers `SQL `_ and `Dataframe `_ APIs, excellent `performance `_, built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. `Python Bindings `_ are also available. +`Ballista `_ is Apache DataFusion extension enabling the parallelized execution of workloads across multiple nodes in a distributed environment. DataFusion features a full query planner, a columnar, streaming, multi-threaded, vectorized execution engine, and partitioned data sources. You can @@ -126,18 +126,19 @@ To get started, see :caption: Library User Guide library-user-guide/index + library-user-guide/upgrading library-user-guide/extensions library-user-guide/using-the-sql-api library-user-guide/working-with-exprs library-user-guide/using-the-dataframe-api library-user-guide/building-logical-plans library-user-guide/catalogs - library-user-guide/adding-udfs + library-user-guide/functions/index library-user-guide/custom-table-providers + library-user-guide/table-constraints library-user-guide/extending-operators library-user-guide/profiling library-user-guide/query-optimizer - library-user-guide/upgrading .. .. _toc.contributor-guide: @@ -156,8 +157,7 @@ To get started, see contributor-guide/governance contributor-guide/inviting contributor-guide/specification/index - contributor-guide/gsoc_application_guidelines - contributor-guide/gsoc_project_ideas + contributor-guide/gsoc/index .. _toc.subprojects: @@ -165,6 +165,6 @@ To get started, see :maxdepth: 1 :caption: DataFusion Subprojects - DataFusion Ballista + DataFusion Ballista DataFusion Comet DataFusion Python diff --git a/docs/source/library-user-guide/building-logical-plans.md b/docs/source/library-user-guide/building-logical-plans.md index e1e75b3e4bdbd..9dc0fcbf31578 100644 --- a/docs/source/library-user-guide/building-logical-plans.md +++ b/docs/source/library-user-guide/building-logical-plans.md @@ -153,9 +153,9 @@ Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] Logical plans can not be directly executed. They must be "compiled" into an [`ExecutionPlan`], which is often referred to as a "physical plan". -Compared to `LogicalPlan`s `ExecutionPlans` have many more details such as -specific algorithms and detailed optimizations compared to. Given a -`LogicalPlan` the easiest way to create an `ExecutionPlan` is using +Compared to `LogicalPlan`s, `ExecutionPlan`s have many more details such as +specific algorithms and detailed optimizations. Given a +`LogicalPlan`, the easiest way to create an `ExecutionPlan` is using [`SessionState::create_physical_plan`] as shown below ```rust @@ -181,7 +181,7 @@ async fn main() -> Result<(), DataFusionError> { // TableProvider. For this example, we don't provide any data // but in production code, this would have `RecordBatch`es with // in memory data - let table_provider = Arc::new(MemTable::try_new(Arc::new(schema), vec![])?); + let table_provider = Arc::new(MemTable::try_new(Arc::new(schema), vec![vec![]])?); // Use the provider_as_source function to convert the TableProvider to a table source let table_source = provider_as_source(table_provider); @@ -220,7 +220,7 @@ However, it is more common to use a [TableProvider]. To get a [TableSource] from [logicaltablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/builder/struct.LogicalTableSource.html [defaulttablesource]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html [provider_as_source]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/fn.provider_as_source.html -[tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html +[tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html [tablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/trait.TableSource.html [`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html [`sessionstate::create_physical_plan`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.create_physical_plan diff --git a/docs/source/library-user-guide/catalogs.md b/docs/source/library-user-guide/catalogs.md index 906039ba23003..d4e6633d40ba7 100644 --- a/docs/source/library-user-guide/catalogs.md +++ b/docs/source/library-user-guide/catalogs.md @@ -23,11 +23,14 @@ This section describes how to create and manage catalogs, schemas, and tables in ## General Concepts -CatalogProviderList, Catalogs, schemas, and tables are organized in a hierarchy. A CatalogProviderList contains catalog providers, a catalog provider contains schemas and a schema contains tables. +Catalog providers, catalogs, schemas, and tables are organized in a hierarchy. A `CatalogProviderList` contains `CatalogProvider`s, a `CatalogProvider` contains `SchemaProviders` and a `SchemaProvider` contains `TableProvider`s. DataFusion comes with a basic in memory catalog functionality in the [`catalog` module]. You can use these in memory implementations as is, or extend DataFusion with your own catalog implementations, for example based on local files or files on remote object storage. +DataFusion supports DDL queries (e.g. `CREATE TABLE`) using the catalog API described in this section. See the [TableProvider] section for information on DML queries (e.g. `INSERT INTO`). + [`catalog` module]: https://docs.rs/datafusion/latest/datafusion/catalog/index.html +[tableprovider]: ./custom-table-providers.md Similarly to other concepts in DataFusion, you'll implement various traits to create your own catalogs, schemas, and tables. The following sections describe the traits you'll need to implement. diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index 886ac96295662..695cb16ac8604 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -19,17 +19,25 @@ # Custom Table Provider -Like other areas of DataFusion, you extend DataFusion's functionality by implementing a trait. The `TableProvider` and associated traits, have methods that allow you to implement a custom table provider, i.e. use DataFusion's other functionality with your custom data source. +Like other areas of DataFusion, you extend DataFusion's functionality by implementing a trait. The [`TableProvider`] and associated traits allow you to implement a custom table provider, i.e. use DataFusion's other functionality with your custom data source. -This section will also touch on how to have DataFusion use the new `TableProvider` implementation. +This section describes how to create a [`TableProvider`] and how to configure DataFusion to use it for reading. + +For details on how table constraints such as primary keys or unique +constraints are handled, see [Table Constraint Enforcement](table-constraints.md). ## Table Provider and Scan -The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution of the query. +The [`TableProvider::scan`] method reads data from the table and is likely the most important. It returns an [`ExecutionPlan`] that DataFusion will use to read the actual data during execution of the query. The [`TableProvider::insert_into`] method is used to `INSERT` data into the table. ### Scan -As mentioned, `scan` returns an execution plan, and in particular a `Result>`. The core of this is returning something that can be dynamically dispatched to an `ExecutionPlan`. And as per the general DataFusion idea, we'll need to implement it. +As mentioned, [`TableProvider::scan`] returns an execution plan, and in particular a `Result>`. The core of this is returning something that can be dynamically dispatched to an `ExecutionPlan`. And as per the general DataFusion idea, we'll need to implement it. + +[`tableprovider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +[`tableprovider::scan`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#tymethod.scan +[`tableprovider::insert_into`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#tymethod.insert_into +[`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html #### Execution Plan diff --git a/docs/source/library-user-guide/extending-operators.md b/docs/source/library-user-guide/extending-operators.md index 631bdc67975a4..5c28d1e670586 100644 --- a/docs/source/library-user-guide/extending-operators.md +++ b/docs/source/library-user-guide/extending-operators.md @@ -19,4 +19,41 @@ # Extending DataFusion's operators: custom LogicalPlan and Execution Plans -Coming soon +DataFusion supports extension of operators by transforming logical plan and execution plan through customized [optimizer rules](https://docs.rs/datafusion/latest/datafusion/optimizer/trait.OptimizerRule.html). This section will use the µWheel project to illustrate such capabilities. + +## About DataFusion µWheel + +[DataFusion µWheel](https://github.com/uwheel/datafusion-uwheel/tree/main) is a native DataFusion optimizer which improves query performance for time-based analytics through fast temporal aggregation and pruning using custom indices. The integration of µWheel into DataFusion is a joint effort with the DataFusion community. + +### Optimizing Logical Plan + +The `rewrite` function transforms logical plans by identifying temporal patterns and aggregation functions that match the stored wheel indices. When match is found, it queries the corresponding index to retrieve pre-computed aggregate values, stores these results in a [MemTable](https://docs.rs/datafusion/latest/datafusion/datasource/memory/struct.MemTable.html), and returns as a new `LogicalPlan::TableScan`. If no match is found, the original plan proceeds unchanged through DataFusion's standard execution path. + +```rust,ignore +fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, +) -> Result> { + // Attempts to rewrite a logical plan to a uwheel-based plan that either provides + // plan-time aggregates or skips execution based on min/max pruning. + if let Some(rewritten) = self.try_rewrite(&plan) { + Ok(Transformed::yes(rewritten)) + } else { + Ok(Transformed::no(plan)) + } +} +``` + +```rust,ignore +// Converts a uwheel aggregate result to a TableScan with a MemTable as source +fn agg_to_table_scan(result: f64, schema: SchemaRef) -> Result { + let data = Float64Array::from(vec![result]); + let record_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(data)])?; + let df_schema = Arc::new(DFSchema::try_from(schema.clone())?); + let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?; + mem_table_as_table_scan(mem_table, df_schema) +} +``` + +To get a deeper dive into the usage of the µWheel project, visit the [blog post](https://uwheel.rs/post/datafusion_uwheel/) by Max Meldrum. diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md similarity index 82% rename from docs/source/library-user-guide/adding-udfs.md rename to docs/source/library-user-guide/functions/adding-udfs.md index 8fb8a59fb8609..2335105882a10 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -23,19 +23,29 @@ User Defined Functions (UDFs) are functions that can be used in the context of D This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | -| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | +| UDF Type | Description | Example(s) | +| -------------- | ---------------------------------------------------------------------------------------------------------- | ------------------------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs] / [advanced_udf.rs] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs] / [advanced_udwf.rs] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs] / [advanced_udaf.rs] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs] | +| Scalar (async) | A scalar function for performing `async` operations (such as network or I/O calls) within the UDF. | [async_udf.rs] | + +[simple_udf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs +[advanced_udf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +[simple_udwf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs +[advanced_udwf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +[simple_udaf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +[advanced_udaf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +[simple_udtf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs +[async_udf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/async_udf.rs First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. ## Adding a Scalar UDF -A Scalar UDF is a function that takes a row of data and returns a single value. In order for good performance +A Scalar UDF is a function that takes a row of data and returns a single value. To achieve good performance, such functions are "vectorized" in DataFusion, meaning they get one or more Arrow Arrays as input and produce an Arrow Array with the same number of rows as output. @@ -47,8 +57,8 @@ To create a Scalar UDF, you In the following example, we will add a function takes a single i64 and returns a single i64 with 1 added to it: -For brevity, we'll skipped some error handling, but e.g. you may want to check that `args.len()` is the expected number -of arguments. +For brevity, we'll skip some error handling. +For production code, you may want to check, for example, that `args.len()` matches the expected number of arguments. ### Adding by `impl ScalarUDFImpl` @@ -73,7 +83,7 @@ use datafusion_doc::Documentation; description = "Add one udf", syntax_example = "add_one(1)" )] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AddOne { signature: Signature, } @@ -136,7 +146,7 @@ We now need to register the function with DataFusion so that it can be used in t # description = "Add one udf", # syntax_example = "add_one(1)" # )] -# #[derive(Debug)] +# #[derive(Debug, PartialEq, Eq, Hash)] # struct AddOne { # signature: Signature, # } @@ -344,6 +354,233 @@ async fn main() { } ``` +## Adding a Async Scalar UDF + +An Async Scalar UDF allows you to implement user-defined functions that support +asynchronous execution, such as performing network or I/O operations within the +UDF. + +To add a Scalar Async UDF, you need to: + +1. Implement the `AsyncScalarUDFImpl` trait to define your async function logic, signature, and types. +2. Wrap your implementation with `AsyncScalarUDF::new` and register it with the `SessionContext`. + +### Adding by `impl AsyncScalarUDFImpl` + +```rust +# use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray}; +# use arrow_schema::DataType; +# use async_trait::async_trait; +# use datafusion::common::error::Result; +# use datafusion::common::{internal_err, not_impl_err}; +# use datafusion::common::types::logical_string; +# use datafusion::config::ConfigOptions; +# use datafusion_expr::ScalarUDFImpl; +# use datafusion::logical_expr::async_udf::AsyncScalarUDFImpl; +# use datafusion::logical_expr::{ +# ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, ScalarFunctionArgs +# }; +# use datafusion::logical_expr_common::signature::Coercion; +# use std::any::Any; +# use std::sync::Arc; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct AsyncUpper { + signature: Signature, +} + +impl Default for AsyncUpper { + fn default() -> Self { + Self::new() + } +} + +impl AsyncUpper { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_string()), + }]), + Volatility::Volatile, + ), + } + } +} + +/// Implement the normal ScalarUDFImpl trait for AsyncUpper +#[async_trait] +impl ScalarUDFImpl for AsyncUpper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + // Note the normal invoke_with_args method is not called for Async UDFs + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + not_impl_err!("AsyncUpper can only be called from async contexts") + } +} + +/// The actual implementation of the async UDF +#[async_trait] +impl AsyncScalarUDFImpl for AsyncUpper { + fn ideal_batch_size(&self) -> Option { + Some(10) + } + + /// This method is called to execute the async UDF and is similar + /// to the normal `invoke_with_args` except it is `async`. + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let value = &args.args[0]; + // This function simply implements a simple string to uppercase conversion + // but can be used for any async operation such as network calls. + let result = match value { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let iter = ArrayIter::new(string_array); + let result = iter + .map(|string| string.map(|s| s.to_uppercase())) + .collect::(); + Arc::new(result) as ArrayRef + } + _ => return internal_err!("Expected a string argument, got {:?}", value), + }; + Ok(ColumnarValue::from(result)) + } +} +``` + +We can now transfer the async UDF into the normal scalar using `into_scalar_udf` to register the function with DataFusion so that it can be used in the context of a query. + +```rust +# use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray}; +# use arrow_schema::DataType; +# use async_trait::async_trait; +# use datafusion::common::error::Result; +# use datafusion::common::{internal_err, not_impl_err}; +# use datafusion::common::types::logical_string; +# use datafusion::config::ConfigOptions; +# use datafusion_expr::ScalarUDFImpl; +# use datafusion::logical_expr::async_udf::AsyncScalarUDFImpl; +# use datafusion::logical_expr::{ +# ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, ScalarFunctionArgs +# }; +# use datafusion::logical_expr_common::signature::Coercion; +# use log::trace; +# use std::any::Any; +# use std::sync::Arc; +# +# #[derive(Debug, PartialEq, Eq, Hash)] +# pub struct AsyncUpper { +# signature: Signature, +# } +# +# impl Default for AsyncUpper { +# fn default() -> Self { +# Self::new() +# } +# } +# +# impl AsyncUpper { +# pub fn new() -> Self { +# Self { +# signature: Signature::new( +# TypeSignature::Coercible(vec![Coercion::Exact { +# desired_type: TypeSignatureClass::Native(logical_string()), +# }]), +# Volatility::Volatile, +# ), +# } +# } +# } +# +# #[async_trait] +# impl ScalarUDFImpl for AsyncUpper { +# fn as_any(&self) -> &dyn Any { +# self +# } +# +# fn name(&self) -> &str { +# "async_upper" +# } +# +# fn signature(&self) -> &Signature { +# &self.signature +# } +# +# fn return_type(&self, _arg_types: &[DataType]) -> Result { +# Ok(DataType::Utf8) +# } +# +# fn invoke_with_args( +# &self, +# _args: ScalarFunctionArgs, +# ) -> Result { +# not_impl_err!("AsyncUpper can only be called from async contexts") +# } +# } +# +# #[async_trait] +# impl AsyncScalarUDFImpl for AsyncUpper { +# fn ideal_batch_size(&self) -> Option { +# Some(10) +# } +# +# async fn invoke_async_with_args( +# &self, +# args: ScalarFunctionArgs, +# ) -> Result { +# trace!("Invoking async_upper with args: {:?}", args); +# let value = &args.args[0]; +# let result = match value { +# ColumnarValue::Array(array) => { +# let string_array = array.as_string::(); +# let iter = ArrayIter::new(string_array); +# let result = iter +# .map(|string| string.map(|s| s.to_uppercase())) +# .collect::(); +# Arc::new(result) as ArrayRef +# } +# _ => return internal_err!("Expected a string argument, got {:?}", value), +# }; +# Ok(ColumnarValue::from(result)) +# } +# } +use datafusion::execution::context::SessionContext; +use datafusion::logical_expr::async_udf::AsyncScalarUDF; + +let async_upper = AsyncUpper::new(); +let udf = AsyncScalarUDF::new(Arc::new(async_upper)); +let mut ctx = SessionContext::new(); +ctx.register_udf(udf.into_scalar_udf()); +``` + +After registration, you can use these async UDFs directly in SQL queries, for example: + +```sql +SELECT async_upper('datafusion'); +``` + +For async UDF implementation details, see [`async_udf.rs`](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/async_udf.rs). + [`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html [`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html [`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html @@ -1076,7 +1313,7 @@ pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { return plan_err!("First argument must be an integer"); }; @@ -1117,7 +1354,7 @@ With the UDTF implemented, you can register it with the `SessionContext`: # # impl TableFunctionImpl for EchoFunction { # fn call(&self, exprs: &[Expr]) -> Result> { -# let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { +# let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { # return plan_err!("First argument must be an integer"); # }; # @@ -1244,8 +1481,3 @@ async fn main() -> Result<()> { Ok(()) } ``` - -[1]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs -[2]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs -[3]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs -[4]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs diff --git a/docs/source/library-user-guide/functions/index.rst b/docs/source/library-user-guide/functions/index.rst new file mode 100644 index 0000000000000..d6127446c2286 --- /dev/null +++ b/docs/source/library-user-guide/functions/index.rst @@ -0,0 +1,25 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Functions +============= + +.. toctree:: + :maxdepth: 2 + + adding-udfs + spark diff --git a/docs/source/library-user-guide/functions/spark.md b/docs/source/library-user-guide/functions/spark.md new file mode 100644 index 0000000000000..c371ae1cb5a86 --- /dev/null +++ b/docs/source/library-user-guide/functions/spark.md @@ -0,0 +1,29 @@ + + +# Spark Compatible Functions + +The [`datafusion-spark`] crate provides Apache Spark-compatible expressions for +use with DataFusion. + +[`datafusion-spark`]: https://crates.io/crates/datafusion-spark + +Please see the documentation for the [`datafusion-spark` crate] for more details. + +[`datafusion-spark` crate]: https://docs.rs/datafusion-spark/latest/datafusion_spark/ diff --git a/docs/source/library-user-guide/profiling.md b/docs/source/library-user-guide/profiling.md index 61e848a2b7d9b..a2ea6723e55a7 100644 --- a/docs/source/library-user-guide/profiling.md +++ b/docs/source/library-user-guide/profiling.md @@ -48,7 +48,7 @@ Ensure that you're in the directory containing the necessary data files for your ### Step 3: Running the Flamegraph Tool -To generate a flamegraph, you'll need to use the -- separator to pass arguments to the binary you're profiling. For datafusion-cli, you need to make sure to run the command with sudo permissions (especially on macOS, where DTrace requires elevated privileges). +To generate a flamegraph, you'll need to use the `--` separator to pass arguments to the binary you're profiling. For datafusion-cli, you need to make sure to run the command with sudo permissions (especially on macOS, where DTrace requires elevated privileges). Here is a general example: diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md index 03cd7b5bbbbe3..877ff8c754ad5 100644 --- a/docs/source/library-user-guide/query-optimizer.md +++ b/docs/source/library-user-guide/query-optimizer.md @@ -68,7 +68,7 @@ fn observer(plan: &LogicalPlan, rule: &dyn OptimizerRule) { ## Writing Optimization Rules Please refer to the -[optimizer_rule.rs](../../../datafusion-examples/examples/optimizer_rule.rs) +[optimizer_rule.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/optimizer_rule.rs) example to learn more about the general approach to writing optimizer rules and then move onto studying the existing rules. @@ -193,7 +193,7 @@ Looking at the `EXPLAIN` output we can see that the optimizer has effectively re `3 as "1 + 2"`: ```text -> explain select 1 + 2; +> explain format indent select 1 + 2; +---------------+-------------------------------------------------+ | plan_type | plan | +---------------+-------------------------------------------------+ @@ -428,7 +428,7 @@ Each of these statistics is wrapped in a `Precision` type that indicates whether exact or estimated, allowing the optimizer to make informed decisions about the reliability of its cardinality estimates. -### Boundary Analaysis Flow +### Boundary Analysis Flow The boundary analysis process flows through several stages, with each stage building upon the information gathered in previous stages. The `AnalysisContext` is continuously diff --git a/docs/source/library-user-guide/table-constraints.md b/docs/source/library-user-guide/table-constraints.md new file mode 100644 index 0000000000000..dea746463d234 --- /dev/null +++ b/docs/source/library-user-guide/table-constraints.md @@ -0,0 +1,42 @@ + + +# Table Constraint Enforcement + +Table providers can describe table constraints using the +[`TableConstraint`] and [`Constraints`] APIs. These constraints include +primary keys, unique keys, foreign keys and check constraints. + +DataFusion does **not** currently enforce these constraints at runtime. +They are provided for informational purposes and can be used by custom +`TableProvider` implementations or other parts of the system. + +- **Nullability**: The only property enforced by DataFusion is the + nullability of each [`Field`] in a schema. Returning data with null values + for Columns marked as not nullable will result in runtime errors during execution. DataFusion + does not check or enforce nullability when data is ingested. +- **Primary and unique keys**: DataFusion does not verify that the data + satisfies primary or unique key constraints. Table providers that + require this behaviour must implement their own checks. +- **Foreign keys and check constraints**: These constraints are parsed + but are not validated or used during query planning. + +[`tableconstraint`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/enum.TableConstraint.html +[`constraints`]: https://docs.rs/datafusion/latest/datafusion/common/functional_dependencies/struct.Constraints.html +[`field`]: https://docs.rs/arrow/latest/arrow/datatype/struct.Field.html diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 11fd495665225..e93659872565b 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -19,6 +19,943 @@ # Upgrade Guides +## DataFusion `51.0.0` + +**Note:** DataFusion `51.0.0` has not been released yet. The information provided in this section pertains to features and changes that have already been merged to the main branch and are awaiting release in this version. + +You can see the current [status of the `51.0.0`release here](https://github.com/apache/datafusion/issues/17558) + +### `MSRV` updated to 1.87.0 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.87.0`]. + +[`1.87.0`]: https://releases.rs/docs/1.87.0/ + +### `FunctionRegistry` exposes two additional methods + +`FunctionRegistry` exposes two additional methods `udafs` and `udwfs` which expose set of registered user defined aggregation and window function names. To upgrade implement methods returning set of registered function names: + +```diff +impl FunctionRegistry for FunctionRegistryImpl { + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } ++ fn udafs(&self) -> HashSet { ++ self.aggregate_functions.keys().cloned().collect() ++ } ++ ++ fn udwfs(&self) -> HashSet { ++ self.window_functions.keys().cloned().collect() ++ } +} +``` + +### `datafusion-proto` use `TaskContext` rather than `SessionContext` in physical plan serde methods + +There have been changes in the public API methods of `datafusion-proto` which handle physical plan serde. + +Methods like `physical_plan_from_bytes`, `parse_physical_expr` and similar, expect `TaskContext` instead of `SessionContext` + +```diff +- let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; ++ let plan2 = physical_plan_from_bytes(&bytes, &ctx.task_ctx())?; +``` + +as `TaskContext` contains `RuntimeEnv` methods such as `try_into_physical_plan` will not have explicit `RuntimeEnv` parameter. + +```diff +let result_exec_plan: Arc = proto +- .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) ++. .try_into_physical_plan(&ctx.task_ctx(), &composed_codec) +``` + +`PhysicalExtensionCodec::try_decode()` expects `TaskContext` instead of `FunctionRegistry`: + +```diff +pub trait PhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], +- registry: &dyn FunctionRegistry, ++ ctx: &TaskContext, + ) -> Result>; +``` + +See [issue #17601] for more details. + +[issue #17601]: https://github.com/apache/datafusion/issues/17601 + +## DataFusion `50.0.0` + +### ListingTable automatically detects Hive Partitioned tables + +DataFusion 50.0.0 automatically infers Hive partitions when using the `ListingTableFactory` and `CREATE EXTERNAL TABLE`. Previously, +when creating a `ListingTable`, datasets that use Hive partitioning (e.g. +`/table_root/column1=value1/column2=value2/data.parquet`) would not have the Hive columns reflected in +the table's schema or data. The previous behavior can be +restored by setting the `datafusion.execution.listing_table_factory_infer_partitions` configuration option to `false`. +See [issue #17049] for more details. + +[issue #17049]: https://github.com/apache/datafusion/issues/17049 + +### `MSRV` updated to 1.86.0 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.86.0`]. +See [#17230] for details. + +[`1.86.0`]: https://releases.rs/docs/1.86.0/ +[#17230]: https://github.com/apache/datafusion/pull/17230 + +### `ScalarUDFImpl`, `AggregateUDFImpl` and `WindowUDFImpl` traits now require `PartialEq`, `Eq`, and `Hash` traits + +To address error-proneness of `ScalarUDFImpl::equals`, `AggregateUDFImpl::equals`and +`WindowUDFImpl::equals` methods and to make it easy to implement function equality correctly, +the `equals` and `hash_value` methods have been removed from `ScalarUDFImpl`, `AggregateUDFImpl` +and `WindowUDFImpl` traits. They are replaced the requirement to implement the `PartialEq`, `Eq`, +and `Hash` traits on any type implementing `ScalarUDFImpl`, `AggregateUDFImpl` or `WindowUDFImpl`. +Please see [issue #16677] for more details. + +Most of the scalar functions are stateless and have a `signature` field. These can be migrated +using regular expressions + +- search for `\#\[derive\(Debug\)\](\n *(pub )?struct \w+ \{\n *signature\: Signature\,\n *\})`, +- replace with `#[derive(Debug, PartialEq, Eq, Hash)]$1`, +- review all the changes and make sure only function structs were changed. + +[issue #16677]: https://github.com/apache/datafusion/issues/16677 + +### `AsyncScalarUDFImpl::invoke_async_with_args` returns `ColumnarValue` + +In order to enable single value optimizations and be consistent with other +user defined function APIs, the `AsyncScalarUDFImpl::invoke_async_with_args` method now +returns a `ColumnarValue` instead of a `ArrayRef`. + +To upgrade, change the return type of your implementation + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + return array_ref; // old code + } +} +# */ +``` + +To return a `ColumnarValue` + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + return ColumnarValue::from(array_ref); // new code + } +} +# */ +``` + +See [#16896](https://github.com/apache/datafusion/issues/16896) for more details. + +### `ProjectionExpr` changed from type alias to struct + +`ProjectionExpr` has been changed from a type alias to a struct with named fields to improve code clarity and maintainability. + +**Before:** + +```rust,ignore +pub type ProjectionExpr = (Arc, String); +``` + +**After:** + +```rust,ignore +#[derive(Debug, Clone)] +pub struct ProjectionExpr { + pub expr: Arc, + pub alias: String, +} +``` + +To upgrade your code: + +- Replace tuple construction `(expr, alias)` with `ProjectionExpr::new(expr, alias)` or `ProjectionExpr { expr, alias }` +- Replace tuple field access `.0` and `.1` with `.expr` and `.alias` +- Update pattern matching from `(expr, alias)` to `ProjectionExpr { expr, alias }` + +This mainly impacts use of `ProjectionExec`. + +This change was done in [#17398] + +[#17398]: https://github.com/apache/datafusion/pull/17398 + +### `SessionState`, `SessionConfig`, and `OptimizerConfig` returns `&Arc` instead of `&ConfigOptions` + +To provide broader access to `ConfigOptions` and reduce required clones, some +APIs have been changed to return a `&Arc` instead of a +`&ConfigOptions`. This allows sharing the same `ConfigOptions` across multiple +threads without needing to clone the entire `ConfigOptions` structure unless it +is modified. + +Most users will not be impacted by this change since the Rust compiler typically +automatically dereference the `Arc` when needed. However, in some cases you may +have to change your code to explicitly call `as_ref()` for example, from + +```rust +# /* comment to avoid running +let optimizer_config: &ConfigOptions = state.options(); +# */ +``` + +To + +```rust +# /* comment to avoid running +let optimizer_config: &ConfigOptions = state.options().as_ref(); +# */ +``` + +See PR [#16970](https://github.com/apache/datafusion/pull/16970) + +### API Change to `AsyncScalarUDFImpl::invoke_async_with_args` + +The `invoke_async_with_args` method of the `AsyncScalarUDFImpl` trait has been +updated to remove the `_option: &ConfigOptions` parameter to simplify the API +now that the `ConfigOptions` can be accessed through the `ScalarFunctionArgs` +parameter. + +You can change your code like this + +```rust +# /* comment to avoid running +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + .. + } + ... +} +# */ +``` + +To this: + +```rust +# /* comment to avoid running + +impl AsyncScalarUDFImpl for AskLLM { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let options = &args.config_options; + .. + } + ... +} +# */ +``` + +### Schema Rewriter Module Moved to New Crate + +The `schema_rewriter` module and its associated symbols have been moved from `datafusion_physical_expr` to a new crate `datafusion_physical_expr_adapter`. This affects the following symbols: + +- `DefaultPhysicalExprAdapter` +- `DefaultPhysicalExprAdapterFactory` +- `PhysicalExprAdapter` +- `PhysicalExprAdapterFactory` + +To upgrade, change your imports to: + +```rust +use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, + PhysicalExprAdapter, PhysicalExprAdapterFactory +}; +``` + +### Upgrade to arrow `56.0.0` and parquet `56.0.0` + +This version of DataFusion upgrades the underlying Apache Arrow implementation +to version `56.0.0`. See the [release notes](https://github.com/apache/arrow-rs/releases/tag/56.0.0) +for more details. + +### Added `ExecutionPlan::reset_state` + +In order to fix a bug in DataFusion `49.0.0` where dynamic filters (currently only generated in the presence of a query such as `ORDER BY ... LIMIT ...`) +produced incorrect results in recursive queries, a new method `reset_state` has been added to the `ExecutionPlan` trait. + +Any `ExecutionPlan` that needs to maintain internal state or references to other nodes in the execution plan tree should implement this method to reset that state. +See [#17028] for more details and an example implementation for `SortExec`. + +[#17028]: https://github.com/apache/datafusion/pull/17028 + +### Nested Loop Join input sort order cannot be preserved + +The Nested Loop Join operator has been rewritten from scratch to improve performance and memory efficiency. From the micro-benchmarks: this change introduces up to 5X speed-up and uses only 1% memory in extreme cases compared to the previous implementation. + +However, the new implementation cannot preserve input sort order like the old version could. This is a fundamental design trade-off that prioritizes performance and memory efficiency over sort order preservation. + +See [#16996] for details. + +[#16996]: https://github.com/apache/datafusion/pull/16996 + +### Add `as_any()` method to `LazyBatchGenerator` + +To help with protobuf serialization, the `as_any()` method has been added to the `LazyBatchGenerator` trait. This means you will need to add `as_any()` to your implementation of `LazyBatchGenerator`: + +```rust +# /* comment to avoid running + +impl LazyBatchGenerator for MyBatchGenerator { + fn as_any(&self) -> &dyn Any { + self + } + + ... +} + +# */ +``` + +See [#17200](https://github.com/apache/datafusion/pull/17200) for details. + +### Refactored `DataSource::try_swapping_with_projection` + +We refactored `DataSource::try_swapping_with_projection` to simplify the method and minimize leakage across the ExecutionPlan <-> DataSource abstraction layer. +Reimplementation for any custom `DataSource` should be relatively straightforward, see [#17395] for more details. + +[#17395]: https://github.com/apache/datafusion/pull/17395/ + +### `FileOpenFuture` now uses `DataFusionError` instead of `ArrowError` + +The `FileOpenFuture` type alias has been updated to use `DataFusionError` instead of `ArrowError` for its error type. This change affects the `FileOpener` trait and any implementations that work with file streaming operations. + +**Before:** + +```rust,ignore +pub type FileOpenFuture = BoxFuture<'static, Result>>>; +``` + +**After:** + +```rust,ignore +pub type FileOpenFuture = BoxFuture<'static, Result>>>; +``` + +If you have custom implementations of `FileOpener` or work directly with `FileOpenFuture`, you'll need to update your error handling to use `DataFusionError` instead of `ArrowError`. The `FileStreamState` enum's `Open` variant has also been updated accordingly. See [#17397] for more details. + +[#17397]: https://github.com/apache/datafusion/pull/17397 + +### FFI user defined aggregate function signature change + +The Foreign Function Interface (FFI) signature for user defined aggregate functions +has been updated to call `return_field` instead of `return_type` on the underlying +aggregate function. This is to support metadata handling with these aggregate functions. +This change should be transparent to most users. If you have written unit tests to call +`return_type` directly, you may need to change them to calling `return_field` instead. + +This update is a breaking change to the FFI API. The current best practice when using the +FFI crate is to ensure that all libraries that are interacting are using the same +underlying Rust version. Issue [#17374] has been opened to discuss stabilization of +this interface so that these libraries can be used across different DataFusion versions. + +See [#17407] for details. + +[#17407]: https://github.com/apache/datafusion/pull/17407 +[#17374]: https://github.com/apache/datafusion/issues/17374 + +### Added `PhysicalExpr::is_volatile_node` + +We added a method to `PhysicalExpr` to mark a `PhysicalExpr` as volatile: + +```rust,ignore +impl PhysicalExpr for MyRandomExpr { + fn is_volatile_node(&self) -> bool { + true + } +} +``` + +We've shipped this with a default value of `false` to minimize breakage but we highly recommend that implementers of `PhysicalExpr` opt into a behavior, even if it is returning `false`. + +You can see more discussion and example implementations in [#17351]. + +[#17351]: https://github.com/apache/datafusion/pull/17351 + +## DataFusion `49.0.0` + +### `MSRV` updated to 1.85.1 + +The Minimum Supported Rust Version (MSRV) has been updated to [`1.85.1`]. See +[#16728] for details. + +[`1.85.1`]: https://releases.rs/docs/1.85.1/ +[#16728]: https://github.com/apache/datafusion/pull/16728 + +### `DataFusionError` variants are now `Box`ed + +To reduce the size of `DataFusionError`, several variants that were previously stored inline are now `Box`ed. This reduces the size of `Result` and thus stack usage and async state machine size. Please see [#16652] for more details. + +The following variants of `DataFusionError` are now boxed: + +- `ArrowError` +- `SQL` +- `SchemaError` + +This is a breaking change. Code that constructs or matches on these variants will need to be updated. + +For example, to create a `SchemaError`, instead of: + +```rust +# /* comment to avoid running +use datafusion_common::{DataFusionError, SchemaError}; +DataFusionError::SchemaError( + SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }, + Box::new(None) +) +# */ +``` + +You now need to `Box` the inner error: + +```rust +# /* comment to avoid running +use datafusion_common::{DataFusionError, SchemaError}; +DataFusionError::SchemaError( + Box::new(SchemaError::DuplicateUnqualifiedField { name: "foo".to_string() }), + Box::new(None) +) +# */ +``` + +[#16652]: https://github.com/apache/datafusion/issues/16652 + +### Metadata on Arrow Types is now represented by `FieldMetadata` + +Metadata from the Arrow `Field` is now stored using the `FieldMetadata` +structure. In prior versions it was stored as both a `HashMap` +and a `BTreeMap`. `FieldMetadata` is a easier to work with and +is more efficient. + +To create `FieldMetadata` from a `Field`: + +```rust +# /* comment to avoid running + let metadata = FieldMetadata::from(&field); +# */ +``` + +To add metadata to a `Field`, use the `add_to_field` method: + +```rust +# /* comment to avoid running +let updated_field = metadata.add_to_field(field); +# */ +``` + +See [#16317] for details. + +[#16317]: https://github.com/apache/datafusion/pull/16317 + +### New `datafusion.execution.spill_compression` configuration option + +DataFusion 49.0.0 adds support for compressing spill files when data is written to disk during spilling query execution. A new configuration option `datafusion.execution.spill_compression` controls the compression codec used. + +**Configuration:** + +- **Key**: `datafusion.execution.spill_compression` +- **Default**: `uncompressed` +- **Valid values**: `uncompressed`, `lz4_frame`, `zstd` + +**Usage:** + +```rust +# /* comment to avoid running +use datafusion::prelude::*; +use datafusion_common::config::SpillCompression; + +let config = SessionConfig::default() + .with_spill_compression(SpillCompression::Zstd); +let ctx = SessionContext::new_with_config(config); +# */ +``` + +Or via SQL: + +```sql +SET datafusion.execution.spill_compression = 'zstd'; +``` + +For more details about this configuration option, including performance trade-offs between different compression codecs, see the [Configuration Settings](../user-guide/configs.md) documentation. + +### Deprecated `map_varchar_to_utf8view` configuration option + +See [issue #16290](https://github.com/apache/datafusion/pull/16290) for more information +The old configuration + +```text +datafusion.sql_parser.map_varchar_to_utf8view +``` + +is now **deprecated** in favor of the unified option below.\ +If you previously used this to control only `VARCHAR`→`Utf8View` mapping, please migrate to `map_string_types_to_utf8view`. + +--- + +### New `map_string_types_to_utf8view` configuration option + +To unify **all** SQL string types (`CHAR`, `VARCHAR`, `TEXT`, `STRING`) to Arrow’s zero‑copy `Utf8View`, DataFusion 49.0.0 introduces: + +- **Key**: `datafusion.sql_parser.map_string_types_to_utf8view` +- **Default**: `true` + +**Description:** + +- When **true** (default), **all** SQL string types are mapped to `Utf8View`, avoiding full‑copy UTF‑8 allocations and improving performance. +- When **false**, DataFusion falls back to the legacy `Utf8` mapping for **all** string types. + +#### Examples + +```rust +# /* comment to avoid running +// Disable Utf8View mapping for all SQL string types +let opts = datafusion::sql::planner::ParserOptions::new() + .with_map_string_types_to_utf8view(false); + +// Verify the setting is applied +assert!(!opts.map_string_types_to_utf8view); +# */ +``` + +--- + +```sql +-- Disable Utf8View mapping globally +SET datafusion.sql_parser.map_string_types_to_utf8view = false; + +-- Now VARCHAR, CHAR, TEXT, STRING all use Utf8 rather than Utf8View +CREATE TABLE my_table (a VARCHAR, b TEXT, c STRING); +DESCRIBE my_table; +``` + +### Deprecating `SchemaAdapterFactory` and `SchemaAdapter` + +We are moving away from converting data (using `SchemaAdapter`) to converting the expressions themselves (which is more efficient and flexible). + +See [issue #16800](https://github.com/apache/datafusion/issues/16800) for more information +The first place this change has taken place is in predicate pushdown for Parquet. +By default if you do not use a custom `SchemaAdapterFactory` we will use expression conversion instead. +If you do set a custom `SchemaAdapterFactory` we will continue to use it but emit a warning about that code path being deprecated. + +To resolve this you need to implement a custom `PhysicalExprAdapterFactory` and use that instead of a `SchemaAdapterFactory`. +See the [default values](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/default_column_values.rs) for an example of how to do this. +Opting into the new APIs will set you up for future changes since we plan to expand use of `PhysicalExprAdapterFactory` to other areas of DataFusion. + +See [#16800] for details. + +[#16800]: https://github.com/apache/datafusion/issues/16800 + +### `TableParquetOptions` Updated + +The `TableParquetOptions` struct has a new `crypto` field to specify encryption +options for Parquet files. The `ParquetEncryptionOptions` implements `Default` +so you can upgrade your existing code like this: + +```rust +# /* comment to avoid running +TableParquetOptions { + global, + column_specific_options, + key_value_metadata, +} +# */ +``` + +To this: + +```rust +# /* comment to avoid running +TableParquetOptions { + global, + column_specific_options, + key_value_metadata, + crypto: Default::default(), // New crypto field +} +# */ +``` + +## DataFusion `48.0.1` + +### `datafusion.execution.collect_statistics` now defaults to `true` + +The default value of the `datafusion.execution.collect_statistics` configuration +setting is now true. This change impacts users that use that value directly and relied +on its default value being `false`. + +This change also restores the default behavior of `ListingTable` to its previous. If you use it directly +you can maintain the current behavior by overriding the default value in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(false) + // other options +# */ +``` + +## DataFusion `48.0.0` + +### `Expr::Literal` has optional metadata + +The [`Expr::Literal`] variant now includes optional metadata, which allows for +carrying through Arrow field metadata to support extension types and other uses. + +This means code such as + +```rust +# /* comment to avoid running +match expr { +... + Expr::Literal(scalar) => ... +... +} +# */ +``` + +Should be updated to: + +```rust +# /* comment to avoid running +match expr { +... + Expr::Literal(scalar, _metadata) => ... +... +} +# */ +``` + +Likewise constructing `Expr::Literal` requires metadata as well. The [`lit`] function +has not changed and returns an `Expr::Literal` with no metadata. + +[`expr::literal`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html#variant.Literal +[`lit`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.lit.html + +### `Expr::WindowFunction` is now `Box`ed + +`Expr::WindowFunction` is now a `Box` instead of a `WindowFunction` directly. +This change was made to reduce the size of `Expr` and improve performance when +planning queries (see [details on #16207]). + +This is a breaking change, so you will need to update your code if you match +on `Expr::WindowFunction` directly. For example, if you have code like this: + +```rust +# /* comment to avoid running +match expr { + Expr::WindowFunction(WindowFunction { + params: + WindowFunctionParams { + partition_by, + order_by, + .. + } + }) => { + // Use partition_by and order_by as needed + } + _ => { + // other expr + } +} +# */ +``` + +You will need to change it to: + +```rust +# /* comment to avoid running +match expr { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: WindowFunctionParams { + args, + partition_by, + .. + }, + } = window_fun.as_ref(); + // Use partition_by and order_by as needed + } + _ => { + // other expr + } +} +# */ +``` + +[details on #16207]: https://github.com/apache/datafusion/pull/16207#issuecomment-2922659103 + +### The `VARCHAR` SQL type is now represented as `Utf8View` in Arrow + +The mapping of the SQL `VARCHAR` type has been changed from `Utf8` to `Utf8View` +which improves performance for many string operations. You can read more about +`Utf8View` in the [DataFusion blog post on German-style strings] + +[datafusion blog post on german-style strings]: https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/ + +This means that when you create a table with a `VARCHAR` column, it will now use +`Utf8View` as the underlying data type. For example: + +```sql +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.001 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8View | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.000 seconds. +``` + +You can restore the old behavior of using `Utf8` by changing the +`datafusion.sql_parser.map_varchar_to_utf8view` configuration setting. For +example + +```sql +> set datafusion.sql_parser.map_varchar_to_utf8view = false; +0 row(s) fetched. +Elapsed 0.001 seconds. + +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.014 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8 | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.004 seconds. +``` + +### `ListingOptions` default for `collect_stat` changed from `true` to `false` + +This makes it agree with the default for `SessionConfig`. +Most users won't be impacted by this change but if you were using `ListingOptions` directly +and relied on the default value of `collect_stat` being `true`, you will need to +explicitly set it to `true` in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true) + // other options +# */ +``` + +### Processing `FieldRef` instead of `DataType` for user defined functions + +In order to support metadata handling and extension types, user defined functions are +now switching to traits which use `FieldRef` rather than a `DataType` and nullability. +This gives a single interface to both of these parameters and additionally allows +access to metadata fields, which can be used for extension types. + +To upgrade structs which implement `ScalarUDFImpl`, if you have implemented +`return_type_from_args` you need instead to implement `return_field_from_args`. +If your functions do not need to handle metadata, this should be straightforward +repackaging of the output data into a `FieldRef`. The name you specify on the +field is not important. It will be overwritten during planning. `ReturnInfo` +has been removed, so you will need to remove all references to it. + +`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this +to access the metadata associated with the columnar values during invocation. + +To upgrade user defined aggregate functions, there is now a function +`return_field` that will allow you to specify both metadata and nullability of +your function. You are not required to implement this if you do not need to +handle metadata. + +The largest change to aggregate functions happens in the accumulator arguments. +Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `FieldRef` rather +than `DataType`. + +To upgrade window functions, `ExpressionArgs` now contains input fields instead +of input data types. When setting these fields, the name of the field is +not important since this gets overwritten during the planning stage. All you +should need to do is wrap your existing data types in fields with nullability +set depending on your use case. + +### Physical Expression return `Field` + +To support the changes to user defined functions processing metadata, the +`PhysicalExpr` trait, which now must specify a return `Field` based on the input +schema. To upgrade structs which implement `PhysicalExpr` you need to implement +the `return_field` function. There are numerous examples in the `physical-expr` +crate. + +### `FileFormat::supports_filters_pushdown` replaced with `FileSource::try_pushdown_filters` + +To support more general filter pushdown, the `FileFormat::supports_filters_pushdown` was replaced with +`FileSource::try_pushdown_filters`. +If you implemented a custom `FileFormat` that uses a custom `FileSource` you will need to implement +`FileSource::try_pushdown_filters`. +See `ParquetSource::try_pushdown_filters` for an example of how to implement this. + +`FileFormat::supports_filters_pushdown` has been removed. + +### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` Removed + +`ParquetExec`, `AvroExec`, `CsvExec`, and `JsonExec` were deprecated in +DataFusion 46 and are removed in DataFusion 48. This is sooner than the normal +process described in the [API Deprecation Guidelines] because all the tests +cover the new `DataSourceExec` rather than the older structures. As we evolve +`DataSource`, the old structures began to show signs of "bit rotting" (not +working but no one knows due to lack of test coverage). + +[api deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html#deprecation-guidelines + +### `PartitionedFile` added as an argument to the `FileOpener` trait + +This is necessary to properly fix filter pushdown for filters that combine partition +columns and file columns (e.g. `day = username['dob']`). + +If you implemented a custom `FileOpener` you will need to add the `PartitionedFile` argument +but are not required to use it in any way. + +## DataFusion `47.0.0` + +This section calls out some of the major changes in the `47.0.0` release of DataFusion. + +Here are some example upgrade PRs that demonstrate changes required when upgrading from DataFusion 46.0.0: + +- [delta-rs Upgrade to `47.0.0`](https://github.com/delta-io/delta-rs/pull/3378) +- [DataFusion Comet Upgrade to `47.0.0`](https://github.com/apache/datafusion-comet/pull/1563) +- [Sail Upgrade to `47.0.0`](https://github.com/lakehq/sail/pull/434) + +### Upgrades to `arrow-rs` and `arrow-parquet` 55.0.0 and `object_store` 0.12.0 + +Several APIs are changed in the underlying arrow and parquet libraries to use a +`u64` instead of `usize` to better support WASM (See [#7371] and [#6961]) + +Additionally `ObjectStore::list` and `ObjectStore::list_with_offset` have been changed to return `static` lifetimes (See [#6619]) + +[#6619]: https://github.com/apache/arrow-rs/pull/6619 +[#7371]: https://github.com/apache/arrow-rs/pull/7371 + +This requires converting from `usize` to `u64` occasionally as well as changes to `ObjectStore` implementations such as + +```rust +# /* comment to avoid running +impl Objectstore { + ... + // The range is now a u64 instead of usize + async fn get_range(&self, location: &Path, range: Range) -> ObjectStoreResult { + self.inner.get_range(location, range).await + } + ... + // the lifetime is now 'static instead of `_ (meaning the captured closure can't contain references) + // (this also applies to list_with_offset) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, ObjectStoreResult> { + self.inner.list(prefix) + } +} +# */ +``` + +The `ParquetObjectReader` has been updated to no longer require the object size +(it can be fetched using a single suffix request). See [#7334] for details + +[#7334]: https://github.com/apache/arrow-rs/pull/7334 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, meta); +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, location) + .with_file_size(meta.size); +# */ +``` + +### `DisplayFormatType::TreeRender` + +DataFusion now supports [`tree` style explain plans]. Implementations of +`Executionplan` must also provide a description in the +`DisplayFormatType::TreeRender` format. This can be the same as the existing +`DisplayFormatType::Default`. + +[`tree` style explain plans]: https://datafusion.apache.org/user-guide/sql/explain.html#tree-format-default + +### Removed Deprecated APIs + +Several APIs have been removed in this release. These were either deprecated +previously or were hard to use correctly such as the multiple different +`ScalarUDFImpl::invoke*` APIs. See [#15130], [#15123], and [#15027] for more +details. + +[#15130]: https://github.com/apache/datafusion/pull/15130 +[#15123]: https://github.com/apache/datafusion/pull/15123 +[#15027]: https://github.com/apache/datafusion/pull/15027 + +### `FileScanConfig` --> `FileScanConfigBuilder` + +Previously, `FileScanConfig::build()` directly created ExecutionPlans. In +DataFusion 47.0.0 this has been changed to use `FileScanConfigBuilder`. See +[#15352] for details. + +[#15352]: https://github.com/apache/datafusion/pull/15352 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let plan = FileScanConfig::new(url, schema, Arc::new(file_source)) + .with_statistics(stats) + ... + .build() +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let config = FileScanConfigBuilder::new(url, schema, Arc::new(file_source)) + .with_statistics(stats) + ... + .build(); +let scan = DataSourceExec::from_data_source(config); +# */ +``` + ## DataFusion `46.0.0` ### Use `invoke_with_args` instead of `invoke()` and `invoke_batch()` @@ -39,7 +976,7 @@ below. See [PR 14876] for an example. Given existing code like this: ```rust -# /* +# /* comment to avoid running impl ScalarUDFImpl for SparkConcat { ... fn invoke_batch(&self, args: &[ColumnarValue], number_rows: usize) -> Result { @@ -59,7 +996,7 @@ impl ScalarUDFImpl for SparkConcat { To ```rust -# /* comment out so they don't run +# /* comment to avoid running impl ScalarUDFImpl for SparkConcat { ... fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -164,7 +1101,7 @@ let mut file_source = ParquetSource::new(parquet_options) // Add filter if let Some(predicate) = logical_filter { if config.enable_parquet_pushdown { - file_source = file_source.with_predicate(Arc::clone(&file_schema), predicate); + file_source = file_source.with_predicate(predicate); } }; @@ -217,8 +1154,8 @@ Elapsed 0.005 seconds. DataFusion 46 has changed the way scalar array function signatures are declared. Previously, functions needed to select from a list of predefined signatures within the `ArrayFunctionSignature` enum. Now the signatures -can be defined via a `Vec` of psuedo-types, which each correspond to a -single argument. Those psuedo-types are the variants of the +can be defined via a `Vec` of pseudo-types, which each correspond to a +single argument. Those pseudo-types are the variants of the `ArrayFunctionArgument` enum and are as follows: - `Array`: An argument of type List/LargeList/FixedSizeList. All Array diff --git a/docs/source/library-user-guide/using-the-sql-api.md b/docs/source/library-user-guide/using-the-sql-api.md index f78cf16f4cb67..8b8ba2a3716a3 100644 --- a/docs/source/library-user-guide/using-the-sql-api.md +++ b/docs/source/library-user-guide/using-the-sql-api.md @@ -119,6 +119,8 @@ async fn main() -> Result<()> { DataFusion can also read Avro files using the `register_avro` method. ```rust +# #[cfg(feature = "avro")] +{ use datafusion::arrow::util::pretty; use datafusion::error::Result; use datafusion::prelude::*; @@ -154,6 +156,7 @@ async fn main() -> Result<()> { ); Ok(()) } +} ``` ## Reading Multiple Files as a table diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index df4e5e3940aa6..bdcaaeae0a6e2 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -52,13 +52,13 @@ As the writer of a library, you can use `Expr`s to represent computations that y ## Arrow Schema and DataFusion DFSchema -Apache Arrow `Schema` provides a lightweight structure for defining data, and Apache Datafusion`DFSchema` extends it with extra information such as column qualifiers and functional dependencies. Column qualifiers are multi part path to the table e.g table, schema, catalog. Functional Dependency is the relationship between attributes(characteristics) of a table related to each other. +Apache Arrow `Schema` provides a lightweight structure for defining data, and Apache Datafusion `DFSchema` extends it with extra information such as column qualifiers and functional dependencies. Column qualifiers are multi part path to the table e.g table, schema, catalog. Functional Dependency is the relationship between attributes(characteristics) of a table related to each other. ### Difference between Schema and DFSchema - Schema: A fundamental component of Apache Arrow, `Schema` defines a dataset's structure, specifying column names and their data types. - > Please see [Struct Schema](https://docs.rs/arrow-schema/54.2.1/arrow_schema/struct.Schema.html) for a detailed document of Arrow Schema. + > Please see [Struct Schema](https://docs.rs/arrow-schema/latest/arrow_schema/struct.Schema.html) for a detailed document of Arrow Schema. - DFSchema: Extending `Schema`, `DFSchema` incorporates qualifiers such as table names, enabling it to carry additional context when required. This is particularly valuable for managing queries across multiple tables. > Please see [Struct DFSchema](https://docs.rs/datafusion/latest/datafusion/common/struct.DFSchema.html) for a detailed document of DFSchema. @@ -75,7 +75,7 @@ Please see [expr_api.rs](https://github.com/apache/datafusion/blob/main/datafusi ## A Scalar UDF Example -We'll use a `ScalarUDF` expression as our example. This necessitates implementing an actual UDF, and for ease we'll use the same example from the [adding UDFs](./adding-udfs.md) guide. +We'll use a `ScalarUDF` expression as our example. This necessitates implementing an actual UDF, and for ease we'll use the same example from the [adding UDFs](functions/adding-udfs.md) guide. So assuming you've written that function, you can use it to create an `Expr`: @@ -121,7 +121,7 @@ If you'd like to learn more about `Expr`s, before we get into the details of cre ## Rewriting `Expr`s -There are several examples of rewriting and working with `Exprs`: +There are several examples of rewriting and working with `Expr`s: - [expr_api.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs) - [analyzer_rule.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/analyzer_rule.rs) @@ -162,7 +162,7 @@ fn rewrite_add_one(expr: Expr) -> Result> { ### Creating an `OptimizerRule` -In DataFusion, an `OptimizerRule` is a trait that supports rewriting`Expr`s that appear in various parts of the `LogicalPlan`. It follows DataFusion's general mantra of trait implementations to drive behavior. +In DataFusion, an `OptimizerRule` is a trait that supports rewriting `Expr`s that appear in various parts of the `LogicalPlan`. It follows DataFusion's general mantra of trait implementations to drive behavior. We'll call our rule `AddOneInliner` and implement the `OptimizerRule` trait. The `OptimizerRule` trait has two methods: @@ -322,7 +322,7 @@ async fn main() -> Result<()> { let plan = ctx.sql(sql).await?.into_optimized_plan()?.clone(); let expected = r#"Projection: Int64(6) AS added_one - EmptyRelation"#; + EmptyRelation: rows=1"#; assert_eq!(plan.to_string(), expected); diff --git a/docs/source/user-guide/cli/datasources.md b/docs/source/user-guide/cli/datasources.md index 39172e94e5f80..6b1a4887a8a0f 100644 --- a/docs/source/user-guide/cli/datasources.md +++ b/docs/source/user-guide/cli/datasources.md @@ -82,23 +82,29 @@ select count(*) from 'https://datasets.clickhouse.com/hits_compatible/athena_par To read from an AWS S3 or GCS, use `s3` or `gs` as a protocol prefix. For example, to read a file in an S3 bucket named `my-data-bucket` use the URL `s3://my-data-bucket`and set the relevant access credentials as environmental -variables (e.g. for AWS S3 you need to at least `AWS_ACCESS_KEY_ID` and +variables (e.g. for AWS S3 you can use `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`). ```sql -select count(*) from 's3://my-data-bucket/athena_partitioned/hits.parquet' +> select count(*) from 's3://altinity-clickhouse-data/nyc_taxi_rides/data/tripdata_parquet/'; ++------------+ +| count(*) | ++------------+ +| 1310903963 | ++------------+ ``` -See the [`CREATE EXTERNAL TABLE`](#create-external-table) section for +See the [`CREATE EXTERNAL TABLE`](#create-external-table) section below for additional configuration options. # `CREATE EXTERNAL TABLE` It is also possible to create a table backed by files or remote locations via -`CREATE EXTERNAL TABLE` as shown below. Note that wildcards (e.g. `*`) are also -supported +`CREATE EXTERNAL TABLE` as shown below. Note that DataFusion does not support +wildcards (e.g. `*`) in file paths; instead, specify the directory path directly +to read all compatible files in that directory. -For example, to create a table `hits` backed by a local parquet file, use: +For example, to create a table `hits` backed by a local parquet file named `hits.parquet`: ```sql CREATE EXTERNAL TABLE hits @@ -106,7 +112,7 @@ STORED AS PARQUET LOCATION 'hits.parquet'; ``` -To create a table `hits` backed by a remote parquet file via HTTP(S), use +To create a table `hits` backed by a remote parquet file via HTTP(S): ```sql CREATE EXTERNAL TABLE hits @@ -126,6 +132,60 @@ select count(*) from hits; 1 row in set. Query took 0.344 seconds. ``` +**Why Wildcards Are Not Supported** + +Although wildcards (e.g., _.parquet or \*\*/_.parquet) may work for local +filesystems in some cases, they are not supported by DataFusion CLI. This +is because wildcards are not universally applicable across all storage backends +(e.g., S3, GCS). Instead, DataFusion expects the user to specify the directory +path, and it will automatically read all compatible files within that directory. + +For example, the following usage is not supported: + +```sql +CREATE EXTERNAL TABLE test ( + message TEXT, + day DATE +) +STORED AS PARQUET +LOCATION 'gs://bucket/*.parquet'; +``` + +Instead, you should use: + +```sql +CREATE EXTERNAL TABLE test ( + message TEXT, + day DATE +) +STORED AS PARQUET +LOCATION 'gs://bucket/my_table/'; +``` + +When specifying a directory path that has a Hive compliant partition structure, by default, DataFusion CLI will +automatically parse and incorporate the Hive columns and their values into the table's schema and data. Given the +following remote object paths: + +```console +gs://bucket/my_table/a=1/b=100/file1.parquet +gs://bucket/my_table/a=2/b=200/file2.parquet +``` + +`my_table` can be queried and filtered on the Hive columns: + +```sql +CREATE EXTERNAL TABLE my_table +STORED AS PARQUET +LOCATION 'gs://bucket/my_table/'; + +SELECT count(*) FROM my_table WHERE b=200; ++----------+ +| count(*) | ++----------+ +| 1 | ++----------+ +``` + # Formats ## Parquet @@ -143,25 +203,63 @@ LOCATION '/mnt/nyctaxi/tripdata.parquet'; Register a single folder parquet datasource. Note: All files inside must be valid parquet files and have compatible schemas +:::{note} +Paths must end in Slash `/` +: The path must end in `/` otherwise DataFusion will treat the path as a file and not a directory +::: + ```sql CREATE EXTERNAL TABLE taxi STORED AS PARQUET LOCATION '/mnt/nyctaxi/'; ``` -Register a single folder parquet datasource by specifying a wildcard for files to read +### Parquet Specific Options + +You can specify additional options for parquet files using the `OPTIONS` clause. +For example, to read and write a parquet directory with encryption settings you could use: ```sql -CREATE EXTERNAL TABLE taxi -STORED AS PARQUET -LOCATION '/mnt/nyctaxi/*.parquet'; +CREATE EXTERNAL TABLE encrypted_parquet_table +( +double_field double, +float_field float +) +STORED AS PARQUET LOCATION 'pq/' OPTIONS ( + -- encryption + 'format.crypto.file_encryption.encrypt_footer' 'true', + 'format.crypto.file_encryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + 'format.crypto.file_encryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + 'format.crypto.file_encryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + -- decryption + 'format.crypto.file_decryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + 'format.crypto.file_decryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + 'format.crypto.file_decryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" +); ``` +Here the keys are specified in hexadecimal format because they are binary data. These can be encoded in SQL using: + +```sql +select encode('0123456789012345', 'hex'); +/* ++----------------------------------------------+ +| encode(Utf8("0123456789012345"),Utf8("hex")) | ++----------------------------------------------+ +| 30313233343536373839303132333435 | ++----------------------------------------------+ +*/ +``` + +For more details on the available options, refer to the Rust +[TableParquetOptions](https://docs.rs/datafusion/latest/datafusion/common/config/struct.TableParquetOptions.html) +documentation in DataFusion. + ## CSV DataFusion will infer the CSV schema automatically or you can provide it explicitly. -Register a single file csv datasource with a header row. +Register a single file csv datasource with a header row: ```sql CREATE EXTERNAL TABLE test @@ -170,7 +268,7 @@ LOCATION '/path/to/aggregate_test_100.csv' OPTIONS ('has_header' 'true'); ``` -Register a single file csv datasource with explicitly defined schema. +Register a single file csv datasource with explicitly defined schema: ```sql CREATE EXTERNAL TABLE test ( @@ -196,7 +294,7 @@ LOCATION '/path/to/aggregate_test_100.csv'; ## HTTP(s) -To read from a remote parquet file via HTTP(S) you can use the following: +To read from a remote parquet file via HTTP(S): ```sql CREATE EXTERNAL TABLE hits @@ -206,9 +304,12 @@ LOCATION 'https://datasets.clickhouse.com/hits_compatible/athena_partitioned/hit ## S3 -[AWS S3](https://aws.amazon.com/s3/) data sources must have connection credentials configured. +DataFusion CLI supports configuring [AWS S3](https://aws.amazon.com/s3/) via the +`CREATE EXTERNAL TABLE` statement and standard AWS configuration methods (via the +[`aws-config`] AWS SDK crate). -To create an external table from a file in an S3 bucket: +To create an external table from a file in an S3 bucket with explicit +credentials: ```sql CREATE EXTERNAL TABLE test @@ -221,7 +322,7 @@ OPTIONS( LOCATION 's3://bucket/path/file.parquet'; ``` -It is also possible to specify the access information using environment variables: +To create an external table using environment variables: ```bash $ export AWS_DEFAULT_REGION=us-east-2 @@ -230,7 +331,7 @@ $ export AWS_ACCESS_KEY_ID=****** $ datafusion-cli `datafusion-cli v21.0.0 -> create external table test stored as parquet location 's3://bucket/path/file.parquet'; +> create CREATE TABLE test STORED AS PARQUET LOCATION 's3://bucket/path/file.parquet'; 0 rows in set. Query took 0.374 seconds. > select * from test; +----------+----------+ @@ -241,19 +342,39 @@ $ datafusion-cli 1 row in set. Query took 0.171 seconds. ``` +To read from a public S3 bucket without signatures, use the +`aws.SKIP_SIGNATURE` option: + +```sql +CREATE EXTERNAL TABLE nyc_taxi_rides +STORED AS PARQUET LOCATION 's3://altinity-clickhouse-data/nyc_taxi_rides/data/tripdata_parquet/' +OPTIONS(aws.SKIP_SIGNATURE true); +``` + +Credentials are taken in this order of precedence: + +1. Explicitly specified in the `OPTIONS` clause of the `CREATE EXTERNAL TABLE` statement. +2. Determined by [`aws-config`] crate (standard environment variables such as `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` as well as other AWS specific features). + +If no credentials are specified, DataFusion CLI will use unsigned requests to S3, +which allows reading from public buckets. + Supported configuration options are: -| Environment Variable | Configuration Option | Description | -| ---------------------------------------- | ----------------------- | ---------------------------------------------------- | -| `AWS_ACCESS_KEY_ID` | `aws.access_key_id` | | -| `AWS_SECRET_ACCESS_KEY` | `aws.secret_access_key` | | -| `AWS_DEFAULT_REGION` | `aws.region` | | -| `AWS_ENDPOINT` | `aws.endpoint` | | -| `AWS_SESSION_TOKEN` | `aws.token` | | -| `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` | | See [IAM Roles] | -| `AWS_ALLOW_HTTP` | | set to "true" to permit HTTP connections without TLS | +| Environment Variable | Configuration Option | Description | +| ---------------------------------------- | ----------------------- | ---------------------------------------------- | +| `AWS_ACCESS_KEY_ID` | `aws.access_key_id` | | +| `AWS_SECRET_ACCESS_KEY` | `aws.secret_access_key` | | +| `AWS_DEFAULT_REGION` | `aws.region` | | +| `AWS_ENDPOINT` | `aws.endpoint` | | +| `AWS_SESSION_TOKEN` | `aws.token` | | +| `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` | | See [IAM Roles] | +| `AWS_ALLOW_HTTP` | | If "true", permit HTTP connections without TLS | +| `AWS_SKIP_SIGNATURE` | `aws.skip_signature` | If "true", does not sign requests | +| | `aws.nosign` | Alias for `skip_signature` | [iam roles]: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html +[`aws-config`]: https://docs.rs/aws-config/latest/aws_config/ ## OSS diff --git a/docs/source/user-guide/cli/functions.md b/docs/source/user-guide/cli/functions.md new file mode 100644 index 0000000000000..305b53c16f65e --- /dev/null +++ b/docs/source/user-guide/cli/functions.md @@ -0,0 +1,142 @@ + + +# CLI Specific Functions + +`datafusion-cli` comes with build-in functions that are not included in the +DataFusion SQL engine by default. These functions are: + +## `parquet_metadata` + +The `parquet_metadata` table function can be used to inspect detailed metadata +about a parquet file such as statistics, sizes, and other information. This can +be helpful to understand how parquet files are structured. + +For example, to see information about the `"WatchID"` column in the +`hits.parquet` file, you can use: + +```sql +SELECT path_in_schema, row_group_id, row_group_num_rows, stats_min, stats_max, total_compressed_size +FROM parquet_metadata('hits.parquet') +WHERE path_in_schema = '"WatchID"' +LIMIT 3; + ++----------------+--------------+--------------------+---------------------+---------------------+-----------------------+ +| path_in_schema | row_group_id | row_group_num_rows | stats_min | stats_max | total_compressed_size | ++----------------+--------------+--------------------+---------------------+---------------------+-----------------------+ +| "WatchID" | 0 | 450560 | 4611687214012840539 | 9223369186199968220 | 3883759 | +| "WatchID" | 1 | 612174 | 4611689135232456464 | 9223371478009085789 | 5176803 | +| "WatchID" | 2 | 344064 | 4611692774829951781 | 9223363791697310021 | 3031680 | ++----------------+--------------+--------------------+---------------------+---------------------+-----------------------+ +3 rows in set. Query took 0.053 seconds. +``` + +The returned table has the following columns for each row for each column chunk +in the file. Please refer to the [Parquet Documentation] for more information in +the meaning of these fields. + +[parquet documentation]: https://parquet.apache.org/ + +| column_name | data_type | Description | +| ----------------------- | --------- | --------------------------------------------------------------------------------------------------- | +| filename | Utf8 | Name of the file | +| row_group_id | Int64 | Row group index the column chunk belongs to | +| row_group_num_rows | Int64 | Count of rows stored in the row group | +| row_group_num_columns | Int64 | Total number of columns in the row group (same for all row groups) | +| row_group_bytes | Int64 | Number of bytes used to store the row group (not including metadata) | +| column_id | Int64 | ID of the column | +| file_offset | Int64 | Offset within the file that this column chunk's data begins | +| num_values | Int64 | Total number of values in this column chunk | +| path_in_schema | Utf8 | "Path" (column name) of the column chunk in the schema | +| type | Utf8 | Parquet data type of the column chunk | +| stats_min | Utf8 | The minimum value for this column chunk, if stored in the statistics, cast to a string | +| stats_max | Utf8 | The maximum value for this column chunk, if stored in the statistics, cast to a string | +| stats_null_count | Int64 | Number of null values in this column chunk, if stored in the statistics | +| stats_distinct_count | Int64 | Number of distinct values in this column chunk, if stored in the statistics | +| stats_min_value | Utf8 | Same as `stats_min` | +| stats_max_value | Utf8 | Same as `stats_max` | +| compression | Utf8 | Block level compression (e.g. `SNAPPY`) used for this column chunk | +| encodings | Utf8 | All block level encodings (e.g. `[PLAIN_DICTIONARY, PLAIN, RLE]`) used for this column chunk | +| index_page_offset | Int64 | Offset in the file of the [`page index`], if any | +| dictionary_page_offset | Int64 | Offset in the file of the dictionary page, if any | +| data_page_offset | Int64 | Offset in the file of the first data page, if any | +| total_compressed_size | Int64 | Number of bytes the column chunk's data after encoding and compression (what is stored in the file) | +| total_uncompressed_size | Int64 | Number of bytes the column chunk's data after encoding | + +[`page index`]: https://github.com/apache/parquet-format/blob/master/PageIndex.md + +## `metadata_cache` + +The `metadata_cache` function shows information about the default File Metadata Cache that is used by the +[`ListingTable`] implementation in DataFusion. This cache is used to speed up +reading metadata from files when scanning directories with many files. + +For example, after creating a table with the [CREATE EXTERNAL TABLE](../sql/ddl.md#create-external-table) +command: + +```sql +> create external table hits + stored as parquet + location 's3://clickhouse-public-datasets/hits_compatible/athena_partitioned/'; +``` + +You can inspect the metadata cache by querying the `metadata_cache` function: + +```sql +> select * from metadata_cache(); ++----------------------------------------------------+---------------------+-----------------+---------------------------------------+---------+---------------------+------+------------------+ +| path | file_modified | file_size_bytes | e_tag | version | metadata_size_bytes | hits | extra | ++----------------------------------------------------+---------------------+-----------------+---------------------------------------+---------+---------------------+------+------------------+ +| hits_compatible/athena_partitioned/hits_61.parquet | 2022-07-03T15:40:34 | 117270944 | "5db11cad1ca0d80d748fc92c914b010a-6" | NULL | 212949 | 0 | page_index=false | +| hits_compatible/athena_partitioned/hits_32.parquet | 2022-07-03T15:37:17 | 94506004 | "2f7db49a9fe242179590b615b94a39d2-5" | NULL | 278157 | 0 | page_index=false | +| hits_compatible/athena_partitioned/hits_40.parquet | 2022-07-03T15:38:07 | 142508647 | "9e5852b45a469d5a05bf270a286eab8a-8" | NULL | 212917 | 0 | page_index=false | +| hits_compatible/athena_partitioned/hits_93.parquet | 2022-07-03T15:44:07 | 127987774 | "751100bf0dac7d489b9836abf3108b99-7" | NULL | 278318 | 0 | page_index=false | +| . | ++----------------------------------------------------+---------------------+-----------------+---------------------------------------+---------+---------------------+------+------------------+ +``` + +Since `metadata_cache` is a normal table function, you can use it in most places you can use +a table reference. + +For example, to get the total size consumed by the cached entries: + +```sql +> select sum(metadata_size_bytes) from metadata_cache(); ++-------------------------------------------+ +| sum(metadata_cache().metadata_size_bytes) | ++-------------------------------------------+ +| 22972345 | ++-------------------------------------------+ +``` + +The columns of the returned table are: + +| column_name | data_type | Description | +| ------------------- | --------- | ----------------------------------------------------------------------------------------- | +| path | Utf8 | File path relative to the object store / filesystem root | +| file_modified | Timestamp | Last modified time of the file | +| file_size_bytes | UInt64 | Size of the file in bytes | +| e_tag | Utf8 | [Entity Tag] (ETag) of the file if available | +| version | Utf8 | Version of the file if available (for object stores that support versioning) | +| metadata_size_bytes | UInt64 | Size of the cached metadata in memory (not its thrift encoded form) | +| hits | UInt64 | Number of times the cached metadata has been accessed | +| extra | Utf8 | Extra information about the cached metadata (e.g., if page index information is included) | + +[`listingtable`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html +[entity tag]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag diff --git a/docs/source/user-guide/cli/index.rst b/docs/source/user-guide/cli/index.rst index 874cfc0eae868..325b0dce3fb19 100644 --- a/docs/source/user-guide/cli/index.rst +++ b/docs/source/user-guide/cli/index.rst @@ -25,3 +25,4 @@ DataFusion CLI installation usage datasources + functions diff --git a/docs/source/user-guide/cli/usage.md b/docs/source/user-guide/cli/usage.md index fb238dad10bb1..57a96c5d79003 100644 --- a/docs/source/user-guide/cli/usage.md +++ b/docs/source/user-guide/cli/usage.md @@ -57,6 +57,16 @@ OPTIONS: --mem-pool-type Specify the memory pool type 'greedy' or 'fair', default to 'greedy' + --top-memory-consumers + The number of top memory consumers to display when query fails due to memory exhaustion. To disable memory consumer tracking, set this value to 0 [default: 3] + + -d, --disk-limit + Available disk space for spilling queries (e.g. '10g'), default to None (uses DataFusion's default value of '100g') + + --object-store-profiling + Specify the default object_store_profiling mode, defaults to 'disabled'. + [possible values: disabled, enabled] [default: Disabled] + -p, --data-path Path to your data, default to current directory @@ -116,6 +126,12 @@ Available commands inside DataFusion CLI are: > \h function ``` +- Object Store Profiling Mode + +```bash +> \object_store_profiling [disabled|enabled] +``` + ## Supported SQL In addition to the normal [SQL supported in DataFusion], `datafusion-cli` also @@ -225,64 +241,5 @@ DataFusion CLI v13.0.0 ## Functions `datafusion-cli` comes with build-in functions that are not included in the -DataFusion SQL engine. These functions are: - -### `parquet_metadata` - -The `parquet_metadata` table function can be used to inspect detailed metadata -about a parquet file such as statistics, sizes, and other information. This can -be helpful to understand how parquet files are structured. - -For example, to see information about the `"WatchID"` column in the -`hits.parquet` file, you can use: - -```sql -SELECT path_in_schema, row_group_id, row_group_num_rows, stats_min, stats_max, total_compressed_size -FROM parquet_metadata('hits.parquet') -WHERE path_in_schema = '"WatchID"' -LIMIT 3; - -+----------------+--------------+--------------------+---------------------+---------------------+-----------------------+ -| path_in_schema | row_group_id | row_group_num_rows | stats_min | stats_max | total_compressed_size | -+----------------+--------------+--------------------+---------------------+---------------------+-----------------------+ -| "WatchID" | 0 | 450560 | 4611687214012840539 | 9223369186199968220 | 3883759 | -| "WatchID" | 1 | 612174 | 4611689135232456464 | 9223371478009085789 | 5176803 | -| "WatchID" | 2 | 344064 | 4611692774829951781 | 9223363791697310021 | 3031680 | -+----------------+--------------+--------------------+---------------------+---------------------+-----------------------+ -3 rows in set. Query took 0.053 seconds. -``` - -The returned table has the following columns for each row for each column chunk -in the file. Please refer to the [Parquet Documentation] for more information. - -[parquet documentation]: https://parquet.apache.org/ - -| column_name | data_type | Description | -| ----------------------- | --------- | --------------------------------------------------------------------------------------------------- | -| filename | Utf8 | Name of the file | -| row_group_id | Int64 | Row group index the column chunk belongs to | -| row_group_num_rows | Int64 | Count of rows stored in the row group | -| row_group_num_columns | Int64 | Total number of columns in the row group (same for all row groups) | -| row_group_bytes | Int64 | Number of bytes used to store the row group (not including metadata) | -| column_id | Int64 | ID of the column | -| file_offset | Int64 | Offset within the file that this column chunk's data begins | -| num_values | Int64 | Total number of values in this column chunk | -| path_in_schema | Utf8 | "Path" (column name) of the column chunk in the schema | -| type | Utf8 | Parquet data type of the column chunk | -| stats_min | Utf8 | The minimum value for this column chunk, if stored in the statistics, cast to a string | -| stats_max | Utf8 | The maximum value for this column chunk, if stored in the statistics, cast to a string | -| stats_null_count | Int64 | Number of null values in this column chunk, if stored in the statistics | -| stats_distinct_count | Int64 | Number of distinct values in this column chunk, if stored in the statistics | -| stats_min_value | Utf8 | Same as `stats_min` | -| stats_max_value | Utf8 | Same as `stats_max` | -| compression | Utf8 | Block level compression (e.g. `SNAPPY`) used for this column chunk | -| encodings | Utf8 | All block level encodings (e.g. `[PLAIN_DICTIONARY, PLAIN, RLE]`) used for this column chunk | -| index_page_offset | Int64 | Offset in the file of the [`page index`], if any | -| dictionary_page_offset | Int64 | Offset in the file of the dictionary page, if any | -| data_page_offset | Int64 | Offset in the file of the first data page, if any | -| total_compressed_size | Int64 | Number of bytes the column chunk's data after encoding and compression (what is stored in the file) | -| total_uncompressed_size | Int64 | Number of bytes the column chunk's data after encoding | - -+-------------------------+-----------+-------------+ - -[`page index`]: https://github.com/apache/parquet-format/blob/master/PageIndex.md +DataFusion SQL engine, see [DataFusion CLI specific functions](functions.md) section +for details. diff --git a/docs/source/user-guide/concepts-readings-events.md b/docs/source/user-guide/concepts-readings-events.md index fef677dd3a621..ad444ef91c474 100644 --- a/docs/source/user-guide/concepts-readings-events.md +++ b/docs/source/user-guide/concepts-readings-events.md @@ -37,6 +37,10 @@ This is a list of DataFusion related blog posts, articles, and other resources. Please open a PR to add any new resources you create or find +- **2025-03-21** [Blog: Efficient Filter Pushdown in Parquet](https://datafusion.apache.org/blog/2025/03/21/parquet-pushdown/) + +- **2025-03-20** [Blog: Parquet Pruning in DataFusion: Read Only What Matters](https://datafusion.apache.org/blog/2025/03/20/parquet-pruning/) + - **2025-02-12** [Video: Alex Kesling on Apache Arrow DataFusion - Papers We Love NYC ](https://www.youtube.com/watch?v=6A4vFRpSq3k) - **2025-01-30** [Video: Data & Drinks: Building Next-Gen Data Systems with Apache DataFusion](https://www.youtube.com/watch?v=GruBeVDoWq4) @@ -134,6 +138,8 @@ This is a list of DataFusion related blog posts, articles, and other resources. ## 📅 Release Notes & Updates +- **2025-03-24** [Apache DataFusion 46.0.0 Released](https://datafusion.apache.org/blog/2025/03/24/datafusion-46.0.0/) + - **2024-09-14** [Apache DataFusion Python 43.1.0 Released](https://datafusion.apache.org/blog/2024/12/14/datafusion-python-43.1.0/) - **2024-08-24** [Apache DataFusion Python 40.1.0 Released, Significant usability updates](https://datafusion.apache.org/blog/2024/08/20/python-datafusion-40.0.0/) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 68e21183938b1..6bc7b90e893ad 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -18,116 +18,234 @@ --> # Configuration Settings -The following configuration options can be passed to `SessionConfig` to control various aspects of query execution. - -For applications which do not expose `SessionConfig`, like `datafusion-cli`, these options may also be set via environment variables. -To construct a session with options from the environment, use `SessionConfig::from_env`. -The name of the environment variable is the option's key, transformed to uppercase and with periods replaced with underscores. -For example, to configure `datafusion.execution.batch_size` you would set the `DATAFUSION_EXECUTION_BATCH_SIZE` environment variable. -Values are parsed according to the [same rules used in casts from Utf8](https://docs.rs/arrow/latest/arrow/compute/kernels/cast/fn.cast.html). -If the value in the environment variable cannot be cast to the type of the configuration option, the default value will be used instead and a warning emitted. -Environment variables are read during `SessionConfig` initialisation so they must be set beforehand and will not affect running sessions. - -| key | default | description | -| ----------------------------------------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | -| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | -| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | -| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | -| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | -| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | -| datafusion.catalog.has_header | true | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | -| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | -| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | -| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | -| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | -| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | -| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | -| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | -| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | -| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | -| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | -| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | -| datafusion.execution.parquet.schema_force_view_types | true | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | -| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | -| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | -| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | -| datafusion.execution.parquet.skip_arrow_metadata | false | (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to | -| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | -| datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | -| datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting max_statistics_size is deprecated, currently it is not being used | -| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 46.0.1 | (writing) Sets "created by" property | -| datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | -| datafusion.execution.parquet.statistics_truncate_length | NULL | (writing) Sets statictics truncate length. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | -| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_on_read | true | (writing) Use any available bloom filters when reading parquet files | -| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | -| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | -| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | -| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | -| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | -| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | -| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | -| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | -| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | -| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | -| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | -| datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | -| datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | -| datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | -| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | -| datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | -| datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | -| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | -| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | -| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | -| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | -| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | -| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | -| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | -| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | -| datafusion.optimizer.repartition_file_scans | true | When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. | -| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | -| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | -| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | -| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | -| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | -| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | -| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | -| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | -| datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | -| datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | -| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | -| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | -| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | -| datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | -| datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | -| datafusion.explain.format | indent | Display format of explain. Default is "indent". When set to "tree", it will print the plan in a tree-rendered format. | -| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | -| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | -| datafusion.sql_parser.enable_options_value_normalization | false | When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. | -| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. | -| datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. | -| datafusion.sql_parser.map_varchar_to_utf8view | false | If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. If false, `VARCHAR` is mapped to `Utf8` during SQL planning. Default is false. | -| datafusion.sql_parser.collect_spans | false | When set to true, the source locations relative to the original SQL query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected and recorded in the logical plan nodes. | -| datafusion.sql_parser.recursion_limit | 50 | Specifies the recursion depth limit when parsing complex SQL Queries | +DataFusion configurations control various aspects of DataFusion planning and execution + +## Setting Configuration Options + +### Programmatically + +You can set the options programmatically via the [`ConfigOptions`] object. For +example, to configure the `datafusion.execution.target_partitions` using the API: + +```rust +use datafusion::common::config::ConfigOptions; +let mut config = ConfigOptions::new(); +config.execution.target_partitions = 1; +``` + +### Via Environment Variables + +You can also set configuration options via environment variables using +[`ConfigOptions::from_env`], for example + +```shell +DATAFUSION_EXECUTION_TARGET_PARTITIONS=1 ./your_program +``` + +### Via SQL + +You can also set configuration options via SQL using the `SET` command. For +example, to configure `datafusion.execution.target_partitions`: + +```sql +SET datafusion.execution.target_partitions = '1'; +``` + +[`configoptions`]: https://docs.rs/datafusion/latest/datafusion/common/config/struct.ConfigOptions.html +[`configoptions::from_env`]: https://docs.rs/datafusion/latest/datafusion/common/config/struct.ConfigOptions.html#method.from_env + +The following configuration settings are available: + +| key | default | description | +| ----------------------------------------------------------------------- | ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | +| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | +| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | +| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | +| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | +| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | +| datafusion.catalog.has_header | true | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | +| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | +| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | +| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | +| datafusion.execution.collect_statistics | true | Should DataFusion collect statistics when first creating a table. Has no effect after the table is created. Applies to the default `ListingTableProvider` in DataFusion. Defaults to true. | +| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | +| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | +| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | +| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | +| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | +| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | +| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | +| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | +| datafusion.execution.parquet.schema_force_view_types | true | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | +| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | +| datafusion.execution.parquet.coerce_int96 | NULL | (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. | +| datafusion.execution.parquet.bloom_filter_on_read | true | (reading) Use any available bloom filters when reading parquet files | +| datafusion.execution.parquet.max_predicate_cache_size | NULL | (reading) The maximum predicate cache size, in bytes. When `pushdown_filters` is enabled, sets the maximum memory used to cache the results of predicate evaluation between filter evaluation and output generation. Decreasing this value will reduce memory usage, but may increase IO and CPU usage. None means use the default parquet reader setting. 0 means no caching. | +| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | +| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | +| datafusion.execution.parquet.skip_arrow_metadata | false | (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to | +| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | +| datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | +| datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | +| datafusion.execution.parquet.created_by | datafusion version 50.2.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | +| datafusion.execution.parquet.statistics_truncate_length | 64 | (writing) Sets statistics truncate length. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | +| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | +| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | +| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | +| datafusion.execution.spill_compression | uncompressed | Sets the compression codec used when spilling data to disk. Since datafusion writes spill files using the Arrow IPC Stream format, only codecs supported by the Arrow IPC Stream Writer are allowed. Valid values are: uncompressed, lz4_frame, zstd. Note: lz4_frame offers faster (de)compression, but typically results in larger spill files. In contrast, zstd achieves higher compression ratios at the cost of slower (de)compression speed. | +| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | +| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | +| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | +| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | +| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | +| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | +| datafusion.execution.listing_table_factory_infer_partitions | true | Should a `ListingTable` created through the `ListingTableFactory` infer table partitions from Hive compliant directories. Defaults to true (partition columns are inferred and will be represented in the table schema). | +| datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | +| datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | +| datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | +| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | +| datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | +| datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | +| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | +| datafusion.execution.objectstore_writer_buffer_size | 10485760 | Size (bytes) of data buffer DataFusion uses when writing output files. This affects the size of the data chunks that are uploaded to remote object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being written, it may be necessary to increase this size to avoid errors from the remote end point. | +| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | +| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | +| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | +| datafusion.optimizer.enable_window_limits | true | When set to true, the optimizer will attempt to push limit operations past window functions, if possible | +| datafusion.optimizer.enable_dynamic_filter_pushdown | true | When set to true attempts to push down dynamic filters generated by operators into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. | +| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | +| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | +| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | +| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | +| datafusion.optimizer.repartition_file_scans | true | When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). For FileSources, only Parquet and CSV formats are currently supported. If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't happen within a single file. If set to `true` for an in-memory source, all memtable's partitions will have their batches repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change the total number of partitions and batches per partition, but does not slice the initial record tables provided to the MemTable on creation. | +| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | +| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | +| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | +| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | +| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | +| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | +| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | +| datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | +| datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | +| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | +| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | +| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | +| datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | +| datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | +| datafusion.explain.format | indent | Display format of explain. Default is "indent". When set to "tree", it will print the plan in a tree-rendered format. | +| datafusion.explain.tree_maximum_render_width | 240 | (format=tree only) Maximum total width of the rendered tree. When set to 0, the tree will have no width limit. | +| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | +| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | +| datafusion.sql_parser.enable_options_value_normalization | false | When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. | +| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. | +| datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. | +| datafusion.sql_parser.map_string_types_to_utf8view | true | If true, string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. If false, they are mapped to `Utf8`. Default is true. | +| datafusion.sql_parser.collect_spans | false | When set to true, the source locations relative to the original SQL query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected and recorded in the logical plan nodes. | +| datafusion.sql_parser.recursion_limit | 50 | Specifies the recursion depth limit when parsing complex SQL Queries | +| datafusion.sql_parser.default_null_ordering | nulls_max | Specifies the default null ordering for query results. There are 4 options: - `nulls_max`: Nulls appear last in ascending order. - `nulls_min`: Nulls appear first in ascending order. - `nulls_first`: Nulls always be first in any order. - `nulls_last`: Nulls always be last in any order. By default, `nulls_max` is used to follow Postgres's behavior. postgres rule: | +| datafusion.format.safe | true | If set to `true` any formatting errors will be written to the output instead of being converted into a [`std::fmt::Error`] | +| datafusion.format.null | | Format string for nulls | +| datafusion.format.date_format | %Y-%m-%d | Date format for date arrays | +| datafusion.format.datetime_format | %Y-%m-%dT%H:%M:%S%.f | Format for DateTime arrays | +| datafusion.format.timestamp_format | %Y-%m-%dT%H:%M:%S%.f | Timestamp format for timestamp arrays | +| datafusion.format.timestamp_tz_format | NULL | Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. | +| datafusion.format.time_format | %H:%M:%S%.f | Time format for time arrays | +| datafusion.format.duration_format | pretty | Duration format. Can be either `"pretty"` or `"ISO8601"` | +| datafusion.format.types_info | false | Show types in visual representation batches | + +# Runtime Configuration Settings + +DataFusion runtime configurations can be set via SQL using the `SET` command. + +For example, to configure `datafusion.runtime.memory_limit`: + +```sql +SET datafusion.runtime.memory_limit = '2G'; +``` + +The following runtime configuration settings are available: + +| key | default | description | +| ------------------------------------------ | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| datafusion.runtime.max_temp_directory_size | 100G | Maximum temporary file directory size. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes. | +| datafusion.runtime.memory_limit | NULL | Maximum memory limit for query execution. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes. | +| datafusion.runtime.metadata_cache_limit | 50M | Maximum memory to use for file metadata cache such as Parquet metadata. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes. | +| datafusion.runtime.temp_directory | NULL | The path to the temporary file directory. | + +# Tuning Guide + +## Short Queries + +By default DataFusion will attempt to maximize parallelism and use all cores -- +For example, if you have 32 cores, each plan will split the data into 32 +partitions. However, if your data is small, the overhead of splitting the data +to enable parallelization can dominate the actual computation. + +You can find out how many cores are being used via the [`EXPLAIN`] command and look +at the number of partitions in the plan. + +[`explain`]: sql/explain.md + +The `datafusion.optimizer.repartition_file_min_size` option controls the minimum file size the +[`ListingTable`] provider will attempt to repartition. However, this +does not apply to user defined data sources and only works when DataFusion has accurate statistics. + +If you know your data is small, you can set the `datafusion.execution.target_partitions` +option to a smaller number to reduce the overhead of repartitioning. For very small datasets (e.g. less +than 1MB), we recommend setting `target_partitions` to 1 to avoid repartitioning altogether. + +```sql +SET datafusion.execution.target_partitions = '1'; +``` + +[`listingtable`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html + +## Memory-limited Queries + +When executing a memory-consuming query under a tight memory limit, DataFusion +will spill intermediate results to disk. + +When the [`FairSpillPool`] is used, memory is divided evenly among partitions. +The higher the value of `datafusion.execution.target_partitions`, the less memory +is allocated to each partition, and the out-of-core execution path may trigger +more frequently, possibly slowing down execution. + +Additionally, while spilling, data is read back in `datafusion.execution.batch_size` size batches. +The larger this value, the fewer spilled sorted runs can be merged. Decreasing this setting +can help reduce the number of subsequent spills required. + +In conclusion, for queries under a very tight memory limit, it's recommended to +set `target_partitions` and `batch_size` to smaller values. + +```sql +-- Query still gets parallelized, but each partition will have more memory to use +SET datafusion.execution.target_partitions = 4; +-- Smaller than the default '8192', while still keep the benefit of vectorized execution +SET datafusion.execution.batch_size = 1024; +``` + +[`fairspillpool`]: https://docs.rs/datafusion/latest/datafusion/execution/memory_pool/struct.FairSpillPool.html diff --git a/docs/source/user-guide/crate-configuration.md b/docs/source/user-guide/crate-configuration.md index f4a1910f5f78f..eecf7f5bde6e1 100644 --- a/docs/source/user-guide/crate-configuration.md +++ b/docs/source/user-guide/crate-configuration.md @@ -19,18 +19,19 @@ # Crate Configuration -This section contains information on how to configure DataFusion in your Rust -project. See the [Configuration Settings] section for a list of options that -control DataFusion's behavior. +This section contains information on how to configure builds of DataFusion in +your Rust project. The [Configuration Settings] section lists options that +control additional aspects DataFusion's runtime behavior. [configuration settings]: configs.md -## Add latest non published DataFusion dependency +## Using the nightly DataFusion builds DataFusion changes are published to `crates.io` according to the [release schedule](https://github.com/apache/datafusion/blob/main/dev/release/README.md#release-process) -If you would like to test out DataFusion changes which are merged but not yet -published, Cargo supports adding dependency directly to GitHub branch: +If you would like to use or test versions of the DataFusion code which are +merged but not yet published, you can use Cargo's [support for adding +dependencies] directly to a GitHub branch: ```toml datafusion = { git = "https://github.com/apache/datafusion", branch = "main"} @@ -50,22 +51,58 @@ datafusion = { git = "https://github.com/apache/datafusion", branch = "main", de More on [Cargo dependencies](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies) -## Optimized Configuration +## Optimizing Builds -For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is -worth noting that using the settings in the `[profile.release]` section will significantly increase the build time. +Here are several suggestions to get the Rust compler to produce faster code when +compiling DataFusion. Note that these changes may increase compile time and +binary size. -```toml -[dependencies] -datafusion = { version = "22.0" } -tokio = { version = "^1.0", features = ["rt-multi-thread"] } -snmalloc-rs = "0.3" +### Generate Code with CPU Specific Instructions + +By default, the Rust compiler produces code that runs on a wide range of CPUs, +but may not take advantage of all the features of your specific CPU (such as +certain [SIMD instructions]). This is especially true for x86_64 CPUs, where the +default target is `x86_64-unknown-linux-gnu`, which only guarantees support for +the `SSE2` instruction set. DataFusion can benefit from the more advanced +instructions in the `AVX2` and `AVX512` to speed up operations like filtering, +aggregation, and joins. To tell the Rust compiler to use these instructions, set +the `RUSTFLAGS` environment variable to specify a more specific target CPU. +We recommend setting `target-cpu` or at least `avx2`, or preferably at least +`native` (whatever the current CPU is). For example, to build and run DataFusion +with optimizations for your current CPU: + +```shell +RUSTFLAGS='-C target-cpu=native' cargo run --release +``` + +[simd instructions]: https://en.wikipedia.org/wiki/SIMD + +### Enable Link Time Optimization / Single Codegen Unit + +You can potentially improve your performance by compiling DataFusion into a +single codegen unit which gives the Rust compiler more opportunity to optimize +across crate boundaries. To do so, modify your projects' `Cargo.toml` to include +`lto = true` and `codegen-units = 1` as shown below. Beware that using a single +codegen unit _significantly_ increases `--release` build times. + +```toml [profile.release] lto = true codegen-units = 1 ``` +### Alternate Allocator: `snmalloc` + +You can also use [snmalloc-rs](https://crates.io/crates/snmalloc-rs) crate as +the memory allocator for DataFusion to improve performance. To do so, add the +dependency to your `Cargo.toml` as shown below. + +```toml +[dependencies] +snmalloc-rs = "0.3" +``` + Then, in `main.rs.` update the memory allocator with the below after your imports: @@ -82,17 +119,10 @@ async fn main() -> datafusion::error::Result<()> { } ``` -Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally -with `native` or at least `avx2`. - -```shell -RUSTFLAGS='-C target-cpu=native' cargo run --release -``` - -## Enable backtraces +## Enable Backtraces -By default Datafusion returns errors as a plain message. There is option to enable more verbose details about the error, -like error backtrace. To enable a backtrace you need to add Datafusion `backtrace` feature to your `Cargo.toml` file: +By default, Datafusion returns errors as a plain text message. You can enable more verbose details about the error, +such as backtraces by enabling the `backtrace` feature to your `Cargo.toml` file like this: ```toml datafusion = { version = "31.0.0", features = ["backtrace"]} diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 96be1bb9e2568..82f1eeb2823dc 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -50,13 +50,38 @@ use datafusion::prelude::*; Here is a minimal example showing the execution of a query using the DataFrame API. +Create DataFrame using macro API from in memory rows + ```rust use datafusion::prelude::*; use datafusion::error::Result; + +#[tokio::main] +async fn main() -> Result<()> { + // Create a new dataframe with in-memory data using macro + let df = dataframe!( + "a" => [1, 2, 3], + "b" => [true, true, false], + "c" => [Some("foo"), Some("bar"), None] + )?; + df.show().await?; + Ok(()) +} +``` + +Create DataFrame from file or in memory rows using standard API + +```rust +use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray}; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::error::Result; use datafusion::functions_aggregate::expr_fn::min; +use datafusion::prelude::*; +use std::sync::Arc; #[tokio::main] async fn main() -> Result<()> { + // Read the data from a csv file let ctx = SessionContext::new(); let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; let df = df.filter(col("a").lt_eq(col("b")))? @@ -64,6 +89,22 @@ async fn main() -> Result<()> { .limit(0, Some(100))?; // Print results df.show().await?; + + // Create a new dataframe with in-memory data + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), + ], + )?; + let df = ctx.read_batch(batch)?; + df.show().await?; + Ok(()) } ``` diff --git a/docs/source/user-guide/explain-usage.md b/docs/source/user-guide/explain-usage.md index d89ed5f0e7ea6..2288cae85dda5 100644 --- a/docs/source/user-guide/explain-usage.md +++ b/docs/source/user-guide/explain-usage.md @@ -40,7 +40,7 @@ Let's see how DataFusion runs a query that selects the top 5 watch lists for the site `http://domcheloveplanet.ru/`: ```sql -EXPLAIN SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip +EXPLAIN FORMAT INDENT SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip FROM 'hits.parquet' WHERE starts_with("URL", 'http://domcheloveplanet.ru/') ORDER BY wid ASC, ip DESC @@ -227,10 +227,10 @@ When predicate pushdown is enabled, `DataSourceExec` with `ParquetSource` gains - `page_index_rows_matched`: number of rows in pages that were tested by a page index filter, and passed - `page_index_rows_pruned`: number of rows in pages that were tested by a page index filter, and did not pass -- `row_groups_matched_bloom_filter`: number of rows in row groups that were tested by a Bloom Filter, and passed -- `row_groups_pruned_bloom_filter`: number of rows in row groups that were tested by a Bloom Filter, and did not pass -- `row_groups_matched_statistics`: number of rows in row groups that were tested by row group statistics (min and max value), and passed -- `row_groups_pruned_statistics`: number of rows in row groups that were tested by row group statistics (min and max value), and did not pass +- `row_groups_matched_bloom_filter`: number of row groups that were tested by a Bloom Filter, and passed +- `row_groups_pruned_bloom_filter`: number of row groups that were tested by a Bloom Filter, and did not pass +- `row_groups_matched_statistics`: number of row groups that were tested by row group statistics (min and max value), and passed +- `row_groups_pruned_statistics`: number of row groups that were tested by row group statistics (min and max value), and did not pass - `pushdown_rows_matched`: rows that were tested by any of the above filtered, and passed all of them (this should be minimum of `page_index_rows_matched`, `row_groups_pruned_bloom_filter`, and `row_groups_pruned_statistics`) - `pushdown_rows_pruned`: rows that were tested by any of the above filtered, and did not pass one of them (this should be sum of `page_index_rows_matched`, `row_groups_pruned_bloom_filter`, and `row_groups_pruned_statistics`) - `predicate_evaluation_errors`: number of times evaluating the filter expression failed (expected to be zero in normal operation) @@ -249,7 +249,7 @@ a separate core. Data crosses between cores only within certain operators such a You can read more about this in the [Partitioning Docs]. -[partitoning docs]: https://docs.rs/datafusion/latest/datafusion/physical_expr/enum.Partitioning.html +[partitioning docs]: https://docs.rs/datafusion/latest/datafusion/physical_expr/enum.Partitioning.html ## Example of an Aggregate Query @@ -268,7 +268,7 @@ LIMIT 10; We can again see the query plan by using `EXPLAIN`: ```sql -> EXPLAIN SELECT "UserID", COUNT(*) FROM 'hits.parquet' GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +> EXPLAIN FORMAT INDENT SELECT "UserID", COUNT(*) FROM 'hits.parquet' GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | plan_type | plan | +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 03ab86eeb813a..56e4369a9b8b5 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -285,27 +285,29 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Functions -| Syntax | Description | -| ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- | -| avg(expr) | Сalculates the average value for `expr`. | -| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | -| approx_median(expr) | Calculates an approximation of the median for `expr`. | -| approx_percentile_cont(expr, percentile) | Calculates an approximation of the specified `percentile` for `expr`. | -| approx_percentile_cont_with_weight(expr, weight_expr, percentile) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. | -| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | -| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | -| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | -| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | -| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | -| count(expr) | Returns the number of rows for `expr`. | -| count_distinct | Creates an expression to represent the count(distinct) aggregate function | -| cube(exprs) | Creates a grouping set for all combination of `exprs` | -| grouping_set(exprs) | Create a grouping set. | -| max(expr) | Finds the maximum value of `expr`. | -| median(expr) | Сalculates the median of `expr`. | -| min(expr) | Finds the minimum value of `expr`. | -| rollup(exprs) | Creates a grouping set for rollup sets. | -| sum(expr) | Сalculates the sum of `expr`. | +| Syntax | Description | +| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| avg(expr) | Сalculates the average value for `expr`. | +| avg_distinct(expr) | Creates an expression to represent the avg(distinct) aggregate function | +| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | +| approx_median(expr) | Calculates an approximation of the median for `expr`. | +| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| approx_percentile_cont_with_weight(expr, weight_expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | +| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | +| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | +| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | +| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | +| count(expr) | Returns the number of rows for `expr`. | +| count_distinct(expr) | Creates an expression to represent the count(distinct) aggregate function | +| cube(exprs) | Creates a grouping set for all combination of `exprs` | +| grouping_set(exprs) | Create a grouping set. | +| max(expr) | Finds the maximum value of `expr`. | +| median(expr) | Сalculates the median of `expr`. | +| min(expr) | Finds the minimum value of `expr`. | +| rollup(exprs) | Creates a grouping set for rollup sets. | +| sum(expr) | Сalculates the sum of `expr`. | +| sum_distinct(expr) | Creates an expression to represent the sum(distinct) aggregate function | ## Aggregate Function Builder diff --git a/docs/source/user-guide/features.md b/docs/source/user-guide/features.md index 1f73ce7eac113..967e81e681f50 100644 --- a/docs/source/user-guide/features.md +++ b/docs/source/user-guide/features.md @@ -43,7 +43,7 @@ - [x] Filter (`WHERE`) - [x] Filter post-aggregate (`HAVING`) - [x] Sorting (`ORDER BY`) -- [x] Limit (`LIMIT` +- [x] Limit (`LIMIT`) - [x] Aggregate (`GROUP BY`) - [x] cast /try_cast - [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html) @@ -93,7 +93,8 @@ - [x] Memory limits enforced - [x] Spilling (to disk) Sort - [x] Spilling (to disk) Grouping -- [ ] Spilling (to disk) Joins +- [x] Spilling (to disk) Sort Merge Join +- [ ] Spilling (to disk) Hash Join ## Data Sources diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 14d6ab177dc34..dc4825dc06dfb 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -40,9 +40,9 @@ Arrow](https://arrow.apache.org/). ## Features - Feature-rich [SQL support](https://datafusion.apache.org/user-guide/sql/index.html) and [DataFrame API](https://datafusion.apache.org/user-guide/dataframe.html) -- Blazingly fast, vectorized, multi-threaded, streaming execution engine. +- Blazingly fast, vectorized, multithreaded, streaming execution engine. - Native support for Parquet, CSV, JSON, and Avro file formats. Support - for custom file formats and non file datasources via the `TableProvider` trait. + for custom file formats and non-file datasources via the `TableProvider` trait. - Many extension points: user defined scalar/aggregate/window functions, DataSources, SQL, other query languages, custom plan and execution nodes, optimizer passes, and more. - Streaming, asynchronous IO directly from popular object stores, including AWS S3, @@ -68,25 +68,25 @@ DataFusion can be used without modification as an embedded SQL engine or can be customized and used as a foundation for building new systems. -While most current usecases are "analytic" or (throughput) some +While most current use cases are "analytic" or (throughput) some components of DataFusion such as the plan representations, are suitable for "streaming" and "transaction" style systems (low latency). Here are some example systems built using DataFusion: -- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such a [Ballista]. +- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such as [Ballista] - New query language engines such as [prql-query] and accelerators such as [VegaFusion] - Research platform for new Database Systems, such as [Flock] -- SQL support to another library, such as [dask sql] +- SQL support to another library, such as [Vortex] - Streaming data platforms such as [Synnada] - Tools for reading / sorting / transcoding Parquet, CSV, AVRO, and JSON files such as [qv] -- Native Spark runtime replacement such as [Blaze] +- Native Spark runtime replacement such as [Auron] By using DataFusion, projects are freed to focus on their specific features, and avoid reimplementing general (but still necessary) features such as an expression representation, standard optimizations, -parellelized streaming execution plans, file format support, etc. +parallelized streaming execution plans, file format support, etc. ## Known Users @@ -95,56 +95,67 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust -- [Ballista](https://github.com/apache/datafusion-ballista) Distributed SQL Query Engine -- [Blaze](https://github.com/kwai/blaze) The Blaze accelerator for Apache Spark leverages native vectorized execution to accelerate query processing -- [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database +- [ArkFlow](https://github.com/arkflow-rs/arkflow) High-performance Rust stream processing engine +- [Auron] The Auron accelerator for big data engine (e.g., Spark, Flink) leverages native vectorized execution to accelerate query processing +- [Ballista] Distributed SQL Query Engine +- [CnosDB] Open Source Distributed Time Series Database - [Comet](https://github.com/apache/datafusion-comet) Apache Spark native query execution plugin -- [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) -- [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python +- [Cube Store] Cube’s universal semantic layer platform is the next evolution of OLAP technology for AI, BI, spreadsheets, and embedded analytics - [datafusion-dft](https://github.com/datafusion-contrib/datafusion-dft) Batteries included CLI, TUI, and server implementations for DataFusion. -- [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake -- [Exon](https://github.com/wheretrue/exon) Analysis toolkit for life-science applications +- [dbt Fusion engine](https://github.com/dbt-labs/dbt-fusion) The dbt Fusion engine, written in Rust, designed for speed and correctness with a native SQL understanding across DWH SQL dialects. +- [delta-rs] Native Rust implementation of Delta Lake +- [EDB Postgres Lakehouse] built with [Seafowl] +- [Feldera](https://github.com/feldera/feldera) Fast query engine for incremental computation - [Funnel](https://funnel.io/) Data Platform powering Marketing Intelligence applications. - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. -- [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database -- [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database -- [InfluxDB](https://github.com/influxdata/influxdb) Time Series Database -- [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline +- [GreptimeDB] Open Source & Cloud Native Distributed Time Series Database +- [HoraeDB] Distributed Time-Series Database +- [Iceberg-rust](https://github.com/apache/iceberg-rust) Rust implementation of Apache Iceberg +- [InfluxDB] Time Series Database +- [Kamu] Planet-scale streaming data pipeline - [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. - [Lance](https://github.com/lancedb/lance) Modern columnar data format for ML -- [OpenObserve](https://github.com/openobserve/openobserve) Distributed cloud native observability platform +- [OpenObserve] Distributed cloud native observability platform - [ParadeDB](https://github.com/paradedb/paradedb) PostgreSQL for Search & Analytics -- [Parseable](https://github.com/parseablehq/parseable) Log storage and observability platform +- [Parseable] Log storage and observability platform - [Polygon.io](https://polygon.io/) Stock Market API -- [qv](https://github.com/timvw/qv) Quickly view your data +- [qv] Quickly view your data +- [R2 Query Engine](https://blog.cloudflare.com/r2-sql-deep-dive/) Cloudflare's distributed engine for querying data in Iceberg Catalogs +- [rerun.io](https://rerun.io/) Visualize and query robotics logs and transform them into training data. - [Restate](https://github.com/restatedev) Easily build resilient applications using distributed durable async/await -- [ROAPI](https://github.com/roapi/roapi) -- [Sail](https://github.com/lakehq/sail) Unifying stream, batch, and AI workloads with Apache Spark compatibility -- [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database +- [ROAPI] Create full-fledged APIs for slowly moving datasets without writing a single line of code +- [Sail](https://github.com/lakehq/sail) Unifying stream, batch and AI workloads with Apache Spark compatibility +- [SedonaDB](https://github.com/apache/sedona-db) A single-node analytical database engine with geospatial as a first-class citizen - [Sleeper](https://github.com/gchq/sleeper) Serverless, cloud-native, log-structured merge tree based, scalable key-value store -- [Spice.ai](https://github.com/spiceai/spiceai) Unified SQL query interface & materialization engine -- [Synnada](https://synnada.ai/) Streaming-first framework for data products -- [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar +- [Spice.ai] Building blocks for data-driven AI applications +- [Synnada] Streaming-first framework for data products +- [VegaFusion] Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar +- [Vortex] An extensible, state of the art columnar file format - [Telemetry](https://telemetry.sh/) Structured logging made easy +- [Xorq](https://github.com/xorq-labs/xorq/) Xorq is a multi-engine batch transformation framework built on Ibis, DataFusion and Arrow Here are some less active projects that used DataFusion: - [bdt](https://github.com/datafusion-contrib/bdt) Boring Data Tool -- [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) -- [Flock](https://github.com/flock-lab/flock) -- [Tensorbase](https://github.com/tensorbase/tensorbase) +- [Cloudfuse Buzz] +- [Dask SQL] Distributed SQL query engine in Python +- [Exon] Analysis toolkit for life-science applications +- [Flock] +- [Tensorbase] [ballista]: https://github.com/apache/datafusion-ballista -[blaze]: https://github.com/blaze-init/blaze +[auron]: https://github.com/apache/auron [cloudfuse buzz]: https://github.com/cloudfuse-io/buzz-rust [cnosdb]: https://github.com/cnosdb/cnosdb [cube store]: https://github.com/cube-js/cube.js/tree/master/rust [dask sql]: https://github.com/dask-contrib/dask-sql [datafusion-tui]: https://github.com/datafusion-contrib/datafusion-tui [delta-rs]: https://github.com/delta-io/delta-rs +[edb postgres lakehouse]: https://www.enterprisedb.com/products/analytics +[exon]: https://github.com/wheretrue/exon [flock]: https://github.com/flock-lab/flock [kamu]: https://github.com/kamu-data/kamu-cli -[greptime db]: https://github.com/GreptimeTeam/greptimedb +[greptimedb]: https://github.com/GreptimeTeam/greptimedb [horaedb]: https://github.com/apache/incubator-horaedb [influxdb]: https://github.com/influxdata/influxdb [openobserve]: https://github.com/openobserve/openobserve @@ -156,7 +167,8 @@ Here are some less active projects that used DataFusion: [spice.ai]: https://github.com/spiceai/spiceai [synnada]: https://synnada.ai/ [tensorbase]: https://github.com/tensorbase/tensorbase -[vegafusion]: https://vegafusion.io/ "if you know of another project, please submit a PR to add a link!" +[vegafusion]: https://vegafusion.io/ +[vortex]: https://vortex.dev/ "if you know of another project, please submit a PR to add a link!" ## Integrations and Extensions @@ -179,6 +191,20 @@ provide integrations with other systems, some of which are described below: ## Why DataFusion? - _High Performance_: Leveraging Rust and Arrow's memory model, DataFusion is very fast. -- _Easy to Connect_: Being part of the Apache Arrow ecosystem (Arrow, Parquet and Flight), DataFusion works well with the rest of the big data ecosystem +- _Easy to Connect_: Being part of the Apache Arrow ecosystem (Arrow, Parquet, and Flight), DataFusion works well with the rest of the big data ecosystem - _Easy to Embed_: Allowing extension at almost any point in its design, and published regularly as a crate on [crates.io](http://crates.io), DataFusion can be integrated and tailored for your specific usecase. - _High Quality_: Extensively tested, both by itself and with the rest of the Arrow ecosystem, DataFusion can and is used as the foundation for production systems. + +## Rust Version Compatibility Policy + +The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow +[semantic versioning](https://semver.org/). A Rust toolchain release can be identified +by a version string like `1.80.0`, or more generally `major.minor.patch`. + +DataFusion supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. + +For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. + +Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. + +DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index c7f5c5f674424..205962031b1d0 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -29,6 +29,25 @@ dev/update_function_docs.sh file for updating surrounding text. Aggregate functions operate on a set of values to compute a single result. +## Filter clause + +Aggregate functions support the SQL `FILTER (WHERE ...)` clause to restrict which input rows contribute to the aggregate result. + +```sql +function([exprs]) FILTER (WHERE condition) +``` + +Example: + +```sql +SELECT + sum(salary) FILTER (WHERE salary > 0) AS sum_positive_salaries, + count(*) FILTER (WHERE active) AS active_count +FROM employees; +``` + +Note: When no rows pass the filter, `COUNT` returns `0` while `SUM`/`AVG`/`MIN`/`MAX` return `NULL`. + ## General Functions - [array_agg](#array_agg) @@ -371,10 +390,10 @@ min(expression) ### `string_agg` -Concatenates the values of string expressions and places separator values between them. +Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. ```sql -string_agg(expression, delimiter) +string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) ``` #### Arguments @@ -390,7 +409,21 @@ string_agg(expression, delimiter) +--------------------------+ | names_list | +--------------------------+ -| Alice, Bob, Charlie | +| Alice, Bob, Bob, Charlie | ++--------------------------+ +> SELECT string_agg(name, ', ' ORDER BY name DESC) AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Charlie, Bob, Bob, Alice | ++--------------------------+ +> SELECT string_agg(DISTINCT name, ', ' ORDER BY name DESC) AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Charlie, Bob, Alice | +--------------------------+ ``` @@ -604,6 +637,29 @@ regr_avgx(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table daily_sales(day int, total_sales int) as values (1,100), (2,150), (3,200), (4,NULL), (5,250); +select * from daily_sales; ++-----+-------------+ +| day | total_sales | +| --- | ----------- | +| 1 | 100 | +| 2 | 150 | +| 3 | 200 | +| 4 | NULL | +| 5 | 250 | ++-----+-------------+ + +SELECT regr_avgx(total_sales, day) AS avg_day FROM daily_sales; ++----------+ +| avg_day | ++----------+ +| 2.75 | ++----------+ +``` + ### `regr_avgy` Computes the average of the dependent variable (output) expression_y for the non-null paired data points. @@ -617,6 +673,30 @@ regr_avgy(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table daily_temperature(day int, temperature int) as values (1,30), (2,32), (3, NULL), (4,35), (5,36); +select * from daily_temperature; ++-----+-------------+ +| day | temperature | +| --- | ----------- | +| 1 | 30 | +| 2 | 32 | +| 3 | NULL | +| 4 | 35 | +| 5 | 36 | ++-----+-------------+ + +-- temperature as Dependent Variable(Y), day as Independent Variable(X) +SELECT regr_avgy(temperature, day) AS avg_temperature FROM daily_temperature; ++-----------------+ +| avg_temperature | ++-----------------+ +| 33.25 | ++-----------------+ +``` + ### `regr_count` Counts the number of non-null paired data points. @@ -630,6 +710,29 @@ regr_count(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table daily_metrics(day int, user_signups int) as values (1,100), (2,120), (3, NULL), (4,110), (5,NULL); +select * from daily_metrics; ++-----+---------------+ +| day | user_signups | +| --- | ------------- | +| 1 | 100 | +| 2 | 120 | +| 3 | NULL | +| 4 | 110 | +| 5 | NULL | ++-----+---------------+ + +SELECT regr_count(user_signups, day) AS valid_pairs FROM daily_metrics; ++-------------+ +| valid_pairs | ++-------------+ +| 3 | ++-------------+ +``` + ### `regr_intercept` Computes the y-intercept of the linear regression line. For the equation (y = kx + b), this function returns b. @@ -643,6 +746,30 @@ regr_intercept(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table weekly_performance(week int, productivity_score int) as values (1,60), (2,65), (3, 70), (4,75), (5,80); +select * from weekly_performance; ++------+---------------------+ +| week | productivity_score | +| ---- | ------------------- | +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | +| 4 | 75 | +| 5 | 80 | ++------+---------------------+ + +SELECT regr_intercept(productivity_score, week) AS intercept FROM weekly_performance; ++----------+ +|intercept| +|intercept | ++----------+ +| 55 | ++----------+ +``` + ### `regr_r2` Computes the square of the correlation coefficient between the independent and dependent variables. @@ -656,6 +783,29 @@ regr_r2(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table weekly_performance(day int ,user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80); +select * from weekly_performance; ++-----+--------------+ +| day | user_signups | ++-----+--------------+ +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | +| 4 | 75 | +| 5 | 80 | ++-----+--------------+ + +SELECT regr_r2(user_signups, day) AS r_squared FROM weekly_performance; ++---------+ +|r_squared| ++---------+ +| 1.0 | ++---------+ +``` + ### `regr_slope` Returns the slope of the linear regression line for non-null pairs in aggregate columns. Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. @@ -669,6 +819,29 @@ regr_slope(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table weekly_performance(day int, user_signups int) as values (1,60), (2,65), (3, 70), (4,75), (5,80); +select * from weekly_performance; ++-----+--------------+ +| day | user_signups | ++-----+--------------+ +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | +| 4 | 75 | +| 5 | 80 | ++-----+--------------+ + +SELECT regr_slope(user_signups, day) AS slope FROM weekly_performance; ++--------+ +| slope | ++--------+ +| 5.0 | ++--------+ +``` + ### `regr_sxx` Computes the sum of squares of the independent variable. @@ -682,6 +855,29 @@ regr_sxx(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table study_hours(student_id int, hours int, test_score int) as values (1,2,55), (2,4,65), (3,6,75), (4,8,85), (5,10,95); +select * from study_hours; ++------------+-------+------------+ +| student_id | hours | test_score | ++------------+-------+------------+ +| 1 | 2 | 55 | +| 2 | 4 | 65 | +| 3 | 6 | 75 | +| 4 | 8 | 85 | +| 5 | 10 | 95 | ++------------+-------+------------+ + +SELECT regr_sxx(test_score, hours) AS sxx FROM study_hours; ++------+ +| sxx | ++------+ +| 40.0 | ++------+ +``` + ### `regr_sxy` Computes the sum of products of paired data points. @@ -695,6 +891,27 @@ regr_sxy(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table employee_productivity(week int, productivity_score int) as values(1,60), (2,65), (3,70); +select * from employee_productivity; ++------+--------------------+ +| week | productivity_score | ++------+--------------------+ +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | ++------+--------------------+ + +SELECT regr_sxy(productivity_score, week) AS sum_product_deviations FROM employee_productivity; ++------------------------+ +| sum_product_deviations | ++------------------------+ +| 10.0 | ++------------------------+ +``` + ### `regr_syy` Computes the sum of squares of the dependent variable. @@ -708,6 +925,27 @@ regr_syy(expression_y, expression_x) - **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +create table employee_productivity(week int, productivity_score int) as values (1,60), (2,65), (3,70); +select * from employee_productivity; ++------+--------------------+ +| week | productivity_score | ++------+--------------------+ +| 1 | 60 | +| 2 | 65 | +| 3 | 70 | ++------+--------------------+ + +SELECT regr_syy(productivity_score, week) AS sum_squares_y FROM employee_productivity; ++---------------+ +| sum_squares_y | ++---------------+ +| 50.0 | ++---------------+ +``` + ### `stddev` Returns the standard deviation of a set of numbers. @@ -794,7 +1032,7 @@ approx_distinct(expression) ### `approx_median` -Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`. +Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`. ```sql approx_median(expression) @@ -820,7 +1058,7 @@ approx_median(expression) Returns the approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont(expression, percentile, centroids) +approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -832,12 +1070,36 @@ approx_percentile_cont(expression, percentile, centroids) #### Example ```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ +> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++-----------------------------------------------------------------------+ +| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | ++-----------------------------------------------------------------------+ +| 65.0 | ++-----------------------------------------------------------------------+ +``` + +An alternate syntax is also supported: + +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + > SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; -+-------------------------------------------------+ -| approx_percentile_cont(column_name, 0.75, 100) | -+-------------------------------------------------+ -| 65.0 | -+-------------------------------------------------+ ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ ``` ### `approx_percentile_cont_with_weight` @@ -845,7 +1107,7 @@ approx_percentile_cont(expression, percentile, centroids) Returns the weighted approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont_with_weight(expression, weight, percentile) +approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -853,14 +1115,32 @@ approx_percentile_cont_with_weight(expression, weight, percentile) - **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. - **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). +- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. #### Example +```sql +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++---------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) | ++---------------------------------------------------------------------------------------------+ +| 78.5 | ++---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ +``` + +An alternative syntax is also supported: + ```sql > SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; -+----------------------------------------------------------------------+ ++--------------------------------------------------+ | approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | -+----------------------------------------------------------------------+ -| 78.5 | -+----------------------------------------------------------------------+ ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ ``` diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 18c95cdea70ed..d977a4396e40d 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -60,20 +60,20 @@ select arrow_cast(now(), 'Timestamp(Second, None)'); ## Numeric Types -| SQL DataType | Arrow DataType | Notes | -| ------------------------------------ | :----------------------------- | ----------------------------------------------------------------------------------------------------- | -| `TINYINT` | `Int8` | | -| `SMALLINT` | `Int16` | | -| `INT` or `INTEGER` | `Int32` | | -| `BIGINT` | `Int64` | | -| `TINYINT UNSIGNED` | `UInt8` | | -| `SMALLINT UNSIGNED` | `UInt16` | | -| `INT UNSIGNED` or `INTEGER UNSIGNED` | `UInt32` | | -| `BIGINT UNSIGNED` | `UInt64` | | -| `FLOAT` | `Float32` | | -| `REAL` | `Float32` | | -| `DOUBLE` | `Float64` | | -| `DECIMAL(precision, scale)` | `Decimal128(precision, scale)` | Decimal support is currently experimental ([#3523](https://github.com/apache/datafusion/issues/3523)) | +| SQL DataType | Arrow DataType | +| ------------------------------------ | :----------------------------- | +| `TINYINT` | `Int8` | +| `SMALLINT` | `Int16` | +| `INT` or `INTEGER` | `Int32` | +| `BIGINT` | `Int64` | +| `TINYINT UNSIGNED` | `UInt8` | +| `SMALLINT UNSIGNED` | `UInt16` | +| `INT UNSIGNED` or `INTEGER UNSIGNED` | `UInt32` | +| `BIGINT UNSIGNED` | `UInt64` | +| `FLOAT` | `Float32` | +| `REAL` | `Float32` | +| `DOUBLE` | `Float64` | +| `DECIMAL(precision, scale)` | `Decimal128(precision, scale)` | ## Date/Time Types diff --git a/docs/source/user-guide/sql/ddl.md b/docs/source/user-guide/sql/ddl.md index 71475cff9a39b..bd41f691bf90b 100644 --- a/docs/source/user-guide/sql/ddl.md +++ b/docs/source/user-guide/sql/ddl.md @@ -74,7 +74,7 @@ LOCATION := ( , ...) ``` -For a detailed list of write related options which can be passed in the OPTIONS key_value_list, see [Write Options](write_options). +For a comprehensive list of format-specific options that can be specified in the `OPTIONS` clause, see [Format Options](format_options.md). `file_type` is one of `CSV`, `ARROW`, `PARQUET`, `AVRO` or `JSON` @@ -82,6 +82,8 @@ For a detailed list of write related options which can be passed in the OPTIONS a path to a file or directory of partitioned files locally or on an object store. +### Example: Parquet + Parquet data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement such as the following. It is not necessary to provide schema information for Parquet files. @@ -91,6 +93,23 @@ STORED AS PARQUET LOCATION '/mnt/nyctaxi/tripdata.parquet'; ``` +:::{note} +Statistics +: By default, when a table is created, DataFusion will read the files +to gather statistics, which can be expensive but can accelerate subsequent +queries substantially. If you don't want to gather statistics +when creating a table, set the `datafusion.execution.collect_statistics` +configuration option to `false` before creating the table. For example: + +```sql +SET datafusion.execution.collect_statistics = false; +``` + +See the [config settings docs](../configs.md) for more details. +::: + +### Example: Comma Separated Value (CSV) + CSV data sources can also be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. The schema will be inferred based on scanning a subset of the file. @@ -101,6 +120,8 @@ LOCATION '/path/to/aggregate_simple.csv' OPTIONS ('has_header' 'true'); ``` +### Example: Compression + It is also possible to use compressed files, such as `.csv.gz`: ```sql @@ -111,6 +132,8 @@ LOCATION '/path/to/aggregate_simple.csv.gz' OPTIONS ('has_header' 'true'); ``` +### Example: Specifying Schema + It is also possible to specify the schema manually. ```sql @@ -134,6 +157,8 @@ LOCATION '/path/to/aggregate_test_100.csv' OPTIONS ('has_header' 'true'); ``` +### Example: Partitioned Tables + It is also possible to specify a directory that contains a partitioned table (multiple files with the same schema) @@ -144,7 +169,38 @@ LOCATION '/path/to/directory/of/files' OPTIONS ('has_header' 'true'); ``` -With `CREATE UNBOUNDED EXTERNAL TABLE` SQL statement. We can create unbounded data sources such as following: +Tables that are partitioned using a Hive compliant partitioning scheme will have their columns and values automatically +detected and incorporated into the table's schema and data. Given the following example directory structure: + +```console +hive_partitioned/ +├── a=1 +│   └── b=200 +│   └── file1.parquet +└── a=2 + └── b=100 + └── file2.parquet +``` + +Users can specify the top level `hive_partitioned` directory as an `EXTERNAL TABLE` and leverage the Hive partitions to query +and filter data. + +```sql +CREATE EXTERNAL TABLE hive_partitioned +STORED AS PARQUET +LOCATION '/path/to/hive_partitioned/'; + +SELECT count(*) FROM hive_partitioned WHERE b=100; ++------------------+ +| count(*) | ++------------------+ +| 1 | ++------------------+ +``` + +### Example: Unbounded Data Sources + +We can create unbounded data sources using the `CREATE UNBOUNDED EXTERNAL TABLE` SQL statement. ```sql CREATE UNBOUNDED EXTERNAL TABLE taxi @@ -154,6 +210,8 @@ LOCATION '/mnt/nyctaxi/tripdata.parquet'; Note that this statement actually reads data from a fixed-size file, so a better example would involve reading from a FIFO file. Nevertheless, once Datafusion sees the `UNBOUNDED` keyword in a data source, it tries to execute queries that refer to this unbounded source in streaming fashion. If this is not possible according to query specifications, plan generation fails stating it is not possible to execute given query in streaming fashion. Note that queries that can run with unbounded sources (i.e. in streaming mode) are a subset of those that can with bounded sources. A query that fails with unbounded source(s) may work with bounded source(s). +### Example: `WITH ORDER` Clause + When creating an output from a data source that is already ordered by an expression, you can pre-specify the order of the data using the `WITH ORDER` clause. This applies even if the expression used for @@ -178,7 +236,7 @@ CREATE EXTERNAL TABLE test ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH ORDER (c2 ASC, c5 + c8 DESC NULL FIRST) +WITH ORDER (c2 ASC, c5 + c8 DESC NULLS FIRST) LOCATION '/path/to/aggregate_test_100.csv' OPTIONS ('has_header' 'true'); ``` @@ -190,7 +248,7 @@ WITH ORDER (sort_expression1 [ASC | DESC] [NULLS { FIRST | LAST }] [, sort_expression2 [ASC | DESC] [NULLS { FIRST | LAST }] ...]) ``` -### Cautions when using the WITH ORDER Clause +#### Cautions when using the WITH ORDER Clause - It's important to understand that using the `WITH ORDER` clause in the `CREATE EXTERNAL TABLE` statement only specifies the order in which the data should be read from the external file. If the data in the file is not already sorted according to the specified order, then the results may not be correct. @@ -287,3 +345,78 @@ DROP VIEW [ IF EXISTS ] view_name; -- drop users_v view from the customer_a schema DROP VIEW IF EXISTS customer_a.users_v; ``` + +## DESCRIBE + +Displays the schema of a table, showing column names, data types, and nullable status. Both `DESCRIBE` and `DESC` are supported as aliases. + +

+{ DESCRIBE | DESC } table_name
+
+ +The output contains three columns: + +- `column_name`: The name of the column +- `data_type`: The data type of the column (e.g., Int32, Utf8, Boolean) +- `is_nullable`: Whether the column can contain null values (YES/NO) + +### Example: Basic table description + +```sql +-- Create a table +CREATE TABLE users AS VALUES (1, 'Alice', true), (2, 'Bob', false); + +-- Describe the table structure +DESCRIBE users; +``` + +Output: + +```sql ++--------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++--------------+-----------+-------------+ +| column1 | Int64 | YES | +| column2 | Utf8 | YES | +| column3 | Boolean | YES | ++--------------+-----------+-------------+ +``` + +### Example: Using DESC alias + +```sql +-- DESC is an alias for DESCRIBE +DESC users; +``` + +### Example: Describing external tables + +```sql +-- Create an external table +CREATE EXTERNAL TABLE taxi +STORED AS PARQUET +LOCATION '/mnt/nyctaxi/tripdata.parquet'; + +-- Describe its schema +DESCRIBE taxi; +``` + +Output might show: + +```sql ++--------------------+-----------------------------+-------------+ +| column_name | data_type | is_nullable | ++--------------------+-----------------------------+-------------+ +| vendor_id | Int32 | YES | +| pickup_datetime | Timestamp(Nanosecond, None) | NO | +| passenger_count | Int32 | YES | +| trip_distance | Float64 | YES | ++--------------------+-----------------------------+-------------+ +``` + +The `DESCRIBE` command works with all table types in DataFusion, including: + +- Regular tables created with `CREATE TABLE` +- External tables created with `CREATE EXTERNAL TABLE` +- Views created with `CREATE VIEW` +- Tables in different schemas using qualified names (e.g., `DESCRIBE schema_name.table_name`) diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 4eda59d6dea10..c29447f23cd9c 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -49,7 +49,7 @@ The output format is determined by the first match of the following rules: 1. Value of `STORED AS` 2. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) -For a detailed list of valid OPTIONS, see [Write Options](write_options). +For a detailed list of valid OPTIONS, see [Format Options](format_options.md). ### Examples diff --git a/docs/source/user-guide/sql/explain.md b/docs/source/user-guide/sql/explain.md index f89e854ebffd5..c5e2e215a6b66 100644 --- a/docs/source/user-guide/sql/explain.md +++ b/docs/source/user-guide/sql/explain.md @@ -39,39 +39,7 @@ the format from the [configuration value] `datafusion.explain.format`. [configuration value]: ../configs.md -### `indent` format (default) - -The `indent` format shows both the logical and physical plan, with one line for -each operator in the plan. Child plans are indented to show the hierarchy. - -See [Reading Explain Plans](../explain-usage.md) for more information on how to interpret these plans. - -```sql -> CREATE TABLE t(x int, b int) AS VALUES (1, 2), (2, 3); -0 row(s) fetched. -Elapsed 0.004 seconds. - -> EXPLAIN SELECT SUM(x) FROM t GROUP BY b; -+---------------+-------------------------------------------------------------------------------+ -| plan_type | plan | -+---------------+-------------------------------------------------------------------------------+ -| logical_plan | Projection: sum(t.x) | -| | Aggregate: groupBy=[[t.b]], aggr=[[sum(CAST(t.x AS Int64))]] | -| | TableScan: t projection=[x, b] | -| physical_plan | ProjectionExec: expr=[sum(t.x)@1 as sum(t.x)] | -| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[sum(t.x)] | -| | CoalesceBatchesExec: target_batch_size=8192 | -| | RepartitionExec: partitioning=Hash([b@0], 16), input_partitions=16 | -| | RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=1 | -| | AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[sum(t.x)] | -| | DataSourceExec: partitions=1, partition_sizes=[1] | -| | | -+---------------+-------------------------------------------------------------------------------+ -2 row(s) fetched. -Elapsed 0.004 seconds. -``` - -### `tree` format +### `tree` format (default) The `tree` format is modeled after [DuckDB plans] and is designed to be easier to see the high level structure of the plan @@ -103,7 +71,7 @@ to see the high level structure of the plan | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | -| | │ output_partition_count: │ | +| | │ input_partition_count: │ | | | │ 16 │ | | | │ │ | | | │ partitioning_scheme: │ | @@ -112,7 +80,7 @@ to see the high level structure of the plan | | ┌─────────────┴─────────────┐ | | | │ RepartitionExec │ | | | │ -------------------- │ | -| | │ output_partition_count: │ | +| | │ input_partition_count: │ | | | │ 1 │ | | | │ │ | | | │ partitioning_scheme: │ | @@ -138,6 +106,38 @@ to see the high level structure of the plan Elapsed 0.016 seconds. ``` +### `indent` format + +The `indent` format shows both the logical and physical plan, with one line for +each operator in the plan. Child plans are indented to show the hierarchy. + +See [Reading Explain Plans](../explain-usage.md) for more information on how to interpret these plans. + +```sql +> CREATE TABLE t(x int, b int) AS VALUES (1, 2), (2, 3); +0 row(s) fetched. +Elapsed 0.004 seconds. + +> EXPLAIN FORMAT INDENT SELECT SUM(x) FROM t GROUP BY b; ++---------------+-------------------------------------------------------------------------------+ +| plan_type | plan | ++---------------+-------------------------------------------------------------------------------+ +| logical_plan | Projection: sum(t.x) | +| | Aggregate: groupBy=[[t.b]], aggr=[[sum(CAST(t.x AS Int64))]] | +| | TableScan: t projection=[x, b] | +| physical_plan | ProjectionExec: expr=[sum(t.x)@1 as sum(t.x)] | +| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[sum(t.x)] | +| | CoalesceBatchesExec: target_batch_size=8192 | +| | RepartitionExec: partitioning=Hash([b@0], 16), input_partitions=16 | +| | RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=1 | +| | AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[sum(t.x)] | +| | DataSourceExec: partitions=1, partition_sizes=[1] | +| | | ++---------------+-------------------------------------------------------------------------------+ +2 row(s) fetched. +Elapsed 0.004 seconds. +``` + ### `pgjson` format The `pgjson` format is modeled after [Postgres JSON] format. diff --git a/docs/source/user-guide/sql/format_options.md b/docs/source/user-guide/sql/format_options.md new file mode 100644 index 0000000000000..e8008eafb166c --- /dev/null +++ b/docs/source/user-guide/sql/format_options.md @@ -0,0 +1,180 @@ + + +# Format Options + +DataFusion supports customizing how data is read from or written to disk as a result of a `COPY`, `INSERT INTO`, or `CREATE EXTERNAL TABLE` statements. There are a few special options, file format (e.g., CSV or Parquet) specific options, and Parquet column-specific options. In some cases, Options can be specified in multiple ways with a set order of precedence. + +## Specifying Options and Order of Precedence + +Format-related options can be specified in three ways, in decreasing order of precedence: + +- `CREATE EXTERNAL TABLE` syntax +- `COPY` option tuples +- Session-level config defaults + +For a list of supported session-level config defaults, see [Configuration Settings](../configs). These defaults apply to all operations but have the lowest level of precedence. + +If creating an external table, table-specific format options can be specified when the table is created using the `OPTIONS` clause: + +```sql +CREATE EXTERNAL TABLE + my_table(a bigint, b bigint) + STORED AS csv + LOCATION '/tmp/my_csv_table/' + OPTIONS( + NULL_VALUE 'NAN', + 'has_header' 'true', + 'format.delimiter' ';' + ); +``` + +When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (e.g., gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified, the `OPTIONS` setting will be ignored. + +For example, with the table defined above, running the following command: + +```sql +INSERT INTO my_table VALUES(1,2); +``` + +Results in a new CSV file with the specified options: + +```shell +$ cat /tmp/my_csv_table/bmC8zWFvLMtWX68R_0.csv +a;b +1;2 +``` + +Finally, options can be passed when running a `COPY` command. + +```sql +COPY source_table + TO 'test/table_with_options' + PARTITIONED BY (column3, column4) + OPTIONS ( + format parquet, + compression snappy, + 'compression::column1' 'zstd(5)', + ) +``` + +In this example, we write the entire `source_table` out to a folder of Parquet files. One Parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified, all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the Parquet file will use the ZSTD compression codec with compression level `5`. In general, Parquet options that support column-specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`. + +# Available Options + +## JSON Format Options + +The following options are available when reading or writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail. + +| Option | Description | Default Value | +| ----------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| COMPRESSION | Sets the compression that should be applied to the entire JSON file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t(a int) +STORED AS JSON +LOCATION '/tmp/foo/' +OPTIONS('COMPRESSION' 'gzip'); +``` + +## CSV Format Options + +The following options are available when reading or writing CSV files. Note: If any unsupported option is specified, an error will be raised and the query will fail. + +| Option | Description | Default Value | +| -------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ------------------ | +| COMPRESSION | Sets the compression that should be applied to the entire CSV file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | +| HAS_HEADER | Sets if the CSV file should include column headers. If not set, uses session or system default. | None | +| DELIMITER | Sets the character which should be used as the column delimiter within the CSV file. | `,` (comma) | +| QUOTE | Sets the character which should be used for quoting values within the CSV file. | `"` (double quote) | +| TERMINATOR | Sets the character which should be used as the line terminator within the CSV file. | None | +| ESCAPE | Sets the character which should be used for escaping special characters within the CSV file. | None | +| DOUBLE_QUOTE | Sets if quotes within quoted fields should be escaped by doubling them (e.g., `"aaa""bbb"`). | None | +| NEWLINES_IN_VALUES | Sets if newlines in quoted values are supported. If not set, uses session or system default. | None | +| DATE_FORMAT | Sets the format that dates should be encoded in within the CSV file. | None | +| DATETIME_FORMAT | Sets the format that datetimes should be encoded in within the CSV file. | None | +| TIMESTAMP_FORMAT | Sets the format that timestamps should be encoded in within the CSV file. | None | +| TIMESTAMP_TZ_FORMAT | Sets the format that timestamps with timezone should be encoded in within the CSV file. | None | +| TIME_FORMAT | Sets the format that times should be encoded in within the CSV file. | None | +| NULL_VALUE | Sets the string which should be used to indicate null values within the CSV file. | None | +| NULL_REGEX | Sets the regex pattern to match null values when loading CSVs. | None | +| SCHEMA_INFER_MAX_REC | Sets the maximum number of records to scan to infer the schema. | None | +| COMMENT | Sets the character which should be used to indicate comment lines in the CSV file. | None | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t (col1 varchar, col2 int, col3 boolean) +STORED AS CSV +LOCATION '/tmp/foo/' +OPTIONS('DELIMITER' '|', 'HAS_HEADER' 'true', 'NEWLINES_IN_VALUES' 'true'); +``` + +## Parquet Format Options + +The following options are available when reading or writing Parquet files. If any unsupported option is specified, an error will be raised and the query will fail. If a column-specific option is specified for a column that does not exist, the option will be ignored without error. + +| Option | Can be Column Specific? | Description | OPTIONS Key | Default Value | +| ------------------------------------------ | ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- | ------------------------ | +| COMPRESSION | Yes | Sets the internal Parquet **compression codec** for data pages, optionally including the compression level. Applies globally if set without `::col`, or specifically to a column if set using `'compression::column_name'`. Valid values: `uncompressed`, `snappy`, `gzip(level)`, `lzo`, `brotli(level)`, `lz4`, `zstd(level)`, `lz4_raw`. | `'compression'` or `'compression::col'` | zstd(3) | +| ENCODING | Yes | Sets the **encoding** scheme for data pages. Valid values: `plain`, `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, `byte_stream_split`. Use key `'encoding'` or `'encoding::col'` in OPTIONS. | `'encoding'` or `'encoding::col'` | None | +| DICTIONARY_ENABLED | Yes | Sets whether dictionary encoding should be enabled globally or for a specific column. | `'dictionary_enabled'` or `'dictionary_enabled::col'` | true | +| STATISTICS_ENABLED | Yes | Sets the level of statistics to write (`none`, `chunk`, `page`). | `'statistics_enabled'` or `'statistics_enabled::col'` | page | +| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written for a specific column. | `'bloom_filter_enabled::column_name'` | None | +| BLOOM_FILTER_FPP | Yes | Sets bloom filter false positive probability (global or per column). | `'bloom_filter_fpp'` or `'bloom_filter_fpp::col'` | None | +| BLOOM_FILTER_NDV | Yes | Sets bloom filter number of distinct values (global or per column). | `'bloom_filter_ndv'` or `'bloom_filter_ndv::col'` | None | +| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows per row group. Larger groups require more memory but can improve compression and scan efficiency. | `'max_row_group_size'` | 1048576 | +| ENABLE_PAGE_INDEX | No | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce I/O and decoding. | `'enable_page_index'` | true | +| PRUNING | No | If true, enables row group pruning based on min/max statistics. | `'pruning'` | true | +| SKIP_METADATA | No | If true, skips optional embedded metadata in the file schema. | `'skip_metadata'` | true | +| METADATA_SIZE_HINT | No | Sets the size hint (in bytes) for fetching Parquet file metadata. | `'metadata_size_hint'` | None | +| PUSHDOWN_FILTERS | No | If true, enables filter pushdown during Parquet decoding. | `'pushdown_filters'` | false | +| REORDER_FILTERS | No | If true, enables heuristic reordering of filters during Parquet decoding. | `'reorder_filters'` | false | +| SCHEMA_FORCE_VIEW_TYPES | No | If true, reads Utf8/Binary columns as view types. | `'schema_force_view_types'` | true | +| BINARY_AS_STRING | No | If true, reads Binary columns as strings. | `'binary_as_string'` | false | +| DATA_PAGESIZE_LIMIT | No | Sets best effort maximum size of data page in bytes. | `'data_pagesize_limit'` | 1048576 | +| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in data page. | `'data_page_row_count_limit'` | 20000 | +| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size, in bytes. | `'dictionary_page_size_limit'` | 1048576 | +| WRITE_BATCH_SIZE | No | Sets write_batch_size in bytes. | `'write_batch_size'` | 1024 | +| WRITER_VERSION | No | Sets the Parquet writer version (`1.0` or `2.0`). | `'writer_version'` | 1.0 | +| SKIP_ARROW_METADATA | No | If true, skips writing Arrow schema information into the Parquet file metadata. | `'skip_arrow_metadata'` | false | +| CREATED_BY | No | Sets the "created by" string in the Parquet file metadata. | `'created_by'` | datafusion version X.Y.Z | +| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the length (in bytes) to truncate min/max values in column indexes. | `'column_index_truncate_length'` | 64 | +| STATISTICS_TRUNCATE_LENGTH | No | Sets statistics truncate length. | `'statistics_truncate_length'` | None | +| BLOOM_FILTER_ON_WRITE | No | Sets whether bloom filters should be written for all columns by default (can be overridden per column). | `'bloom_filter_on_write'` | false | +| ALLOW_SINGLE_FILE_PARALLELISM | No | Enables parallel serialization of columns in a single file. | `'allow_single_file_parallelism'` | true | +| MAXIMUM_PARALLEL_ROW_GROUP_WRITERS | No | Maximum number of parallel row group writers. | `'maximum_parallel_row_group_writers'` | 1 | +| MAXIMUM_BUFFERED_RECORD_BATCHES_PER_STREAM | No | Maximum number of buffered record batches per stream. | `'maximum_buffered_record_batches_per_stream'` | 2 | +| KEY_VALUE_METADATA | No (Key is specific) | Adds custom key-value pairs to the file metadata. Use the format `'metadata::your_key_name' 'your_value'`. Multiple entries allowed. | `'metadata::key_name'` | None | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t (id bigint, value double, category varchar) +STORED AS PARQUET +LOCATION '/tmp/parquet_data/' +OPTIONS( + 'COMPRESSION::user_id' 'snappy', + 'ENCODING::col_a' 'delta_binary_packed', + 'MAX_ROW_GROUP_SIZE' '1000000', + 'BLOOM_FILTER_ENABLED::id' 'true' +); +``` diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 8e3f51bf8b0bc..a13d40334b639 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -33,5 +33,5 @@ SQL Reference window_functions scalar_functions special_functions - write_options + format_options prepared_statements diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 0f08934c8a9c3..9fcaac7628557 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -81,6 +81,17 @@ abs(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT abs(-5); ++----------+ +| abs(-5) | ++----------+ +| 5 | ++----------+ +``` + ### `acos` Returns the arc cosine or inverse cosine of a number. @@ -93,6 +104,17 @@ acos(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT acos(1); ++----------+ +| acos(1) | ++----------+ +| 0.0 | ++----------+ +``` + ### `acosh` Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number. @@ -105,6 +127,17 @@ acosh(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT acosh(2); ++------------+ +| acosh(2) | ++------------+ +| 1.31696 | ++------------+ +``` + ### `asin` Returns the arc sine or inverse sine of a number. @@ -117,6 +150,17 @@ asin(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT asin(0.5); ++------------+ +| asin(0.5) | ++------------+ +| 0.5235988 | ++------------+ +``` + ### `asinh` Returns the area hyperbolic sine or inverse hyperbolic sine of a number. @@ -129,6 +173,17 @@ asinh(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT asinh(1); ++------------+ +| asinh(1) | ++------------+ +| 0.8813736 | ++------------+ +``` + ### `atan` Returns the arc tangent or inverse tangent of a number. @@ -141,6 +196,17 @@ atan(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql + > SELECT atan(1); ++-----------+ +| atan(1) | ++-----------+ +| 0.7853982 | ++-----------+ +``` + ### `atan2` Returns the arc tangent or inverse tangent of `expression_y / expression_x`. @@ -156,6 +222,17 @@ atan2(expression_y, expression_x) - **expression_x**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +#### Example + +```sql +> SELECT atan2(1, 1); ++------------+ +| atan2(1,1) | ++------------+ +| 0.7853982 | ++------------+ +``` + ### `atanh` Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number. @@ -168,6 +245,17 @@ atanh(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql + > SELECT atanh(0.5); ++-------------+ +| atanh(0.5) | ++-------------+ +| 0.5493061 | ++-------------+ +``` + ### `cbrt` Returns the cube root of a number. @@ -180,6 +268,17 @@ cbrt(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT cbrt(27); ++-----------+ +| cbrt(27) | ++-----------+ +| 3.0 | ++-----------+ +``` + ### `ceil` Returns the nearest integer greater than or equal to a number. @@ -192,6 +291,17 @@ ceil(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql + > SELECT ceil(3.14); ++------------+ +| ceil(3.14) | ++------------+ +| 4.0 | ++------------+ +``` + ### `cos` Returns the cosine of a number. @@ -204,6 +314,17 @@ cos(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT cos(0); ++--------+ +| cos(0) | ++--------+ +| 1.0 | ++--------+ +``` + ### `cosh` Returns the hyperbolic cosine of a number. @@ -216,6 +337,17 @@ cosh(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT cosh(1); ++-----------+ +| cosh(1) | ++-----------+ +| 1.5430806 | ++-----------+ +``` + ### `cot` Returns the cotangent of a number. @@ -228,6 +360,17 @@ cot(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT cot(1); ++---------+ +| cot(1) | ++---------+ +| 0.64209 | ++---------+ +``` + ### `degrees` Converts radians to degrees. @@ -240,6 +383,17 @@ degrees(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql + > SELECT degrees(pi()); ++------------+ +| degrees(0) | ++------------+ +| 180.0 | ++------------+ +``` + ### `exp` Returns the base-e exponential of a number. @@ -252,6 +406,17 @@ exp(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT exp(1); ++---------+ +| exp(1) | ++---------+ +| 2.71828 | ++---------+ +``` + ### `factorial` Factorial. Returns 1 if value is less than 2. @@ -264,6 +429,17 @@ factorial(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT factorial(5); ++---------------+ +| factorial(5) | ++---------------+ +| 120 | ++---------------+ +``` + ### `floor` Returns the nearest integer less than or equal to a number. @@ -276,6 +452,17 @@ floor(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT floor(3.14); ++-------------+ +| floor(3.14) | ++-------------+ +| 3.0 | ++-------------+ +``` + ### `gcd` Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero. @@ -289,6 +476,17 @@ gcd(expression_x, expression_y) - **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT gcd(48, 18); ++------------+ +| gcd(48,18) | ++------------+ +| 6 | ++------------+ +``` + ### `isnan` Returns true if a given number is +NaN or -NaN otherwise returns false. @@ -301,6 +499,17 @@ isnan(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT isnan(1); ++----------+ +| isnan(1) | ++----------+ +| false | ++----------+ +``` + ### `iszero` Returns true if a given number is +0.0 or -0.0 otherwise returns false. @@ -313,6 +522,17 @@ iszero(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT iszero(0); ++------------+ +| iszero(0) | ++------------+ +| true | ++------------+ +``` + ### `lcm` Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero. @@ -326,6 +546,17 @@ lcm(expression_x, expression_y) - **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT lcm(4, 5); ++----------+ +| lcm(4,5) | ++----------+ +| 20 | ++----------+ +``` + ### `ln` Returns the natural logarithm of a number. @@ -338,6 +569,17 @@ ln(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT ln(2.71828); ++-------------+ +| ln(2.71828) | ++-------------+ +| 1.0 | ++-------------+ +``` + ### `log` Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number. @@ -352,6 +594,17 @@ log(numeric_expression) - **base**: Base numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT log(10); ++---------+ +| log(10) | ++---------+ +| 1.0 | ++---------+ +``` + ### `log10` Returns the base-10 logarithm of a number. @@ -364,6 +617,17 @@ log10(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT log10(100); ++-------------+ +| log10(100) | ++-------------+ +| 2.0 | ++-------------+ +``` + ### `log2` Returns the base-2 logarithm of a number. @@ -376,6 +640,17 @@ log2(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT log2(8); ++-----------+ +| log2(8) | ++-----------+ +| 3.0 | ++-----------+ +``` + ### `nanvl` Returns the first argument if it's not _NaN_. @@ -390,6 +665,17 @@ nanvl(expression_x, expression_y) - **expression_x**: Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. - **expression_y**: Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. +#### Example + +```sql +> SELECT nanvl(0, 5); ++------------+ +| nanvl(0,5) | ++------------+ +| 0 | ++------------+ +``` + ### `pi` Returns an approximate value of π. @@ -415,6 +701,17 @@ power(base, exponent) - **base**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - **exponent**: Exponent numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT power(2, 3); ++-------------+ +| power(2,3) | ++-------------+ +| 8 | ++-------------+ +``` + #### Aliases - pow @@ -431,6 +728,17 @@ radians(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT radians(180); ++----------------+ +| radians(180) | ++----------------+ +| 3.14159265359 | ++----------------+ +``` + ### `random` Returns a random float value in the range [0, 1). @@ -440,6 +748,17 @@ The random seed is unique to each row. random() ``` +#### Example + +```sql +> SELECT random(); ++------------------+ +| random() | ++------------------+ +| 0.7389238902938 | ++------------------+ +``` + ### `round` Rounds a number to the nearest integer. @@ -453,6 +772,17 @@ round(numeric_expression[, decimal_places]) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - **decimal_places**: Optional. The number of decimal places to round to. Defaults to 0. +#### Example + +```sql +> SELECT round(3.14159); ++--------------+ +| round(3.14159)| ++--------------+ +| 3.0 | ++--------------+ +``` + ### `signum` Returns the sign of a number. @@ -467,6 +797,17 @@ signum(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT signum(-42); ++-------------+ +| signum(-42) | ++-------------+ +| -1 | ++-------------+ +``` + ### `sin` Returns the sine of a number. @@ -479,6 +820,17 @@ sin(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT sin(0); ++----------+ +| sin(0) | ++----------+ +| 0.0 | ++----------+ +``` + ### `sinh` Returns the hyperbolic sine of a number. @@ -491,6 +843,17 @@ sinh(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT sinh(1); ++-----------+ +| sinh(1) | ++-----------+ +| 1.1752012 | ++-----------+ +``` + ### `sqrt` Returns the square root of a number. @@ -515,6 +878,17 @@ tan(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql +> SELECT tan(pi()/4); ++--------------+ +| tan(PI()/4) | ++--------------+ +| 1.0 | ++--------------+ +``` + ### `tanh` Returns the hyperbolic tangent of a number. @@ -527,6 +901,17 @@ tanh(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +#### Example + +```sql + > SELECT tanh(20); + +----------+ + | tanh(20) | + +----------+ + | 1.0 | + +----------+ +``` + ### `trunc` Truncates a number to a whole number or truncated to the specified decimal places. @@ -544,6 +929,17 @@ trunc(numeric_expression[, decimal_places]) right of the decimal point. If `decimal_places` is a negative integer, replaces digits to the left of the decimal point with `0`. +#### Example + +```sql +> SELECT trunc(42.738); ++----------------+ +| trunc(42.738) | ++----------------+ +| 42 | ++----------------+ +``` + ## Conditional Functions - [coalesce](#coalesce) @@ -768,7 +1164,7 @@ nvl2(expression1, expression2, expression3) ### `ascii` -Returns the Unicode character code of the first character in a string. +Returns the first Unicode scalar value of a string. ```sql ascii(str) @@ -909,7 +1305,7 @@ character_length(str) ### `chr` -Returns the character with the specified ASCII or Unicode code value. +Returns a string containing the character with the specified Unicode scalar value. ```sql chr(expression) @@ -1793,6 +2189,7 @@ regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) The following regular expression functions are supported: - [regexp_count](#regexp_count) +- [regexp_instr](#regexp_instr) - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) @@ -1828,6 +2225,39 @@ regexp_count(str, regexp[, start, flags]) +---------------------------------------------------------------+ ``` +### `regexp_instr` + +Returns the position in a string where the specified occurrence of a POSIX regular expression is located. + +```sql +regexp_instr(str, regexp[, start[, N[, flags[, subexpr]]]]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. Defaults to 1 +- **N**: - **N**: Optional The N-th occurrence of pattern to find. Defaults to 1 (first match). Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? +- **subexpr**: Optional Specifies which capture group (subexpression) to return the position for. Defaults to 0, which returns the position of the entire match. + +#### Example + +```sql +> SELECT regexp_instr('ABCDEF', 'C(.)(..)'); ++---------------------------------------------------------------+ +| regexp_instr(Utf8("ABCDEF"),Utf8("C(.)(..)")) | ++---------------------------------------------------------------+ +| 3 | ++---------------------------------------------------------------+ +``` + ### `regexp_like` Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. @@ -1973,7 +2403,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `current_date` -Returns the current UTC date. +Returns the current date in the session time zone. The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. @@ -2080,9 +2510,10 @@ date_part(part, expression) - millisecond - microsecond - nanosecond - - dow (day of the week) + - dow (day of the week where Sunday is 0) - doy (day of the year) - epoch (seconds since Unix epoch) + - isodow (day of the week where Monday is 0) - **expression**: Time expression to operate on. Can be a constant, column, or function. @@ -2116,6 +2547,8 @@ date_trunc(precision, expression) - hour / HOUR - minute / MINUTE - second / SECOND + - millisecond / MILLISECOND + - microsecond / MICROSECOND - **expression**: Time expression to operate on. Can be a constant, column, or function. @@ -2133,7 +2566,7 @@ _Alias of [date_trunc](#date_trunc)._ ### `from_unixtime` -Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. +Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. ```sql from_unixtime(expression[, timezone]) @@ -2186,7 +2619,7 @@ make_date(year, month, day) +-----------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) ### `now` @@ -2227,7 +2660,7 @@ to_char(expression, format) +----------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) #### Aliases @@ -2271,7 +2704,7 @@ to_date('2017-05-31', '%Y-%m-%d') +---------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) ### `to_local_time` @@ -2366,7 +2799,7 @@ to_timestamp(expression[, ..., format_n]) +--------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) ### `to_timestamp_micros` @@ -2398,7 +2831,7 @@ to_timestamp_micros(expression[, ..., format_n]) +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) ### `to_timestamp_millis` @@ -2430,7 +2863,7 @@ to_timestamp_millis(expression[, ..., format_n]) +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) ### `to_timestamp_nanos` @@ -2462,7 +2895,7 @@ to_timestamp_nanos(expression[, ..., format_n]) +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) ### `to_timestamp_seconds` @@ -2494,7 +2927,7 @@ to_timestamp_seconds(expression[, ..., format_n]) +----------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) ### `to_unixtime` @@ -2552,6 +2985,7 @@ _Alias of [current_date](#current_date)._ - [array_join](#array_join) - [array_length](#array_length) - [array_max](#array_max) +- [array_min](#array_min) - [array_ndims](#array_ndims) - [array_pop_back](#array_pop_back) - [array_pop_front](#array_pop_front) @@ -3058,6 +3492,29 @@ array_max(array) - list_max +### `array_min` + +Returns the minimum value in the array. + +```sql +array_min(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_min([3,1,4,2]); ++-----------------------------------------+ +| array_min(List([3,1,4,2])) | ++-----------------------------------------+ +| 1 | ++-----------------------------------------+ +``` + ### `array_ndims` Returns the number of dimensions of the array. @@ -3142,7 +3599,7 @@ array_pop_front(array) ### `array_position` -Returns the position of the first occurrence of the specified element in the array. +Returns the position of the first occurrence of the specified element in the array, or NULL if not found. ```sql array_position(array, element) @@ -3153,7 +3610,7 @@ array_position(array, element, index) - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **element**: Element to search for position in the array. -- **index**: Index at which to start searching. +- **index**: Index at which to start searching (1-indexed). #### Example @@ -4105,6 +4562,7 @@ select struct(a as field_a, b) from t; - [element_at](#element_at) - [map](#map) +- [map_entries](#map_entries) - [map_extract](#map_extract) - [map_keys](#map_keys) - [map_values](#map_values) @@ -4162,6 +4620,30 @@ SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); {key1: value1, key2: } ``` +### `map_entries` + +Returns a list of all entries in the map. + +```sql +map_entries(map) +``` + +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_entries(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[{'key': a, 'value': 1}, {'key': b, 'value': NULL}, {'key': c, 'value': 3}] + +SELECT map_entries(map([100, 5], [42, 43])); +---- +[{'key': 100, 'value': 42}, {'key': 5, 'value': 43}] +``` + ### `map_extract` Returns a list containing the value for the given key or an empty list if the key is not present in the map. @@ -4404,6 +4886,7 @@ sha512(expression) Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator - [union_extract](#union_extract) +- [union_tag](#union_tag) ### `union_extract` @@ -4433,6 +4916,33 @@ union_extract(union, field_name) +--------------+----------------------------------+----------------------------------+ ``` +### `union_tag` + +Returns the name of the currently selected field in the union + +```sql +union_tag(union_expression) +``` + +#### Arguments + +- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +❯ select union_column, union_tag(union_column) from table_with_union; ++--------------+-------------------------+ +| union_column | union_tag(union_column) | ++--------------+-------------------------+ +| {a=1} | a | +| {b=3.0} | b | +| {a=4} | a | +| {b=} | b | +| {a=} | a | ++--------------+-------------------------+ +``` + ## Other Functions - [arrow_cast](#arrow_cast) diff --git a/docs/source/user-guide/sql/select.md b/docs/source/user-guide/sql/select.md index b2fa0a6305888..baacf432f5fde 100644 --- a/docs/source/user-guide/sql/select.md +++ b/docs/source/user-guide/sql/select.md @@ -35,10 +35,12 @@ DataFusion supports the following syntax for queries: [ [WHERE](#where-clause) condition ]
[ [GROUP BY](#group-by-clause) grouping_element [, ...] ]
[ [HAVING](#having-clause) condition]
+[ [QUALIFY](#qualify-clause) condition]
[ [UNION](#union-clause) [ ALL | select ]
[ [ORDER BY](#order-by-clause) expression [ ASC | DESC ][, ...] ]
[ [LIMIT](#limit-clause) count ]
[ [EXCLUDE | EXCEPT](#exclude-and-except-clause) ]
+[Pipe operators](#pipe-operators)
@@ -84,7 +86,7 @@ SELECT a FROM table WHERE a > 10 ## JOIN clause -DataFusion supports `INNER JOIN`, `LEFT OUTER JOIN`, `RIGHT OUTER JOIN`, `FULL OUTER JOIN`, `NATURAL JOIN` and `CROSS JOIN`. +DataFusion supports `INNER JOIN`, `LEFT OUTER JOIN`, `RIGHT OUTER JOIN`, `FULL OUTER JOIN`, `NATURAL JOIN`, `CROSS JOIN`, `LEFT SEMI JOIN`, `RIGHT SEMI JOIN`, `LEFT ANTI JOIN`, and `RIGHT ANTI JOIN`. The following examples are based on this table: @@ -102,7 +104,7 @@ select * from x; The keywords `JOIN` or `INNER JOIN` define a join that only shows rows where there is a match in both tables. ```sql -select * from x inner join x y ON x.column_1 = y.column_1; +SELECT * FROM x INNER JOIN x y ON x.column_1 = y.column_1; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -116,7 +118,7 @@ The keywords `LEFT JOIN` or `LEFT OUTER JOIN` define a join that includes all ro is not a match in the right table. When there is no match, null values are produced for the right side of the join. ```sql -select * from x left join x y ON x.column_1 = y.column_2; +SELECT * FROM x LEFT JOIN x y ON x.column_1 = y.column_2; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -130,7 +132,7 @@ The keywords `RIGHT JOIN` or `RIGHT OUTER JOIN` define a join that includes all is not a match in the left table. When there is no match, null values are produced for the left side of the join. ```sql -select * from x right join x y ON x.column_1 = y.column_2; +SELECT * FROM x RIGHT JOIN x y ON x.column_1 = y.column_2; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -145,7 +147,7 @@ The keywords `FULL JOIN` or `FULL OUTER JOIN` define a join that is effectively either side of the join where there is not a match. ```sql -select * from x full outer join x y ON x.column_1 = y.column_2; +SELECT * FROM x FULL OUTER JOIN x y ON x.column_1 = y.column_2; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -156,11 +158,11 @@ select * from x full outer join x y ON x.column_1 = y.column_2; ### NATURAL JOIN -A natural join defines an inner join based on common column names found between the input tables. When no common -column names are found, it behaves like a cross join. +A `NATURAL JOIN` defines an inner join based on common column names found between the input tables. When no common +column names are found, it behaves like a `CROSS JOIN`. ```sql -select * from x natural join x y; +SELECT * FROM x NATURAL JOIN x y; +----------+----------+ | column_1 | column_2 | +----------+----------+ @@ -170,11 +172,11 @@ select * from x natural join x y; ### CROSS JOIN -A cross join produces a cartesian product that matches every row in the left side of the join with every row in the +A `CROSS JOIN` produces a cartesian product that matches every row in the left side of the join with every row in the right side of the join. ```sql -select * from x cross join x y; +SELECT * FROM x CROSS JOIN x y; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -182,6 +184,60 @@ select * from x cross join x y; +----------+----------+----------+----------+ ``` +### LEFT SEMI JOIN + +The `LEFT SEMI JOIN` returns all rows from the left table that have at least one matching row in the right table, and +projects only the columns from the left table. + +```sql +SELECT * FROM x LEFT SEMI JOIN x y ON x.column_1 = y.column_1; ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ +| 1 | 2 | ++----------+----------+ +``` + +### RIGHT SEMI JOIN + +The `RIGHT SEMI JOIN` returns all rows from the right table that have at least one matching row in the left table, and +only projects the columns from the right table. + +```sql +SELECT * FROM x RIGHT SEMI JOIN x y ON x.column_1 = y.column_1; ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ +| 1 | 2 | ++----------+----------+ +``` + +### LEFT ANTI JOIN + +The `LEFT ANTI JOIN` returns all rows from the left table that do not have any matching row in the right table, projecting +only the left table’s columns. + +```sql +SELECT * FROM x LEFT ANTI JOIN x y ON x.column_1 = y.column_1; ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ ++----------+----------+ +``` + +### RIGHT ANTI JOIN + +The `RIGHT ANTI JOIN` returns all rows from the right table that do not have any matching row in the left table, projecting +only the right table’s columns. + +```sql +SELECT * FROM x RIGHT ANTI JOIN x y ON x.column_1 = y.column_1; ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ ++----------+----------+ +``` + ## GROUP BY clause Example: @@ -207,6 +263,14 @@ Example: SELECT a, b, MAX(c) FROM table GROUP BY a, b HAVING MAX(c) > 10 ``` +## QUALIFY clause + +Example: + +```sql +SELECT ROW_NUMBER() OVER (PARTITION BY region) AS rk FROM table QUALIFY rk > 1; +``` + ## UNION clause Example: @@ -264,3 +328,215 @@ FROM table; SELECT * EXCLUDE(age, person) FROM table; ``` + +## Pipe operators + +Some SQL dialects (e.g. BigQuery) support the pipe operator `|>`. +The SQL dialect can be set like this: + +```sql +set datafusion.sql_parser.dialect = 'BigQuery'; +``` + +DataFusion currently supports the following pipe operators: + +- [WHERE](#pipe_where) +- [ORDER BY](#pipe_order_by) +- [LIMIT](#pipe_limit) +- [SELECT](#pipe_select) +- [EXTEND](#pipe_extend) +- [AS](#pipe_as) +- [UNION](#pipe_union) +- [INTERSECT](#pipe_intersect) +- [EXCEPT](#pipe_except) +- [AGGREGATE](#pipe_aggregate) +- [JOIN](#pipe_join) + +(pipe_where)= + +### WHERE + +```sql +select * from range(0,10) +|> where value < 2; ++-------+ +| value | ++-------+ +| 0 | +| 1 | ++-------+ +``` + +(pipe_order_by)= + +### ORDER BY + +```sql +select * from range(0,3) +|> order by value desc; ++-------+ +| value | ++-------+ +| 2 | +| 1 | +| 0 | ++-------+ +``` + +(pipe_limit)= + +### LIMIT + +```sql +select * from range(0,3) +|> order by value desc +|> limit 1; ++-------+ +| value | ++-------+ +| 2 | ++-------+ +``` + +(pipe_select)= + +### SELECT + +```sql +select * from range(0,3) +|> select value + 10; ++---------------------------+ +| range().value + Int64(10) | ++---------------------------+ +| 10 | +| 11 | +| 12 | ++---------------------------+ +``` + +(pipe_extend)= + +### EXTEND + +```sql +select * from range(0,3) +|> extend -value AS minus_value; ++-------+-------------+ +| value | minus_value | ++-------+-------------+ +| 0 | 0 | +| 1 | -1 | +| 2 | -2 | ++-------+-------------+ +``` + +(pipe_as)= + +### AS + +```sql +select * from range(0,3) +|> as my_range +|> SELECT my_range.value; ++-------+ +| value | ++-------+ +| 0 | +| 1 | +| 2 | ++-------+ +``` + +(pipe_union)= + +### UNION + +```sql +select * from range(0,3) +|> union all ( + select * from range(3,6) +); ++-------+ +| value | ++-------+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | ++-------+ +``` + +(pipe_intersect)= + +### INTERSECT + +```sql +select * from range(0,100) +|> INTERSECT DISTINCT ( + select 3 +); ++-------+ +| value | ++-------+ +| 3 | ++-------+ +``` + +(pipe_except)= + +### EXCEPT + +```sql +select * from range(0,10) +|> EXCEPT DISTINCT (select * from range(5,10)); ++-------+ +| value | ++-------+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | ++-------+ +``` + +(pipe_aggregate)= + +### AGGREGATE + +```sql +select * from range(0,3) +|> aggregate sum(value) AS total; ++-------+ +| total | ++-------+ +| 3 | ++-------+ +``` + +(pipe_join)= + +### JOIN + +```sql +( + SELECT 'apples' AS item, 2 AS sales + UNION ALL + SELECT 'bananas' AS item, 5 AS sales +) +|> AS produce_sales +|> LEFT JOIN + ( + SELECT 'apples' AS item, 123 AS id + ) AS produce_data + ON produce_sales.item = produce_data.item +|> SELECT produce_sales.item, sales, id; ++--------+-------+------+ +| item | sales | id | ++--------+-------+------+ +| apples | 2 | 123 | +| bananas| 5 | NULL | ++--------+-------+------+ +``` diff --git a/docs/source/user-guide/sql/special_functions.md b/docs/source/user-guide/sql/special_functions.md index 7c9efbb66218f..4f2a39f642b06 100644 --- a/docs/source/user-guide/sql/special_functions.md +++ b/docs/source/user-guide/sql/special_functions.md @@ -69,6 +69,7 @@ Expands an array or map into rows. ### `unnest (struct)` Expand a struct fields into individual columns. +Each field of the struct will be prefixed with `__unnest_placeholder` and could be accessed via `"__unnest_placeholder()."`. #### Arguments @@ -91,10 +92,10 @@ Expand a struct fields into individual columns. +---------------------------+ > select unnest(struct_column) from foov; -+------------------------------------------+------------------------------------------+ -| unnest_placeholder(foov.struct_column).a | unnest_placeholder(foov.struct_column).b | -+------------------------------------------+------------------------------------------+ -| 5 | a string | -| 6 | another string | -+------------------------------------------+------------------------------------------+ ++--------------------------------------------+--------------------------------------------+ +| __unnest_placeholder(foov.struct_column).a | __unnest_placeholder(foov.struct_column).b | ++--------------------------------------------+--------------------------------------------+ +| 5 | a string | +| 6 | another string | ++--------------------------------------------+--------------------------------------------+ ``` diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 1c02804f0deed..2c8050ce1f9ca 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -145,6 +145,17 @@ where **offset** is an non-negative integer. RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must specify exactly one column). +## Filter clause for aggregate window functions + +Aggregate window functions support the SQL `FILTER (WHERE ...)` clause to include only rows that satisfy the predicate from the window frame in the aggregation. + +```sql +sum(salary) FILTER (WHERE salary > 0) + OVER (PARTITION BY depname ORDER BY salary ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +``` + +If no rows in the frame satisfy the filter for a given output row, `COUNT` yields `0` while `SUM`/`AVG`/`MIN`/`MAX` yield `NULL`. + ## Aggregate functions All [aggregate functions](aggregate_functions.md) can be used as window functions. @@ -160,12 +171,29 @@ All [aggregate functions](aggregate_functions.md) can be used as window function ### `cume_dist` -Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows). +Relative rank of the current row: (number of rows preceding or peer with the current row) / (total rows). ```sql cume_dist() ``` +#### Example + +```sql +-- Example usage of the cume_dist window function: +SELECT salary, + cume_dist() OVER (ORDER BY salary) AS cume_dist +FROM employees; + ++--------+-----------+ +| salary | cume_dist | ++--------+-----------+ +| 30000 | 0.33 | +| 50000 | 0.67 | +| 70000 | 1.00 | ++--------+-----------+ +``` + ### `dense_rank` Returns the rank of the current row without gaps. This function ranks rows in a dense manner, meaning consecutive ranks are assigned even for identical values. @@ -174,6 +202,27 @@ Returns the rank of the current row without gaps. This function ranks rows in a dense_rank() ``` +#### Example + +```sql +-- Example usage of the dense_rank window function: +SELECT department, + salary, + dense_rank() OVER (PARTITION BY department ORDER BY salary DESC) AS dense_rank +FROM employees; + ++-------------+--------+------------+ +| department | salary | dense_rank | ++-------------+--------+------------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 3 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------------+ +``` + ### `ntile` Integer ranging from 1 to the argument value, dividing the partition as equally as possible @@ -186,6 +235,29 @@ ntile(expression) - **expression**: An integer describing the number groups the partition should be split into +#### Example + +```sql +-- Example usage of the ntile window function: +SELECT employee_id, + salary, + ntile(4) OVER (ORDER BY salary DESC) AS quartile +FROM employees; + ++-------------+--------+----------+ +| employee_id | salary | quartile | ++-------------+--------+----------+ +| 1 | 90000 | 1 | +| 2 | 85000 | 1 | +| 3 | 80000 | 2 | +| 4 | 70000 | 2 | +| 5 | 60000 | 3 | +| 6 | 50000 | 3 | +| 7 | 40000 | 4 | +| 8 | 30000 | 4 | ++-------------+--------+----------+ +``` + ### `percent_rank` Returns the percentage rank of the current row within its partition. The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`. @@ -194,6 +266,24 @@ Returns the percentage rank of the current row within its partition. The value r percent_rank() ``` +#### Example + +```sql + -- Example usage of the percent_rank window function: +SELECT employee_id, + salary, + percent_rank() OVER (ORDER BY salary) AS percent_rank +FROM employees; + ++-------------+--------+---------------+ +| employee_id | salary | percent_rank | ++-------------+--------+---------------+ +| 1 | 30000 | 0.00 | +| 2 | 50000 | 0.50 | +| 3 | 70000 | 1.00 | ++-------------+--------+---------------+ +``` + ### `rank` Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. @@ -202,6 +292,27 @@ Returns the rank of the current row within its partition, allowing gaps between rank() ``` +#### Example + +```sql +-- Example usage of the rank window function: +SELECT department, + salary, + rank() OVER (PARTITION BY department ORDER BY salary DESC) AS rank +FROM employees; + ++-------------+--------+------+ +| department | salary | rank | ++-------------+--------+------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------+ +``` + ### `row_number` Number of the current row within its partition, counting from 1. @@ -210,6 +321,27 @@ Number of the current row within its partition, counting from 1. row_number() ``` +#### Example + +```sql +-- Example usage of the row_number window function: +SELECT department, + salary, + row_number() OVER (PARTITION BY department ORDER BY salary DESC) AS row_num +FROM employees; + ++-------------+--------+---------+ +| department | salary | row_num | ++-------------+--------+---------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 3 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+---------+ +``` + ## Analytical Functions - [first_value](#first_value) @@ -230,6 +362,27 @@ first_value(expression) - **expression**: Expression to operate on +#### Example + +```sql +-- Example usage of the first_value window function: +SELECT department, + employee_id, + salary, + first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary +FROM employees; + ++-------------+-------------+--------+------------+ +| department | employee_id | salary | top_salary | ++-------------+-------------+--------+------------+ +| Sales | 1 | 70000 | 70000 | +| Sales | 2 | 50000 | 70000 | +| Sales | 3 | 30000 | 70000 | +| Engineering | 4 | 90000 | 90000 | +| Engineering | 5 | 80000 | 90000 | ++-------------+-------------+--------+------------+ +``` + ### `lag` Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). @@ -244,6 +397,25 @@ lag(expression, offset, default) - **offset**: Integer. Specifies how many rows back the value of expression should be retrieved. Defaults to 1. - **default**: The default value if the offset is not within the partition. Must be of the same type as expression. +#### Example + +```sql +-- Example usage of the lag window function: +SELECT employee_id, + salary, + lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary +FROM employees; + ++-------------+--------+-------------+ +| employee_id | salary | prev_salary | ++-------------+--------+-------------+ +| 1 | 30000 | 0 | +| 2 | 50000 | 30000 | +| 3 | 70000 | 50000 | +| 4 | 60000 | 70000 | ++-------------+--------+-------------+ +``` + ### `last_value` Returns value evaluated at the row that is the last row of the window frame. @@ -256,6 +428,27 @@ last_value(expression) - **expression**: Expression to operate on +#### Example + +```sql +-- SQL example of last_value: +SELECT department, + employee_id, + salary, + last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary +FROM employees; + ++-------------+-------------+--------+---------------------+ +| department | employee_id | salary | running_last_salary | ++-------------+-------------+--------+---------------------+ +| Sales | 1 | 30000 | 30000 | +| Sales | 2 | 50000 | 50000 | +| Sales | 3 | 70000 | 70000 | +| Engineering | 4 | 40000 | 40000 | +| Engineering | 5 | 60000 | 60000 | ++-------------+-------------+--------+---------------------+ +``` + ### `lead` Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). @@ -270,9 +463,31 @@ lead(expression, offset, default) - **offset**: Integer. Specifies how many rows forward the value of expression should be retrieved. Defaults to 1. - **default**: The default value if the offset is not within the partition. Must be of the same type as expression. +#### Example + +```sql +-- Example usage of lead window function: +SELECT + employee_id, + department, + salary, + lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary +FROM employees; + ++-------------+-------------+--------+--------------+ +| employee_id | department | salary | next_salary | ++-------------+-------------+--------+--------------+ +| 1 | Sales | 30000 | 50000 | +| 2 | Sales | 50000 | 70000 | +| 3 | Sales | 70000 | 0 | +| 4 | Engineering | 40000 | 60000 | +| 5 | Engineering | 60000 | 0 | ++-------------+-------------+--------+--------------+ +``` + ### `nth_value` -Returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row. +Returns the value evaluated at the nth row of the window frame (counting from 1). Returns NULL if no such row exists. ```sql nth_value(expression, n) @@ -280,5 +495,35 @@ nth_value(expression, n) #### Arguments -- **expression**: The name the column of which nth value to retrieve -- **n**: Integer. Specifies the n in nth +- **expression**: The column from which to retrieve the nth value. +- **n**: Integer. Specifies the row number (starting from 1) in the window frame. + +#### Example + +```sql +-- Sample employees table: +CREATE TABLE employees (id INT, salary INT); +INSERT INTO employees (id, salary) VALUES +(1, 30000), +(2, 40000), +(3, 50000), +(4, 60000), +(5, 70000); + +-- Example usage of nth_value: +SELECT nth_value(salary, 2) OVER ( + ORDER BY salary + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +) AS nth_value +FROM employees; + ++-----------+ +| nth_value | ++-----------+ +| 40000 | +| 40000 | +| 40000 | +| 40000 | +| 40000 | ++-----------+ +``` diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md deleted file mode 100644 index 521e29436212d..0000000000000 --- a/docs/source/user-guide/sql/write_options.md +++ /dev/null @@ -1,127 +0,0 @@ - - -# Write Options - -DataFusion supports customizing how data is written out to disk as a result of a `COPY` or `INSERT INTO` query. There are a few special options, file format (e.g. CSV or parquet) specific options, and parquet column specific options. Options can also in some cases be specified in multiple ways with a set order of precedence. - -## Specifying Options and Order of Precedence - -Write related options can be specified in the following ways: - -- Session level config defaults -- `CREATE EXTERNAL TABLE` options -- `COPY` option tuples - -For a list of supported session level config defaults see [Configuration Settings](../configs). These defaults apply to all write operations but have the lowest level of precedence. - -If inserting to an external table, table specific write options can be specified when the table is created using the `OPTIONS` clause: - -```sql -CREATE EXTERNAL TABLE - my_table(a bigint, b bigint) - STORED AS csv - COMPRESSION TYPE gzip - LOCATION '/test/location/my_csv_table/' - OPTIONS( - NULL_VALUE 'NAN', - 'has_header' 'true', - 'format.delimiter' ';' - ) -``` - -When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). There will be a single output file if the output path doesn't have folder format, i.e. ending with a `\`. Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. - -Finally, options can be passed when running a `COPY` command. - - - -```sql -COPY source_table - TO 'test/table_with_options' - PARTITIONED BY (column3, column4) - OPTIONS ( - format parquet, - compression snappy, - 'compression::column1' 'zstd(5)', - ) -``` - -In this example, we write the entirety of `source_table` out to a folder of parquet files. One parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the parquet file will use `ZSTD` compression codec with compression level `5`. In general, parquet options which support column specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`. - -## Available Options - -### Execution Specific Options - -The following options are available when executing a `COPY` query. - -| Option | Description | Default Value | -| ----------------------------------- | ---------------------------------------------------------------------------------- | ------------- | -| execution.keep_partition_by_columns | Flag to retain the columns in the output data when using `PARTITIONED BY` queries. | false | - -Note: `execution.keep_partition_by_columns` flag can also be enabled through `ExecutionOptions` within `SessionConfig`. - -### JSON Format Specific Options - -The following options are available when writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail. - -| Option | Description | Default Value | -| ----------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------- | -| COMPRESSION | Sets the compression that should be applied to the entire JSON file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | - -### CSV Format Specific Options - -The following options are available when writing CSV files. Note: if any unsupported options is specified an error will be raised and the query will fail. - -| Option | Description | Default Value | -| --------------- | --------------------------------------------------------------------------------------------------------------------------------- | ---------------- | -| COMPRESSION | Sets the compression that should be applied to the entire CSV file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | -| HEADER | Sets if the CSV file should include column headers | false | -| DATE_FORMAT | Sets the format that dates should be encoded in within the CSV file | arrow-rs default | -| DATETIME_FORMAT | Sets the format that datetimes should be encoded in within the CSV file | arrow-rs default | -| TIME_FORMAT | Sets the format that times should be encoded in within the CSV file | arrow-rs default | -| RFC3339 | If true, uses RFC339 format for date and time encodings | arrow-rs default | -| NULL_VALUE | Sets the string which should be used to indicate null values within the CSV file. | arrow-rs default | -| DELIMITER | Sets the character which should be used as the column delimiter within the CSV file. | arrow-rs default | - -### Parquet Format Specific Options - -The following options are available when writing parquet files. If any unsupported option is specified an error will be raised and the query will fail. If a column specific option is specified for a column which does not exist, the option will be ignored without error. For default values, see: [Configuration Settings](https://datafusion.apache.org/user-guide/configs.html). - -| Option | Can be Column Specific? | Description | -| ---------------------------- | ----------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | -| COMPRESSION | Yes | Sets the compression codec and if applicable compression level to use | -| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows that can be encoded in a single row group. Larger row groups require more memory to write and read. | -| DATA_PAGESIZE_LIMIT | No | Sets the best effort maximum page size in bytes | -| WRITE_BATCH_SIZE | No | Maximum number of rows written for each column in a single batch | -| WRITER_VERSION | No | Parquet writer version (1.0 or 2.0) | -| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size in bytes | -| CREATED_BY | No | Sets the "created by" property in the parquet file | -| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the max length of min/max value fields in the column index. | -| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in a data page. | -| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written into the file. | -| ENCODING | Yes | Sets the encoding that should be used (e.g. PLAIN or RLE) | -| DICTIONARY_ENABLED | Yes | Sets if dictionary encoding is enabled. Use this instead of ENCODING to set dictionary encoding. | -| STATISTICS_ENABLED | Yes | Sets if statistics are enabled at PAGE or ROW_GROUP level. | -| MAX_STATISTICS_SIZE | Yes | Sets the maximum size in bytes that statistics can take up. | -| BLOOM_FILTER_FPP | Yes | Sets the false positive probability (fpp) for the bloom filter. Implicitly sets BLOOM_FILTER_ENABLED to true. | -| BLOOM_FILTER_NDV | Yes | Sets the number of distinct values (ndv) for the bloom filter. Implicitly sets bloom_filter_enabled to true. | diff --git a/parquet-testing b/parquet-testing index f4d7ed772a62a..107b36603e051 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit f4d7ed772a62a95111db50fbcad2460833e8c882 +Subproject commit 107b36603e051aee26bd93e04b871034f6c756c0 diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 11f4fb798c376..7697bc1c1e259 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -19,5 +19,5 @@ # to compile this workspace and run CI jobs. [toolchain] -channel = "1.85.0" +channel = "1.90.0" components = ["rustfmt", "clippy"] diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 811102cf6dbdb..3a161d5f4d645 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -27,7 +27,7 @@ workspace = true [dependencies] arrow = { workspace = true } -chrono-tz = { version = "0.10.3", default-features = false } +chrono-tz = { version = "0.10.4", default-features = false } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } rand = { workspace = true } diff --git a/test-utils/src/array_gen/binary.rs b/test-utils/src/array_gen/binary.rs index d342118fa85d3..9740eeae5e7fe 100644 --- a/test-utils/src/array_gen/binary.rs +++ b/test-utils/src/array_gen/binary.rs @@ -46,11 +46,11 @@ impl BinaryArrayGenerator { // Pick num_binaries randomly from the distinct binary table let indices: UInt32Array = (0..self.num_binaries) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_binaries > 1 { let range = 0..(self.num_distinct_binaries as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -68,11 +68,11 @@ impl BinaryArrayGenerator { let indices: UInt32Array = (0..self.num_binaries) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_binaries > 1 { let range = 0..(self.num_distinct_binaries as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -88,7 +88,7 @@ fn random_binary(rng: &mut StdRng, max_len: usize) -> Vec { if max_len == 0 { Vec::new() } else { - let len = rng.gen_range(1..=max_len); - (0..len).map(|_| rng.gen()).collect() + let len = rng.random_range(1..=max_len); + (0..len).map(|_| rng.random()).collect() } } diff --git a/test-utils/src/array_gen/boolean.rs b/test-utils/src/array_gen/boolean.rs index f3b83dd245f72..004d615b4caa4 100644 --- a/test-utils/src/array_gen/boolean.rs +++ b/test-utils/src/array_gen/boolean.rs @@ -34,7 +34,7 @@ impl BooleanArrayGenerator { // Table of booleans from which to draw (distinct means 1 or 2) let distinct_booleans: BooleanArray = match self.num_distinct_booleans { 1 => { - let value = self.rng.gen::(); + let value = self.rng.random::(); let mut builder = BooleanBuilder::with_capacity(1); builder.append_value(value); builder.finish() @@ -51,10 +51,10 @@ impl BooleanArrayGenerator { // Generate indices to select from the distinct booleans let indices: UInt32Array = (0..self.num_booleans) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_booleans > 1 { - Some(self.rng.gen_range(0..self.num_distinct_booleans as u32)) + Some(self.rng.random_range(0..self.num_distinct_booleans as u32)) } else { Some(0) } diff --git a/test-utils/src/array_gen/decimal.rs b/test-utils/src/array_gen/decimal.rs index d46ea9fe54575..c5ec8ac5e8938 100644 --- a/test-utils/src/array_gen/decimal.rs +++ b/test-utils/src/array_gen/decimal.rs @@ -62,11 +62,11 @@ impl DecimalArrayGenerator { // pick num_decimals randomly from the distinct decimal table let indices: UInt32Array = (0..self.num_decimals) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_decimals > 1 { let range = 1..(self.num_distinct_decimals as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs index 58d39c14e65d6..62a38a1b4ce1d 100644 --- a/test-utils/src/array_gen/primitive.rs +++ b/test-utils/src/array_gen/primitive.rs @@ -18,7 +18,8 @@ use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, UInt32Array}; use arrow::datatypes::DataType; use chrono_tz::{Tz, TZ_VARIANTS}; -use rand::{rngs::StdRng, seq::SliceRandom, thread_rng, Rng}; +use rand::prelude::IndexedRandom; +use rand::{rng, rngs::StdRng, Rng}; use std::sync::Arc; use super::random_data::RandomNativeData; @@ -66,6 +67,7 @@ impl PrimitiveArrayGenerator { | DataType::Time32(_) | DataType::Time64(_) | DataType::Interval(_) + | DataType::Duration(_) | DataType::Binary | DataType::LargeBinary | DataType::BinaryView @@ -81,11 +83,11 @@ impl PrimitiveArrayGenerator { // pick num_primitives randomly from the distinct string table let indices: UInt32Array = (0..self.num_primitives) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_primitives > 1 { let range = 1..(self.num_distinct_primitives as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -102,7 +104,7 @@ impl PrimitiveArrayGenerator { /// - `Some(Arc)` containing the timezone name. /// - `None` if no timezone is selected. fn generate_timezone() -> Option> { - let mut rng = thread_rng(); + let mut rng = rng(); // Allows for timezones + None let mut timezone_options: Vec> = vec![None]; diff --git a/test-utils/src/array_gen/random_data.rs b/test-utils/src/array_gen/random_data.rs index a7297d45fdf07..ea2b872f7d86f 100644 --- a/test-utils/src/array_gen/random_data.rs +++ b/test-utils/src/array_gen/random_data.rs @@ -17,15 +17,16 @@ use arrow::array::ArrowPrimitiveType; use arrow::datatypes::{ - i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime, - IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, - IntervalYearMonthType, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, + i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Decimal32Type, + Decimal64Type, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTime, IntervalDayTimeType, + IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalYearMonthType, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use rand::distributions::Standard; +use rand::distr::StandardUniform; use rand::prelude::Distribution; use rand::rngs::StdRng; use rand::Rng; @@ -40,11 +41,11 @@ macro_rules! basic_random_data { ($ARROW_TYPE: ty) => { impl RandomNativeData for $ARROW_TYPE where - Standard: Distribution, + StandardUniform: Distribution, { #[inline] fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { - rng.gen::() + rng.random::() } } }; @@ -66,16 +67,23 @@ basic_random_data!(Time32MillisecondType); basic_random_data!(Time64MicrosecondType); basic_random_data!(Time64NanosecondType); basic_random_data!(IntervalYearMonthType); +basic_random_data!(Decimal32Type); +basic_random_data!(Decimal64Type); basic_random_data!(Decimal128Type); basic_random_data!(TimestampSecondType); basic_random_data!(TimestampMillisecondType); basic_random_data!(TimestampMicrosecondType); basic_random_data!(TimestampNanosecondType); +// Note DurationSecondType is restricted to i64::MIN / 1000 to i64::MAX / 1000 +// due to https://github.com/apache/arrow-rs/issues/7533 so handle it specially below +basic_random_data!(DurationMillisecondType); +basic_random_data!(DurationMicrosecondType); +basic_random_data!(DurationNanosecondType); impl RandomNativeData for Date64Type { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { // TODO: constrain this range to valid dates if necessary - let date_value = rng.gen_range(i64::MIN..=i64::MAX); + let date_value = rng.random_range(i64::MIN..=i64::MAX); let millis_per_day = 86_400_000; date_value - (date_value % millis_per_day) } @@ -84,8 +92,8 @@ impl RandomNativeData for Date64Type { impl RandomNativeData for IntervalDayTimeType { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { IntervalDayTime { - days: rng.gen::(), - milliseconds: rng.gen::(), + days: rng.random::(), + milliseconds: rng.random::(), } } } @@ -93,15 +101,24 @@ impl RandomNativeData for IntervalDayTimeType { impl RandomNativeData for IntervalMonthDayNanoType { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { IntervalMonthDayNano { - months: rng.gen::(), - days: rng.gen::(), - nanoseconds: rng.gen::(), + months: rng.random::(), + days: rng.random::(), + nanoseconds: rng.random::(), } } } +// Restrict Duration(Seconds) to i64::MIN / 1000 to i64::MAX / 1000 to +// avoid panics on pretty printing. See +// https://github.com/apache/arrow-rs/issues/7533 +impl RandomNativeData for DurationSecondType { + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { + rng.random::() / 1000 + } +} + impl RandomNativeData for Decimal256Type { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { - i256::from_parts(rng.gen::(), rng.gen::()) + i256::from_parts(rng.random::(), rng.random::()) } } diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs index ac659ae67bc0e..546485fd8dc16 100644 --- a/test-utils/src/array_gen/string.rs +++ b/test-utils/src/array_gen/string.rs @@ -18,6 +18,7 @@ use arrow::array::{ ArrayRef, GenericStringArray, OffsetSizeTrait, StringViewArray, UInt32Array, }; +use rand::distr::StandardUniform; use rand::rngs::StdRng; use rand::Rng; @@ -47,11 +48,11 @@ impl StringArrayGenerator { // pick num_strings randomly from the distinct string table let indices: UInt32Array = (0..self.num_strings) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_strings > 1 { let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -71,11 +72,11 @@ impl StringArrayGenerator { // pick num_strings randomly from the distinct string table let indices: UInt32Array = (0..self.num_strings) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_strings > 1 { let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -92,10 +93,10 @@ fn random_string(rng: &mut StdRng, max_len: usize) -> String { // pick characters at random (not just ascii) match max_len { 0 => "".to_string(), - 1 => String::from(rng.gen::()), + 1 => String::from(rng.random::()), _ => { - let len = rng.gen_range(1..=max_len); - rng.sample_iter::(rand::distributions::Standard) + let len = rng.random_range(1..=max_len); + rng.sample_iter::(StandardUniform) .take(len) .collect() } diff --git a/test-utils/src/data_gen.rs b/test-utils/src/data_gen.rs index 7ac6f3d3e255a..2228010b28dd1 100644 --- a/test-utils/src/data_gen.rs +++ b/test-utils/src/data_gen.rs @@ -104,10 +104,11 @@ impl BatchBuilder { } fn append(&mut self, rng: &mut StdRng, host: &str, service: &str) { - let num_pods = rng.gen_range(self.options.pods_per_host.clone()); + let num_pods = rng.random_range(self.options.pods_per_host.clone()); let pods = generate_sorted_strings(rng, num_pods, 30..40); for pod in pods { - let num_containers = rng.gen_range(self.options.containers_per_pod.clone()); + let num_containers = + rng.random_range(self.options.containers_per_pod.clone()); for container_idx in 0..num_containers { let container = format!("{service}_container_{container_idx}"); let image = format!( @@ -115,7 +116,7 @@ impl BatchBuilder { ); let num_entries = - rng.gen_range(self.options.entries_per_container.clone()); + rng.random_range(self.options.entries_per_container.clone()); for i in 0..num_entries { if self.is_finished() { return; @@ -154,7 +155,7 @@ impl BatchBuilder { if self.options.include_nulls { // Append a null value if the option is set // Use both "NULL" as a string and a null value - if rng.gen_bool(0.5) { + if rng.random_bool(0.5) { self.client_addr.append_null(); } else { self.client_addr.append_value("NULL"); @@ -162,26 +163,26 @@ impl BatchBuilder { } else { self.client_addr.append_value(format!( "{}.{}.{}.{}", - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::() + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::() )); } - self.request_duration.append_value(rng.gen()); + self.request_duration.append_value(rng.random()); self.request_user_agent .append_value(random_string(rng, 20..100)); self.request_method - .append_value(methods[rng.gen_range(0..methods.len())]); + .append_value(methods[rng.random_range(0..methods.len())]); self.request_host .append_value(format!("https://{service}.mydomain.com")); self.request_bytes - .append_option(rng.gen_bool(0.9).then(|| rng.gen())); + .append_option(rng.random_bool(0.9).then(|| rng.random())); self.response_bytes - .append_option(rng.gen_bool(0.9).then(|| rng.gen())); + .append_option(rng.random_bool(0.9).then(|| rng.random())); self.response_status - .append_value(status[rng.gen_range(0..status.len())]); + .append_value(status[rng.random_range(0..status.len())]); self.prices_status.append_value(self.row_count as i128); } @@ -216,9 +217,9 @@ impl BatchBuilder { } fn random_string(rng: &mut StdRng, len_range: Range) -> String { - let len = rng.gen_range(len_range); + let len = rng.random_range(len_range); (0..len) - .map(|_| rng.gen_range(b'a'..=b'z') as char) + .map(|_| rng.random_range(b'a'..=b'z') as char) .collect::() } @@ -364,7 +365,7 @@ impl Iterator for AccessLogGenerator { self.host_idx += 1; for service in &["frontend", "backend", "database", "cache"] { - if self.rng.gen_bool(0.5) { + if self.rng.random_bool(0.5) { continue; } if builder.is_finished() { diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 9db8920833ae5..be2bc0712afbd 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -67,10 +67,9 @@ pub fn add_empty_batches( .flat_map(|batch| { // insert 0, or 1 empty batches before and after the current batch let empty_batch = RecordBatch::new_empty(schema.clone()); - std::iter::repeat(empty_batch.clone()) - .take(rng.gen_range(0..2)) + std::iter::repeat_n(empty_batch.clone(), rng.random_range(0..2)) .chain(std::iter::once(batch)) - .chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2))) + .chain(std::iter::repeat_n(empty_batch, rng.random_range(0..2))) }) .collect() } @@ -101,7 +100,7 @@ pub fn stagger_batch_with_seed(batch: RecordBatch, seed: u64) -> Vec 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs index b598241db1e92..75ed03898a279 100644 --- a/test-utils/src/string_gen.rs +++ b/test-utils/src/string_gen.rs @@ -19,7 +19,7 @@ use crate::array_gen::StringArrayGenerator; use crate::stagger_batch; use arrow::record_batch::RecordBatch; use rand::rngs::StdRng; -use rand::{thread_rng, Rng, SeedableRng}; +use rand::{rng, Rng, SeedableRng}; /// Randomly generate strings pub struct StringBatchGenerator(StringArrayGenerator); @@ -56,18 +56,18 @@ impl StringBatchGenerator { stagger_batch(batch) } - /// Return an set of `BatchGenerator`s that cover a range of interesting + /// Return a set of `BatchGenerator`s that cover a range of interesting /// cases pub fn interesting_cases() -> Vec { let mut cases = vec![]; - let mut rng = thread_rng(); + let mut rng = rng(); for null_pct in [0.0, 0.01, 0.1, 0.5] { for _ in 0..10 { // max length of generated strings - let max_len = rng.gen_range(1..50); - let num_strings = rng.gen_range(1..100); + let max_len = rng.random_range(1..50); + let num_strings = rng.random_range(1..100); let num_distinct_strings = if num_strings > 1 { - rng.gen_range(1..num_strings) + rng.random_range(1..num_strings) } else { num_strings }; @@ -76,7 +76,7 @@ impl StringBatchGenerator { num_strings, num_distinct_strings, null_pct, - rng: StdRng::from_seed(rng.gen()), + rng: StdRng::from_seed(rng.random()), })) } } diff --git a/testing b/testing index d2a1371230349..0d60ccae40d0e 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit d2a13712303498963395318a4eb42872e66aead7 +Subproject commit 0d60ccae40d0e8f2d22c15fafb01c5d4be8c63a6 diff --git a/typos.toml b/typos.toml new file mode 100644 index 0000000000000..09c5c55c452ab --- /dev/null +++ b/typos.toml @@ -0,0 +1,50 @@ +[default.extend-words] +# random words from unit tests +Pn = "Pn" +fo = "fo" +nd = "nd" +Nd = "Nd" +ba = "ba" +ECT = "ECT" +Ue = "Ue" +Iy = "Iy" +hte = "hte" +numer = "numer" +abd = "abd" +aroun = "aroun" +abov = "abov" +Ois = "Ois" +alo = "alo" + +# abbreviations, common words, etc. +typ = "typ" +datas = "datas" +YOUY = "YOUY" +lits = "lits" + +# exposed to public API +Serializeable = "Serializeable" + +# from test cases like TPC-* or ClickBench +carefull = "carefull" +precentage = "precentage" +flate = "flate" +hom = "hom" +alph = "alph" +wih = "wih" +Ded = "Ded" + +# From SLT README +nteger = "nteger" + +[files] +extend-exclude = [ + "*.slt", + "*.slt.part", + "*.svg", + "*.sql", + "dev/changelog/**", + "benchmarks/**", + "*.csv", + "docs/source/contributor-guide/governance.md" +]